crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
helm/common/cache.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
from abc import abstractmethod
|
|
3
|
-
import contextlib
|
|
1
|
+
from collections import defaultdict
|
|
4
2
|
from dataclasses import dataclass
|
|
3
|
+
from typing import Dict, Callable, Generator, Mapping, Optional, Tuple
|
|
5
4
|
import json
|
|
6
|
-
from typing import Dict, Callable, Generator, Iterable, Optional, Tuple
|
|
7
|
-
from collections import defaultdict
|
|
8
|
-
import sqlite3
|
|
9
5
|
import threading
|
|
10
6
|
|
|
11
|
-
|
|
7
|
+
import sqlite3
|
|
8
|
+
|
|
12
9
|
from helm.common.general import hlog, htrack
|
|
10
|
+
from helm.common.key_value_store import BlackHoleKeyValueStore, KeyValueStore, SqliteKeyValueStore
|
|
13
11
|
from helm.proxy.retry import get_retry_decorator
|
|
14
|
-
from bson.son import SON
|
|
15
|
-
from bson.errors import InvalidDocument
|
|
16
|
-
from pymongo import MongoClient, ReplaceOne
|
|
17
12
|
|
|
18
13
|
try:
|
|
19
14
|
from cPickle import loads
|
|
@@ -21,31 +16,19 @@ except ImportError:
|
|
|
21
16
|
from pickle import loads
|
|
22
17
|
|
|
23
18
|
|
|
24
|
-
def request_to_key(request: Dict) -> str:
|
|
25
|
-
"""Normalize a `request` into a `key` so that we can hash using it."""
|
|
26
|
-
return json.dumps(request, sort_keys=True)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def key_to_request(key: str) -> Dict:
|
|
30
|
-
"""Convert the normalized version to the request."""
|
|
31
|
-
return json.loads(key)
|
|
32
|
-
|
|
33
|
-
|
|
34
19
|
def retry_if_write_failed(success: bool) -> bool:
|
|
35
20
|
"""Retries when the write fails."""
|
|
36
21
|
return not success
|
|
37
22
|
|
|
38
23
|
|
|
39
24
|
retry: Callable = get_retry_decorator(
|
|
40
|
-
"Write", max_attempts=
|
|
25
|
+
"Write", max_attempts=5, wait_exponential_multiplier_seconds=2, retry_on_result=retry_if_write_failed
|
|
41
26
|
)
|
|
42
27
|
|
|
43
28
|
|
|
44
29
|
class CacheConfig:
|
|
45
30
|
"""Configuration for a cache."""
|
|
46
31
|
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
32
|
@property
|
|
50
33
|
def cache_stats_key(self) -> str:
|
|
51
34
|
"""The string key used by CacheStats to identify this cache."""
|
|
@@ -55,8 +38,6 @@ class CacheConfig:
|
|
|
55
38
|
class KeyValueStoreCacheConfig(CacheConfig):
|
|
56
39
|
"""Configuration for a cache backed by a key-value store."""
|
|
57
40
|
|
|
58
|
-
pass
|
|
59
|
-
|
|
60
41
|
|
|
61
42
|
@dataclass(frozen=True)
|
|
62
43
|
class SqliteCacheConfig(KeyValueStoreCacheConfig):
|
|
@@ -70,6 +51,16 @@ class SqliteCacheConfig(KeyValueStoreCacheConfig):
|
|
|
70
51
|
return self.path
|
|
71
52
|
|
|
72
53
|
|
|
54
|
+
@dataclass(frozen=True)
|
|
55
|
+
class BlackHoleCacheConfig(KeyValueStoreCacheConfig):
|
|
56
|
+
"""Configuration for a cache that does not save any data."""
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def cache_stats_key(self) -> str:
|
|
60
|
+
"""The string key used by CacheStats to identify this cache."""
|
|
61
|
+
return "disabled_cache"
|
|
62
|
+
|
|
63
|
+
|
|
73
64
|
@dataclass(frozen=True)
|
|
74
65
|
class MongoCacheConfig(KeyValueStoreCacheConfig):
|
|
75
66
|
"""Configuration for a cache backed by a MongoDB collection."""
|
|
@@ -105,156 +96,6 @@ class WithFollowerCacheConfig(CacheConfig):
|
|
|
105
96
|
return self.main.cache_stats_key
|
|
106
97
|
|
|
107
98
|
|
|
108
|
-
class KeyValueStore(contextlib.AbstractContextManager):
|
|
109
|
-
"""Key value store that persists writes."""
|
|
110
|
-
|
|
111
|
-
@property
|
|
112
|
-
def path(self):
|
|
113
|
-
return self._path
|
|
114
|
-
|
|
115
|
-
@abstractmethod
|
|
116
|
-
def contains(self, key: Dict) -> bool:
|
|
117
|
-
pass
|
|
118
|
-
|
|
119
|
-
@abstractmethod
|
|
120
|
-
def get(self, key: Dict) -> Optional[Dict]:
|
|
121
|
-
pass
|
|
122
|
-
|
|
123
|
-
@abstractmethod
|
|
124
|
-
def get_all(self) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
125
|
-
pass
|
|
126
|
-
|
|
127
|
-
@abstractmethod
|
|
128
|
-
def put(self, key: Dict, value: Dict) -> None:
|
|
129
|
-
pass
|
|
130
|
-
|
|
131
|
-
@abstractmethod
|
|
132
|
-
def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
|
|
133
|
-
pass
|
|
134
|
-
|
|
135
|
-
@abstractmethod
|
|
136
|
-
def remove(self, key: Dict) -> None:
|
|
137
|
-
pass
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class _SqliteKeyValueStore(KeyValueStore):
|
|
141
|
-
"""Key value store backed by a SQLite file."""
|
|
142
|
-
|
|
143
|
-
def __init__(self, path: str):
|
|
144
|
-
self._sqlite_dict = SqliteDict(path)
|
|
145
|
-
super().__init__()
|
|
146
|
-
|
|
147
|
-
def __enter__(self) -> "_SqliteKeyValueStore":
|
|
148
|
-
self._sqlite_dict.__enter__()
|
|
149
|
-
return self
|
|
150
|
-
|
|
151
|
-
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
152
|
-
self._sqlite_dict.__exit__(exc_type, exc_value, traceback)
|
|
153
|
-
|
|
154
|
-
def contains(self, key: Dict) -> bool:
|
|
155
|
-
return request_to_key(key) in self._sqlite_dict
|
|
156
|
-
|
|
157
|
-
def get(self, key: Dict) -> Optional[Dict]:
|
|
158
|
-
key_string = request_to_key(key)
|
|
159
|
-
result = self._sqlite_dict.get(key_string)
|
|
160
|
-
if result is not None:
|
|
161
|
-
assert isinstance(result, dict)
|
|
162
|
-
return result
|
|
163
|
-
return None
|
|
164
|
-
|
|
165
|
-
def get_all(self) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
166
|
-
for key, value in self._sqlite_dict.items():
|
|
167
|
-
yield (key, value)
|
|
168
|
-
|
|
169
|
-
def put(self, key: Dict, value: Dict) -> None:
|
|
170
|
-
key_string = request_to_key(key)
|
|
171
|
-
self._sqlite_dict[key_string] = value
|
|
172
|
-
self._sqlite_dict.commit()
|
|
173
|
-
|
|
174
|
-
def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
|
|
175
|
-
for key, value in pairs:
|
|
176
|
-
self.put(key, value)
|
|
177
|
-
|
|
178
|
-
def remove(self, key: Dict) -> None:
|
|
179
|
-
del self._sqlite_dict[key]
|
|
180
|
-
self._sqlite_dict.commit()
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
class _MongoKeyValueStore(KeyValueStore):
|
|
184
|
-
"""Key value store backed by a MongoDB database."""
|
|
185
|
-
|
|
186
|
-
# The number of documents to return per batch.
|
|
187
|
-
_BATCH_SIZE: int = 8
|
|
188
|
-
|
|
189
|
-
_REQUEST_KEY = "request"
|
|
190
|
-
_RESPONSE_KEY = "response"
|
|
191
|
-
|
|
192
|
-
def __init__(self, uri: str, collection_name: str):
|
|
193
|
-
# TODO: Create client in __enter__ and clean up client in __exit__
|
|
194
|
-
self._mongodb_client: MongoClient = MongoClient(uri)
|
|
195
|
-
self._database = self._mongodb_client.get_default_database()
|
|
196
|
-
self._collection = self._database.get_collection(collection_name)
|
|
197
|
-
self._collection.create_index(self._REQUEST_KEY, unique=True)
|
|
198
|
-
super().__init__()
|
|
199
|
-
|
|
200
|
-
def __enter__(self) -> "_MongoKeyValueStore":
|
|
201
|
-
return self
|
|
202
|
-
|
|
203
|
-
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
204
|
-
return
|
|
205
|
-
|
|
206
|
-
def _canonicalize_key(self, key: Dict) -> SON:
|
|
207
|
-
serialized = json.dumps(key, sort_keys=True)
|
|
208
|
-
return json.loads(serialized, object_pairs_hook=SON)
|
|
209
|
-
|
|
210
|
-
def contains(self, key: Dict) -> bool:
|
|
211
|
-
query = {self._REQUEST_KEY: self._canonicalize_key(key)}
|
|
212
|
-
return self._collection.find_one(query) is not None
|
|
213
|
-
|
|
214
|
-
def get(self, key: Dict) -> Optional[Dict]:
|
|
215
|
-
query = {self._REQUEST_KEY: self._canonicalize_key(key)}
|
|
216
|
-
document = self._collection.find_one(query)
|
|
217
|
-
if document is not None:
|
|
218
|
-
response = document[self._RESPONSE_KEY]
|
|
219
|
-
if isinstance(response, str):
|
|
220
|
-
return json.loads(response)
|
|
221
|
-
else:
|
|
222
|
-
return response
|
|
223
|
-
return None
|
|
224
|
-
|
|
225
|
-
def get_all(self) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
226
|
-
for document in self._collection.find({}).batch_size(self._BATCH_SIZE):
|
|
227
|
-
request = document[self._REQUEST_KEY]
|
|
228
|
-
response = document[self._RESPONSE_KEY]
|
|
229
|
-
if isinstance(response, str):
|
|
230
|
-
yield (request, json.loads(response))
|
|
231
|
-
else:
|
|
232
|
-
yield (request, response)
|
|
233
|
-
|
|
234
|
-
def put(self, key: Dict, value: Dict) -> None:
|
|
235
|
-
request = self._canonicalize_key(key)
|
|
236
|
-
document = SON([(self._REQUEST_KEY, request), (self._RESPONSE_KEY, value)])
|
|
237
|
-
# The MongoDB collection should have a unique indexed on "request"
|
|
238
|
-
try:
|
|
239
|
-
self._collection.replace_one(filter={"request": request}, replacement=document, upsert=True)
|
|
240
|
-
except InvalidDocument:
|
|
241
|
-
# If the document is malformed e.g. because of null bytes in keys, instead store the response as a string.
|
|
242
|
-
alternate_document = SON([(self._REQUEST_KEY, request), (self._RESPONSE_KEY, json.dumps(value))])
|
|
243
|
-
self._collection.replace_one(filter={"request": request}, replacement=alternate_document, upsert=True)
|
|
244
|
-
|
|
245
|
-
def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
|
|
246
|
-
operations = []
|
|
247
|
-
for key, value in pairs:
|
|
248
|
-
request = self._canonicalize_key(key)
|
|
249
|
-
document = SON([(self._REQUEST_KEY, request), (self._RESPONSE_KEY, value)])
|
|
250
|
-
operations.append(ReplaceOne({self._REQUEST_KEY: request}, document, upsert=True))
|
|
251
|
-
# Note: unlike put, multi_put does not support documents with null bytes in keys.
|
|
252
|
-
self._collection.bulk_write(operations)
|
|
253
|
-
|
|
254
|
-
def remove(self, key: Dict) -> None:
|
|
255
|
-
self._collection.delete_one(key)
|
|
256
|
-
|
|
257
|
-
|
|
258
99
|
def get_all_from_sqlite(path: str) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
259
100
|
"""Yields all decoded key, value pairs from the SQLite cache.
|
|
260
101
|
|
|
@@ -277,15 +118,19 @@ def create_key_value_store(config: KeyValueStoreCacheConfig) -> KeyValueStore:
|
|
|
277
118
|
"""Create a key value store from the given configuration."""
|
|
278
119
|
# TODO: Support creating _MongoKeyValueStore
|
|
279
120
|
if isinstance(config, MongoCacheConfig):
|
|
280
|
-
|
|
121
|
+
from helm.common.mongo_key_value_store import MongoKeyValueStore
|
|
122
|
+
|
|
123
|
+
return MongoKeyValueStore(config.uri, config.collection_name)
|
|
281
124
|
elif isinstance(config, SqliteCacheConfig):
|
|
282
|
-
return
|
|
125
|
+
return SqliteKeyValueStore(config.path)
|
|
126
|
+
elif isinstance(config, BlackHoleCacheConfig):
|
|
127
|
+
return BlackHoleKeyValueStore()
|
|
283
128
|
else:
|
|
284
129
|
raise ValueError(f"KeyValueStoreCacheConfig with unknown type: {config}")
|
|
285
130
|
|
|
286
131
|
|
|
287
132
|
@retry
|
|
288
|
-
def write_to_key_value_store(key_value_store: KeyValueStore, key:
|
|
133
|
+
def write_to_key_value_store(key_value_store: KeyValueStore, key: Mapping, response: Dict) -> bool:
|
|
289
134
|
"""
|
|
290
135
|
Write to the key value store with retry. Returns boolean indicating whether the write was successful or not.
|
|
291
136
|
"""
|
|
@@ -355,7 +200,7 @@ class Cache(object):
|
|
|
355
200
|
else:
|
|
356
201
|
raise ValueError(f"CacheConfig with unknown type: {config}")
|
|
357
202
|
|
|
358
|
-
def get(self, request:
|
|
203
|
+
def get(self, request: Mapping, compute: Callable[[], Dict]) -> Tuple[Dict, bool]:
|
|
359
204
|
"""Get the result of `request` (by calling `compute` as needed)."""
|
|
360
205
|
cache_stats.increment_query(self.config.cache_stats_key)
|
|
361
206
|
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from helm.common.cache import CacheConfig, MongoCacheConfig, BlackHoleCacheConfig, SqliteCacheConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CacheBackendConfig(ABC):
|
|
9
|
+
"""Config for a cache backend."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
13
|
+
"""Get a CacheConfig for the given shard."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class MongoCacheBackendConfig(CacheBackendConfig):
|
|
19
|
+
"""Config for a MongoDB cache backend."""
|
|
20
|
+
|
|
21
|
+
uri: str
|
|
22
|
+
"""URL for the MongoDB database that contains the collection.
|
|
23
|
+
|
|
24
|
+
Example format: mongodb://[username:password@]host1[:port1]/[dbname]
|
|
25
|
+
For full format, see: https://www.mongodb.com/docs/manual/reference/connection-string/"""
|
|
26
|
+
|
|
27
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
28
|
+
return MongoCacheConfig(uri=self.uri, collection_name=shard_name)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class BlackHoleCacheBackendConfig(CacheBackendConfig):
|
|
33
|
+
"""Config for a cache backend that does not save any data."""
|
|
34
|
+
|
|
35
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
36
|
+
return BlackHoleCacheConfig()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class SqliteCacheBackendConfig(CacheBackendConfig):
|
|
41
|
+
"""Config for a Sqlite cache backend."""
|
|
42
|
+
|
|
43
|
+
path: str
|
|
44
|
+
"""Path for the directory that will contain Sqlite files for caches."""
|
|
45
|
+
|
|
46
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
47
|
+
return SqliteCacheConfig(path=os.path.join(self.path, f"{shard_name}.sqlite"))
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
DEFAULT_CLIP_SCORE_MODEL = "openai/clip-vit-large-patch14"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class CLIPScoreRequest:
|
|
10
|
+
"""
|
|
11
|
+
Computes a CLIPScore for a given caption and image.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Caption to compute CLIPScore for
|
|
15
|
+
caption: str
|
|
16
|
+
|
|
17
|
+
# Location of the image
|
|
18
|
+
image_location: str
|
|
19
|
+
|
|
20
|
+
# Which CLIP model to use
|
|
21
|
+
model: str = DEFAULT_CLIP_SCORE_MODEL
|
|
22
|
+
|
|
23
|
+
# Compute multilingual CLIPScore
|
|
24
|
+
multilingual: bool = False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class CLIPScoreResult:
|
|
29
|
+
"""Result after sending a `CLIPScoreRequest`."""
|
|
30
|
+
|
|
31
|
+
# Whether the request was successful
|
|
32
|
+
success: bool
|
|
33
|
+
|
|
34
|
+
# Whether the request was cached
|
|
35
|
+
cached: bool
|
|
36
|
+
|
|
37
|
+
# The CLIPScore
|
|
38
|
+
score: float = 0.0
|
|
39
|
+
|
|
40
|
+
# If `success` is false, what was the error?
|
|
41
|
+
error: Optional[str] = None
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from contextlib import AbstractContextManager
|
|
2
|
+
from threading import Lock
|
|
3
|
+
from typing import TypeVar, Generic
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ThreadSafeWrapper(AbstractContextManager, Generic[T]):
|
|
10
|
+
"""A wrapper that makes thread-hostile objects thread-safe.
|
|
11
|
+
|
|
12
|
+
This provides a context manager that holds a lock for accessing the inner object.
|
|
13
|
+
|
|
14
|
+
Example usage:
|
|
15
|
+
|
|
16
|
+
wrapped_obj = wrapper(thread_hostile_obj)
|
|
17
|
+
with wrapped_obj as obj:
|
|
18
|
+
# Lock is automatically held in here
|
|
19
|
+
obj.do_stuff()
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, wrapped: T):
|
|
23
|
+
self._wrapped = wrapped
|
|
24
|
+
self._lock = Lock()
|
|
25
|
+
|
|
26
|
+
def __enter__(self) -> T:
|
|
27
|
+
self._lock.__enter__()
|
|
28
|
+
return self._wrapped
|
|
29
|
+
|
|
30
|
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
31
|
+
self._lock.__exit__(exc_type, exc_value, traceback)
|
|
32
|
+
pass
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Functions used for credentials."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Mapping, Optional
|
|
4
|
+
|
|
5
|
+
from helm.common.hierarchical_logger import hlog
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def provide_api_key(
|
|
9
|
+
credentials: Mapping[str, Any], host_organization: str, model: Optional[str] = None
|
|
10
|
+
) -> Optional[str]:
|
|
11
|
+
api_key_name = host_organization + "ApiKey"
|
|
12
|
+
if api_key_name in credentials:
|
|
13
|
+
hlog(f"Using host_organization api key defined in credentials.conf: {api_key_name}")
|
|
14
|
+
return credentials[api_key_name]
|
|
15
|
+
if "deployments" not in credentials:
|
|
16
|
+
hlog(
|
|
17
|
+
"WARNING: Could not find key 'deployments' in credentials.conf, "
|
|
18
|
+
f"therefore the API key {api_key_name} should be specified."
|
|
19
|
+
)
|
|
20
|
+
return None
|
|
21
|
+
deployment_api_keys = credentials["deployments"]
|
|
22
|
+
if model is None:
|
|
23
|
+
hlog(f"WARNING: Could not find key '{host_organization}' in credentials.conf and no model provided")
|
|
24
|
+
return None
|
|
25
|
+
if model not in deployment_api_keys:
|
|
26
|
+
hlog(f"WARNING: Could not find key '{model}' under key 'deployments' in credentials.conf")
|
|
27
|
+
return None
|
|
28
|
+
return deployment_api_keys[model]
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FileCache(ABC):
|
|
6
|
+
"""
|
|
7
|
+
Cache to store files.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def store(self, compute: Callable[[], bytes]) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Stores the output of `compute` as a file at a unique location.
|
|
14
|
+
Returns the location of the file.
|
|
15
|
+
"""
|
|
16
|
+
pass
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from helm.common.general import ensure_directory_exists, generate_unique_id
|
|
5
|
+
from .file_cache import FileCache
|
|
6
|
+
|
|
7
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from PIL import Image
|
|
11
|
+
except ModuleNotFoundError as e:
|
|
12
|
+
handle_module_not_found_error(e, ["images"])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalFileCache(FileCache):
|
|
16
|
+
def __init__(self, base_path: str, file_extension: str):
|
|
17
|
+
ensure_directory_exists(base_path)
|
|
18
|
+
self._location: str = base_path
|
|
19
|
+
self._file_extension: str = file_extension
|
|
20
|
+
|
|
21
|
+
def store(self, compute: Callable[[], bytes]) -> str:
|
|
22
|
+
"""
|
|
23
|
+
Stores the output of `compute` as a file at a unique path.
|
|
24
|
+
Returns the file path.
|
|
25
|
+
"""
|
|
26
|
+
file_path: str = self.generate_unique_new_file_path()
|
|
27
|
+
with open(file_path, "wb") as f:
|
|
28
|
+
f.write(compute())
|
|
29
|
+
|
|
30
|
+
return file_path
|
|
31
|
+
|
|
32
|
+
def generate_unique_new_file_path(self) -> str:
|
|
33
|
+
"""Generate an unique file name at `base_path`"""
|
|
34
|
+
|
|
35
|
+
def generate_one() -> str:
|
|
36
|
+
file_name: str = f"{generate_unique_id()}.{self._file_extension}"
|
|
37
|
+
return os.path.join(self._location, file_name)
|
|
38
|
+
|
|
39
|
+
file_path: str
|
|
40
|
+
while True:
|
|
41
|
+
file_path = generate_one()
|
|
42
|
+
if not os.path.exists(file_path):
|
|
43
|
+
break
|
|
44
|
+
return file_path
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LocalPILFileCache(LocalFileCache):
|
|
48
|
+
def __init__(self, base_path: str):
|
|
49
|
+
super().__init__(base_path, "png")
|
|
50
|
+
|
|
51
|
+
def store_image(self, compute: Callable[[], Image.Image]) -> str:
|
|
52
|
+
"""
|
|
53
|
+
Stores the output of `compute` as a file at a unique path.
|
|
54
|
+
Returns the file path.
|
|
55
|
+
"""
|
|
56
|
+
file_path: str = self.generate_unique_new_file_path()
|
|
57
|
+
compute().save(file_path)
|
|
58
|
+
return file_path
|
|
59
|
+
|
|
60
|
+
def load_image(self, file_path: str) -> Image.Image:
|
|
61
|
+
return Image.open(file_path).convert("RGB")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import tempfile
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
from .local_file_cache import LocalFileCache
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestLocalFileCache(unittest.TestCase):
|
|
10
|
+
def setup_method(self, _):
|
|
11
|
+
self.path: str = tempfile.mkdtemp()
|
|
12
|
+
|
|
13
|
+
def teardown_method(self, _):
|
|
14
|
+
shutil.rmtree(self.path)
|
|
15
|
+
|
|
16
|
+
def test_get(self):
|
|
17
|
+
cache = LocalFileCache(self.path, file_extension="txt")
|
|
18
|
+
file_path1: str = cache.store(lambda: "hello.".encode())
|
|
19
|
+
|
|
20
|
+
# Verify the contents of the file
|
|
21
|
+
with open(file_path1, "r") as f:
|
|
22
|
+
assert f.read() == "hello."
|
|
23
|
+
|
|
24
|
+
cache.store(lambda: "bye.".encode())
|
|
25
|
+
assert len(os.listdir(self.path)) == 2
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass(frozen=True)
|
|
6
|
+
class FileUploadRequest:
|
|
7
|
+
"""Uploads a file at `path`."""
|
|
8
|
+
|
|
9
|
+
# Path of the file to upload
|
|
10
|
+
path: str
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class FileUploadResult:
|
|
15
|
+
"""Result after sending a `FileUploadRequest`."""
|
|
16
|
+
|
|
17
|
+
# Whether the request was successful
|
|
18
|
+
success: bool
|
|
19
|
+
|
|
20
|
+
# Whether the request was cached
|
|
21
|
+
cached: bool
|
|
22
|
+
|
|
23
|
+
# URL of the uploaded file
|
|
24
|
+
url: str
|
|
25
|
+
|
|
26
|
+
# If `success` is false, what was the error?
|
|
27
|
+
error: Optional[str] = None
|
helm/common/general.py
CHANGED
|
@@ -7,7 +7,8 @@ import urllib
|
|
|
7
7
|
import uuid
|
|
8
8
|
import zstandard
|
|
9
9
|
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
|
10
|
-
from
|
|
10
|
+
from datetime import datetime, date
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
11
12
|
from tqdm import tqdm
|
|
12
13
|
|
|
13
14
|
import pyhocon
|
|
@@ -62,7 +63,7 @@ def shell(args: List[str]):
|
|
|
62
63
|
hlog(f"Executing: {cmd}")
|
|
63
64
|
exit_code = subprocess.call(args)
|
|
64
65
|
if exit_code != 0:
|
|
65
|
-
|
|
66
|
+
raise Exception(f"Failed with exit code {exit_code}: {cmd}")
|
|
66
67
|
|
|
67
68
|
|
|
68
69
|
@htrack(None)
|
|
@@ -160,6 +161,13 @@ def asdict_without_nones(obj: Any) -> Dict[str, Any]:
|
|
|
160
161
|
return asdict(obj, dict_factory=lambda x: {k: v for (k, v) in x if v is not None})
|
|
161
162
|
|
|
162
163
|
|
|
164
|
+
def serialize_dates(obj):
|
|
165
|
+
"""Serialize dates (pass deault=serialize_dates into json.dumps)."""
|
|
166
|
+
if isinstance(obj, (datetime, date)):
|
|
167
|
+
return obj.isoformat()
|
|
168
|
+
raise TypeError(f"Type {type(obj)} is not serializable")
|
|
169
|
+
|
|
170
|
+
|
|
163
171
|
def binarize_dict(d: Dict[str, int]) -> Dict[str, int]:
|
|
164
172
|
"""Binarize the dict by setting the values that are 1 to 0.
|
|
165
173
|
|
|
@@ -214,20 +222,14 @@ InT = TypeVar("InT")
|
|
|
214
222
|
OutT = TypeVar("OutT")
|
|
215
223
|
|
|
216
224
|
|
|
217
|
-
def parallel_map(
|
|
218
|
-
process: Callable[[InT], OutT], items: List[InT], parallelism: int, multiprocessing: bool = False
|
|
219
|
-
) -> List[OutT]:
|
|
225
|
+
def parallel_map(process: Callable[[InT], OutT], items: List[InT], parallelism: int) -> List[OutT]:
|
|
220
226
|
"""
|
|
221
227
|
A wrapper for applying `process` to all `items`.
|
|
222
228
|
"""
|
|
223
|
-
|
|
224
|
-
with htrack_block(f"Parallelizing computation on {len(items)} items over {parallelism} {units}"):
|
|
229
|
+
with htrack_block(f"Parallelizing computation on {len(items)} items over {parallelism} threads"):
|
|
225
230
|
results: List
|
|
226
231
|
if parallelism == 1:
|
|
227
232
|
results = list(tqdm(map(process, items), total=len(items), disable=None))
|
|
228
|
-
elif multiprocessing:
|
|
229
|
-
with ProcessPoolExecutor(max_workers=parallelism) as executor:
|
|
230
|
-
results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
|
|
231
233
|
else:
|
|
232
234
|
with ThreadPoolExecutor(max_workers=parallelism) as executor:
|
|
233
235
|
results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
|
|
@@ -320,3 +322,20 @@ def safe_symlink(src: str, dest: str) -> None:
|
|
|
320
322
|
def is_url(location: str) -> bool:
|
|
321
323
|
"""Return True if `location` is a url. False otherwise."""
|
|
322
324
|
return urllib.parse.urlparse(location).scheme in ["http", "https"]
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def assert_is_str(val: Any) -> str:
|
|
328
|
+
assert isinstance(val, str)
|
|
329
|
+
return val
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def assert_is_str_list(val: Any) -> List[str]:
|
|
333
|
+
assert isinstance(val, list)
|
|
334
|
+
for v in val:
|
|
335
|
+
assert isinstance(v, str)
|
|
336
|
+
return val
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def assert_present(val: Optional[InT]) -> InT:
|
|
340
|
+
assert val is not None
|
|
341
|
+
return val
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass(frozen=True)
|
|
6
|
+
class ImageGenerationParameters:
|
|
7
|
+
"""
|
|
8
|
+
Parameters for image generation.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
output_image_width: Optional[int] = None
|
|
12
|
+
"""Width of the generated image. The model will generate images with the model's
|
|
13
|
+
default dimensions when unspecified."""
|
|
14
|
+
|
|
15
|
+
output_image_height: Optional[int] = None
|
|
16
|
+
"""Height of the generated image. The model will generate images with the model's
|
|
17
|
+
default dimensions when unspecified."""
|
|
18
|
+
|
|
19
|
+
guidance_scale: Optional[float] = None
|
|
20
|
+
"""A non-negative number determining how much importance is given to the prompt
|
|
21
|
+
when generating images. Higher values will generate images that follow more
|
|
22
|
+
closely to the prompt. Currently only for diffusion models."""
|
|
23
|
+
|
|
24
|
+
diffusion_denoising_steps: Optional[int] = None
|
|
25
|
+
"""The number of denoising steps for diffusion models."""
|