scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
scitex/ai/genai/llama.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-05 21:11:08 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/_gen_ai/_Llama.py
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
DEPRECATED: This module is deprecated. Please use llama_provider.py instead.
|
|
8
|
+
The new provider-based architecture offers better modularity and maintainability.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
"""Imports"""
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
from typing import List, Optional
|
|
15
|
+
import warnings
|
|
16
|
+
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
import scitex
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from llama import Dialog
|
|
22
|
+
from llama import Llama as _Llama
|
|
23
|
+
except:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
from .base_genai import BaseGenAI
|
|
27
|
+
|
|
28
|
+
warnings.warn(
|
|
29
|
+
"llama.py is deprecated. Please use llama_provider.py instead. "
|
|
30
|
+
"See PROVIDER_MIGRATION_GUIDE.md for migration instructions.",
|
|
31
|
+
DeprecationWarning,
|
|
32
|
+
stacklevel=2,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
"""Functions & Classes"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def print_envs():
|
|
39
|
+
settings = {
|
|
40
|
+
"MASTER_ADDR": os.getenv("MASTER_ADDR", "localhost"),
|
|
41
|
+
"MASTER_PORT": os.getenv("MASTER_PORT", "12355"),
|
|
42
|
+
"WORLD_SIZE": os.getenv("WORLD_SIZE", "1"),
|
|
43
|
+
"RANK": os.getenv("RANK", "0"),
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
print("Environment Variable Settings:")
|
|
47
|
+
for key, value in settings.items():
|
|
48
|
+
print(f"{key}: {value}")
|
|
49
|
+
print()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Llama(BaseGenAI):
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
ckpt_dir: str = "",
|
|
56
|
+
tokenizer_path: str = "",
|
|
57
|
+
system_setting: str = "",
|
|
58
|
+
model: str = "Meta-Llama-3-8B",
|
|
59
|
+
max_seq_len: int = 32_768,
|
|
60
|
+
max_batch_size: int = 4,
|
|
61
|
+
max_gen_len: Optional[int] = None,
|
|
62
|
+
stream: bool = False,
|
|
63
|
+
seed: Optional[int] = None,
|
|
64
|
+
n_keep: int = 1,
|
|
65
|
+
temperature: float = 1.0,
|
|
66
|
+
provider="Llama",
|
|
67
|
+
chat_history=None,
|
|
68
|
+
**kwargs,
|
|
69
|
+
):
|
|
70
|
+
|
|
71
|
+
# Configure environment variables
|
|
72
|
+
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
|
|
73
|
+
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
|
|
74
|
+
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
|
|
75
|
+
os.environ["RANK"] = os.getenv("RANK", "0")
|
|
76
|
+
print_envs()
|
|
77
|
+
|
|
78
|
+
self.ckpt_dir = ckpt_dir if ckpt_dir else f"Meta-{model}/"
|
|
79
|
+
self.tokenizer_path = (
|
|
80
|
+
tokenizer_path if tokenizer_path else f"./Meta-{model}/tokenizer.model"
|
|
81
|
+
)
|
|
82
|
+
self.max_seq_len = max_seq_len
|
|
83
|
+
self.max_batch_size = max_batch_size
|
|
84
|
+
self.max_gen_len = max_gen_len
|
|
85
|
+
|
|
86
|
+
super().__init__(
|
|
87
|
+
system_setting=system_setting,
|
|
88
|
+
model=model,
|
|
89
|
+
api_key="",
|
|
90
|
+
stream=stream,
|
|
91
|
+
seed=seed,
|
|
92
|
+
n_keep=n_keep,
|
|
93
|
+
temperature=temperature,
|
|
94
|
+
chat_history=chat_history,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def __str__(self):
|
|
98
|
+
return "Llama"
|
|
99
|
+
|
|
100
|
+
def _init_client(self):
|
|
101
|
+
generator = _Llama.build(
|
|
102
|
+
ckpt_dir=self.ckpt_dir,
|
|
103
|
+
tokenizer_path=self.tokenizer_path,
|
|
104
|
+
max_seq_len=self.max_seq_len,
|
|
105
|
+
max_batch_size=self.max_batch_size,
|
|
106
|
+
)
|
|
107
|
+
return generator
|
|
108
|
+
|
|
109
|
+
def _api_call_static(self):
|
|
110
|
+
dialogs: List[Dialog] = [self.history]
|
|
111
|
+
results = self.client.chat_completion(
|
|
112
|
+
dialogs,
|
|
113
|
+
max_gen_len=self.max_gen_len,
|
|
114
|
+
temperature=self.temperature,
|
|
115
|
+
top_p=0.9,
|
|
116
|
+
)
|
|
117
|
+
out_text = results[0]["generation"]["content"]
|
|
118
|
+
return out_text
|
|
119
|
+
|
|
120
|
+
def _api_call_stream(self):
|
|
121
|
+
# Llama3 doesn't have built-in streaming, so we'll simulate it
|
|
122
|
+
full_response = self._api_call_static()
|
|
123
|
+
for char in full_response:
|
|
124
|
+
yield char
|
|
125
|
+
|
|
126
|
+
# def _get_available_models(self):
|
|
127
|
+
# # Llama3 doesn't have a list of available models, so we'll return a placeholder
|
|
128
|
+
# return ["llama3"]
|
|
129
|
+
|
|
130
|
+
def verify_model(self):
|
|
131
|
+
# Llama3 doesn't require model verification, so we'll skip it
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def main():
|
|
136
|
+
m = Llama(
|
|
137
|
+
ckpt_dir="/path/to/checkpoint",
|
|
138
|
+
tokenizer_path="/path/to/tokenizer",
|
|
139
|
+
system_setting="You are a helpful assistant.",
|
|
140
|
+
max_seq_len=512,
|
|
141
|
+
max_batch_size=4,
|
|
142
|
+
stream=True,
|
|
143
|
+
temperature=0.7,
|
|
144
|
+
)
|
|
145
|
+
m("Hi")
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
if __name__ == "__main__":
|
|
150
|
+
# Main
|
|
151
|
+
CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.gen.start(sys, plt, verbose=False)
|
|
152
|
+
main()
|
|
153
|
+
scitex.gen.close(CONFIG, verbose=False, notify=False)
|
|
154
|
+
|
|
155
|
+
# EOF
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-13 20:25:55 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/genai/llama_provider.py
|
|
5
|
+
|
|
6
|
+
"""Llama provider implementation using the new component-based architecture.
|
|
7
|
+
|
|
8
|
+
This module provides integration with local Llama models through the official Llama library.
|
|
9
|
+
It supports loading and running Llama models locally with full control over model parameters.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from typing import Dict, List, Iterator, Optional, Any
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from llama import Llama as _Llama
|
|
17
|
+
from llama import Dialog
|
|
18
|
+
except ImportError:
|
|
19
|
+
_Llama = None
|
|
20
|
+
Dialog = None
|
|
21
|
+
print(
|
|
22
|
+
"Warning: llama package not installed. "
|
|
23
|
+
"Install with the official Meta Llama repository instructions."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from .base_provider import BaseProvider, ProviderConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LlamaProvider(BaseProvider):
|
|
30
|
+
"""Llama provider implementation for local model inference."""
|
|
31
|
+
|
|
32
|
+
SUPPORTED_MODELS = [
|
|
33
|
+
"Meta-Llama-3-8B",
|
|
34
|
+
"Meta-Llama-3-70B",
|
|
35
|
+
"Meta-Llama-3.1-8B",
|
|
36
|
+
"Meta-Llama-3.1-70B",
|
|
37
|
+
"Meta-Llama-3.1-405B",
|
|
38
|
+
"Llama-2-7b",
|
|
39
|
+
"Llama-2-13b",
|
|
40
|
+
"Llama-2-70b",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
DEFAULT_MODEL = "Meta-Llama-3-8B"
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: ProviderConfig):
|
|
46
|
+
"""Initialize Llama provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: Provider configuration
|
|
50
|
+
"""
|
|
51
|
+
self.config = config
|
|
52
|
+
self.model_name = config.model or self.DEFAULT_MODEL
|
|
53
|
+
|
|
54
|
+
# Llama-specific configuration
|
|
55
|
+
self.ckpt_dir = getattr(config, "ckpt_dir", None) or f"{self.model_name}/"
|
|
56
|
+
self.tokenizer_path = (
|
|
57
|
+
getattr(config, "tokenizer_path", None)
|
|
58
|
+
or f"{self.model_name}/tokenizer.model"
|
|
59
|
+
)
|
|
60
|
+
self.max_seq_len = getattr(config, "max_seq_len", 32_768)
|
|
61
|
+
self.max_batch_size = getattr(config, "max_batch_size", 4)
|
|
62
|
+
self.max_gen_len = config.max_tokens
|
|
63
|
+
|
|
64
|
+
# Configure environment variables for distributed inference
|
|
65
|
+
self._setup_environment()
|
|
66
|
+
|
|
67
|
+
# Initialize the Llama model
|
|
68
|
+
if _Llama is None:
|
|
69
|
+
raise ImportError(
|
|
70
|
+
"Llama package is not installed. Please install it from the official Meta repository."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
self.generator = _Llama.build(
|
|
75
|
+
ckpt_dir=self.ckpt_dir,
|
|
76
|
+
tokenizer_path=self.tokenizer_path,
|
|
77
|
+
max_seq_len=self.max_seq_len,
|
|
78
|
+
max_batch_size=self.max_batch_size,
|
|
79
|
+
)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise RuntimeError(f"Failed to load Llama model: {str(e)}")
|
|
82
|
+
|
|
83
|
+
def _setup_environment(self):
|
|
84
|
+
"""Set up environment variables for distributed inference."""
|
|
85
|
+
env_vars = {
|
|
86
|
+
"MASTER_ADDR": os.getenv("MASTER_ADDR", "localhost"),
|
|
87
|
+
"MASTER_PORT": os.getenv("MASTER_PORT", "12355"),
|
|
88
|
+
"WORLD_SIZE": os.getenv("WORLD_SIZE", "1"),
|
|
89
|
+
"RANK": os.getenv("RANK", "0"),
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
for key, value in env_vars.items():
|
|
93
|
+
os.environ[key] = value
|
|
94
|
+
|
|
95
|
+
def validate_messages(self, messages: List[Dict[str, Any]]) -> None:
|
|
96
|
+
"""Validate message format.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
messages: List of message dictionaries
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: If messages are invalid
|
|
103
|
+
"""
|
|
104
|
+
if not messages:
|
|
105
|
+
raise ValueError("Messages cannot be empty")
|
|
106
|
+
|
|
107
|
+
for msg in messages:
|
|
108
|
+
if "role" not in msg:
|
|
109
|
+
raise ValueError(f"Missing role in message: {msg}")
|
|
110
|
+
if "content" not in msg:
|
|
111
|
+
raise ValueError(f"Missing content in message: {msg}")
|
|
112
|
+
if msg["role"] not in ["system", "user", "assistant"]:
|
|
113
|
+
raise ValueError(f"Invalid role: {msg['role']}")
|
|
114
|
+
|
|
115
|
+
def format_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
116
|
+
"""Format messages for Llama API.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
messages: List of message dictionaries
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Formatted messages for Llama
|
|
123
|
+
"""
|
|
124
|
+
formatted = []
|
|
125
|
+
|
|
126
|
+
# Add system prompt if configured
|
|
127
|
+
if self.config.system_prompt:
|
|
128
|
+
formatted.append({"role": "system", "content": self.config.system_prompt})
|
|
129
|
+
|
|
130
|
+
# Add user messages
|
|
131
|
+
formatted.extend(messages)
|
|
132
|
+
|
|
133
|
+
return formatted
|
|
134
|
+
|
|
135
|
+
def complete(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
|
136
|
+
"""Generate a completion.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
messages: List of message dictionaries
|
|
140
|
+
**kwargs: Additional parameters for the API
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Completion response dictionary
|
|
144
|
+
"""
|
|
145
|
+
self.validate_messages(messages)
|
|
146
|
+
formatted_messages = self.format_messages(messages)
|
|
147
|
+
|
|
148
|
+
# Convert to Llama Dialog format
|
|
149
|
+
dialogs: List[Dialog] = [formatted_messages]
|
|
150
|
+
|
|
151
|
+
# Merge config parameters with kwargs
|
|
152
|
+
params = {
|
|
153
|
+
"max_gen_len": self.max_gen_len,
|
|
154
|
+
"temperature": self.config.temperature or 1.0,
|
|
155
|
+
"top_p": kwargs.get("top_p", 0.9),
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
results = self.generator.chat_completion(dialogs, **params)
|
|
160
|
+
|
|
161
|
+
result = results[0]
|
|
162
|
+
content = result["generation"]["content"]
|
|
163
|
+
|
|
164
|
+
# Estimate token counts (Llama doesn't provide exact counts)
|
|
165
|
+
prompt_tokens = len(
|
|
166
|
+
" ".join(msg["content"] for msg in formatted_messages).split()
|
|
167
|
+
)
|
|
168
|
+
completion_tokens = len(content.split())
|
|
169
|
+
|
|
170
|
+
return {
|
|
171
|
+
"content": content,
|
|
172
|
+
"model": self.model_name,
|
|
173
|
+
"usage": {
|
|
174
|
+
"prompt_tokens": prompt_tokens,
|
|
175
|
+
"completion_tokens": completion_tokens,
|
|
176
|
+
"total_tokens": prompt_tokens + completion_tokens,
|
|
177
|
+
},
|
|
178
|
+
"finish_reason": "stop",
|
|
179
|
+
}
|
|
180
|
+
except Exception as e:
|
|
181
|
+
raise RuntimeError(f"Llama inference error: {str(e)}")
|
|
182
|
+
|
|
183
|
+
def stream(
|
|
184
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
|
185
|
+
) -> Iterator[Dict[str, Any]]:
|
|
186
|
+
"""Stream a completion.
|
|
187
|
+
|
|
188
|
+
Note: Llama doesn't have native streaming support, so this simulates streaming
|
|
189
|
+
by yielding characters one at a time.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
messages: List of message dictionaries
|
|
193
|
+
**kwargs: Additional parameters for the API
|
|
194
|
+
|
|
195
|
+
Yields:
|
|
196
|
+
Chunks of the completion
|
|
197
|
+
"""
|
|
198
|
+
# Get the full response
|
|
199
|
+
response = self.complete(messages, **kwargs)
|
|
200
|
+
content = response["content"]
|
|
201
|
+
|
|
202
|
+
# Simulate streaming by yielding characters
|
|
203
|
+
for i, char in enumerate(content):
|
|
204
|
+
yield {
|
|
205
|
+
"content": char,
|
|
206
|
+
"model": self.model_name,
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
# Yield final chunk with usage info
|
|
210
|
+
yield {
|
|
211
|
+
"content": "",
|
|
212
|
+
"usage": response["usage"],
|
|
213
|
+
"finish_reason": "stop",
|
|
214
|
+
}
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-15 12:00:00"
|
|
4
|
+
# Author: Yusuke Watanabe (ywatanabe@alumni.u-tokyo.ac.jp)
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Mock provider for testing purposes.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import List, Dict, Any, Optional, Iterator, Generator
|
|
11
|
+
from .base_provider import BaseProvider, CompletionResponse, Provider
|
|
12
|
+
from .provider_factory import register_provider
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MockProvider(BaseProvider):
|
|
16
|
+
"""Mock provider for testing."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config):
|
|
19
|
+
"""Initialize mock provider."""
|
|
20
|
+
self.config = config
|
|
21
|
+
self.api_key = config.api_key
|
|
22
|
+
self.model = config.model or "mock-model"
|
|
23
|
+
self.stream_mode = config.stream
|
|
24
|
+
self.system_prompt = config.system_prompt
|
|
25
|
+
self.temperature = config.temperature
|
|
26
|
+
self.max_tokens = config.max_tokens
|
|
27
|
+
self.seed = config.seed
|
|
28
|
+
self.n_draft = config.n_draft
|
|
29
|
+
self.client = None # Mock client
|
|
30
|
+
|
|
31
|
+
def init_client(self) -> Any:
|
|
32
|
+
"""Initialize the mock client."""
|
|
33
|
+
self.client = {"mock": True} # Mock client object
|
|
34
|
+
return self.client
|
|
35
|
+
|
|
36
|
+
def format_history(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
37
|
+
"""Format conversation history."""
|
|
38
|
+
# Mock implementation - just return as-is
|
|
39
|
+
return history
|
|
40
|
+
|
|
41
|
+
def call_static(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
|
|
42
|
+
"""Make a static API call."""
|
|
43
|
+
# Mock response
|
|
44
|
+
content = f"Mock response to: {messages[-1]['content']}"
|
|
45
|
+
return {
|
|
46
|
+
"choices": [
|
|
47
|
+
{
|
|
48
|
+
"message": {"content": content, "role": "assistant"},
|
|
49
|
+
"finish_reason": "stop",
|
|
50
|
+
}
|
|
51
|
+
],
|
|
52
|
+
"usage": {
|
|
53
|
+
"prompt_tokens": len(str(messages)),
|
|
54
|
+
"completion_tokens": len(content),
|
|
55
|
+
"total_tokens": len(str(messages)) + len(content),
|
|
56
|
+
},
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def call_stream(
|
|
60
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
|
61
|
+
) -> Generator[str, None, None]:
|
|
62
|
+
"""Make a streaming API call."""
|
|
63
|
+
response = f"Mock streaming response to: {messages[-1]['content']}"
|
|
64
|
+
for word in response.split():
|
|
65
|
+
yield word + " "
|
|
66
|
+
|
|
67
|
+
def complete(self, messages: List[Dict[str, Any]], **kwargs) -> CompletionResponse:
|
|
68
|
+
"""Generate a mock completion."""
|
|
69
|
+
# Ensure client is initialized
|
|
70
|
+
if not self.client:
|
|
71
|
+
self.init_client()
|
|
72
|
+
|
|
73
|
+
# Format history
|
|
74
|
+
formatted_messages = self.format_history(messages)
|
|
75
|
+
|
|
76
|
+
# Make API call
|
|
77
|
+
response = self.call_static(formatted_messages, **kwargs)
|
|
78
|
+
|
|
79
|
+
# Extract content
|
|
80
|
+
content = response["choices"][0]["message"]["content"]
|
|
81
|
+
usage = response.get("usage", {})
|
|
82
|
+
|
|
83
|
+
return CompletionResponse(
|
|
84
|
+
content=content,
|
|
85
|
+
input_tokens=usage.get("prompt_tokens", 0),
|
|
86
|
+
output_tokens=usage.get("completion_tokens", 0),
|
|
87
|
+
finish_reason=response["choices"][0].get("finish_reason", "stop"),
|
|
88
|
+
provider_response=response,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def stream(self, messages: List[Dict[str, Any]], **kwargs) -> Iterator[str]:
|
|
92
|
+
"""Stream mock completions."""
|
|
93
|
+
# Ensure client is initialized
|
|
94
|
+
if not self.client:
|
|
95
|
+
self.init_client()
|
|
96
|
+
|
|
97
|
+
# Format history
|
|
98
|
+
formatted_messages = self.format_history(messages)
|
|
99
|
+
|
|
100
|
+
# Stream response
|
|
101
|
+
for chunk in self.call_stream(formatted_messages, **kwargs):
|
|
102
|
+
yield chunk
|
|
103
|
+
|
|
104
|
+
def count_tokens(self, text: str) -> int:
|
|
105
|
+
"""Count tokens in text."""
|
|
106
|
+
return len(text) // 4 # Mock implementation
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def supports_images(self) -> bool:
|
|
110
|
+
"""Check if provider supports images."""
|
|
111
|
+
return True
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def supports_streaming(self) -> bool:
|
|
115
|
+
"""Check if provider supports streaming."""
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def max_context_length(self) -> int:
|
|
120
|
+
"""Get maximum context length."""
|
|
121
|
+
return 4096
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Auto-register when module is imported
|
|
125
|
+
register_provider(Provider.MOCK.value, MockProvider)
|
|
126
|
+
|
|
127
|
+
## EOF
|