crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -5,7 +5,7 @@ from transformers.generation.stopping_criteria import (
|
|
|
5
5
|
StoppingCriteria,
|
|
6
6
|
StoppingCriteriaList,
|
|
7
7
|
)
|
|
8
|
-
from typing import Any, Dict, List, Optional
|
|
8
|
+
from typing import Any, Dict, List, Optional, TypedDict
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
11
|
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
@@ -14,11 +14,11 @@ from helm.common.request import (
|
|
|
14
14
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
15
15
|
Request,
|
|
16
16
|
RequestResult,
|
|
17
|
-
|
|
17
|
+
GeneratedOutput,
|
|
18
18
|
Token,
|
|
19
19
|
)
|
|
20
20
|
from .client import CachingClient, truncate_sequence
|
|
21
|
-
from helm.
|
|
21
|
+
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
|
|
22
22
|
from threading import Lock
|
|
23
23
|
|
|
24
24
|
|
|
@@ -36,6 +36,20 @@ class StopAtSpecificTokenCriteria(StoppingCriteria):
|
|
|
36
36
|
return bool(torch.all(current_sequence == stop_sequence_tensor).item())
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
class HuggingFaceRequest(TypedDict):
|
|
40
|
+
"""Data passed between make_request and serve_request. Used as the cache key."""
|
|
41
|
+
|
|
42
|
+
engine: str
|
|
43
|
+
prompt: str
|
|
44
|
+
temperature: float
|
|
45
|
+
num_return_sequences: int
|
|
46
|
+
max_new_tokens: int
|
|
47
|
+
top_p: float
|
|
48
|
+
echo_prompt: bool
|
|
49
|
+
top_k_per_token: int
|
|
50
|
+
stop_sequences: List
|
|
51
|
+
|
|
52
|
+
|
|
39
53
|
class HuggingFaceServer:
|
|
40
54
|
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
|
|
41
55
|
|
|
@@ -55,30 +69,24 @@ class HuggingFaceServer:
|
|
|
55
69
|
pretrained_model_name_or_path, **kwargs
|
|
56
70
|
)
|
|
57
71
|
|
|
58
|
-
def serve_request(self, raw_request: Dict
|
|
72
|
+
def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
|
|
59
73
|
with self.wrapped_tokenizer as tokenizer:
|
|
60
74
|
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
|
|
61
75
|
self.device
|
|
62
76
|
)
|
|
63
|
-
raw_request = deepcopy(raw_request)
|
|
64
|
-
raw_request["do_sample"] = True
|
|
65
|
-
raw_request["return_dict_in_generate"] = True
|
|
66
|
-
raw_request["output_scores"] = True
|
|
67
|
-
top_k_per_token: int = raw_request["top_k_per_token"]
|
|
68
|
-
del raw_request["top_k_per_token"]
|
|
69
77
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
|
78
|
+
optional_args = {}
|
|
70
79
|
if len(raw_request["stop_sequences"]) > 0:
|
|
71
80
|
with self.wrapped_tokenizer as tokenizer:
|
|
72
81
|
stop_sequence_ids = tokenizer(
|
|
73
82
|
raw_request["stop_sequences"], return_token_type_ids=False, add_special_tokens=False
|
|
74
83
|
)
|
|
75
84
|
if len(stop_sequence_ids.input_ids) == 1 and len(stop_sequence_ids.input_ids[0]) == 1:
|
|
76
|
-
|
|
85
|
+
optional_args["eos_token_id"] = stop_sequence_ids.input_ids[0][0]
|
|
77
86
|
else:
|
|
78
87
|
stopping_criteria = StoppingCriteriaList()
|
|
79
88
|
for stop_sequence_input_ids in stop_sequence_ids.input_ids:
|
|
80
89
|
stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_input_ids))
|
|
81
|
-
del raw_request["stop_sequences"]
|
|
82
90
|
|
|
83
91
|
# Check if we need to compute the perplexity of the prompt (#1497)
|
|
84
92
|
compute_logprobs_only = (
|
|
@@ -94,64 +102,42 @@ class HuggingFaceServer:
|
|
|
94
102
|
sequences = encoded_input["input_ids"]
|
|
95
103
|
scores = output.logits
|
|
96
104
|
else:
|
|
97
|
-
# Strip out irrelevant parameters
|
|
98
|
-
relevant_raw_request = {
|
|
99
|
-
key: raw_request[key]
|
|
100
|
-
for key in raw_request
|
|
101
|
-
if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"]
|
|
102
|
-
}
|
|
103
|
-
|
|
104
105
|
output = self.model.generate(
|
|
105
106
|
**encoded_input,
|
|
106
|
-
|
|
107
|
+
temperature=raw_request["temperature"],
|
|
108
|
+
num_return_sequences=raw_request["num_return_sequences"],
|
|
109
|
+
max_new_tokens=raw_request["max_new_tokens"],
|
|
110
|
+
top_p=raw_request["top_p"],
|
|
111
|
+
do_sample=True,
|
|
112
|
+
return_dict_in_generate=True,
|
|
113
|
+
output_scores=True,
|
|
114
|
+
**optional_args,
|
|
107
115
|
stopping_criteria=stopping_criteria,
|
|
108
116
|
)
|
|
109
117
|
sequences = output.sequences
|
|
110
118
|
scores = output.scores
|
|
111
119
|
|
|
112
120
|
prompt_tokens_logprobs = []
|
|
113
|
-
prompt_tokens_top_logprobs_dicts: List[Dict] = []
|
|
114
121
|
if compute_logprobs_only:
|
|
115
122
|
# Append the logprob of the first token of the prompt.
|
|
116
123
|
prompt_tokens_logprobs.append(0.0)
|
|
117
|
-
prompt_tokens_top_logprobs_dicts.append({})
|
|
118
124
|
|
|
119
125
|
# Compute logprobs of prompt tokens.
|
|
120
126
|
for completion_id in range(raw_request["num_return_sequences"]):
|
|
121
127
|
for i in range(len(sequences[completion_id]) - 1):
|
|
122
128
|
logprobs = torch.nn.functional.log_softmax(scores[completion_id][i], dim=0)
|
|
123
|
-
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
|
|
124
|
-
with self.wrapped_tokenizer as tokenizer:
|
|
125
|
-
prompt_tokens_top_logprobs_dicts.append(
|
|
126
|
-
{
|
|
127
|
-
tokenizer.convert_ids_to_tokens(k.item()): v.item()
|
|
128
|
-
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)
|
|
129
|
-
}
|
|
130
|
-
)
|
|
131
129
|
prompt_tokens_logprobs.append(logprobs[sequences[completion_id][i + 1]].item())
|
|
132
130
|
|
|
133
131
|
# Compute logprobs of generated tokens for each completed sequence.
|
|
134
132
|
all_generated_tokens_logprobs = []
|
|
135
|
-
all_generated_tokens_top_logprobs_dicts = []
|
|
136
133
|
for completion_id in range(raw_request["num_return_sequences"]):
|
|
137
134
|
generated_tokens_logprobs = []
|
|
138
|
-
generated_tokens_top_logprobs_dicts = []
|
|
139
135
|
for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])):
|
|
140
136
|
logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0)
|
|
141
|
-
# Get top tokens in terms of log probability.
|
|
142
|
-
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
|
|
143
|
-
with self.wrapped_tokenizer as tokenizer:
|
|
144
|
-
generated_tokens_top_logprobs_dicts.append(
|
|
145
|
-
{
|
|
146
|
-
tokenizer.convert_ids_to_tokens(k.item()): v.item()
|
|
147
|
-
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)
|
|
148
|
-
}
|
|
149
|
-
)
|
|
150
137
|
# Get log probability of chosen token.
|
|
151
138
|
j = i + len(encoded_input.input_ids[0])
|
|
152
139
|
generated_tokens_logprobs.append(logprobs[sequences[completion_id][j]].item())
|
|
153
140
|
all_generated_tokens_logprobs.append(generated_tokens_logprobs)
|
|
154
|
-
all_generated_tokens_top_logprobs_dicts.append(generated_tokens_top_logprobs_dicts)
|
|
155
141
|
|
|
156
142
|
# Remove prompt from the start of each sequence if echo_prompt is False.
|
|
157
143
|
if not raw_request["echo_prompt"]:
|
|
@@ -162,17 +148,15 @@ class HuggingFaceServer:
|
|
|
162
148
|
all_decoded_text = tokenizer.batch_decode(sequences)
|
|
163
149
|
|
|
164
150
|
completions = []
|
|
165
|
-
for decoded_text, tokens, generated_tokens_logprobs
|
|
166
|
-
all_decoded_text, all_tokens, all_generated_tokens_logprobs
|
|
151
|
+
for decoded_text, tokens, generated_tokens_logprobs in zip(
|
|
152
|
+
all_decoded_text, all_tokens, all_generated_tokens_logprobs
|
|
167
153
|
):
|
|
168
154
|
completions.append(
|
|
169
155
|
{
|
|
170
156
|
"text": decoded_text,
|
|
171
157
|
"tokens": tokens,
|
|
172
158
|
"logprobs": generated_tokens_logprobs,
|
|
173
|
-
"top_logprobs_dicts": generated_tokens_top_logprobs_dicts,
|
|
174
159
|
"prompt_logprobs": prompt_tokens_logprobs,
|
|
175
|
-
"prompt_top_logprobs_dicts": prompt_tokens_top_logprobs_dicts,
|
|
176
160
|
}
|
|
177
161
|
)
|
|
178
162
|
|
|
@@ -240,7 +224,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
240
224
|
if request.embedding:
|
|
241
225
|
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
|
|
242
226
|
|
|
243
|
-
raw_request = {
|
|
227
|
+
raw_request: HuggingFaceRequest = {
|
|
244
228
|
"engine": request.model_engine,
|
|
245
229
|
"prompt": request.prompt,
|
|
246
230
|
"temperature": 1e-7 if request.temperature == 0 else request.temperature,
|
|
@@ -252,20 +236,18 @@ class HuggingFaceClient(CachingClient):
|
|
|
252
236
|
"stop_sequences": request.stop_sequences,
|
|
253
237
|
}
|
|
254
238
|
|
|
255
|
-
pretrained_model_name_or_path
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
else:
|
|
259
|
-
pretrained_model_name_or_path = resolve_alias(request.model_deployment)
|
|
239
|
+
pretrained_model_name_or_path = (
|
|
240
|
+
self._pretrained_model_name_or_path if self._pretrained_model_name_or_path else request.model
|
|
241
|
+
)
|
|
260
242
|
huggingface_model: HuggingFaceServer = HuggingFaceServerFactory.get_server(
|
|
261
|
-
helm_model_name=request.
|
|
243
|
+
helm_model_name=request.model,
|
|
262
244
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
263
245
|
**self._kwargs,
|
|
264
246
|
)
|
|
265
247
|
|
|
266
248
|
try:
|
|
267
249
|
|
|
268
|
-
def do_it():
|
|
250
|
+
def do_it() -> Dict[str, Any]:
|
|
269
251
|
return huggingface_model.serve_request(raw_request)
|
|
270
252
|
|
|
271
253
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
@@ -282,29 +264,26 @@ class HuggingFaceClient(CachingClient):
|
|
|
282
264
|
if request.echo_prompt:
|
|
283
265
|
# Add prompt to list of generated tokens.
|
|
284
266
|
generated_tokens = raw_completion["tokens"][response["input_length"] :]
|
|
285
|
-
if raw_completion.get("prompt_logprobs")
|
|
286
|
-
for token_text, logprob
|
|
267
|
+
if raw_completion.get("prompt_logprobs"):
|
|
268
|
+
for token_text, logprob in zip(
|
|
287
269
|
raw_completion["tokens"][: response["input_length"]],
|
|
288
270
|
raw_completion["prompt_logprobs"][: response["input_length"]],
|
|
289
|
-
raw_completion["prompt_top_logprobs_dicts"][: response["input_length"]],
|
|
290
271
|
):
|
|
291
|
-
tokens.append(Token(text=token_text, logprob=logprob
|
|
272
|
+
tokens.append(Token(text=token_text, logprob=logprob))
|
|
292
273
|
sequence_logprob += logprob
|
|
293
274
|
else:
|
|
294
275
|
for token_text in raw_completion["tokens"][: response["input_length"]]:
|
|
295
|
-
tokens.append(Token(text=token_text, logprob=0.0
|
|
276
|
+
tokens.append(Token(text=token_text, logprob=0.0))
|
|
296
277
|
|
|
297
278
|
else:
|
|
298
279
|
generated_tokens = raw_completion["tokens"]
|
|
299
280
|
|
|
300
281
|
# Compute logprob for the entire sequence.
|
|
301
|
-
for token_text, logprob
|
|
302
|
-
|
|
303
|
-
):
|
|
304
|
-
tokens.append(Token(text=token_text, logprob=logprob, top_logprobs=top_logprobs_dict))
|
|
282
|
+
for token_text, logprob in zip(generated_tokens, raw_completion["logprobs"]):
|
|
283
|
+
tokens.append(Token(text=token_text, logprob=logprob))
|
|
305
284
|
sequence_logprob += logprob
|
|
306
285
|
|
|
307
|
-
completion =
|
|
286
|
+
completion = GeneratedOutput(text=raw_completion["text"], logprob=sequence_logprob, tokens=tokens)
|
|
308
287
|
completion = truncate_sequence(completion, request)
|
|
309
288
|
completions.append(completion)
|
|
310
289
|
|
|
File without changes
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import Cache, CacheConfig
|
|
4
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput
|
|
5
|
+
from helm.common.tokenization_request import (
|
|
6
|
+
TokenizationRequest,
|
|
7
|
+
TokenizationRequestResult,
|
|
8
|
+
DecodeRequest,
|
|
9
|
+
DecodeRequestResult,
|
|
10
|
+
)
|
|
11
|
+
from helm.clients.client import Client, CachingClient
|
|
12
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AdobeVisionClient(Client):
|
|
16
|
+
"""
|
|
17
|
+
Client for Adobe vision models. Offline eval only.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
SUPPORTED_MODELS: List[str] = ["giga-gan", "firefly"]
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def convert_to_raw_request(request: Request) -> Dict:
|
|
24
|
+
# Use default hyperparameters for everything else
|
|
25
|
+
raw_request: Dict = {
|
|
26
|
+
"request_type": "image-model-inference",
|
|
27
|
+
"model": request.model_engine,
|
|
28
|
+
"prompt": request.prompt,
|
|
29
|
+
"n": request.num_completions,
|
|
30
|
+
}
|
|
31
|
+
if request.random is not None:
|
|
32
|
+
raw_request["random"] = request.random
|
|
33
|
+
return raw_request
|
|
34
|
+
|
|
35
|
+
def __init__(self, cache_config: CacheConfig):
|
|
36
|
+
self._cache = Cache(cache_config)
|
|
37
|
+
self._promptist_model = None
|
|
38
|
+
self._promptist_tokenizer = None
|
|
39
|
+
|
|
40
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
41
|
+
if request.model_engine not in self.SUPPORTED_MODELS:
|
|
42
|
+
raise ValueError(f"Unsupported model: {request.model_engine}")
|
|
43
|
+
|
|
44
|
+
raw_request = AdobeVisionClient.convert_to_raw_request(request)
|
|
45
|
+
raw_request.pop("random", None)
|
|
46
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
|
|
50
|
+
def fail():
|
|
51
|
+
raise RuntimeError(
|
|
52
|
+
f"The result has not been uploaded to the cache for the following request: {cache_key}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
response, cached = self._cache.get(cache_key, fail)
|
|
56
|
+
except RuntimeError as e:
|
|
57
|
+
error: str = f"Adobe Vision Client error: {e}"
|
|
58
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
59
|
+
|
|
60
|
+
completions: List[GeneratedOutput] = [
|
|
61
|
+
GeneratedOutput(
|
|
62
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_path)
|
|
63
|
+
)
|
|
64
|
+
for file_path in response["images"]
|
|
65
|
+
]
|
|
66
|
+
return RequestResult(
|
|
67
|
+
success=True,
|
|
68
|
+
cached=cached,
|
|
69
|
+
request_time=response["request_time"],
|
|
70
|
+
completions=completions,
|
|
71
|
+
embedding=[],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
75
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
76
|
+
|
|
77
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
78
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import Cache, CacheConfig
|
|
4
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput
|
|
5
|
+
from helm.common.tokenization_request import (
|
|
6
|
+
TokenizationRequest,
|
|
7
|
+
TokenizationRequestResult,
|
|
8
|
+
DecodeRequest,
|
|
9
|
+
DecodeRequestResult,
|
|
10
|
+
)
|
|
11
|
+
from helm.clients.client import Client, CachingClient
|
|
12
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AlephAlphaImageGenerationClient(Client):
|
|
16
|
+
"""
|
|
17
|
+
Client for Aleph Alpha vision models. Offline eval only.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
DEFAULT_IMAGE_HEIGHT: int = 512
|
|
21
|
+
DEFAULT_IMAGE_WIDTH: int = 512
|
|
22
|
+
|
|
23
|
+
DEFAULT_GUIDANCE_SCALE: float = 7.5
|
|
24
|
+
DEFAULT_STEPS: int = 50
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def convert_to_raw_request(request: Request) -> Dict:
|
|
28
|
+
raw_request: Dict = {
|
|
29
|
+
"request_type": "image-model-inference",
|
|
30
|
+
"model": request.model_engine,
|
|
31
|
+
"prompt": request.prompt,
|
|
32
|
+
"n": request.num_completions,
|
|
33
|
+
"guidance_scale": AlephAlphaImageGenerationClient.DEFAULT_GUIDANCE_SCALE,
|
|
34
|
+
"steps": AlephAlphaImageGenerationClient.DEFAULT_STEPS,
|
|
35
|
+
"width": AlephAlphaImageGenerationClient.DEFAULT_IMAGE_WIDTH,
|
|
36
|
+
"height": AlephAlphaImageGenerationClient.DEFAULT_IMAGE_HEIGHT,
|
|
37
|
+
}
|
|
38
|
+
if request.random is not None:
|
|
39
|
+
raw_request["random"] = request.random
|
|
40
|
+
|
|
41
|
+
assert request.image_generation_parameters is not None
|
|
42
|
+
if request.image_generation_parameters.guidance_scale is not None:
|
|
43
|
+
raw_request["guidance_scale"] = request.image_generation_parameters.guidance_scale
|
|
44
|
+
if request.image_generation_parameters.diffusion_denoising_steps is not None:
|
|
45
|
+
raw_request["steps"] = request.image_generation_parameters.diffusion_denoising_steps
|
|
46
|
+
if (
|
|
47
|
+
request.image_generation_parameters.output_image_width is not None
|
|
48
|
+
and request.image_generation_parameters.output_image_height is not None
|
|
49
|
+
):
|
|
50
|
+
raw_request["width"] = request.image_generation_parameters.output_image_width
|
|
51
|
+
raw_request["height"] = request.image_generation_parameters.output_image_height
|
|
52
|
+
|
|
53
|
+
return raw_request
|
|
54
|
+
|
|
55
|
+
def __init__(self, cache_config: CacheConfig):
|
|
56
|
+
self._cache = Cache(cache_config)
|
|
57
|
+
self._promptist_model = None
|
|
58
|
+
self._promptist_tokenizer = None
|
|
59
|
+
|
|
60
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
61
|
+
if request.model_engine != "m-vader":
|
|
62
|
+
raise ValueError(f"Unsupported model: {request.model_engine}")
|
|
63
|
+
|
|
64
|
+
raw_request = AlephAlphaImageGenerationClient.convert_to_raw_request(request)
|
|
65
|
+
raw_request.pop("random", None)
|
|
66
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
|
|
70
|
+
def fail():
|
|
71
|
+
raise RuntimeError(
|
|
72
|
+
f"The result has not been uploaded to the cache for the following request: {cache_key}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
response, cached = self._cache.get(cache_key, fail)
|
|
76
|
+
except RuntimeError as e:
|
|
77
|
+
error: str = f"AlephAlphaVisionClient error: {e}"
|
|
78
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
79
|
+
|
|
80
|
+
completions: List[GeneratedOutput] = [
|
|
81
|
+
GeneratedOutput(
|
|
82
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_path)
|
|
83
|
+
)
|
|
84
|
+
for file_path in response["images"]
|
|
85
|
+
]
|
|
86
|
+
return RequestResult(
|
|
87
|
+
success=True,
|
|
88
|
+
cached=cached,
|
|
89
|
+
request_time=response["request_time"],
|
|
90
|
+
completions=completions,
|
|
91
|
+
embedding=[],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
95
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
96
|
+
|
|
97
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
98
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
File without changes
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
@File : coglm_strategy.py
|
|
4
|
+
@Time : 2021/10/08 22:22:42
|
|
5
|
+
@Author : Ming Ding
|
|
6
|
+
@Contact : dm18@mails.tsinghua.edu.cn
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# here put the import lib
|
|
10
|
+
import os
|
|
11
|
+
import torch
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
|
|
17
|
+
# This function has been mostly taken from huggingface conversational ai code at
|
|
18
|
+
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
|
19
|
+
|
|
20
|
+
if top_k > 0:
|
|
21
|
+
# Remove all tokens with a probability less than the last token of the top-k
|
|
22
|
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
23
|
+
logits[indices_to_remove] = filter_value
|
|
24
|
+
|
|
25
|
+
if top_p > 0.0:
|
|
26
|
+
# convert to 1D
|
|
27
|
+
logits = logits.view(logits.size()[1]).contiguous()
|
|
28
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
29
|
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
30
|
+
|
|
31
|
+
# Remove tokens with cumulative probability above the threshold
|
|
32
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
33
|
+
# Shift the indices to the right to keep also the first token above the threshold
|
|
34
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
35
|
+
sorted_indices_to_remove[..., 0] = 0
|
|
36
|
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
37
|
+
logits[indices_to_remove] = filter_value
|
|
38
|
+
# going back to 2D
|
|
39
|
+
logits = logits.view(1, -1).contiguous()
|
|
40
|
+
|
|
41
|
+
return logits
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CoglmStrategy:
|
|
45
|
+
def __init__(
|
|
46
|
+
self, invalid_slices=[], temperature=1.0, top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, top_k_cluster=1.0
|
|
47
|
+
):
|
|
48
|
+
self.invalid_slices = invalid_slices
|
|
49
|
+
self.temperature = temperature
|
|
50
|
+
self.topk = top_k
|
|
51
|
+
self.top_p = top_p
|
|
52
|
+
self.eps = eps
|
|
53
|
+
if end_tokens is None:
|
|
54
|
+
end_tokens = []
|
|
55
|
+
self.end_tokens = end_tokens
|
|
56
|
+
self._is_done = False
|
|
57
|
+
self.outlier_count_down = 5
|
|
58
|
+
self.cluster_labels = torch.tensor(
|
|
59
|
+
np.load(f"{os.path.dirname(os.path.abspath(__file__))}/cluster_label.npy"),
|
|
60
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
61
|
+
dtype=torch.long,
|
|
62
|
+
)
|
|
63
|
+
self.top_k_cluster = top_k_cluster
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def is_done(self) -> bool:
|
|
67
|
+
return self._is_done
|
|
68
|
+
|
|
69
|
+
def forward(self, logits, tokens, mems, temperature=None):
|
|
70
|
+
if temperature is None:
|
|
71
|
+
temperature = self.temperature
|
|
72
|
+
logits = logits / temperature
|
|
73
|
+
for invalid_slice in self.invalid_slices:
|
|
74
|
+
logits[..., invalid_slice] = -65504
|
|
75
|
+
|
|
76
|
+
rprobs = F.softmax(logits.float(), dim=-1)
|
|
77
|
+
c = self.cluster_labels.expand(*rprobs.shape)
|
|
78
|
+
cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
|
|
79
|
+
best_scores, best_clusters = cprobs.topk(self.topk)
|
|
80
|
+
bz = logits.shape[0]
|
|
81
|
+
for i in range(bz):
|
|
82
|
+
best_scores[i] = best_scores[i] # ** 0.2
|
|
83
|
+
selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
|
|
84
|
+
logits[i, self.cluster_labels != selected_cluster] = -65504
|
|
85
|
+
|
|
86
|
+
probs = F.softmax(logits.float() / self.top_k_cluster, dim=-1) # float is essential, due to a bug in Pytorch
|
|
87
|
+
pred = torch.multinomial(probs, num_samples=1)
|
|
88
|
+
|
|
89
|
+
if pred.numel() == 1 and pred.item() in self.end_tokens:
|
|
90
|
+
self._is_done = True
|
|
91
|
+
tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
|
|
92
|
+
return tokens, mems
|
|
93
|
+
|
|
94
|
+
def finalize(self, tokens, mems):
|
|
95
|
+
self._is_done = False
|
|
96
|
+
return tokens, mems
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from SwissArmyTransformer.model import CachedAutoregressiveModel
|
|
7
|
+
except ModuleNotFoundError as e:
|
|
8
|
+
handle_module_not_found_error(e, ["heim"])
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_masks_and_position_ids_coglm(seq, context_length):
|
|
12
|
+
tokens = seq.unsqueeze(0)
|
|
13
|
+
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
|
|
14
|
+
attention_mask.tril_()
|
|
15
|
+
attention_mask[..., :context_length] = 1
|
|
16
|
+
attention_mask.unsqueeze_(1)
|
|
17
|
+
position_ids = torch.zeros(len(seq), device=tokens.device, dtype=torch.long)
|
|
18
|
+
torch.arange(0, context_length, out=position_ids[:context_length])
|
|
19
|
+
torch.arange(512, 512 + len(seq) - context_length, out=position_ids[context_length:])
|
|
20
|
+
position_ids = position_ids.unsqueeze(0)
|
|
21
|
+
return tokens, attention_mask, position_ids
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_recipe(name):
|
|
25
|
+
r = {
|
|
26
|
+
"attn_plus": 1.4,
|
|
27
|
+
"temp_all_gen": 1.15,
|
|
28
|
+
"topk_gen": 16,
|
|
29
|
+
"temp_cluster_gen": 1.0,
|
|
30
|
+
"temp_all_dsr": 1.5,
|
|
31
|
+
"topk_dsr": 100,
|
|
32
|
+
"temp_cluster_dsr": 0.89,
|
|
33
|
+
"temp_all_itersr": 1.3,
|
|
34
|
+
"topk_itersr": 16,
|
|
35
|
+
"query_template": "{}<start_of_image>",
|
|
36
|
+
}
|
|
37
|
+
if name == "none":
|
|
38
|
+
pass
|
|
39
|
+
elif name == "mainbody":
|
|
40
|
+
r["query_template"] = "{} 高清摄影 隔绝<start_of_image>"
|
|
41
|
+
elif name == "photo":
|
|
42
|
+
r["query_template"] = "{} 高清摄影<start_of_image>"
|
|
43
|
+
elif name == "flat":
|
|
44
|
+
r["query_template"] = "{} 平面风格<start_of_image>"
|
|
45
|
+
# r['attn_plus'] = 1.8
|
|
46
|
+
# r['temp_cluster_gen'] = 0.75
|
|
47
|
+
r["temp_all_gen"] = 1.1
|
|
48
|
+
r["topk_dsr"] = 5
|
|
49
|
+
r["temp_cluster_dsr"] = 0.4
|
|
50
|
+
r["temp_all_itersr"] = 1
|
|
51
|
+
r["topk_itersr"] = 5
|
|
52
|
+
elif name == "comics":
|
|
53
|
+
r["query_template"] = "{} 漫画 隔绝<start_of_image>"
|
|
54
|
+
r["topk_dsr"] = 5
|
|
55
|
+
r["temp_cluster_dsr"] = 0.4
|
|
56
|
+
r["temp_all_gen"] = 1.1
|
|
57
|
+
r["temp_all_itersr"] = 1
|
|
58
|
+
r["topk_itersr"] = 5
|
|
59
|
+
elif name == "oil":
|
|
60
|
+
r["query_template"] = "{} 油画风格<start_of_image>"
|
|
61
|
+
pass
|
|
62
|
+
elif name == "sketch":
|
|
63
|
+
r["query_template"] = "{} 素描风格<start_of_image>"
|
|
64
|
+
r["temp_all_gen"] = 1.1
|
|
65
|
+
elif name == "isometric":
|
|
66
|
+
r["query_template"] = "{} 等距矢量图<start_of_image>"
|
|
67
|
+
r["temp_all_gen"] = 1.1
|
|
68
|
+
elif name == "chinese":
|
|
69
|
+
r["query_template"] = "{} 水墨国画<start_of_image>"
|
|
70
|
+
r["temp_all_gen"] = 1.12
|
|
71
|
+
elif name == "watercolor":
|
|
72
|
+
r["query_template"] = "{} 水彩画风格<start_of_image>"
|
|
73
|
+
return r
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class InferenceModel(CachedAutoregressiveModel):
|
|
77
|
+
def final_forward(self, logits, **kwargs):
|
|
78
|
+
logits_parallel = logits
|
|
79
|
+
logits_parallel = torch.nn.functional.linear(
|
|
80
|
+
logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()
|
|
81
|
+
)
|
|
82
|
+
return logits_parallel
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
@File : __init__.py
|
|
4
|
+
@Time : 2022/03/02 13:57:09
|
|
5
|
+
@Author : Ming Ding
|
|
6
|
+
@Contact : dm18@mails.tsinghua.edu.cn
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .direct_sr import DirectSuperResolution
|
|
10
|
+
from .iterative_sr import IterativeSuperResolution
|
|
11
|
+
from .sr_group import SRGroup
|
|
12
|
+
|
|
13
|
+
DirectSuperResolution
|
|
14
|
+
IterativeSuperResolution
|
|
15
|
+
SRGroup
|