crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
- crfm_helm-0.5.1.dist-info/RECORD +654 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +25 -3
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +41 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +213 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +41 -1
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +205 -35
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +163 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +757 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +823 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +233 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +301 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +104 -73
- helm/clients/vertexai_client.py +400 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +111 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +33 -3
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1159 -538
- helm/config/model_metadata.yaml +868 -41
- helm/config/tokenizer_configs.yaml +149 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_deployment_definition.py +0 -92
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# ------------------------------------------------------------------------------------
|
|
2
|
+
# minDALL-E
|
|
3
|
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
5
|
+
# ------------------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from typing import Optional
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from torch.nn import functional as F
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
|
|
14
|
+
if k is None:
|
|
15
|
+
return logits
|
|
16
|
+
else:
|
|
17
|
+
v, ix = torch.topk(logits, k)
|
|
18
|
+
out = logits.clone()
|
|
19
|
+
out[out < v[:, [-1]]] = -float("Inf")
|
|
20
|
+
return out
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
|
|
24
|
+
if p is None:
|
|
25
|
+
return probs
|
|
26
|
+
else:
|
|
27
|
+
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
|
|
28
|
+
cum_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
29
|
+
|
|
30
|
+
sorted_idx_remove_cond = cum_probs >= p
|
|
31
|
+
|
|
32
|
+
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
|
|
33
|
+
sorted_idx_remove_cond[..., 0] = 0
|
|
34
|
+
|
|
35
|
+
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
|
|
36
|
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
|
37
|
+
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
|
|
38
|
+
return norm_probs
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_positional_encoding(inputs: torch.LongTensor, mode: str = "1d") -> torch.LongTensor:
|
|
42
|
+
device = inputs.device
|
|
43
|
+
if mode == "1d":
|
|
44
|
+
B, N = inputs.shape
|
|
45
|
+
xs_pos = torch.arange(N, device=device).repeat((B, 1))
|
|
46
|
+
elif mode == "2d":
|
|
47
|
+
B, H, W = inputs.shape
|
|
48
|
+
xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
|
|
49
|
+
xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
|
|
50
|
+
xs_pos = (xs_pos_h, xs_pos_w)
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError("%s positional encoding invalid" % mode)
|
|
53
|
+
return xs_pos
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@torch.no_grad()
|
|
57
|
+
def sampling(
|
|
58
|
+
model: torch.nn.Module,
|
|
59
|
+
tokens: torch.LongTensor,
|
|
60
|
+
top_k: Optional[float] = None,
|
|
61
|
+
top_p: Optional[float] = None,
|
|
62
|
+
softmax_temperature: float = 1.0,
|
|
63
|
+
is_tqdm: bool = True,
|
|
64
|
+
use_fp16: bool = True,
|
|
65
|
+
max_seq_len: int = 256,
|
|
66
|
+
) -> torch.LongTensor:
|
|
67
|
+
code = None
|
|
68
|
+
past = None
|
|
69
|
+
|
|
70
|
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
|
71
|
+
pos_enc_tokens = get_positional_encoding(tokens, mode="1d")
|
|
72
|
+
|
|
73
|
+
for cnt, h in enumerate(pbar):
|
|
74
|
+
if code is None:
|
|
75
|
+
code_ = None
|
|
76
|
+
pos_enc_code_ = None
|
|
77
|
+
else:
|
|
78
|
+
code_ = code.clone().detach()
|
|
79
|
+
pos_enc_code_ = get_positional_encoding(code_, mode="1d")
|
|
80
|
+
code_ = code_[:, cnt - 1].unsqueeze(-1)
|
|
81
|
+
pos_enc_code_ = pos_enc_code_[:, cnt - 1].unsqueeze(-1)
|
|
82
|
+
|
|
83
|
+
logits, present = model.sampling(
|
|
84
|
+
images=code_, texts=tokens, pos_images=pos_enc_code_, pos_texts=pos_enc_tokens, use_fp16=use_fp16, past=past
|
|
85
|
+
)
|
|
86
|
+
logits = logits.to(dtype=torch.float32)
|
|
87
|
+
logits = logits / softmax_temperature
|
|
88
|
+
|
|
89
|
+
present = torch.stack(present).clone().detach()
|
|
90
|
+
if past is None:
|
|
91
|
+
past = [present]
|
|
92
|
+
else:
|
|
93
|
+
past.append(present)
|
|
94
|
+
|
|
95
|
+
logits = cutoff_topk_logits(logits, top_k)
|
|
96
|
+
probs = F.softmax(logits, dim=-1)
|
|
97
|
+
probs = cutoff_topp_probs(probs, top_p)
|
|
98
|
+
|
|
99
|
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
|
100
|
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
|
101
|
+
|
|
102
|
+
del past
|
|
103
|
+
return code
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@torch.no_grad()
|
|
107
|
+
def sampling_igpt(
|
|
108
|
+
model: torch.nn.Module,
|
|
109
|
+
sos: torch.FloatTensor,
|
|
110
|
+
top_k: Optional[float] = None,
|
|
111
|
+
top_p: Optional[float] = None,
|
|
112
|
+
softmax_temperature: float = 1.0,
|
|
113
|
+
is_tqdm: bool = True,
|
|
114
|
+
use_fp16: bool = True,
|
|
115
|
+
max_seq_len: int = 256,
|
|
116
|
+
) -> torch.LongTensor:
|
|
117
|
+
code = None
|
|
118
|
+
past = None
|
|
119
|
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
|
120
|
+
|
|
121
|
+
for cnt, h in enumerate(pbar):
|
|
122
|
+
if code is None:
|
|
123
|
+
code_ = None
|
|
124
|
+
pos_enc_code_ = None
|
|
125
|
+
else:
|
|
126
|
+
code_ = code.clone().detach()
|
|
127
|
+
pos_enc_code_ = get_positional_encoding(code_, mode="1d")
|
|
128
|
+
code_ = code_[:, cnt - 1].unsqueeze(-1)
|
|
129
|
+
pos_enc_code_ = pos_enc_code_[:, cnt - 1].unsqueeze(-1)
|
|
130
|
+
|
|
131
|
+
logits, present = model.sampling(sos=sos, codes=code_, pos_codes=pos_enc_code_, use_fp16=use_fp16, past=past)
|
|
132
|
+
logits = logits.to(dtype=torch.float32)
|
|
133
|
+
logits = logits / softmax_temperature
|
|
134
|
+
|
|
135
|
+
present = torch.stack(present).clone().detach()
|
|
136
|
+
if past is None:
|
|
137
|
+
past = [present]
|
|
138
|
+
else:
|
|
139
|
+
past.append(present)
|
|
140
|
+
|
|
141
|
+
logits = cutoff_topk_logits(logits, top_k)
|
|
142
|
+
probs = F.softmax(logits, dim=-1)
|
|
143
|
+
probs = cutoff_topp_probs(probs, top_p)
|
|
144
|
+
|
|
145
|
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
|
146
|
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
|
147
|
+
|
|
148
|
+
del past
|
|
149
|
+
return code
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# ------------------------------------------------------------------------------------
|
|
2
|
+
# minDALL-E
|
|
3
|
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
5
|
+
# ------------------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import urllib
|
|
10
|
+
import hashlib
|
|
11
|
+
import tarfile
|
|
12
|
+
import torch
|
|
13
|
+
import numpy as np
|
|
14
|
+
from torch.nn import functional as F
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def set_seed(seed: int):
|
|
21
|
+
random.seed(seed)
|
|
22
|
+
np.random.seed(seed)
|
|
23
|
+
torch.manual_seed(seed)
|
|
24
|
+
torch.cuda.manual_seed_all(seed)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@torch.no_grad()
|
|
28
|
+
def clip_score(
|
|
29
|
+
prompt: str, images: np.ndarray, model_clip: torch.nn.Module, preprocess_clip, device: str
|
|
30
|
+
) -> np.ndarray:
|
|
31
|
+
try:
|
|
32
|
+
import clip
|
|
33
|
+
from PIL import Image
|
|
34
|
+
except ModuleNotFoundError as e:
|
|
35
|
+
handle_module_not_found_error(e, ["heim"])
|
|
36
|
+
|
|
37
|
+
images = [preprocess_clip(Image.fromarray((image * 255).astype(np.uint8))) for image in images]
|
|
38
|
+
images = torch.stack(images, dim=0).to(device=device)
|
|
39
|
+
texts = clip.tokenize(prompt).to(device=device)
|
|
40
|
+
texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
|
|
41
|
+
|
|
42
|
+
image_features = model_clip.encode_image(images)
|
|
43
|
+
text_features = model_clip.encode_text(texts)
|
|
44
|
+
|
|
45
|
+
scores = F.cosine_similarity(image_features, text_features).squeeze()
|
|
46
|
+
rank = torch.argsort(scores, descending=True).cpu().numpy()
|
|
47
|
+
return rank
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def download(url: str, root: str) -> str:
|
|
51
|
+
os.makedirs(root, exist_ok=True)
|
|
52
|
+
filename = os.path.basename(url)
|
|
53
|
+
pathname = filename[: -len(".tar.gz")]
|
|
54
|
+
|
|
55
|
+
expected_md5 = url.split("/")[-2]
|
|
56
|
+
download_target = os.path.join(root, filename)
|
|
57
|
+
result_path = os.path.join(root, pathname)
|
|
58
|
+
|
|
59
|
+
if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
|
|
60
|
+
return result_path
|
|
61
|
+
|
|
62
|
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
|
63
|
+
with tqdm(
|
|
64
|
+
total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
|
|
65
|
+
) as loop:
|
|
66
|
+
while True:
|
|
67
|
+
buffer = source.read(8192)
|
|
68
|
+
if not buffer:
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
output.write(buffer)
|
|
72
|
+
loop.update(len(buffer))
|
|
73
|
+
|
|
74
|
+
if hashlib.md5(open(download_target, "rb").read()).hexdigest() != expected_md5:
|
|
75
|
+
raise RuntimeError(f"Model has been downloaded but the md5 checksum does not not match")
|
|
76
|
+
|
|
77
|
+
with tarfile.open(download_target, "r:gz") as f:
|
|
78
|
+
pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
|
|
79
|
+
for member in pbar:
|
|
80
|
+
pbar.set_description(f"extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)")
|
|
81
|
+
f.extract(member=member, path=root)
|
|
82
|
+
|
|
83
|
+
return result_path
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
|
|
87
|
+
if urllib.parse.urlparse(url_or_path).scheme in ("http", "https"):
|
|
88
|
+
return download(url_or_path, root)
|
|
89
|
+
return url_or_path
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from helm.common.cache import CacheConfig, Cache
|
|
6
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
7
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
8
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
11
|
+
from helm.common.tokenization_request import (
|
|
12
|
+
DecodeRequest,
|
|
13
|
+
DecodeRequestResult,
|
|
14
|
+
TokenizationRequest,
|
|
15
|
+
TokenizationRequestResult,
|
|
16
|
+
)
|
|
17
|
+
from helm.clients.client import Client, CachingClient
|
|
18
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
except ModuleNotFoundError as e:
|
|
23
|
+
handle_module_not_found_error(e, ["heim"])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MinDALLEClient(Client):
|
|
27
|
+
"""
|
|
28
|
+
Source: https://github.com/kakaobrain/mindall-e
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
|
|
32
|
+
self._cache = Cache(cache_config)
|
|
33
|
+
self._file_cache: FileCache = file_cache
|
|
34
|
+
|
|
35
|
+
self._model = None
|
|
36
|
+
|
|
37
|
+
def _get_model(self):
|
|
38
|
+
try:
|
|
39
|
+
from helm.clients.image_generation.mindalle.models import Dalle
|
|
40
|
+
except ModuleNotFoundError as e:
|
|
41
|
+
handle_module_not_found_error(e, ["heim"])
|
|
42
|
+
|
|
43
|
+
if self._model is None:
|
|
44
|
+
self._model = Dalle.from_pretrained("minDALL-E/1.3B")
|
|
45
|
+
self._model = self._model.to(get_torch_device_name())
|
|
46
|
+
return self._model
|
|
47
|
+
|
|
48
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
49
|
+
raw_request = {
|
|
50
|
+
"prompt": request.prompt,
|
|
51
|
+
# Setting this to a higher value can cause CUDA OOM
|
|
52
|
+
# Fix it to 1 and generate an image `request.num_completions` times
|
|
53
|
+
"num_candidates": 1,
|
|
54
|
+
"softmax_temperature": 1.0,
|
|
55
|
+
"top_k": 256, # It is recommended that top_k is set lower than 256.
|
|
56
|
+
"top_p": None,
|
|
57
|
+
"device": "cuda",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
|
|
62
|
+
def do_it() -> Dict[str, Any]:
|
|
63
|
+
prompt: str = request.prompt
|
|
64
|
+
|
|
65
|
+
with htrack_block(f"Generating images for prompt: {prompt}"):
|
|
66
|
+
model = self._get_model()
|
|
67
|
+
|
|
68
|
+
images: List[Image] = []
|
|
69
|
+
for _ in range(request.num_completions):
|
|
70
|
+
output = model.sampling(**raw_request).cpu().numpy()
|
|
71
|
+
output = np.transpose(output, (0, 2, 3, 1))
|
|
72
|
+
image = Image.fromarray(np.asarray(output[0] * 255, dtype=np.uint8))
|
|
73
|
+
images.append(image)
|
|
74
|
+
|
|
75
|
+
assert (
|
|
76
|
+
len(images) == request.num_completions
|
|
77
|
+
), f"Expected {request.num_completions} images, but got {len(images)}"
|
|
78
|
+
|
|
79
|
+
result = {"file_locations": []}
|
|
80
|
+
for image in images:
|
|
81
|
+
# Write out the image to a file and save the path
|
|
82
|
+
file_location: str = self._file_cache.get_unique_file_location()
|
|
83
|
+
image.save(file_location)
|
|
84
|
+
hlog(f"Image saved at {file_location}.")
|
|
85
|
+
result["file_locations"].append(file_location)
|
|
86
|
+
return result
|
|
87
|
+
|
|
88
|
+
# Include the model name and number of completions in the cache key
|
|
89
|
+
cache_key: Dict = CachingClient.make_cache_key(
|
|
90
|
+
{"model": request.model_engine, "n": request.num_completions, **raw_request}, request
|
|
91
|
+
)
|
|
92
|
+
results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
|
|
93
|
+
except RuntimeError as ex:
|
|
94
|
+
error: str = f"MinDALLEClient error: {ex}"
|
|
95
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
96
|
+
|
|
97
|
+
completions: List[GeneratedOutput] = [
|
|
98
|
+
GeneratedOutput(
|
|
99
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(location)
|
|
100
|
+
)
|
|
101
|
+
for location in results["file_locations"]
|
|
102
|
+
]
|
|
103
|
+
return RequestResult(
|
|
104
|
+
success=True,
|
|
105
|
+
cached=cached,
|
|
106
|
+
request_time=results["request_time"],
|
|
107
|
+
completions=completions,
|
|
108
|
+
embedding=[],
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
112
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
113
|
+
|
|
114
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
115
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from helm.common.cache import Cache, CacheConfig
|
|
5
|
+
from helm.common.general import ensure_file_downloaded, ensure_directory_exists
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
7
|
+
from helm.common.nudity_check_request import NudityCheckRequest, NudityCheckResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NudityCheckClientError(Exception):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NudityCheckClient:
|
|
15
|
+
MODEL_DOWNLOAD_URL: str = (
|
|
16
|
+
"https://worksheets.codalab.org/rest/bundles/0x81f7d1febb374fdcad6d9ae970f46750/contents/blob/"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def __init__(self, cache_config: CacheConfig):
|
|
20
|
+
try:
|
|
21
|
+
from nudenet import NudeClassifier
|
|
22
|
+
except ModuleNotFoundError as e:
|
|
23
|
+
handle_module_not_found_error(e, ["heim"])
|
|
24
|
+
|
|
25
|
+
self.cache = Cache(cache_config)
|
|
26
|
+
self._nudity_classifier: Optional[NudeClassifier] = None
|
|
27
|
+
|
|
28
|
+
def check_nudity(self, request: NudityCheckRequest) -> NudityCheckResult:
|
|
29
|
+
"""Check for nudity for a batch of images using NudeNet."""
|
|
30
|
+
try:
|
|
31
|
+
from nudenet import NudeClassifier
|
|
32
|
+
except ModuleNotFoundError as e:
|
|
33
|
+
handle_module_not_found_error(e, ["heim"])
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
|
|
37
|
+
def do_it() -> Dict[str, Any]:
|
|
38
|
+
if self._nudity_classifier is None:
|
|
39
|
+
# The NudeNet library does not automatically provide model weights that work, so
|
|
40
|
+
# manually download them. The path is hardcoded in the NudeNet library.
|
|
41
|
+
base_path: str = os.path.join(os.path.expanduser("~"), ".NudeNet")
|
|
42
|
+
ensure_directory_exists(base_path)
|
|
43
|
+
model_path: str = os.path.join(base_path, "classifier_model.onnx")
|
|
44
|
+
ensure_file_downloaded(source_url=self.MODEL_DOWNLOAD_URL, target_path=model_path)
|
|
45
|
+
self._nudity_classifier = NudeClassifier()
|
|
46
|
+
|
|
47
|
+
path_to_nudity_scores: Dict[str, Dict[str, float]] = self._nudity_classifier.classify(
|
|
48
|
+
request.image_locations
|
|
49
|
+
)
|
|
50
|
+
return path_to_nudity_scores
|
|
51
|
+
|
|
52
|
+
results, cached = self.cache.get({"locations": sorted(request.image_locations)}, do_it)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
raise NudityCheckClientError(e)
|
|
55
|
+
|
|
56
|
+
nudity_results: Dict[str, bool] = {
|
|
57
|
+
image_location: nudity_result["unsafe"] > nudity_result["safe"]
|
|
58
|
+
for image_location, nudity_result in results.items()
|
|
59
|
+
}
|
|
60
|
+
return NudityCheckResult(
|
|
61
|
+
success=True,
|
|
62
|
+
cached=cached,
|
|
63
|
+
image_to_nudity=nudity_results,
|
|
64
|
+
)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
import base64
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from helm.common.cache import CacheConfig, Cache
|
|
6
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
7
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
8
|
+
from helm.common.tokenization_request import (
|
|
9
|
+
TokenizationRequest,
|
|
10
|
+
TokenizationRequestResult,
|
|
11
|
+
DecodeRequest,
|
|
12
|
+
DecodeRequestResult,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from helm.clients.client import CachingClient, Client
|
|
16
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TogetherImageGenerationClient(Client):
|
|
20
|
+
"""
|
|
21
|
+
Client for image generation via the Together API.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
DEFAULT_IMAGE_HEIGHT: int = 512
|
|
25
|
+
DEFAULT_IMAGE_WIDTH: int = 512
|
|
26
|
+
|
|
27
|
+
DEFAULT_GUIDANCE_SCALE: float = 7.5
|
|
28
|
+
DEFAULT_STEPS: int = 50
|
|
29
|
+
|
|
30
|
+
INFERENCE_ENDPOINT: str = "https://api.together.xyz/api/inference"
|
|
31
|
+
|
|
32
|
+
def __init__(self, cache_config: CacheConfig, file_cache: FileCache, api_key: Optional[str] = None):
|
|
33
|
+
self._cache = Cache(cache_config)
|
|
34
|
+
self.file_cache: FileCache = file_cache
|
|
35
|
+
|
|
36
|
+
self._promptist_model = None
|
|
37
|
+
self._promptist_tokenizer = None
|
|
38
|
+
|
|
39
|
+
self.api_key: Optional[str] = api_key
|
|
40
|
+
|
|
41
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
42
|
+
# Following https://docs.together.xyz/en/api
|
|
43
|
+
assert request.image_generation_parameters is not None
|
|
44
|
+
raw_request = {
|
|
45
|
+
"request_type": "image-model-inference",
|
|
46
|
+
"model": request.model_engine,
|
|
47
|
+
"prompt": request.prompt,
|
|
48
|
+
"n": request.num_completions,
|
|
49
|
+
"guidance_scale": (
|
|
50
|
+
request.image_generation_parameters.guidance_scale
|
|
51
|
+
if request.image_generation_parameters.guidance_scale is not None
|
|
52
|
+
else self.DEFAULT_GUIDANCE_SCALE
|
|
53
|
+
),
|
|
54
|
+
"steps": (
|
|
55
|
+
request.image_generation_parameters.diffusion_denoising_steps
|
|
56
|
+
if request.image_generation_parameters.diffusion_denoising_steps is not None
|
|
57
|
+
else self.DEFAULT_STEPS
|
|
58
|
+
),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if (
|
|
62
|
+
request.image_generation_parameters.output_image_width is None
|
|
63
|
+
or request.image_generation_parameters.output_image_height is None
|
|
64
|
+
):
|
|
65
|
+
raw_request["width"] = self.DEFAULT_IMAGE_WIDTH
|
|
66
|
+
raw_request["height"] = self.DEFAULT_IMAGE_HEIGHT
|
|
67
|
+
else:
|
|
68
|
+
raw_request["width"] = request.image_generation_parameters.output_image_width
|
|
69
|
+
raw_request["height"] = request.image_generation_parameters.output_image_height
|
|
70
|
+
|
|
71
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
|
|
75
|
+
def do_it() -> Dict[str, Any]:
|
|
76
|
+
result = requests.post(self.INFERENCE_ENDPOINT, json=raw_request).json()
|
|
77
|
+
assert "output" in result, f"Invalid response: {result} from prompt: {request.prompt}"
|
|
78
|
+
|
|
79
|
+
for choice in result["output"]["choices"]:
|
|
80
|
+
# Write out the image to a file and save the path
|
|
81
|
+
choice["file_path"] = self.file_cache.store(lambda: base64.b64decode(choice["image_base64"]))
|
|
82
|
+
choice.pop("image_base64", None)
|
|
83
|
+
return result["output"]
|
|
84
|
+
|
|
85
|
+
response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
|
|
86
|
+
except RuntimeError as e:
|
|
87
|
+
error: str = f"TogetherVisionClient error: {e}"
|
|
88
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
89
|
+
|
|
90
|
+
completions: List[GeneratedOutput] = [
|
|
91
|
+
GeneratedOutput(
|
|
92
|
+
text="",
|
|
93
|
+
logprob=0,
|
|
94
|
+
tokens=[],
|
|
95
|
+
multimodal_content=get_single_image_multimedia_object(choice["file_path"]),
|
|
96
|
+
)
|
|
97
|
+
for choice in response["choices"]
|
|
98
|
+
]
|
|
99
|
+
return RequestResult(
|
|
100
|
+
success=True,
|
|
101
|
+
cached=cached,
|
|
102
|
+
request_time=response["request_time"],
|
|
103
|
+
completions=completions,
|
|
104
|
+
embedding=[],
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
108
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
109
|
+
|
|
110
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
111
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -9,8 +9,8 @@ import torch
|
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
11
|
from helm.common.optional_dependencies import OptionalDependencyNotInstalled
|
|
12
|
-
from helm.common.request import Request, RequestResult,
|
|
13
|
-
from helm.
|
|
12
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
13
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
14
14
|
|
|
15
15
|
from .client import CachingClient
|
|
16
16
|
from .lit_gpt_generate import generate # type: ignore
|
|
@@ -156,8 +156,8 @@ class LitGPTClient(CachingClient):
|
|
|
156
156
|
|
|
157
157
|
generated_tokens = []
|
|
158
158
|
for token in tokens:
|
|
159
|
-
generated_tokens.append(Token(text=tokenizer.decode(token), logprob=0
|
|
160
|
-
completions = [
|
|
159
|
+
generated_tokens.append(Token(text=tokenizer.decode(token), logprob=0))
|
|
160
|
+
completions = [GeneratedOutput(text=output, logprob=0, tokens=generated_tokens)]
|
|
161
161
|
|
|
162
162
|
return RequestResult(
|
|
163
163
|
success=True,
|
|
@@ -9,12 +9,12 @@ from helm.common.request import (
|
|
|
9
9
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
10
10
|
Request,
|
|
11
11
|
RequestResult,
|
|
12
|
-
|
|
12
|
+
GeneratedOutput,
|
|
13
13
|
Token,
|
|
14
14
|
)
|
|
15
15
|
from helm.common.tokenization_request import TokenizationRequest
|
|
16
|
-
from helm.
|
|
17
|
-
from helm.
|
|
16
|
+
from helm.clients.client import CachingClient, truncate_sequence
|
|
17
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class MegatronClient(CachingClient):
|
|
@@ -52,7 +52,7 @@ class MegatronClient(CachingClient):
|
|
|
52
52
|
tokenized_text = self.tokenizer.tokenize(TokenizationRequest(text, tokenizer=self.tokenizer_name))
|
|
53
53
|
|
|
54
54
|
# TODO(tgale): Support logprobs.
|
|
55
|
-
tokens = [Token(text=str(token), logprob=0
|
|
55
|
+
tokens = [Token(text=str(token), logprob=0) for token in tokenized_text.raw_tokens]
|
|
56
56
|
return tokens
|
|
57
57
|
|
|
58
58
|
def _make_request(self, request: Request) -> RequestResult:
|
|
@@ -87,7 +87,7 @@ class MegatronClient(CachingClient):
|
|
|
87
87
|
|
|
88
88
|
# NOTE: Megatron returns the de-tokenized response. Re-tokenize.
|
|
89
89
|
tokens = self._tokenize_response(generated_text)
|
|
90
|
-
completion =
|
|
90
|
+
completion = GeneratedOutput(text=generated_text, logprob=0, tokens=tokens)
|
|
91
91
|
completion = truncate_sequence(completion, request, print_warning=True)
|
|
92
92
|
|
|
93
93
|
return RequestResult(
|