crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
import base64
|
|
3
|
+
|
|
4
|
+
from helm.common.cache import CacheConfig, Cache
|
|
5
|
+
from helm.common.general import hlog
|
|
6
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
7
|
+
from helm.common.media_object import MultimediaObject
|
|
8
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
10
|
+
from helm.common.tokenization_request import (
|
|
11
|
+
TokenizationRequest,
|
|
12
|
+
TokenizationRequestResult,
|
|
13
|
+
DecodeRequest,
|
|
14
|
+
DecodeRequestResult,
|
|
15
|
+
)
|
|
16
|
+
from helm.clients.moderation_api_client import ModerationAPIClient
|
|
17
|
+
from helm.clients.client import Client, CachingClient
|
|
18
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import openai
|
|
22
|
+
from openai import OpenAI
|
|
23
|
+
except ModuleNotFoundError as missing_module_exception:
|
|
24
|
+
handle_module_not_found_error(missing_module_exception, ["openai"])
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DALLE2Client(Client):
|
|
28
|
+
MAX_PROMPT_LENGTH: int = 1000
|
|
29
|
+
DEFAULT_IMAGE_SIZE_STR: str = "512x512"
|
|
30
|
+
VALID_IMAGE_SIZES: List[str] = ["256x256", DEFAULT_IMAGE_SIZE_STR, "1024x1024"]
|
|
31
|
+
|
|
32
|
+
# Set the finish reason to this if the prompt violates OpenAI's content policy
|
|
33
|
+
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
|
|
34
|
+
"The prompt violates OpenAI's content policy. "
|
|
35
|
+
"See https://labs.openai.com/policies/content-policy for more information."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# The DALL-E API will respond with the following error messages (or even a substring of the message)
|
|
39
|
+
# if it has any issues generating images for a particular prompt
|
|
40
|
+
PROMPT_FLAGGED_ERROR: str = (
|
|
41
|
+
"Your request was rejected as a result of our safety system. "
|
|
42
|
+
"Your prompt may contain text that is not allowed by our safety system."
|
|
43
|
+
)
|
|
44
|
+
PROMPT_FLAGGED_ERROR2: str = (
|
|
45
|
+
"Something went wrong with your generation. You may try again or ask for a different prompt"
|
|
46
|
+
)
|
|
47
|
+
PROMPT_FLAGGED_ERROR3: str = (
|
|
48
|
+
"The server had an error while processing your request. Sorry about that! You can retry your request, "
|
|
49
|
+
"or contact us through our help center at help.openai.com if the error persists."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
api_key: str,
|
|
55
|
+
cache_config: CacheConfig,
|
|
56
|
+
file_cache: FileCache,
|
|
57
|
+
moderation_api_client: ModerationAPIClient,
|
|
58
|
+
org_id: Optional[str] = None,
|
|
59
|
+
):
|
|
60
|
+
self.file_cache: FileCache = file_cache
|
|
61
|
+
self._cache = Cache(cache_config)
|
|
62
|
+
|
|
63
|
+
self.client = OpenAI(api_key=api_key, organization=org_id)
|
|
64
|
+
self.moderation_api_client: ModerationAPIClient = moderation_api_client
|
|
65
|
+
|
|
66
|
+
def get_content_policy_violated_result(self, request: Request) -> RequestResult:
|
|
67
|
+
"""
|
|
68
|
+
Return a RequestResult with no images and a finish reason indicating that the prompt / generated images
|
|
69
|
+
violate OpenAI's content policy.
|
|
70
|
+
"""
|
|
71
|
+
no_image = GeneratedOutput(
|
|
72
|
+
text="",
|
|
73
|
+
logprob=0,
|
|
74
|
+
tokens=[],
|
|
75
|
+
multimodal_content=MultimediaObject(),
|
|
76
|
+
finish_reason={"reason": self.CONTENT_POLICY_VIOLATED_FINISH_REASON},
|
|
77
|
+
)
|
|
78
|
+
return RequestResult(
|
|
79
|
+
success=True,
|
|
80
|
+
cached=False,
|
|
81
|
+
request_time=0,
|
|
82
|
+
completions=[no_image] * request.num_completions,
|
|
83
|
+
embedding=[],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def get_size_str(self, request: Request) -> str:
|
|
87
|
+
"""
|
|
88
|
+
Return the size string for the image generation request.
|
|
89
|
+
If the request does not specify a size, return the default size.
|
|
90
|
+
"""
|
|
91
|
+
assert request.image_generation_parameters is not None
|
|
92
|
+
w: Optional[int] = request.image_generation_parameters.output_image_width
|
|
93
|
+
h: Optional[int] = request.image_generation_parameters.output_image_height
|
|
94
|
+
if w is None or h is None:
|
|
95
|
+
return self.DEFAULT_IMAGE_SIZE_STR
|
|
96
|
+
|
|
97
|
+
image_dimensions: str = f"{w}x{h}"
|
|
98
|
+
assert image_dimensions in self.VALID_IMAGE_SIZES, f"Valid image sizes are {self.VALID_IMAGE_SIZES}"
|
|
99
|
+
return image_dimensions
|
|
100
|
+
|
|
101
|
+
def fail_if_invalid_request(self, request: Request) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Validate the request to ensure it is a valid request for the DALL-E API.
|
|
104
|
+
"""
|
|
105
|
+
assert request.image_generation_parameters is not None
|
|
106
|
+
if len(request.prompt) > self.MAX_PROMPT_LENGTH:
|
|
107
|
+
raise ValueError("The maximum length of the prompt is 1000 characters.")
|
|
108
|
+
if request.num_completions < 1 or request.num_completions > 10:
|
|
109
|
+
raise ValueError("`num_completions` must be between 1 and 10.")
|
|
110
|
+
|
|
111
|
+
def handle_openai_error(self, request: Request, error: Exception) -> RequestResult:
|
|
112
|
+
"""
|
|
113
|
+
Handle a thrown error from the DALL-E API.
|
|
114
|
+
"""
|
|
115
|
+
if (
|
|
116
|
+
str(error) in self.PROMPT_FLAGGED_ERROR
|
|
117
|
+
# Sometimes the DALL-E API will add additional information to the error message.
|
|
118
|
+
or self.PROMPT_FLAGGED_ERROR2 in str(error)
|
|
119
|
+
or self.PROMPT_FLAGGED_ERROR3 in str(error)
|
|
120
|
+
):
|
|
121
|
+
# Some requests fail even if we check the prompt against the moderation API.
|
|
122
|
+
# For example, "black" in Spanish (negro) causes requests to DALL-E to fail even
|
|
123
|
+
# though the prompt does not get flagged by the Moderation API.
|
|
124
|
+
hlog(f"Failed safety check: {request.prompt}")
|
|
125
|
+
return self.get_content_policy_violated_result(request)
|
|
126
|
+
else:
|
|
127
|
+
return RequestResult(
|
|
128
|
+
success=False, cached=False, error=f"DALL-E error: {error}", completions=[], embedding=[]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def generate_with_dalle_api(self, raw_request: Dict[str, Any]) -> Dict:
|
|
132
|
+
"""
|
|
133
|
+
Makes a single request to generate the images with the DALL-E API.
|
|
134
|
+
"""
|
|
135
|
+
result = self.client.images.generate(**raw_request).model_dump(mode="json")
|
|
136
|
+
assert "data" in result, f"Invalid response: {result} from prompt: {raw_request['prompt']}"
|
|
137
|
+
|
|
138
|
+
for image in result["data"]:
|
|
139
|
+
# Write out the image to a file and save the path
|
|
140
|
+
image["file_path"] = self.file_cache.store(lambda: base64.b64decode(image["b64_json"]))
|
|
141
|
+
# Don't cache contents of `b64_json` as we already have the image stored
|
|
142
|
+
image.pop("b64_json", None)
|
|
143
|
+
return result
|
|
144
|
+
|
|
145
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
146
|
+
self.fail_if_invalid_request(request)
|
|
147
|
+
|
|
148
|
+
# Use the Moderation API to check if the prompt violates OpenAI's content policy before generating images
|
|
149
|
+
if self.moderation_api_client.will_be_flagged(request.prompt):
|
|
150
|
+
return self.get_content_policy_violated_result(request)
|
|
151
|
+
|
|
152
|
+
# https://beta.openai.com/docs/api-reference/images/create#images/create-response_format
|
|
153
|
+
raw_request: Dict[str, Any] = {
|
|
154
|
+
"prompt": request.prompt,
|
|
155
|
+
"n": request.num_completions,
|
|
156
|
+
"size": self.get_size_str(request),
|
|
157
|
+
"response_format": "b64_json", # Always set to b64_json as URLs are only valid for an hour
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
|
|
162
|
+
def do_it() -> Dict[str, Any]:
|
|
163
|
+
# To maintain backwards compatibility, specify the model in the request but not in the cache key
|
|
164
|
+
return self.generate_with_dalle_api({"model": "dall-e-2", **raw_request})
|
|
165
|
+
|
|
166
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
167
|
+
response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
|
|
168
|
+
except openai.OpenAIError as e:
|
|
169
|
+
return self.handle_openai_error(request, e)
|
|
170
|
+
|
|
171
|
+
completions: List[GeneratedOutput] = [
|
|
172
|
+
GeneratedOutput(
|
|
173
|
+
text="",
|
|
174
|
+
logprob=0,
|
|
175
|
+
tokens=[],
|
|
176
|
+
multimodal_content=get_single_image_multimedia_object(generated_image["file_path"]),
|
|
177
|
+
)
|
|
178
|
+
for generated_image in response["data"]
|
|
179
|
+
]
|
|
180
|
+
return RequestResult(
|
|
181
|
+
success=True,
|
|
182
|
+
cached=cached,
|
|
183
|
+
request_time=response["request_time"],
|
|
184
|
+
completions=completions,
|
|
185
|
+
embedding=[],
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
189
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
190
|
+
|
|
191
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
192
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import CacheConfig
|
|
4
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
5
|
+
from helm.common.general import singleton
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
7
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
8
|
+
from helm.clients.moderation_api_client import ModerationAPIClient
|
|
9
|
+
from helm.clients.client import CachingClient
|
|
10
|
+
from .dalle2_client import DALLE2Client
|
|
11
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import openai
|
|
15
|
+
except ModuleNotFoundError as missing_module_exception:
|
|
16
|
+
handle_module_not_found_error(missing_module_exception, ["openai"])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DALLE3Client(DALLE2Client):
|
|
20
|
+
"""
|
|
21
|
+
Client for the OpenAI's DALL-E 3 API.
|
|
22
|
+
DALL-E 3 cookbook with explanations for the different parameters:
|
|
23
|
+
https://cookbook.openai.com/articles/what_is_new_with_dalle_3
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
DEFAULT_IMAGE_SIZE_STR: str = "1024x1024"
|
|
27
|
+
VALID_IMAGE_SIZES: List[str] = [DEFAULT_IMAGE_SIZE_STR, "1792x1024", "1024x1792"]
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str,
|
|
32
|
+
cache_config: CacheConfig,
|
|
33
|
+
file_cache: FileCache,
|
|
34
|
+
moderation_api_client: ModerationAPIClient,
|
|
35
|
+
org_id: Optional[str] = None,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(api_key, cache_config, file_cache, moderation_api_client, org_id)
|
|
38
|
+
|
|
39
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
40
|
+
self.fail_if_invalid_request(request)
|
|
41
|
+
if self.moderation_api_client.will_be_flagged(request.prompt):
|
|
42
|
+
return self.get_content_policy_violated_result(request)
|
|
43
|
+
|
|
44
|
+
raw_request: Dict[str, Any] = {
|
|
45
|
+
"model": "dall-e-3",
|
|
46
|
+
"prompt": request.prompt,
|
|
47
|
+
"n": 1, # As of December 2023, the DALL-E 3 API only supports a single generated image per request
|
|
48
|
+
"size": self.get_size_str(request),
|
|
49
|
+
"response_format": "b64_json", # Always set to b64_json as URLs are only valid for an hour
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
if request.model_engine == "dall-e-3":
|
|
53
|
+
raw_request["quality"] = "standard"
|
|
54
|
+
raw_request["style"] = "vivid"
|
|
55
|
+
elif request.model_engine == "dall-e-3-natural":
|
|
56
|
+
raw_request["quality"] = "standard"
|
|
57
|
+
raw_request["style"] = "natural"
|
|
58
|
+
elif request.model_engine == "dall-e-3-hd":
|
|
59
|
+
raw_request["quality"] = "hd"
|
|
60
|
+
raw_request["style"] = "vivid"
|
|
61
|
+
elif request.model_engine == "dall-e-3-hd-natural":
|
|
62
|
+
raw_request["quality"] = "hd"
|
|
63
|
+
raw_request["style"] = "natural"
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(f"Invalid DALL-E 3 model: {request.model_engine}")
|
|
66
|
+
|
|
67
|
+
responses: List[Dict[str, Any]] = []
|
|
68
|
+
all_cached: bool = True
|
|
69
|
+
|
|
70
|
+
# Since the DALL-E 3 API only supports a single generated image, make `request.num_completions` requests
|
|
71
|
+
for completion_index in range(request.num_completions):
|
|
72
|
+
try:
|
|
73
|
+
|
|
74
|
+
def do_it() -> Dict[str, Any]:
|
|
75
|
+
return self.generate_with_dalle_api({**raw_request})
|
|
76
|
+
|
|
77
|
+
cache_key = CachingClient.make_cache_key({"completion_index": completion_index, **raw_request}, request)
|
|
78
|
+
response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
|
|
79
|
+
|
|
80
|
+
responses.append(response)
|
|
81
|
+
all_cached = all_cached and cached
|
|
82
|
+
except openai.OpenAIError as e:
|
|
83
|
+
return self.handle_openai_error(request, e)
|
|
84
|
+
|
|
85
|
+
completions: List[GeneratedOutput] = []
|
|
86
|
+
total_request_time: float = 0
|
|
87
|
+
for response in responses:
|
|
88
|
+
image_response: Dict[str, Any] = singleton(response["data"])
|
|
89
|
+
completions.append(
|
|
90
|
+
GeneratedOutput(
|
|
91
|
+
# From https://cookbook.openai.com/articles/what_is_new_with_dalle_3,
|
|
92
|
+
# "a new feature in the latest DALL·E-3 API is prompt rewriting, where we use
|
|
93
|
+
# GPT-4 to optimize all of your prompts before they’re passed to DALL-E."
|
|
94
|
+
text=image_response["revised_prompt"],
|
|
95
|
+
multimodal_content=get_single_image_multimedia_object(image_response["file_path"]),
|
|
96
|
+
logprob=0,
|
|
97
|
+
tokens=[],
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
total_request_time += response["request_time"]
|
|
101
|
+
|
|
102
|
+
return RequestResult(
|
|
103
|
+
success=True,
|
|
104
|
+
cached=all_cached,
|
|
105
|
+
request_time=total_request_time,
|
|
106
|
+
completions=completions,
|
|
107
|
+
embedding=[],
|
|
108
|
+
)
|