crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
|
1
|
+
from typing import Any, Dict, List, Optional, TypedDict, Union, cast
|
|
2
2
|
import json
|
|
3
3
|
import requests
|
|
4
4
|
import time
|
|
@@ -6,13 +6,14 @@ import urllib.parse
|
|
|
6
6
|
|
|
7
7
|
from helm.common.cache import CacheConfig
|
|
8
8
|
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
9
|
+
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
|
|
9
10
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
11
|
from helm.common.request import (
|
|
11
12
|
wrap_request_time,
|
|
12
13
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
13
14
|
Request,
|
|
14
15
|
RequestResult,
|
|
15
|
-
|
|
16
|
+
GeneratedOutput,
|
|
16
17
|
Token,
|
|
17
18
|
ErrorFlags,
|
|
18
19
|
)
|
|
@@ -20,16 +21,30 @@ from helm.common.tokenization_request import (
|
|
|
20
21
|
TokenizationRequest,
|
|
21
22
|
TokenizationRequestResult,
|
|
22
23
|
)
|
|
23
|
-
from helm.proxy.
|
|
24
|
-
from .
|
|
24
|
+
from helm.proxy.retry import NonRetriableException
|
|
25
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
26
|
+
from helm.clients.client import CachingClient, truncate_sequence, truncate_and_tokenize_response_text
|
|
25
27
|
|
|
26
28
|
try:
|
|
27
|
-
import
|
|
29
|
+
from anthropic import Anthropic, BadRequestError
|
|
30
|
+
from anthropic.types import MessageParam
|
|
31
|
+
from anthropic.types.image_block_param import ImageBlockParam
|
|
32
|
+
from anthropic.types.text_block_param import TextBlockParam
|
|
28
33
|
import websocket
|
|
29
34
|
except ModuleNotFoundError as e:
|
|
30
35
|
handle_module_not_found_error(e, ["anthropic"])
|
|
31
36
|
|
|
32
37
|
|
|
38
|
+
class AnthropicCompletionRequest(TypedDict):
|
|
39
|
+
prompt: str
|
|
40
|
+
stop_sequences: List[str]
|
|
41
|
+
model: str
|
|
42
|
+
max_tokens_to_sample: int
|
|
43
|
+
temperature: float
|
|
44
|
+
top_p: float
|
|
45
|
+
top_k: int
|
|
46
|
+
|
|
47
|
+
|
|
33
48
|
class AnthropicClient(CachingClient):
|
|
34
49
|
"""
|
|
35
50
|
Client for the Anthropic models (https://arxiv.org/abs/2204.05862).
|
|
@@ -63,12 +78,12 @@ class AnthropicClient(CachingClient):
|
|
|
63
78
|
self.tokenizer = tokenizer
|
|
64
79
|
self.tokenizer_name = tokenizer_name
|
|
65
80
|
self.api_key: Optional[str] = api_key
|
|
66
|
-
self.
|
|
81
|
+
self.client = Anthropic(api_key=api_key)
|
|
67
82
|
|
|
68
|
-
def _send_request(self, raw_request:
|
|
83
|
+
def _send_request(self, raw_request: AnthropicCompletionRequest) -> Dict[str, Any]:
|
|
69
84
|
if self.api_key is None:
|
|
70
85
|
raise Exception("API key is not set. Please set it in the HELM config file.")
|
|
71
|
-
result = self.
|
|
86
|
+
result = self.client.completions.create(**raw_request).model_dump()
|
|
72
87
|
assert "error" not in result, f"Request failed with error: {result['error']}"
|
|
73
88
|
return result
|
|
74
89
|
|
|
@@ -103,7 +118,7 @@ class AnthropicClient(CachingClient):
|
|
|
103
118
|
if request.max_tokens == 0 and not request.echo_prompt:
|
|
104
119
|
raise ValueError("echo_prompt must be True when max_tokens=0.")
|
|
105
120
|
|
|
106
|
-
raw_request = {
|
|
121
|
+
raw_request: AnthropicCompletionRequest = {
|
|
107
122
|
"prompt": request.prompt,
|
|
108
123
|
"stop_sequences": request.stop_sequences,
|
|
109
124
|
"model": request.model_engine,
|
|
@@ -113,7 +128,7 @@ class AnthropicClient(CachingClient):
|
|
|
113
128
|
"top_k": request.top_k_per_token,
|
|
114
129
|
}
|
|
115
130
|
|
|
116
|
-
completions: List[
|
|
131
|
+
completions: List[GeneratedOutput] = []
|
|
117
132
|
|
|
118
133
|
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
119
134
|
for completion_index in range(request.num_completions):
|
|
@@ -172,11 +187,9 @@ class AnthropicClient(CachingClient):
|
|
|
172
187
|
)
|
|
173
188
|
|
|
174
189
|
# Log probs are not currently not supported by the Anthropic, so set to 0 for now.
|
|
175
|
-
tokens: List[Token] = [
|
|
176
|
-
Token(text=str(text), logprob=0, top_logprobs={}) for text in tokenization_result.raw_tokens
|
|
177
|
-
]
|
|
190
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
178
191
|
|
|
179
|
-
completion =
|
|
192
|
+
completion = GeneratedOutput(text=response["completion"], logprob=0, tokens=tokens)
|
|
180
193
|
# See NOTE() in _filter_completion() to understand why warnings are printed for truncation.
|
|
181
194
|
# TODO(#1512): Fix this with post-processing.
|
|
182
195
|
sequence = truncate_sequence(completion, request, print_warning=True)
|
|
@@ -192,6 +205,179 @@ class AnthropicClient(CachingClient):
|
|
|
192
205
|
)
|
|
193
206
|
|
|
194
207
|
|
|
208
|
+
def _is_content_moderation_failure(response: Dict) -> bool:
|
|
209
|
+
"""Return whether a a response failed because of the content moderation filter."""
|
|
210
|
+
if (
|
|
211
|
+
"error" in response
|
|
212
|
+
and "message" in response["error"]
|
|
213
|
+
and response["error"]["message"] == "Output blocked by content filtering policy"
|
|
214
|
+
):
|
|
215
|
+
hlog(f"Anthropic - output blocked by content filtering policy: {response}")
|
|
216
|
+
return True
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class AnthropicMessagesRequest(TypedDict, total=False):
|
|
221
|
+
messages: List[MessageParam]
|
|
222
|
+
model: str
|
|
223
|
+
stop_sequences: List[str]
|
|
224
|
+
system: str
|
|
225
|
+
max_tokens: int
|
|
226
|
+
temperature: float
|
|
227
|
+
top_k: int
|
|
228
|
+
top_p: float
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class AnthropicMessagesRequestError(NonRetriableException):
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class AnthropicMessagesResponseError(Exception):
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class AnthropicMessagesClient(CachingClient):
|
|
240
|
+
# Source: https://docs.anthropic.com/claude/docs/models-overview
|
|
241
|
+
MAX_OUTPUT_TOKENS = 4096
|
|
242
|
+
|
|
243
|
+
def __init__(
|
|
244
|
+
self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
|
|
245
|
+
):
|
|
246
|
+
super().__init__(cache_config=cache_config)
|
|
247
|
+
self.tokenizer = tokenizer
|
|
248
|
+
self.tokenizer_name = tokenizer_name
|
|
249
|
+
self.client = Anthropic(api_key=api_key)
|
|
250
|
+
self.api_key: Optional[str] = api_key
|
|
251
|
+
|
|
252
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
253
|
+
if request.max_tokens > AnthropicMessagesClient.MAX_OUTPUT_TOKENS:
|
|
254
|
+
raise AnthropicMessagesRequestError(
|
|
255
|
+
f"Request.max_tokens must be <= {AnthropicMessagesClient.MAX_OUTPUT_TOKENS}"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
messages: List[MessageParam] = []
|
|
259
|
+
system_message: Optional[MessageParam] = None
|
|
260
|
+
|
|
261
|
+
if request.messages is not None:
|
|
262
|
+
# TODO(#2439): Refactor out Request validation
|
|
263
|
+
if request.multimodal_prompt is not None or request.prompt:
|
|
264
|
+
raise AnthropicMessagesRequestError(
|
|
265
|
+
"Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set"
|
|
266
|
+
)
|
|
267
|
+
messages = cast(List[MessageParam], request.messages)
|
|
268
|
+
if messages[0]["role"] == "system":
|
|
269
|
+
system_message = messages[0]
|
|
270
|
+
messages = messages[1:]
|
|
271
|
+
|
|
272
|
+
elif request.multimodal_prompt is not None:
|
|
273
|
+
# TODO(#2439): Refactor out Request validation
|
|
274
|
+
if request.messages is not None or request.prompt:
|
|
275
|
+
raise AnthropicMessagesRequestError(
|
|
276
|
+
"Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set"
|
|
277
|
+
)
|
|
278
|
+
blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
|
|
279
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
280
|
+
if media_object.is_type(IMAGE_TYPE):
|
|
281
|
+
# TODO(#2439): Refactor out Request validation
|
|
282
|
+
if not media_object.location:
|
|
283
|
+
raise Exception("MediaObject of image type has missing location field value")
|
|
284
|
+
|
|
285
|
+
from helm.common.images_utils import encode_base64
|
|
286
|
+
|
|
287
|
+
base64_image: str = encode_base64(media_object.location, format="JPEG")
|
|
288
|
+
image_block: ImageBlockParam = {
|
|
289
|
+
"type": "image",
|
|
290
|
+
"source": {
|
|
291
|
+
"type": "base64",
|
|
292
|
+
"media_type": "image/jpeg",
|
|
293
|
+
"data": base64_image,
|
|
294
|
+
},
|
|
295
|
+
}
|
|
296
|
+
blocks.append(image_block)
|
|
297
|
+
if media_object.is_type(TEXT_TYPE):
|
|
298
|
+
# TODO(#2439): Refactor out Request validation
|
|
299
|
+
if media_object.text is None:
|
|
300
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
301
|
+
text_block: TextBlockParam = {
|
|
302
|
+
"type": "text",
|
|
303
|
+
"text": media_object.text,
|
|
304
|
+
}
|
|
305
|
+
blocks.append(text_block)
|
|
306
|
+
messages = [{"role": "user", "content": blocks}]
|
|
307
|
+
|
|
308
|
+
else:
|
|
309
|
+
messages = [{"role": "user", "content": request.prompt}]
|
|
310
|
+
|
|
311
|
+
raw_request: AnthropicMessagesRequest = {
|
|
312
|
+
"messages": messages,
|
|
313
|
+
"model": request.model_engine,
|
|
314
|
+
"stop_sequences": request.stop_sequences,
|
|
315
|
+
"max_tokens": request.max_tokens,
|
|
316
|
+
"temperature": request.temperature,
|
|
317
|
+
"top_p": request.top_p,
|
|
318
|
+
"top_k": request.top_k_per_token,
|
|
319
|
+
}
|
|
320
|
+
if system_message is not None:
|
|
321
|
+
raw_request["system"] = cast(str, system_message["content"])
|
|
322
|
+
completions: List[GeneratedOutput] = []
|
|
323
|
+
|
|
324
|
+
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
325
|
+
for completion_index in range(request.num_completions):
|
|
326
|
+
|
|
327
|
+
def do_it() -> Dict[str, Any]:
|
|
328
|
+
try:
|
|
329
|
+
result = self.client.messages.create(**raw_request).model_dump()
|
|
330
|
+
if "content" not in result or not result["content"]:
|
|
331
|
+
raise AnthropicMessagesResponseError(f"Anthropic response has empty content: {result}")
|
|
332
|
+
elif "text" not in result["content"][0]:
|
|
333
|
+
raise AnthropicMessagesResponseError(f"Anthropic response has non-text content: {result}")
|
|
334
|
+
return result
|
|
335
|
+
except BadRequestError as e:
|
|
336
|
+
response = e.response.json()
|
|
337
|
+
if _is_content_moderation_failure(response):
|
|
338
|
+
return response
|
|
339
|
+
raise
|
|
340
|
+
|
|
341
|
+
cache_key = CachingClient.make_cache_key(
|
|
342
|
+
{
|
|
343
|
+
"completion_index": completion_index,
|
|
344
|
+
**raw_request,
|
|
345
|
+
},
|
|
346
|
+
request,
|
|
347
|
+
)
|
|
348
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
349
|
+
|
|
350
|
+
if _is_content_moderation_failure(response):
|
|
351
|
+
hlog(
|
|
352
|
+
f"WARNING: Returning empty request for {request.model_deployment} "
|
|
353
|
+
"due to content moderation filter"
|
|
354
|
+
)
|
|
355
|
+
return RequestResult(
|
|
356
|
+
success=False,
|
|
357
|
+
cached=cached,
|
|
358
|
+
error=response["error"]["message"],
|
|
359
|
+
completions=[],
|
|
360
|
+
embedding=[],
|
|
361
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
362
|
+
request_time=response["request_time"],
|
|
363
|
+
request_datetime=response["request_datetime"],
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
completion = truncate_and_tokenize_response_text(
|
|
367
|
+
response["content"][0]["text"], request, self.tokenizer, self.tokenizer_name, original_finish_reason=""
|
|
368
|
+
)
|
|
369
|
+
completions.append(completion)
|
|
370
|
+
|
|
371
|
+
return RequestResult(
|
|
372
|
+
success=True,
|
|
373
|
+
cached=cached,
|
|
374
|
+
request_time=response["request_time"],
|
|
375
|
+
request_datetime=response["request_datetime"],
|
|
376
|
+
completions=completions,
|
|
377
|
+
embedding=[],
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
|
|
195
381
|
class AnthropicRequestError(Exception):
|
|
196
382
|
pass
|
|
197
383
|
|
|
@@ -394,7 +580,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
394
580
|
|
|
395
581
|
# Since Anthropic doesn't support multiple completions, we have to manually call it multiple times,
|
|
396
582
|
# and aggregate the results into `completions` and `request_time`.
|
|
397
|
-
completions: List[
|
|
583
|
+
completions: List[GeneratedOutput] = []
|
|
398
584
|
all_cached = True
|
|
399
585
|
request_time = 0
|
|
400
586
|
request_datetime: Optional[int] = None
|
|
@@ -427,8 +613,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
427
613
|
for text, token_logprob, all_logprobs, all_tokens in zip(
|
|
428
614
|
log_probs["tokens"], log_probs["logprobs"], log_probs["topk_logprobs"], log_probs["topk_tokens"]
|
|
429
615
|
):
|
|
430
|
-
|
|
431
|
-
tokens.append(Token(text=text, logprob=token_logprob, top_logprobs=top_logprobs))
|
|
616
|
+
tokens.append(Token(text=text, logprob=token_logprob))
|
|
432
617
|
sequence_logprob += token_logprob
|
|
433
618
|
|
|
434
619
|
finish_reason: str = response["stop_reason"]
|
|
@@ -436,7 +621,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
436
621
|
if finish_reason == AnthropicLegacyClient.STOP_SEQUENCE_STOP_REASON:
|
|
437
622
|
finish_reason = "stop"
|
|
438
623
|
|
|
439
|
-
completion =
|
|
624
|
+
completion = GeneratedOutput(
|
|
440
625
|
text=response["text"],
|
|
441
626
|
logprob=sequence_logprob,
|
|
442
627
|
tokens=tokens,
|
|
@@ -1,22 +1,23 @@
|
|
|
1
|
-
import os
|
|
2
1
|
from dataclasses import replace
|
|
2
|
+
import os
|
|
3
3
|
from typing import Any, Dict, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
from retrying import Attempt, RetryError
|
|
6
6
|
|
|
7
7
|
from helm.benchmark.model_deployment_registry import ModelDeployment, get_model_deployment
|
|
8
|
-
from helm.common.
|
|
8
|
+
from helm.common.file_caches.file_cache import FileCache
|
|
9
|
+
from helm.common.file_caches.local_file_cache import LocalFileCache
|
|
9
10
|
from helm.common.credentials_utils import provide_api_key
|
|
10
|
-
from helm.common.
|
|
11
|
+
from helm.common.cache_backend_config import CacheBackendConfig, CacheConfig
|
|
11
12
|
from helm.common.hierarchical_logger import hlog
|
|
12
13
|
from helm.common.object_spec import create_object, inject_object_spec_args
|
|
13
14
|
from helm.common.request import Request, RequestResult
|
|
14
|
-
from helm.
|
|
15
|
+
from helm.clients.client import Client
|
|
16
|
+
from helm.clients.moderation_api_client import ModerationAPIClient
|
|
15
17
|
from helm.proxy.critique.critique_client import CritiqueClient
|
|
16
|
-
from helm.
|
|
17
|
-
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
18
|
+
from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
18
19
|
from helm.proxy.retry import NonRetriableException, retry_request
|
|
19
|
-
from helm.
|
|
20
|
+
from helm.tokenizers.auto_tokenizer import AutoTokenizer
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class AuthenticationError(NonRetriableException):
|
|
@@ -26,18 +27,17 @@ class AuthenticationError(NonRetriableException):
|
|
|
26
27
|
class AutoClient(Client):
|
|
27
28
|
"""Automatically dispatch to the proper `Client` based on the model deployment name."""
|
|
28
29
|
|
|
29
|
-
def __init__(
|
|
30
|
-
self
|
|
30
|
+
def __init__(
|
|
31
|
+
self, credentials: Mapping[str, Any], file_storage_path: str, cache_backend_config: CacheBackendConfig
|
|
32
|
+
):
|
|
33
|
+
self._auto_tokenizer = AutoTokenizer(credentials, cache_backend_config)
|
|
31
34
|
self.credentials = credentials
|
|
32
|
-
self.
|
|
33
|
-
self.
|
|
35
|
+
self.file_storage_path = file_storage_path
|
|
36
|
+
self.cache_backend_config = cache_backend_config
|
|
34
37
|
self.clients: Dict[str, Client] = {}
|
|
35
|
-
# self._huggingface_client is lazily instantiated by get_huggingface_client()
|
|
36
|
-
self._huggingface_client: Optional[HuggingFaceClient] = None
|
|
37
|
-
# self._critique_client is lazily instantiated by get_critique_client()
|
|
38
38
|
self._critique_client: Optional[CritiqueClient] = None
|
|
39
|
-
hlog(f"AutoClient:
|
|
40
|
-
hlog(f"AutoClient:
|
|
39
|
+
hlog(f"AutoClient: file_storage_path = {file_storage_path}")
|
|
40
|
+
hlog(f"AutoClient: cache_backend_config = {cache_backend_config}")
|
|
41
41
|
|
|
42
42
|
def _get_client(self, model_deployment_name: str) -> Client:
|
|
43
43
|
"""Return a client based on the model, creating it if necessary."""
|
|
@@ -64,11 +64,14 @@ class AutoClient(Client):
|
|
|
64
64
|
|
|
65
65
|
# Prepare a cache
|
|
66
66
|
host_organization: str = model_deployment.host_organization
|
|
67
|
-
cache_config: CacheConfig =
|
|
67
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config(host_organization)
|
|
68
68
|
|
|
69
69
|
client_spec = inject_object_spec_args(
|
|
70
70
|
model_deployment.client_spec,
|
|
71
|
-
constant_bindings={
|
|
71
|
+
constant_bindings={
|
|
72
|
+
"cache_config": cache_config,
|
|
73
|
+
"tokenizer_name": model_deployment.tokenizer_name,
|
|
74
|
+
},
|
|
72
75
|
provider_bindings={
|
|
73
76
|
"api_key": lambda: provide_api_key(self.credentials, host_organization, model_deployment_name),
|
|
74
77
|
"tokenizer": lambda: self._auto_tokenizer._get_tokenizer(
|
|
@@ -77,9 +80,14 @@ class AutoClient(Client):
|
|
|
77
80
|
"org_id": lambda: self.credentials.get(
|
|
78
81
|
host_organization + "OrgId", None
|
|
79
82
|
), # OpenAI, GooseAI, Microsoft
|
|
80
|
-
"
|
|
83
|
+
"moderation_api_client": lambda: self.get_moderation_api_client(), # OpenAI DALL-E
|
|
84
|
+
"lock_file_path": lambda: os.path.join(
|
|
85
|
+
self.file_storage_path, f"{host_organization}.lock"
|
|
86
|
+
), # Microsoft
|
|
81
87
|
"project_id": lambda: self.credentials.get(host_organization + "ProjectId", None), # VertexAI
|
|
82
88
|
"location": lambda: self.credentials.get(host_organization + "Location", None), # VertexAI
|
|
89
|
+
"hf_auth_token": lambda: self.credentials.get("huggingfaceAuthToken", None), # HuggingFace
|
|
90
|
+
"file_cache": lambda: self._get_file_cache(host_organization), # Text-to-image models
|
|
83
91
|
},
|
|
84
92
|
)
|
|
85
93
|
client = create_object(client_spec)
|
|
@@ -117,13 +125,37 @@ class AutoClient(Client):
|
|
|
117
125
|
# Notify our user that we failed to make the request even after retrying.
|
|
118
126
|
return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")
|
|
119
127
|
|
|
128
|
+
def get_gcs_client(self):
|
|
129
|
+
from .gcs_client import GCSClient
|
|
130
|
+
|
|
131
|
+
bucket_name: str = self.credentials["gcsBucketName"]
|
|
132
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("gcs")
|
|
133
|
+
return GCSClient(bucket_name, cache_config)
|
|
134
|
+
|
|
135
|
+
def get_nudity_check_client(self):
|
|
136
|
+
from helm.clients.image_generation.nudity_check_client import NudityCheckClient
|
|
137
|
+
|
|
138
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("nudity")
|
|
139
|
+
return NudityCheckClient(cache_config)
|
|
140
|
+
|
|
141
|
+
def get_clip_score_client(self):
|
|
142
|
+
from .clip_score_client import CLIPScoreClient
|
|
143
|
+
|
|
144
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("clip_score")
|
|
145
|
+
return CLIPScoreClient(cache_config)
|
|
146
|
+
|
|
120
147
|
def get_toxicity_classifier_client(self) -> ToxicityClassifierClient:
|
|
121
148
|
"""Get the toxicity classifier client. We currently only support Perspective API."""
|
|
122
|
-
from helm.
|
|
149
|
+
from helm.clients.perspective_api_client import PerspectiveAPIClient
|
|
123
150
|
|
|
124
|
-
cache_config: CacheConfig =
|
|
151
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("perspectiveapi")
|
|
125
152
|
return PerspectiveAPIClient(self.credentials.get("perspectiveApiKey", ""), cache_config)
|
|
126
153
|
|
|
154
|
+
def get_moderation_api_client(self) -> ModerationAPIClient:
|
|
155
|
+
"""Get the ModerationAPI client."""
|
|
156
|
+
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("ModerationAPI")
|
|
157
|
+
return ModerationAPIClient(self.credentials.get("openaiApiKey", ""), cache_config)
|
|
158
|
+
|
|
127
159
|
def get_critique_client(self) -> CritiqueClient:
|
|
128
160
|
"""Get the critique client."""
|
|
129
161
|
if self._critique_client:
|
|
@@ -148,7 +180,7 @@ class AutoClient(Client):
|
|
|
148
180
|
if not surgeai_credentials:
|
|
149
181
|
raise ValueError("surgeaiApiKey credentials are required for SurgeAICritiqueClient")
|
|
150
182
|
self._critique_client = SurgeAICritiqueClient(
|
|
151
|
-
surgeai_credentials,
|
|
183
|
+
surgeai_credentials, self.cache_backend_config.get_cache_config("surgeai")
|
|
152
184
|
)
|
|
153
185
|
elif critique_type == "model":
|
|
154
186
|
from helm.proxy.critique.model_critique_client import ModelCritiqueClient
|
|
@@ -168,7 +200,7 @@ class AutoClient(Client):
|
|
|
168
200
|
if not scale_credentials:
|
|
169
201
|
raise ValueError("scaleApiKey is required for ScaleCritiqueClient")
|
|
170
202
|
self._critique_client = ScaleCritiqueClient(
|
|
171
|
-
scale_credentials,
|
|
203
|
+
scale_credentials, self.cache_backend_config.get_cache_config("scale"), scale_project
|
|
172
204
|
)
|
|
173
205
|
else:
|
|
174
206
|
raise ValueError(
|
|
@@ -177,11 +209,7 @@ class AutoClient(Client):
|
|
|
177
209
|
)
|
|
178
210
|
return self._critique_client
|
|
179
211
|
|
|
180
|
-
def
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
return self._huggingface_client
|
|
185
|
-
cache_config = build_cache_config(self.cache_path, self.mongo_uri, "huggingface")
|
|
186
|
-
self._huggingface_client = HuggingFaceClient(cache_config=cache_config)
|
|
187
|
-
return self._huggingface_client
|
|
212
|
+
def _get_file_cache(self, host_organization: str) -> FileCache:
|
|
213
|
+
# Initialize `FileCache` for text-to-image model APIs
|
|
214
|
+
local_file_cache_path: str = os.path.join(self.file_storage_path, "output", host_organization)
|
|
215
|
+
return LocalFileCache(local_file_cache_path, file_extension="png")
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional
|
|
6
|
+
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
10
|
+
from helm.clients.bedrock_utils import get_bedrock_client
|
|
11
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
JSON_CONTENT_TYPE = "application/json"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BedrockClient(CachingClient):
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
20
|
+
raise NotImplementedError()
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
24
|
+
raise NotImplementedError()
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
cache_config: CacheConfig,
|
|
29
|
+
tokenizer: Tokenizer,
|
|
30
|
+
tokenizer_name: str,
|
|
31
|
+
bedrock_model_id: Optional[str] = None,
|
|
32
|
+
assumed_role: Optional[str] = None,
|
|
33
|
+
region: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(cache_config=cache_config)
|
|
36
|
+
self.tokenizer = tokenizer
|
|
37
|
+
self.tokenizer_name = tokenizer_name
|
|
38
|
+
self.bedrock_model_id = bedrock_model_id
|
|
39
|
+
self.bedrock_client = get_bedrock_client(
|
|
40
|
+
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
|
|
41
|
+
region=region or os.environ.get("AWS_DEFAULT_REGION", None),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
45
|
+
# model_id should be something like "amazon.titan-tg1-large"
|
|
46
|
+
model_id = self.bedrock_model_id if self.bedrock_model_id else request.model.replace("/", ".")
|
|
47
|
+
raw_request = self.convert_request_to_raw_request(request)
|
|
48
|
+
|
|
49
|
+
# modelId isn't part of raw_request, so it must be explicitly passed into the input to
|
|
50
|
+
raw_request_for_cache: Dict = {"modelId": model_id, **deepcopy(raw_request)}
|
|
51
|
+
cache_key: Mapping = CachingClient.make_cache_key(raw_request_for_cache, request)
|
|
52
|
+
|
|
53
|
+
def do_it() -> Dict[Any, Any]:
|
|
54
|
+
response = self.bedrock_client.invoke_model(
|
|
55
|
+
body=json.dumps(raw_request), modelId=model_id, accept=JSON_CONTENT_TYPE, contentType=JSON_CONTENT_TYPE
|
|
56
|
+
)
|
|
57
|
+
return json.loads(response.get("body").read())
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
61
|
+
except Exception as error:
|
|
62
|
+
return RequestResult(
|
|
63
|
+
success=False,
|
|
64
|
+
cached=False,
|
|
65
|
+
error=str(error),
|
|
66
|
+
completions=[],
|
|
67
|
+
embedding=[],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
completions = self.convert_raw_response_to_completions(response, request)
|
|
71
|
+
|
|
72
|
+
return RequestResult(
|
|
73
|
+
success=True,
|
|
74
|
+
cached=cached,
|
|
75
|
+
request_time=response["request_time"],
|
|
76
|
+
request_datetime=response["request_datetime"],
|
|
77
|
+
completions=completions,
|
|
78
|
+
embedding=[],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BedrockTitanClient(BedrockClient):
|
|
83
|
+
_COMPLETION_REASON_TO_FINISH_REASON = {
|
|
84
|
+
"LENGTH": "length",
|
|
85
|
+
"FINISH": "endoftext",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
89
|
+
# TODO: Support the following:
|
|
90
|
+
# - top_k_per_token
|
|
91
|
+
# - echo_prompt
|
|
92
|
+
# - num_completions
|
|
93
|
+
return {
|
|
94
|
+
"inputText": request.prompt,
|
|
95
|
+
"textGenerationConfig": {
|
|
96
|
+
"maxTokenCount": request.max_tokens,
|
|
97
|
+
# We ignore stop sequences in the request and always set stop sequences to the empty list.
|
|
98
|
+
# This is because:
|
|
99
|
+
#
|
|
100
|
+
# 1. The only permitted stop sequences are "|" and "User:"
|
|
101
|
+
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
|
102
|
+
# - https://github.com/boto/boto3/issues/3993
|
|
103
|
+
# - https://github.com/aws/aws-sdk/issues/692
|
|
104
|
+
#
|
|
105
|
+
# 2. Titan has the tendency to emit "\n" as the first token in the generated text output,
|
|
106
|
+
# which would cause the output to stop immediately if "\n" is in the stop_sequences.
|
|
107
|
+
"stopSequences": [],
|
|
108
|
+
"temperature": request.temperature,
|
|
109
|
+
"topP": request.top_p,
|
|
110
|
+
},
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
114
|
+
# TODO: Support the following:
|
|
115
|
+
# - tokens
|
|
116
|
+
# - logprob
|
|
117
|
+
completions: List[GeneratedOutput] = []
|
|
118
|
+
for raw_completion in response["results"]:
|
|
119
|
+
output_text = raw_completion["outputText"]
|
|
120
|
+
# Call lstrip() Titan has the tendency to emit "\n" as the first token in the generated text output.
|
|
121
|
+
finish_reason = BedrockTitanClient._COMPLETION_REASON_TO_FINISH_REASON.get(
|
|
122
|
+
raw_completion["completionReason"], raw_completion["completionReason"].lower()
|
|
123
|
+
)
|
|
124
|
+
completion = truncate_and_tokenize_response_text(
|
|
125
|
+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
|
|
126
|
+
)
|
|
127
|
+
completions.append(completion)
|
|
128
|
+
return completions
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Helper utilities for working with Amazon Bedrock."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from helm.common.hierarchical_logger import hlog
|
|
7
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import boto3
|
|
11
|
+
from botocore.config import Config
|
|
12
|
+
except ModuleNotFoundError as e:
|
|
13
|
+
handle_module_not_found_error(e, ["aws"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# From https://github.com/aws-samples/amazon-bedrock-workshop/blob/main/01_Generation/00_generate_w_bedrock.ipynb
|
|
17
|
+
# MIT-0 Licensed
|
|
18
|
+
def get_bedrock_client(
|
|
19
|
+
assumed_role: Optional[str] = None,
|
|
20
|
+
region: Optional[str] = None,
|
|
21
|
+
runtime: Optional[bool] = True,
|
|
22
|
+
):
|
|
23
|
+
"""Create a boto3 client for Amazon Bedrock, with optional configuration overrides
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
assumed_role :
|
|
28
|
+
Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not
|
|
29
|
+
specified, the current active credentials will be used.
|
|
30
|
+
region :
|
|
31
|
+
Optional name of the AWS Region in which the service should be called (e.g. "us-east-1").
|
|
32
|
+
If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used.
|
|
33
|
+
runtime :
|
|
34
|
+
Optional choice of getting different client to perform operations with the Amazon Bedrock service.
|
|
35
|
+
"""
|
|
36
|
+
if region is None:
|
|
37
|
+
target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
|
|
38
|
+
else:
|
|
39
|
+
target_region = region
|
|
40
|
+
|
|
41
|
+
session_kwargs = {"region_name": target_region}
|
|
42
|
+
client_kwargs = {**session_kwargs}
|
|
43
|
+
|
|
44
|
+
profile_name = os.environ.get("AWS_PROFILE")
|
|
45
|
+
if profile_name:
|
|
46
|
+
session_kwargs["profile_name"] = profile_name
|
|
47
|
+
|
|
48
|
+
retry_config = Config(
|
|
49
|
+
region_name=target_region,
|
|
50
|
+
retries={
|
|
51
|
+
"max_attempts": 10,
|
|
52
|
+
"mode": "standard",
|
|
53
|
+
},
|
|
54
|
+
)
|
|
55
|
+
session = boto3.Session(**session_kwargs)
|
|
56
|
+
|
|
57
|
+
if assumed_role:
|
|
58
|
+
sts = session.client("sts")
|
|
59
|
+
response = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="crfm-helm")
|
|
60
|
+
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
|
|
61
|
+
client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"]
|
|
62
|
+
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
|
|
63
|
+
|
|
64
|
+
if runtime:
|
|
65
|
+
service_name = "bedrock-runtime"
|
|
66
|
+
else:
|
|
67
|
+
service_name = "bedrock"
|
|
68
|
+
|
|
69
|
+
bedrock_client = session.client(service_name=service_name, config=retry_config, **client_kwargs)
|
|
70
|
+
|
|
71
|
+
hlog(f"Amazon Bedrock client successfully created with endpoint {bedrock_client._endpoint}")
|
|
72
|
+
return bedrock_client
|