crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from dataclasses import asdict
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from helm.common.cache import Cache, CacheConfig
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
7
|
+
from helm.common.hierarchical_logger import hlog
|
|
8
|
+
from helm.common.file_upload_request import FileUploadRequest, FileUploadResult
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GCSClientError(Exception):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GCSClient:
|
|
16
|
+
"""
|
|
17
|
+
Uploads files to GCS. Ensure the GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json
|
|
18
|
+
environment variable is set.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
MAX_CHECK_ATTEMPTS: int = 10
|
|
22
|
+
|
|
23
|
+
def __init__(self, bucket_name: str, cache_config: CacheConfig):
|
|
24
|
+
try:
|
|
25
|
+
from google.cloud import storage # type: ignore
|
|
26
|
+
except ModuleNotFoundError as e:
|
|
27
|
+
handle_module_not_found_error(e, ["heim"])
|
|
28
|
+
|
|
29
|
+
self._bucket_name: str = bucket_name
|
|
30
|
+
self._cache = Cache(cache_config)
|
|
31
|
+
self._storage_client: Optional[storage.Client] = None
|
|
32
|
+
|
|
33
|
+
def upload(self, request: FileUploadRequest) -> FileUploadResult:
|
|
34
|
+
"""Uploads a file to GCS."""
|
|
35
|
+
try:
|
|
36
|
+
from google.cloud import storage # type: ignore
|
|
37
|
+
except ModuleNotFoundError as e:
|
|
38
|
+
handle_module_not_found_error(e, ["heim"])
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
|
|
42
|
+
def do_it():
|
|
43
|
+
if self._storage_client is None:
|
|
44
|
+
self._storage_client = storage.Client()
|
|
45
|
+
|
|
46
|
+
bucket = self._storage_client.bucket(self._bucket_name)
|
|
47
|
+
file_path: str = request.path
|
|
48
|
+
blob = bucket.blob(file_path)
|
|
49
|
+
|
|
50
|
+
# Optional: set a generation-match precondition to avoid potential race conditions
|
|
51
|
+
# and data corruptions. The request to upload is aborted if the object's
|
|
52
|
+
# generation number does not match your precondition. For a destination
|
|
53
|
+
# object that does not yet exist, set the if_generation_match precondition to 0.
|
|
54
|
+
# If the destination object already exists in your bucket, set instead a
|
|
55
|
+
# generation-match precondition using its generation number.
|
|
56
|
+
generation_match_precondition: int = 0
|
|
57
|
+
|
|
58
|
+
blob.upload_from_filename(file_path, if_generation_match=generation_match_precondition)
|
|
59
|
+
url: str = self._get_url(file_path)
|
|
60
|
+
|
|
61
|
+
# Ensure the file was uploaded successfully
|
|
62
|
+
uploaded: bool = False
|
|
63
|
+
for _ in range(0, self.MAX_CHECK_ATTEMPTS):
|
|
64
|
+
check_response = requests.head(url)
|
|
65
|
+
if check_response.status_code == 200:
|
|
66
|
+
uploaded = True
|
|
67
|
+
break
|
|
68
|
+
assert uploaded, f"File {file_path} was not uploaded successfully."
|
|
69
|
+
|
|
70
|
+
hlog(f"File {file_path} uploaded and is available at {url}.")
|
|
71
|
+
return {"url": url}
|
|
72
|
+
|
|
73
|
+
cache_key: Dict = asdict(request)
|
|
74
|
+
result, cached = self._cache.get(cache_key, do_it)
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise GCSClientError(e)
|
|
78
|
+
|
|
79
|
+
return FileUploadResult(success=True, cached=cached, url=result["url"])
|
|
80
|
+
|
|
81
|
+
def _get_url(self, path: str) -> str:
|
|
82
|
+
return f"https://storage.googleapis.com/{self._bucket_name}/{path}"
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
from typing import List, Dict
|
|
2
2
|
|
|
3
3
|
from helm.common.cache import CacheConfig
|
|
4
|
-
from helm.common.request import Request, RequestResult,
|
|
5
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
4
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
6
5
|
from .client import CachingClient, truncate_sequence
|
|
7
6
|
|
|
8
7
|
|
|
@@ -28,12 +27,12 @@ class GoogleClient(CachingClient):
|
|
|
28
27
|
"top_p": request.top_p,
|
|
29
28
|
}
|
|
30
29
|
|
|
31
|
-
def __init__(self,
|
|
32
|
-
super().__init__(cache_config=cache_config
|
|
30
|
+
def __init__(self, cache_config: CacheConfig):
|
|
31
|
+
super().__init__(cache_config=cache_config)
|
|
33
32
|
|
|
34
33
|
def make_request(self, request: Request) -> RequestResult:
|
|
35
34
|
raw_request = GoogleClient.convert_to_raw_request(request)
|
|
36
|
-
cache_key
|
|
35
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
37
36
|
|
|
38
37
|
try:
|
|
39
38
|
|
|
@@ -49,17 +48,17 @@ class GoogleClient(CachingClient):
|
|
|
49
48
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
50
49
|
|
|
51
50
|
# Expect the result to be structured the same way as a response from OpenAI API.
|
|
52
|
-
completions: List[
|
|
51
|
+
completions: List[GeneratedOutput] = []
|
|
53
52
|
for raw_completion in response["choices"]:
|
|
54
53
|
sequence_logprob = 0
|
|
55
54
|
tokens: List[Token] = []
|
|
56
55
|
|
|
57
56
|
raw_data = raw_completion["logprobs"]
|
|
58
57
|
for text, logprob in zip(raw_data["tokens"], raw_data["token_logprobs"]):
|
|
59
|
-
tokens.append(Token(text=text, logprob=logprob or 0
|
|
58
|
+
tokens.append(Token(text=text, logprob=logprob or 0))
|
|
60
59
|
sequence_logprob += logprob or 0
|
|
61
60
|
|
|
62
|
-
completion =
|
|
61
|
+
completion = GeneratedOutput(
|
|
63
62
|
text=raw_completion["text"],
|
|
64
63
|
logprob=sequence_logprob,
|
|
65
64
|
tokens=tokens,
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import Cache, SqliteCacheConfig
|
|
4
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from google.cloud import translate_v2 as translate # type: ignore
|
|
8
|
+
except ModuleNotFoundError as e:
|
|
9
|
+
handle_module_not_found_error(e, ["heim"])
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GoogleTranslateClient:
|
|
13
|
+
"""
|
|
14
|
+
Client for Google Translate.
|
|
15
|
+
Follow the instructions at https://cloud.google.com/translate/docs/setup to use this client.
|
|
16
|
+
|
|
17
|
+
# TODO: add this as a central service
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, cache_path: str = "prod_env/cache/google_translate.sqlite"):
|
|
21
|
+
self.translate_client: Optional[translate.Client] = None
|
|
22
|
+
self.cache = Cache(SqliteCacheConfig(cache_path))
|
|
23
|
+
|
|
24
|
+
def translate(self, text: str, target_language: str) -> str:
|
|
25
|
+
def do_it():
|
|
26
|
+
if self.translate_client is None:
|
|
27
|
+
self.translate_client = translate.Client()
|
|
28
|
+
|
|
29
|
+
result = self.translate_client.translate(text, target_language=target_language)
|
|
30
|
+
del result["input"]
|
|
31
|
+
assert "translatedText" in result, f"Invalid response: {result}"
|
|
32
|
+
return result
|
|
33
|
+
|
|
34
|
+
response, _ = self.cache.get({"text": text, "target_language": target_language}, do_it)
|
|
35
|
+
return response["translatedText"]
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from dataclasses import asdict
|
|
3
|
+
from typing import Any, Dict
|
|
3
4
|
|
|
4
5
|
from helm.common.cache import CacheConfig
|
|
5
6
|
from helm.common.request import (
|
|
6
7
|
wrap_request_time,
|
|
7
8
|
Request,
|
|
8
9
|
RequestResult,
|
|
9
|
-
|
|
10
|
+
GeneratedOutput,
|
|
10
11
|
Token,
|
|
11
12
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
12
13
|
)
|
|
13
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
14
14
|
from .client import CachingClient
|
|
15
15
|
|
|
16
16
|
import requests
|
|
@@ -21,13 +21,12 @@ class HTTPModelClient(CachingClient):
|
|
|
21
21
|
|
|
22
22
|
def __init__(
|
|
23
23
|
self,
|
|
24
|
-
tokenizer: Tokenizer,
|
|
25
24
|
cache_config: CacheConfig,
|
|
26
25
|
base_url: str = "http://localhost:8080",
|
|
27
26
|
timeout: int = 3000,
|
|
28
27
|
do_cache: bool = False,
|
|
29
28
|
):
|
|
30
|
-
super().__init__(cache_config=cache_config
|
|
29
|
+
super().__init__(cache_config=cache_config)
|
|
31
30
|
self.base_url = (
|
|
32
31
|
base_url if not os.environ.get("HELM_HTTP_MODEL_BASE_URL") else os.environ["HELM_HTTP_MODEL_BASE_URL"]
|
|
33
32
|
)
|
|
@@ -53,7 +52,7 @@ class HTTPModelClient(CachingClient):
|
|
|
53
52
|
|
|
54
53
|
try:
|
|
55
54
|
|
|
56
|
-
def do_it():
|
|
55
|
+
def do_it() -> Dict[str, Any]:
|
|
57
56
|
url = f"{self.base_url}/process"
|
|
58
57
|
response = requests.post(url, json=raw_request, timeout=self.timeout)
|
|
59
58
|
response.raise_for_status()
|
|
@@ -65,11 +64,8 @@ class HTTPModelClient(CachingClient):
|
|
|
65
64
|
else:
|
|
66
65
|
response, cached = do_it(), False
|
|
67
66
|
|
|
68
|
-
tokens = [
|
|
69
|
-
|
|
70
|
-
for token in response["tokens"]
|
|
71
|
-
]
|
|
72
|
-
completions = [Sequence(text=response["text"], logprob=response["logprob"], tokens=tokens)]
|
|
67
|
+
tokens = [Token(text=token["text"], logprob=token["logprob"]) for token in response["tokens"]]
|
|
68
|
+
completions = [GeneratedOutput(text=response["text"], logprob=response["logprob"], tokens=tokens)]
|
|
73
69
|
|
|
74
70
|
return RequestResult(
|
|
75
71
|
success=True,
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
import torch
|
|
3
|
-
from transformers import AutoModelForCausalLM
|
|
3
|
+
from transformers import AutoModelForCausalLM
|
|
4
4
|
from transformers.generation.stopping_criteria import (
|
|
5
5
|
StoppingCriteria,
|
|
6
6
|
StoppingCriteriaList,
|
|
7
7
|
)
|
|
8
|
-
from typing import Any, Dict, List, Optional
|
|
8
|
+
from typing import Any, Dict, List, Optional, TypedDict
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
11
|
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
@@ -14,12 +14,11 @@ from helm.common.request import (
|
|
|
14
14
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
15
15
|
Request,
|
|
16
16
|
RequestResult,
|
|
17
|
-
|
|
17
|
+
GeneratedOutput,
|
|
18
18
|
Token,
|
|
19
19
|
)
|
|
20
20
|
from .client import CachingClient, truncate_sequence
|
|
21
|
-
from helm.
|
|
22
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
21
|
+
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
|
|
23
22
|
from threading import Lock
|
|
24
23
|
|
|
25
24
|
|
|
@@ -37,108 +36,127 @@ class StopAtSpecificTokenCriteria(StoppingCriteria):
|
|
|
37
36
|
return bool(torch.all(current_sequence == stop_sequence_tensor).item())
|
|
38
37
|
|
|
39
38
|
|
|
39
|
+
class HuggingFaceRequest(TypedDict):
|
|
40
|
+
"""Data passed between make_request and serve_request. Used as the cache key."""
|
|
41
|
+
|
|
42
|
+
engine: str
|
|
43
|
+
prompt: str
|
|
44
|
+
temperature: float
|
|
45
|
+
num_return_sequences: int
|
|
46
|
+
max_new_tokens: int
|
|
47
|
+
top_p: float
|
|
48
|
+
echo_prompt: bool
|
|
49
|
+
top_k_per_token: int
|
|
50
|
+
stop_sequences: List
|
|
51
|
+
|
|
52
|
+
|
|
40
53
|
class HuggingFaceServer:
|
|
41
54
|
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
|
|
42
55
|
|
|
43
|
-
def __init__(self, pretrained_model_name_or_path: str,
|
|
56
|
+
def __init__(self, pretrained_model_name_or_path: str, **kwargs):
|
|
44
57
|
if torch.cuda.is_available():
|
|
45
58
|
hlog("CUDA is available, initializing with a GPU...")
|
|
46
59
|
self.device: str = "cuda:0"
|
|
47
60
|
else:
|
|
48
61
|
self.device = "cpu"
|
|
49
|
-
model_kwargs = {}
|
|
50
|
-
if revision:
|
|
51
|
-
model_kwargs["revision"] = revision
|
|
52
62
|
with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
|
|
53
63
|
# WARNING this may fail if your GPU does not have enough memory
|
|
54
64
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
55
|
-
pretrained_model_name_or_path, trust_remote_code=True, **
|
|
65
|
+
pretrained_model_name_or_path, trust_remote_code=True, **kwargs
|
|
56
66
|
).to(self.device)
|
|
57
67
|
with htrack_block(f"Loading Hugging Face tokenizer for model {pretrained_model_name_or_path}"):
|
|
58
|
-
self.
|
|
59
|
-
pretrained_model_name_or_path,
|
|
68
|
+
self.wrapped_tokenizer: WrappedPreTrainedTokenizer = HuggingFaceTokenizer.create_tokenizer(
|
|
69
|
+
pretrained_model_name_or_path, **kwargs
|
|
60
70
|
)
|
|
61
71
|
|
|
62
|
-
def serve_request(self, raw_request: Dict
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
raw_request["do_sample"] = True
|
|
68
|
-
raw_request["return_dict_in_generate"] = True
|
|
69
|
-
raw_request["output_scores"] = True
|
|
70
|
-
top_k_per_token: int = raw_request["top_k_per_token"]
|
|
71
|
-
del raw_request["top_k_per_token"]
|
|
72
|
+
def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
|
|
73
|
+
with self.wrapped_tokenizer as tokenizer:
|
|
74
|
+
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
|
|
75
|
+
self.device
|
|
76
|
+
)
|
|
72
77
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
|
78
|
+
optional_args = {}
|
|
73
79
|
if len(raw_request["stop_sequences"]) > 0:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
80
|
+
with self.wrapped_tokenizer as tokenizer:
|
|
81
|
+
stop_sequence_ids = tokenizer(
|
|
82
|
+
raw_request["stop_sequences"], return_token_type_ids=False, add_special_tokens=False
|
|
83
|
+
)
|
|
77
84
|
if len(stop_sequence_ids.input_ids) == 1 and len(stop_sequence_ids.input_ids[0]) == 1:
|
|
78
|
-
|
|
85
|
+
optional_args["eos_token_id"] = stop_sequence_ids.input_ids[0][0]
|
|
79
86
|
else:
|
|
80
87
|
stopping_criteria = StoppingCriteriaList()
|
|
81
88
|
for stop_sequence_input_ids in stop_sequence_ids.input_ids:
|
|
82
89
|
stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_input_ids))
|
|
83
|
-
del raw_request["stop_sequences"]
|
|
84
90
|
|
|
85
|
-
#
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
# Use HuggingFace's `generate` method.
|
|
93
|
-
output = self.model.generate(
|
|
94
|
-
**encoded_input,
|
|
95
|
-
**relevant_raw_request,
|
|
96
|
-
stopping_criteria=stopping_criteria,
|
|
91
|
+
# Check if we need to compute the perplexity of the prompt (#1497)
|
|
92
|
+
compute_logprobs_only = (
|
|
93
|
+
raw_request["max_new_tokens"] == 0
|
|
94
|
+
and raw_request["num_return_sequences"] == 1
|
|
95
|
+
and raw_request["echo_prompt"]
|
|
97
96
|
)
|
|
98
|
-
sequences = output.sequences
|
|
99
|
-
scores = output.scores
|
|
100
97
|
|
|
101
|
-
#
|
|
102
|
-
|
|
103
|
-
|
|
98
|
+
# Use HuggingFace's `generate` method.
|
|
99
|
+
if compute_logprobs_only:
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
output = self.model(encoded_input["input_ids"])
|
|
102
|
+
sequences = encoded_input["input_ids"]
|
|
103
|
+
scores = output.logits
|
|
104
|
+
else:
|
|
105
|
+
output = self.model.generate(
|
|
106
|
+
**encoded_input,
|
|
107
|
+
temperature=raw_request["temperature"],
|
|
108
|
+
num_return_sequences=raw_request["num_return_sequences"],
|
|
109
|
+
max_new_tokens=raw_request["max_new_tokens"],
|
|
110
|
+
top_p=raw_request["top_p"],
|
|
111
|
+
do_sample=True,
|
|
112
|
+
return_dict_in_generate=True,
|
|
113
|
+
output_scores=True,
|
|
114
|
+
**optional_args,
|
|
115
|
+
stopping_criteria=stopping_criteria,
|
|
116
|
+
)
|
|
117
|
+
sequences = output.sequences
|
|
118
|
+
scores = output.scores
|
|
119
|
+
|
|
120
|
+
prompt_tokens_logprobs = []
|
|
121
|
+
if compute_logprobs_only:
|
|
122
|
+
# Append the logprob of the first token of the prompt.
|
|
123
|
+
prompt_tokens_logprobs.append(0.0)
|
|
124
|
+
|
|
125
|
+
# Compute logprobs of prompt tokens.
|
|
126
|
+
for completion_id in range(raw_request["num_return_sequences"]):
|
|
127
|
+
for i in range(len(sequences[completion_id]) - 1):
|
|
128
|
+
logprobs = torch.nn.functional.log_softmax(scores[completion_id][i], dim=0)
|
|
129
|
+
prompt_tokens_logprobs.append(logprobs[sequences[completion_id][i + 1]].item())
|
|
130
|
+
|
|
131
|
+
# Compute logprobs of generated tokens for each completed sequence.
|
|
132
|
+
all_generated_tokens_logprobs = []
|
|
104
133
|
for completion_id in range(raw_request["num_return_sequences"]):
|
|
105
|
-
|
|
106
|
-
top_logprobs_dicts = []
|
|
134
|
+
generated_tokens_logprobs = []
|
|
107
135
|
for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])):
|
|
108
136
|
logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0)
|
|
109
|
-
|
|
110
|
-
# Get top tokens in terms of log probability.
|
|
111
|
-
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
|
|
112
|
-
top_logprobs_dicts.append(
|
|
113
|
-
{
|
|
114
|
-
self.tokenizer.convert_ids_to_tokens(k.item()): v.item()
|
|
115
|
-
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)
|
|
116
|
-
}
|
|
117
|
-
)
|
|
118
|
-
|
|
119
137
|
# Get log probability of chosen token.
|
|
120
138
|
j = i + len(encoded_input.input_ids[0])
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
all_top_logprobs_dicts.append(top_logprobs_dicts)
|
|
139
|
+
generated_tokens_logprobs.append(logprobs[sequences[completion_id][j]].item())
|
|
140
|
+
all_generated_tokens_logprobs.append(generated_tokens_logprobs)
|
|
124
141
|
|
|
125
142
|
# Remove prompt from the start of each sequence if echo_prompt is False.
|
|
126
143
|
if not raw_request["echo_prompt"]:
|
|
127
144
|
sequences = [sequence[len(encoded_input.input_ids[0]) :] for sequence in sequences]
|
|
128
145
|
|
|
129
|
-
|
|
130
|
-
|
|
146
|
+
with self.wrapped_tokenizer as tokenizer:
|
|
147
|
+
all_tokens = [[tokenizer.decode(token) for token in sequence_tokens] for sequence_tokens in sequences]
|
|
148
|
+
all_decoded_text = tokenizer.batch_decode(sequences)
|
|
131
149
|
|
|
132
150
|
completions = []
|
|
133
|
-
for decoded_text, tokens,
|
|
134
|
-
all_decoded_text, all_tokens,
|
|
151
|
+
for decoded_text, tokens, generated_tokens_logprobs in zip(
|
|
152
|
+
all_decoded_text, all_tokens, all_generated_tokens_logprobs
|
|
135
153
|
):
|
|
136
154
|
completions.append(
|
|
137
155
|
{
|
|
138
156
|
"text": decoded_text,
|
|
139
157
|
"tokens": tokens,
|
|
140
|
-
"logprobs":
|
|
141
|
-
"
|
|
158
|
+
"logprobs": generated_tokens_logprobs,
|
|
159
|
+
"prompt_logprobs": prompt_tokens_logprobs,
|
|
142
160
|
}
|
|
143
161
|
)
|
|
144
162
|
|
|
@@ -152,7 +170,7 @@ class HuggingFaceServerFactory:
|
|
|
152
170
|
_servers_lock: Lock = Lock()
|
|
153
171
|
|
|
154
172
|
@staticmethod
|
|
155
|
-
def get_server(helm_model_name: str, pretrained_model_name_or_path: str,
|
|
173
|
+
def get_server(helm_model_name: str, pretrained_model_name_or_path: str, **kwargs) -> Any:
|
|
156
174
|
"""
|
|
157
175
|
Checks if the desired HuggingFaceModel is cached. Creates the HuggingFaceModel if it's not cached.
|
|
158
176
|
Returns the HuggingFaceModel.
|
|
@@ -160,34 +178,53 @@ class HuggingFaceServerFactory:
|
|
|
160
178
|
with HuggingFaceServerFactory._servers_lock:
|
|
161
179
|
if helm_model_name not in HuggingFaceServerFactory._servers:
|
|
162
180
|
with htrack_block(
|
|
163
|
-
f"Loading {pretrained_model_name_or_path} (
|
|
181
|
+
f"Loading {pretrained_model_name_or_path} (kwargs={kwargs}) "
|
|
164
182
|
f"for HELM model {helm_model_name} with Hugging Face Transformers"
|
|
165
183
|
):
|
|
166
184
|
HuggingFaceServerFactory._servers[helm_model_name] = HuggingFaceServer(
|
|
167
|
-
pretrained_model_name_or_path,
|
|
185
|
+
pretrained_model_name_or_path, **kwargs
|
|
168
186
|
)
|
|
169
187
|
|
|
170
188
|
return HuggingFaceServerFactory._servers[helm_model_name]
|
|
171
189
|
|
|
172
190
|
|
|
191
|
+
TORCH_DTYPE_KEY = "torch_dtype"
|
|
192
|
+
TORCH_DTYPE_VALUE_PREFIX = "torch."
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _process_huggingface_client_kwargs(raw_kwargs: Dict[str, Any]):
|
|
196
|
+
"""Process the kwargs for HuggingFaceClient.
|
|
197
|
+
|
|
198
|
+
The kwargs passed to HuggingFaceClient will eventually be passed to AutoModel.from_pretrained().
|
|
199
|
+
Since the kwargs from HuggingFaceClient may be derived from configuration YAML,
|
|
200
|
+
they may contain primitive types instead of the unserializable types that
|
|
201
|
+
AutoModel.from_pretrained() expects (e.g. torch_dtype). This function converts values of
|
|
202
|
+
primitive types to values of the unserializable types."""
|
|
203
|
+
processed_kwargs = deepcopy(raw_kwargs)
|
|
204
|
+
|
|
205
|
+
# Convert torch_dtype string value to actual dtypes
|
|
206
|
+
# e.g. the string "torch.bfloat16" is converted to torch.bfloat16
|
|
207
|
+
torch_dtype = processed_kwargs.get(TORCH_DTYPE_KEY)
|
|
208
|
+
if torch_dtype and isinstance(torch_dtype, str):
|
|
209
|
+
if not torch_dtype.startswith(TORCH_DTYPE_VALUE_PREFIX):
|
|
210
|
+
raise ValueError(f'Unknown dtype "{torch_dtype}"; expected a string such as "torch.bfloat16"')
|
|
211
|
+
processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
|
|
212
|
+
|
|
213
|
+
return processed_kwargs
|
|
214
|
+
|
|
215
|
+
|
|
173
216
|
class HuggingFaceClient(CachingClient):
|
|
174
|
-
def __init__(
|
|
175
|
-
|
|
176
|
-
tokenizer: Tokenizer,
|
|
177
|
-
cache_config: CacheConfig,
|
|
178
|
-
pretrained_model_name_or_path: Optional[str] = None,
|
|
179
|
-
revision: Optional[str] = None,
|
|
180
|
-
):
|
|
181
|
-
super().__init__(cache_config=cache_config, tokenizer=tokenizer)
|
|
217
|
+
def __init__(self, cache_config: CacheConfig, pretrained_model_name_or_path: Optional[str] = None, **kwargs):
|
|
218
|
+
super().__init__(cache_config=cache_config)
|
|
182
219
|
self._pretrained_model_name_or_path = pretrained_model_name_or_path
|
|
183
|
-
self.
|
|
220
|
+
self._kwargs = _process_huggingface_client_kwargs(kwargs)
|
|
184
221
|
|
|
185
222
|
def make_request(self, request: Request) -> RequestResult:
|
|
186
223
|
# Embedding not supported for this model
|
|
187
224
|
if request.embedding:
|
|
188
225
|
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
|
|
189
226
|
|
|
190
|
-
raw_request = {
|
|
227
|
+
raw_request: HuggingFaceRequest = {
|
|
191
228
|
"engine": request.model_engine,
|
|
192
229
|
"prompt": request.prompt,
|
|
193
230
|
"temperature": 1e-7 if request.temperature == 0 else request.temperature,
|
|
@@ -199,20 +236,18 @@ class HuggingFaceClient(CachingClient):
|
|
|
199
236
|
"stop_sequences": request.stop_sequences,
|
|
200
237
|
}
|
|
201
238
|
|
|
202
|
-
pretrained_model_name_or_path
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
else:
|
|
206
|
-
pretrained_model_name_or_path = resolve_alias(request.model)
|
|
239
|
+
pretrained_model_name_or_path = (
|
|
240
|
+
self._pretrained_model_name_or_path if self._pretrained_model_name_or_path else request.model
|
|
241
|
+
)
|
|
207
242
|
huggingface_model: HuggingFaceServer = HuggingFaceServerFactory.get_server(
|
|
208
243
|
helm_model_name=request.model,
|
|
209
244
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
210
|
-
|
|
245
|
+
**self._kwargs,
|
|
211
246
|
)
|
|
212
247
|
|
|
213
248
|
try:
|
|
214
249
|
|
|
215
|
-
def do_it():
|
|
250
|
+
def do_it() -> Dict[str, Any]:
|
|
216
251
|
return huggingface_model.serve_request(raw_request)
|
|
217
252
|
|
|
218
253
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
@@ -229,19 +264,26 @@ class HuggingFaceClient(CachingClient):
|
|
|
229
264
|
if request.echo_prompt:
|
|
230
265
|
# Add prompt to list of generated tokens.
|
|
231
266
|
generated_tokens = raw_completion["tokens"][response["input_length"] :]
|
|
232
|
-
|
|
233
|
-
|
|
267
|
+
if raw_completion.get("prompt_logprobs"):
|
|
268
|
+
for token_text, logprob in zip(
|
|
269
|
+
raw_completion["tokens"][: response["input_length"]],
|
|
270
|
+
raw_completion["prompt_logprobs"][: response["input_length"]],
|
|
271
|
+
):
|
|
272
|
+
tokens.append(Token(text=token_text, logprob=logprob))
|
|
273
|
+
sequence_logprob += logprob
|
|
274
|
+
else:
|
|
275
|
+
for token_text in raw_completion["tokens"][: response["input_length"]]:
|
|
276
|
+
tokens.append(Token(text=token_text, logprob=0.0))
|
|
277
|
+
|
|
234
278
|
else:
|
|
235
279
|
generated_tokens = raw_completion["tokens"]
|
|
236
280
|
|
|
237
281
|
# Compute logprob for the entire sequence.
|
|
238
|
-
for token_text, logprob
|
|
239
|
-
|
|
240
|
-
):
|
|
241
|
-
tokens.append(Token(text=token_text, logprob=logprob, top_logprobs=top_logprobs_dict))
|
|
282
|
+
for token_text, logprob in zip(generated_tokens, raw_completion["logprobs"]):
|
|
283
|
+
tokens.append(Token(text=token_text, logprob=logprob))
|
|
242
284
|
sequence_logprob += logprob
|
|
243
285
|
|
|
244
|
-
completion =
|
|
286
|
+
completion = GeneratedOutput(text=raw_completion["text"], logprob=sequence_logprob, tokens=tokens)
|
|
245
287
|
completion = truncate_sequence(completion, request)
|
|
246
288
|
completions.append(completion)
|
|
247
289
|
|
|
File without changes
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import Cache, CacheConfig
|
|
4
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput
|
|
5
|
+
from helm.common.tokenization_request import (
|
|
6
|
+
TokenizationRequest,
|
|
7
|
+
TokenizationRequestResult,
|
|
8
|
+
DecodeRequest,
|
|
9
|
+
DecodeRequestResult,
|
|
10
|
+
)
|
|
11
|
+
from helm.clients.client import Client, CachingClient
|
|
12
|
+
from .image_generation_client_utils import get_single_image_multimedia_object
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AdobeVisionClient(Client):
|
|
16
|
+
"""
|
|
17
|
+
Client for Adobe vision models. Offline eval only.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
SUPPORTED_MODELS: List[str] = ["giga-gan", "firefly"]
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def convert_to_raw_request(request: Request) -> Dict:
|
|
24
|
+
# Use default hyperparameters for everything else
|
|
25
|
+
raw_request: Dict = {
|
|
26
|
+
"request_type": "image-model-inference",
|
|
27
|
+
"model": request.model_engine,
|
|
28
|
+
"prompt": request.prompt,
|
|
29
|
+
"n": request.num_completions,
|
|
30
|
+
}
|
|
31
|
+
if request.random is not None:
|
|
32
|
+
raw_request["random"] = request.random
|
|
33
|
+
return raw_request
|
|
34
|
+
|
|
35
|
+
def __init__(self, cache_config: CacheConfig):
|
|
36
|
+
self._cache = Cache(cache_config)
|
|
37
|
+
self._promptist_model = None
|
|
38
|
+
self._promptist_tokenizer = None
|
|
39
|
+
|
|
40
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
41
|
+
if request.model_engine not in self.SUPPORTED_MODELS:
|
|
42
|
+
raise ValueError(f"Unsupported model: {request.model_engine}")
|
|
43
|
+
|
|
44
|
+
raw_request = AdobeVisionClient.convert_to_raw_request(request)
|
|
45
|
+
raw_request.pop("random", None)
|
|
46
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
|
|
50
|
+
def fail():
|
|
51
|
+
raise RuntimeError(
|
|
52
|
+
f"The result has not been uploaded to the cache for the following request: {cache_key}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
response, cached = self._cache.get(cache_key, fail)
|
|
56
|
+
except RuntimeError as e:
|
|
57
|
+
error: str = f"Adobe Vision Client error: {e}"
|
|
58
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
59
|
+
|
|
60
|
+
completions: List[GeneratedOutput] = [
|
|
61
|
+
GeneratedOutput(
|
|
62
|
+
text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_path)
|
|
63
|
+
)
|
|
64
|
+
for file_path in response["images"]
|
|
65
|
+
]
|
|
66
|
+
return RequestResult(
|
|
67
|
+
success=True,
|
|
68
|
+
cached=cached,
|
|
69
|
+
request_time=response["request_time"],
|
|
70
|
+
completions=completions,
|
|
71
|
+
embedding=[],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
75
|
+
raise NotImplementedError("This client does not support tokenizing.")
|
|
76
|
+
|
|
77
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
78
|
+
raise NotImplementedError("This client does not support decoding.")
|