scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
scitex/ai/genai/groq.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-28 02:47:54 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/_gen_ai/_Groq.py
|
|
5
|
+
|
|
6
|
+
THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/_gen_ai/_Groq.py"
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
Functionality:
|
|
10
|
+
- Implements GLOQ AI interface
|
|
11
|
+
- Handles both streaming and static text generation
|
|
12
|
+
Input:
|
|
13
|
+
- User prompts and chat history
|
|
14
|
+
- Model configurations and API credentials
|
|
15
|
+
Output:
|
|
16
|
+
- Generated text responses
|
|
17
|
+
- Token usage statistics
|
|
18
|
+
Prerequisites:
|
|
19
|
+
- GLOQ API key (GLOQ_API_KEY environment variable)
|
|
20
|
+
- gloq package
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
"""Imports"""
|
|
24
|
+
import os
|
|
25
|
+
import sys
|
|
26
|
+
import warnings
|
|
27
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
|
28
|
+
|
|
29
|
+
from groq import Groq as _Groq
|
|
30
|
+
import matplotlib.pyplot as plt
|
|
31
|
+
|
|
32
|
+
from .base_genai import BaseGenAI
|
|
33
|
+
|
|
34
|
+
"""Functions & Classes"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Groq(BaseGenAI):
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
system_setting: str = "",
|
|
41
|
+
api_key: Optional[str] = os.getenv("GROQ_API_KEY"),
|
|
42
|
+
model: str = "llama3-8b-8192",
|
|
43
|
+
stream: bool = False,
|
|
44
|
+
seed: Optional[int] = None,
|
|
45
|
+
n_keep: int = 1,
|
|
46
|
+
temperature: float = 0.5,
|
|
47
|
+
chat_history: Optional[List[Dict[str, str]]] = None,
|
|
48
|
+
max_tokens: int = 8000,
|
|
49
|
+
) -> None:
|
|
50
|
+
warnings.warn(
|
|
51
|
+
"Groq class is deprecated. Use GenAI(provider='groq') instead. "
|
|
52
|
+
"Example: from scitex.ai.genai import GenAI; ai = GenAI(provider='groq', model='llama3-8b-8192')",
|
|
53
|
+
DeprecationWarning,
|
|
54
|
+
stacklevel=2,
|
|
55
|
+
)
|
|
56
|
+
max_tokens = min(max_tokens, 8000)
|
|
57
|
+
if not api_key:
|
|
58
|
+
raise ValueError("GROQ_API_KEY environment variable not set")
|
|
59
|
+
|
|
60
|
+
super().__init__(
|
|
61
|
+
system_setting=system_setting,
|
|
62
|
+
model=model,
|
|
63
|
+
api_key=api_key,
|
|
64
|
+
stream=stream,
|
|
65
|
+
n_keep=n_keep,
|
|
66
|
+
temperature=temperature,
|
|
67
|
+
provider="Groq",
|
|
68
|
+
chat_history=chat_history,
|
|
69
|
+
max_tokens=max_tokens,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _init_client(self) -> Any:
|
|
73
|
+
return _Groq(api_key=self.api_key)
|
|
74
|
+
|
|
75
|
+
def _api_call_static(self) -> str:
|
|
76
|
+
output = self.client.chat.completions.create(
|
|
77
|
+
model=self.model,
|
|
78
|
+
messages=self.history,
|
|
79
|
+
temperature=self.temperature,
|
|
80
|
+
max_tokens=self.max_tokens,
|
|
81
|
+
stream=False,
|
|
82
|
+
)
|
|
83
|
+
out_text = output.choices[0].message.content
|
|
84
|
+
|
|
85
|
+
self.input_tokens += output.usage.prompt_tokens
|
|
86
|
+
self.output_tokens += output.usage.completion_tokens
|
|
87
|
+
|
|
88
|
+
return out_text
|
|
89
|
+
|
|
90
|
+
def _api_call_stream(self) -> Generator[str, None, None]:
|
|
91
|
+
stream = self.client.chat.completions.create(
|
|
92
|
+
model=self.model,
|
|
93
|
+
messages=self.history,
|
|
94
|
+
temperature=self.temperature,
|
|
95
|
+
max_tokens=self.max_tokens,
|
|
96
|
+
stream=True,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
for chunk in stream:
|
|
100
|
+
if chunk.choices[0].delta.content:
|
|
101
|
+
yield chunk.choices[0].delta.content
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# EOF
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-25 12:00:00"
|
|
4
|
+
# Author: Yusuke Watanabe (ywatanabe@alumni.u-tokyo.ac.jp)
|
|
5
|
+
# scitex/src/scitex/ai/genai/groq_provider.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Groq provider implementation for GenAI.
|
|
9
|
+
|
|
10
|
+
Provides access to Groq's API with models like Llama, Mixtral, etc.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import List, Dict, Any, Optional, Generator
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from .base_provider import BaseProvider, CompletionResponse, Provider
|
|
18
|
+
from .provider_factory import register_provider
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GroqProvider(BaseProvider):
|
|
24
|
+
"""
|
|
25
|
+
Groq provider implementation.
|
|
26
|
+
|
|
27
|
+
Supports Llama 3, Mixtral, and other models available through Groq.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
SUPPORTED_MODELS = [
|
|
31
|
+
"llama3-8b-8192",
|
|
32
|
+
"llama3-70b-8192",
|
|
33
|
+
"llama2-70b-4096",
|
|
34
|
+
"mixtral-8x7b-32768",
|
|
35
|
+
"gemma-7b-it",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
DEFAULT_MODEL = "llama3-8b-8192"
|
|
39
|
+
|
|
40
|
+
def __init__(self, config):
|
|
41
|
+
"""Initialize Groq provider."""
|
|
42
|
+
self.config = config
|
|
43
|
+
self.api_key = config.api_key or os.getenv("GROQ_API_KEY")
|
|
44
|
+
self.model = config.model or self.DEFAULT_MODEL
|
|
45
|
+
self.kwargs = config.kwargs or {}
|
|
46
|
+
|
|
47
|
+
if not self.api_key:
|
|
48
|
+
raise ValueError("GROQ_API_KEY not provided and not found in environment")
|
|
49
|
+
|
|
50
|
+
# Import Groq client
|
|
51
|
+
try:
|
|
52
|
+
from groq import Groq as GroqClient
|
|
53
|
+
|
|
54
|
+
self.client = GroqClient(api_key=self.api_key)
|
|
55
|
+
except ImportError:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
"Groq package not installed. Install with: pip install groq"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def complete(self, messages: List[Dict[str, Any]], **kwargs) -> CompletionResponse:
|
|
61
|
+
"""
|
|
62
|
+
Generate completion using Groq API.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
messages: List of message dictionaries
|
|
66
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
CompletionResponse with generated text and usage info
|
|
70
|
+
"""
|
|
71
|
+
# Validate messages
|
|
72
|
+
if not self.validate_messages(messages):
|
|
73
|
+
raise ValueError("Invalid message format")
|
|
74
|
+
|
|
75
|
+
# Format messages for Groq (same as OpenAI format)
|
|
76
|
+
formatted_messages = self.format_messages(messages)
|
|
77
|
+
|
|
78
|
+
# Prepare API parameters
|
|
79
|
+
api_params = {
|
|
80
|
+
"model": self.model,
|
|
81
|
+
"messages": formatted_messages,
|
|
82
|
+
"stream": False,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# Add optional parameters
|
|
86
|
+
for param in ["temperature", "max_tokens", "top_p", "stop", "seed"]:
|
|
87
|
+
if param in kwargs:
|
|
88
|
+
api_params[param] = kwargs[param]
|
|
89
|
+
|
|
90
|
+
# Groq has a max token limit of 8000
|
|
91
|
+
if "max_tokens" in api_params:
|
|
92
|
+
api_params["max_tokens"] = min(api_params["max_tokens"], 8000)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
# Make API call
|
|
96
|
+
response = self.client.chat.completions.create(**api_params)
|
|
97
|
+
|
|
98
|
+
# Extract content and usage
|
|
99
|
+
content = response.choices[0].message.content
|
|
100
|
+
usage = response.usage
|
|
101
|
+
|
|
102
|
+
return CompletionResponse(
|
|
103
|
+
content=content,
|
|
104
|
+
input_tokens=usage.prompt_tokens,
|
|
105
|
+
output_tokens=usage.completion_tokens,
|
|
106
|
+
finish_reason=response.choices[0].finish_reason,
|
|
107
|
+
provider_response=response,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
except Exception as e:
|
|
111
|
+
logger.error(f"Groq API error: {str(e)}")
|
|
112
|
+
raise
|
|
113
|
+
|
|
114
|
+
def format_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
115
|
+
"""
|
|
116
|
+
Format messages for Groq API.
|
|
117
|
+
|
|
118
|
+
Groq uses the same message format as OpenAI.
|
|
119
|
+
"""
|
|
120
|
+
formatted_messages = []
|
|
121
|
+
|
|
122
|
+
for msg in messages:
|
|
123
|
+
formatted_msg = {"role": msg["role"], "content": msg["content"]}
|
|
124
|
+
formatted_messages.append(formatted_msg)
|
|
125
|
+
|
|
126
|
+
return formatted_messages
|
|
127
|
+
|
|
128
|
+
def validate_messages(self, messages: List[Dict[str, Any]]) -> bool:
|
|
129
|
+
"""Validate message format."""
|
|
130
|
+
if not messages:
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
for msg in messages:
|
|
134
|
+
if not isinstance(msg, dict):
|
|
135
|
+
return False
|
|
136
|
+
if "role" not in msg or "content" not in msg:
|
|
137
|
+
return False
|
|
138
|
+
if msg["role"] not in ["system", "user", "assistant"]:
|
|
139
|
+
return False
|
|
140
|
+
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
def stream(
|
|
144
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
|
145
|
+
) -> Generator[str, None, CompletionResponse]:
|
|
146
|
+
"""Generate a streaming completion.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
messages: List of messages in standard format
|
|
150
|
+
**kwargs: Additional parameters
|
|
151
|
+
|
|
152
|
+
Yields:
|
|
153
|
+
Text chunks during streaming
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Final CompletionResponse when complete
|
|
157
|
+
"""
|
|
158
|
+
# Validate messages
|
|
159
|
+
if not self.validate_messages(messages):
|
|
160
|
+
raise ValueError("Invalid message format")
|
|
161
|
+
|
|
162
|
+
# Format messages for Groq
|
|
163
|
+
formatted_messages = self.format_messages(messages)
|
|
164
|
+
|
|
165
|
+
# Prepare API parameters
|
|
166
|
+
api_params = {
|
|
167
|
+
"model": self.model,
|
|
168
|
+
"messages": formatted_messages,
|
|
169
|
+
"stream": True,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
# Add optional parameters
|
|
173
|
+
for param in ["temperature", "max_tokens", "top_p", "stop", "seed"]:
|
|
174
|
+
if param in kwargs:
|
|
175
|
+
api_params[param] = kwargs[param]
|
|
176
|
+
|
|
177
|
+
# Groq has a max token limit of 8000
|
|
178
|
+
if "max_tokens" in api_params:
|
|
179
|
+
api_params["max_tokens"] = min(api_params["max_tokens"], 8000)
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
# Make streaming API call
|
|
183
|
+
stream = self.client.chat.completions.create(**api_params)
|
|
184
|
+
|
|
185
|
+
# Track content
|
|
186
|
+
full_content = ""
|
|
187
|
+
|
|
188
|
+
for chunk in stream:
|
|
189
|
+
if chunk.choices[0].delta.content:
|
|
190
|
+
content = chunk.choices[0].delta.content
|
|
191
|
+
full_content += content
|
|
192
|
+
yield content
|
|
193
|
+
|
|
194
|
+
# Estimate tokens for streaming (Groq doesn't provide usage in stream)
|
|
195
|
+
input_tokens = self.count_tokens(str(formatted_messages))
|
|
196
|
+
output_tokens = self.count_tokens(full_content)
|
|
197
|
+
|
|
198
|
+
# Return final response
|
|
199
|
+
return CompletionResponse(
|
|
200
|
+
content=full_content,
|
|
201
|
+
input_tokens=input_tokens,
|
|
202
|
+
output_tokens=output_tokens,
|
|
203
|
+
finish_reason="stop",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error(f"Groq streaming error: {str(e)}")
|
|
208
|
+
raise
|
|
209
|
+
|
|
210
|
+
def count_tokens(self, text: str) -> int:
|
|
211
|
+
"""Count tokens in the given text.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
text: Text to count tokens for
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Number of tokens (estimated)
|
|
218
|
+
"""
|
|
219
|
+
# Groq doesn't provide a token counter, so estimate
|
|
220
|
+
# Llama tokenization is roughly similar to GPT
|
|
221
|
+
return len(text.split()) * 4 // 3
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def supports_images(self) -> bool:
|
|
225
|
+
"""Check if this provider/model supports image inputs."""
|
|
226
|
+
# Groq doesn't currently support multimodal inputs
|
|
227
|
+
return False
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def supports_streaming(self) -> bool:
|
|
231
|
+
"""Check if this provider/model supports streaming."""
|
|
232
|
+
return True
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def max_context_length(self) -> int:
|
|
236
|
+
"""Get maximum context length for this model."""
|
|
237
|
+
context_lengths = {
|
|
238
|
+
"llama3-8b-8192": 8192,
|
|
239
|
+
"llama3-70b-8192": 8192,
|
|
240
|
+
"llama2-70b-4096": 4096,
|
|
241
|
+
"mixtral-8x7b-32768": 32768,
|
|
242
|
+
"gemma-7b-it": 8192,
|
|
243
|
+
}
|
|
244
|
+
return context_lengths.get(self.model, 8192)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# Auto-register when module is imported
|
|
248
|
+
register_provider(Provider.GROQ.value, GroqProvider)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2025-05-31 10:25:00"
|
|
4
|
+
# Author: ywatanabe
|
|
5
|
+
# File: ./src/scitex/ai/genai/image_processor.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Handles image processing for multimodal AI inputs.
|
|
9
|
+
|
|
10
|
+
This module provides image processing functionality including:
|
|
11
|
+
- Image resizing to fit token limits
|
|
12
|
+
- Base64 encoding for API transmission
|
|
13
|
+
- Multiple format support (file path, bytes, PIL Image)
|
|
14
|
+
- Format validation
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import base64
|
|
18
|
+
import io
|
|
19
|
+
from typing import Union, Tuple, Optional
|
|
20
|
+
from PIL import Image
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ImageProcessor:
|
|
24
|
+
"""Processes images for multimodal AI inputs.
|
|
25
|
+
|
|
26
|
+
Example
|
|
27
|
+
-------
|
|
28
|
+
>>> processor = ImageProcessor()
|
|
29
|
+
>>> # Process image from file
|
|
30
|
+
>>> base64_str = processor.process_image("path/to/image.jpg", max_size=512)
|
|
31
|
+
>>> print(base64_str[:50])
|
|
32
|
+
/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAgGBgcGBQgHBw...
|
|
33
|
+
|
|
34
|
+
>>> # Process PIL Image
|
|
35
|
+
>>> from PIL import Image
|
|
36
|
+
>>> img = Image.new('RGB', (100, 100), color='red')
|
|
37
|
+
>>> base64_str = processor.process_image(img)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
"""Initialize image processor."""
|
|
42
|
+
self.supported_formats = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
|
43
|
+
|
|
44
|
+
def process_image(
|
|
45
|
+
self, image: Union[str, bytes, Image.Image], max_size: int = 512
|
|
46
|
+
) -> str:
|
|
47
|
+
"""Process an image for API transmission.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
image : Union[str, bytes, Image.Image]
|
|
52
|
+
Image as file path, bytes, or PIL Image
|
|
53
|
+
max_size : int
|
|
54
|
+
Maximum dimension (width or height) in pixels
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
str
|
|
59
|
+
Base64 encoded image string
|
|
60
|
+
"""
|
|
61
|
+
# Convert to PIL Image
|
|
62
|
+
pil_image = self._to_pil_image(image)
|
|
63
|
+
|
|
64
|
+
# Resize if needed
|
|
65
|
+
if max(pil_image.size) > max_size:
|
|
66
|
+
pil_image = self.resize_image(pil_image, max_size)
|
|
67
|
+
|
|
68
|
+
# Convert to base64
|
|
69
|
+
return self.to_base64(pil_image)
|
|
70
|
+
|
|
71
|
+
def _to_pil_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
|
|
72
|
+
"""Convert various image formats to PIL Image.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
image : Union[str, bytes, Image.Image]
|
|
77
|
+
Input image in various formats
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
Image.Image
|
|
82
|
+
PIL Image object
|
|
83
|
+
"""
|
|
84
|
+
if isinstance(image, Image.Image):
|
|
85
|
+
return image
|
|
86
|
+
|
|
87
|
+
if isinstance(image, str):
|
|
88
|
+
# Check if it's a base64 string
|
|
89
|
+
if image.startswith("data:image"):
|
|
90
|
+
# Extract base64 data from data URL
|
|
91
|
+
base64_data = image.split(",")[1]
|
|
92
|
+
image_bytes = base64.b64decode(base64_data)
|
|
93
|
+
return Image.open(io.BytesIO(image_bytes))
|
|
94
|
+
else:
|
|
95
|
+
# Assume it's a file path
|
|
96
|
+
try:
|
|
97
|
+
return Image.open(image)
|
|
98
|
+
except Exception as e:
|
|
99
|
+
# Maybe it's already base64 encoded
|
|
100
|
+
try:
|
|
101
|
+
image_bytes = base64.b64decode(image)
|
|
102
|
+
return Image.open(io.BytesIO(image_bytes))
|
|
103
|
+
except:
|
|
104
|
+
raise ValueError(f"Could not load image from string: {e}")
|
|
105
|
+
|
|
106
|
+
if isinstance(image, bytes):
|
|
107
|
+
return Image.open(io.BytesIO(image))
|
|
108
|
+
|
|
109
|
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
|
110
|
+
|
|
111
|
+
def resize_image(self, image: Image.Image, max_size: int) -> Image.Image:
|
|
112
|
+
"""Resize image to fit within max_size while maintaining aspect ratio.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
image : Image.Image
|
|
117
|
+
PIL Image to resize
|
|
118
|
+
max_size : int
|
|
119
|
+
Maximum dimension in pixels
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
Image.Image
|
|
124
|
+
Resized PIL Image
|
|
125
|
+
"""
|
|
126
|
+
# Calculate new dimensions
|
|
127
|
+
width, height = image.size
|
|
128
|
+
aspect_ratio = width / height
|
|
129
|
+
|
|
130
|
+
if width > height:
|
|
131
|
+
new_width = max_size
|
|
132
|
+
new_height = int(max_size / aspect_ratio)
|
|
133
|
+
else:
|
|
134
|
+
new_height = max_size
|
|
135
|
+
new_width = int(max_size * aspect_ratio)
|
|
136
|
+
|
|
137
|
+
# Use high-quality resampling
|
|
138
|
+
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
139
|
+
|
|
140
|
+
def to_base64(self, image: Image.Image, format: str = "JPEG") -> str:
|
|
141
|
+
"""Convert PIL Image to base64 string.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
image : Image.Image
|
|
146
|
+
PIL Image to encode
|
|
147
|
+
format : str
|
|
148
|
+
Output format (JPEG, PNG, etc.)
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
str
|
|
153
|
+
Base64 encoded image string
|
|
154
|
+
"""
|
|
155
|
+
# Convert RGBA to RGB if saving as JPEG
|
|
156
|
+
if format.upper() == "JPEG" and image.mode in ("RGBA", "LA", "P"):
|
|
157
|
+
# Create a white background
|
|
158
|
+
background = Image.new("RGB", image.size, (255, 255, 255))
|
|
159
|
+
if image.mode == "P":
|
|
160
|
+
image = image.convert("RGBA")
|
|
161
|
+
background.paste(
|
|
162
|
+
image, mask=image.split()[-1] if image.mode == "RGBA" else None
|
|
163
|
+
)
|
|
164
|
+
image = background
|
|
165
|
+
|
|
166
|
+
# Save to bytes buffer
|
|
167
|
+
buffer = io.BytesIO()
|
|
168
|
+
image.save(
|
|
169
|
+
buffer, format=format, quality=95 if format.upper() == "JPEG" else None
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Encode to base64
|
|
173
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
174
|
+
|
|
175
|
+
def get_image_info(self, image: Union[str, bytes, Image.Image]) -> dict:
|
|
176
|
+
"""Get information about an image.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
image : Union[str, bytes, Image.Image]
|
|
181
|
+
Image to analyze
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
dict
|
|
186
|
+
Image information including size, mode, format
|
|
187
|
+
"""
|
|
188
|
+
pil_image = self._to_pil_image(image)
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"width": pil_image.width,
|
|
192
|
+
"height": pil_image.height,
|
|
193
|
+
"mode": pil_image.mode,
|
|
194
|
+
"format": pil_image.format,
|
|
195
|
+
"size_mb": self._estimate_size_mb(pil_image),
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
def _estimate_size_mb(self, image: Image.Image) -> float:
|
|
199
|
+
"""Estimate image size in megabytes.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
image : Image.Image
|
|
204
|
+
PIL Image
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
float
|
|
209
|
+
Estimated size in MB
|
|
210
|
+
"""
|
|
211
|
+
# Rough estimate based on dimensions and mode
|
|
212
|
+
bytes_per_pixel = len(image.mode) # Rough estimate
|
|
213
|
+
total_bytes = image.width * image.height * bytes_per_pixel
|
|
214
|
+
return total_bytes / (1024 * 1024)
|
|
215
|
+
|
|
216
|
+
def validate_image(self, image_path: str) -> bool:
|
|
217
|
+
"""Validate if a file is a supported image format.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
image_path : str
|
|
222
|
+
Path to image file
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
bool
|
|
227
|
+
True if valid image format
|
|
228
|
+
"""
|
|
229
|
+
if not isinstance(image_path, str):
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
# Check file extension
|
|
233
|
+
ext = image_path.lower().split(".")[-1]
|
|
234
|
+
if f".{ext}" not in self.supported_formats:
|
|
235
|
+
return False
|
|
236
|
+
|
|
237
|
+
# Try to open the image
|
|
238
|
+
try:
|
|
239
|
+
img = Image.open(image_path)
|
|
240
|
+
img.verify()
|
|
241
|
+
return True
|
|
242
|
+
except:
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
def __repr__(self) -> str:
|
|
246
|
+
"""String representation of ImageProcessor."""
|
|
247
|
+
return f"ImageProcessor(supported_formats={self.supported_formats})"
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# EOF
|