crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -1,6 +1,26 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from typing import List, Optional
|
|
3
3
|
|
|
4
|
+
from helm.common.image_generation_parameters import ImageGenerationParameters
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Adaptation methods
|
|
8
|
+
ADAPT_GENERATION: str = "generation"
|
|
9
|
+
ADAPT_LANGUAGE_MODELING: str = "language_modeling"
|
|
10
|
+
ADAPT_MULTIPLE_CHOICE_JOINT: str = "multiple_choice_joint"
|
|
11
|
+
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL: str = "multiple_choice_separate_original"
|
|
12
|
+
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED: str = "multiple_choice_separate_calibrated"
|
|
13
|
+
ADAPT_RANKING_BINARY: str = "ranking_binary"
|
|
14
|
+
|
|
15
|
+
ADAPT_MULTIPLE_CHOICE_SEPARATE_METHODS: List[str] = [
|
|
16
|
+
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
|
|
17
|
+
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
# Multimodal adaptation methods
|
|
21
|
+
ADAPT_GENERATION_MULTIMODAL: str = "generation_multimodal"
|
|
22
|
+
ADAPT_MULTIPLE_CHOICE_JOINT_MULTIMODAL: str = "multiple_choice_joint_multimodal"
|
|
23
|
+
|
|
4
24
|
|
|
5
25
|
@dataclass(frozen=True)
|
|
6
26
|
class Substitution:
|
|
@@ -71,6 +91,9 @@ class AdapterSpec:
|
|
|
71
91
|
# set of training instances. Used to compute error bars.
|
|
72
92
|
num_train_trials: int = 1
|
|
73
93
|
|
|
94
|
+
# Number of trials, where we query the model with the same requests, but different random seeds
|
|
95
|
+
num_trials: int = 1
|
|
96
|
+
|
|
74
97
|
# If true, randomly sample N training examples; if false, select N consecutive training examples
|
|
75
98
|
sample_train: bool = True
|
|
76
99
|
|
|
@@ -79,8 +102,7 @@ class AdapterSpec:
|
|
|
79
102
|
# Model deployment to make the request to (need to fill in)
|
|
80
103
|
model_deployment: str = ""
|
|
81
104
|
|
|
82
|
-
#
|
|
83
|
-
# TODO: Remove this once we do not wish to support backward compatibility anymore.
|
|
105
|
+
# Model to make the request to
|
|
84
106
|
model: str = ""
|
|
85
107
|
|
|
86
108
|
# Temperature to use
|
|
@@ -96,5 +118,11 @@ class AdapterSpec:
|
|
|
96
118
|
random: Optional[str] = None
|
|
97
119
|
|
|
98
120
|
# If true, for instances with multiple correct reference, the gold answer should be considered
|
|
99
|
-
# to be all
|
|
121
|
+
# to be all the correct references rather than any of the correct references.
|
|
100
122
|
multi_label: bool = False
|
|
123
|
+
|
|
124
|
+
# Parameters for image generation
|
|
125
|
+
image_generation_parameters: Optional[ImageGenerationParameters] = None
|
|
126
|
+
|
|
127
|
+
# The splits from which evaluation instances will be drawn (set hash=False to make `AdapterSpec` hashable)
|
|
128
|
+
eval_splits: Optional[List[str]] = field(default=None, hash=False)
|
|
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from typing import List
|
|
3
3
|
|
|
4
4
|
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
5
|
-
from helm.benchmark.adaptation.
|
|
5
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
6
6
|
from helm.benchmark.scenarios.scenario import Instance
|
|
7
7
|
from helm.benchmark.window_services.tokenizer_service import TokenizerService
|
|
8
8
|
from helm.benchmark.window_services.window_service import WindowService
|
|
@@ -22,7 +22,7 @@ class Adapter(ABC):
|
|
|
22
22
|
)
|
|
23
23
|
|
|
24
24
|
@abstractmethod
|
|
25
|
-
def adapt(self, instances: List[Instance], parallelism: int) ->
|
|
25
|
+
def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]:
|
|
26
26
|
"""
|
|
27
27
|
Takes a a list of `Instance`s and returns a `ScenarioState` with the
|
|
28
28
|
list of corresponding `RequestState`s.
|
|
@@ -1,31 +1,26 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from .language_modeling_adapter import LanguageModelingAdapter
|
|
8
|
-
from .multiple_choice_joint_adapter import MultipleChoiceJointAdapter
|
|
9
|
-
from .multiple_choice_separate_adapter import MultipleChoiceSeparateAdapter
|
|
10
|
-
from .multiple_choice_calibrated_adapter import MultipleChoiceCalibratedAdapter
|
|
11
|
-
from .binary_ranking_adapter import BinaryRankingAdapter
|
|
12
|
-
from .multimodal.generation_multimodal_adapter import GenerationMultimodalAdapter
|
|
13
|
-
|
|
14
|
-
# Adaptation methods
|
|
15
|
-
ADAPT_GENERATION: str = "generation"
|
|
16
|
-
ADAPT_LANGUAGE_MODELING: str = "language_modeling"
|
|
17
|
-
ADAPT_MULTIPLE_CHOICE_JOINT: str = "multiple_choice_joint"
|
|
18
|
-
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL: str = "multiple_choice_separate_original"
|
|
19
|
-
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED: str = "multiple_choice_separate_calibrated"
|
|
20
|
-
ADAPT_RANKING_BINARY: str = "ranking_binary"
|
|
21
|
-
|
|
22
|
-
ADAPT_MULTIPLE_CHOICE_SEPARATE_METHODS: List[str] = [
|
|
23
|
-
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
|
|
1
|
+
from helm.benchmark.adaptation.adapter_spec import (
|
|
2
|
+
ADAPT_GENERATION,
|
|
3
|
+
ADAPT_GENERATION_MULTIMODAL,
|
|
4
|
+
ADAPT_LANGUAGE_MODELING,
|
|
5
|
+
ADAPT_MULTIPLE_CHOICE_JOINT,
|
|
6
|
+
ADAPT_MULTIPLE_CHOICE_JOINT_MULTIMODAL,
|
|
24
7
|
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
8
|
+
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
|
|
9
|
+
ADAPT_RANKING_BINARY,
|
|
10
|
+
AdapterSpec,
|
|
11
|
+
)
|
|
12
|
+
from helm.benchmark.adaptation.adapters.adapter import Adapter
|
|
13
|
+
from helm.benchmark.adaptation.adapters.binary_ranking_adapter import BinaryRankingAdapter
|
|
14
|
+
from helm.benchmark.adaptation.adapters.generation_adapter import GenerationAdapter
|
|
15
|
+
from helm.benchmark.adaptation.adapters.language_modeling_adapter import LanguageModelingAdapter
|
|
16
|
+
from helm.benchmark.adaptation.adapters.multimodal.generation_multimodal_adapter import GenerationMultimodalAdapter
|
|
17
|
+
from helm.benchmark.adaptation.adapters.multimodal.multiple_choice_joint_multimodal_adapter import (
|
|
18
|
+
MultipleChoiceJointMultimodalAdapter,
|
|
19
|
+
)
|
|
20
|
+
from helm.benchmark.adaptation.adapters.multiple_choice_calibrated_adapter import MultipleChoiceCalibratedAdapter
|
|
21
|
+
from helm.benchmark.adaptation.adapters.multiple_choice_joint_adapter import MultipleChoiceJointAdapter
|
|
22
|
+
from helm.benchmark.adaptation.adapters.multiple_choice_separate_adapter import MultipleChoiceSeparateAdapter
|
|
23
|
+
from helm.benchmark.window_services.tokenizer_service import TokenizerService
|
|
29
24
|
|
|
30
25
|
|
|
31
26
|
class AdapterFactory:
|
|
@@ -51,6 +46,8 @@ class AdapterFactory:
|
|
|
51
46
|
adapter = BinaryRankingAdapter(adapter_spec, tokenizer_service)
|
|
52
47
|
elif method == ADAPT_GENERATION_MULTIMODAL:
|
|
53
48
|
adapter = GenerationMultimodalAdapter(adapter_spec, tokenizer_service)
|
|
49
|
+
elif method == ADAPT_MULTIPLE_CHOICE_JOINT_MULTIMODAL:
|
|
50
|
+
adapter = MultipleChoiceJointMultimodalAdapter(adapter_spec, tokenizer_service)
|
|
54
51
|
else:
|
|
55
52
|
raise ValueError(f"Invalid adaptation method: {method}")
|
|
56
53
|
|
|
@@ -46,6 +46,7 @@ class GenerationAdapter(InContextLearningAdapter):
|
|
|
46
46
|
max_tokens=self.adapter_spec.max_tokens,
|
|
47
47
|
stop_sequences=self.adapter_spec.stop_sequences,
|
|
48
48
|
random=self.adapter_spec.random,
|
|
49
|
+
image_generation_parameters=self.adapter_spec.image_generation_parameters,
|
|
49
50
|
)
|
|
50
51
|
request_state = RequestState(
|
|
51
52
|
instance=eval_instance,
|
|
@@ -7,9 +7,9 @@ from typing import List, Dict, Optional
|
|
|
7
7
|
|
|
8
8
|
from helm.benchmark.adaptation.prompt import Prompt
|
|
9
9
|
from helm.benchmark.adaptation.request_state import RequestState
|
|
10
|
-
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
11
10
|
from helm.benchmark.scenarios.scenario import Instance, TRAIN_SPLIT, EVAL_SPLITS, Reference
|
|
12
11
|
from helm.common.general import parallel_map
|
|
12
|
+
from helm.common.request import Request
|
|
13
13
|
from helm.common.hierarchical_logger import hlog, htrack, htrack_block
|
|
14
14
|
from .adapter import Adapter
|
|
15
15
|
|
|
@@ -30,7 +30,7 @@ class InContextLearningAdapter(Adapter, ABC):
|
|
|
30
30
|
pass
|
|
31
31
|
|
|
32
32
|
@htrack(None)
|
|
33
|
-
def adapt(self, instances: List[Instance], parallelism: int) ->
|
|
33
|
+
def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]:
|
|
34
34
|
"""
|
|
35
35
|
Takes a list of `Instance`s and builds a list of corresponding `RequestState`s.
|
|
36
36
|
The reason we don't do this per eval instance is that we create a common set of
|
|
@@ -64,7 +64,7 @@ class InContextLearningAdapter(Adapter, ABC):
|
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
hlog(f"{len(all_request_states)} requests")
|
|
67
|
-
return
|
|
67
|
+
return all_request_states
|
|
68
68
|
|
|
69
69
|
def _adapt_trial_index(
|
|
70
70
|
self,
|
|
@@ -101,7 +101,23 @@ class InContextLearningAdapter(Adapter, ABC):
|
|
|
101
101
|
hlog(line)
|
|
102
102
|
|
|
103
103
|
# Flatten and return
|
|
104
|
-
|
|
104
|
+
all_request_states: List[RequestState] = [request_state for result in results for request_state in result]
|
|
105
|
+
return self._add_trials(all_request_states)
|
|
106
|
+
|
|
107
|
+
def _add_trials(self, request_states: List[RequestState]) -> List[RequestState]:
|
|
108
|
+
"""Expand the request states by adding trials."""
|
|
109
|
+
if self.adapter_spec.num_trials <= 1:
|
|
110
|
+
return request_states
|
|
111
|
+
|
|
112
|
+
all_request_states: List[RequestState] = request_states.copy()
|
|
113
|
+
for i in range(1, self.adapter_spec.num_trials):
|
|
114
|
+
seed: str = str(i)
|
|
115
|
+
for request_state in request_states:
|
|
116
|
+
request: Request = replace(request_state.request, random=seed)
|
|
117
|
+
all_request_states.append(replace(request_state, request=request))
|
|
118
|
+
|
|
119
|
+
assert len(all_request_states) == len(request_states) * self.adapter_spec.num_trials
|
|
120
|
+
return all_request_states
|
|
105
121
|
|
|
106
122
|
def sample_examples(
|
|
107
123
|
self, all_train_instances: List[Instance], seed: int, sample_train: bool = True
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import List, Tuple, Optional
|
|
2
2
|
|
|
3
3
|
from helm.benchmark.adaptation.request_state import RequestState
|
|
4
|
-
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
5
4
|
from helm.benchmark.scenarios.scenario import Instance, EVAL_SPLITS
|
|
6
5
|
from helm.benchmark.window_services.window_service import EncodeResult
|
|
7
6
|
from helm.common.general import flatten_list, parallel_map
|
|
@@ -26,7 +25,7 @@ class LanguageModelingAdapter(Adapter):
|
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
27
|
@htrack(None)
|
|
29
|
-
def adapt(self, instances: List[Instance], parallelism: int) ->
|
|
28
|
+
def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]:
|
|
30
29
|
"""
|
|
31
30
|
Takes a list of `Instance`s and builds a list of corresponding `RequestState`s.
|
|
32
31
|
Only requires eval instances.
|
|
@@ -46,7 +45,7 @@ class LanguageModelingAdapter(Adapter):
|
|
|
46
45
|
)
|
|
47
46
|
hlog(f"{len(all_request_states)} requests")
|
|
48
47
|
|
|
49
|
-
return
|
|
48
|
+
return all_request_states
|
|
50
49
|
|
|
51
50
|
def _generate_requests(self, eval_instance: Instance) -> List[RequestState]:
|
|
52
51
|
"""
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
5
|
+
from helm.benchmark.scenarios.scenario import Instance
|
|
6
|
+
from helm.common.media_object import MediaObject, MultimediaObject
|
|
7
|
+
from helm.common.request import Request
|
|
8
|
+
from helm.benchmark.adaptation.adapters.multimodal.in_context_learning_multimodal_adapter import (
|
|
9
|
+
InContextLearningMultimodalAdapter,
|
|
10
|
+
)
|
|
11
|
+
from .multimodal_prompt import MultimodalPrompt
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MultipleChoiceJointMultimodalAdapter(InContextLearningMultimodalAdapter, ABC):
|
|
15
|
+
"""
|
|
16
|
+
An `Adapter`, guided by the `AdapterSpec`, takes a `Scenario` and produces
|
|
17
|
+
a `ScenarioState`. This `Adapter` has additional logic to support in-context
|
|
18
|
+
learning for multimodal models.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def get_reference_prefix(prefix: str, i: int) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Example: prefix = "\nA. ", i = 2, return "\nC. "
|
|
25
|
+
"""
|
|
26
|
+
return prefix.replace("A", chr(ord("A") + i))
|
|
27
|
+
|
|
28
|
+
def generate_requests(
|
|
29
|
+
self, eval_instance: Instance, train_trial_index: int, training_instances: List[Instance]
|
|
30
|
+
) -> List[RequestState]:
|
|
31
|
+
prompt: MultimodalPrompt = self.construct_prompt(
|
|
32
|
+
training_instances, eval_instance, include_output=False, reference_index=None
|
|
33
|
+
)
|
|
34
|
+
output_mapping: Dict[str, str] = dict(
|
|
35
|
+
(self.get_reference_prefix("A", reference_index), reference.output.text)
|
|
36
|
+
for reference_index, reference in enumerate(eval_instance.references)
|
|
37
|
+
)
|
|
38
|
+
request = Request(
|
|
39
|
+
model=self.adapter_spec.model,
|
|
40
|
+
model_deployment=self.adapter_spec.model_deployment,
|
|
41
|
+
multimodal_prompt=prompt.multimedia_object,
|
|
42
|
+
num_completions=self.adapter_spec.num_outputs,
|
|
43
|
+
temperature=self.adapter_spec.temperature,
|
|
44
|
+
max_tokens=self.adapter_spec.max_tokens,
|
|
45
|
+
stop_sequences=[],
|
|
46
|
+
random=self.adapter_spec.random,
|
|
47
|
+
)
|
|
48
|
+
request_state = RequestState(
|
|
49
|
+
instance=eval_instance,
|
|
50
|
+
reference_index=None,
|
|
51
|
+
request_mode=None,
|
|
52
|
+
train_trial_index=train_trial_index,
|
|
53
|
+
output_mapping=output_mapping,
|
|
54
|
+
request=request,
|
|
55
|
+
result=None,
|
|
56
|
+
num_train_instances=prompt.num_train_instances,
|
|
57
|
+
prompt_truncated=False,
|
|
58
|
+
)
|
|
59
|
+
return [request_state]
|
|
60
|
+
|
|
61
|
+
def construct_example_multimodal_prompt(
|
|
62
|
+
self, instance: Instance, include_output: bool, reference_index: Optional[int]
|
|
63
|
+
) -> MultimediaObject:
|
|
64
|
+
"""
|
|
65
|
+
Returns a single example of the prompt. `include_output` controls whether the gold output is included.
|
|
66
|
+
"""
|
|
67
|
+
# Input
|
|
68
|
+
assert instance.input.multimedia_content is not None
|
|
69
|
+
result: MultimediaObject = instance.input.multimedia_content.add_textual_prefix(self.adapter_spec.input_prefix)
|
|
70
|
+
result = result.add_textual_suffix(self.adapter_spec.input_suffix)
|
|
71
|
+
|
|
72
|
+
# Include the references
|
|
73
|
+
delimiter: str = ", "
|
|
74
|
+
no_correct_references: str = "n/a"
|
|
75
|
+
output: str = no_correct_references
|
|
76
|
+
for reference_index, reference in enumerate(instance.references):
|
|
77
|
+
prefix = self.get_reference_prefix(self.adapter_spec.reference_prefix, reference_index)
|
|
78
|
+
|
|
79
|
+
if reference.output.multimedia_content is not None:
|
|
80
|
+
reference_output_content: MultimediaObject = reference.output.multimedia_content
|
|
81
|
+
reference_output_content = reference_output_content.add_textual_prefix(prefix)
|
|
82
|
+
reference_output_content = reference_output_content.add_textual_suffix(
|
|
83
|
+
self.adapter_spec.reference_suffix
|
|
84
|
+
)
|
|
85
|
+
result = result.combine(reference_output_content)
|
|
86
|
+
else:
|
|
87
|
+
result = result.add_textual_suffix(prefix + reference.output.text + self.adapter_spec.reference_suffix)
|
|
88
|
+
|
|
89
|
+
if reference.is_correct:
|
|
90
|
+
if output == no_correct_references:
|
|
91
|
+
output = self.get_reference_prefix("A", reference_index)
|
|
92
|
+
elif self.adapter_spec.multi_label:
|
|
93
|
+
output += delimiter
|
|
94
|
+
output += self.get_reference_prefix("A", reference_index)
|
|
95
|
+
|
|
96
|
+
if include_output:
|
|
97
|
+
output_content: MultimediaObject = MultimediaObject([MediaObject(text=output, content_type="text/plain")])
|
|
98
|
+
output_content = output_content.add_textual_prefix(self.adapter_spec.output_prefix)
|
|
99
|
+
output_content = output_content.add_textual_suffix(self.adapter_spec.output_suffix)
|
|
100
|
+
result = result.combine(output_content)
|
|
101
|
+
else:
|
|
102
|
+
result = result.add_textual_suffix(self.adapter_spec.output_prefix.rstrip())
|
|
103
|
+
|
|
104
|
+
return result
|
helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import shutil
|
|
2
2
|
import tempfile
|
|
3
3
|
import unittest
|
|
4
|
+
from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
|
|
4
5
|
|
|
5
6
|
from helm.common.media_object import MediaObject, MultimediaObject
|
|
6
7
|
from helm.benchmark.scenarios.scenario import Instance, Reference, Input, Output, TEST_SPLIT, TRAIN_SPLIT, CORRECT_TAG
|
|
@@ -14,7 +15,7 @@ from .multimodal_prompt import MultimodalPrompt
|
|
|
14
15
|
class TestInContextLearningMultimodalAdapter(unittest.TestCase):
|
|
15
16
|
def setup_method(self, _):
|
|
16
17
|
self._path: str = tempfile.mkdtemp()
|
|
17
|
-
self._tokenizer_service = get_tokenizer_service(self._path)
|
|
18
|
+
self._tokenizer_service = get_tokenizer_service(self._path, BlackHoleCacheBackendConfig())
|
|
18
19
|
|
|
19
20
|
def teardown_method(self, _):
|
|
20
21
|
shutil.rmtree(self._path)
|
|
@@ -2,6 +2,7 @@ import shutil
|
|
|
2
2
|
import tempfile
|
|
3
3
|
|
|
4
4
|
from helm.common.authentication import Authentication
|
|
5
|
+
from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
|
|
5
6
|
from helm.proxy.services.server_service import ServerService
|
|
6
7
|
from helm.benchmark.window_services.tokenizer_service import TokenizerService
|
|
7
8
|
|
|
@@ -13,7 +14,7 @@ class TestAdapter:
|
|
|
13
14
|
|
|
14
15
|
def setup_method(self):
|
|
15
16
|
self.path: str = tempfile.mkdtemp()
|
|
16
|
-
service = ServerService(base_path=self.path, root_mode=True)
|
|
17
|
+
service = ServerService(base_path=self.path, root_mode=True, cache_backend_config=BlackHoleCacheBackendConfig())
|
|
17
18
|
self.tokenizer_service = TokenizerService(service, Authentication("test"))
|
|
18
19
|
|
|
19
20
|
def teardown_method(self, _):
|
|
@@ -11,24 +11,27 @@ from helm.benchmark.scenarios.scenario import (
|
|
|
11
11
|
Input,
|
|
12
12
|
Output,
|
|
13
13
|
)
|
|
14
|
-
from helm.benchmark.run_specs import
|
|
14
|
+
from helm.benchmark.run_specs.simple_run_specs import get_simple1_spec
|
|
15
15
|
from helm.benchmark.adaptation.prompt import Prompt
|
|
16
16
|
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
17
17
|
from .adapter_factory import AdapterFactory, ADAPT_GENERATION
|
|
18
|
+
from .generation_adapter import GenerationAdapter
|
|
18
19
|
from .test_adapter import TestAdapter
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class TestGenerationAdapter(TestAdapter):
|
|
22
23
|
def test_adapt(self):
|
|
23
|
-
|
|
24
|
-
|
|
24
|
+
run_spec = get_simple1_spec()
|
|
25
|
+
scenario = create_scenario(run_spec.scenario_spec)
|
|
26
|
+
adapter_spec = run_spec.adapter_spec
|
|
25
27
|
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
|
|
26
|
-
|
|
28
|
+
instances = scenario.get_instances(output_path="")
|
|
29
|
+
request_states = adapter.adapt(instances, parallelism=1)
|
|
30
|
+
non_train_instances = [instance for instance in instances if instance.split != TRAIN_SPLIT]
|
|
27
31
|
|
|
28
32
|
# Make sure we generated the right number of request_states:
|
|
29
33
|
# For each trial, instance and reference (+ 1 for free-form generation).
|
|
30
|
-
|
|
31
|
-
assert num_instances * adapter_spec.num_train_trials == len(scenario_state.request_states)
|
|
34
|
+
assert len(non_train_instances) * adapter_spec.num_train_trials == len(request_states)
|
|
32
35
|
|
|
33
36
|
def test_construct_prompt(self):
|
|
34
37
|
adapter_spec = AdapterSpec(
|
|
@@ -194,7 +197,7 @@ class TestGenerationAdapter(TestAdapter):
|
|
|
194
197
|
],
|
|
195
198
|
split=TEST_SPLIT,
|
|
196
199
|
)
|
|
197
|
-
actual_instances = adapter.adapt(train_instances + [eval_instance], parallelism=1)
|
|
200
|
+
actual_instances = adapter.adapt(train_instances + [eval_instance], parallelism=1)
|
|
198
201
|
assert len(actual_instances) == 1
|
|
199
202
|
assert actual_instances[0].request.prompt == (
|
|
200
203
|
"Input: Second reference is correct\n"
|
|
@@ -244,7 +247,7 @@ class TestGenerationAdapter(TestAdapter):
|
|
|
244
247
|
],
|
|
245
248
|
split=TEST_SPLIT,
|
|
246
249
|
)
|
|
247
|
-
actual_instances = adapter.adapt(train_instances + [eval_instance], parallelism=1)
|
|
250
|
+
actual_instances = adapter.adapt(train_instances + [eval_instance], parallelism=1)
|
|
248
251
|
assert len(actual_instances) == 1
|
|
249
252
|
assert actual_instances[0].request.prompt == (
|
|
250
253
|
"Input: Second reference is correct\n"
|
|
@@ -254,3 +257,24 @@ class TestGenerationAdapter(TestAdapter):
|
|
|
254
257
|
"Input: First reference is correct\n"
|
|
255
258
|
"Output:"
|
|
256
259
|
)
|
|
260
|
+
|
|
261
|
+
def test_construct_prompt_image_generation(self):
|
|
262
|
+
adapter_spec = AdapterSpec(
|
|
263
|
+
model_deployment="openai/dall-e-2",
|
|
264
|
+
method=ADAPT_GENERATION,
|
|
265
|
+
input_prefix="",
|
|
266
|
+
input_suffix="",
|
|
267
|
+
output_prefix="",
|
|
268
|
+
output_suffix="",
|
|
269
|
+
max_train_instances=0,
|
|
270
|
+
num_outputs=1,
|
|
271
|
+
max_tokens=0,
|
|
272
|
+
)
|
|
273
|
+
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
|
|
274
|
+
assert isinstance(adapter, GenerationAdapter)
|
|
275
|
+
|
|
276
|
+
eval_instance = Instance(Input(text="a blue dog"), references=[])
|
|
277
|
+
prompt: Prompt = adapter.construct_prompt([], eval_instance, include_output=False, reference_index=None)
|
|
278
|
+
|
|
279
|
+
assert adapter.window_service.fits_within_context_window(prompt.text)
|
|
280
|
+
assert prompt.text == "a blue dog"
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
# mypy: check_untyped_defs = False
|
|
2
2
|
from typing import List
|
|
3
|
-
from helm.benchmark.window_services.gpt2_window_service import GPT2WindowService
|
|
4
3
|
|
|
5
4
|
from helm.common.tokenization_request import TokenizationToken
|
|
6
5
|
from helm.benchmark.adaptation.request_state import RequestState
|
|
@@ -11,18 +10,6 @@ from .test_adapter import TestAdapter
|
|
|
11
10
|
from helm.benchmark.scenarios.scenario import TEST_SPLIT, Instance, Input, Reference
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
class MockGPT2Window(GPT2WindowService):
|
|
15
|
-
"""Utility for overriding properties of a GPT2WindowService for test purposes."""
|
|
16
|
-
|
|
17
|
-
def __init__(self, service, *, max_sequence_length):
|
|
18
|
-
super().__init__(service)
|
|
19
|
-
self._max_sequence_length = max_sequence_length
|
|
20
|
-
|
|
21
|
-
@property
|
|
22
|
-
def max_sequence_length(self) -> int:
|
|
23
|
-
return self._max_sequence_length
|
|
24
|
-
|
|
25
|
-
|
|
26
13
|
class TestLanguageModelingAdapter(TestAdapter):
|
|
27
14
|
def test_construct_language_modeling_prompt(self):
|
|
28
15
|
adapter_spec = AdapterSpec(
|
|
@@ -100,7 +87,7 @@ class TestLanguageModelingAdapter(TestAdapter):
|
|
|
100
87
|
split=TEST_SPLIT,
|
|
101
88
|
)
|
|
102
89
|
# Ensure the adapter returns the correct prompt
|
|
103
|
-
request_states: List[RequestState] = adapter.adapt([instance], parallelism=1)
|
|
90
|
+
request_states: List[RequestState] = adapter.adapt([instance], parallelism=1)
|
|
104
91
|
request: Request = request_states[0].request
|
|
105
92
|
# The prompt should be "<|endoftext|>Excuse me, do you have the time?"
|
|
106
93
|
assert request.prompt == "<|endoftext|>Excuse me, do you have the time?"
|
|
@@ -112,7 +99,7 @@ class TestLanguageModelingAdapter(TestAdapter):
|
|
|
112
99
|
references=[reference],
|
|
113
100
|
split=TEST_SPLIT,
|
|
114
101
|
)
|
|
115
|
-
request_states_long: List[RequestState] = adapter.adapt([instance_long], parallelism=1)
|
|
102
|
+
request_states_long: List[RequestState] = adapter.adapt([instance_long], parallelism=1)
|
|
116
103
|
request_long: Request = request_states_long[0].request
|
|
117
104
|
# Count the number of tokens of the prompt
|
|
118
105
|
num_tokens = len(adapter.window_service.encode(request_long.prompt).token_values)
|
|
@@ -130,7 +117,7 @@ class TestLanguageModelingAdapter(TestAdapter):
|
|
|
130
117
|
adapter_2 = AdapterFactory.get_adapter(adapter_spec_2_, self.tokenizer_service)
|
|
131
118
|
|
|
132
119
|
# Step 2.1. Check that if the prompt is not too long, it is not truncated
|
|
133
|
-
request_state_2: List[RequestState] = adapter_2.adapt([instance], parallelism=1)
|
|
120
|
+
request_state_2: List[RequestState] = adapter_2.adapt([instance], parallelism=1)
|
|
134
121
|
request_2: Request = request_state_2[0].request
|
|
135
122
|
# The prompt should be unchanged
|
|
136
123
|
assert request_2.prompt == "<|endoftext|>Excuse me, do you have the time?"
|
|
@@ -138,7 +125,7 @@ class TestLanguageModelingAdapter(TestAdapter):
|
|
|
138
125
|
|
|
139
126
|
# Step 2.2. Check that if the prompt + max_tokens is too long, it is truncated
|
|
140
127
|
# but that we keep the same number of tokens as in the previous test
|
|
141
|
-
request_states_long_2: List[RequestState] = adapter_2.adapt([instance_long], parallelism=1)
|
|
128
|
+
request_states_long_2: List[RequestState] = adapter_2.adapt([instance_long], parallelism=1)
|
|
142
129
|
request_long_2: Request = request_states_long_2[0].request
|
|
143
130
|
# Count the number of tokens of the prompt
|
|
144
131
|
num_tokens_2 = len(adapter_2.window_service.encode(request_long_2.prompt).token_values)
|
|
@@ -159,12 +146,13 @@ class TestLanguageModelingAdapter(TestAdapter):
|
|
|
159
146
|
)
|
|
160
147
|
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
|
|
161
148
|
# Monkey patch the window service to have really short max sequences.
|
|
162
|
-
adapter.window_service =
|
|
149
|
+
adapter.window_service._max_sequence_length = max_sequence_length
|
|
150
|
+
adapter.window_service._max_request_length = max_sequence_length + 1
|
|
163
151
|
input_text = Input(text=" ".join(str(i) for i in range(input_tokens)))
|
|
164
152
|
instance = Instance(input=input_text, references=[], split=TEST_SPLIT)
|
|
165
153
|
|
|
166
154
|
# Generate the requests
|
|
167
|
-
request_states: List[RequestState] = adapter.adapt([instance], parallelism=1)
|
|
155
|
+
request_states: List[RequestState] = adapter.adapt([instance], parallelism=1)
|
|
168
156
|
# A smaller window service creates more requests
|
|
169
157
|
assert len(request_states) == 3
|
|
170
158
|
assert request_states[0].request.prompt == "<|endoftext|>0 1 2 3 4 5 6 7 8 9"
|
|
@@ -1,10 +1,23 @@
|
|
|
1
1
|
# mypy: check_untyped_defs = False
|
|
2
|
+
from typing import List, Set
|
|
2
3
|
from helm.benchmark.scenarios.scenario import TEST_SPLIT, TRAIN_SPLIT, Instance, Input, Output, Reference, CORRECT_TAG
|
|
3
4
|
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
4
5
|
from .adapter_factory import AdapterFactory, ADAPT_MULTIPLE_CHOICE_JOINT
|
|
5
6
|
from .test_adapter import TestAdapter
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
def _make_instance(
|
|
10
|
+
text: str, reference_texts: List[str], correct_references: Set[int], is_eval: bool = False
|
|
11
|
+
) -> Instance:
|
|
12
|
+
references = []
|
|
13
|
+
for i, reference_text in enumerate(reference_texts):
|
|
14
|
+
tags = [CORRECT_TAG] if i in correct_references else []
|
|
15
|
+
references.append(Reference(Output(text=reference_text), tags=tags))
|
|
16
|
+
|
|
17
|
+
split = TEST_SPLIT if is_eval else TRAIN_SPLIT
|
|
18
|
+
return Instance(Input(text=text), references=references, split=split)
|
|
19
|
+
|
|
20
|
+
|
|
8
21
|
class TestMultipleChoiceJointAdapter(TestAdapter):
|
|
9
22
|
def test_sample_examples(self):
|
|
10
23
|
adapter_spec = AdapterSpec(
|
|
@@ -53,6 +66,47 @@ class TestMultipleChoiceJointAdapter(TestAdapter):
|
|
|
53
66
|
examples = adapter.sample_examples(all_train_instances, seed=0)
|
|
54
67
|
assert len(examples) == 3
|
|
55
68
|
|
|
69
|
+
def test_sample_examples_unique_labels(self):
|
|
70
|
+
"""This is a demonstration of behavior reported in issue #2224."""
|
|
71
|
+
adapter_spec = AdapterSpec(
|
|
72
|
+
method=ADAPT_MULTIPLE_CHOICE_JOINT, model="openai/ada", model_deployment="openai/ada", max_train_instances=3
|
|
73
|
+
)
|
|
74
|
+
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
|
|
75
|
+
all_train_instances = [
|
|
76
|
+
# Three with 0 being correct.
|
|
77
|
+
_make_instance("one", ["0", "1"], correct_references={0}),
|
|
78
|
+
_make_instance("two", ["2", "3"], correct_references={0}),
|
|
79
|
+
_make_instance("three", ["4", "5"], correct_references={0}),
|
|
80
|
+
# Two with 1 being correct.
|
|
81
|
+
_make_instance("four", ["6", "7"], correct_references={1}),
|
|
82
|
+
_make_instance("five", ["8", "9"], correct_references={1}),
|
|
83
|
+
]
|
|
84
|
+
eval_instance = _make_instance("eval", ["10", "11"], correct_references={1}, is_eval=True)
|
|
85
|
+
request_states = adapter.adapt(all_train_instances + [eval_instance], parallelism=1)
|
|
86
|
+
assert len(request_states) == 1
|
|
87
|
+
# In every case, we are showing that model that Output should be "A".
|
|
88
|
+
assert request_states[0].request.prompt == (
|
|
89
|
+
"Input: three\n"
|
|
90
|
+
"A. 4\n"
|
|
91
|
+
"B. 5\n"
|
|
92
|
+
"Output: A\n"
|
|
93
|
+
"\n"
|
|
94
|
+
"Input: two\n"
|
|
95
|
+
"A. 2\n"
|
|
96
|
+
"B. 3\n"
|
|
97
|
+
"Output: A\n"
|
|
98
|
+
"\n"
|
|
99
|
+
"Input: one\n"
|
|
100
|
+
"A. 0\n"
|
|
101
|
+
"B. 1\n"
|
|
102
|
+
"Output: A\n"
|
|
103
|
+
"\n"
|
|
104
|
+
"Input: eval\n"
|
|
105
|
+
"A. 10\n"
|
|
106
|
+
"B. 11\n"
|
|
107
|
+
"Output:"
|
|
108
|
+
)
|
|
109
|
+
|
|
56
110
|
def test_multiple_correct_reference(self):
|
|
57
111
|
adapter_spec = AdapterSpec(
|
|
58
112
|
method=ADAPT_MULTIPLE_CHOICE_JOINT,
|
|
@@ -91,9 +145,9 @@ class TestMultipleChoiceJointAdapter(TestAdapter):
|
|
|
91
145
|
],
|
|
92
146
|
split=TEST_SPLIT,
|
|
93
147
|
)
|
|
94
|
-
|
|
95
|
-
assert len(
|
|
96
|
-
assert
|
|
148
|
+
request_states = adapter.adapt(train_instances + [eval_instance], parallelism=1)
|
|
149
|
+
assert len(request_states) == 1
|
|
150
|
+
assert request_states[0].request.prompt == (
|
|
97
151
|
"Input: Second reference is correct\n"
|
|
98
152
|
"A. First\n"
|
|
99
153
|
"B. Second\n"
|
|
@@ -150,9 +204,9 @@ class TestMultipleChoiceJointAdapter(TestAdapter):
|
|
|
150
204
|
],
|
|
151
205
|
split=TEST_SPLIT,
|
|
152
206
|
)
|
|
153
|
-
|
|
154
|
-
assert len(
|
|
155
|
-
assert
|
|
207
|
+
request_states = adapter.adapt(train_instances + [eval_instance], parallelism=1)
|
|
208
|
+
assert len(request_states) == 1
|
|
209
|
+
assert request_states[0].request.prompt == (
|
|
156
210
|
"Input: Second reference is correct\n"
|
|
157
211
|
"A. First\n"
|
|
158
212
|
"B. Second\n"
|