crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
- crfm_helm-0.5.1.dist-info/RECORD +654 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +25 -3
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +41 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +213 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +41 -1
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +205 -35
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +163 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +757 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +823 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +233 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +301 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +104 -73
- helm/clients/vertexai_client.py +400 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +111 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +33 -3
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1159 -538
- helm/config/model_metadata.yaml +868 -41
- helm/config/tokenizer_configs.yaml +149 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_deployment_definition.py +0 -92
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -1,18 +1,20 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
|
1
|
+
from typing import Any, Dict, List, Optional, TypedDict, Union, cast
|
|
2
2
|
import json
|
|
3
3
|
import requests
|
|
4
|
+
import tempfile
|
|
4
5
|
import time
|
|
5
6
|
import urllib.parse
|
|
6
7
|
|
|
7
8
|
from helm.common.cache import CacheConfig
|
|
8
9
|
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
10
|
+
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
|
|
9
11
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
12
|
from helm.common.request import (
|
|
11
13
|
wrap_request_time,
|
|
12
14
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
13
15
|
Request,
|
|
14
16
|
RequestResult,
|
|
15
|
-
|
|
17
|
+
GeneratedOutput,
|
|
16
18
|
Token,
|
|
17
19
|
ErrorFlags,
|
|
18
20
|
)
|
|
@@ -20,16 +22,30 @@ from helm.common.tokenization_request import (
|
|
|
20
22
|
TokenizationRequest,
|
|
21
23
|
TokenizationRequestResult,
|
|
22
24
|
)
|
|
23
|
-
from helm.proxy.
|
|
24
|
-
from .
|
|
25
|
+
from helm.proxy.retry import NonRetriableException
|
|
26
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
27
|
+
from helm.clients.client import CachingClient, truncate_sequence, truncate_and_tokenize_response_text
|
|
25
28
|
|
|
26
29
|
try:
|
|
27
|
-
import
|
|
30
|
+
from anthropic import Anthropic, BadRequestError
|
|
31
|
+
from anthropic.types import MessageParam
|
|
32
|
+
from anthropic.types.image_block_param import ImageBlockParam
|
|
33
|
+
from anthropic.types.text_block_param import TextBlockParam
|
|
28
34
|
import websocket
|
|
29
35
|
except ModuleNotFoundError as e:
|
|
30
36
|
handle_module_not_found_error(e, ["anthropic"])
|
|
31
37
|
|
|
32
38
|
|
|
39
|
+
class AnthropicCompletionRequest(TypedDict):
|
|
40
|
+
prompt: str
|
|
41
|
+
stop_sequences: List[str]
|
|
42
|
+
model: str
|
|
43
|
+
max_tokens_to_sample: int
|
|
44
|
+
temperature: float
|
|
45
|
+
top_p: float
|
|
46
|
+
top_k: int
|
|
47
|
+
|
|
48
|
+
|
|
33
49
|
class AnthropicClient(CachingClient):
|
|
34
50
|
"""
|
|
35
51
|
Client for the Anthropic models (https://arxiv.org/abs/2204.05862).
|
|
@@ -53,6 +69,9 @@ class AnthropicClient(CachingClient):
|
|
|
53
69
|
MAX_COMPLETION_LENGTH: int = (
|
|
54
70
|
8192 # See https://docs.google.com/document/d/1vX6xgoA-KEKxqtMlBVAqYvE8KUfZ7ABCjTxAjf1T5kI/edit#
|
|
55
71
|
)
|
|
72
|
+
# An Anthropic error message: "At least one of the image dimensions exceed max allowed size: 8000 pixels"
|
|
73
|
+
MAX_IMAGE_DIMENSION: int = 8000
|
|
74
|
+
|
|
56
75
|
ADDITIONAL_TOKENS: int = 5
|
|
57
76
|
PROMPT_ANSWER_START: str = "The answer is "
|
|
58
77
|
|
|
@@ -63,12 +82,12 @@ class AnthropicClient(CachingClient):
|
|
|
63
82
|
self.tokenizer = tokenizer
|
|
64
83
|
self.tokenizer_name = tokenizer_name
|
|
65
84
|
self.api_key: Optional[str] = api_key
|
|
66
|
-
self.
|
|
85
|
+
self.client = Anthropic(api_key=api_key)
|
|
67
86
|
|
|
68
|
-
def _send_request(self, raw_request:
|
|
87
|
+
def _send_request(self, raw_request: AnthropicCompletionRequest) -> Dict[str, Any]:
|
|
69
88
|
if self.api_key is None:
|
|
70
89
|
raise Exception("API key is not set. Please set it in the HELM config file.")
|
|
71
|
-
result = self.
|
|
90
|
+
result = self.client.completions.create(**raw_request).model_dump()
|
|
72
91
|
assert "error" not in result, f"Request failed with error: {result['error']}"
|
|
73
92
|
return result
|
|
74
93
|
|
|
@@ -103,7 +122,7 @@ class AnthropicClient(CachingClient):
|
|
|
103
122
|
if request.max_tokens == 0 and not request.echo_prompt:
|
|
104
123
|
raise ValueError("echo_prompt must be True when max_tokens=0.")
|
|
105
124
|
|
|
106
|
-
raw_request = {
|
|
125
|
+
raw_request: AnthropicCompletionRequest = {
|
|
107
126
|
"prompt": request.prompt,
|
|
108
127
|
"stop_sequences": request.stop_sequences,
|
|
109
128
|
"model": request.model_engine,
|
|
@@ -113,7 +132,7 @@ class AnthropicClient(CachingClient):
|
|
|
113
132
|
"top_k": request.top_k_per_token,
|
|
114
133
|
}
|
|
115
134
|
|
|
116
|
-
completions: List[
|
|
135
|
+
completions: List[GeneratedOutput] = []
|
|
117
136
|
|
|
118
137
|
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
119
138
|
for completion_index in range(request.num_completions):
|
|
@@ -172,11 +191,9 @@ class AnthropicClient(CachingClient):
|
|
|
172
191
|
)
|
|
173
192
|
|
|
174
193
|
# Log probs are not currently not supported by the Anthropic, so set to 0 for now.
|
|
175
|
-
tokens: List[Token] = [
|
|
176
|
-
Token(text=str(text), logprob=0, top_logprobs={}) for text in tokenization_result.raw_tokens
|
|
177
|
-
]
|
|
194
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
178
195
|
|
|
179
|
-
completion =
|
|
196
|
+
completion = GeneratedOutput(text=response["completion"], logprob=0, tokens=tokens)
|
|
180
197
|
# See NOTE() in _filter_completion() to understand why warnings are printed for truncation.
|
|
181
198
|
# TODO(#1512): Fix this with post-processing.
|
|
182
199
|
sequence = truncate_sequence(completion, request, print_warning=True)
|
|
@@ -192,6 +209,205 @@ class AnthropicClient(CachingClient):
|
|
|
192
209
|
)
|
|
193
210
|
|
|
194
211
|
|
|
212
|
+
def _is_content_moderation_failure(response: Dict) -> bool:
|
|
213
|
+
"""Return whether a response failed because of the content moderation filter."""
|
|
214
|
+
if (
|
|
215
|
+
"error" in response
|
|
216
|
+
and "message" in response["error"]
|
|
217
|
+
and response["error"]["message"] == "Output blocked by content filtering policy"
|
|
218
|
+
):
|
|
219
|
+
hlog(f"Anthropic - output blocked by content filtering policy: {response}")
|
|
220
|
+
return True
|
|
221
|
+
return False
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class AnthropicMessagesRequest(TypedDict, total=False):
|
|
225
|
+
messages: List[MessageParam]
|
|
226
|
+
model: str
|
|
227
|
+
stop_sequences: List[str]
|
|
228
|
+
system: str
|
|
229
|
+
max_tokens: int
|
|
230
|
+
temperature: float
|
|
231
|
+
top_k: int
|
|
232
|
+
top_p: float
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class AnthropicMessagesRequestError(NonRetriableException):
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class AnthropicMessagesResponseError(Exception):
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class AnthropicMessagesClient(CachingClient):
|
|
244
|
+
# Source: https://docs.anthropic.com/claude/docs/models-overview
|
|
245
|
+
MAX_OUTPUT_TOKENS: int = 4096
|
|
246
|
+
|
|
247
|
+
def __init__(
|
|
248
|
+
self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
|
|
249
|
+
):
|
|
250
|
+
super().__init__(cache_config=cache_config)
|
|
251
|
+
self.tokenizer = tokenizer
|
|
252
|
+
self.tokenizer_name = tokenizer_name
|
|
253
|
+
self.client = Anthropic(api_key=api_key)
|
|
254
|
+
self.api_key: Optional[str] = api_key
|
|
255
|
+
|
|
256
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
257
|
+
if request.max_tokens > AnthropicMessagesClient.MAX_OUTPUT_TOKENS:
|
|
258
|
+
raise AnthropicMessagesRequestError(
|
|
259
|
+
f"Request.max_tokens must be <= {AnthropicMessagesClient.MAX_OUTPUT_TOKENS}"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
messages: List[MessageParam] = []
|
|
263
|
+
system_message: Optional[MessageParam] = None
|
|
264
|
+
|
|
265
|
+
if request.messages is not None:
|
|
266
|
+
# TODO(#2439): Refactor out Request validation
|
|
267
|
+
if request.multimodal_prompt is not None or request.prompt:
|
|
268
|
+
raise AnthropicMessagesRequestError(
|
|
269
|
+
"Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set"
|
|
270
|
+
)
|
|
271
|
+
messages = cast(List[MessageParam], request.messages)
|
|
272
|
+
if messages[0]["role"] == "system":
|
|
273
|
+
system_message = messages[0]
|
|
274
|
+
messages = messages[1:]
|
|
275
|
+
|
|
276
|
+
elif request.multimodal_prompt is not None:
|
|
277
|
+
# TODO(#2439): Refactor out Request validation
|
|
278
|
+
if request.messages is not None or request.prompt:
|
|
279
|
+
raise AnthropicMessagesRequestError(
|
|
280
|
+
"Exactly one of Request.messages, Request.prompt or Request.multimodal_prompt should be set"
|
|
281
|
+
)
|
|
282
|
+
blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
|
|
283
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
284
|
+
if media_object.is_type(IMAGE_TYPE):
|
|
285
|
+
# TODO(#2439): Refactor out Request validation
|
|
286
|
+
if not media_object.location:
|
|
287
|
+
raise Exception("MediaObject of image type has missing location field value")
|
|
288
|
+
|
|
289
|
+
from helm.common.images_utils import encode_base64, get_dimensions, copy_image
|
|
290
|
+
|
|
291
|
+
image_location: str = media_object.location
|
|
292
|
+
base64_image: str
|
|
293
|
+
|
|
294
|
+
image_width, image_height = get_dimensions(media_object.location)
|
|
295
|
+
if (
|
|
296
|
+
image_width > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
297
|
+
or image_height > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
298
|
+
):
|
|
299
|
+
hlog(
|
|
300
|
+
f"WARNING: Image {image_location} exceeds max allowed size: "
|
|
301
|
+
f"{AnthropicClient.MAX_IMAGE_DIMENSION} pixels"
|
|
302
|
+
)
|
|
303
|
+
# Save the resized image to a temporary file
|
|
304
|
+
with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
|
|
305
|
+
hlog(f"Resizing image to temporary path: {temp_file.name}")
|
|
306
|
+
copy_image(
|
|
307
|
+
src=image_location,
|
|
308
|
+
dest=temp_file.name,
|
|
309
|
+
width=min(image_width, AnthropicClient.MAX_IMAGE_DIMENSION),
|
|
310
|
+
height=min(image_height, AnthropicClient.MAX_IMAGE_DIMENSION),
|
|
311
|
+
)
|
|
312
|
+
base64_image = encode_base64(temp_file.name, format="JPEG")
|
|
313
|
+
else:
|
|
314
|
+
base64_image = encode_base64(image_location, format="JPEG")
|
|
315
|
+
|
|
316
|
+
image_block: ImageBlockParam = {
|
|
317
|
+
"type": "image",
|
|
318
|
+
"source": {
|
|
319
|
+
"type": "base64",
|
|
320
|
+
"media_type": "image/jpeg",
|
|
321
|
+
"data": base64_image,
|
|
322
|
+
},
|
|
323
|
+
}
|
|
324
|
+
blocks.append(image_block)
|
|
325
|
+
if media_object.is_type(TEXT_TYPE):
|
|
326
|
+
# TODO(#2439): Refactor out Request validation
|
|
327
|
+
if media_object.text is None:
|
|
328
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
329
|
+
text_block: TextBlockParam = {
|
|
330
|
+
"type": "text",
|
|
331
|
+
"text": media_object.text,
|
|
332
|
+
}
|
|
333
|
+
# Anthropic does not support empty text blocks
|
|
334
|
+
if media_object.text.strip():
|
|
335
|
+
blocks.append(text_block)
|
|
336
|
+
messages = [{"role": "user", "content": blocks}]
|
|
337
|
+
|
|
338
|
+
else:
|
|
339
|
+
messages = [{"role": "user", "content": request.prompt}]
|
|
340
|
+
|
|
341
|
+
raw_request: AnthropicMessagesRequest = {
|
|
342
|
+
"messages": messages,
|
|
343
|
+
"model": request.model_engine,
|
|
344
|
+
"stop_sequences": request.stop_sequences,
|
|
345
|
+
"max_tokens": request.max_tokens,
|
|
346
|
+
"temperature": request.temperature,
|
|
347
|
+
"top_p": request.top_p,
|
|
348
|
+
"top_k": request.top_k_per_token,
|
|
349
|
+
}
|
|
350
|
+
if system_message is not None:
|
|
351
|
+
raw_request["system"] = cast(str, system_message["content"])
|
|
352
|
+
completions: List[GeneratedOutput] = []
|
|
353
|
+
|
|
354
|
+
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
355
|
+
for completion_index in range(request.num_completions):
|
|
356
|
+
|
|
357
|
+
def do_it() -> Dict[str, Any]:
|
|
358
|
+
try:
|
|
359
|
+
result = self.client.messages.create(**raw_request).model_dump()
|
|
360
|
+
if "content" not in result or not result["content"]:
|
|
361
|
+
raise AnthropicMessagesResponseError(f"Anthropic response has empty content: {result}")
|
|
362
|
+
elif "text" not in result["content"][0]:
|
|
363
|
+
raise AnthropicMessagesResponseError(f"Anthropic response has non-text content: {result}")
|
|
364
|
+
return result
|
|
365
|
+
except BadRequestError as e:
|
|
366
|
+
response = e.response.json()
|
|
367
|
+
if _is_content_moderation_failure(response):
|
|
368
|
+
return response
|
|
369
|
+
raise
|
|
370
|
+
|
|
371
|
+
cache_key = CachingClient.make_cache_key(
|
|
372
|
+
{
|
|
373
|
+
"completion_index": completion_index,
|
|
374
|
+
**raw_request,
|
|
375
|
+
},
|
|
376
|
+
request,
|
|
377
|
+
)
|
|
378
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
379
|
+
|
|
380
|
+
if _is_content_moderation_failure(response):
|
|
381
|
+
hlog(
|
|
382
|
+
f"WARNING: Returning empty request for {request.model_deployment} "
|
|
383
|
+
"due to content moderation filter"
|
|
384
|
+
)
|
|
385
|
+
return RequestResult(
|
|
386
|
+
success=False,
|
|
387
|
+
cached=cached,
|
|
388
|
+
error=response["error"]["message"],
|
|
389
|
+
completions=[],
|
|
390
|
+
embedding=[],
|
|
391
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
392
|
+
request_time=response["request_time"],
|
|
393
|
+
request_datetime=response["request_datetime"],
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
completion = truncate_and_tokenize_response_text(
|
|
397
|
+
response["content"][0]["text"], request, self.tokenizer, self.tokenizer_name, original_finish_reason=""
|
|
398
|
+
)
|
|
399
|
+
completions.append(completion)
|
|
400
|
+
|
|
401
|
+
return RequestResult(
|
|
402
|
+
success=True,
|
|
403
|
+
cached=cached,
|
|
404
|
+
request_time=response["request_time"],
|
|
405
|
+
request_datetime=response["request_datetime"],
|
|
406
|
+
completions=completions,
|
|
407
|
+
embedding=[],
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
195
411
|
class AnthropicRequestError(Exception):
|
|
196
412
|
pass
|
|
197
413
|
|
|
@@ -394,7 +610,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
394
610
|
|
|
395
611
|
# Since Anthropic doesn't support multiple completions, we have to manually call it multiple times,
|
|
396
612
|
# and aggregate the results into `completions` and `request_time`.
|
|
397
|
-
completions: List[
|
|
613
|
+
completions: List[GeneratedOutput] = []
|
|
398
614
|
all_cached = True
|
|
399
615
|
request_time = 0
|
|
400
616
|
request_datetime: Optional[int] = None
|
|
@@ -427,8 +643,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
427
643
|
for text, token_logprob, all_logprobs, all_tokens in zip(
|
|
428
644
|
log_probs["tokens"], log_probs["logprobs"], log_probs["topk_logprobs"], log_probs["topk_tokens"]
|
|
429
645
|
):
|
|
430
|
-
|
|
431
|
-
tokens.append(Token(text=text, logprob=token_logprob, top_logprobs=top_logprobs))
|
|
646
|
+
tokens.append(Token(text=text, logprob=token_logprob))
|
|
432
647
|
sequence_logprob += token_logprob
|
|
433
648
|
|
|
434
649
|
finish_reason: str = response["stop_reason"]
|
|
@@ -436,7 +651,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
436
651
|
if finish_reason == AnthropicLegacyClient.STOP_SEQUENCE_STOP_REASON:
|
|
437
652
|
finish_reason = "stop"
|
|
438
653
|
|
|
439
|
-
completion =
|
|
654
|
+
completion = GeneratedOutput(
|
|
440
655
|
text=response["text"],
|
|
441
656
|
logprob=sequence_logprob,
|
|
442
657
|
tokens=tokens,
|
|
@@ -1,22 +1,23 @@
|
|
|
1
|
-
import os
|
|
2
1
|
from dataclasses import replace
|
|
2
|
+
import os
|
|
3
3
|
from typing import Any, Dict, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
from retrying import Attempt, RetryError
|
|
6
6
|
|
|
7
7
|
from helm.benchmark.model_deployment_registry import ModelDeployment, get_model_deployment
|
|
8
|
-
from helm.common.
|
|
8
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
9
|
+
from helm.common.file_caches.local_file_cache import LocalFileCache
|
|
9
10
|
from helm.common.credentials_utils import provide_api_key
|
|
10
|
-
from helm.common.
|
|
11
|
+
from helm.common.cache_backend_config import CacheBackendConfig, CacheConfig
|
|
11
12
|
from helm.common.hierarchical_logger import hlog
|
|
12
13
|
from helm.common.object_spec import create_object, inject_object_spec_args
|
|
13
14
|
from helm.common.request import Request, RequestResult
|
|
14
|
-
from helm.
|
|
15
|
+
from helm.clients.client import Client
|
|
16
|
+
from helm.clients.moderation_api_client import ModerationAPIClient
|
|
15
17
|
from helm.proxy.critique.critique_client import CritiqueClient
|
|
16
|
-
from helm.
|
|
17
|
-
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
18
|
+
from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
18
19
|
from helm.proxy.retry import NonRetriableException, retry_request
|
|
19
|
-
from helm.
|
|
20
|
+
from helm.tokenizers.auto_tokenizer import AutoTokenizer
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class AuthenticationError(NonRetriableException):
|
|
@@ -26,18 +27,17 @@ class AuthenticationError(NonRetriableException):
|
|
|
26
27
|
class AutoClient(Client):
|
|
27
28
|
"""Automatically dispatch to the proper `Client` based on the model deployment name."""
|
|
28
29
|
|
|
29
|
-
def __init__(
|
|
30
|
-
self
|
|
30
|
+
def __init__(
|
|
31
|
+
self, credentials: Mapping[str, Any], file_storage_path: str, cache_backend_config: CacheBackendConfig
|
|
32
|
+
):
|
|
33
|
+
self._auto_tokenizer = AutoTokenizer(credentials, cache_backend_config)
|
|
31
34
|
self.credentials = credentials
|
|
32
|
-
self.
|
|
33
|
-
self.
|
|
35
|
+
self.file_storage_path = file_storage_path
|
|
36
|
+
self.cache_backend_config = cache_backend_config
|
|
34
37
|
self.clients: Dict[str, Client] = {}
|
|
35
|
-
# self._huggingface_client is lazily instantiated by get_huggingface_client()
|
|
36
|
-
self._huggingface_client: Optional[HuggingFaceClient] = None
|
|
37
|
-
# self._critique_client is lazily instantiated by get_critique_client()
|
|
38
38
|
self._critique_client: Optional[CritiqueClient] = None
|
|
39
|
-
hlog(f"AutoClient:
|
|
40
|
-
hlog(f"AutoClient:
|
|
39
|
+
hlog(f"AutoClient: file_storage_path = {file_storage_path}")
|
|
40
|
+
hlog(f"AutoClient: cache_backend_config = {cache_backend_config}")
|
|
41
41
|
|
|
42
42
|
def _get_client(self, model_deployment_name: str) -> Client:
|
|
43
43
|
"""Return a client based on the model, creating it if necessary."""
|
|
@@ -64,11 +64,14 @@ class AutoClient(Client):
|
|
|
64
64
|
|
|
65
65
|
# Prepare a cache
|
|
66
66
|
host_organization: str = model_deployment.host_organization
|
|
67
|
-
cache_config: CacheConfig =
|
|
67
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config(host_organization)
|
|
68
68
|
|
|
69
69
|
client_spec = inject_object_spec_args(
|
|
70
70
|
model_deployment.client_spec,
|
|
71
|
-
constant_bindings={
|
|
71
|
+
constant_bindings={
|
|
72
|
+
"cache_config": cache_config,
|
|
73
|
+
"tokenizer_name": model_deployment.tokenizer_name,
|
|
74
|
+
},
|
|
72
75
|
provider_bindings={
|
|
73
76
|
"api_key": lambda: provide_api_key(self.credentials, host_organization, model_deployment_name),
|
|
74
77
|
"tokenizer": lambda: self._auto_tokenizer._get_tokenizer(
|
|
@@ -77,9 +80,14 @@ class AutoClient(Client):
|
|
|
77
80
|
"org_id": lambda: self.credentials.get(
|
|
78
81
|
host_organization + "OrgId", None
|
|
79
82
|
), # OpenAI, GooseAI, Microsoft
|
|
80
|
-
"
|
|
83
|
+
"moderation_api_client": lambda: self.get_moderation_api_client(), # OpenAI DALL-E
|
|
84
|
+
"lock_file_path": lambda: os.path.join(
|
|
85
|
+
self.file_storage_path, f"{host_organization}.lock"
|
|
86
|
+
), # Microsoft
|
|
81
87
|
"project_id": lambda: self.credentials.get(host_organization + "ProjectId", None), # VertexAI
|
|
82
88
|
"location": lambda: self.credentials.get(host_organization + "Location", None), # VertexAI
|
|
89
|
+
"hf_auth_token": lambda: self.credentials.get("huggingfaceAuthToken", None), # HuggingFace
|
|
90
|
+
"file_cache": lambda: self._get_file_cache(host_organization), # Text-to-image models
|
|
83
91
|
},
|
|
84
92
|
)
|
|
85
93
|
client = create_object(client_spec)
|
|
@@ -117,13 +125,37 @@ class AutoClient(Client):
|
|
|
117
125
|
# Notify our user that we failed to make the request even after retrying.
|
|
118
126
|
return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")
|
|
119
127
|
|
|
128
|
+
def get_gcs_client(self):
|
|
129
|
+
from .gcs_client import GCSClient
|
|
130
|
+
|
|
131
|
+
bucket_name: str = self.credentials["gcsBucketName"]
|
|
132
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("gcs")
|
|
133
|
+
return GCSClient(bucket_name, cache_config)
|
|
134
|
+
|
|
135
|
+
def get_nudity_check_client(self):
|
|
136
|
+
from helm.clients.image_generation.nudity_check_client import NudityCheckClient
|
|
137
|
+
|
|
138
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("nudity")
|
|
139
|
+
return NudityCheckClient(cache_config)
|
|
140
|
+
|
|
141
|
+
def get_clip_score_client(self):
|
|
142
|
+
from .clip_score_client import CLIPScoreClient
|
|
143
|
+
|
|
144
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("clip_score")
|
|
145
|
+
return CLIPScoreClient(cache_config)
|
|
146
|
+
|
|
120
147
|
def get_toxicity_classifier_client(self) -> ToxicityClassifierClient:
|
|
121
148
|
"""Get the toxicity classifier client. We currently only support Perspective API."""
|
|
122
|
-
from helm.
|
|
149
|
+
from helm.clients.perspective_api_client import PerspectiveAPIClient
|
|
123
150
|
|
|
124
|
-
cache_config: CacheConfig =
|
|
151
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("perspectiveapi")
|
|
125
152
|
return PerspectiveAPIClient(self.credentials.get("perspectiveApiKey", ""), cache_config)
|
|
126
153
|
|
|
154
|
+
def get_moderation_api_client(self) -> ModerationAPIClient:
|
|
155
|
+
"""Get the ModerationAPI client."""
|
|
156
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("ModerationAPI")
|
|
157
|
+
return ModerationAPIClient(self.credentials.get("openaiApiKey", ""), cache_config)
|
|
158
|
+
|
|
127
159
|
def get_critique_client(self) -> CritiqueClient:
|
|
128
160
|
"""Get the critique client."""
|
|
129
161
|
if self._critique_client:
|
|
@@ -148,7 +180,7 @@ class AutoClient(Client):
|
|
|
148
180
|
if not surgeai_credentials:
|
|
149
181
|
raise ValueError("surgeaiApiKey credentials are required for SurgeAICritiqueClient")
|
|
150
182
|
self._critique_client = SurgeAICritiqueClient(
|
|
151
|
-
surgeai_credentials,
|
|
183
|
+
surgeai_credentials, self.cache_backend_config.get_cache_config("surgeai")
|
|
152
184
|
)
|
|
153
185
|
elif critique_type == "model":
|
|
154
186
|
from helm.proxy.critique.model_critique_client import ModelCritiqueClient
|
|
@@ -168,7 +200,7 @@ class AutoClient(Client):
|
|
|
168
200
|
if not scale_credentials:
|
|
169
201
|
raise ValueError("scaleApiKey is required for ScaleCritiqueClient")
|
|
170
202
|
self._critique_client = ScaleCritiqueClient(
|
|
171
|
-
scale_credentials,
|
|
203
|
+
scale_credentials, self.cache_backend_config.get_cache_config("scale"), scale_project
|
|
172
204
|
)
|
|
173
205
|
else:
|
|
174
206
|
raise ValueError(
|
|
@@ -177,11 +209,7 @@ class AutoClient(Client):
|
|
|
177
209
|
)
|
|
178
210
|
return self._critique_client
|
|
179
211
|
|
|
180
|
-
def
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
return self._huggingface_client
|
|
185
|
-
cache_config = build_cache_config(self.cache_path, self.mongo_uri, "huggingface")
|
|
186
|
-
self._huggingface_client = HuggingFaceClient(cache_config=cache_config)
|
|
187
|
-
return self._huggingface_client
|
|
212
|
+
def _get_file_cache(self, host_organization: str) -> FileCache:
|
|
213
|
+
# Initialize `FileCache` for text-to-image model APIs
|
|
214
|
+
local_file_cache_path: str = os.path.join(self.file_storage_path, "output", host_organization)
|
|
215
|
+
return LocalFileCache(local_file_cache_path, file_extension="png")
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional
|
|
6
|
+
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
10
|
+
from helm.clients.bedrock_utils import get_bedrock_client
|
|
11
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
JSON_CONTENT_TYPE = "application/json"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BedrockClient(CachingClient):
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
20
|
+
raise NotImplementedError()
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
24
|
+
raise NotImplementedError()
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
cache_config: CacheConfig,
|
|
29
|
+
tokenizer: Tokenizer,
|
|
30
|
+
tokenizer_name: str,
|
|
31
|
+
bedrock_model_id: Optional[str] = None,
|
|
32
|
+
assumed_role: Optional[str] = None,
|
|
33
|
+
region: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(cache_config=cache_config)
|
|
36
|
+
self.tokenizer = tokenizer
|
|
37
|
+
self.tokenizer_name = tokenizer_name
|
|
38
|
+
self.bedrock_model_id = bedrock_model_id
|
|
39
|
+
self.bedrock_client = get_bedrock_client(
|
|
40
|
+
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
|
|
41
|
+
region=region or os.environ.get("AWS_DEFAULT_REGION", None),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
45
|
+
# model_id should be something like "amazon.titan-tg1-large"
|
|
46
|
+
model_id = self.bedrock_model_id if self.bedrock_model_id else request.model.replace("/", ".")
|
|
47
|
+
raw_request = self.convert_request_to_raw_request(request)
|
|
48
|
+
|
|
49
|
+
# modelId isn't part of raw_request, so it must be explicitly passed into the input to
|
|
50
|
+
raw_request_for_cache: Dict = {"modelId": model_id, **deepcopy(raw_request)}
|
|
51
|
+
cache_key: Mapping = CachingClient.make_cache_key(raw_request_for_cache, request)
|
|
52
|
+
|
|
53
|
+
def do_it() -> Dict[Any, Any]:
|
|
54
|
+
response = self.bedrock_client.invoke_model(
|
|
55
|
+
body=json.dumps(raw_request), modelId=model_id, accept=JSON_CONTENT_TYPE, contentType=JSON_CONTENT_TYPE
|
|
56
|
+
)
|
|
57
|
+
return json.loads(response.get("body").read())
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
61
|
+
except Exception as error:
|
|
62
|
+
return RequestResult(
|
|
63
|
+
success=False,
|
|
64
|
+
cached=False,
|
|
65
|
+
error=str(error),
|
|
66
|
+
completions=[],
|
|
67
|
+
embedding=[],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
completions = self.convert_raw_response_to_completions(response, request)
|
|
71
|
+
|
|
72
|
+
return RequestResult(
|
|
73
|
+
success=True,
|
|
74
|
+
cached=cached,
|
|
75
|
+
request_time=response["request_time"],
|
|
76
|
+
request_datetime=response["request_datetime"],
|
|
77
|
+
completions=completions,
|
|
78
|
+
embedding=[],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BedrockTitanClient(BedrockClient):
|
|
83
|
+
_COMPLETION_REASON_TO_FINISH_REASON = {
|
|
84
|
+
"LENGTH": "length",
|
|
85
|
+
"FINISH": "endoftext",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
89
|
+
# TODO: Support the following:
|
|
90
|
+
# - top_k_per_token
|
|
91
|
+
# - echo_prompt
|
|
92
|
+
# - num_completions
|
|
93
|
+
return {
|
|
94
|
+
"inputText": request.prompt,
|
|
95
|
+
"textGenerationConfig": {
|
|
96
|
+
"maxTokenCount": request.max_tokens,
|
|
97
|
+
# We ignore stop sequences in the request and always set stop sequences to the empty list.
|
|
98
|
+
# This is because:
|
|
99
|
+
#
|
|
100
|
+
# 1. The only permitted stop sequences are "|" and "User:"
|
|
101
|
+
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
|
102
|
+
# - https://github.com/boto/boto3/issues/3993
|
|
103
|
+
# - https://github.com/aws/aws-sdk/issues/692
|
|
104
|
+
#
|
|
105
|
+
# 2. Titan has the tendency to emit "\n" as the first token in the generated text output,
|
|
106
|
+
# which would cause the output to stop immediately if "\n" is in the stop_sequences.
|
|
107
|
+
"stopSequences": [],
|
|
108
|
+
"temperature": request.temperature,
|
|
109
|
+
"topP": request.top_p,
|
|
110
|
+
},
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
114
|
+
# TODO: Support the following:
|
|
115
|
+
# - tokens
|
|
116
|
+
# - logprob
|
|
117
|
+
completions: List[GeneratedOutput] = []
|
|
118
|
+
for raw_completion in response["results"]:
|
|
119
|
+
output_text = raw_completion["outputText"]
|
|
120
|
+
# Call lstrip() Titan has the tendency to emit "\n" as the first token in the generated text output.
|
|
121
|
+
finish_reason = BedrockTitanClient._COMPLETION_REASON_TO_FINISH_REASON.get(
|
|
122
|
+
raw_completion["completionReason"], raw_completion["completionReason"].lower()
|
|
123
|
+
)
|
|
124
|
+
completion = truncate_and_tokenize_response_text(
|
|
125
|
+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
|
|
126
|
+
)
|
|
127
|
+
completions.append(completion)
|
|
128
|
+
return completions
|