crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import partial
|
|
5
|
+
|
|
6
|
+
from helm.common.cache import CacheConfig, Cache
|
|
7
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
8
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
11
|
+
from helm.common.tokenization_request import (
|
|
12
|
+
DecodeRequest,
|
|
13
|
+
DecodeRequestResult,
|
|
14
|
+
TokenizationRequest,
|
|
15
|
+
TokenizationRequestResult,
|
|
16
|
+
)
|
|
17
|
+
from helm.clients.client import Client, CachingClient
|
|
18
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DALLEMiniClient(Client):
|
|
22
|
+
"""
|
|
23
|
+
Source: https://github.com/borisdayma/dalle-mini, https://github.com/patil-suraj/vqgan-jax
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
|
|
27
|
+
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
|
|
28
|
+
|
|
29
|
+
def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
|
|
30
|
+
self._cache = Cache(cache_config)
|
|
31
|
+
self._file_cache: FileCache = file_cache
|
|
32
|
+
|
|
33
|
+
self._model_engine_to_model = {}
|
|
34
|
+
|
|
35
|
+
def _get_model(self, model_engine: str):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the model based on the model name.
|
|
38
|
+
Cache the model, so it doesn't get reinitialize for a new request.
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
import jax.numpy as jnp
|
|
42
|
+
from flax.jax_utils import replicate
|
|
43
|
+
|
|
44
|
+
from helm.clients.image_generation.dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
|
|
45
|
+
from helm.clients.image_generation.dalle_mini import DalleBart, DalleBartProcessor
|
|
46
|
+
except ModuleNotFoundError as e:
|
|
47
|
+
handle_module_not_found_error(e, ["heim"])
|
|
48
|
+
|
|
49
|
+
if model_engine not in self._model_engine_to_model:
|
|
50
|
+
model_name: str
|
|
51
|
+
if model_engine == "dalle-mini":
|
|
52
|
+
model_name = "dalle-mini/dalle-mini/mini-1:v0"
|
|
53
|
+
elif model_engine == "dalle-mega":
|
|
54
|
+
model_name = "dalle-mini/dalle-mini/mega-1-fp16:latest"
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unhandled model: {model_engine}")
|
|
57
|
+
|
|
58
|
+
model, params = DalleBart.from_pretrained(model_name, revision=None, dtype=jnp.float16, _do_init=False)
|
|
59
|
+
processor = DalleBartProcessor.from_pretrained(model_name, revision=None)
|
|
60
|
+
vqgan, vqgan_params = VQModel.from_pretrained(
|
|
61
|
+
self.VQGAN_REPO, revision=self.VQGAN_COMMIT_ID, _do_init=False
|
|
62
|
+
)
|
|
63
|
+
params = replicate(params)
|
|
64
|
+
vqgan_params = replicate(vqgan_params)
|
|
65
|
+
self._model_engine_to_model[model_engine] = [model, params, processor, vqgan, vqgan_params]
|
|
66
|
+
return self._model_engine_to_model[model_engine]
|
|
67
|
+
|
|
68
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
69
|
+
try:
|
|
70
|
+
import jax
|
|
71
|
+
from flax.training.common_utils import shard_prng_key
|
|
72
|
+
from flax.jax_utils import replicate
|
|
73
|
+
from PIL import Image
|
|
74
|
+
except ModuleNotFoundError as e:
|
|
75
|
+
handle_module_not_found_error(e, ["heim"])
|
|
76
|
+
|
|
77
|
+
raw_request = {
|
|
78
|
+
"prompt": request.prompt,
|
|
79
|
+
"top_k": None,
|
|
80
|
+
"top_p": None,
|
|
81
|
+
"temperature": None,
|
|
82
|
+
"condition_scale": 10.0,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
|
|
87
|
+
def _inference(
|
|
88
|
+
model, params, vqgan, vqgan_params, tokenized_prompt, subkey, top_k, top_p, temperature, condition_scale
|
|
89
|
+
):
|
|
90
|
+
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
|
|
91
|
+
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
|
|
92
|
+
return model.generate(
|
|
93
|
+
**tokenized_prompt,
|
|
94
|
+
prng_key=key,
|
|
95
|
+
params=params,
|
|
96
|
+
top_k=top_k,
|
|
97
|
+
top_p=top_p,
|
|
98
|
+
temperature=temperature,
|
|
99
|
+
condition_scale=condition_scale,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@partial(jax.pmap, axis_name="batch")
|
|
103
|
+
def p_decode(indices, params):
|
|
104
|
+
return vqgan.decode_code(indices, params=params)
|
|
105
|
+
|
|
106
|
+
# generate images
|
|
107
|
+
encoded_images = p_generate(
|
|
108
|
+
tokenized_prompt,
|
|
109
|
+
shard_prng_key(subkey),
|
|
110
|
+
params,
|
|
111
|
+
top_k,
|
|
112
|
+
top_p,
|
|
113
|
+
temperature,
|
|
114
|
+
condition_scale,
|
|
115
|
+
)
|
|
116
|
+
# remove BOS
|
|
117
|
+
encoded_images = encoded_images.sequences[..., 1:]
|
|
118
|
+
# decode images
|
|
119
|
+
decoded_images = p_decode(encoded_images, vqgan_params)
|
|
120
|
+
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
|
|
121
|
+
return decoded_images
|
|
122
|
+
|
|
123
|
+
def do_it() -> Dict[str, Any]:
|
|
124
|
+
prompt: str = request.prompt
|
|
125
|
+
|
|
126
|
+
with htrack_block(f"Generating images for prompt: {prompt}"):
|
|
127
|
+
model, params, processor, vqgan, vqgan_params = self._get_model(request.model_engine)
|
|
128
|
+
tokenized_prompts = processor([prompt])
|
|
129
|
+
tokenized_prompt = replicate(tokenized_prompts)
|
|
130
|
+
|
|
131
|
+
images: List[Image] = []
|
|
132
|
+
key = jax.random.PRNGKey(0)
|
|
133
|
+
for _ in range(request.num_completions):
|
|
134
|
+
key, subkey = jax.random.split(key)
|
|
135
|
+
image = _inference(
|
|
136
|
+
model,
|
|
137
|
+
params,
|
|
138
|
+
vqgan,
|
|
139
|
+
vqgan_params,
|
|
140
|
+
tokenized_prompt,
|
|
141
|
+
subkey,
|
|
142
|
+
raw_request["top_k"],
|
|
143
|
+
raw_request["top_p"],
|
|
144
|
+
raw_request["temperature"],
|
|
145
|
+
raw_request["condition_scale"],
|
|
146
|
+
)[0]
|
|
147
|
+
image = Image.fromarray(np.asarray(image * 255, dtype=np.uint8))
|
|
148
|
+
images.append(image)
|
|
149
|
+
|
|
150
|
+
assert (
|
|
151
|
+
len(images) == request.num_completions
|
|
152
|
+
), f"Expected {request.num_completions} images, but got {len(images)}"
|
|
153
|
+
|
|
154
|
+
result = {"file_locations": []}
|
|
155
|
+
for image in images:
|
|
156
|
+
# Write out the image to a file and save the path
|
|
157
|
+
file_location: str = self._file_cache.get_unique_file_location()
|
|
158
|
+
image.save(file_location)
|
|
159
|
+
hlog(f"Image saved at {file_location}.")
|
|
160
|
+
result["file_locations"].append(file_location)
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
# Include the model name and number of completions in the cache key
|
|
164
|
+
cache_key = CachingClient.make_cache_key(
|
|
165
|
+
{"model": request.model_engine, "n": request.num_completions, **raw_request}, request
|
|
166
|
+
)
|
|
167
|
+
results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
|
|
168
|
+
except RuntimeError as e:
|
|
169
|
+
error: str = f"DALLEMiniClient error: {e}"
|
|
170
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
171
|
+
|
|
172
|
+
completions: List[GeneratedOutput] = [
|
|
173
|
+
GeneratedOutput(
|
|
174
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_location)
|
|
175
|
+
)
|
|
176
|
+
for file_location in results["file_locations"]
|
|
177
|
+
]
|
|
178
|
+
return RequestResult(
|
|
179
|
+
success=True,
|
|
180
|
+
cached=cached,
|
|
181
|
+
request_time=results["request_time"],
|
|
182
|
+
completions=completions,
|
|
183
|
+
embedding=[],
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
187
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
188
|
+
|
|
189
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
190
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import Cache, CacheConfig
|
|
4
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput
|
|
5
|
+
from helm.common.tokenization_request import (
|
|
6
|
+
TokenizationRequest,
|
|
7
|
+
TokenizationRequestResult,
|
|
8
|
+
DecodeRequest,
|
|
9
|
+
DecodeRequestResult,
|
|
10
|
+
)
|
|
11
|
+
from helm.clients.client import Client, CachingClient
|
|
12
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeepFloydClient(Client):
|
|
16
|
+
"""
|
|
17
|
+
Client for [DeepFloyd image generation models](https://huggingface.co/docs/diffusers/v0.16.0/api/pipelines/ifs).
|
|
18
|
+
We rely on offline eval for now due to conflicting dependencies (e.g., Transformers).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
SUPPORTED_MODELS: List[str] = ["IF-I-M-v1.0", "IF-I-L-v1.0", "IF-I-XL-v1.0"]
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def convert_to_raw_request(request: Request) -> Dict:
|
|
25
|
+
# Use default hyperparameters for everything else
|
|
26
|
+
raw_request: Dict = {
|
|
27
|
+
"model": request.model_engine,
|
|
28
|
+
"n": request.num_completions,
|
|
29
|
+
"prompt": request.prompt,
|
|
30
|
+
"request_type": "image-model-inference",
|
|
31
|
+
}
|
|
32
|
+
if request.random is not None:
|
|
33
|
+
raw_request["random"] = request.random
|
|
34
|
+
return raw_request
|
|
35
|
+
|
|
36
|
+
def __init__(self, cache_config: CacheConfig):
|
|
37
|
+
self._cache = Cache(cache_config)
|
|
38
|
+
self._promptist_model = None
|
|
39
|
+
self._promptist_tokenizer = None
|
|
40
|
+
|
|
41
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
42
|
+
if request.model_engine not in self.SUPPORTED_MODELS:
|
|
43
|
+
raise ValueError(f"Unsupported model: {request.model_engine}")
|
|
44
|
+
|
|
45
|
+
raw_request = DeepFloydClient.convert_to_raw_request(request)
|
|
46
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
|
|
50
|
+
def fail():
|
|
51
|
+
raise RuntimeError(
|
|
52
|
+
f"The result has not been uploaded to the cache for the following request: {cache_key}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
response, cached = self._cache.get(cache_key, fail)
|
|
56
|
+
except RuntimeError as e:
|
|
57
|
+
error: str = f"DeepFloyd Client error: {e}"
|
|
58
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
59
|
+
|
|
60
|
+
completions: List[GeneratedOutput] = [
|
|
61
|
+
GeneratedOutput(
|
|
62
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_path)
|
|
63
|
+
)
|
|
64
|
+
for file_path in response["images"]
|
|
65
|
+
]
|
|
66
|
+
return RequestResult(
|
|
67
|
+
success=True,
|
|
68
|
+
cached=cached,
|
|
69
|
+
request_time=response["total_inference_time"],
|
|
70
|
+
completions=completions,
|
|
71
|
+
embedding=[],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
75
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
76
|
+
|
|
77
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
78
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from helm.common.cache import CacheConfig, Cache
|
|
8
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
9
|
+
from helm.common.gpu_utils import get_torch_device_name, is_cuda_available
|
|
10
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
11
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
12
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
13
|
+
from helm.common.tokenization_request import (
|
|
14
|
+
DecodeRequest,
|
|
15
|
+
DecodeRequestResult,
|
|
16
|
+
TokenizationRequest,
|
|
17
|
+
TokenizationRequestResult,
|
|
18
|
+
)
|
|
19
|
+
from helm.clients.client import Client, CachingClient
|
|
20
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_models_lock: Lock = Lock()
|
|
24
|
+
_models: Dict[str, Any] = {}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class HuggingFaceDiffusersClient(Client):
|
|
28
|
+
def __init__(self, hf_auth_token: str, cache_config: CacheConfig, file_cache: FileCache):
|
|
29
|
+
self._hf_auth_token: str = hf_auth_token
|
|
30
|
+
self._cache = Cache(cache_config)
|
|
31
|
+
self._file_cache: FileCache = file_cache
|
|
32
|
+
|
|
33
|
+
self._promptist_model = None
|
|
34
|
+
self._promptist_tokenizer = None
|
|
35
|
+
|
|
36
|
+
def _get_diffuser(self, request: Request):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the Diffusion Pipeline based on the model name.
|
|
39
|
+
Cache the model, so it doesn't get reinitialize for a new request.
|
|
40
|
+
"""
|
|
41
|
+
try:
|
|
42
|
+
from diffusers import DiffusionPipeline
|
|
43
|
+
except ModuleNotFoundError as e:
|
|
44
|
+
handle_module_not_found_error(e, ["heim"])
|
|
45
|
+
|
|
46
|
+
global _models_lock
|
|
47
|
+
global _models
|
|
48
|
+
|
|
49
|
+
with _models_lock:
|
|
50
|
+
model_engine: str = request.model_engine
|
|
51
|
+
|
|
52
|
+
if model_engine not in _models:
|
|
53
|
+
huggingface_model_name: str
|
|
54
|
+
if model_engine in ["stable-diffusion-v1-4", "promptist-stable-diffusion-v1-4"]:
|
|
55
|
+
huggingface_model_name = "CompVis/stable-diffusion-v1-4"
|
|
56
|
+
elif model_engine == "stable-diffusion-v1-5":
|
|
57
|
+
huggingface_model_name = "runwayml/stable-diffusion-v1-5"
|
|
58
|
+
elif model_engine == "stable-diffusion-v2-base":
|
|
59
|
+
huggingface_model_name = "stabilityai/stable-diffusion-2-base"
|
|
60
|
+
elif model_engine == "stable-diffusion-v2-1-base":
|
|
61
|
+
huggingface_model_name = "stabilityai/stable-diffusion-2-1-base"
|
|
62
|
+
elif model_engine == "dreamlike-diffusion-v1-0":
|
|
63
|
+
huggingface_model_name = "dreamlike-art/dreamlike-diffusion-1.0"
|
|
64
|
+
elif model_engine == "dreamlike-photoreal-v2-0":
|
|
65
|
+
huggingface_model_name = "dreamlike-art/dreamlike-photoreal-2.0"
|
|
66
|
+
elif model_engine == "openjourney-v1-0":
|
|
67
|
+
huggingface_model_name = "prompthero/openjourney"
|
|
68
|
+
elif model_engine == "openjourney-v2-0":
|
|
69
|
+
huggingface_model_name = "prompthero/openjourney-v2"
|
|
70
|
+
elif model_engine == "redshift-diffusion":
|
|
71
|
+
huggingface_model_name = "nitrosocke/redshift-diffusion"
|
|
72
|
+
elif "stable-diffusion-safe" in model_engine:
|
|
73
|
+
huggingface_model_name = "AIML-TUDA/stable-diffusion-safe"
|
|
74
|
+
elif model_engine == "vintedois-diffusion-v0-1":
|
|
75
|
+
huggingface_model_name = "22h/vintedois-diffusion-v0-1"
|
|
76
|
+
elif model_engine == "SSD-1B":
|
|
77
|
+
huggingface_model_name = "segmind/SSD-1B"
|
|
78
|
+
else:
|
|
79
|
+
huggingface_model_name = request.model
|
|
80
|
+
|
|
81
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
|
82
|
+
huggingface_model_name,
|
|
83
|
+
torch_dtype=torch.float16 if is_cuda_available() else torch.float,
|
|
84
|
+
use_auth_token=self._hf_auth_token,
|
|
85
|
+
)
|
|
86
|
+
_models[model_engine] = pipeline.to(get_torch_device_name())
|
|
87
|
+
return _models[model_engine]
|
|
88
|
+
|
|
89
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
90
|
+
try:
|
|
91
|
+
from diffusers import DiffusionPipeline
|
|
92
|
+
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
|
|
93
|
+
except ModuleNotFoundError as e:
|
|
94
|
+
handle_module_not_found_error(e, ["heim"])
|
|
95
|
+
|
|
96
|
+
raw_request = {
|
|
97
|
+
"prompt": request.prompt,
|
|
98
|
+
# Setting this to a higher value can cause CUDA OOM
|
|
99
|
+
# Fix it to 1 and generate an image `request.num_completions` times
|
|
100
|
+
"num_images_per_prompt": 1,
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
assert request.image_generation_parameters is not None
|
|
104
|
+
if request.image_generation_parameters.guidance_scale is not None:
|
|
105
|
+
raw_request["guidance_scale"] = request.image_generation_parameters.guidance_scale
|
|
106
|
+
if request.image_generation_parameters.diffusion_denoising_steps is not None:
|
|
107
|
+
raw_request["num_inference_steps"] = request.image_generation_parameters.diffusion_denoising_steps
|
|
108
|
+
if request.image_generation_parameters.output_image_width is not None:
|
|
109
|
+
raw_request["width"] = request.image_generation_parameters.output_image_width
|
|
110
|
+
if request.image_generation_parameters.output_image_height is not None:
|
|
111
|
+
raw_request["height"] = request.image_generation_parameters.output_image_height
|
|
112
|
+
|
|
113
|
+
# Add the additional pre-configured parameters for Safe Stable Diffusion
|
|
114
|
+
if request.model_engine == "stable-diffusion-safe-weak":
|
|
115
|
+
raw_request = {**raw_request, **SafetyConfig.WEAK}
|
|
116
|
+
elif request.model_engine == "stable-diffusion-safe-medium":
|
|
117
|
+
raw_request = {**raw_request, **SafetyConfig.MEDIUM}
|
|
118
|
+
elif request.model_engine == "stable-diffusion-safe-strong":
|
|
119
|
+
raw_request = {**raw_request, **SafetyConfig.STRONG}
|
|
120
|
+
elif request.model_engine == "stable-diffusion-safe-max":
|
|
121
|
+
raw_request = {**raw_request, **SafetyConfig.MAX}
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
|
|
125
|
+
def replace_prompt(request_to_update: Dict, new_prompt: str) -> Dict:
|
|
126
|
+
new_request: Dict = dict(request_to_update)
|
|
127
|
+
assert "prompt" in new_request
|
|
128
|
+
new_request["prompt"] = new_prompt
|
|
129
|
+
return new_request
|
|
130
|
+
|
|
131
|
+
def do_it() -> Dict[str, Any]:
|
|
132
|
+
prompt: str = request.prompt
|
|
133
|
+
|
|
134
|
+
with htrack_block(f"Generating images for prompt: {prompt}"):
|
|
135
|
+
diffuser: DiffusionPipeline = self._get_diffuser(request)
|
|
136
|
+
promptist_prompt: Optional[str] = None
|
|
137
|
+
|
|
138
|
+
images = []
|
|
139
|
+
for _ in range(request.num_completions):
|
|
140
|
+
if request.model_engine == "promptist-stable-diffusion-v1-4":
|
|
141
|
+
promptist_prompt = self._generate_promptist_prompt(prompt)
|
|
142
|
+
hlog(f"Promptist: {prompt} -> {promptist_prompt}")
|
|
143
|
+
image = diffuser(**replace_prompt(raw_request, promptist_prompt)).images[0] # type: ignore
|
|
144
|
+
elif request.model_engine == "openjourney-v1-0":
|
|
145
|
+
# It is required to include "mdjrny-v4 style" in prompt for Openjourney v1
|
|
146
|
+
image = diffuser(
|
|
147
|
+
**replace_prompt(raw_request, f"mdjrny-v4 style {prompt}") # type: ignore
|
|
148
|
+
).images[0]
|
|
149
|
+
elif request.model_engine == "redshift-diffusion":
|
|
150
|
+
# It is required to include "redshift style" to generate 3D images
|
|
151
|
+
image = diffuser(
|
|
152
|
+
**replace_prompt(raw_request, f"redshift style {prompt}") # type: ignore
|
|
153
|
+
).images[0]
|
|
154
|
+
else:
|
|
155
|
+
image = diffuser(**raw_request).images[0] # type: ignore
|
|
156
|
+
images.append(image)
|
|
157
|
+
|
|
158
|
+
assert (
|
|
159
|
+
len(images) == request.num_completions
|
|
160
|
+
), f"Expected {request.num_completions} images, but got {len(images)}"
|
|
161
|
+
|
|
162
|
+
result: Dict = {"file_locations": []}
|
|
163
|
+
if promptist_prompt is not None:
|
|
164
|
+
# Save the Promptist version of the prompts in the cache, just in case we need it later
|
|
165
|
+
result["promptist_prompt"] = promptist_prompt
|
|
166
|
+
|
|
167
|
+
for image in images:
|
|
168
|
+
# Write out the image to a file and save the path
|
|
169
|
+
file_location: str = self._file_cache.generate_unique_new_file_path() # type: ignore
|
|
170
|
+
image.save(file_location)
|
|
171
|
+
hlog(f"Image saved at {file_location}")
|
|
172
|
+
result["file_locations"].append(file_location)
|
|
173
|
+
return result
|
|
174
|
+
|
|
175
|
+
# Include the model name and number of completions in the cache key
|
|
176
|
+
cache_key = CachingClient.make_cache_key(
|
|
177
|
+
{"model": request.model_engine, "n": request.num_completions, **raw_request}, request
|
|
178
|
+
)
|
|
179
|
+
results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
|
|
180
|
+
except RuntimeError as ex:
|
|
181
|
+
error: str = f"HuggingFaceDiffusersClient error: {ex}"
|
|
182
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
183
|
+
|
|
184
|
+
completions: List[GeneratedOutput] = [
|
|
185
|
+
GeneratedOutput(
|
|
186
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_location)
|
|
187
|
+
)
|
|
188
|
+
for file_location in results["file_locations"]
|
|
189
|
+
]
|
|
190
|
+
return RequestResult(
|
|
191
|
+
success=True,
|
|
192
|
+
cached=cached,
|
|
193
|
+
request_time=results["request_time"],
|
|
194
|
+
completions=completions,
|
|
195
|
+
embedding=[],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def _generate_promptist_prompt(self, prompt: str) -> str:
|
|
199
|
+
"""
|
|
200
|
+
Generate a better version of the prompt with Promptist.
|
|
201
|
+
Promptist was trained specifically with CompVis/stable-diffusion-v1-4.
|
|
202
|
+
Adapted from https://huggingface.co/spaces/microsoft/Promptist/blob/main/app.py.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def load_promptist():
|
|
206
|
+
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
|
|
207
|
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
208
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
209
|
+
tokenizer.padding_side = "left"
|
|
210
|
+
return prompter_model, tokenizer
|
|
211
|
+
|
|
212
|
+
def generate(plain_text: str) -> str:
|
|
213
|
+
if self._promptist_model is None or self._promptist_tokenizer is None:
|
|
214
|
+
self._promptist_model, self._promptist_tokenizer = load_promptist()
|
|
215
|
+
assert self._promptist_model is not None
|
|
216
|
+
assert self._promptist_tokenizer is not None
|
|
217
|
+
|
|
218
|
+
input_ids = self._promptist_tokenizer(f"{plain_text.strip()} Rephrase:", return_tensors="pt").input_ids
|
|
219
|
+
eos_id = self._promptist_tokenizer.eos_token_id
|
|
220
|
+
# Used the same hyperparameters from the example
|
|
221
|
+
outputs = self._promptist_model.generate(
|
|
222
|
+
input_ids,
|
|
223
|
+
do_sample=False,
|
|
224
|
+
max_new_tokens=75,
|
|
225
|
+
num_beams=8,
|
|
226
|
+
num_return_sequences=8,
|
|
227
|
+
eos_token_id=eos_id,
|
|
228
|
+
pad_token_id=eos_id,
|
|
229
|
+
length_penalty=-1.0,
|
|
230
|
+
)
|
|
231
|
+
output_texts: List[str] = self._promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
232
|
+
|
|
233
|
+
for output_text in output_texts:
|
|
234
|
+
res: str = output_text.replace(f"{plain_text} Rephrase:", "").strip()
|
|
235
|
+
# The Promptist model sometimes generates empty string results.
|
|
236
|
+
# Return the first non-empty string result.
|
|
237
|
+
if len(res) > 0:
|
|
238
|
+
return res
|
|
239
|
+
|
|
240
|
+
# If all fails, just return the original text.
|
|
241
|
+
return plain_text
|
|
242
|
+
|
|
243
|
+
return generate(prompt)
|
|
244
|
+
|
|
245
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
246
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
247
|
+
|
|
248
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
249
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from helm.common.media_object import MediaObject, MultimediaObject
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_single_image_multimedia_object(image_location: str) -> MultimediaObject:
|
|
5
|
+
"""
|
|
6
|
+
Returns a `MultimediaObject` containing a single image file used for text-to-image generation clients.
|
|
7
|
+
"""
|
|
8
|
+
file_extension: str = image_location.split(".")[-1]
|
|
9
|
+
return MultimediaObject([MediaObject(content_type=f"image/{file_extension}", location=image_location)])
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from typing import Any, List, Dict, Union
|
|
2
|
+
import base64
|
|
3
|
+
import requests
|
|
4
|
+
import urllib.parse
|
|
5
|
+
|
|
6
|
+
from helm.common.cache import CacheConfig, Cache
|
|
7
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
8
|
+
from helm.common.images_utils import encode_base64
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
10
|
+
from helm.common.tokenization_request import (
|
|
11
|
+
TokenizationRequest,
|
|
12
|
+
TokenizationRequestResult,
|
|
13
|
+
DecodeRequest,
|
|
14
|
+
DecodeRequestResult,
|
|
15
|
+
)
|
|
16
|
+
from helm.clients.client import Client, CachingClient
|
|
17
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LexicaClient(Client):
|
|
21
|
+
"""
|
|
22
|
+
Client for Lexica API. Does not support image generation.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
|
|
26
|
+
self.cache = Cache(cache_config)
|
|
27
|
+
self.file_cache: FileCache = file_cache
|
|
28
|
+
|
|
29
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
30
|
+
"""
|
|
31
|
+
Retrieves images through Lexica's search API (https://lexica.art/docs).
|
|
32
|
+
The search API is powered by CLIP to fetch the most relevant images for a given query.
|
|
33
|
+
"""
|
|
34
|
+
if request.model_engine != "search-stable-diffusion-1.5":
|
|
35
|
+
# Only Stable Diffusion 1.5 is supported at the moment
|
|
36
|
+
raise ValueError(f"Invalid model: {request.model_engine}")
|
|
37
|
+
|
|
38
|
+
raw_request: Dict[str, Union[str, int]] = {
|
|
39
|
+
"model": request.model_engine,
|
|
40
|
+
"prompt": request.prompt,
|
|
41
|
+
"n": request.num_completions,
|
|
42
|
+
}
|
|
43
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
|
|
47
|
+
def do_it() -> Dict[str, Any]:
|
|
48
|
+
num_completions: int = int(raw_request["n"])
|
|
49
|
+
result = requests.get(
|
|
50
|
+
f"https://lexica.art/api/v1/search?{urllib.parse.urlencode({'q': request.prompt})}"
|
|
51
|
+
).json()
|
|
52
|
+
assert "images" in result, f"Invalid response: {result} from prompt: {request.prompt}"
|
|
53
|
+
assert len(result["images"]) >= num_completions, "Did not retrieve enough images"
|
|
54
|
+
|
|
55
|
+
image_locations: List[str] = []
|
|
56
|
+
# Most relevant images are at the top of the list
|
|
57
|
+
for image in result["images"][:num_completions]:
|
|
58
|
+
# Write out the image to a file and save the location
|
|
59
|
+
image_base64: str = encode_base64(image["src"])
|
|
60
|
+
image_locations.append(self.file_cache.store(lambda: base64.b64decode(image_base64)))
|
|
61
|
+
return {"image_locations": image_locations}
|
|
62
|
+
|
|
63
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
64
|
+
except RuntimeError as e:
|
|
65
|
+
error: str = f"LexicaClient error: {e}"
|
|
66
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
67
|
+
|
|
68
|
+
completions: List[GeneratedOutput] = [
|
|
69
|
+
GeneratedOutput(
|
|
70
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(location)
|
|
71
|
+
)
|
|
72
|
+
for location in response["image_locations"]
|
|
73
|
+
]
|
|
74
|
+
return RequestResult(
|
|
75
|
+
success=True,
|
|
76
|
+
cached=cached,
|
|
77
|
+
request_time=response["request_time"],
|
|
78
|
+
completions=completions,
|
|
79
|
+
embedding=[],
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
83
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
84
|
+
|
|
85
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
86
|
+
raise NotImplementedError("This client does not support decoding.")
|
|
File without changes
|