scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2025-05-31 10:00:00"
|
|
4
|
+
# Author: ywatanabe
|
|
5
|
+
# File: ./src/scitex/ai/genai/auth_manager.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Handles API key management and validation for AI providers.
|
|
9
|
+
|
|
10
|
+
This module provides secure handling of API keys including:
|
|
11
|
+
- Environment variable retrieval
|
|
12
|
+
- Key validation
|
|
13
|
+
- Masked key display for security
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Optional, Dict, Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AuthManager:
|
|
21
|
+
"""Manages API key authentication for AI providers.
|
|
22
|
+
|
|
23
|
+
Example
|
|
24
|
+
-------
|
|
25
|
+
>>> auth = AuthManager("sk-abc123...", "OpenAI")
|
|
26
|
+
>>> auth.validate_key()
|
|
27
|
+
True
|
|
28
|
+
>>> auth.get_masked_key()
|
|
29
|
+
'sk-a****3...'
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
api_key : str
|
|
34
|
+
The API key for authentication
|
|
35
|
+
provider : str
|
|
36
|
+
The provider name (e.g., "OpenAI", "Anthropic")
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Mapping of providers to environment variable names
|
|
40
|
+
ENV_VAR_MAPPING: Dict[str, str] = {
|
|
41
|
+
"OpenAI": "OPENAI_API_KEY",
|
|
42
|
+
"Anthropic": "ANTHROPIC_API_KEY",
|
|
43
|
+
"Google": "GOOGLE_API_KEY",
|
|
44
|
+
"Groq": "GROQ_API_KEY",
|
|
45
|
+
"DeepSeek": "DEEPSEEK_API_KEY",
|
|
46
|
+
"Perplexity": "PERPLEXITY_API_KEY",
|
|
47
|
+
"Llama": "LLAMA_API_KEY",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
def __init__(self, api_key: Optional[str], provider: str):
|
|
51
|
+
"""Initialize AuthManager with API key and provider.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
api_key : Optional[str]
|
|
56
|
+
The API key. If None, will attempt to get from environment
|
|
57
|
+
provider : str
|
|
58
|
+
The provider name
|
|
59
|
+
"""
|
|
60
|
+
# Normalize provider name to lowercase
|
|
61
|
+
self.provider = provider.lower()
|
|
62
|
+
|
|
63
|
+
# Check if provider is known
|
|
64
|
+
if self.provider not in [p.lower() for p in self.ENV_VAR_MAPPING.keys()]:
|
|
65
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
66
|
+
|
|
67
|
+
# Store private api key
|
|
68
|
+
self._api_key = api_key or self.get_key_from_env(provider)
|
|
69
|
+
|
|
70
|
+
if not self._api_key:
|
|
71
|
+
raise ValueError(f"No API key provided for {provider}")
|
|
72
|
+
|
|
73
|
+
# Public property for backward compatibility
|
|
74
|
+
self.api_key = self._api_key
|
|
75
|
+
|
|
76
|
+
def validate(self) -> bool:
|
|
77
|
+
"""Validate the API key format.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
bool
|
|
82
|
+
True if key appears valid, False otherwise
|
|
83
|
+
|
|
84
|
+
Raises
|
|
85
|
+
------
|
|
86
|
+
ValueError
|
|
87
|
+
If API key is missing or invalid
|
|
88
|
+
"""
|
|
89
|
+
if not self._api_key:
|
|
90
|
+
raise ValueError("No API key configured")
|
|
91
|
+
|
|
92
|
+
# Basic validation - ensure key is not empty and has reasonable length
|
|
93
|
+
if len(self._api_key) < 10:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"API key for {self.provider} appears to be invalid (too short)"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
def get_masked_key(self) -> str:
|
|
101
|
+
"""Get a masked version of the API key for display.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
str
|
|
106
|
+
Masked API key showing only first few and last few characters
|
|
107
|
+
"""
|
|
108
|
+
if not self._api_key:
|
|
109
|
+
return "No API key"
|
|
110
|
+
|
|
111
|
+
if len(self._api_key) <= 8:
|
|
112
|
+
return "*****"
|
|
113
|
+
|
|
114
|
+
# Match test expectation: show first 3 and last 4 characters with "..." in between
|
|
115
|
+
return f"{self._api_key[:3]}...{self._api_key[-4:]}"
|
|
116
|
+
|
|
117
|
+
def get_client_config(self) -> Dict[str, Any]:
|
|
118
|
+
"""Get client configuration for the provider.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
Dict[str, any]
|
|
123
|
+
Configuration dictionary for the provider client
|
|
124
|
+
"""
|
|
125
|
+
config = {"api_key": self._api_key}
|
|
126
|
+
|
|
127
|
+
# Provider-specific configurations
|
|
128
|
+
if self.provider == "openai":
|
|
129
|
+
# Check for optional organization
|
|
130
|
+
org = os.getenv("OPENAI_ORGANIZATION")
|
|
131
|
+
if org:
|
|
132
|
+
config["organization"] = org
|
|
133
|
+
elif self.provider == "anthropic":
|
|
134
|
+
config["max_retries"] = 3
|
|
135
|
+
elif self.provider == "google":
|
|
136
|
+
# Google only needs API key
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
return config
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def get_key_from_env(cls, provider: str) -> Optional[str]:
|
|
143
|
+
"""Get API key from environment variable.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
provider : str
|
|
148
|
+
The provider name
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
Optional[str]
|
|
153
|
+
The API key if found in environment, None otherwise
|
|
154
|
+
"""
|
|
155
|
+
# Find the env var name case-insensitively
|
|
156
|
+
for p, env_var in cls.ENV_VAR_MAPPING.items():
|
|
157
|
+
if p.lower() == provider.lower():
|
|
158
|
+
return os.getenv(env_var)
|
|
159
|
+
|
|
160
|
+
# Try generic pattern if not found
|
|
161
|
+
env_var = f"{provider.upper()}_API_KEY"
|
|
162
|
+
return os.getenv(env_var)
|
|
163
|
+
|
|
164
|
+
def get_api_key(self, provider: str, api_key: Optional[str] = None) -> str:
|
|
165
|
+
"""Get API key for a provider.
|
|
166
|
+
|
|
167
|
+
This method is used by provider_base.py for compatibility.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
provider : str
|
|
172
|
+
The provider name
|
|
173
|
+
api_key : Optional[str]
|
|
174
|
+
Explicit API key, if None will use environment
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
str
|
|
179
|
+
The API key
|
|
180
|
+
|
|
181
|
+
Raises
|
|
182
|
+
------
|
|
183
|
+
ValueError
|
|
184
|
+
If no API key is found
|
|
185
|
+
"""
|
|
186
|
+
if api_key:
|
|
187
|
+
return api_key
|
|
188
|
+
|
|
189
|
+
key = self.get_key_from_env(provider)
|
|
190
|
+
if not key:
|
|
191
|
+
raise ValueError(f"No API key provided for {provider}")
|
|
192
|
+
|
|
193
|
+
return key
|
|
194
|
+
|
|
195
|
+
def __repr__(self) -> str:
|
|
196
|
+
"""String representation of AuthManager."""
|
|
197
|
+
return f"AuthManager(provider={self.provider}, key={self.get_masked_key()})"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# EOF
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Timestamp: "2025-05-03 11:55:54 (ywatanabe)"
|
|
4
|
+
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/ai/_gen_ai/_BaseGenAI.py
|
|
5
|
+
# ----------------------------------------
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
__FILE__ = "./src/scitex/ai/_gen_ai/_BaseGenAI.py"
|
|
9
|
+
__DIR__ = os.path.dirname(__FILE__)
|
|
10
|
+
# ----------------------------------------
|
|
11
|
+
|
|
12
|
+
import base64
|
|
13
|
+
import sys
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from ...io._load import load
|
|
21
|
+
from .calc_cost import calc_cost
|
|
22
|
+
from .format_output_func import format_output_func
|
|
23
|
+
from .params import MODELS
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseGenAI(ABC):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
system_setting: str = "",
|
|
30
|
+
model: str = "",
|
|
31
|
+
api_key: str = "",
|
|
32
|
+
stream: bool = False,
|
|
33
|
+
seed: Optional[int] = None,
|
|
34
|
+
n_keep: int = 1,
|
|
35
|
+
temperature: float = 1.0,
|
|
36
|
+
provider: str = "",
|
|
37
|
+
chat_history: Optional[List[Dict[str, str]]] = None,
|
|
38
|
+
max_tokens: int = 4_096,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.provider = provider
|
|
41
|
+
self.system_setting = system_setting
|
|
42
|
+
self.model = model
|
|
43
|
+
self.api_key = api_key
|
|
44
|
+
self.stream = stream
|
|
45
|
+
self.seed = seed
|
|
46
|
+
self.n_keep = n_keep
|
|
47
|
+
self.temperature = temperature
|
|
48
|
+
self.max_tokens = max_tokens
|
|
49
|
+
self.input_tokens = 0
|
|
50
|
+
self.output_tokens = 0
|
|
51
|
+
self._error_messages: List[str] = []
|
|
52
|
+
|
|
53
|
+
self.reset(system_setting)
|
|
54
|
+
self.history = chat_history if chat_history else []
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
self.verify_model()
|
|
58
|
+
self.client = self._init_client()
|
|
59
|
+
except Exception as error:
|
|
60
|
+
print(error)
|
|
61
|
+
self._error_messages.append(f"\nError:\n{str(error)}")
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def list_models(cls, provider: Optional[str] = None) -> List[str]:
|
|
65
|
+
"""List available models for the provider. If provider is None, list all models."""
|
|
66
|
+
if provider:
|
|
67
|
+
indi = [
|
|
68
|
+
provider.lower() in api_key_env.lower()
|
|
69
|
+
for api_key_env in MODELS["api_key_env"]
|
|
70
|
+
]
|
|
71
|
+
models = MODELS[indi].name.tolist()
|
|
72
|
+
providers = MODELS[indi].provider.tolist()
|
|
73
|
+
|
|
74
|
+
else:
|
|
75
|
+
indi = np.arange(len(MODELS))
|
|
76
|
+
models = MODELS.name.tolist()
|
|
77
|
+
providers = MODELS.provider.tolist()
|
|
78
|
+
|
|
79
|
+
for provider, model in zip(providers, models):
|
|
80
|
+
print(f"- {provider} - {model}")
|
|
81
|
+
|
|
82
|
+
return models
|
|
83
|
+
|
|
84
|
+
def gen_error(
|
|
85
|
+
self, return_stream: bool
|
|
86
|
+
) -> tuple[bool, Optional[Union[str, Generator]]]:
|
|
87
|
+
error_exists = bool(self._error_messages)
|
|
88
|
+
if not error_exists:
|
|
89
|
+
return False, None
|
|
90
|
+
|
|
91
|
+
error_msgs = self._error_messages
|
|
92
|
+
self._error_messages = []
|
|
93
|
+
|
|
94
|
+
if not self.stream:
|
|
95
|
+
return True, "".join(error_msgs)
|
|
96
|
+
|
|
97
|
+
stream_obj = self._to_stream(error_msgs)
|
|
98
|
+
return True, (
|
|
99
|
+
self._yield_stream(stream_obj) if not return_stream else stream_obj
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def __call__(
|
|
103
|
+
self,
|
|
104
|
+
prompt: Optional[str] = None,
|
|
105
|
+
prompt_file: Optional[str] = None,
|
|
106
|
+
images: List[Any] = None,
|
|
107
|
+
format_output: bool = False,
|
|
108
|
+
return_stream: bool = False,
|
|
109
|
+
) -> Union[str, Generator]:
|
|
110
|
+
|
|
111
|
+
# ----------------------------------------
|
|
112
|
+
# Handles Prompt and Prompt File
|
|
113
|
+
if (not prompt) and (not prompt_file):
|
|
114
|
+
print("Please input prompt\n")
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
if prompt_file:
|
|
118
|
+
file_content = load(prompt_file)
|
|
119
|
+
# Escape special characters
|
|
120
|
+
escaped_content = [repr(line)[1:-1] for line in file_content]
|
|
121
|
+
prompt = (
|
|
122
|
+
str(prompt).strip() + "\n\n" + str("\n".join(escaped_content)).strip()
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# if prompt_file:
|
|
126
|
+
# prompt = (
|
|
127
|
+
# str(prompt).strip()
|
|
128
|
+
# + "\n\n"
|
|
129
|
+
# + str("\n".join(load(prompt_file))).strip()
|
|
130
|
+
# )
|
|
131
|
+
|
|
132
|
+
if prompt.strip() == "":
|
|
133
|
+
print("Please input prompt\n")
|
|
134
|
+
return
|
|
135
|
+
# ----------------------------------------
|
|
136
|
+
|
|
137
|
+
self.update_history("user", prompt or "", images=images)
|
|
138
|
+
|
|
139
|
+
error_flag, error_obj = self.gen_error(return_stream)
|
|
140
|
+
if error_flag:
|
|
141
|
+
return error_obj
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
if not self.stream:
|
|
145
|
+
return self._call_static(format_output)
|
|
146
|
+
|
|
147
|
+
if return_stream:
|
|
148
|
+
self.stream, orig_stream = return_stream, self.stream
|
|
149
|
+
stream_obj = self._call_stream(format_output)
|
|
150
|
+
self.stream = orig_stream
|
|
151
|
+
return stream_obj
|
|
152
|
+
|
|
153
|
+
return self._yield_stream(self._call_stream(format_output))
|
|
154
|
+
|
|
155
|
+
except Exception as error:
|
|
156
|
+
self._error_messages.append(f"\nError:\n{str(error)}")
|
|
157
|
+
error_flag, error_obj = self.gen_error(return_stream)
|
|
158
|
+
if error_flag:
|
|
159
|
+
return error_obj
|
|
160
|
+
|
|
161
|
+
def _yield_stream(self, stream_obj: Generator) -> str:
|
|
162
|
+
accumulated = []
|
|
163
|
+
for chunk in stream_obj:
|
|
164
|
+
if chunk:
|
|
165
|
+
sys.stdout.write(chunk)
|
|
166
|
+
sys.stdout.flush()
|
|
167
|
+
accumulated.append(chunk)
|
|
168
|
+
result = "".join(accumulated)
|
|
169
|
+
self.update_history("assistant", result)
|
|
170
|
+
return result
|
|
171
|
+
|
|
172
|
+
def _call_static(self, format_output: bool = True) -> str:
|
|
173
|
+
out_text = self._api_call_static()
|
|
174
|
+
out_text = format_output_func(out_text) if format_output else out_text
|
|
175
|
+
self.update_history("assistant", out_text)
|
|
176
|
+
return out_text
|
|
177
|
+
|
|
178
|
+
def _call_stream(self, format_output: Optional[bool] = None) -> Generator:
|
|
179
|
+
return self._api_call_stream()
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
def _init_client(self) -> Any:
|
|
183
|
+
"""Returns client"""
|
|
184
|
+
pass
|
|
185
|
+
|
|
186
|
+
def _api_format_history(self, history):
|
|
187
|
+
"""Returns chat_history by handling differences in API expectations"""
|
|
188
|
+
return history
|
|
189
|
+
|
|
190
|
+
@abstractmethod
|
|
191
|
+
def _api_call_static(self) -> str:
|
|
192
|
+
"""Returns out_text by handling differences in API expectations"""
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
@abstractmethod
|
|
196
|
+
def _api_call_stream(self) -> Generator:
|
|
197
|
+
"""Returns stream by handling differences in API expectations"""
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
def _get_available_models(self) -> List[str]:
|
|
201
|
+
indi = [
|
|
202
|
+
self.provider.lower() in api_key_env.lower()
|
|
203
|
+
for api_key_env in MODELS["api_key_env"]
|
|
204
|
+
]
|
|
205
|
+
return MODELS[indi].name.tolist()
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def available_models(self) -> List[str]:
|
|
209
|
+
return self._get_available_models()
|
|
210
|
+
|
|
211
|
+
def reset(self, system_setting: str = "") -> None:
|
|
212
|
+
self.history = []
|
|
213
|
+
if system_setting:
|
|
214
|
+
self.history.append({"role": "system", "content": system_setting})
|
|
215
|
+
|
|
216
|
+
def _ensure_alternative_history(
|
|
217
|
+
self, history: List[Dict[str, str]]
|
|
218
|
+
) -> List[Dict[str, str]]:
|
|
219
|
+
if len(history) < 2:
|
|
220
|
+
return history
|
|
221
|
+
|
|
222
|
+
if history[-1]["role"] == history[-2]["role"]:
|
|
223
|
+
last_content = history.pop()["content"]
|
|
224
|
+
history[-1]["content"] += f"\n\n{last_content}"
|
|
225
|
+
return self._ensure_alternative_history(history)
|
|
226
|
+
|
|
227
|
+
return history
|
|
228
|
+
|
|
229
|
+
@staticmethod
|
|
230
|
+
def _ensure_start_from_user(history: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
|
231
|
+
if history and history[0]["role"] != "user":
|
|
232
|
+
history.pop(0)
|
|
233
|
+
return history
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def _ensure_base64_encoding(image, max_size=512):
|
|
237
|
+
import io
|
|
238
|
+
|
|
239
|
+
from PIL import Image
|
|
240
|
+
|
|
241
|
+
def resize_image(img):
|
|
242
|
+
# Calculate new dimensions while maintaining aspect ratio
|
|
243
|
+
ratio = max_size / max(img.size)
|
|
244
|
+
if ratio < 1:
|
|
245
|
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
|
246
|
+
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
|
247
|
+
return img
|
|
248
|
+
|
|
249
|
+
if isinstance(image, str):
|
|
250
|
+
try:
|
|
251
|
+
# Try to open and resize as file path
|
|
252
|
+
img = Image.open(image)
|
|
253
|
+
img = resize_image(img)
|
|
254
|
+
buffer = io.BytesIO()
|
|
255
|
+
img.save(buffer, format="JPEG")
|
|
256
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
257
|
+
except:
|
|
258
|
+
# If fails, assume it's already base64 string
|
|
259
|
+
return image
|
|
260
|
+
elif isinstance(image, bytes):
|
|
261
|
+
# Convert bytes to image, resize, then back to base64
|
|
262
|
+
img = Image.open(io.BytesIO(image))
|
|
263
|
+
img = resize_image(img)
|
|
264
|
+
buffer = io.BytesIO()
|
|
265
|
+
img.save(buffer, format="JPEG")
|
|
266
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
267
|
+
else:
|
|
268
|
+
raise ValueError("Unsupported image format")
|
|
269
|
+
|
|
270
|
+
def update_history(self, role: str, content: str, images=None) -> None:
|
|
271
|
+
if images is not None:
|
|
272
|
+
content = [
|
|
273
|
+
{"type": "text", "text": content},
|
|
274
|
+
*[
|
|
275
|
+
{
|
|
276
|
+
"type": "_image",
|
|
277
|
+
"_image": self._ensure_base64_encoding(image),
|
|
278
|
+
}
|
|
279
|
+
for image in images
|
|
280
|
+
],
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
self.history.append({"role": role, "content": content})
|
|
284
|
+
|
|
285
|
+
if len(self.history) > self.n_keep:
|
|
286
|
+
self.history = self.history[-self.n_keep :]
|
|
287
|
+
|
|
288
|
+
self.history = self._ensure_alternative_history(self.history)
|
|
289
|
+
self.history = self._ensure_start_from_user(self.history)
|
|
290
|
+
self.history = self._api_format_history(self.history)
|
|
291
|
+
|
|
292
|
+
def verify_model(self) -> None:
|
|
293
|
+
if self.model not in self.available_models:
|
|
294
|
+
message = (
|
|
295
|
+
f"Specified model {self.model} is not supported for the API Key ({self.masked_api_key}). "
|
|
296
|
+
f"Available models for {str(self)} are as follows:\n{self.available_models}"
|
|
297
|
+
)
|
|
298
|
+
raise ValueError(message)
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def masked_api_key(self) -> str:
|
|
302
|
+
return f"{self.api_key[:4]}****{self.api_key[-4:]}"
|
|
303
|
+
|
|
304
|
+
def _add_masked_api_key(self, text: str) -> str:
|
|
305
|
+
return text + f"\n(API Key: {self.masked_api_key}"
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def cost(self) -> float:
|
|
309
|
+
return calc_cost(self.model, self.input_tokens, self.output_tokens)
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
def _to_stream(string: Union[str, List[str]]) -> Generator[str, None, None]:
|
|
313
|
+
"""Converts string or list of strings to generator for streaming."""
|
|
314
|
+
chunks = string if isinstance(string, list) else [string]
|
|
315
|
+
for chunk in chunks:
|
|
316
|
+
if chunk:
|
|
317
|
+
yield chunk
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def main() -> None:
|
|
321
|
+
pass
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
if __name__ == "__main__":
|
|
325
|
+
import scitex
|
|
326
|
+
|
|
327
|
+
CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.gen.start(sys, plt, verbose=False)
|
|
328
|
+
main()
|
|
329
|
+
scitex.gen.close(CONFIG, verbose=False, notify=False)
|
|
330
|
+
|
|
331
|
+
"""
|
|
332
|
+
python src/scitex/ai/_gen_ai/_BaseGenAI.py
|
|
333
|
+
python -m src.scitex.ai._gen_ai._BaseGenAI
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
# EOF
|