scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Time-stamp: "2024-09-07 01:09:38 (ywatanabe)"
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import scitex
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EarlyStopping:
|
|
11
|
+
"""
|
|
12
|
+
Early stops the training if the validation score doesn't improve after a given patience period.
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, patience=7, verbose=False, delta=1e-5, direction="minimize"):
|
|
17
|
+
"""
|
|
18
|
+
Args:
|
|
19
|
+
patience (int): How long to wait after last time validation score improved.
|
|
20
|
+
Default: 7
|
|
21
|
+
verbose (bool): If True, prints a message for each validation score improvement.
|
|
22
|
+
Default: False
|
|
23
|
+
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
24
|
+
Default: 0
|
|
25
|
+
"""
|
|
26
|
+
self.patience = patience
|
|
27
|
+
self.verbose = verbose
|
|
28
|
+
self.direction = direction
|
|
29
|
+
|
|
30
|
+
self.delta = delta
|
|
31
|
+
|
|
32
|
+
# default
|
|
33
|
+
self.counter = 0
|
|
34
|
+
self.best_score = np.inf if direction == "minimize" else -np.inf
|
|
35
|
+
self.best_i_global = None
|
|
36
|
+
self.models_spaths_dict = {}
|
|
37
|
+
|
|
38
|
+
def is_best(self, val_score):
|
|
39
|
+
is_smaller = val_score < self.best_score - abs(self.delta)
|
|
40
|
+
is_larger = self.best_score + abs(self.delta) < val_score
|
|
41
|
+
return is_smaller if self.direction == "minimize" else is_larger
|
|
42
|
+
|
|
43
|
+
def __call__(self, current_score, models_spaths_dict, i_global):
|
|
44
|
+
# The 1st call
|
|
45
|
+
if self.best_score is None:
|
|
46
|
+
self.save(current_score, models_spaths_dict, i_global)
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
# After the 2nd call
|
|
50
|
+
if self.is_best(current_score):
|
|
51
|
+
self.save(current_score, models_spaths_dict, i_global)
|
|
52
|
+
self.counter = 0
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
self.counter += 1
|
|
57
|
+
if self.verbose:
|
|
58
|
+
print(
|
|
59
|
+
f"\nEarlyStopping counter: {self.counter} out of {self.patience}\n"
|
|
60
|
+
)
|
|
61
|
+
if self.counter >= self.patience:
|
|
62
|
+
if self.verbose:
|
|
63
|
+
scitex.gen.print_block("Early-stopped.", c="yellow")
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
def save(self, current_score, models_spaths_dict, i_global):
|
|
67
|
+
"""Saves model when validation score decrease."""
|
|
68
|
+
|
|
69
|
+
if self.verbose:
|
|
70
|
+
print(
|
|
71
|
+
f"\nUpdate the best score: ({self.best_score:.6f} --> {current_score:.6f})"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.best_score = current_score
|
|
75
|
+
self.best_i_global = i_global
|
|
76
|
+
|
|
77
|
+
for model, spath in models_spaths_dict.items():
|
|
78
|
+
scitex.io.save(model.state_dict(), spath)
|
|
79
|
+
|
|
80
|
+
self.models_spaths_dict = models_spaths_dict
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
pass
|
|
85
|
+
# # starts the current fold's loop
|
|
86
|
+
# i_global = 0
|
|
87
|
+
# lc_logger = scitex.ml.LearningCurveLogger()
|
|
88
|
+
# early_stopping = utils.EarlyStopping(patience=50, verbose=True)
|
|
89
|
+
# for i_epoch, epoch in enumerate(tqdm(range(merged_conf["MAX_EPOCHS"]))):
|
|
90
|
+
|
|
91
|
+
# dlf.fill(i_fold, reset_fill_counter=False)
|
|
92
|
+
|
|
93
|
+
# step_str = "Validation"
|
|
94
|
+
# for i_batch, batch in enumerate(dlf.dl_val):
|
|
95
|
+
# _, loss_diag_val = utils.base_step(
|
|
96
|
+
# step_str,
|
|
97
|
+
# model,
|
|
98
|
+
# mtl,
|
|
99
|
+
# batch,
|
|
100
|
+
# device,
|
|
101
|
+
# i_fold,
|
|
102
|
+
# i_epoch,
|
|
103
|
+
# i_batch,
|
|
104
|
+
# i_global,
|
|
105
|
+
# lc_logger,
|
|
106
|
+
# no_mtl=args.no_mtl,
|
|
107
|
+
# print_batch_interval=False,
|
|
108
|
+
# )
|
|
109
|
+
# lc_logger.print(step_str)
|
|
110
|
+
|
|
111
|
+
# step_str = "Training"
|
|
112
|
+
# for i_batch, batch in enumerate(dlf.dl_tra):
|
|
113
|
+
# optimizer.zero_grad()
|
|
114
|
+
# loss, _ = utils.base_step(
|
|
115
|
+
# step_str,
|
|
116
|
+
# model,
|
|
117
|
+
# mtl,
|
|
118
|
+
# batch,
|
|
119
|
+
# device,
|
|
120
|
+
# i_fold,
|
|
121
|
+
# i_epoch,
|
|
122
|
+
# i_batch,
|
|
123
|
+
# i_global,
|
|
124
|
+
# lc_logger,
|
|
125
|
+
# no_mtl=args.no_mtl,
|
|
126
|
+
# print_batch_interval=False,
|
|
127
|
+
# )
|
|
128
|
+
# loss.backward()
|
|
129
|
+
# optimizer.step()
|
|
130
|
+
# i_global += 1
|
|
131
|
+
# lc_logger.print(step_str)
|
|
132
|
+
|
|
133
|
+
# bACC_val = np.array(lc_logger.logged_dict["Validation"]["bACC_diag_plot"])[
|
|
134
|
+
# np.array(lc_logger.logged_dict["Validation"]["i_epoch"]) == i_epoch
|
|
135
|
+
# ].mean()
|
|
136
|
+
|
|
137
|
+
# model_spath = (
|
|
138
|
+
# merged_conf["sdir"]
|
|
139
|
+
# + f"checkpoints/model_fold#{i_fold}_epoch#{i_epoch:03d}.pth"
|
|
140
|
+
# )
|
|
141
|
+
# mtl_spath = model_spath.replace("model_fold", "mtl_fold")
|
|
142
|
+
# models_spaths_dict = {model_spath: model, mtl_spath: mtl}
|
|
143
|
+
|
|
144
|
+
# early_stopping(loss_diag_val, models_spaths_dict, i_epoch, i_global)
|
|
145
|
+
# # early_stopping(-bACC_val, models_spaths_dict, i_epoch, i_global)
|
|
146
|
+
|
|
147
|
+
# if early_stopping.early_stop:
|
|
148
|
+
# print("Early stopping")
|
|
149
|
+
# break
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-20 10:53:22 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/feature_extraction/__init__.py
|
|
5
|
+
|
|
6
|
+
THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/feature_extraction/__init__.py"
|
|
7
|
+
|
|
8
|
+
#!/usr/bin/env python3
|
|
9
|
+
# -*- coding: utf-8 -*-
|
|
10
|
+
# Time-stamp: "2024-10-22 19:51:47 (ywatanabe)"
|
|
11
|
+
# File: __init__.py
|
|
12
|
+
|
|
13
|
+
import os as __os
|
|
14
|
+
import importlib as __importlib
|
|
15
|
+
import inspect as __inspect
|
|
16
|
+
import warnings as __warnings
|
|
17
|
+
|
|
18
|
+
# Get the current directory
|
|
19
|
+
current_dir = __os.path.dirname(__file__)
|
|
20
|
+
|
|
21
|
+
# Iterate through all Python files in the current directory
|
|
22
|
+
for filename in __os.listdir(current_dir):
|
|
23
|
+
if filename.endswith(".py") and not filename.startswith("__"):
|
|
24
|
+
module_name = filename[:-3] # Remove .py extension
|
|
25
|
+
try:
|
|
26
|
+
module = __importlib.import_module(f".{module_name}", package=__name__)
|
|
27
|
+
|
|
28
|
+
# Import only functions and classes from the module
|
|
29
|
+
for name, obj in __inspect.getmembers(module):
|
|
30
|
+
if __inspect.isfunction(obj) or __inspect.isclass(obj):
|
|
31
|
+
if not name.startswith("_"):
|
|
32
|
+
globals()[name] = obj
|
|
33
|
+
except ImportError as e:
|
|
34
|
+
# Warn about modules that couldn't be imported due to missing dependencies
|
|
35
|
+
__warnings.warn(
|
|
36
|
+
f"Could not import {module_name} from scitex.ai.feature_extraction: {str(e)}. "
|
|
37
|
+
f"Some functionality may be unavailable. "
|
|
38
|
+
f"Consider installing missing dependencies if you need this module.",
|
|
39
|
+
ImportWarning,
|
|
40
|
+
stacklevel=2
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Clean up temporary variables
|
|
44
|
+
del __os, __importlib, __inspect, __warnings, current_dir
|
|
45
|
+
if 'filename' in locals():
|
|
46
|
+
del filename
|
|
47
|
+
if 'module_name' in locals():
|
|
48
|
+
del module_name
|
|
49
|
+
if 'module' in locals():
|
|
50
|
+
del module
|
|
51
|
+
if 'name' in locals():
|
|
52
|
+
del name
|
|
53
|
+
if 'obj' in locals():
|
|
54
|
+
del obj
|
|
55
|
+
|
|
56
|
+
# EOF
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-27 21:36:51 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/feature_extraction/vit.py
|
|
5
|
+
|
|
6
|
+
THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/feature_extraction/vit.py"
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
Functionality:
|
|
10
|
+
Extracts features from images using Vision Transformer (ViT) models
|
|
11
|
+
Input:
|
|
12
|
+
Image arrays of arbitrary dimensions
|
|
13
|
+
Output:
|
|
14
|
+
Feature vectors (1000-dimensional embeddings)
|
|
15
|
+
Prerequisites:
|
|
16
|
+
torch, PIL, torchvision
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os as _os
|
|
20
|
+
from typing import Tuple, Union
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch as _torch
|
|
24
|
+
from pytorch_pretrained_vit import ViT
|
|
25
|
+
from torchvision import transforms as _transforms
|
|
26
|
+
|
|
27
|
+
# from ...decorators import batch_torch_fn
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _setup_device(device: Union[str, None]) -> str:
|
|
31
|
+
if device is None:
|
|
32
|
+
device = "cuda" if _torch.cuda.is_available() else "cpu"
|
|
33
|
+
return device
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class VitFeatureExtractor:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
model_name="B_16",
|
|
40
|
+
torch_home="./models",
|
|
41
|
+
device=None,
|
|
42
|
+
):
|
|
43
|
+
self.valid_models = [
|
|
44
|
+
"B_16",
|
|
45
|
+
"B_32",
|
|
46
|
+
"L_16",
|
|
47
|
+
"L_32",
|
|
48
|
+
"B_16_imagenet1k",
|
|
49
|
+
"B_32_imagenet1k",
|
|
50
|
+
"L_16_imagenet1k",
|
|
51
|
+
"L_32_imagenet1k",
|
|
52
|
+
]
|
|
53
|
+
self.model_name = model_name
|
|
54
|
+
self.torch_home = torch_home
|
|
55
|
+
self.device = _setup_device(device)
|
|
56
|
+
|
|
57
|
+
_os.environ["TORCH_HOME"] = torch_home
|
|
58
|
+
self._validate_inputs()
|
|
59
|
+
self.model = ViT(model_name, pretrained=True).to(self.device).eval()
|
|
60
|
+
self.transform = _transforms.Compose(
|
|
61
|
+
[
|
|
62
|
+
_transforms.ToPILImage(),
|
|
63
|
+
_transforms.Resize(self.model.image_size),
|
|
64
|
+
_transforms.ToTensor(),
|
|
65
|
+
_transforms.Normalize(0.5, 0.5),
|
|
66
|
+
]
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def _validate_inputs(self):
|
|
70
|
+
if self.model_name not in self.valid_models:
|
|
71
|
+
raise ValueError(f"Invalid model name. Choose from: {self.valid_models}")
|
|
72
|
+
if not _os.path.exists(self.torch_home):
|
|
73
|
+
raise FileNotFoundError(f"Model directory not found: {self.torch_home}")
|
|
74
|
+
|
|
75
|
+
def _preprocess_array(
|
|
76
|
+
self,
|
|
77
|
+
arr: _torch.Tensor,
|
|
78
|
+
dim: Tuple[int, int],
|
|
79
|
+
channel_dim: Union[int, None],
|
|
80
|
+
) -> _torch.Tensor:
|
|
81
|
+
# print(f"Input array shape: {arr.shape}")
|
|
82
|
+
|
|
83
|
+
orig_shape = arr.shape
|
|
84
|
+
dim = tuple(d if d >= 0 else len(orig_shape) + d for d in dim)
|
|
85
|
+
|
|
86
|
+
perm = list(range(len(orig_shape)))
|
|
87
|
+
for d in sorted(dim):
|
|
88
|
+
perm.remove(d)
|
|
89
|
+
perm.append(d)
|
|
90
|
+
arr = arr.permute(perm)
|
|
91
|
+
|
|
92
|
+
# Flatten all dimensions except the last two (spatial dimensions)
|
|
93
|
+
batch_shape = arr.shape[:-2]
|
|
94
|
+
spatial_shape = arr.shape[-2:]
|
|
95
|
+
arr = arr.reshape(-1, *spatial_shape)
|
|
96
|
+
|
|
97
|
+
# Process each image
|
|
98
|
+
transformed = []
|
|
99
|
+
for img in arr:
|
|
100
|
+
img = img.unsqueeze(0)
|
|
101
|
+
img = img.repeat(3, 1, 1)
|
|
102
|
+
transformed.append(self.transform(img))
|
|
103
|
+
result = _torch.stack(transformed)
|
|
104
|
+
return result, batch_shape
|
|
105
|
+
|
|
106
|
+
# @batch_method
|
|
107
|
+
# @torch_method
|
|
108
|
+
# @batch_torch_fn
|
|
109
|
+
def extract_features(
|
|
110
|
+
self,
|
|
111
|
+
arr,
|
|
112
|
+
axis=(-2, -1),
|
|
113
|
+
dim=None,
|
|
114
|
+
channel_dim=None,
|
|
115
|
+
batch_size=None,
|
|
116
|
+
device="cuda",
|
|
117
|
+
):
|
|
118
|
+
|
|
119
|
+
processed_arr, batch_shape = self._preprocess_array(
|
|
120
|
+
arr,
|
|
121
|
+
axis,
|
|
122
|
+
channel_dim,
|
|
123
|
+
)
|
|
124
|
+
# print(f"Processed shape: {processed_arr.shape}")
|
|
125
|
+
|
|
126
|
+
processed_arr = processed_arr.to(self.device)
|
|
127
|
+
with _torch.no_grad():
|
|
128
|
+
features = self.model(processed_arr).cpu()
|
|
129
|
+
|
|
130
|
+
return features.reshape(*batch_shape, -1)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if __name__ == "__main__":
|
|
134
|
+
import scitex
|
|
135
|
+
|
|
136
|
+
extractor = scitex.ai.feature_extraction.VitFeatureExtractor(
|
|
137
|
+
model_name="B_16_imagenet1k"
|
|
138
|
+
)
|
|
139
|
+
tensor = torch.randn(3, 2, 4, 5, 32, 32)
|
|
140
|
+
processed = extractor.extract_features(tensor, (-2, -1), None)
|
|
141
|
+
print(processed.shape)
|
|
142
|
+
|
|
143
|
+
arr = np.random.rand(3, 2, 4, 5, 32, 32)
|
|
144
|
+
processed = extractor.extract_features(arr, (-2, -1), None)
|
|
145
|
+
print(processed.shape)
|
|
146
|
+
# torch.Size([3, 2, 4, 5, 32, 32])
|
|
147
|
+
|
|
148
|
+
# EOF
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-25 12:00:00"
|
|
4
|
+
# Author: Yusuke Watanabe (ywatanabe@alumni.u-tokyo.ac.jp)
|
|
5
|
+
# scitex/src/scitex/ai/genai/__init__.py
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
GenAI module for unified access to multiple AI providers.
|
|
9
|
+
|
|
10
|
+
This module provides a consistent interface for interacting with various
|
|
11
|
+
AI providers (OpenAI, Anthropic, Google, etc.) with built-in cost tracking,
|
|
12
|
+
chat history management, and error handling.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import List, Dict, Any, Optional, Union
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from .provider_factory import Provider, create_provider, GenAI as GenAIFactory
|
|
19
|
+
from .auth_manager import AuthManager
|
|
20
|
+
from .chat_history import ChatHistory
|
|
21
|
+
from .cost_tracker import CostTracker
|
|
22
|
+
from .response_handler import ResponseHandler
|
|
23
|
+
from .base_provider import BaseProvider, CompletionResponse
|
|
24
|
+
|
|
25
|
+
# Import legacy providers for backward compatibility
|
|
26
|
+
from .anthropic import Anthropic
|
|
27
|
+
from .openai import OpenAI
|
|
28
|
+
from .google import Google
|
|
29
|
+
from .groq import Groq
|
|
30
|
+
from .deepseek import DeepSeek
|
|
31
|
+
from .llama import Llama
|
|
32
|
+
from .perplexity import Perplexity
|
|
33
|
+
from .genai_factory import genai_factory
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GenAI:
|
|
39
|
+
"""
|
|
40
|
+
Unified interface for multiple AI providers.
|
|
41
|
+
|
|
42
|
+
This class provides a consistent API for interacting with various AI providers
|
|
43
|
+
while handling authentication, chat history, cost tracking, and response processing.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
provider: Provider name (e.g., 'openai', 'anthropic', 'google')
|
|
47
|
+
api_key: Optional API key (if not provided, will use environment variable)
|
|
48
|
+
model: Model name (if not provided, will use provider's default)
|
|
49
|
+
system_prompt: Optional system prompt to prepend to conversations
|
|
50
|
+
**kwargs: Additional provider-specific configuration
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> from scitex.ai.genai import GenAI
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Basic usage
|
|
56
|
+
>>> ai = GenAI(provider="openai")
|
|
57
|
+
>>> response = ai.complete("What is the capital of France?")
|
|
58
|
+
>>> print(response)
|
|
59
|
+
"The capital of France is Paris."
|
|
60
|
+
>>>
|
|
61
|
+
>>> # With specific model and system prompt
|
|
62
|
+
>>> ai = GenAI(
|
|
63
|
+
... provider="anthropic",
|
|
64
|
+
... model="claude-3-opus-20240229",
|
|
65
|
+
... system_prompt="You are a helpful geography expert."
|
|
66
|
+
... )
|
|
67
|
+
>>> response = ai.complete("Tell me about Paris.")
|
|
68
|
+
>>>
|
|
69
|
+
>>> # Check costs
|
|
70
|
+
>>> print(ai.get_cost_summary())
|
|
71
|
+
"Total cost: $0.015 | Requests: 2 | Tokens: 1,234"
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
provider: Union[str, Provider],
|
|
77
|
+
api_key: Optional[str] = None,
|
|
78
|
+
model: Optional[str] = None,
|
|
79
|
+
system_prompt: Optional[str] = None,
|
|
80
|
+
**kwargs
|
|
81
|
+
):
|
|
82
|
+
"""Initialize GenAI with specified provider."""
|
|
83
|
+
# Store provider name
|
|
84
|
+
if isinstance(provider, str):
|
|
85
|
+
self.provider_name = provider.lower()
|
|
86
|
+
else:
|
|
87
|
+
self.provider_name = provider.value
|
|
88
|
+
|
|
89
|
+
# Initialize components
|
|
90
|
+
self.auth_manager = AuthManager(api_key, self.provider_name)
|
|
91
|
+
self.chat_history = ChatHistory(n_keep=-1) # Keep all messages by default
|
|
92
|
+
self.response_handler = ResponseHandler()
|
|
93
|
+
|
|
94
|
+
# Get API key from auth manager if not provided
|
|
95
|
+
if api_key is None:
|
|
96
|
+
api_key = self.auth_manager.api_key
|
|
97
|
+
|
|
98
|
+
# Create provider instance
|
|
99
|
+
self.provider = create_provider(
|
|
100
|
+
provider=self.provider_name, api_key=api_key, model=model, **kwargs
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Initialize cost tracker with provider and model
|
|
104
|
+
# Note: provider instance may have a model attribute set during initialization
|
|
105
|
+
actual_model = getattr(self.provider, "model", None) or model or "unknown"
|
|
106
|
+
self.cost_tracker = CostTracker(provider=self.provider_name, model=actual_model)
|
|
107
|
+
|
|
108
|
+
# Add system prompt if provided
|
|
109
|
+
if system_prompt:
|
|
110
|
+
self.chat_history.add_message("system", system_prompt)
|
|
111
|
+
|
|
112
|
+
logger.info(f"Initialized GenAI with provider: {self.provider_name}")
|
|
113
|
+
|
|
114
|
+
def complete(
|
|
115
|
+
self, prompt: str, images: Optional[List[str]] = None, **kwargs
|
|
116
|
+
) -> str:
|
|
117
|
+
"""
|
|
118
|
+
Generate a completion for the given prompt.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
prompt: The input prompt
|
|
122
|
+
images: Optional list of image URLs or base64 strings
|
|
123
|
+
**kwargs: Additional provider-specific parameters
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
The generated response text
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
ValueError: If the provider doesn't support images but images are provided
|
|
130
|
+
Exception: Provider-specific exceptions
|
|
131
|
+
"""
|
|
132
|
+
# Add user message to history
|
|
133
|
+
self.chat_history.add_message("user", prompt, images)
|
|
134
|
+
|
|
135
|
+
# Get messages for API call
|
|
136
|
+
messages = [msg.to_dict() for msg in self.chat_history.get_messages()]
|
|
137
|
+
|
|
138
|
+
# Call provider
|
|
139
|
+
try:
|
|
140
|
+
response: CompletionResponse = self.provider.complete(
|
|
141
|
+
messages=messages, **kwargs
|
|
142
|
+
)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Provider {self.provider_name} failed: {str(e)}")
|
|
145
|
+
raise
|
|
146
|
+
|
|
147
|
+
# Process response - CompletionResponse has a content attribute
|
|
148
|
+
content = response.content
|
|
149
|
+
|
|
150
|
+
# Add assistant message to history
|
|
151
|
+
self.chat_history.add_message("assistant", content)
|
|
152
|
+
|
|
153
|
+
# Track costs - CompletionResponse has input_tokens and output_tokens
|
|
154
|
+
self.cost_tracker.update(
|
|
155
|
+
input_tokens=response.input_tokens, output_tokens=response.output_tokens
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return content
|
|
159
|
+
|
|
160
|
+
def complete_async(self, prompt: str, images: Optional[List[str]] = None, **kwargs):
|
|
161
|
+
"""
|
|
162
|
+
Async version of complete method.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
prompt: The input prompt
|
|
166
|
+
images: Optional list of image URLs or base64 strings
|
|
167
|
+
**kwargs: Additional provider-specific parameters
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Awaitable that resolves to the generated response text
|
|
171
|
+
"""
|
|
172
|
+
raise NotImplementedError("Async completion not yet implemented")
|
|
173
|
+
|
|
174
|
+
def stream(self, prompt: str, images: Optional[List[str]] = None, **kwargs):
|
|
175
|
+
"""
|
|
176
|
+
Stream completions for the given prompt.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
prompt: The input prompt
|
|
180
|
+
images: Optional list of image URLs or base64 strings
|
|
181
|
+
**kwargs: Additional provider-specific parameters
|
|
182
|
+
|
|
183
|
+
Yields:
|
|
184
|
+
Chunks of the generated response
|
|
185
|
+
"""
|
|
186
|
+
raise NotImplementedError("Streaming not yet implemented")
|
|
187
|
+
|
|
188
|
+
def clear_history(self):
|
|
189
|
+
"""Clear the chat history."""
|
|
190
|
+
self.chat_history.clear()
|
|
191
|
+
logger.info("Chat history cleared")
|
|
192
|
+
|
|
193
|
+
def get_history(self) -> List[Dict[str, str]]:
|
|
194
|
+
"""Get the current chat history."""
|
|
195
|
+
return self.chat_history.messages
|
|
196
|
+
|
|
197
|
+
def get_cost_summary(self) -> str:
|
|
198
|
+
"""Get a summary of costs incurred."""
|
|
199
|
+
return self.cost_tracker.get_summary()
|
|
200
|
+
|
|
201
|
+
def get_detailed_costs(self) -> Dict[str, Any]:
|
|
202
|
+
"""Get detailed cost breakdown."""
|
|
203
|
+
return {
|
|
204
|
+
"total_cost": self.cost_tracker.total_cost,
|
|
205
|
+
"total_prompt_tokens": self.cost_tracker.total_prompt_tokens,
|
|
206
|
+
"total_completion_tokens": self.cost_tracker.total_completion_tokens,
|
|
207
|
+
"request_count": self.cost_tracker.request_count,
|
|
208
|
+
"cost_by_model": self.cost_tracker.cost_by_model,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
def reset_costs(self):
|
|
212
|
+
"""Reset cost tracking."""
|
|
213
|
+
self.cost_tracker.reset()
|
|
214
|
+
logger.info("Cost tracking reset")
|
|
215
|
+
|
|
216
|
+
def __repr__(self) -> str:
|
|
217
|
+
"""String representation of GenAI instance."""
|
|
218
|
+
return (
|
|
219
|
+
f"GenAI(provider='{self.provider_name}', "
|
|
220
|
+
f"model='{self.provider.model}', "
|
|
221
|
+
f"requests={self.cost_tracker.request_count})"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# Convenience function for one-off completions
|
|
226
|
+
def complete(
|
|
227
|
+
prompt: str,
|
|
228
|
+
provider: Union[str, Provider] = "openai",
|
|
229
|
+
model: Optional[str] = None,
|
|
230
|
+
api_key: Optional[str] = None,
|
|
231
|
+
**kwargs
|
|
232
|
+
) -> str:
|
|
233
|
+
"""
|
|
234
|
+
Convenience function for one-off completions without managing state.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
prompt: The input prompt
|
|
238
|
+
provider: Provider name or enum
|
|
239
|
+
model: Optional model name
|
|
240
|
+
api_key: Optional API key
|
|
241
|
+
**kwargs: Additional parameters
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
The generated response text
|
|
245
|
+
|
|
246
|
+
Example:
|
|
247
|
+
>>> from scitex.ai.genai import complete
|
|
248
|
+
>>> response = complete("What is 2+2?", provider="anthropic")
|
|
249
|
+
>>> print(response)
|
|
250
|
+
"2 + 2 = 4"
|
|
251
|
+
"""
|
|
252
|
+
genai = GenAI(provider=provider, model=model, api_key=api_key)
|
|
253
|
+
return genai.complete(prompt, **kwargs)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# Export public API
|
|
257
|
+
__all__ = [
|
|
258
|
+
# New API
|
|
259
|
+
"GenAI",
|
|
260
|
+
"GenAIFactory",
|
|
261
|
+
"complete",
|
|
262
|
+
"Provider",
|
|
263
|
+
"create_provider",
|
|
264
|
+
"AuthManager",
|
|
265
|
+
"ChatHistory",
|
|
266
|
+
"CostTracker",
|
|
267
|
+
"ResponseHandler",
|
|
268
|
+
# Legacy API for backward compatibility
|
|
269
|
+
"genai_factory",
|
|
270
|
+
"Anthropic",
|
|
271
|
+
"OpenAI",
|
|
272
|
+
"Google",
|
|
273
|
+
"Groq",
|
|
274
|
+
"DeepSeek",
|
|
275
|
+
"Llama",
|
|
276
|
+
"Perplexity",
|
|
277
|
+
]
|