scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-11 04:11:10 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/_gen_ai/_Perplexity.py
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Functionality:
|
|
8
|
+
- Implements Perplexity AI interface using OpenAI-compatible API
|
|
9
|
+
- Provides access to Llama and Mixtral models
|
|
10
|
+
Input:
|
|
11
|
+
- User prompts and chat history
|
|
12
|
+
- Model configurations and API credentials
|
|
13
|
+
Output:
|
|
14
|
+
- Generated text responses from Perplexity models
|
|
15
|
+
- Token usage statistics
|
|
16
|
+
Prerequisites:
|
|
17
|
+
- Perplexity API key
|
|
18
|
+
- openai package
|
|
19
|
+
|
|
20
|
+
DEPRECATED: This module is deprecated. Please use perplexity_provider.py instead.
|
|
21
|
+
The new provider-based architecture offers better modularity and maintainability.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
"""Imports"""
|
|
25
|
+
import os
|
|
26
|
+
import sys
|
|
27
|
+
from pprint import pprint
|
|
28
|
+
from typing import Dict, Generator, List, Optional
|
|
29
|
+
import warnings
|
|
30
|
+
|
|
31
|
+
import matplotlib.pyplot as plt
|
|
32
|
+
from openai import OpenAI
|
|
33
|
+
|
|
34
|
+
from .base_genai import BaseGenAI
|
|
35
|
+
|
|
36
|
+
warnings.warn(
|
|
37
|
+
"perplexity.py is deprecated. Please use perplexity_provider.py instead. "
|
|
38
|
+
"See PROVIDER_MIGRATION_GUIDE.md for migration instructions.",
|
|
39
|
+
DeprecationWarning,
|
|
40
|
+
stacklevel=2,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
"""Functions & Classes"""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Perplexity(BaseGenAI):
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
system_setting: str = "",
|
|
50
|
+
model: str = "",
|
|
51
|
+
api_key: str = "",
|
|
52
|
+
stream: bool = False,
|
|
53
|
+
seed: Optional[int] = None,
|
|
54
|
+
n_keep: int = 1,
|
|
55
|
+
temperature: float = 1.0,
|
|
56
|
+
chat_history: Optional[List[Dict[str, str]]] = None,
|
|
57
|
+
max_tokens: Optional[int] = None, # Added parameter
|
|
58
|
+
) -> None:
|
|
59
|
+
# Set max_tokens based on model if not provided
|
|
60
|
+
if max_tokens is None:
|
|
61
|
+
max_tokens = 128_000 if "128k" in model else 32_000
|
|
62
|
+
|
|
63
|
+
super().__init__(
|
|
64
|
+
system_setting=system_setting,
|
|
65
|
+
model=model,
|
|
66
|
+
api_key=api_key,
|
|
67
|
+
stream=stream,
|
|
68
|
+
n_keep=n_keep,
|
|
69
|
+
temperature=temperature,
|
|
70
|
+
provider="Perplexity",
|
|
71
|
+
chat_history=chat_history,
|
|
72
|
+
max_tokens=max_tokens,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def _init_client(self) -> OpenAI:
|
|
76
|
+
return OpenAI(api_key=self.api_key, base_url="https://api.perplexity.ai")
|
|
77
|
+
# return OpenAI(
|
|
78
|
+
# api_key=self.api_key, base_url="https://api.perplexity.ai/chat/completions"
|
|
79
|
+
# )
|
|
80
|
+
|
|
81
|
+
def _api_call_static(self) -> str:
|
|
82
|
+
output = self.client.chat.completions.create(
|
|
83
|
+
model=self.model,
|
|
84
|
+
messages=self.history,
|
|
85
|
+
max_tokens=self.max_tokens,
|
|
86
|
+
stream=False,
|
|
87
|
+
temperature=self.temperature,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
print(output)
|
|
91
|
+
|
|
92
|
+
out_text = output.choices[0].message.content
|
|
93
|
+
self.input_tokens += output.usage.prompt_tokens
|
|
94
|
+
self.output_tokens += output.usage.completion_tokens
|
|
95
|
+
|
|
96
|
+
return out_text
|
|
97
|
+
|
|
98
|
+
def _api_call_stream(self) -> Generator[str, None, None]:
|
|
99
|
+
stream = self.client.chat.completions.create(
|
|
100
|
+
model=self.model,
|
|
101
|
+
messages=self.history,
|
|
102
|
+
max_tokens=self.max_tokens,
|
|
103
|
+
n=1,
|
|
104
|
+
stream=self.stream,
|
|
105
|
+
temperature=self.temperature,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
for chunk in stream:
|
|
109
|
+
if chunk and chunk.choices[0].finish_reason == "stop":
|
|
110
|
+
print(chunk.choices)
|
|
111
|
+
try:
|
|
112
|
+
self.input_tokens += chunk.usage.prompt_tokens
|
|
113
|
+
self.output_tokens += chunk.usage.completion_tokens
|
|
114
|
+
except AttributeError:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
if chunk.choices:
|
|
118
|
+
current_text = chunk.choices[0].delta.content
|
|
119
|
+
if current_text:
|
|
120
|
+
yield current_text
|
|
121
|
+
|
|
122
|
+
def _get_available_models(self) -> List[str]:
|
|
123
|
+
return [
|
|
124
|
+
"llama-3.1-sonar-small-128k-online",
|
|
125
|
+
"llama-3.1-sonar-large-128k-online",
|
|
126
|
+
"llama-3.1-sonar-huge-128k-online",
|
|
127
|
+
"llama-3.1-sonar-small-128k-chat",
|
|
128
|
+
"llama-3.1-sonar-large-128k-chat",
|
|
129
|
+
"llama-3-sonar-small-32k-chat",
|
|
130
|
+
"llama-3-sonar-small-32k-online",
|
|
131
|
+
"llama-3-sonar-large-32k-chat",
|
|
132
|
+
"llama-3-sonar-large-32k-online",
|
|
133
|
+
"llama-3-8b-instruct",
|
|
134
|
+
"llama-3-70b-instruct",
|
|
135
|
+
"mixtral-8x7b-instruct",
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def main() -> None:
|
|
140
|
+
from ._genai_factory import genai_factory as GenAI
|
|
141
|
+
|
|
142
|
+
models = [
|
|
143
|
+
"llama-3.1-sonar-small-128k-online",
|
|
144
|
+
"llama-3.1-sonar-large-128k-online",
|
|
145
|
+
"llama-3.1-sonar-huge-128k-online",
|
|
146
|
+
]
|
|
147
|
+
ai = GenAI(model=models[0], api_key=os.getenv("PERPLEXITY_API_KEY"), stream=False)
|
|
148
|
+
out = ai("tell me about important citations for epilepsy prediction with citations")
|
|
149
|
+
print(out)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def main():
|
|
153
|
+
import requests
|
|
154
|
+
|
|
155
|
+
url = "https://api.perplexity.ai/chat/completions"
|
|
156
|
+
|
|
157
|
+
payload = {
|
|
158
|
+
"model": "llama-3.1-sonar-small-128k-online",
|
|
159
|
+
"messages": [
|
|
160
|
+
{"role": "system", "content": "Be precise and concise."},
|
|
161
|
+
{
|
|
162
|
+
"role": "user",
|
|
163
|
+
"content": "tell me useful citations (scientific peer-reviewed papers) for epilepsy seizure prediction.",
|
|
164
|
+
},
|
|
165
|
+
],
|
|
166
|
+
"max_tokens": 4096,
|
|
167
|
+
"temperature": 0.2,
|
|
168
|
+
"top_p": 0.9,
|
|
169
|
+
"search_domain_filter": ["perplexity.ai"],
|
|
170
|
+
"return_images": False,
|
|
171
|
+
"return_related_questions": False,
|
|
172
|
+
"search_recency_filter": "month",
|
|
173
|
+
"top_k": 0,
|
|
174
|
+
"stream": False,
|
|
175
|
+
"presence_penalty": 0,
|
|
176
|
+
"frequency_penalty": 1,
|
|
177
|
+
}
|
|
178
|
+
api_key = os.getenv("PERPLEXITY_API_KEY")
|
|
179
|
+
headers = {
|
|
180
|
+
"Authorization": f"Bearer {api_key}",
|
|
181
|
+
"Content-Type": "application/json",
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
response = requests.request("POST", url, json=payload, headers=headers)
|
|
185
|
+
|
|
186
|
+
pprint(response.json()["citations"])
|
|
187
|
+
# pprint(response["citations"])
|
|
188
|
+
|
|
189
|
+
# print(response.url)
|
|
190
|
+
# print(response.links)
|
|
191
|
+
# print(dir(response))
|
|
192
|
+
# print(response.text["citations"])
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
if __name__ == "__main__":
|
|
196
|
+
import scitex
|
|
197
|
+
|
|
198
|
+
CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.gen.start(sys, plt, verbose=False)
|
|
199
|
+
main()
|
|
200
|
+
scitex.gen.close(CONFIG, verbose=False, notify=False)
|
|
201
|
+
|
|
202
|
+
# EOF
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-13 20:25:55 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/genai/perplexity_provider.py
|
|
5
|
+
|
|
6
|
+
"""Perplexity AI provider implementation using the new component-based architecture.
|
|
7
|
+
|
|
8
|
+
This module provides integration with Perplexity's API using an OpenAI-compatible interface.
|
|
9
|
+
Perplexity offers access to various Llama and Mixtral models with online search capabilities.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from typing import Dict, List, Iterator, Optional, Any
|
|
14
|
+
import openai
|
|
15
|
+
from openai import OpenAI
|
|
16
|
+
|
|
17
|
+
from .base_provider import BaseProvider, ProviderConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PerplexityProvider(BaseProvider):
|
|
21
|
+
"""Perplexity AI provider implementation."""
|
|
22
|
+
|
|
23
|
+
SUPPORTED_MODELS = [
|
|
24
|
+
"llama-3.1-sonar-small-128k-online",
|
|
25
|
+
"llama-3.1-sonar-large-128k-online",
|
|
26
|
+
"llama-3.1-sonar-huge-128k-online",
|
|
27
|
+
"llama-3.1-sonar-small-128k-chat",
|
|
28
|
+
"llama-3.1-sonar-large-128k-chat",
|
|
29
|
+
"llama-3-sonar-small-32k-chat",
|
|
30
|
+
"llama-3-sonar-small-32k-online",
|
|
31
|
+
"llama-3-sonar-large-32k-chat",
|
|
32
|
+
"llama-3-sonar-large-32k-online",
|
|
33
|
+
"llama-3-8b-instruct",
|
|
34
|
+
"llama-3-70b-instruct",
|
|
35
|
+
"mixtral-8x7b-instruct",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
DEFAULT_MODEL = "llama-3.1-sonar-small-128k-online"
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: ProviderConfig):
|
|
41
|
+
"""Initialize Perplexity provider.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
config: Provider configuration
|
|
45
|
+
"""
|
|
46
|
+
self.config = config
|
|
47
|
+
self.api_key = config.api_key or os.getenv("PERPLEXITY_API_KEY")
|
|
48
|
+
|
|
49
|
+
if not self.api_key:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Perplexity API key not provided. Set PERPLEXITY_API_KEY environment variable or pass api_key."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Initialize OpenAI client with Perplexity endpoint
|
|
55
|
+
self.client = OpenAI(api_key=self.api_key, base_url="https://api.perplexity.ai")
|
|
56
|
+
|
|
57
|
+
# Set default max_tokens based on model if not provided
|
|
58
|
+
if self.config.max_tokens is None:
|
|
59
|
+
if "128k" in (config.model or self.DEFAULT_MODEL):
|
|
60
|
+
self.config.max_tokens = 128_000
|
|
61
|
+
else:
|
|
62
|
+
self.config.max_tokens = 32_000
|
|
63
|
+
|
|
64
|
+
def validate_messages(self, messages: List[Dict[str, Any]]) -> None:
|
|
65
|
+
"""Validate message format.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
messages: List of message dictionaries
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValueError: If messages are invalid
|
|
72
|
+
"""
|
|
73
|
+
if not messages:
|
|
74
|
+
raise ValueError("Messages cannot be empty")
|
|
75
|
+
|
|
76
|
+
for msg in messages:
|
|
77
|
+
if "role" not in msg:
|
|
78
|
+
raise ValueError(f"Missing role in message: {msg}")
|
|
79
|
+
if "content" not in msg:
|
|
80
|
+
raise ValueError(f"Missing content in message: {msg}")
|
|
81
|
+
if msg["role"] not in ["system", "user", "assistant"]:
|
|
82
|
+
raise ValueError(f"Invalid role: {msg['role']}")
|
|
83
|
+
|
|
84
|
+
def format_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
85
|
+
"""Format messages for Perplexity API.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
messages: List of message dictionaries
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Formatted messages
|
|
92
|
+
"""
|
|
93
|
+
formatted = []
|
|
94
|
+
|
|
95
|
+
# Add system prompt if configured
|
|
96
|
+
if self.config.system_prompt:
|
|
97
|
+
formatted.append({"role": "system", "content": self.config.system_prompt})
|
|
98
|
+
|
|
99
|
+
# Add user messages
|
|
100
|
+
formatted.extend(messages)
|
|
101
|
+
|
|
102
|
+
return formatted
|
|
103
|
+
|
|
104
|
+
def complete(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
|
105
|
+
"""Generate a completion.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
messages: List of message dictionaries
|
|
109
|
+
**kwargs: Additional parameters for the API
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Completion response dictionary
|
|
113
|
+
"""
|
|
114
|
+
self.validate_messages(messages)
|
|
115
|
+
formatted_messages = self.format_messages(messages)
|
|
116
|
+
|
|
117
|
+
# Merge config parameters with kwargs
|
|
118
|
+
params = {
|
|
119
|
+
"model": self.config.model or self.DEFAULT_MODEL,
|
|
120
|
+
"messages": formatted_messages,
|
|
121
|
+
"temperature": self.config.temperature,
|
|
122
|
+
"max_tokens": self.config.max_tokens,
|
|
123
|
+
"stream": False,
|
|
124
|
+
}
|
|
125
|
+
params.update(kwargs)
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
response = self.client.chat.completions.create(**params)
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
"content": response.choices[0].message.content,
|
|
132
|
+
"model": response.model,
|
|
133
|
+
"usage": {
|
|
134
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
135
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
136
|
+
"total_tokens": response.usage.total_tokens,
|
|
137
|
+
},
|
|
138
|
+
"finish_reason": response.choices[0].finish_reason,
|
|
139
|
+
}
|
|
140
|
+
except Exception as e:
|
|
141
|
+
raise RuntimeError(f"Perplexity API error: {str(e)}")
|
|
142
|
+
|
|
143
|
+
def stream(
|
|
144
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
|
145
|
+
) -> Iterator[Dict[str, Any]]:
|
|
146
|
+
"""Stream a completion.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
messages: List of message dictionaries
|
|
150
|
+
**kwargs: Additional parameters for the API
|
|
151
|
+
|
|
152
|
+
Yields:
|
|
153
|
+
Chunks of the completion
|
|
154
|
+
"""
|
|
155
|
+
self.validate_messages(messages)
|
|
156
|
+
formatted_messages = self.format_messages(messages)
|
|
157
|
+
|
|
158
|
+
# Merge config parameters with kwargs
|
|
159
|
+
params = {
|
|
160
|
+
"model": self.config.model or self.DEFAULT_MODEL,
|
|
161
|
+
"messages": formatted_messages,
|
|
162
|
+
"temperature": self.config.temperature,
|
|
163
|
+
"max_tokens": self.config.max_tokens,
|
|
164
|
+
"stream": True,
|
|
165
|
+
}
|
|
166
|
+
params.update(kwargs)
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
stream = self.client.chat.completions.create(**params)
|
|
170
|
+
|
|
171
|
+
for chunk in stream:
|
|
172
|
+
if chunk.choices:
|
|
173
|
+
content = chunk.choices[0].delta.content
|
|
174
|
+
if content:
|
|
175
|
+
yield {
|
|
176
|
+
"content": content,
|
|
177
|
+
"model": (
|
|
178
|
+
chunk.model
|
|
179
|
+
if hasattr(chunk, "model")
|
|
180
|
+
else params["model"]
|
|
181
|
+
),
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
# Check for usage in final chunk
|
|
185
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
|
186
|
+
yield {
|
|
187
|
+
"content": "",
|
|
188
|
+
"usage": {
|
|
189
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
|
190
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
|
191
|
+
"total_tokens": chunk.usage.total_tokens,
|
|
192
|
+
},
|
|
193
|
+
"finish_reason": (
|
|
194
|
+
chunk.choices[0].finish_reason if chunk.choices else None
|
|
195
|
+
),
|
|
196
|
+
}
|
|
197
|
+
elif chunk.choices and chunk.choices[0].finish_reason == "stop":
|
|
198
|
+
# Handle case where usage might be in a stop chunk
|
|
199
|
+
yield {
|
|
200
|
+
"content": "",
|
|
201
|
+
"finish_reason": "stop",
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
except Exception as e:
|
|
205
|
+
raise RuntimeError(f"Perplexity API error: {str(e)}")
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-15 12:00:00"
|
|
4
|
+
# Author: Yusuke Watanabe (ywatanabe@alumni.u-tokyo.ac.jp)
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Provider base implementation using composition pattern.
|
|
8
|
+
|
|
9
|
+
This module provides the concrete base class that combines all components
|
|
10
|
+
to implement the provider interface.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from .auth_manager import AuthManager
|
|
18
|
+
from .base_provider import BaseProvider
|
|
19
|
+
from .chat_history import ChatHistory
|
|
20
|
+
from .cost_tracker import CostTracker, TokenUsage
|
|
21
|
+
from .image_processor import ImageProcessor
|
|
22
|
+
from .model_registry import ModelInfo, ModelRegistry
|
|
23
|
+
from .response_handler import ResponseHandler
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class ProviderConfig:
|
|
28
|
+
"""Configuration for provider initialization."""
|
|
29
|
+
|
|
30
|
+
api_key: Optional[str] = None
|
|
31
|
+
model: str = "gpt-3.5-turbo"
|
|
32
|
+
system_prompt: Optional[str] = None
|
|
33
|
+
stream: bool = False
|
|
34
|
+
seed: Optional[int] = None
|
|
35
|
+
max_tokens: Optional[int] = None
|
|
36
|
+
temperature: float = 0.0
|
|
37
|
+
n_draft: int = 1
|
|
38
|
+
kwargs: Optional[Dict[str, Any]] = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ProviderBase(BaseProvider):
|
|
42
|
+
"""
|
|
43
|
+
Base implementation using composition pattern.
|
|
44
|
+
|
|
45
|
+
This class combines all components to provide a complete implementation
|
|
46
|
+
of the provider interface. Concrete providers should inherit from this
|
|
47
|
+
class and implement provider-specific methods.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
provider_name: str,
|
|
53
|
+
config: ProviderConfig,
|
|
54
|
+
auth_manager: Optional[AuthManager] = None,
|
|
55
|
+
model_registry: Optional[ModelRegistry] = None,
|
|
56
|
+
chat_history: Optional[ChatHistory] = None,
|
|
57
|
+
cost_tracker: Optional[CostTracker] = None,
|
|
58
|
+
response_handler: Optional[ResponseHandler] = None,
|
|
59
|
+
image_processor: Optional[ImageProcessor] = None,
|
|
60
|
+
):
|
|
61
|
+
"""Initialize provider with components."""
|
|
62
|
+
self.provider_name = provider_name
|
|
63
|
+
self.config = config
|
|
64
|
+
|
|
65
|
+
# Initialize components
|
|
66
|
+
self.auth_manager = auth_manager or AuthManager()
|
|
67
|
+
self.model_registry = model_registry or ModelRegistry()
|
|
68
|
+
self.chat_history = chat_history or ChatHistory()
|
|
69
|
+
self.cost_tracker = cost_tracker or CostTracker()
|
|
70
|
+
self.response_handler = response_handler or ResponseHandler()
|
|
71
|
+
self.image_processor = image_processor or ImageProcessor()
|
|
72
|
+
|
|
73
|
+
# Get and validate API key
|
|
74
|
+
self.api_key = self.auth_manager.get_api_key(provider_name, config.api_key)
|
|
75
|
+
|
|
76
|
+
# Initialize provider-specific attributes
|
|
77
|
+
self.model = config.model
|
|
78
|
+
self.system_prompt = config.system_prompt
|
|
79
|
+
self.stream = config.stream
|
|
80
|
+
self.seed = config.seed
|
|
81
|
+
self.max_tokens = config.max_tokens
|
|
82
|
+
self.temperature = config.temperature
|
|
83
|
+
self.n_draft = config.n_draft
|
|
84
|
+
self.kwargs = config.kwargs or {}
|
|
85
|
+
|
|
86
|
+
# Get model info
|
|
87
|
+
self.model_info = self._get_model_info()
|
|
88
|
+
|
|
89
|
+
def _get_model_info(self) -> ModelInfo:
|
|
90
|
+
"""Get model information from registry."""
|
|
91
|
+
model_info = self.model_registry.get_model_info(self.model)
|
|
92
|
+
if not model_info:
|
|
93
|
+
# Create default model info if not found
|
|
94
|
+
model_info = ModelInfo(
|
|
95
|
+
name=self.model,
|
|
96
|
+
provider=self.provider_name,
|
|
97
|
+
max_tokens=4096, # Default
|
|
98
|
+
supports_images=False,
|
|
99
|
+
supports_streaming=True,
|
|
100
|
+
)
|
|
101
|
+
warnings.warn(
|
|
102
|
+
f"Model {self.model} not found in registry. Using defaults.",
|
|
103
|
+
UserWarning,
|
|
104
|
+
)
|
|
105
|
+
return model_info
|
|
106
|
+
|
|
107
|
+
def call(
|
|
108
|
+
self,
|
|
109
|
+
messages: Union[str, List[Dict[str, Any]]],
|
|
110
|
+
**kwargs: Any,
|
|
111
|
+
) -> Union[str, Iterator[str]]:
|
|
112
|
+
"""
|
|
113
|
+
Main method to interact with the AI provider.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
messages : Union[str, List[Dict[str, Any]]]
|
|
118
|
+
Input messages or prompt
|
|
119
|
+
**kwargs : Any
|
|
120
|
+
Additional parameters for the API call
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
Union[str, Iterator[str]]
|
|
125
|
+
Response text or streaming iterator
|
|
126
|
+
"""
|
|
127
|
+
# Merge kwargs with instance kwargs
|
|
128
|
+
call_kwargs = {**self.kwargs, **kwargs}
|
|
129
|
+
|
|
130
|
+
# Process messages
|
|
131
|
+
processed_messages = self._process_messages(messages)
|
|
132
|
+
|
|
133
|
+
# Add system prompt if provided
|
|
134
|
+
if self.system_prompt:
|
|
135
|
+
processed_messages = self._add_system_prompt(processed_messages)
|
|
136
|
+
|
|
137
|
+
# Process images if present
|
|
138
|
+
processed_messages = self._process_images_in_messages(processed_messages)
|
|
139
|
+
|
|
140
|
+
# Store messages in history
|
|
141
|
+
for msg in processed_messages:
|
|
142
|
+
if msg["role"] != "system":
|
|
143
|
+
self.chat_history.add_message(msg["role"], msg["content"])
|
|
144
|
+
|
|
145
|
+
# Ensure alternating messages
|
|
146
|
+
self.chat_history.ensure_alternating()
|
|
147
|
+
|
|
148
|
+
# Make API call (to be implemented by concrete providers)
|
|
149
|
+
response = self._make_api_call(processed_messages, **call_kwargs)
|
|
150
|
+
|
|
151
|
+
# Handle response based on stream mode
|
|
152
|
+
if self.stream:
|
|
153
|
+
return self._handle_streaming_response(response)
|
|
154
|
+
else:
|
|
155
|
+
return self._handle_static_response(response)
|
|
156
|
+
|
|
157
|
+
def _process_messages(
|
|
158
|
+
self, messages: Union[str, List[Dict[str, Any]]]
|
|
159
|
+
) -> List[Dict[str, Any]]:
|
|
160
|
+
"""Process input messages into standard format."""
|
|
161
|
+
if isinstance(messages, str):
|
|
162
|
+
return [{"role": "user", "content": messages}]
|
|
163
|
+
return messages
|
|
164
|
+
|
|
165
|
+
def _add_system_prompt(
|
|
166
|
+
self, messages: List[Dict[str, Any]]
|
|
167
|
+
) -> List[Dict[str, Any]]:
|
|
168
|
+
"""Add system prompt to messages."""
|
|
169
|
+
if messages and messages[0]["role"] == "system":
|
|
170
|
+
# Replace existing system prompt
|
|
171
|
+
messages[0]["content"] = self.system_prompt
|
|
172
|
+
else:
|
|
173
|
+
# Insert system prompt at beginning
|
|
174
|
+
messages.insert(0, {"role": "system", "content": self.system_prompt})
|
|
175
|
+
return messages
|
|
176
|
+
|
|
177
|
+
def _process_images_in_messages(
|
|
178
|
+
self, messages: List[Dict[str, Any]]
|
|
179
|
+
) -> List[Dict[str, Any]]:
|
|
180
|
+
"""Process images in messages if model supports it."""
|
|
181
|
+
if not self.model_info.supports_images:
|
|
182
|
+
return messages
|
|
183
|
+
|
|
184
|
+
processed_messages = []
|
|
185
|
+
for msg in messages:
|
|
186
|
+
if isinstance(msg.get("content"), list):
|
|
187
|
+
# Process multimodal content
|
|
188
|
+
processed_content = []
|
|
189
|
+
for item in msg["content"]:
|
|
190
|
+
if item.get("type") == "image" and "path" in item:
|
|
191
|
+
# Process image file
|
|
192
|
+
image_data = self.image_processor.process_image(
|
|
193
|
+
item["path"], max_size=item.get("max_size", 2048)
|
|
194
|
+
)
|
|
195
|
+
processed_content.append(
|
|
196
|
+
{
|
|
197
|
+
"type": "image",
|
|
198
|
+
"data": image_data["data"],
|
|
199
|
+
"mime_type": image_data["mime_type"],
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
processed_content.append(item)
|
|
204
|
+
|
|
205
|
+
processed_messages.append(
|
|
206
|
+
{
|
|
207
|
+
"role": msg["role"],
|
|
208
|
+
"content": processed_content,
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
processed_messages.append(msg)
|
|
213
|
+
|
|
214
|
+
return processed_messages
|
|
215
|
+
|
|
216
|
+
def _handle_static_response(self, response: Any) -> str:
|
|
217
|
+
"""Handle static response from API."""
|
|
218
|
+
result = self.response_handler.handle_static_response(
|
|
219
|
+
response, self.provider_name
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Track usage
|
|
223
|
+
if result.usage:
|
|
224
|
+
self.cost_tracker.track_usage(
|
|
225
|
+
self.model,
|
|
226
|
+
TokenUsage(
|
|
227
|
+
input_tokens=result.usage.input_tokens,
|
|
228
|
+
output_tokens=result.usage.output_tokens,
|
|
229
|
+
),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Add to history
|
|
233
|
+
self.chat_history.add_message("assistant", result.content)
|
|
234
|
+
|
|
235
|
+
return result.content
|
|
236
|
+
|
|
237
|
+
def _handle_streaming_response(self, response: Any) -> Iterator[str]:
|
|
238
|
+
"""Handle streaming response from API."""
|
|
239
|
+
full_content = []
|
|
240
|
+
total_usage = TokenUsage()
|
|
241
|
+
|
|
242
|
+
for chunk in self.response_handler.handle_streaming_response(
|
|
243
|
+
response, self.provider_name
|
|
244
|
+
):
|
|
245
|
+
if chunk.content:
|
|
246
|
+
full_content.append(chunk.content)
|
|
247
|
+
yield chunk.content
|
|
248
|
+
|
|
249
|
+
if chunk.usage:
|
|
250
|
+
total_usage.input_tokens += chunk.usage.input_tokens
|
|
251
|
+
total_usage.output_tokens += chunk.usage.output_tokens
|
|
252
|
+
|
|
253
|
+
# Track total usage
|
|
254
|
+
if total_usage.input_tokens > 0 or total_usage.output_tokens > 0:
|
|
255
|
+
self.cost_tracker.track_usage(self.model, total_usage)
|
|
256
|
+
|
|
257
|
+
# Add complete response to history
|
|
258
|
+
complete_content = "".join(full_content)
|
|
259
|
+
if complete_content:
|
|
260
|
+
self.chat_history.add_message("assistant", complete_content)
|
|
261
|
+
|
|
262
|
+
def _make_api_call(self, messages: List[Dict[str, Any]], **kwargs: Any) -> Any:
|
|
263
|
+
"""
|
|
264
|
+
Make API call to the provider.
|
|
265
|
+
|
|
266
|
+
This method must be implemented by concrete providers.
|
|
267
|
+
"""
|
|
268
|
+
raise NotImplementedError(
|
|
269
|
+
f"{self.__class__.__name__} must implement _make_api_call"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def get_usage_stats(self) -> Dict[str, Any]:
|
|
273
|
+
"""Get usage statistics."""
|
|
274
|
+
return self.cost_tracker.get_usage_stats()
|
|
275
|
+
|
|
276
|
+
def reset_usage_stats(self) -> None:
|
|
277
|
+
"""Reset usage statistics."""
|
|
278
|
+
self.cost_tracker.reset()
|
|
279
|
+
|
|
280
|
+
def clear_history(self) -> None:
|
|
281
|
+
"""Clear chat history."""
|
|
282
|
+
self.chat_history.clear()
|
|
283
|
+
|
|
284
|
+
def get_history(self) -> List[Dict[str, str]]:
|
|
285
|
+
"""Get chat history."""
|
|
286
|
+
return self.chat_history.get_messages()
|
|
287
|
+
|
|
288
|
+
def set_system_prompt(self, prompt: str) -> None:
|
|
289
|
+
"""Update system prompt."""
|
|
290
|
+
self.system_prompt = prompt
|
|
291
|
+
|
|
292
|
+
def __repr__(self) -> str:
|
|
293
|
+
"""String representation."""
|
|
294
|
+
return (
|
|
295
|
+
f"{self.__class__.__name__}("
|
|
296
|
+
f"provider={self.provider_name}, "
|
|
297
|
+
f"model={self.model}, "
|
|
298
|
+
f"stream={self.stream})"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
## EOF
|