crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Based on: https://github.com/lucidrains/flamingo-pytorch
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops import rearrange, repeat
|
|
7
|
+
from einops_exts import rearrange_many
|
|
8
|
+
from torch import einsum, nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def exists(val):
|
|
12
|
+
return val is not None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def FeedForward(dim, mult=4):
|
|
16
|
+
inner_dim = int(dim * mult)
|
|
17
|
+
return nn.Sequential(
|
|
18
|
+
nn.LayerNorm(dim),
|
|
19
|
+
nn.Linear(dim, inner_dim, bias=False),
|
|
20
|
+
nn.GELU(),
|
|
21
|
+
nn.Linear(inner_dim, dim, bias=False),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PerceiverAttention(nn.Module):
|
|
26
|
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.scale = dim_head**-0.5
|
|
29
|
+
self.heads = heads
|
|
30
|
+
inner_dim = dim_head * heads
|
|
31
|
+
|
|
32
|
+
self.norm_media = nn.LayerNorm(dim)
|
|
33
|
+
self.norm_latents = nn.LayerNorm(dim)
|
|
34
|
+
|
|
35
|
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
36
|
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
|
37
|
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
38
|
+
|
|
39
|
+
def forward(self, x, latents):
|
|
40
|
+
"""
|
|
41
|
+
Args:
|
|
42
|
+
x (torch.Tensor): image features
|
|
43
|
+
shape (b, T, n1, D)
|
|
44
|
+
latent (torch.Tensor): latent features
|
|
45
|
+
shape (b, T, n2, D)
|
|
46
|
+
"""
|
|
47
|
+
x = self.norm_media(x)
|
|
48
|
+
latents = self.norm_latents(latents)
|
|
49
|
+
|
|
50
|
+
h = self.heads
|
|
51
|
+
|
|
52
|
+
q = self.to_q(latents)
|
|
53
|
+
kv_input = torch.cat((x, latents), dim=-2)
|
|
54
|
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
55
|
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
|
56
|
+
q = q * self.scale
|
|
57
|
+
|
|
58
|
+
# attention
|
|
59
|
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
|
60
|
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
61
|
+
attn = sim.softmax(dim=-1)
|
|
62
|
+
|
|
63
|
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
|
64
|
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
|
65
|
+
return self.to_out(out)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class PerceiverResampler(nn.Module):
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
*,
|
|
72
|
+
dim,
|
|
73
|
+
depth=6,
|
|
74
|
+
dim_head=64,
|
|
75
|
+
heads=8,
|
|
76
|
+
num_latents=64,
|
|
77
|
+
max_num_media=None,
|
|
78
|
+
max_num_frames=None,
|
|
79
|
+
ff_mult=4,
|
|
80
|
+
):
|
|
81
|
+
super().__init__()
|
|
82
|
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
|
83
|
+
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
|
|
84
|
+
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
|
|
85
|
+
|
|
86
|
+
self.layers = nn.ModuleList([])
|
|
87
|
+
for _ in range(depth):
|
|
88
|
+
self.layers.append(
|
|
89
|
+
nn.ModuleList(
|
|
90
|
+
[
|
|
91
|
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
|
92
|
+
FeedForward(dim=dim, mult=ff_mult),
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self.norm = nn.LayerNorm(dim)
|
|
98
|
+
|
|
99
|
+
def forward(self, x):
|
|
100
|
+
"""
|
|
101
|
+
Args:
|
|
102
|
+
x (torch.Tensor): image features
|
|
103
|
+
shape (b, T, F, v, D)
|
|
104
|
+
Returns:
|
|
105
|
+
shape (b, T, n, D) where n is self.num_latents
|
|
106
|
+
"""
|
|
107
|
+
b, T, F, v = x.shape[:4]
|
|
108
|
+
|
|
109
|
+
# frame and media time embeddings
|
|
110
|
+
if exists(self.frame_embs):
|
|
111
|
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
|
112
|
+
x = x + frame_embs
|
|
113
|
+
x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
|
|
114
|
+
if exists(self.media_time_embs):
|
|
115
|
+
x = x + self.media_time_embs[:T]
|
|
116
|
+
|
|
117
|
+
# blocks
|
|
118
|
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
|
119
|
+
for attn, ff in self.layers:
|
|
120
|
+
latents = attn(x, latents) + latents
|
|
121
|
+
latents = ff(latents) + latents
|
|
122
|
+
return self.norm(latents)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# gated cross attention
|
|
126
|
+
class MaskedCrossAttention(nn.Module):
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
*,
|
|
130
|
+
dim,
|
|
131
|
+
dim_visual,
|
|
132
|
+
dim_head=64,
|
|
133
|
+
heads=8,
|
|
134
|
+
only_attend_immediate_media=True,
|
|
135
|
+
):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.scale = dim_head**-0.5
|
|
138
|
+
self.heads = heads
|
|
139
|
+
inner_dim = dim_head * heads
|
|
140
|
+
|
|
141
|
+
self.norm = nn.LayerNorm(dim)
|
|
142
|
+
|
|
143
|
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
144
|
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
|
145
|
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
146
|
+
|
|
147
|
+
# whether for text to only attend to immediate preceding image, or all previous images
|
|
148
|
+
self.only_attend_immediate_media = only_attend_immediate_media
|
|
149
|
+
|
|
150
|
+
def forward(self, x, media, media_locations=None, use_cached_media=False):
|
|
151
|
+
"""
|
|
152
|
+
Args:
|
|
153
|
+
x (torch.Tensor): text features
|
|
154
|
+
shape (B, T_txt, D_txt)
|
|
155
|
+
media (torch.Tensor): image features
|
|
156
|
+
shape (B, T_img, n, D_img) where n is the dim of the latents
|
|
157
|
+
media_locations: boolean mask identifying the media tokens in x
|
|
158
|
+
shape (B, T_txt)
|
|
159
|
+
use_cached_media: bool
|
|
160
|
+
If true, treat all of x as if they occur after the last media
|
|
161
|
+
registered in media_locations. T_txt does not need to exactly
|
|
162
|
+
equal media_locations.shape[1] in this case
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
if not use_cached_media:
|
|
166
|
+
assert (
|
|
167
|
+
media_locations.shape[1] == x.shape[1]
|
|
168
|
+
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
|
|
169
|
+
|
|
170
|
+
T_txt = x.shape[1]
|
|
171
|
+
_, T_img, n = media.shape[:3]
|
|
172
|
+
h = self.heads
|
|
173
|
+
|
|
174
|
+
x = self.norm(x)
|
|
175
|
+
|
|
176
|
+
q = self.to_q(x)
|
|
177
|
+
media = rearrange(media, "b t n d -> b (t n) d")
|
|
178
|
+
|
|
179
|
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
|
180
|
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
|
181
|
+
|
|
182
|
+
q = q * self.scale
|
|
183
|
+
|
|
184
|
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
|
185
|
+
|
|
186
|
+
if exists(media_locations):
|
|
187
|
+
media_time = torch.arange(T_img, device=x.device) + 1
|
|
188
|
+
|
|
189
|
+
if use_cached_media:
|
|
190
|
+
# text time is set to the last cached media location
|
|
191
|
+
text_time = repeat(
|
|
192
|
+
torch.count_nonzero(media_locations, dim=1),
|
|
193
|
+
"b -> b i",
|
|
194
|
+
i=T_txt,
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
# at each boolean of True, increment the time counter (relative to media time)
|
|
198
|
+
text_time = media_locations.cumsum(dim=-1)
|
|
199
|
+
|
|
200
|
+
# text time must equal media time if only attending to most immediate image
|
|
201
|
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
|
202
|
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
|
203
|
+
|
|
204
|
+
text_to_media_mask = mask_op(
|
|
205
|
+
rearrange(text_time, "b i -> b 1 i 1"),
|
|
206
|
+
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
|
207
|
+
)
|
|
208
|
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
|
209
|
+
|
|
210
|
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
211
|
+
attn = sim.softmax(dim=-1)
|
|
212
|
+
|
|
213
|
+
if exists(media_locations) and self.only_attend_immediate_media:
|
|
214
|
+
# any text without a preceding media needs to have attention zeroed out
|
|
215
|
+
text_without_media_mask = text_time == 0
|
|
216
|
+
text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
|
|
217
|
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
|
218
|
+
|
|
219
|
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
|
220
|
+
out = rearrange(out, "b h n d -> b n (h d)")
|
|
221
|
+
return self.to_out(out)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class GatedCrossAttentionBlock(nn.Module):
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
*,
|
|
228
|
+
dim,
|
|
229
|
+
dim_visual,
|
|
230
|
+
dim_head=64,
|
|
231
|
+
heads=8,
|
|
232
|
+
ff_mult=4,
|
|
233
|
+
only_attend_immediate_media=True,
|
|
234
|
+
):
|
|
235
|
+
super().__init__()
|
|
236
|
+
self.attn = MaskedCrossAttention(
|
|
237
|
+
dim=dim,
|
|
238
|
+
dim_visual=dim_visual,
|
|
239
|
+
dim_head=dim_head,
|
|
240
|
+
heads=heads,
|
|
241
|
+
only_attend_immediate_media=only_attend_immediate_media,
|
|
242
|
+
)
|
|
243
|
+
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
|
244
|
+
|
|
245
|
+
self.ff = FeedForward(dim, mult=ff_mult)
|
|
246
|
+
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
|
247
|
+
|
|
248
|
+
def forward(
|
|
249
|
+
self,
|
|
250
|
+
x,
|
|
251
|
+
media,
|
|
252
|
+
media_locations=None,
|
|
253
|
+
use_cached_media=False,
|
|
254
|
+
):
|
|
255
|
+
x = (
|
|
256
|
+
self.attn(
|
|
257
|
+
x,
|
|
258
|
+
media,
|
|
259
|
+
media_locations=media_locations,
|
|
260
|
+
use_cached_media=use_cached_media,
|
|
261
|
+
)
|
|
262
|
+
* self.attn_gate.tanh()
|
|
263
|
+
+ x
|
|
264
|
+
)
|
|
265
|
+
x = self.ff(x) * self.ff_gate.tanh() + x
|
|
266
|
+
|
|
267
|
+
return x
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Source: https://github.com/mlfoundations/open_flamingo
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def extend_instance(obj, mixin):
|
|
7
|
+
"""Apply mixins to a class instance after creation"""
|
|
8
|
+
base_cls = obj.__class__
|
|
9
|
+
base_cls_name = obj.__class__.__name__
|
|
10
|
+
obj.__class__ = type(
|
|
11
|
+
base_cls_name, (mixin, base_cls), {}
|
|
12
|
+
) # mixin needs to go first for our forward() logic to work
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def getattr_recursive(obj, att):
|
|
16
|
+
"""
|
|
17
|
+
Return nested attribute of obj
|
|
18
|
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
|
19
|
+
"""
|
|
20
|
+
if att == "":
|
|
21
|
+
return obj
|
|
22
|
+
i = att.find(".")
|
|
23
|
+
if i < 0:
|
|
24
|
+
return getattr(obj, att)
|
|
25
|
+
else:
|
|
26
|
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def setattr_recursive(obj, att, val):
|
|
30
|
+
"""
|
|
31
|
+
Set nested attribute of obj
|
|
32
|
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
|
33
|
+
"""
|
|
34
|
+
if "." in att:
|
|
35
|
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
|
36
|
+
setattr(obj, att.split(".")[-1], val)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def apply_with_stopping_condition(module, apply_fn, apply_condition=None, stopping_condition=None, **other_args):
|
|
40
|
+
if stopping_condition(module):
|
|
41
|
+
return
|
|
42
|
+
if apply_condition(module):
|
|
43
|
+
apply_fn(module, **other_args)
|
|
44
|
+
for child in module.children():
|
|
45
|
+
apply_with_stopping_condition(
|
|
46
|
+
child, apply_fn, apply_condition=apply_condition, stopping_condition=stopping_condition, **other_args
|
|
47
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from huggingface_hub import hf_hub_download
|
|
6
|
+
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
9
|
+
from helm.common.images_utils import open_image
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
11
|
+
from helm.common.media_object import TEXT_TYPE
|
|
12
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
13
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
14
|
+
from helm.common.request import wrap_request_time
|
|
15
|
+
from helm.clients.vision_language.open_flamingo import create_model_and_transforms
|
|
16
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from PIL import Image
|
|
20
|
+
except ModuleNotFoundError as e:
|
|
21
|
+
handle_module_not_found_error(e, ["images"])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenFlamingoClient(CachingClient):
|
|
25
|
+
"""
|
|
26
|
+
OpenFlamingo is an open source implementation of DeepMind's Flamingo models.
|
|
27
|
+
Implementation following:
|
|
28
|
+
https://github.com/mlfoundations/open_flamingo
|
|
29
|
+
https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
END_OF_CHUNK_TOKEN: str = "<|endofchunk|>"
|
|
33
|
+
IMAGE_TOKEN: str = "<image>"
|
|
34
|
+
|
|
35
|
+
_model_lock: Lock = Lock()
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
cache_config: CacheConfig,
|
|
40
|
+
checkpoint_path: Optional[str] = None,
|
|
41
|
+
tokenizer_name: Optional[str] = None,
|
|
42
|
+
cross_attn_every_n_layers: int = 4,
|
|
43
|
+
):
|
|
44
|
+
super().__init__(cache_config)
|
|
45
|
+
self._device: str = get_torch_device_name()
|
|
46
|
+
self._checkpoint_path: Optional[str] = checkpoint_path
|
|
47
|
+
self._tokenizer_name: Optional[str] = tokenizer_name
|
|
48
|
+
self._cross_attn_every_n_layers: int = cross_attn_every_n_layers
|
|
49
|
+
|
|
50
|
+
# Model
|
|
51
|
+
# The model is only initialized when the first request is made
|
|
52
|
+
# This is to avoid loading the model if it is not used
|
|
53
|
+
self._model: Optional[torch.nn.Module] = None
|
|
54
|
+
|
|
55
|
+
def _get_model(self):
|
|
56
|
+
if not self._checkpoint_path:
|
|
57
|
+
raise ValueError("OpenFlamingoClient requires a checkpoint path")
|
|
58
|
+
if not self._tokenizer_name:
|
|
59
|
+
raise ValueError("OpenFlamingoClient requires a tokenizer name")
|
|
60
|
+
with htrack_block("Initializing OpenFlamingo model"):
|
|
61
|
+
with self._model_lock:
|
|
62
|
+
self._model, self.image_processor, self.tokenizer = create_model_and_transforms(
|
|
63
|
+
clip_vision_encoder_path="ViT-L-14",
|
|
64
|
+
clip_vision_encoder_pretrained="openai",
|
|
65
|
+
lang_encoder_path=self._tokenizer_name,
|
|
66
|
+
tokenizer_path=self._tokenizer_name,
|
|
67
|
+
cross_attn_every_n_layers=self._cross_attn_every_n_layers,
|
|
68
|
+
)
|
|
69
|
+
self.tokenizer.padding_side = "left"
|
|
70
|
+
checkpoint_path = hf_hub_download(self._checkpoint_path, "checkpoint.pt")
|
|
71
|
+
self._model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
|
72
|
+
self._model = self._model.to(self._device)
|
|
73
|
+
hlog(f"Loaded model to {self._device}.")
|
|
74
|
+
|
|
75
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
76
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
77
|
+
|
|
78
|
+
# Load model if needed
|
|
79
|
+
if self._model is None:
|
|
80
|
+
self._get_model()
|
|
81
|
+
|
|
82
|
+
# Build the prompt
|
|
83
|
+
prompt_text: str = ""
|
|
84
|
+
images: List[Image.Image] = []
|
|
85
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
86
|
+
if media_object.is_type("image") and media_object.location:
|
|
87
|
+
images.append(open_image(media_object.location))
|
|
88
|
+
prompt_text += self.IMAGE_TOKEN
|
|
89
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
90
|
+
if media_object.text is None:
|
|
91
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
92
|
+
prompt_text += media_object.text
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
|
+
|
|
96
|
+
# Preprocess
|
|
97
|
+
vision_x: torch.Tensor = torch.cat([self.image_processor(image).unsqueeze(0) for image in images], dim=0)
|
|
98
|
+
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
|
99
|
+
lang_x = self.tokenizer([prompt_text], return_tensors="pt")
|
|
100
|
+
|
|
101
|
+
# Generate
|
|
102
|
+
try:
|
|
103
|
+
generation_args = {
|
|
104
|
+
"max_new_tokens": request.max_tokens,
|
|
105
|
+
"n": request.num_completions,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
def do_it():
|
|
109
|
+
tensors = self._model.generate(
|
|
110
|
+
vision_x=vision_x.to(self._device),
|
|
111
|
+
lang_x=lang_x["input_ids"].to(self._device),
|
|
112
|
+
attention_mask=lang_x["attention_mask"].to(self._device),
|
|
113
|
+
max_new_tokens=generation_args["max_new_tokens"],
|
|
114
|
+
num_beams=generation_args["n"],
|
|
115
|
+
num_return_sequences=generation_args["n"],
|
|
116
|
+
)
|
|
117
|
+
generated_completions: List[Tuple[str, List[str]]] = []
|
|
118
|
+
for tensor in tensors:
|
|
119
|
+
generated_text: str = self.tokenizer.decode(tensor)
|
|
120
|
+
raw_tokens: List[str] = self.tokenizer.tokenize(generated_text)
|
|
121
|
+
generated_completions.append((generated_text, raw_tokens))
|
|
122
|
+
|
|
123
|
+
return {"output": generated_completions}
|
|
124
|
+
|
|
125
|
+
cache_key = CachingClient.make_cache_key(
|
|
126
|
+
raw_request={
|
|
127
|
+
"model": request.model,
|
|
128
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
129
|
+
**generation_args,
|
|
130
|
+
},
|
|
131
|
+
request=request,
|
|
132
|
+
)
|
|
133
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
134
|
+
except RuntimeError as ex:
|
|
135
|
+
return RequestResult(success=False, cached=False, error=str(ex), completions=[], embedding=[])
|
|
136
|
+
|
|
137
|
+
completions: List[GeneratedOutput] = []
|
|
138
|
+
for text, tokens in result["output"]:
|
|
139
|
+
# Remove the prompt from the generated text
|
|
140
|
+
text = (
|
|
141
|
+
text[len(prompt_text) :].replace(self.END_OF_CHUNK_TOKEN, "").strip()
|
|
142
|
+
if len(text) >= len(prompt_text)
|
|
143
|
+
else text[-1]
|
|
144
|
+
)
|
|
145
|
+
completions.append(
|
|
146
|
+
GeneratedOutput(text=text, logprob=0, tokens=[Token(text=token, logprob=0) for token in tokens])
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return RequestResult(
|
|
150
|
+
success=True,
|
|
151
|
+
cached=cached,
|
|
152
|
+
request_time=result["request_time"],
|
|
153
|
+
completions=completions,
|
|
154
|
+
embedding=[],
|
|
155
|
+
)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
6
|
+
from transformers.generation import GenerationConfig
|
|
7
|
+
|
|
8
|
+
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
10
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
11
|
+
from helm.common.media_object import TEXT_TYPE
|
|
12
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
13
|
+
from helm.common.request import wrap_request_time
|
|
14
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class LoadedQwenModelProcessor:
|
|
19
|
+
"""Loaded model and processor for Qwen."""
|
|
20
|
+
|
|
21
|
+
model: AutoModelForCausalLM
|
|
22
|
+
tokenizer: AutoTokenizer
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
_models_lock: Lock = Lock()
|
|
26
|
+
_models: Dict[str, Optional[LoadedQwenModelProcessor]] = {
|
|
27
|
+
"Qwen/Qwen-VL": None,
|
|
28
|
+
"Qwen/Qwen-VL-Chat": None,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QwenVLMClient(CachingClient):
|
|
33
|
+
"""
|
|
34
|
+
From https://huggingface.co/Qwen/Qwen-VL,
|
|
35
|
+
Qwen-VL (Qwen Large Vision Language Model) is the visual multimodal version of the large model series,
|
|
36
|
+
Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. Qwen-VL accepts image, text, and bounding box
|
|
37
|
+
as inputs, outputs text and bounding box.
|
|
38
|
+
Alibaba released Qwen-VL and Qwen-VL-Chat, which is a chatbot model based on Qwen-VL.
|
|
39
|
+
|
|
40
|
+
Paper: https://arxiv.org/abs/2308.12966
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
END_OF_TEXT_TOKEN: str = "<|endoftext|>"
|
|
44
|
+
|
|
45
|
+
def __init__(self, cache_config: CacheConfig):
|
|
46
|
+
super().__init__(cache_config=cache_config)
|
|
47
|
+
self._device: str = get_torch_device_name()
|
|
48
|
+
|
|
49
|
+
def _get_model(self, helm_model_name: str) -> LoadedQwenModelProcessor:
|
|
50
|
+
global _models_lock
|
|
51
|
+
global _models
|
|
52
|
+
|
|
53
|
+
model_name: str
|
|
54
|
+
if helm_model_name == "qwen-vl-chat":
|
|
55
|
+
model_name = "Qwen/Qwen-VL-Chat"
|
|
56
|
+
elif helm_model_name == "qwen-vl":
|
|
57
|
+
model_name = "Qwen/Qwen-VL"
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Unhandled model name: {helm_model_name}")
|
|
60
|
+
|
|
61
|
+
# Ensure that only one thread is loading the model at a time
|
|
62
|
+
with _models_lock:
|
|
63
|
+
loaded_model_processor = _models[model_name]
|
|
64
|
+
if loaded_model_processor is None:
|
|
65
|
+
hlog(f"Loading model {model_name} and caching in memory...")
|
|
66
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
67
|
+
model_name, device_map=self._device, trust_remote_code=True, bf16=True
|
|
68
|
+
).eval()
|
|
69
|
+
if model_name == "Qwen/Qwen-VL-Chat":
|
|
70
|
+
model.generation_config = GenerationConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
71
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
72
|
+
_models[model_name] = LoadedQwenModelProcessor(model, tokenizer)
|
|
73
|
+
loaded_model_processor = _models[model_name]
|
|
74
|
+
|
|
75
|
+
assert loaded_model_processor is not None
|
|
76
|
+
return loaded_model_processor
|
|
77
|
+
|
|
78
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
79
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
80
|
+
|
|
81
|
+
loaded_model_processor: LoadedQwenModelProcessor = self._get_model(request.model_engine)
|
|
82
|
+
model = loaded_model_processor.model
|
|
83
|
+
tokenizer = loaded_model_processor.tokenizer
|
|
84
|
+
|
|
85
|
+
generation_args = {
|
|
86
|
+
"max_length": request.max_tokens,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
query: List[Dict[str, str]] = []
|
|
90
|
+
prompt_text: str = ""
|
|
91
|
+
|
|
92
|
+
image_index: int = 1
|
|
93
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
94
|
+
if media_object.is_type("image") and media_object.location:
|
|
95
|
+
query.append({"image": media_object.location})
|
|
96
|
+
prompt_text += f"Picture {image_index}: <img>{media_object.location}</img>\n"
|
|
97
|
+
image_index += 1
|
|
98
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
99
|
+
if media_object.text is None:
|
|
100
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
101
|
+
|
|
102
|
+
query.append({"text": media_object.text})
|
|
103
|
+
prompt_text += media_object.text
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
106
|
+
|
|
107
|
+
completions: List[GeneratedOutput] = []
|
|
108
|
+
request_time: float = 0
|
|
109
|
+
request_datetime: Optional[int] = None
|
|
110
|
+
all_cached: bool = True
|
|
111
|
+
|
|
112
|
+
with htrack_block(f"Generating for prompt: {prompt_text}"):
|
|
113
|
+
for completion_index in range(request.num_completions):
|
|
114
|
+
try:
|
|
115
|
+
|
|
116
|
+
def do_it() -> Dict[str, Any]:
|
|
117
|
+
if request.model_engine == "qwen-vl-chat":
|
|
118
|
+
completion, _ = model.chat(tokenizer, query=tokenizer.from_list_format(query), history=None)
|
|
119
|
+
else:
|
|
120
|
+
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt")
|
|
121
|
+
inputs = inputs.to(self._device)
|
|
122
|
+
pred = model.generate(**inputs, **generation_args)
|
|
123
|
+
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
|
|
124
|
+
|
|
125
|
+
tokens: List[str] = tokenizer.tokenize(completion)
|
|
126
|
+
return {"output": (completion, tokens)}
|
|
127
|
+
|
|
128
|
+
# Include the prompt and model name in the cache key
|
|
129
|
+
cache_key = CachingClient.make_cache_key(
|
|
130
|
+
raw_request={
|
|
131
|
+
"completion_index": completion_index,
|
|
132
|
+
"model": request.model,
|
|
133
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
134
|
+
**generation_args,
|
|
135
|
+
},
|
|
136
|
+
request=request,
|
|
137
|
+
)
|
|
138
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
139
|
+
except RuntimeError as model_error:
|
|
140
|
+
return RequestResult(
|
|
141
|
+
success=False, cached=False, error=str(model_error), completions=[], embedding=[]
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
text, tokens = result["output"]
|
|
145
|
+
|
|
146
|
+
# Truncate the output text as the original Qwen includes the prompt in the output sequence
|
|
147
|
+
if request.model_engine == "qwen-vl":
|
|
148
|
+
text = text[len(prompt_text) :]
|
|
149
|
+
text = text.replace(self.END_OF_TEXT_TOKEN, "")
|
|
150
|
+
hlog(f"Truncated: {text}")
|
|
151
|
+
|
|
152
|
+
# Tokenize truncated text to get the list of tokens
|
|
153
|
+
completions.append(
|
|
154
|
+
GeneratedOutput(
|
|
155
|
+
text=text, logprob=0, tokens=[Token(text=str(token), logprob=0) for token in tokens]
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
request_time += result["request_time"]
|
|
160
|
+
# Use the datetime from the first completion because that's when the request was fired
|
|
161
|
+
request_datetime = request_datetime or result.get("request_datetime")
|
|
162
|
+
all_cached = all_cached and cached
|
|
163
|
+
|
|
164
|
+
return RequestResult(
|
|
165
|
+
success=True,
|
|
166
|
+
cached=all_cached,
|
|
167
|
+
request_time=request_time,
|
|
168
|
+
request_datetime=request_datetime,
|
|
169
|
+
completions=completions,
|
|
170
|
+
embedding=[],
|
|
171
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import CacheConfig
|
|
4
|
+
from helm.common.request import Request
|
|
5
|
+
from helm.clients.openai_client import OpenAIClient
|
|
6
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VLLMClient(OpenAIClient):
|
|
10
|
+
"""Sends request to a vLLM server using the OpenAI-compatible API.
|
|
11
|
+
|
|
12
|
+
See: https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
tokenizer: Tokenizer,
|
|
17
|
+
tokenizer_name: str,
|
|
18
|
+
cache_config: CacheConfig,
|
|
19
|
+
base_url: Optional[str] = None,
|
|
20
|
+
):
|
|
21
|
+
super().__init__(
|
|
22
|
+
tokenizer=tokenizer,
|
|
23
|
+
tokenizer_name=tokenizer_name,
|
|
24
|
+
cache_config=cache_config,
|
|
25
|
+
api_key="EMPTY",
|
|
26
|
+
org_id=None,
|
|
27
|
+
base_url=base_url,
|
|
28
|
+
)
|
|
29
|
+
self.tokenizer = tokenizer
|
|
30
|
+
self.tokenizer_name = tokenizer_name
|
|
31
|
+
|
|
32
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
33
|
+
# Only support vLLM completion models for now.
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
def _get_model_for_request(self, request: Request) -> str:
|
|
37
|
+
# The `model` parameter for vLLM should be the whole model name including the creator organization,
|
|
38
|
+
# unlike OpenAI which only uses the model engine.
|
|
39
|
+
return request.model
|
|
40
|
+
|
|
41
|
+
def _to_raw_completion_request(self, request: Request) -> Dict[str, Any]:
|
|
42
|
+
raw_request = super()._to_raw_completion_request(request)
|
|
43
|
+
# This avoids the error: best_of must be 1 when using greedy sampling
|
|
44
|
+
if "best_of" in raw_request and raw_request["best_of"] > 1:
|
|
45
|
+
raw_request["best_of"] = 1
|
|
46
|
+
return raw_request
|