scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2025-05-31 10:30:00"
|
|
4
|
+
# Author: ywatanabe
|
|
5
|
+
# File: ./src/scitex/ai/genai/base_provider.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Abstract base class for AI provider implementations.
|
|
9
|
+
|
|
10
|
+
This module defines the interface that all AI providers must implement,
|
|
11
|
+
ensuring consistency across different providers (OpenAI, Anthropic, etc.).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Any, Dict, List, Generator, Optional
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Provider(str, Enum):
|
|
21
|
+
"""Supported AI providers."""
|
|
22
|
+
|
|
23
|
+
OPENAI = "openai"
|
|
24
|
+
ANTHROPIC = "anthropic"
|
|
25
|
+
GOOGLE = "google"
|
|
26
|
+
GROQ = "groq"
|
|
27
|
+
DEEPSEEK = "deepseek"
|
|
28
|
+
LLAMA = "llama"
|
|
29
|
+
PERPLEXITY = "perplexity"
|
|
30
|
+
MOCK = "mock" # For testing
|
|
31
|
+
|
|
32
|
+
def __str__(self):
|
|
33
|
+
return self.value
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Role(str, Enum):
|
|
37
|
+
"""Message roles for chat conversations."""
|
|
38
|
+
|
|
39
|
+
SYSTEM = "system"
|
|
40
|
+
USER = "user"
|
|
41
|
+
ASSISTANT = "assistant"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ProviderConfig:
|
|
46
|
+
"""Configuration for AI providers."""
|
|
47
|
+
|
|
48
|
+
provider: str
|
|
49
|
+
model: str
|
|
50
|
+
api_key: Optional[str] = None
|
|
51
|
+
system_prompt: str = ""
|
|
52
|
+
temperature: float = 1.0
|
|
53
|
+
max_tokens: int = 4096
|
|
54
|
+
stream: bool = False
|
|
55
|
+
seed: Optional[int] = None
|
|
56
|
+
n_keep: int = 1
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class CompletionResponse:
|
|
61
|
+
"""Standard response format for completions."""
|
|
62
|
+
|
|
63
|
+
content: str
|
|
64
|
+
input_tokens: int
|
|
65
|
+
output_tokens: int
|
|
66
|
+
finish_reason: str = "stop"
|
|
67
|
+
provider_response: Optional[Any] = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BaseProvider(ABC):
|
|
71
|
+
"""Abstract base class for AI providers.
|
|
72
|
+
|
|
73
|
+
All AI provider implementations must inherit from this class
|
|
74
|
+
and implement the required abstract methods.
|
|
75
|
+
|
|
76
|
+
Example
|
|
77
|
+
-------
|
|
78
|
+
>>> class MyProvider(BaseProvider):
|
|
79
|
+
... def init_client(self) -> Any:
|
|
80
|
+
... return MyAPIClient(self.api_key)
|
|
81
|
+
...
|
|
82
|
+
... def format_history(self, history: List[Dict]) -> List[Dict]:
|
|
83
|
+
... # Provider-specific formatting
|
|
84
|
+
... return history
|
|
85
|
+
...
|
|
86
|
+
... def call_static(self, messages: List[Dict], **kwargs) -> Any:
|
|
87
|
+
... # Make API call
|
|
88
|
+
... return self.client.complete(messages)
|
|
89
|
+
...
|
|
90
|
+
... def call_stream(self, messages: List[Dict], **kwargs) -> Generator:
|
|
91
|
+
... # Make streaming API call
|
|
92
|
+
... for chunk in self.client.stream(messages):
|
|
93
|
+
... yield chunk
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def init_client(self) -> Any:
|
|
98
|
+
"""Initialize the provider-specific client.
|
|
99
|
+
|
|
100
|
+
This method should create and configure the API client
|
|
101
|
+
for the specific provider (e.g., OpenAI client, Anthropic client).
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
Any
|
|
106
|
+
The initialized client object
|
|
107
|
+
"""
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def format_history(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
112
|
+
"""Format conversation history for the provider's API.
|
|
113
|
+
|
|
114
|
+
Different providers may expect different formats for conversation
|
|
115
|
+
history. This method converts the standard format to the
|
|
116
|
+
provider-specific format.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
history : List[Dict[str, Any]]
|
|
121
|
+
Standard format conversation history
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
List[Dict[str, Any]]
|
|
126
|
+
Provider-specific formatted history
|
|
127
|
+
"""
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
@abstractmethod
|
|
131
|
+
def call_static(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
|
|
132
|
+
"""Make a static (non-streaming) API call.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
messages : List[Dict[str, Any]]
|
|
137
|
+
Formatted conversation messages
|
|
138
|
+
**kwargs
|
|
139
|
+
Additional provider-specific parameters
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
Any
|
|
144
|
+
Provider-specific response object
|
|
145
|
+
"""
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
@abstractmethod
|
|
149
|
+
def call_stream(
|
|
150
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
|
151
|
+
) -> Generator[str, None, None]:
|
|
152
|
+
"""Make a streaming API call.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
messages : List[Dict[str, Any]]
|
|
157
|
+
Formatted conversation messages
|
|
158
|
+
**kwargs
|
|
159
|
+
Additional provider-specific parameters
|
|
160
|
+
|
|
161
|
+
Yields
|
|
162
|
+
------
|
|
163
|
+
str
|
|
164
|
+
Response text chunks
|
|
165
|
+
"""
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def supports_streaming(self) -> bool:
|
|
171
|
+
"""Whether this provider supports streaming responses.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
bool
|
|
176
|
+
True if streaming is supported
|
|
177
|
+
"""
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
@abstractmethod
|
|
182
|
+
def supports_images(self) -> bool:
|
|
183
|
+
"""Whether this provider supports image inputs.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
bool
|
|
188
|
+
True if images are supported
|
|
189
|
+
"""
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
@abstractmethod
|
|
194
|
+
def max_context_length(self) -> int:
|
|
195
|
+
"""Maximum context length in tokens.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
int
|
|
200
|
+
Maximum number of tokens
|
|
201
|
+
"""
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
def get_capabilities(self) -> Dict[str, Any]:
|
|
205
|
+
"""Get provider capabilities summary.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
Dict[str, Any]
|
|
210
|
+
Dictionary of provider capabilities
|
|
211
|
+
"""
|
|
212
|
+
return {
|
|
213
|
+
"supports_streaming": self.supports_streaming,
|
|
214
|
+
"supports_images": self.supports_images,
|
|
215
|
+
"max_context_length": self.max_context_length,
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
def extract_tokens_from_response(self, response: Any) -> Dict[str, int]:
|
|
219
|
+
"""Extract token usage from provider response.
|
|
220
|
+
|
|
221
|
+
Default implementation returns zeros. Providers should override
|
|
222
|
+
to extract actual token counts from their response format.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
response : Any
|
|
227
|
+
Provider-specific response object
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
Dict[str, int]
|
|
232
|
+
Dictionary with 'input_tokens' and 'output_tokens'
|
|
233
|
+
"""
|
|
234
|
+
return {"input_tokens": 0, "output_tokens": 0}
|
|
235
|
+
|
|
236
|
+
def handle_rate_limit(self, error: Exception) -> bool:
|
|
237
|
+
"""Handle rate limit errors.
|
|
238
|
+
|
|
239
|
+
Default implementation returns False. Providers can override
|
|
240
|
+
to implement retry logic or other handling.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
error : Exception
|
|
245
|
+
The error that occurred
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
bool
|
|
250
|
+
True if the error was handled and operation should retry
|
|
251
|
+
"""
|
|
252
|
+
return False
|
|
253
|
+
|
|
254
|
+
def validate_model(self, model: str) -> bool:
|
|
255
|
+
"""Validate if a model is supported.
|
|
256
|
+
|
|
257
|
+
Default implementation returns True. Providers should override
|
|
258
|
+
to validate against their supported models.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
model : str
|
|
263
|
+
Model name to validate
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
bool
|
|
268
|
+
True if model is supported
|
|
269
|
+
"""
|
|
270
|
+
return True
|
|
271
|
+
|
|
272
|
+
def get_error_message(self, error: Exception) -> str:
|
|
273
|
+
"""Extract user-friendly error message.
|
|
274
|
+
|
|
275
|
+
Default implementation returns string representation.
|
|
276
|
+
Providers can override for better error messages.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
error : Exception
|
|
281
|
+
The error that occurred
|
|
282
|
+
|
|
283
|
+
Returns
|
|
284
|
+
-------
|
|
285
|
+
str
|
|
286
|
+
User-friendly error message
|
|
287
|
+
"""
|
|
288
|
+
return str(error)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# EOF
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-04 01:37:36 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/_gen_ai/_calc_cost.py
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Functionality:
|
|
8
|
+
- Calculates usage costs for AI model API calls
|
|
9
|
+
- Handles token-based pricing for different models
|
|
10
|
+
Input:
|
|
11
|
+
- Model name
|
|
12
|
+
- Number of input and output tokens used
|
|
13
|
+
Output:
|
|
14
|
+
- Total cost in USD based on token usage
|
|
15
|
+
Prerequisites:
|
|
16
|
+
- MODELS parameter dictionary with pricing information
|
|
17
|
+
- pandas package
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from typing import Union, Any
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
from .params import MODELS
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def calc_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
|
27
|
+
"""Calculates API usage cost based on token count.
|
|
28
|
+
|
|
29
|
+
Example
|
|
30
|
+
-------
|
|
31
|
+
>>> cost = calc_cost("gpt-4", 100, 50)
|
|
32
|
+
>>> print(f"${cost:.4f}")
|
|
33
|
+
$0.0030
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
model : str
|
|
38
|
+
Name of the AI model
|
|
39
|
+
input_tokens : int
|
|
40
|
+
Number of input tokens used
|
|
41
|
+
output_tokens : int
|
|
42
|
+
Number of output tokens used
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
float
|
|
47
|
+
Total cost in USD
|
|
48
|
+
|
|
49
|
+
Raises
|
|
50
|
+
------
|
|
51
|
+
ValueError
|
|
52
|
+
If model is not found in MODELS
|
|
53
|
+
"""
|
|
54
|
+
models_df = pd.DataFrame(MODELS)
|
|
55
|
+
indi = models_df["name"] == model
|
|
56
|
+
|
|
57
|
+
if not indi.any():
|
|
58
|
+
raise ValueError(f"Model '{model}' not found in pricing table")
|
|
59
|
+
|
|
60
|
+
costs = models_df[["input_cost", "output_cost"]][indi]
|
|
61
|
+
cost = (
|
|
62
|
+
input_tokens * costs["input_cost"] + output_tokens * costs["output_cost"]
|
|
63
|
+
) / 1_000_000
|
|
64
|
+
|
|
65
|
+
return cost.iloc[0]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# def calc_cost(model, input_tokens, output_tokens):
|
|
69
|
+
# indi = MODELS["name"] == model
|
|
70
|
+
# costs = MODELS[["input_cost", "output_cost"]][indi]
|
|
71
|
+
# cost = (
|
|
72
|
+
# input_tokens * costs["input_cost"]
|
|
73
|
+
# + output_tokens * costs["output_cost"]
|
|
74
|
+
# ) / 1_000_000
|
|
75
|
+
# return cost.iloc[0]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# EOF
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2025-05-31 10:10:00"
|
|
4
|
+
# Author: ywatanabe
|
|
5
|
+
# File: ./src/scitex/ai/genai/chat_history.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Manages conversation history for AI providers.
|
|
9
|
+
|
|
10
|
+
This module handles chat history management including:
|
|
11
|
+
- Message storage and retrieval
|
|
12
|
+
- Role alternation enforcement
|
|
13
|
+
- System message handling
|
|
14
|
+
- History truncation
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import List, Dict, Optional, Any
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Message:
|
|
24
|
+
"""Represents a single message in chat history.
|
|
25
|
+
|
|
26
|
+
Attributes
|
|
27
|
+
----------
|
|
28
|
+
role : str
|
|
29
|
+
Message role (system, user, assistant)
|
|
30
|
+
content : str
|
|
31
|
+
Message content
|
|
32
|
+
images : Optional[List[str]]
|
|
33
|
+
Optional base64-encoded images
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
role: str
|
|
37
|
+
content: str
|
|
38
|
+
images: Optional[List[str]] = None
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
41
|
+
"""Convert message to dictionary.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
Dict[str, Any]
|
|
46
|
+
Dictionary representation
|
|
47
|
+
"""
|
|
48
|
+
d = {"role": self.role, "content": self.content}
|
|
49
|
+
if self.images:
|
|
50
|
+
d["images"] = self.images
|
|
51
|
+
return d
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ChatHistory:
|
|
55
|
+
"""Manages conversation history with role enforcement.
|
|
56
|
+
|
|
57
|
+
Example
|
|
58
|
+
-------
|
|
59
|
+
>>> history = ChatHistory(n_keep=5)
|
|
60
|
+
>>> history.add_message("user", "Hello")
|
|
61
|
+
>>> history.add_message("assistant", "Hi there!")
|
|
62
|
+
>>> messages = history.get_messages()
|
|
63
|
+
>>> print(len(messages))
|
|
64
|
+
2
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
system_prompt : Optional[str]
|
|
69
|
+
Optional system prompt to prepend
|
|
70
|
+
n_keep : int
|
|
71
|
+
Number of recent exchanges to keep (default: 1)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
VALID_ROLES = {"system", "user", "assistant"}
|
|
75
|
+
|
|
76
|
+
def __init__(self, system_prompt: Optional[str] = None, n_keep: int = 1):
|
|
77
|
+
"""Initialize chat history manager.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
system_prompt : Optional[str]
|
|
82
|
+
Optional system prompt
|
|
83
|
+
n_keep : int
|
|
84
|
+
Number of recent exchanges to keep (-1 to keep all)
|
|
85
|
+
"""
|
|
86
|
+
self.system_prompt = system_prompt or ""
|
|
87
|
+
self.n_keep = n_keep
|
|
88
|
+
self.messages: List[Message] = []
|
|
89
|
+
|
|
90
|
+
# Add system message if provided
|
|
91
|
+
if system_prompt:
|
|
92
|
+
self.messages.append(Message(role="system", content=system_prompt))
|
|
93
|
+
|
|
94
|
+
def add_message(
|
|
95
|
+
self, role: str, content: str, images: Optional[List[str]] = None
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Add a message to the history.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
role : str
|
|
102
|
+
Message role ("user", "assistant", "system")
|
|
103
|
+
content : str
|
|
104
|
+
Message content
|
|
105
|
+
images : Optional[List[str]]
|
|
106
|
+
Optional images for multimodal messages
|
|
107
|
+
|
|
108
|
+
Raises
|
|
109
|
+
------
|
|
110
|
+
ValueError
|
|
111
|
+
If role is invalid
|
|
112
|
+
"""
|
|
113
|
+
if role not in self.VALID_ROLES:
|
|
114
|
+
raise ValueError(f"Invalid role: {role}. Must be one of {self.VALID_ROLES}")
|
|
115
|
+
|
|
116
|
+
# Don't add duplicate system messages
|
|
117
|
+
if role == "system" and self.messages and self.messages[0].role == "system":
|
|
118
|
+
self.messages[0] = Message(role=role, content=content)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
self.messages.append(Message(role=role, content=content, images=images))
|
|
122
|
+
self._trim_history()
|
|
123
|
+
|
|
124
|
+
def _trim_history(self) -> None:
|
|
125
|
+
"""Trim history to n_keep exchanges."""
|
|
126
|
+
if self.n_keep == -1:
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Count system message
|
|
130
|
+
has_system = self.messages and self.messages[0].role == "system"
|
|
131
|
+
start_idx = 1 if has_system else 0
|
|
132
|
+
|
|
133
|
+
# Keep only last n_keep exchanges (2 messages per exchange)
|
|
134
|
+
if len(self.messages) - start_idx > self.n_keep * 2:
|
|
135
|
+
kept_messages = self.messages[-self.n_keep * 2 :]
|
|
136
|
+
if has_system:
|
|
137
|
+
self.messages = [self.messages[0]] + kept_messages
|
|
138
|
+
else:
|
|
139
|
+
self.messages = kept_messages
|
|
140
|
+
|
|
141
|
+
def format_for_api(self, provider: str) -> List[Dict[str, Any]]:
|
|
142
|
+
"""Format messages for specific provider API.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
provider : str
|
|
147
|
+
Provider name (openai, anthropic, google)
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
List[Dict[str, Any]]
|
|
152
|
+
Formatted messages
|
|
153
|
+
"""
|
|
154
|
+
provider = provider.lower()
|
|
155
|
+
|
|
156
|
+
if provider == "openai":
|
|
157
|
+
return self._format_for_openai()
|
|
158
|
+
elif provider == "anthropic":
|
|
159
|
+
return self._format_for_anthropic()
|
|
160
|
+
elif provider == "google":
|
|
161
|
+
return self._format_for_google()
|
|
162
|
+
else:
|
|
163
|
+
# Default format
|
|
164
|
+
return [msg.to_dict() for msg in self.messages]
|
|
165
|
+
|
|
166
|
+
def _format_for_openai(self) -> List[Dict[str, Any]]:
|
|
167
|
+
"""Format messages for OpenAI API."""
|
|
168
|
+
formatted = []
|
|
169
|
+
|
|
170
|
+
for msg in self.messages:
|
|
171
|
+
if msg.images:
|
|
172
|
+
# Multimodal message
|
|
173
|
+
content = [{"type": "text", "text": msg.content}]
|
|
174
|
+
for img in msg.images:
|
|
175
|
+
content.append(
|
|
176
|
+
{
|
|
177
|
+
"type": "image_url",
|
|
178
|
+
"image_url": {"url": f"data:image/jpeg;base64,{img}"},
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
formatted.append({"role": msg.role, "content": content})
|
|
182
|
+
else:
|
|
183
|
+
formatted.append({"role": msg.role, "content": msg.content})
|
|
184
|
+
|
|
185
|
+
return formatted
|
|
186
|
+
|
|
187
|
+
def _format_for_anthropic(self) -> List[Dict[str, Any]]:
|
|
188
|
+
"""Format messages for Anthropic API (excludes system)."""
|
|
189
|
+
formatted = []
|
|
190
|
+
|
|
191
|
+
for msg in self.messages:
|
|
192
|
+
if msg.role == "system":
|
|
193
|
+
continue # Anthropic handles system separately
|
|
194
|
+
formatted.append({"role": msg.role, "content": msg.content})
|
|
195
|
+
|
|
196
|
+
return formatted
|
|
197
|
+
|
|
198
|
+
def _format_for_google(self) -> List[Dict[str, Any]]:
|
|
199
|
+
"""Format messages for Google API."""
|
|
200
|
+
formatted = []
|
|
201
|
+
|
|
202
|
+
for msg in self.messages:
|
|
203
|
+
if msg.images:
|
|
204
|
+
parts = [{"text": msg.content}]
|
|
205
|
+
for img in msg.images:
|
|
206
|
+
parts.append(
|
|
207
|
+
{"inline_data": {"mime_type": "image/jpeg", "data": img}}
|
|
208
|
+
)
|
|
209
|
+
formatted.append({"role": msg.role, "parts": parts})
|
|
210
|
+
else:
|
|
211
|
+
formatted.append({"role": msg.role, "parts": [{"text": msg.content}]})
|
|
212
|
+
|
|
213
|
+
return formatted
|
|
214
|
+
|
|
215
|
+
def ensure_valid_sequence(self) -> None:
|
|
216
|
+
"""Ensure messages follow valid sequence rules.
|
|
217
|
+
|
|
218
|
+
- Must start with user message (after system)
|
|
219
|
+
- Must alternate between user and assistant
|
|
220
|
+
"""
|
|
221
|
+
if not self.messages:
|
|
222
|
+
return
|
|
223
|
+
|
|
224
|
+
# Skip system message if present
|
|
225
|
+
start_idx = 1 if self.messages and self.messages[0].role == "system" else 0
|
|
226
|
+
|
|
227
|
+
# Ensure starts with user
|
|
228
|
+
if (
|
|
229
|
+
len(self.messages) > start_idx
|
|
230
|
+
and self.messages[start_idx].role == "assistant"
|
|
231
|
+
):
|
|
232
|
+
self.messages.insert(start_idx, Message(role="user", content="Hello"))
|
|
233
|
+
|
|
234
|
+
# Ensure alternating
|
|
235
|
+
i = start_idx
|
|
236
|
+
while i < len(self.messages) - 1:
|
|
237
|
+
current = self.messages[i]
|
|
238
|
+
next_msg = self.messages[i + 1]
|
|
239
|
+
|
|
240
|
+
if current.role == next_msg.role:
|
|
241
|
+
# Insert appropriate message
|
|
242
|
+
if current.role == "user":
|
|
243
|
+
self.messages.insert(
|
|
244
|
+
i + 1, Message(role="assistant", content="...")
|
|
245
|
+
)
|
|
246
|
+
else:
|
|
247
|
+
self.messages.insert(i + 1, Message(role="user", content="..."))
|
|
248
|
+
i += 1
|
|
249
|
+
|
|
250
|
+
def clear(self) -> None:
|
|
251
|
+
"""Clear history, keeping only system message if present."""
|
|
252
|
+
if self.messages and self.messages[0].role == "system":
|
|
253
|
+
self.messages = [self.messages[0]]
|
|
254
|
+
else:
|
|
255
|
+
self.messages = []
|
|
256
|
+
|
|
257
|
+
def get_messages(self) -> List[Message]:
|
|
258
|
+
"""Get copy of messages.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
List[Message]
|
|
263
|
+
Copy of message list
|
|
264
|
+
"""
|
|
265
|
+
return deepcopy(self.messages)
|
|
266
|
+
|
|
267
|
+
def __len__(self) -> int:
|
|
268
|
+
"""Get number of messages in history."""
|
|
269
|
+
return len(self.messages)
|
|
270
|
+
|
|
271
|
+
def __repr__(self) -> str:
|
|
272
|
+
"""String representation of ChatHistory."""
|
|
273
|
+
return f"ChatHistory(messages={len(self.messages)}, n_keep={self.n_keep})"
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
# Backward compatibility aliases
|
|
277
|
+
def get_history(self) -> List[Dict[str, Any]]:
|
|
278
|
+
"""Get history as list of dicts (backward compatibility)."""
|
|
279
|
+
return [msg.to_dict() for msg in self.messages]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def ensure_alternating(self) -> None:
|
|
283
|
+
"""Ensure alternating messages (backward compatibility)."""
|
|
284
|
+
self.ensure_valid_sequence()
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def ensure_user_first(self) -> None:
|
|
288
|
+
"""Ensure user first (backward compatibility)."""
|
|
289
|
+
self.ensure_valid_sequence()
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def reset(self, system_message: Optional[str] = None) -> None:
|
|
293
|
+
"""Reset history (backward compatibility)."""
|
|
294
|
+
self.clear()
|
|
295
|
+
if system_message:
|
|
296
|
+
self.system_prompt = system_message
|
|
297
|
+
self.messages.append(Message(role="system", content=system_message))
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
# Add backward compatibility methods to ChatHistory
|
|
301
|
+
ChatHistory.get_history = get_history
|
|
302
|
+
ChatHistory.ensure_alternating = ensure_alternating
|
|
303
|
+
ChatHistory.ensure_user_first = ensure_user_first
|
|
304
|
+
ChatHistory.reset = reset
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# EOF
|