crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from threading import Lock
|
|
4
|
+
from typing import Any, Dict, Optional, List, Union
|
|
5
|
+
|
|
6
|
+
from helm.common.cache import CacheConfig
|
|
7
|
+
from helm.common.hierarchical_logger import hlog
|
|
8
|
+
from helm.common.media_object import TEXT_TYPE
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags
|
|
11
|
+
from helm.clients.client import CachingClient, truncate_sequence, generate_uid_for_multimodal_prompt
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import vertexai
|
|
15
|
+
from vertexai.language_models import TextGenerationModel, TextGenerationResponse # PaLM2
|
|
16
|
+
from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Candidate, Part, Image # Gemini
|
|
17
|
+
from google.cloud.aiplatform_v1beta1.types import SafetySetting, HarmCategory
|
|
18
|
+
except ModuleNotFoundError as e:
|
|
19
|
+
handle_module_not_found_error(e, ["google"])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_models_lock: Lock = Lock()
|
|
23
|
+
_models: Dict[str, Any] = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class VertexAIContentBlockedError(Exception):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class VertexAIClient(CachingClient, ABC):
|
|
31
|
+
"""Client for Vertex AI models"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, cache_config: CacheConfig, project_id: str, location: str) -> None:
|
|
34
|
+
super().__init__(cache_config=cache_config)
|
|
35
|
+
self.project_id = project_id
|
|
36
|
+
self.location = location
|
|
37
|
+
|
|
38
|
+
# VertexAI's default safety filter is overly sensitive, so we disable it.
|
|
39
|
+
self.safety_settings: Dict[HarmCategory, SafetySetting.HarmBlockThreshold] = {
|
|
40
|
+
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
41
|
+
for harm_category in iter(HarmCategory)
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
vertexai.init(project=self.project_id, location=self.location)
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class VertexAITextClient(VertexAIClient):
|
|
52
|
+
"""Client for Vertex AI text models
|
|
53
|
+
This client is used for PaLM2 for example."""
|
|
54
|
+
|
|
55
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
56
|
+
"""Make a request"""
|
|
57
|
+
parameters = {
|
|
58
|
+
"temperature": request.temperature,
|
|
59
|
+
"max_output_tokens": request.max_tokens,
|
|
60
|
+
"top_k": request.top_k_per_token,
|
|
61
|
+
"top_p": request.top_p,
|
|
62
|
+
"stop_sequences": request.stop_sequences,
|
|
63
|
+
"candidate_count": request.num_completions,
|
|
64
|
+
# TODO #2084: Add support for these parameters.
|
|
65
|
+
# The parameters "echo", "frequency_penalty", and "presence_penalty" are supposed to be supported
|
|
66
|
+
# in an HTTP request (See https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text),
|
|
67
|
+
# but they are not supported in the Python SDK:
|
|
68
|
+
# https://github.com/googleapis/python-aiplatform/blob/beae48f63e40ea171c3f1625164569e7311b8e5a/vertexai/language_models/_language_models.py#L968C1-L980C1
|
|
69
|
+
# "frequency_penalty": request.frequency_penalty,
|
|
70
|
+
# "presence_penalty": request.presence_penalty,
|
|
71
|
+
# "echo": request.echo_prompt,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
completions: List[GeneratedOutput] = []
|
|
75
|
+
model_name: str = request.model_engine
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
|
|
79
|
+
def do_it() -> Dict[str, Any]:
|
|
80
|
+
model = TextGenerationModel.from_pretrained(model_name)
|
|
81
|
+
response = model.predict(request.prompt, **parameters)
|
|
82
|
+
candidates: List[TextGenerationResponse] = response.candidates
|
|
83
|
+
response_dict = {
|
|
84
|
+
"predictions": [{"text": completion.text for completion in candidates}],
|
|
85
|
+
} # TODO: Extract more information from the response
|
|
86
|
+
return response_dict
|
|
87
|
+
|
|
88
|
+
# We need to include the engine's name to differentiate among requests made for different model
|
|
89
|
+
# engines since the engine name is not included in the request itself.
|
|
90
|
+
# Same for the prompt.
|
|
91
|
+
cache_key = CachingClient.make_cache_key(
|
|
92
|
+
{
|
|
93
|
+
"engine": request.model_engine,
|
|
94
|
+
"prompt": request.prompt,
|
|
95
|
+
**parameters,
|
|
96
|
+
},
|
|
97
|
+
request,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
101
|
+
except (requests.exceptions.RequestException, AssertionError) as e:
|
|
102
|
+
error: str = f"VertexAITextClient error: {e}"
|
|
103
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
104
|
+
|
|
105
|
+
for prediction in response["predictions"]:
|
|
106
|
+
response_text = prediction["text"]
|
|
107
|
+
|
|
108
|
+
# The Python SDK does not support echo
|
|
109
|
+
text: str = request.prompt + response_text if request.echo_prompt else response_text
|
|
110
|
+
|
|
111
|
+
# TODO #2085: Add support for log probs.
|
|
112
|
+
# Once again, log probs seem to be supported by the API but not by the Python SDK.
|
|
113
|
+
# HTTP Response body reference:
|
|
114
|
+
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text#response_body
|
|
115
|
+
# Python SDK reference:
|
|
116
|
+
# https://github.com/googleapis/python-aiplatform/blob/beae48f63e40ea171c3f1625164569e7311b8e5a/vertexai/language_models/_language_models.py#L868
|
|
117
|
+
completion = GeneratedOutput(text=text, logprob=0, tokens=[])
|
|
118
|
+
sequence = truncate_sequence(completion, request, print_warning=True)
|
|
119
|
+
completions.append(sequence)
|
|
120
|
+
|
|
121
|
+
return RequestResult(
|
|
122
|
+
success=True,
|
|
123
|
+
cached=cached,
|
|
124
|
+
request_time=response["request_time"],
|
|
125
|
+
request_datetime=response["request_datetime"],
|
|
126
|
+
completions=completions,
|
|
127
|
+
embedding=[],
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class VertexAIChatClient(VertexAIClient):
|
|
132
|
+
"""Client for Vertex AI chat models (e.g., Gemini). Supports multimodal prompts."""
|
|
133
|
+
|
|
134
|
+
# Set the finish reason to this if the prompt violates the content policy
|
|
135
|
+
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = "The prompt violates Google's content policy."
|
|
136
|
+
|
|
137
|
+
# Gemini returns this error for certain valid requests
|
|
138
|
+
CONTENT_HAS_NO_PARTS_ERROR: str = "Content has no parts."
|
|
139
|
+
|
|
140
|
+
# Enum taken from:
|
|
141
|
+
# https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#google.cloud.aiplatform.v1beta1.Candidate.FinishReason
|
|
142
|
+
# We don't directly import this enum because it can differ between different Vertex AI library versions.
|
|
143
|
+
CONTENT_BLOCKED_FINISH_REASONS: List[int] = [
|
|
144
|
+
3, # SAFETY
|
|
145
|
+
4, # RECITATION
|
|
146
|
+
6, # BLOCKLIST
|
|
147
|
+
7, # PROHIBITED_CONTENT
|
|
148
|
+
8, # SPII (Sensitive Personally Identifiable Information)
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def get_model(model_name: str) -> Any:
|
|
153
|
+
global _models_lock
|
|
154
|
+
global _models
|
|
155
|
+
|
|
156
|
+
with _models_lock:
|
|
157
|
+
if model_name not in _models:
|
|
158
|
+
_models[model_name] = GenerativeModel(model_name)
|
|
159
|
+
return _models[model_name]
|
|
160
|
+
|
|
161
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
162
|
+
"""Make a request"""
|
|
163
|
+
contents: str = request.prompt
|
|
164
|
+
|
|
165
|
+
# For the multimodal case, build up the content with the media objects of `request.multimodal_prompt`
|
|
166
|
+
if request.multimodal_prompt is not None:
|
|
167
|
+
return self._make_multimodal_request(request)
|
|
168
|
+
|
|
169
|
+
parameters = {
|
|
170
|
+
"temperature": request.temperature,
|
|
171
|
+
"max_output_tokens": request.max_tokens,
|
|
172
|
+
"top_k": request.top_k_per_token,
|
|
173
|
+
"top_p": request.top_p,
|
|
174
|
+
"stop_sequences": request.stop_sequences,
|
|
175
|
+
"candidate_count": request.num_completions,
|
|
176
|
+
# TODO #2084: Add support for these parameters.
|
|
177
|
+
# The parameters "echo", "frequency_penalty", and "presence_penalty" are supposed to be supported
|
|
178
|
+
# in an HTTP request (See https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text),
|
|
179
|
+
# but they are not supported in the Python SDK:
|
|
180
|
+
# https://github.com/googleapis/python-aiplatform/blob/beae48f63e40ea171c3f1625164569e7311b8e5a/vertexai/language_models/_language_models.py#L968C1-L980C1
|
|
181
|
+
# "frequency_penalty": request.frequency_penalty,
|
|
182
|
+
# "presence_penalty": request.presence_penalty,
|
|
183
|
+
# "echo": request.echo_prompt,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
completions: List[GeneratedOutput] = []
|
|
187
|
+
model_name: str = request.model_engine
|
|
188
|
+
model = self.get_model(model_name)
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
|
|
192
|
+
def do_it() -> Dict[str, Any]:
|
|
193
|
+
# Here we differ from Vertex AI's tutorial.
|
|
194
|
+
# https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/send-chat-prompts-gemini#send_chat_prompts # noqa: E501
|
|
195
|
+
# It would advise to use model.start_chat() but since we do not want to use Chat capabilities of
|
|
196
|
+
# Vertex AI, we use model.generate_text() instead. Furthermore, chat.send_message() restricts the
|
|
197
|
+
# output to only one candidate.
|
|
198
|
+
# chat: ChatSession = model.start_chat()
|
|
199
|
+
# See: https://github.com/googleapis/python-aiplatform/blob/e8c505751b10a9dc91ae2e0d6d13742d2abf945c/vertexai/generative_models/_generative_models.py#L812 # noqa: E501
|
|
200
|
+
response: GenerationResponse = model.generate_content(
|
|
201
|
+
contents, generation_config=parameters, safety_settings=self.safety_settings
|
|
202
|
+
)
|
|
203
|
+
candidates: List[Candidate] = response.candidates
|
|
204
|
+
|
|
205
|
+
# Depending on the version of the Vertex AI library and the type of content blocking,
|
|
206
|
+
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
207
|
+
if not candidates:
|
|
208
|
+
raise VertexAIContentBlockedError("No candidates in response due to content blocking")
|
|
209
|
+
predictions: List[Dict[str, Any]] = []
|
|
210
|
+
for candidate in candidates:
|
|
211
|
+
if (
|
|
212
|
+
candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS
|
|
213
|
+
or not candidate.content.parts
|
|
214
|
+
):
|
|
215
|
+
# The prediction was either blocked due to safety settings or the model stopped and returned
|
|
216
|
+
# nothing (which also happens when the model is blocked).
|
|
217
|
+
# For now, we don't cache blocked requests, because we are trying to get the
|
|
218
|
+
# content blocking removed.
|
|
219
|
+
raise VertexAIContentBlockedError("Content has no parts due to content blocking")
|
|
220
|
+
predictions.append({"text": candidate.content.text})
|
|
221
|
+
# TODO: Extract more information from the response
|
|
222
|
+
return {"predictions": predictions}
|
|
223
|
+
|
|
224
|
+
# We need to include the engine's name to differentiate among requests made for different model
|
|
225
|
+
# engines since the engine name is not included in the request itself.
|
|
226
|
+
# Same for the prompt.
|
|
227
|
+
cache_key = CachingClient.make_cache_key(
|
|
228
|
+
{
|
|
229
|
+
"model_name": model_name,
|
|
230
|
+
"prompt": request.prompt,
|
|
231
|
+
**parameters,
|
|
232
|
+
},
|
|
233
|
+
request,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
237
|
+
except VertexAIContentBlockedError:
|
|
238
|
+
return RequestResult(
|
|
239
|
+
success=False,
|
|
240
|
+
cached=False,
|
|
241
|
+
error="Response was empty due to content moderation filter",
|
|
242
|
+
completions=[],
|
|
243
|
+
embedding=[],
|
|
244
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
245
|
+
)
|
|
246
|
+
except (requests.exceptions.RequestException, AssertionError) as e:
|
|
247
|
+
error: str = f"VertexAITextClient error: {e}"
|
|
248
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
249
|
+
|
|
250
|
+
# Handle cached responses with blocked content from old versions of HELM.
|
|
251
|
+
if response["predictions"] is None:
|
|
252
|
+
return RequestResult(
|
|
253
|
+
success=False,
|
|
254
|
+
cached=False,
|
|
255
|
+
error="Response was empty due to content moderation filter",
|
|
256
|
+
completions=[],
|
|
257
|
+
embedding=[],
|
|
258
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
259
|
+
request_time=response["request_time"],
|
|
260
|
+
request_datetime=response["request_datetime"],
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
for prediction in response["predictions"]:
|
|
264
|
+
# Handle cached responses with blocked content from old versions of HELM.
|
|
265
|
+
if "text" not in prediction:
|
|
266
|
+
return RequestResult(
|
|
267
|
+
success=False,
|
|
268
|
+
cached=False,
|
|
269
|
+
error="Response was empty due to content moderation filter",
|
|
270
|
+
completions=[],
|
|
271
|
+
embedding=[],
|
|
272
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
273
|
+
request_time=response["request_time"],
|
|
274
|
+
request_datetime=response["request_datetime"],
|
|
275
|
+
)
|
|
276
|
+
response_text = prediction["text"]
|
|
277
|
+
|
|
278
|
+
# The Python SDK does not support echo
|
|
279
|
+
text: str = request.prompt + response_text if request.echo_prompt else response_text
|
|
280
|
+
completion = GeneratedOutput(text=text, logprob=0, tokens=[])
|
|
281
|
+
sequence = truncate_sequence(completion, request, print_warning=True)
|
|
282
|
+
completions.append(sequence)
|
|
283
|
+
|
|
284
|
+
return RequestResult(
|
|
285
|
+
success=True,
|
|
286
|
+
cached=cached,
|
|
287
|
+
request_time=response["request_time"],
|
|
288
|
+
request_datetime=response["request_datetime"],
|
|
289
|
+
completions=completions,
|
|
290
|
+
embedding=[],
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def _make_multimodal_request(self, request: Request) -> RequestResult:
|
|
294
|
+
def complete_for_valid_error(error_message: str) -> RequestResult:
|
|
295
|
+
empty_completion = GeneratedOutput(
|
|
296
|
+
text="",
|
|
297
|
+
logprob=0,
|
|
298
|
+
tokens=[],
|
|
299
|
+
finish_reason={"reason": error_message},
|
|
300
|
+
)
|
|
301
|
+
return RequestResult(
|
|
302
|
+
success=True,
|
|
303
|
+
cached=False,
|
|
304
|
+
request_time=0,
|
|
305
|
+
completions=[empty_completion] * request.num_completions,
|
|
306
|
+
embedding=[],
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Contents can either be text or a list of multimodal content made up of text, images or other content
|
|
310
|
+
contents: Union[str, List[Union[str, Any]]] = request.prompt
|
|
311
|
+
# Used to generate a unique cache key for this specific request
|
|
312
|
+
assert request.multimodal_prompt is not None
|
|
313
|
+
prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
|
|
314
|
+
|
|
315
|
+
# For the multimodal case, build up the content with the media objects of `request.multimodal_prompt`
|
|
316
|
+
contents = []
|
|
317
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
318
|
+
if media_object.is_type("image") and media_object.location:
|
|
319
|
+
contents.append(Part.from_image(Image.load_from_file(media_object.location)))
|
|
320
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
321
|
+
if media_object.text is None:
|
|
322
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
323
|
+
contents.append(media_object.text)
|
|
324
|
+
else:
|
|
325
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
326
|
+
|
|
327
|
+
parameters = {
|
|
328
|
+
"temperature": request.temperature,
|
|
329
|
+
"max_output_tokens": request.max_tokens,
|
|
330
|
+
"top_k": request.top_k_per_token,
|
|
331
|
+
"top_p": request.top_p,
|
|
332
|
+
"stop_sequences": request.stop_sequences,
|
|
333
|
+
"candidate_count": 1,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
completions: List[GeneratedOutput] = []
|
|
337
|
+
model_name: str = request.model_engine
|
|
338
|
+
model = self.get_model(model_name)
|
|
339
|
+
|
|
340
|
+
request_time = 0
|
|
341
|
+
request_datetime: Optional[int] = None
|
|
342
|
+
all_cached = True
|
|
343
|
+
|
|
344
|
+
# Gemini Vision only supports generating 1-2 candidates at a time, so make `request.num_completions` requests
|
|
345
|
+
for completion_index in range(request.num_completions):
|
|
346
|
+
try:
|
|
347
|
+
|
|
348
|
+
def do_it() -> Dict[str, Any]:
|
|
349
|
+
raw_response = model.generate_content(
|
|
350
|
+
contents, generation_config=parameters, safety_settings=self.safety_settings
|
|
351
|
+
)
|
|
352
|
+
if raw_response._raw_response.prompt_feedback.block_reason != 0:
|
|
353
|
+
hlog(f"Content blocked for prompt: {request.multimodal_prompt}")
|
|
354
|
+
return {"error": self.CONTENT_POLICY_VIOLATED_FINISH_REASON}
|
|
355
|
+
|
|
356
|
+
return {"predictions": [{"text": raw_response.candidates[0].text}]}
|
|
357
|
+
|
|
358
|
+
raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
|
|
359
|
+
if completion_index > 0:
|
|
360
|
+
raw_cache_key["completion_index"] = completion_index
|
|
361
|
+
|
|
362
|
+
cache_key = CachingClient.make_cache_key(raw_cache_key, request)
|
|
363
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
364
|
+
except (requests.exceptions.RequestException, ValueError) as e:
|
|
365
|
+
if str(e) == self.CONTENT_HAS_NO_PARTS_ERROR:
|
|
366
|
+
return complete_for_valid_error(self.CONTENT_HAS_NO_PARTS_ERROR)
|
|
367
|
+
|
|
368
|
+
error: str = f"Gemini Vision error: {e}"
|
|
369
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
370
|
+
|
|
371
|
+
if "error" in response:
|
|
372
|
+
return complete_for_valid_error(response["error"])
|
|
373
|
+
|
|
374
|
+
response_text = response["predictions"][0]["text"]
|
|
375
|
+
completion = GeneratedOutput(text=response_text, logprob=0, tokens=[])
|
|
376
|
+
sequence = truncate_sequence(completion, request, print_warning=True)
|
|
377
|
+
completions.append(sequence)
|
|
378
|
+
|
|
379
|
+
request_time += response["request_time"]
|
|
380
|
+
# Use the datetime from the first completion because that's when the request was fired
|
|
381
|
+
request_datetime = request_datetime or response.get("request_datetime")
|
|
382
|
+
all_cached = all_cached and cached
|
|
383
|
+
|
|
384
|
+
return RequestResult(
|
|
385
|
+
success=True,
|
|
386
|
+
cached=all_cached,
|
|
387
|
+
request_time=request_time,
|
|
388
|
+
request_datetime=request_datetime,
|
|
389
|
+
completions=completions,
|
|
390
|
+
embedding=[],
|
|
391
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from transformers import pipeline
|
|
5
|
+
from transformers.pipelines import ImageToTextPipeline
|
|
6
|
+
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.common.images_utils import open_image
|
|
9
|
+
from helm.common.media_object import TEXT_TYPE
|
|
10
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
11
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
12
|
+
from helm.common.tokenization_request import (
|
|
13
|
+
TokenizationRequest,
|
|
14
|
+
TokenizationRequestResult,
|
|
15
|
+
)
|
|
16
|
+
from helm.common.request import wrap_request_time
|
|
17
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
18
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
except ModuleNotFoundError as e:
|
|
23
|
+
handle_module_not_found_error(e, ["images"])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HuggingFaceVLMClient(CachingClient):
|
|
27
|
+
"""
|
|
28
|
+
General CLient for VLM models from HuggingFace.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
_models_lock: Lock = Lock()
|
|
32
|
+
_models: Dict[str, ImageToTextPipeline] = {}
|
|
33
|
+
_models_aliases: Dict[str, str] = {
|
|
34
|
+
"huggingface/llava-1.5-7b-hf": "llava-hf/llava-1.5-7b-hf",
|
|
35
|
+
"huggingface/llava-1.5-13b-hf": "llava-hf/llava-1.5-13b-hf",
|
|
36
|
+
"huggingface/bakLlava-v1-hf": "llava-hf/bakLlava-v1-hf",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
40
|
+
super().__init__(cache_config=cache_config)
|
|
41
|
+
self.tokenizer = tokenizer
|
|
42
|
+
self.tokenizer_name = tokenizer_name
|
|
43
|
+
|
|
44
|
+
def _get_model(self, model_name: str) -> ImageToTextPipeline:
|
|
45
|
+
with self._models_lock:
|
|
46
|
+
model_id: str = self._models_aliases.get(model_name, model_name)
|
|
47
|
+
if model_id not in self._models:
|
|
48
|
+
self._models[model_id] = pipeline("image-to-text", model=model_id)
|
|
49
|
+
return self._models[model_id]
|
|
50
|
+
|
|
51
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
52
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
53
|
+
|
|
54
|
+
# Build the prompt
|
|
55
|
+
prompt: str = ""
|
|
56
|
+
image: Optional[Image.Image] = None
|
|
57
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
58
|
+
if media_object.is_type("image") and media_object.location:
|
|
59
|
+
# TODO #2235: Figure out is fome HuggingFace models support multiple images
|
|
60
|
+
if image is not None:
|
|
61
|
+
raise ValueError("Only one image is supported in the multimodal prompt")
|
|
62
|
+
image = open_image(media_object.location)
|
|
63
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
64
|
+
if media_object.text is None:
|
|
65
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
66
|
+
prompt += f"\n{media_object.text}"
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f"Unsupported media object type: {media_object.type}")
|
|
69
|
+
|
|
70
|
+
# Generate
|
|
71
|
+
try:
|
|
72
|
+
generation_args = {
|
|
73
|
+
"max_new_tokens": request.max_tokens,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def do_it() -> Dict[str, Any]:
|
|
77
|
+
model: ImageToTextPipeline = self._get_model(request.model_deployment)
|
|
78
|
+
outputs = model(image, prompt=prompt, generate_kwargs=generation_args)
|
|
79
|
+
return outputs[0]
|
|
80
|
+
|
|
81
|
+
cache_key = CachingClient.make_cache_key(
|
|
82
|
+
raw_request={
|
|
83
|
+
"model": request.model,
|
|
84
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
85
|
+
**generation_args,
|
|
86
|
+
},
|
|
87
|
+
request=request,
|
|
88
|
+
)
|
|
89
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
90
|
+
except RuntimeError as e:
|
|
91
|
+
return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
|
|
92
|
+
|
|
93
|
+
tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
|
|
94
|
+
TokenizationRequest(result["generated_text"], tokenizer=self.tokenizer_name)
|
|
95
|
+
)
|
|
96
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
97
|
+
completions: List[GeneratedOutput] = [GeneratedOutput(text=result["generated_text"], logprob=0, tokens=tokens)]
|
|
98
|
+
return RequestResult(
|
|
99
|
+
success=True,
|
|
100
|
+
cached=cached,
|
|
101
|
+
request_time=result["request_time"],
|
|
102
|
+
completions=completions,
|
|
103
|
+
embedding=[],
|
|
104
|
+
)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from threading import Lock
|
|
2
|
-
from typing import Dict, List, Optional, Union
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from dataclasses import dataclass
|
|
@@ -8,17 +8,14 @@ from transformers import IdeficsForVisionText2Text, AutoProcessor, IdeficsProces
|
|
|
8
8
|
from helm.common.cache import CacheConfig
|
|
9
9
|
from helm.common.images_utils import open_image
|
|
10
10
|
from helm.common.gpu_utils import get_torch_device_name
|
|
11
|
-
from helm.common.hierarchical_logger import hlog
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
12
|
from helm.common.media_object import TEXT_TYPE
|
|
13
13
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
14
|
-
from helm.common.request import Request, RequestResult,
|
|
15
|
-
from helm.common.tokenization_request import
|
|
16
|
-
TokenizationRequest,
|
|
17
|
-
TokenizationRequestResult,
|
|
18
|
-
)
|
|
14
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
15
|
+
from helm.common.tokenization_request import TokenizationRequest
|
|
19
16
|
from helm.common.request import wrap_request_time
|
|
20
|
-
from helm.
|
|
21
|
-
from helm.
|
|
17
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
18
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
22
19
|
|
|
23
20
|
try:
|
|
24
21
|
from PIL import Image
|
|
@@ -54,8 +51,12 @@ class IDEFICSClient(CachingClient):
|
|
|
54
51
|
END_OF_UTTERANCE_TOKEN: str = "<end_of_utterance>"
|
|
55
52
|
BAD_WORD_TOKENS: List[str] = ["<image>", "<fake_token_around_image>"]
|
|
56
53
|
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
ASSISTANT_PREFIX: str = "Assistant: "
|
|
55
|
+
|
|
56
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
57
|
+
super().__init__(cache_config=cache_config)
|
|
58
|
+
self.tokenizer = tokenizer
|
|
59
|
+
self.tokenizer_name = tokenizer_name
|
|
59
60
|
self._device: str = get_torch_device_name()
|
|
60
61
|
|
|
61
62
|
def _get_model(self, checkpoint: str) -> LoadedIDEFICSModelProcessor:
|
|
@@ -67,8 +68,8 @@ class IDEFICSClient(CachingClient):
|
|
|
67
68
|
loaded_model_processor = _models[checkpoint]
|
|
68
69
|
if loaded_model_processor is None:
|
|
69
70
|
hlog(f"Loading model {checkpoint} and caching in memory...")
|
|
70
|
-
model = IdeficsForVisionText2Text.from_pretrained(
|
|
71
|
-
|
|
71
|
+
model = IdeficsForVisionText2Text.from_pretrained(
|
|
72
|
+
checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
|
72
73
|
)
|
|
73
74
|
processor = AutoProcessor.from_pretrained(checkpoint)
|
|
74
75
|
_models[checkpoint] = LoadedIDEFICSModelProcessor(model, processor)
|
|
@@ -78,10 +79,10 @@ class IDEFICSClient(CachingClient):
|
|
|
78
79
|
return loaded_model_processor
|
|
79
80
|
|
|
80
81
|
def make_request(self, request: Request) -> RequestResult:
|
|
81
|
-
assert request.
|
|
82
|
+
assert request.model_deployment in _models, f"Not a valid model for this client: {request.model_deployment}"
|
|
82
83
|
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
83
84
|
|
|
84
|
-
loaded_model_processor: LoadedIDEFICSModelProcessor = self._get_model(request.
|
|
85
|
+
loaded_model_processor: LoadedIDEFICSModelProcessor = self._get_model(request.model_deployment)
|
|
85
86
|
model = loaded_model_processor.model
|
|
86
87
|
processor = loaded_model_processor.processor
|
|
87
88
|
|
|
@@ -110,43 +111,49 @@ class IDEFICSClient(CachingClient):
|
|
|
110
111
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
111
112
|
prompt_text: str = request.multimodal_prompt.text.replace(self.END_OF_UTTERANCE_TOKEN, " ")
|
|
112
113
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
#
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
114
|
+
completions: List[GeneratedOutput] = []
|
|
115
|
+
with htrack_block(f"Generating for prompt: {prompt_text}"):
|
|
116
|
+
try:
|
|
117
|
+
|
|
118
|
+
def do_it() -> Dict[str, Any]:
|
|
119
|
+
inputs = processor([multimodal_prompt] * request.num_completions, **input_args).to(self._device)
|
|
120
|
+
generated_ids = model.generate(**inputs, **generation_args)
|
|
121
|
+
generated_text: List[str] = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
122
|
+
return {"output": generated_text}
|
|
123
|
+
|
|
124
|
+
# Include the prompt and model name in the cache key
|
|
125
|
+
cache_key = CachingClient.make_cache_key(
|
|
126
|
+
raw_request={
|
|
127
|
+
"n": request.num_completions,
|
|
128
|
+
"model": request.model,
|
|
129
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
130
|
+
**generation_args,
|
|
131
|
+
},
|
|
132
|
+
request=request,
|
|
133
|
+
)
|
|
134
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
135
|
+
except RuntimeError as model_error:
|
|
136
|
+
return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
|
|
137
|
+
|
|
138
|
+
for text in result["output"]:
|
|
139
|
+
hlog(f"Generated text: {text}")
|
|
140
|
+
|
|
141
|
+
# Truncate the output text as IDEFICS outputs the entire sequence including the prompt
|
|
142
|
+
if "instruct" in request.model:
|
|
143
|
+
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
|
|
144
|
+
text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
|
|
145
|
+
else:
|
|
146
|
+
# Best we can do is to remove the text portion of the prompt from the output
|
|
147
|
+
text = text[len(prompt_text) :]
|
|
148
|
+
|
|
149
|
+
# Tokenize truncated text to get the list of tokens
|
|
150
|
+
hlog(f"Truncated: {text}")
|
|
151
|
+
tokenization_result = self.tokenizer.tokenize(
|
|
152
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
|
|
153
|
+
)
|
|
154
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
155
|
+
completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
|
|
156
|
+
|
|
150
157
|
return RequestResult(
|
|
151
158
|
success=True,
|
|
152
159
|
cached=cached,
|
|
File without changes
|