crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional
|
|
6
|
+
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
10
|
+
from helm.clients.bedrock_utils import get_bedrock_client
|
|
11
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
JSON_CONTENT_TYPE = "application/json"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BedrockClient(CachingClient):
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
20
|
+
raise NotImplementedError()
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
24
|
+
raise NotImplementedError()
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
cache_config: CacheConfig,
|
|
29
|
+
tokenizer: Tokenizer,
|
|
30
|
+
tokenizer_name: str,
|
|
31
|
+
bedrock_model_id: Optional[str] = None,
|
|
32
|
+
assumed_role: Optional[str] = None,
|
|
33
|
+
region: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(cache_config=cache_config)
|
|
36
|
+
self.tokenizer = tokenizer
|
|
37
|
+
self.tokenizer_name = tokenizer_name
|
|
38
|
+
self.bedrock_model_id = bedrock_model_id
|
|
39
|
+
self.bedrock_client = get_bedrock_client(
|
|
40
|
+
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
|
|
41
|
+
region=region or os.environ.get("AWS_DEFAULT_REGION", None),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
45
|
+
# model_id should be something like "amazon.titan-tg1-large"
|
|
46
|
+
model_id = self.bedrock_model_id if self.bedrock_model_id else request.model.replace("/", ".")
|
|
47
|
+
raw_request = self.convert_request_to_raw_request(request)
|
|
48
|
+
|
|
49
|
+
# modelId isn't part of raw_request, so it must be explicitly passed into the input to
|
|
50
|
+
raw_request_for_cache: Dict = {"modelId": model_id, **deepcopy(raw_request)}
|
|
51
|
+
cache_key: Mapping = CachingClient.make_cache_key(raw_request_for_cache, request)
|
|
52
|
+
|
|
53
|
+
def do_it() -> Dict[Any, Any]:
|
|
54
|
+
response = self.bedrock_client.invoke_model(
|
|
55
|
+
body=json.dumps(raw_request), modelId=model_id, accept=JSON_CONTENT_TYPE, contentType=JSON_CONTENT_TYPE
|
|
56
|
+
)
|
|
57
|
+
return json.loads(response.get("body").read())
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
61
|
+
except Exception as error:
|
|
62
|
+
return RequestResult(
|
|
63
|
+
success=False,
|
|
64
|
+
cached=False,
|
|
65
|
+
error=str(error),
|
|
66
|
+
completions=[],
|
|
67
|
+
embedding=[],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
completions = self.convert_raw_response_to_completions(response, request)
|
|
71
|
+
|
|
72
|
+
return RequestResult(
|
|
73
|
+
success=True,
|
|
74
|
+
cached=cached,
|
|
75
|
+
request_time=response["request_time"],
|
|
76
|
+
request_datetime=response["request_datetime"],
|
|
77
|
+
completions=completions,
|
|
78
|
+
embedding=[],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BedrockTitanClient(BedrockClient):
|
|
83
|
+
_COMPLETION_REASON_TO_FINISH_REASON = {
|
|
84
|
+
"LENGTH": "length",
|
|
85
|
+
"FINISH": "endoftext",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
89
|
+
# TODO: Support the following:
|
|
90
|
+
# - top_k_per_token
|
|
91
|
+
# - echo_prompt
|
|
92
|
+
# - num_completions
|
|
93
|
+
return {
|
|
94
|
+
"inputText": request.prompt,
|
|
95
|
+
"textGenerationConfig": {
|
|
96
|
+
"maxTokenCount": request.max_tokens,
|
|
97
|
+
# We ignore stop sequences in the request and always set stop sequences to the empty list.
|
|
98
|
+
# This is because:
|
|
99
|
+
#
|
|
100
|
+
# 1. The only permitted stop sequences are "|" and "User:"
|
|
101
|
+
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
|
102
|
+
# - https://github.com/boto/boto3/issues/3993
|
|
103
|
+
# - https://github.com/aws/aws-sdk/issues/692
|
|
104
|
+
#
|
|
105
|
+
# 2. Titan has the tendency to emit "\n" as the first token in the generated text output,
|
|
106
|
+
# which would cause the output to stop immediately if "\n" is in the stop_sequences.
|
|
107
|
+
"stopSequences": [],
|
|
108
|
+
"temperature": request.temperature,
|
|
109
|
+
"topP": request.top_p,
|
|
110
|
+
},
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
114
|
+
# TODO: Support the following:
|
|
115
|
+
# - tokens
|
|
116
|
+
# - logprob
|
|
117
|
+
completions: List[GeneratedOutput] = []
|
|
118
|
+
for raw_completion in response["results"]:
|
|
119
|
+
output_text = raw_completion["outputText"]
|
|
120
|
+
# Call lstrip() Titan has the tendency to emit "\n" as the first token in the generated text output.
|
|
121
|
+
finish_reason = BedrockTitanClient._COMPLETION_REASON_TO_FINISH_REASON.get(
|
|
122
|
+
raw_completion["completionReason"], raw_completion["completionReason"].lower()
|
|
123
|
+
)
|
|
124
|
+
completion = truncate_and_tokenize_response_text(
|
|
125
|
+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
|
|
126
|
+
)
|
|
127
|
+
completions.append(completion)
|
|
128
|
+
return completions
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Helper utilities for working with Amazon Bedrock."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from helm.common.hierarchical_logger import hlog
|
|
7
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import boto3
|
|
11
|
+
from botocore.config import Config
|
|
12
|
+
except ModuleNotFoundError as e:
|
|
13
|
+
handle_module_not_found_error(e, ["aws"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# From https://github.com/aws-samples/amazon-bedrock-workshop/blob/main/01_Generation/00_generate_w_bedrock.ipynb
|
|
17
|
+
# MIT-0 Licensed
|
|
18
|
+
def get_bedrock_client(
|
|
19
|
+
assumed_role: Optional[str] = None,
|
|
20
|
+
region: Optional[str] = None,
|
|
21
|
+
runtime: Optional[bool] = True,
|
|
22
|
+
):
|
|
23
|
+
"""Create a boto3 client for Amazon Bedrock, with optional configuration overrides
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
assumed_role :
|
|
28
|
+
Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not
|
|
29
|
+
specified, the current active credentials will be used.
|
|
30
|
+
region :
|
|
31
|
+
Optional name of the AWS Region in which the service should be called (e.g. "us-east-1").
|
|
32
|
+
If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used.
|
|
33
|
+
runtime :
|
|
34
|
+
Optional choice of getting different client to perform operations with the Amazon Bedrock service.
|
|
35
|
+
"""
|
|
36
|
+
if region is None:
|
|
37
|
+
target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
|
|
38
|
+
else:
|
|
39
|
+
target_region = region
|
|
40
|
+
|
|
41
|
+
session_kwargs = {"region_name": target_region}
|
|
42
|
+
client_kwargs = {**session_kwargs}
|
|
43
|
+
|
|
44
|
+
profile_name = os.environ.get("AWS_PROFILE")
|
|
45
|
+
if profile_name:
|
|
46
|
+
session_kwargs["profile_name"] = profile_name
|
|
47
|
+
|
|
48
|
+
retry_config = Config(
|
|
49
|
+
region_name=target_region,
|
|
50
|
+
retries={
|
|
51
|
+
"max_attempts": 10,
|
|
52
|
+
"mode": "standard",
|
|
53
|
+
},
|
|
54
|
+
)
|
|
55
|
+
session = boto3.Session(**session_kwargs)
|
|
56
|
+
|
|
57
|
+
if assumed_role:
|
|
58
|
+
sts = session.client("sts")
|
|
59
|
+
response = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="crfm-helm")
|
|
60
|
+
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
|
|
61
|
+
client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"]
|
|
62
|
+
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
|
|
63
|
+
|
|
64
|
+
if runtime:
|
|
65
|
+
service_name = "bedrock-runtime"
|
|
66
|
+
else:
|
|
67
|
+
service_name = "bedrock"
|
|
68
|
+
|
|
69
|
+
bedrock_client = session.client(service_name=service_name, config=retry_config, **client_kwargs)
|
|
70
|
+
|
|
71
|
+
hlog(f"Amazon Bedrock client successfully created with endpoint {bedrock_client._endpoint}")
|
|
72
|
+
return bedrock_client
|
|
@@ -1,49 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import List, Mapping, Optional, cast
|
|
4
4
|
|
|
5
5
|
from helm.common.hierarchical_logger import hlog
|
|
6
6
|
from helm.common.media_object import MultimediaObject, TEXT_TYPE
|
|
7
|
-
from helm.common.request import Request, RequestResult,
|
|
8
|
-
from helm.common.tokenization_request import (
|
|
9
|
-
TokenizationRequest,
|
|
10
|
-
TokenizationRequestResult,
|
|
11
|
-
DecodeRequest,
|
|
12
|
-
DecodeRequestResult,
|
|
13
|
-
)
|
|
7
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
14
8
|
from helm.common.cache import Cache, CacheConfig
|
|
15
|
-
from helm.
|
|
9
|
+
from helm.common.tokenization_request import DecodeRequest, TokenizationRequest
|
|
10
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
16
11
|
|
|
17
12
|
|
|
18
13
|
class Client(ABC):
|
|
19
|
-
# TODO: This method should be removed.
|
|
20
|
-
# This only kept for the AutoClient. Eventually, we should introduce an
|
|
21
|
-
# AutoTokenizer or TokenizerFactory class.
|
|
22
|
-
@abstractmethod
|
|
23
|
-
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
24
|
-
"""Tokenizes `request.text` using `request.tokenizer`.
|
|
25
|
-
|
|
26
|
-
This simply calls the `tokenize` method of the tokenizer.
|
|
27
|
-
Some exceptions can be made (but should be avoided).
|
|
28
|
-
This is the case for the auto client, which needs to handle
|
|
29
|
-
tokenization for multiple tokenizers.
|
|
30
|
-
"""
|
|
31
|
-
pass
|
|
32
|
-
|
|
33
|
-
# TODO: This method should be removed.
|
|
34
|
-
# This only kept for the AutoClient. Eventually, we should introduce an
|
|
35
|
-
# AutoTokenizer or TokenizerFactory class.
|
|
36
|
-
@abstractmethod
|
|
37
|
-
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
38
|
-
"""Decodes `request.tokens` using `request.tokenizer`.
|
|
39
|
-
|
|
40
|
-
This simply calls the `decode` method of the tokenizer.
|
|
41
|
-
Some exceptions can be made (but should be avoided).
|
|
42
|
-
This is the case for the auto client, which needs to handle
|
|
43
|
-
tokenization for multiple tokenizers.
|
|
44
|
-
"""
|
|
45
|
-
pass
|
|
46
|
-
|
|
47
14
|
@abstractmethod
|
|
48
15
|
def make_request(self, request: Request) -> RequestResult:
|
|
49
16
|
"""Makes a request to the model.
|
|
@@ -54,7 +21,7 @@ class Client(ABC):
|
|
|
54
21
|
|
|
55
22
|
|
|
56
23
|
class CachingClient(Client):
|
|
57
|
-
def __init__(self, cache_config: CacheConfig
|
|
24
|
+
def __init__(self, cache_config: CacheConfig) -> None:
|
|
58
25
|
"""Initializes the client.
|
|
59
26
|
|
|
60
27
|
For most clients, both the cache config and tokenizer are required.
|
|
@@ -63,37 +30,30 @@ class CachingClient(Client):
|
|
|
63
30
|
the request is made.
|
|
64
31
|
"""
|
|
65
32
|
self.cache = Cache(cache_config) if cache_config is not None else None
|
|
66
|
-
self.tokenizer = tokenizer
|
|
67
33
|
|
|
68
34
|
@staticmethod
|
|
69
|
-
def make_cache_key(raw_request:
|
|
35
|
+
def make_cache_key(raw_request: Mapping, request: Request) -> Mapping:
|
|
70
36
|
"""
|
|
71
37
|
Construct the key for the cache using the raw request.
|
|
72
38
|
Add `request.random` to the key, if defined.
|
|
73
39
|
"""
|
|
74
40
|
if request.random is not None:
|
|
75
41
|
assert "random" not in raw_request
|
|
76
|
-
cache_key = {**raw_request, "random": request.random}
|
|
42
|
+
cache_key: Mapping = {**raw_request, "random": request.random}
|
|
77
43
|
else:
|
|
78
44
|
cache_key = raw_request
|
|
79
45
|
return cache_key
|
|
80
46
|
|
|
81
|
-
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
82
|
-
# Deprecated - use `self.tokenizer.tokenize` instead. Warn the user.
|
|
83
|
-
hlog("WARNING: CachingClient.tokenize is deprecated, use self.tokenizer.tokenize instead")
|
|
84
|
-
return self.tokenizer.tokenize(request)
|
|
85
|
-
|
|
86
|
-
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
87
|
-
# Deprecated - use `self.tokenizer.decode` instead. Warn the user.
|
|
88
|
-
hlog("WARNING: CachingClient.decode is deprecated, use self.tokenizer.decode instead")
|
|
89
|
-
return self.tokenizer.decode(request)
|
|
90
|
-
|
|
91
47
|
|
|
92
|
-
def truncate_sequence(sequence:
|
|
48
|
+
def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning: bool = True) -> GeneratedOutput:
|
|
93
49
|
"""
|
|
94
50
|
Certain providers have bugs where they aren't respecting max_tokens,
|
|
95
51
|
stop_sequences and the end of text token, so as a hack, we have to manually
|
|
96
52
|
truncate the suffix of `sequence` and `tokens` as a post-hoc process.
|
|
53
|
+
|
|
54
|
+
This method is unsafe and may produce warnings or incorrect results.
|
|
55
|
+
Prefer using the safer truncate_and_tokenize_response_text() method instead
|
|
56
|
+
if your use case satisfies its requirements.
|
|
97
57
|
"""
|
|
98
58
|
# TODO: if echo_prompt, then we should only ignore the prompt, but we don't
|
|
99
59
|
# know how many tokens the prompt takes up.
|
|
@@ -133,7 +93,7 @@ def truncate_sequence(sequence: Sequence, request: Request, print_warning: bool
|
|
|
133
93
|
if print_warning:
|
|
134
94
|
hlog(f"WARNING: truncate_sequence needs to strip {json.dumps(stop)}")
|
|
135
95
|
|
|
136
|
-
sequence =
|
|
96
|
+
sequence = GeneratedOutput(text=new_text, logprob=new_logprob, tokens=new_tokens)
|
|
137
97
|
|
|
138
98
|
# Truncate based on the max number of tokens.
|
|
139
99
|
if len(sequence.tokens) > request.max_tokens:
|
|
@@ -150,11 +110,63 @@ def truncate_sequence(sequence: Sequence, request: Request, print_warning: bool
|
|
|
150
110
|
|
|
151
111
|
new_logprob = sum(token.logprob for token in new_tokens)
|
|
152
112
|
|
|
153
|
-
sequence =
|
|
113
|
+
sequence = GeneratedOutput(text=new_text, logprob=new_logprob, tokens=new_tokens)
|
|
154
114
|
|
|
155
115
|
return sequence
|
|
156
116
|
|
|
157
117
|
|
|
118
|
+
def truncate_and_tokenize_response_text(
|
|
119
|
+
text: str, request: Request, tokenizer: Tokenizer, tokenizer_name: str, original_finish_reason: str = "endoftext"
|
|
120
|
+
) -> GeneratedOutput:
|
|
121
|
+
"""Truncate a string-only response to respect stop_sequences and max_tokens.
|
|
122
|
+
|
|
123
|
+
This can only be used if all of the following conditions are true:
|
|
124
|
+
|
|
125
|
+
- You have access to the tokenizer.
|
|
126
|
+
- The request has echo_prompt = False.
|
|
127
|
+
- The tokenizer supports encoding and decoding.
|
|
128
|
+
- The tokenizer's tokenize() method supports truncation.
|
|
129
|
+
- The model's response is text-only.
|
|
130
|
+
- The model's response not already provide the tokenized text.
|
|
131
|
+
- The model's response does not provide logprobs.
|
|
132
|
+
|
|
133
|
+
This method is safer than truncate_sequence() and should be preferred if the above conditions are met.
|
|
134
|
+
Unlike truncate_sequence(), this method will not produce warnings or incorrect results.
|
|
135
|
+
This is because the the tokens are derived from the truncated text using the tokenizer,
|
|
136
|
+
so the text and the tokens in the resulting result are guranteed to match."""
|
|
137
|
+
# Finish reason strings are token from basic_metrics._compute_finish_reason_metrics()
|
|
138
|
+
finish_reason: str = original_finish_reason
|
|
139
|
+
if request.echo_prompt:
|
|
140
|
+
raise Exception("truncate_and_tokenize_response_text() does not support requests with echo_prompt = True")
|
|
141
|
+
|
|
142
|
+
for stop_sequence in request.stop_sequences:
|
|
143
|
+
try:
|
|
144
|
+
text = text[: text.index(stop_sequence)]
|
|
145
|
+
finish_reason = "stop"
|
|
146
|
+
except ValueError:
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
token_strings = cast(
|
|
150
|
+
List[str], tokenizer.tokenize(TokenizationRequest(text=text, tokenizer=tokenizer_name)).raw_tokens
|
|
151
|
+
)
|
|
152
|
+
if len(token_strings) > request.max_tokens:
|
|
153
|
+
encoded_ints = cast(
|
|
154
|
+
List[int],
|
|
155
|
+
tokenizer.tokenize(
|
|
156
|
+
TokenizationRequest(
|
|
157
|
+
text=text, tokenizer=tokenizer_name, encode=True, truncation=True, max_length=request.max_tokens
|
|
158
|
+
)
|
|
159
|
+
).raw_tokens,
|
|
160
|
+
)
|
|
161
|
+
text = tokenizer.decode(DecodeRequest(encoded_ints, tokenizer_name)).text
|
|
162
|
+
token_strings = cast(
|
|
163
|
+
List[str], tokenizer.tokenize(TokenizationRequest(text=text, tokenizer=tokenizer_name)).raw_tokens
|
|
164
|
+
)
|
|
165
|
+
finish_reason = "length"
|
|
166
|
+
tokens = [Token(text=token_string, logprob=0.0) for token_string in token_strings]
|
|
167
|
+
return GeneratedOutput(text=text, logprob=0.0, tokens=tokens, finish_reason={"reason": finish_reason})
|
|
168
|
+
|
|
169
|
+
|
|
158
170
|
def cleanup_str(token: str, tokenizer_name: Optional[str] = None) -> str:
|
|
159
171
|
"""
|
|
160
172
|
Certain tokenizers introduce special characters to represent spaces, such as
|
|
@@ -171,7 +183,7 @@ def cleanup_str(token: str, tokenizer_name: Optional[str] = None) -> str:
|
|
|
171
183
|
"together",
|
|
172
184
|
]:
|
|
173
185
|
return token.replace("▁", " ")
|
|
174
|
-
elif tokenizer_name is not None and tokenizer_name.startswith("huggingface"):
|
|
186
|
+
elif tokenizer_name is not None and (tokenizer_name.startswith("huggingface") or tokenizer_name.endswith("gpt2")):
|
|
175
187
|
return token.replace("Ġ", " ")
|
|
176
188
|
return token
|
|
177
189
|
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
from dataclasses import asdict
|
|
3
|
+
|
|
4
|
+
from helm.common.cache import Cache, CacheConfig
|
|
5
|
+
from helm.common.clip_score_request import DEFAULT_CLIP_SCORE_MODEL, CLIPScoreRequest, CLIPScoreResult
|
|
6
|
+
from helm.clients.clip_scorers.base_clip_scorer import BaseCLIPScorer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CLIPScoreClientError(Exception):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CLIPScoreClient:
|
|
14
|
+
def __init__(self, cache_config: CacheConfig):
|
|
15
|
+
self.cache = Cache(cache_config)
|
|
16
|
+
self._clip_scorer: Optional[BaseCLIPScorer] = None
|
|
17
|
+
|
|
18
|
+
def compute_score(self, request: CLIPScoreRequest) -> CLIPScoreResult:
|
|
19
|
+
"""
|
|
20
|
+
Compute a CLIPScore for a given caption and image.
|
|
21
|
+
"""
|
|
22
|
+
# TODO: support multilingual CLIPScore and other CLIP models.
|
|
23
|
+
assert request.model == DEFAULT_CLIP_SCORE_MODEL, f"Unsupported model: {request.model}"
|
|
24
|
+
assert not request.multilingual
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
|
|
28
|
+
def do_it():
|
|
29
|
+
if self._clip_scorer is None:
|
|
30
|
+
from helm.clients.clip_scorers.clip_scorer import CLIPScorer
|
|
31
|
+
|
|
32
|
+
self._clip_scorer = CLIPScorer()
|
|
33
|
+
|
|
34
|
+
score: float = self._clip_scorer.compute_score(
|
|
35
|
+
caption=request.caption, image_location=request.image_location
|
|
36
|
+
)
|
|
37
|
+
return {"score": score}
|
|
38
|
+
|
|
39
|
+
cache_key: Dict = asdict(request)
|
|
40
|
+
results, cached = self.cache.get(cache_key, do_it)
|
|
41
|
+
|
|
42
|
+
except Exception as e:
|
|
43
|
+
raise CLIPScoreClientError(e)
|
|
44
|
+
|
|
45
|
+
return CLIPScoreResult(
|
|
46
|
+
success=True,
|
|
47
|
+
cached=cached,
|
|
48
|
+
score=results["score"],
|
|
49
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from abc import abstractmethod, ABC
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseCLIPScorer(ABC):
|
|
6
|
+
@abstractmethod
|
|
7
|
+
def compute_score(self, caption: str, image_location: str) -> float:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
def select_best_image(self, caption: str, image_locations: List[str]) -> str:
|
|
11
|
+
"""Selects the image from a list of images with the highest CLIPScore given the caption."""
|
|
12
|
+
assert len(image_locations) > 0, "Need at least one image"
|
|
13
|
+
|
|
14
|
+
if len(image_locations) == 1:
|
|
15
|
+
return image_locations[0]
|
|
16
|
+
|
|
17
|
+
scores: List[float] = [self.compute_score(caption, image_location) for image_location in image_locations]
|
|
18
|
+
return image_locations[scores.index(max(scores))]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from torchvision import transforms
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from helm.common.gpu_utils import get_torch_device
|
|
7
|
+
from helm.common.images_utils import open_image
|
|
8
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
9
|
+
from .base_clip_scorer import BaseCLIPScorer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_ = torch.manual_seed(42)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CLIPScorer(BaseCLIPScorer):
|
|
16
|
+
"""
|
|
17
|
+
CLIPScore is a reference free metric that can be used to evaluate the correlation between an image
|
|
18
|
+
caption and the content of the image. It has been found to be highly correlated with human judgement.
|
|
19
|
+
Paper: https://arxiv.org/abs/2104.08718
|
|
20
|
+
|
|
21
|
+
We use the TorchMetrics implementation:
|
|
22
|
+
https://torchmetrics.readthedocs.io/en/stable/multimodal/clip_score.html.
|
|
23
|
+
The score is bound between 0 and 100, where a score closer to 100 is better.
|
|
24
|
+
|
|
25
|
+
Verified implementation against the scores of image-caption pairs from
|
|
26
|
+
https://wandb.ai/dalle-mini/dalle-mini/reports/OpenAI-CLIP-Score-exploration--VmlldzoxNjMwODM1.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
model_name: Literal[
|
|
32
|
+
"openai/clip-vit-base-patch16",
|
|
33
|
+
"openai/clip-vit-base-patch32",
|
|
34
|
+
"openai/clip-vit-large-patch14-336",
|
|
35
|
+
"openai/clip-vit-large-patch14",
|
|
36
|
+
] = "openai/clip-vit-large-patch14",
|
|
37
|
+
):
|
|
38
|
+
try:
|
|
39
|
+
from torchmetrics.multimodal import CLIPScore
|
|
40
|
+
except ModuleNotFoundError as e:
|
|
41
|
+
handle_module_not_found_error(e, ["heim"])
|
|
42
|
+
|
|
43
|
+
self._device: torch.device = get_torch_device()
|
|
44
|
+
self._metric = CLIPScore(model_name_or_path=model_name).to(self._device)
|
|
45
|
+
|
|
46
|
+
def compute_score(self, caption: str, image_location: str) -> float:
|
|
47
|
+
image = open_image(image_location)
|
|
48
|
+
image_tensor: torch.Tensor = transforms.ToTensor()(image).to(self._device)
|
|
49
|
+
score: float = self._metric(image_tensor, caption).detach().item()
|
|
50
|
+
return score
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import transformers
|
|
3
|
+
|
|
4
|
+
from helm.common.gpu_utils import get_torch_device, get_torch_device_name
|
|
5
|
+
from helm.common.images_utils import open_image
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
7
|
+
from .base_clip_scorer import BaseCLIPScorer
|
|
8
|
+
|
|
9
|
+
_ = torch.manual_seed(42)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MultilingualCLIPScorer(BaseCLIPScorer):
|
|
13
|
+
"""
|
|
14
|
+
Multilingual-CLIP extends OpenAI's English text encoders to multiple other languages.
|
|
15
|
+
Adapted from https://huggingface.co/M-CLIP/XLM-Roberta-Large-Vit-L-14
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
TEXT_MODEL_NAME: str = "M-CLIP/XLM-Roberta-Large-Vit-L-14"
|
|
19
|
+
IMAGE_MODEL_NAME: str = "ViT-L/14"
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
try:
|
|
23
|
+
import clip
|
|
24
|
+
from multilingual_clip import pt_multilingual_clip
|
|
25
|
+
except ModuleNotFoundError as e:
|
|
26
|
+
handle_module_not_found_error(e, ["heim"])
|
|
27
|
+
|
|
28
|
+
super().__init__()
|
|
29
|
+
self._device: torch.device = get_torch_device()
|
|
30
|
+
self._text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(self.TEXT_MODEL_NAME)
|
|
31
|
+
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.TEXT_MODEL_NAME)
|
|
32
|
+
self._model, self._preprocess = clip.load(self.IMAGE_MODEL_NAME, device=get_torch_device_name())
|
|
33
|
+
|
|
34
|
+
def compute_score(self, caption: str, image_location: str) -> float:
|
|
35
|
+
# Get text features
|
|
36
|
+
text_features = self._text_model.forward(caption, self._tokenizer)
|
|
37
|
+
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
|
|
38
|
+
text_features = text_features.to(self._device)
|
|
39
|
+
|
|
40
|
+
image = open_image(image_location)
|
|
41
|
+
image = self._preprocess(image).unsqueeze(0).to(self._device)
|
|
42
|
+
|
|
43
|
+
# Get image features
|
|
44
|
+
with torch.no_grad():
|
|
45
|
+
image_features = self._model.encode_image(image)
|
|
46
|
+
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
|
|
47
|
+
|
|
48
|
+
# Compute score using text and image features
|
|
49
|
+
score = 100 * (image_features * text_features).sum(axis=-1)
|
|
50
|
+
return score.detach().item()
|
|
@@ -8,11 +8,9 @@ from helm.common.request import (
|
|
|
8
8
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
9
9
|
Request,
|
|
10
10
|
RequestResult,
|
|
11
|
-
|
|
11
|
+
GeneratedOutput,
|
|
12
12
|
Token,
|
|
13
13
|
)
|
|
14
|
-
from helm.proxy.models import get_models_by_organization
|
|
15
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
16
14
|
from .client import CachingClient, truncate_sequence
|
|
17
15
|
from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
|
|
18
16
|
|
|
@@ -21,8 +19,8 @@ class CohereClient(CachingClient):
|
|
|
21
19
|
ORGANIZATION: str = "cohere"
|
|
22
20
|
GENERATE_ENDPOINT: str = "generate"
|
|
23
21
|
|
|
24
|
-
def __init__(self, api_key: str,
|
|
25
|
-
super().__init__(cache_config=cache_config
|
|
22
|
+
def __init__(self, api_key: str, cache_config: CacheConfig):
|
|
23
|
+
super().__init__(cache_config=cache_config)
|
|
26
24
|
self.api_key: str = api_key
|
|
27
25
|
|
|
28
26
|
def make_request(self, request: Request) -> RequestResult:
|
|
@@ -44,8 +42,6 @@ class CohereClient(CachingClient):
|
|
|
44
42
|
# so `max_tokens` has to be greater than 0 when `return_likelihoods` is set to "GENERATION".
|
|
45
43
|
assert request.max_tokens > 0, "max_tokens can only be 0 if echo_prompt=True"
|
|
46
44
|
|
|
47
|
-
# model: "Currently available models are small, medium, large, xlarge"
|
|
48
|
-
assert request.model in get_models_by_organization("cohere")
|
|
49
45
|
# temperature: "min value of 0.0, max value of 5.0"
|
|
50
46
|
assert 0.0 <= request.temperature <= 5.0, f"Invalid temperature: {request.temperature}. Valid range: [0,5]"
|
|
51
47
|
# num_generations: "min value of 1, max value of 5"
|
|
@@ -124,7 +120,7 @@ class CohereClient(CachingClient):
|
|
|
124
120
|
error: str = f"CohereClient error: {e}"
|
|
125
121
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
126
122
|
|
|
127
|
-
completions: List[
|
|
123
|
+
completions: List[GeneratedOutput] = []
|
|
128
124
|
for generation in response["generations"]:
|
|
129
125
|
# From https://docs.cohere.ai/generate-reference, "the likelihood refers to the average log-likelihood
|
|
130
126
|
# of the entire specified string..." What we want is the sum of the log probabilities of all tokens.
|
|
@@ -136,14 +132,7 @@ class CohereClient(CachingClient):
|
|
|
136
132
|
logprob: float = token_likelihood.get("likelihood", 0)
|
|
137
133
|
sequence_logprob += logprob
|
|
138
134
|
|
|
139
|
-
tokens.append(
|
|
140
|
-
Token(
|
|
141
|
-
text=token_likelihood["token"],
|
|
142
|
-
logprob=logprob,
|
|
143
|
-
# Cohere does not include the top log probs in the response
|
|
144
|
-
top_logprobs={},
|
|
145
|
-
)
|
|
146
|
-
)
|
|
135
|
+
tokens.append(Token(text=token_likelihood["token"], logprob=logprob))
|
|
147
136
|
|
|
148
137
|
sequence_text: str = generation["text"]
|
|
149
138
|
if request.echo_prompt and request.max_tokens > 0:
|
|
@@ -151,7 +140,7 @@ class CohereClient(CachingClient):
|
|
|
151
140
|
# `return_likelihoods` is "ALL" and `max_tokens` is greater than 0.
|
|
152
141
|
sequence_text = request.prompt + sequence_text
|
|
153
142
|
|
|
154
|
-
completion:
|
|
143
|
+
completion: GeneratedOutput = GeneratedOutput(text=sequence_text, logprob=sequence_logprob, tokens=tokens)
|
|
155
144
|
completion = truncate_sequence(completion, request)
|
|
156
145
|
completions.append(completion)
|
|
157
146
|
|