scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-15 12:00:00"
|
|
4
|
+
# Author: Yusuke Watanabe (ywatanabe@alumni.u-tokyo.ac.jp)
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Factory for creating AI provider instances.
|
|
8
|
+
|
|
9
|
+
This module provides a factory pattern for instantiating different AI providers
|
|
10
|
+
with consistent configuration.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any, Dict, Optional, Type, Union
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from .base_provider import BaseProvider, Provider
|
|
18
|
+
from .provider_base import ProviderConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProviderRegistry:
|
|
22
|
+
"""Registry for managing AI providers and their aliases."""
|
|
23
|
+
|
|
24
|
+
def __init__(self):
|
|
25
|
+
"""Initialize the registry with provider storage and aliases."""
|
|
26
|
+
self._providers: Dict[Provider, Type[BaseProvider]] = {}
|
|
27
|
+
self._aliases: Dict[str, Provider] = {
|
|
28
|
+
# OpenAI aliases
|
|
29
|
+
"openai": Provider.OPENAI,
|
|
30
|
+
"gpt": Provider.OPENAI,
|
|
31
|
+
"gpt-3": Provider.OPENAI,
|
|
32
|
+
"gpt-3.5": Provider.OPENAI,
|
|
33
|
+
"gpt-4": Provider.OPENAI,
|
|
34
|
+
"gpt-4o": Provider.OPENAI,
|
|
35
|
+
"o1": Provider.OPENAI,
|
|
36
|
+
# Anthropic aliases
|
|
37
|
+
"anthropic": Provider.ANTHROPIC,
|
|
38
|
+
"claude": Provider.ANTHROPIC,
|
|
39
|
+
"claude-2": Provider.ANTHROPIC,
|
|
40
|
+
"claude-3": Provider.ANTHROPIC,
|
|
41
|
+
"claude-3-opus": Provider.ANTHROPIC,
|
|
42
|
+
"claude-3-sonnet": Provider.ANTHROPIC,
|
|
43
|
+
"claude-3-haiku": Provider.ANTHROPIC,
|
|
44
|
+
# Google aliases
|
|
45
|
+
"google": Provider.GOOGLE,
|
|
46
|
+
"gemini": Provider.GOOGLE,
|
|
47
|
+
"bard": Provider.GOOGLE,
|
|
48
|
+
"bison": Provider.GOOGLE,
|
|
49
|
+
"palm": Provider.GOOGLE,
|
|
50
|
+
# Groq aliases
|
|
51
|
+
"groq": Provider.GROQ,
|
|
52
|
+
"mixtral": Provider.GROQ,
|
|
53
|
+
"llama": Provider.GROQ,
|
|
54
|
+
"llama2": Provider.GROQ,
|
|
55
|
+
"llama3": Provider.GROQ,
|
|
56
|
+
# Perplexity aliases
|
|
57
|
+
"perplexity": Provider.PERPLEXITY,
|
|
58
|
+
"pplx": Provider.PERPLEXITY,
|
|
59
|
+
# DeepSeek aliases
|
|
60
|
+
"deepseek": Provider.DEEPSEEK,
|
|
61
|
+
"deepseek-coder": Provider.DEEPSEEK,
|
|
62
|
+
"deepseek-chat": Provider.DEEPSEEK,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
def register(self, provider: Provider, provider_class: Type[BaseProvider]) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Register a provider implementation.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
provider : Provider
|
|
72
|
+
Provider enum value
|
|
73
|
+
provider_class : Type[BaseProvider]
|
|
74
|
+
Provider implementation class
|
|
75
|
+
"""
|
|
76
|
+
self._providers[provider] = provider_class
|
|
77
|
+
|
|
78
|
+
def get(self, provider: Provider) -> Type[BaseProvider]:
|
|
79
|
+
"""
|
|
80
|
+
Get a registered provider class.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
provider : Provider
|
|
85
|
+
Provider enum value
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
Type[BaseProvider]
|
|
90
|
+
Provider implementation class
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
ValueError
|
|
95
|
+
If provider is not registered
|
|
96
|
+
"""
|
|
97
|
+
if provider not in self._providers:
|
|
98
|
+
raise ValueError(f"Provider {provider} is not registered")
|
|
99
|
+
return self._providers[provider]
|
|
100
|
+
|
|
101
|
+
def resolve_provider(self, provider_or_model: str) -> Provider:
|
|
102
|
+
"""
|
|
103
|
+
Resolve a provider from a string or model name.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
provider_or_model : str
|
|
108
|
+
Provider name, alias, or model name
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
Provider
|
|
113
|
+
Resolved provider enum value
|
|
114
|
+
|
|
115
|
+
Raises
|
|
116
|
+
------
|
|
117
|
+
ValueError
|
|
118
|
+
If provider cannot be resolved
|
|
119
|
+
"""
|
|
120
|
+
provider_lower = provider_or_model.lower()
|
|
121
|
+
|
|
122
|
+
# Direct provider name match
|
|
123
|
+
for p in Provider:
|
|
124
|
+
if p.value == provider_lower:
|
|
125
|
+
return p
|
|
126
|
+
|
|
127
|
+
# Alias match
|
|
128
|
+
if provider_lower in self._aliases:
|
|
129
|
+
return self._aliases[provider_lower]
|
|
130
|
+
|
|
131
|
+
# Try to infer from model name patterns
|
|
132
|
+
model_patterns = {
|
|
133
|
+
Provider.OPENAI: [
|
|
134
|
+
"gpt-",
|
|
135
|
+
"o1-",
|
|
136
|
+
"text-davinci",
|
|
137
|
+
"text-curie",
|
|
138
|
+
"text-babbage",
|
|
139
|
+
"text-ada",
|
|
140
|
+
],
|
|
141
|
+
Provider.ANTHROPIC: ["claude-"],
|
|
142
|
+
Provider.GOOGLE: ["gemini-", "palm-", "bison"],
|
|
143
|
+
Provider.GROQ: ["mixtral-", "llama-"],
|
|
144
|
+
Provider.PERPLEXITY: ["pplx-", "perplexity-"],
|
|
145
|
+
Provider.DEEPSEEK: ["deepseek-"],
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
for provider, patterns in model_patterns.items():
|
|
149
|
+
if any(provider_lower.startswith(pattern) for pattern in patterns):
|
|
150
|
+
return provider
|
|
151
|
+
|
|
152
|
+
raise ValueError(f"Cannot resolve provider from: {provider_or_model}")
|
|
153
|
+
|
|
154
|
+
def list_providers(self) -> list[Provider]:
|
|
155
|
+
"""
|
|
156
|
+
List registered providers.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
list[Provider]
|
|
161
|
+
List of registered provider enums
|
|
162
|
+
"""
|
|
163
|
+
return list(self._providers.keys())
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# Global registry instance
|
|
167
|
+
_registry = ProviderRegistry()
|
|
168
|
+
|
|
169
|
+
# Auto-register providers when they're imported
|
|
170
|
+
_auto_register_called = False
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _auto_register():
|
|
174
|
+
"""Auto-register available provider implementations."""
|
|
175
|
+
global _auto_register_called
|
|
176
|
+
if _auto_register_called:
|
|
177
|
+
return
|
|
178
|
+
_auto_register_called = True
|
|
179
|
+
|
|
180
|
+
# Try to import and register providers
|
|
181
|
+
try:
|
|
182
|
+
# Import providers here to trigger their registration
|
|
183
|
+
# Each provider module should register itself when imported
|
|
184
|
+
from . import mock_provider # For testing
|
|
185
|
+
from . import anthropic_provider
|
|
186
|
+
from . import openai_provider
|
|
187
|
+
from . import google_provider
|
|
188
|
+
from . import groq_provider
|
|
189
|
+
from . import perplexity_provider
|
|
190
|
+
from . import deepseek_provider
|
|
191
|
+
from . import llama_provider
|
|
192
|
+
except ImportError as e:
|
|
193
|
+
# Log import errors but continue
|
|
194
|
+
import warnings
|
|
195
|
+
|
|
196
|
+
warnings.warn(f"Failed to import some providers: {e}")
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ModelRegistry:
|
|
200
|
+
"""Registry for model information."""
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def get_models_for_provider(provider: str) -> list[str]:
|
|
204
|
+
"""Get available models for a provider."""
|
|
205
|
+
# This would be implemented with actual model data
|
|
206
|
+
return []
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Module-level convenience functions
|
|
210
|
+
def register_provider(name: str, provider_class: Type[BaseProvider]) -> None:
|
|
211
|
+
"""
|
|
212
|
+
Register a provider implementation.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
name : str
|
|
217
|
+
Provider name
|
|
218
|
+
provider_class : Type[BaseProvider]
|
|
219
|
+
Provider implementation class
|
|
220
|
+
"""
|
|
221
|
+
provider = _registry.resolve_provider(name)
|
|
222
|
+
_registry.register(provider, provider_class)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def create_provider(
|
|
226
|
+
provider: str,
|
|
227
|
+
api_key: Optional[str] = None,
|
|
228
|
+
model: str = "gpt-3.5-turbo",
|
|
229
|
+
system_prompt: Optional[str] = None,
|
|
230
|
+
stream: bool = False,
|
|
231
|
+
seed: Optional[int] = None,
|
|
232
|
+
max_tokens: Optional[int] = None,
|
|
233
|
+
temperature: float = 0.0,
|
|
234
|
+
n_draft: int = 1,
|
|
235
|
+
**kwargs: Any,
|
|
236
|
+
) -> BaseProvider:
|
|
237
|
+
"""
|
|
238
|
+
Create a provider instance.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
provider : str
|
|
243
|
+
Provider name, alias, or model name
|
|
244
|
+
api_key : Optional[str]
|
|
245
|
+
API key for authentication
|
|
246
|
+
model : str
|
|
247
|
+
Model name to use
|
|
248
|
+
system_prompt : Optional[str]
|
|
249
|
+
System prompt to prepend to messages
|
|
250
|
+
stream : bool
|
|
251
|
+
Whether to stream responses
|
|
252
|
+
seed : Optional[int]
|
|
253
|
+
Random seed for reproducibility
|
|
254
|
+
max_tokens : Optional[int]
|
|
255
|
+
Maximum tokens in response
|
|
256
|
+
temperature : float
|
|
257
|
+
Sampling temperature
|
|
258
|
+
n_draft : int
|
|
259
|
+
Number of drafts to generate
|
|
260
|
+
**kwargs : Any
|
|
261
|
+
Additional provider-specific parameters
|
|
262
|
+
|
|
263
|
+
Returns
|
|
264
|
+
-------
|
|
265
|
+
BaseProvider
|
|
266
|
+
Provider instance
|
|
267
|
+
"""
|
|
268
|
+
# Auto-register providers
|
|
269
|
+
_auto_register()
|
|
270
|
+
|
|
271
|
+
# Resolve provider
|
|
272
|
+
provider_enum = _registry.resolve_provider(provider)
|
|
273
|
+
|
|
274
|
+
# Get provider class
|
|
275
|
+
provider_class = _registry.get(provider_enum)
|
|
276
|
+
|
|
277
|
+
# Create configuration
|
|
278
|
+
config = ProviderConfig(
|
|
279
|
+
api_key=api_key,
|
|
280
|
+
model=model,
|
|
281
|
+
system_prompt=system_prompt,
|
|
282
|
+
stream=stream,
|
|
283
|
+
seed=seed,
|
|
284
|
+
max_tokens=max_tokens,
|
|
285
|
+
temperature=temperature,
|
|
286
|
+
n_draft=n_draft,
|
|
287
|
+
kwargs=kwargs,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Instantiate provider
|
|
291
|
+
return provider_class(config)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def GenAI(
|
|
295
|
+
api_key: Optional[str] = None,
|
|
296
|
+
model: str = "gpt-3.5-turbo",
|
|
297
|
+
system_prompt: Optional[str] = None,
|
|
298
|
+
stream: bool = False,
|
|
299
|
+
seed: Optional[int] = None,
|
|
300
|
+
max_tokens: Optional[int] = None,
|
|
301
|
+
temperature: float = 0.0,
|
|
302
|
+
n_draft: int = 1,
|
|
303
|
+
provider: Optional[str] = None,
|
|
304
|
+
**kwargs: Any,
|
|
305
|
+
) -> BaseProvider:
|
|
306
|
+
"""
|
|
307
|
+
Create an AI provider instance (backward compatibility).
|
|
308
|
+
|
|
309
|
+
This function maintains backward compatibility with the old API.
|
|
310
|
+
If provider is not specified, it infers from the model name.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
api_key : Optional[str]
|
|
315
|
+
API key for authentication
|
|
316
|
+
model : str
|
|
317
|
+
Model name to use
|
|
318
|
+
system_prompt : Optional[str]
|
|
319
|
+
System prompt to prepend to messages
|
|
320
|
+
stream : bool
|
|
321
|
+
Whether to stream responses
|
|
322
|
+
seed : Optional[int]
|
|
323
|
+
Random seed for reproducibility
|
|
324
|
+
max_tokens : Optional[int]
|
|
325
|
+
Maximum tokens in response
|
|
326
|
+
temperature : float
|
|
327
|
+
Sampling temperature
|
|
328
|
+
n_draft : int
|
|
329
|
+
Number of drafts to generate
|
|
330
|
+
provider : Optional[str]
|
|
331
|
+
Provider name (if not specified, inferred from model)
|
|
332
|
+
**kwargs : Any
|
|
333
|
+
Additional provider-specific parameters
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
BaseProvider
|
|
338
|
+
Provider instance
|
|
339
|
+
"""
|
|
340
|
+
# If provider is explicitly specified, use it
|
|
341
|
+
if provider:
|
|
342
|
+
return create_provider(
|
|
343
|
+
provider=provider,
|
|
344
|
+
api_key=api_key,
|
|
345
|
+
model=model,
|
|
346
|
+
system_prompt=system_prompt,
|
|
347
|
+
stream=stream,
|
|
348
|
+
seed=seed,
|
|
349
|
+
max_tokens=max_tokens,
|
|
350
|
+
temperature=temperature,
|
|
351
|
+
n_draft=n_draft,
|
|
352
|
+
**kwargs,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Otherwise, try to infer from model name
|
|
356
|
+
return create_provider(
|
|
357
|
+
provider=model, # Let resolve_provider handle it
|
|
358
|
+
api_key=api_key,
|
|
359
|
+
model=model,
|
|
360
|
+
system_prompt=system_prompt,
|
|
361
|
+
stream=stream,
|
|
362
|
+
seed=seed,
|
|
363
|
+
max_tokens=max_tokens,
|
|
364
|
+
temperature=temperature,
|
|
365
|
+
n_draft=n_draft,
|
|
366
|
+
**kwargs,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
## EOF
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2025-05-31 10:20:00"
|
|
4
|
+
# Author: ywatanabe
|
|
5
|
+
# File: ./src/scitex/ai/genai/response_handler.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Handles response processing for AI providers.
|
|
9
|
+
|
|
10
|
+
This module provides response handling functionality including:
|
|
11
|
+
- Static response processing
|
|
12
|
+
- Stream response handling
|
|
13
|
+
- Output formatting
|
|
14
|
+
- Error response generation
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import sys
|
|
18
|
+
from typing import Generator, List, Union, Optional, Any
|
|
19
|
+
|
|
20
|
+
from .format_output_func import format_output_func
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ResponseHandler:
|
|
24
|
+
"""Handles processing of AI provider responses.
|
|
25
|
+
|
|
26
|
+
Example
|
|
27
|
+
-------
|
|
28
|
+
>>> handler = ResponseHandler()
|
|
29
|
+
>>> # Process static response
|
|
30
|
+
>>> result = handler.process_static("Hello, world!")
|
|
31
|
+
>>> print(result)
|
|
32
|
+
Hello, world!
|
|
33
|
+
|
|
34
|
+
>>> # Process stream
|
|
35
|
+
>>> stream = ["Hello", ", ", "world!"]
|
|
36
|
+
>>> for chunk in handler.process_stream(stream):
|
|
37
|
+
... print(chunk, end="")
|
|
38
|
+
Hello, world!
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
"""Initialize response handler."""
|
|
43
|
+
self._accumulated_response = []
|
|
44
|
+
|
|
45
|
+
def process_static(self, response: str, format_output: bool = False) -> str:
|
|
46
|
+
"""Process a static (non-streaming) response.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
response : str
|
|
51
|
+
The response text to process
|
|
52
|
+
format_output : bool
|
|
53
|
+
Whether to apply output formatting
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
str
|
|
58
|
+
Processed response text
|
|
59
|
+
"""
|
|
60
|
+
if format_output:
|
|
61
|
+
response = self.format_output(response)
|
|
62
|
+
return response
|
|
63
|
+
|
|
64
|
+
def process_stream(
|
|
65
|
+
self, stream: Generator[str, None, None], format_output: bool = False
|
|
66
|
+
) -> Generator[str, None, None]:
|
|
67
|
+
"""Process a streaming response.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
stream : Generator[str, None, None]
|
|
72
|
+
The stream of response chunks
|
|
73
|
+
format_output : bool
|
|
74
|
+
Whether to apply output formatting to final result
|
|
75
|
+
|
|
76
|
+
Yields
|
|
77
|
+
------
|
|
78
|
+
str
|
|
79
|
+
Response chunks as they arrive
|
|
80
|
+
"""
|
|
81
|
+
self._accumulated_response = []
|
|
82
|
+
|
|
83
|
+
for chunk in stream:
|
|
84
|
+
if chunk:
|
|
85
|
+
self._accumulated_response.append(chunk)
|
|
86
|
+
yield chunk
|
|
87
|
+
|
|
88
|
+
# Apply formatting to accumulated response if requested
|
|
89
|
+
if format_output and self._accumulated_response:
|
|
90
|
+
full_response = "".join(self._accumulated_response)
|
|
91
|
+
formatted = self.format_output(full_response)
|
|
92
|
+
|
|
93
|
+
# If formatting changed the response, yield the difference
|
|
94
|
+
if formatted != full_response:
|
|
95
|
+
# This is tricky - we've already yielded the unformatted chunks
|
|
96
|
+
# In practice, formatting is usually applied after streaming
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
def format_output(self, text: str) -> str:
|
|
100
|
+
"""Apply output formatting to text.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
text : str
|
|
105
|
+
Text to format
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
str
|
|
110
|
+
Formatted text
|
|
111
|
+
"""
|
|
112
|
+
return format_output_func(text)
|
|
113
|
+
|
|
114
|
+
def yield_stream_with_print(self, stream: Generator[str, None, None]) -> str:
|
|
115
|
+
"""Yield stream chunks while printing to stdout.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
stream : Generator[str, None, None]
|
|
120
|
+
The stream to process
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
str
|
|
125
|
+
Complete accumulated response
|
|
126
|
+
"""
|
|
127
|
+
accumulated = []
|
|
128
|
+
|
|
129
|
+
for chunk in stream:
|
|
130
|
+
if chunk:
|
|
131
|
+
sys.stdout.write(chunk)
|
|
132
|
+
sys.stdout.flush()
|
|
133
|
+
accumulated.append(chunk)
|
|
134
|
+
|
|
135
|
+
return "".join(accumulated)
|
|
136
|
+
|
|
137
|
+
def create_error_response(
|
|
138
|
+
self, error_messages: List[str], as_stream: bool = False
|
|
139
|
+
) -> Union[str, Generator[str, None, None]]:
|
|
140
|
+
"""Create an error response.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
error_messages : List[str]
|
|
145
|
+
List of error messages
|
|
146
|
+
as_stream : bool
|
|
147
|
+
Whether to return as stream
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
Union[str, Generator[str, None, None]]
|
|
152
|
+
Error response as string or stream
|
|
153
|
+
"""
|
|
154
|
+
error_text = "".join(error_messages)
|
|
155
|
+
|
|
156
|
+
if not as_stream:
|
|
157
|
+
return error_text
|
|
158
|
+
|
|
159
|
+
return self._text_to_stream(error_text)
|
|
160
|
+
|
|
161
|
+
def _text_to_stream(self, text: str) -> Generator[str, None, None]:
|
|
162
|
+
"""Convert text to a stream generator.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
text : str
|
|
167
|
+
Text to convert
|
|
168
|
+
|
|
169
|
+
Yields
|
|
170
|
+
------
|
|
171
|
+
str
|
|
172
|
+
Text as stream chunks
|
|
173
|
+
"""
|
|
174
|
+
# Yield text in reasonable chunks
|
|
175
|
+
chunk_size = 50
|
|
176
|
+
for i in range(0, len(text), chunk_size):
|
|
177
|
+
yield text[i : i + chunk_size]
|
|
178
|
+
|
|
179
|
+
def extract_content_from_response(self, response: Any) -> str:
|
|
180
|
+
"""Extract text content from various response formats.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
response : Any
|
|
185
|
+
Response object from provider
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
str
|
|
190
|
+
Extracted text content
|
|
191
|
+
"""
|
|
192
|
+
# Handle string responses
|
|
193
|
+
if isinstance(response, str):
|
|
194
|
+
return response
|
|
195
|
+
|
|
196
|
+
# Handle dict responses
|
|
197
|
+
if isinstance(response, dict):
|
|
198
|
+
if "content" in response:
|
|
199
|
+
return str(response["content"])
|
|
200
|
+
if "text" in response:
|
|
201
|
+
return str(response["text"])
|
|
202
|
+
if "message" in response:
|
|
203
|
+
return str(response["message"])
|
|
204
|
+
|
|
205
|
+
# Handle object responses with attributes
|
|
206
|
+
if hasattr(response, "content"):
|
|
207
|
+
return str(response.content)
|
|
208
|
+
if hasattr(response, "text"):
|
|
209
|
+
return str(response.text)
|
|
210
|
+
if hasattr(response, "message"):
|
|
211
|
+
return str(response.message)
|
|
212
|
+
|
|
213
|
+
# Fallback to string conversion
|
|
214
|
+
return str(response)
|
|
215
|
+
|
|
216
|
+
def reset(self) -> None:
|
|
217
|
+
"""Reset the handler state."""
|
|
218
|
+
self._accumulated_response = []
|
|
219
|
+
|
|
220
|
+
def get_accumulated_response(self) -> str:
|
|
221
|
+
"""Get the accumulated response from streaming.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
str
|
|
226
|
+
Accumulated response text
|
|
227
|
+
"""
|
|
228
|
+
return "".join(self._accumulated_response)
|
|
229
|
+
|
|
230
|
+
def __repr__(self) -> str:
|
|
231
|
+
"""String representation of ResponseHandler."""
|
|
232
|
+
return f"ResponseHandler(accumulated_chunks={len(self._accumulated_response)})"
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# EOF
|
scitex/ai/layer/_Pass.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-20 00:29:47 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/layer/_Pass.py
|
|
5
|
+
|
|
6
|
+
THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/layer/_Pass.py"
|
|
7
|
+
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Pass(nn.Module):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
):
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
def forward(self, x):
|
|
18
|
+
return x
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# EOF
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-07 18:53:03 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/loss/_L1L2Losses.py
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def l1(model, lambda_l1=0.01):
|
|
10
|
+
lambda_l1 = torch.tensor(lambda_l1)
|
|
11
|
+
l1 = torch.tensor(0.0).cuda()
|
|
12
|
+
for param in model.parameters(): # fixme; is this OK?
|
|
13
|
+
l1 += torch.abs(param).sum()
|
|
14
|
+
return l1
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def l2(model, lambda_l2=0.01):
|
|
18
|
+
lambda_l2 = torch.tensor(lambda_l2)
|
|
19
|
+
l2 = torch.tensor(0.0).cuda()
|
|
20
|
+
for param in model.parameters(): # fixme; is this OK?
|
|
21
|
+
l2 += torch.norm(param).sum()
|
|
22
|
+
return l2
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def elastic(model, alpha=1.0, l1_ratio=0.5):
|
|
26
|
+
assert 0 <= l1_ratio <= 1
|
|
27
|
+
|
|
28
|
+
L1 = l1(model)
|
|
29
|
+
L2 = l2(model)
|
|
30
|
+
|
|
31
|
+
return alpha * (l1_ratio * L1 + (1 - l1_ratio) * L2)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# EOF
|