crfm-helm 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +60 -125
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +293 -229
- helm/benchmark/adaptation/adapter_spec.py +5 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
- helm/benchmark/annotation/aci_bench_annotator.py +11 -22
- helm/benchmark/annotation/air_bench_annotator.py +1 -1
- helm/benchmark/annotation/alrage_annotator.py +90 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
- helm/benchmark/annotation/dischargeme_annotator.py +11 -22
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/med_dialog_annotator.py +11 -22
- helm/benchmark/annotation/medalign_annotator.py +11 -22
- helm/benchmark/annotation/medi_qa_annotator.py +11 -22
- helm/benchmark/annotation/medication_qa_annotator.py +11 -22
- helm/benchmark/annotation/mental_health_annotator.py +11 -22
- helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
- helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
- helm/benchmark/annotation/model_as_judge.py +23 -18
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
- helm/benchmark/metrics/air_bench_metrics.py +3157 -1
- helm/benchmark/metrics/alrage_metric.py +35 -0
- helm/benchmark/metrics/basic_metrics.py +267 -2
- helm/benchmark/metrics/classification_metrics.py +19 -1
- helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
- helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
- helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
- helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
- helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
- helm/benchmark/metrics/comet_metric.py +1 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
- helm/benchmark/metrics/copyright_metrics.py +1 -1
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
- helm/benchmark/metrics/dry_run_metrics.py +30 -1
- helm/benchmark/metrics/efficiency_metrics.py +74 -0
- helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
- helm/benchmark/metrics/evaluate_reference_metrics.py +300 -1
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
- helm/benchmark/metrics/ifeval_metrics.py +13 -1
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
- helm/benchmark/metrics/language_modeling_metrics.py +13 -1
- helm/benchmark/metrics/live_qa_metrics.py +13 -1
- helm/benchmark/metrics/llm_jury_metrics.py +13 -1
- helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
- helm/benchmark/metrics/lmkt_metrics.py +47 -0
- helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
- helm/benchmark/metrics/medec_metrics.py +25 -2
- helm/benchmark/metrics/melt_toxicity_metric.py +1 -1
- helm/benchmark/metrics/metric.py +25 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
- helm/benchmark/metrics/omni_math_metrics.py +13 -1
- helm/benchmark/metrics/seahelm_metrics.py +14 -1
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/summarization_metrics.py +129 -1
- helm/benchmark/metrics/toxicity_metrics.py +31 -1
- helm/benchmark/metrics/wildbench_metrics.py +21 -1
- helm/benchmark/model_deployment_registry.py +11 -19
- helm/benchmark/presentation/create_plots.py +11 -2
- helm/benchmark/presentation/schema.py +10 -22
- helm/benchmark/presentation/summarize.py +189 -14
- helm/benchmark/presentation/taxonomy_info.py +20 -0
- helm/benchmark/presentation/test_create_plots.py +4 -1
- helm/benchmark/run.py +7 -1
- helm/benchmark/run_expander.py +4 -0
- helm/benchmark/run_specs/arabic_run_specs.py +191 -0
- helm/benchmark/run_specs/bluex_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +2 -55
- helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
- helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
- helm/benchmark/run_specs/heim_run_specs.py +3 -1
- helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
- helm/benchmark/run_specs/long_context_run_specs.py +48 -1
- helm/benchmark/run_specs/medhelm/__init__.py +0 -0
- helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
- helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +5 -11
- helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
- helm/benchmark/scenarios/air_bench_scenario.py +21 -0
- helm/benchmark/scenarios/alghafa_scenario.py +126 -0
- helm/benchmark/scenarios/alrage_scenario.py +54 -0
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
- helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +82 -0
- helm/benchmark/scenarios/aratrust_scenario.py +95 -0
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/{ultra_suite_asr_classification.py → ultra_suite_asr_classification_scenario.py} +9 -8
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +13 -5
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +13 -5
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +13 -5
- helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
- helm/benchmark/scenarios/bbq_scenario.py +15 -0
- helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
- helm/benchmark/scenarios/bluex_scenario.py +70 -0
- helm/benchmark/scenarios/bold_scenario.py +15 -0
- helm/benchmark/scenarios/boolq_scenario.py +20 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
- helm/benchmark/scenarios/clear_scenario.py +23 -0
- helm/benchmark/scenarios/cleva_scenario.py +480 -1
- helm/benchmark/scenarios/code_scenario.py +28 -0
- helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
- helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
- helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
- helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
- helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
- helm/benchmark/scenarios/commonsense_scenario.py +26 -0
- helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
- helm/benchmark/scenarios/copyright_scenario.py +35 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
- helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
- helm/benchmark/scenarios/disinformation_scenario.py +22 -0
- helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
- helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
- helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
- helm/benchmark/scenarios/gpqa_scenario.py +18 -0
- helm/benchmark/scenarios/grammar_scenario.py +20 -1
- helm/benchmark/scenarios/gsm_scenario.py +15 -0
- helm/benchmark/scenarios/headqa_scenario.py +22 -0
- helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
- helm/benchmark/scenarios/ice_scenario.py +21 -1
- helm/benchmark/scenarios/ifeval_scenario.py +18 -0
- helm/benchmark/scenarios/imdb_scenario.py +15 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +1 -1
- helm/benchmark/scenarios/koala_scenario.py +21 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
- helm/benchmark/scenarios/legal_support_scenario.py +13 -0
- helm/benchmark/scenarios/legalbench_scenario.py +20 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
- helm/benchmark/scenarios/lextreme_scenario.py +11 -0
- helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
- helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
- helm/benchmark/scenarios/math_scenario.py +47 -20
- helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
- helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
- helm/benchmark/scenarios/med_qa_scenario.py +14 -0
- helm/benchmark/scenarios/medalign_scenario.py +23 -0
- helm/benchmark/scenarios/medalign_scenario_helper.py +19 -125
- helm/benchmark/scenarios/medbullets_scenario.py +22 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
- helm/benchmark/scenarios/medec_scenario.py +23 -0
- helm/benchmark/scenarios/medhallu_scenario.py +23 -0
- helm/benchmark/scenarios/medhelm/__init__.py +0 -0
- helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
- helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
- helm/benchmark/scenarios/melt_scenarios.py +2 -2
- helm/benchmark/scenarios/mental_health_scenario.py +23 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +25 -1
- helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
- helm/benchmark/scenarios/mmlu_scenario.py +15 -0
- helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
- helm/benchmark/scenarios/msmarco_scenario.py +30 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
- helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
- helm/benchmark/scenarios/omni_math_scenario.py +18 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
- helm/benchmark/scenarios/quac_scenario.py +14 -0
- helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
- helm/benchmark/scenarios/raft_scenario.py +15 -0
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
- helm/benchmark/scenarios/scenario.py +31 -0
- helm/benchmark/scenarios/seahelm_scenario.py +350 -2
- helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
- helm/benchmark/scenarios/situation_prompts.yaml +49 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
- helm/benchmark/scenarios/summarization_scenario.py +37 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
- helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
- helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
- helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
- helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
- helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
- helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
- helm/benchmark/scenarios/the_pile_scenario.py +13 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
- helm/benchmark/scenarios/vicuna_scenario.py +21 -1
- helm/benchmark/scenarios/wikifact_scenario.py +20 -0
- helm/benchmark/scenarios/wildbench_scenario.py +18 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
- helm/benchmark/slurm_jobs.py +1 -2
- helm/benchmark/slurm_runner.py +8 -1
- helm/benchmark/static/schema_arabic.yaml +271 -0
- helm/benchmark/static/schema_classic.yaml +0 -17
- helm/benchmark/static/schema_long_context.yaml +24 -6
- helm/benchmark/static/schema_medhelm.yaml +36 -0
- helm/benchmark/static/schema_slp.yaml +219 -0
- helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
- helm/benchmark/static_build/assets/index-9352595e.css +1 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
- helm/clients/audio_language/llama_omni/arguments.py +61 -0
- helm/clients/audio_language/llama_omni/constants.py +9 -0
- helm/clients/audio_language/llama_omni/conversation.py +213 -0
- helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
- helm/clients/audio_language/llama_omni/model/builder.py +88 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
- helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
- helm/clients/audio_language/llama_omni/preprocess.py +295 -0
- helm/clients/audio_language/llama_omni/utils.py +202 -0
- helm/clients/audio_language/qwen2_5_omni_client.py +19 -7
- helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
- helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
- helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
- helm/clients/huggingface_client.py +2 -2
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
- helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
- helm/clients/openai_client.py +33 -20
- helm/clients/openai_responses_client.py +34 -8
- helm/clients/openrouter_client.py +31 -0
- helm/clients/test_huggingface_client.py +3 -3
- helm/clients/test_openrouter_client.py +69 -0
- helm/clients/together_client.py +48 -13
- helm/clients/vertexai_client.py +19 -11
- helm/clients/vllm_client.py +43 -7
- helm/clients/vllm_granite_thinking_client.py +56 -0
- helm/common/critique_request.py +0 -1
- helm/common/hierarchical_logger.py +83 -34
- helm/common/object_spec.py +23 -8
- helm/common/test_logging.py +94 -0
- helm/config/model_deployments.yaml +525 -172
- helm/config/model_metadata.yaml +185 -10
- helm/config/tokenizer_configs.yaml +100 -2
- helm/proxy/cli.py +1 -1
- helm/proxy/example_queries.py +8 -8
- helm/proxy/retry.py +5 -0
- helm/proxy/server.py +2 -1
- helm/proxy/static/index.css +4 -0
- helm/proxy/static/index.js +7 -1
- helm/tokenizers/grok_tokenizer.py +2 -0
- helm/benchmark/metrics/aci_bench_metrics.py +0 -14
- helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
- helm/benchmark/metrics/dischargeme_metrics.py +0 -14
- helm/benchmark/metrics/med_dialog_metrics.py +0 -14
- helm/benchmark/metrics/medalign_metrics.py +0 -14
- helm/benchmark/metrics/medi_qa_metrics.py +0 -14
- helm/benchmark/metrics/medication_qa_metrics.py +0 -14
- helm/benchmark/metrics/mental_health_metrics.py +0 -14
- helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
- helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
- helm/benchmark/metrics/numeracy_metrics.py +0 -72
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
- helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
- helm/benchmark/scenarios/numeracy_scenario.py +0 -794
- helm/benchmark/static_build/assets/index-94295e78.js +0 -10
- helm/benchmark/static_build/assets/index-b9779128.css +0 -1
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import logging
|
|
5
|
+
import math
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
import warnings
|
|
10
|
+
from functools import lru_cache
|
|
11
|
+
from io import BytesIO
|
|
12
|
+
|
|
13
|
+
import requests
|
|
14
|
+
import torch
|
|
15
|
+
import torchvision
|
|
16
|
+
from packaging import version
|
|
17
|
+
from PIL import Image
|
|
18
|
+
from torchvision import io, transforms
|
|
19
|
+
from torchvision.transforms import InterpolationMode
|
|
20
|
+
from typing import List, Optional, Union
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
IMAGE_FACTOR = 28
|
|
26
|
+
MIN_PIXELS = 4 * 28 * 28
|
|
27
|
+
MAX_PIXELS = 16384 * 28 * 28
|
|
28
|
+
MAX_RATIO = 200
|
|
29
|
+
|
|
30
|
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
|
31
|
+
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
|
32
|
+
FRAME_FACTOR = 2
|
|
33
|
+
FPS = 2.0
|
|
34
|
+
FPS_MIN_FRAMES = 4
|
|
35
|
+
FPS_MAX_FRAMES = 768
|
|
36
|
+
|
|
37
|
+
# Set the maximum number of video token inputs.
|
|
38
|
+
# Here, 128K represents the maximum number of input tokens for the VLLM model.
|
|
39
|
+
# Remember to adjust it according to your own configuration.
|
|
40
|
+
VIDEO_TOTAL_PIXELS = int(float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9)))
|
|
41
|
+
logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def round_by_factor(number: int, factor: int) -> int:
|
|
45
|
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
46
|
+
return round(number / factor) * factor
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def ceil_by_factor(number: int, factor: int) -> int:
|
|
50
|
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
51
|
+
return math.ceil(number / factor) * factor
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def floor_by_factor(number: int, factor: int) -> int:
|
|
55
|
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
56
|
+
return math.floor(number / factor) * factor
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def smart_resize(
|
|
60
|
+
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
|
61
|
+
) -> tuple[int, int]:
|
|
62
|
+
"""
|
|
63
|
+
Rescales the image so that the following conditions are met:
|
|
64
|
+
|
|
65
|
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
66
|
+
|
|
67
|
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
68
|
+
|
|
69
|
+
3. The aspect ratio of the image is maintained as closely as possible.
|
|
70
|
+
"""
|
|
71
|
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
|
74
|
+
)
|
|
75
|
+
h_bar = max(factor, round_by_factor(height, factor))
|
|
76
|
+
w_bar = max(factor, round_by_factor(width, factor))
|
|
77
|
+
if h_bar * w_bar > max_pixels:
|
|
78
|
+
beta = math.sqrt((height * width) / max_pixels)
|
|
79
|
+
h_bar = floor_by_factor(int(height / beta), factor)
|
|
80
|
+
w_bar = floor_by_factor(int(width / beta), factor)
|
|
81
|
+
elif h_bar * w_bar < min_pixels:
|
|
82
|
+
beta = math.sqrt(min_pixels / (height * width))
|
|
83
|
+
h_bar = ceil_by_factor(int(height * beta), factor)
|
|
84
|
+
w_bar = ceil_by_factor(int(width * beta), factor)
|
|
85
|
+
return h_bar, w_bar
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
|
89
|
+
if pil_image.mode == "RGBA":
|
|
90
|
+
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
|
91
|
+
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
|
92
|
+
return white_background
|
|
93
|
+
else:
|
|
94
|
+
return pil_image.convert("RGB")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def fetch_image(ele, size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
|
98
|
+
if "image" in ele:
|
|
99
|
+
image = ele["image"]
|
|
100
|
+
else:
|
|
101
|
+
image = ele["image_url"]
|
|
102
|
+
image_obj = None
|
|
103
|
+
if isinstance(image, Image.Image):
|
|
104
|
+
image_obj = image
|
|
105
|
+
elif image.startswith("http://") or image.startswith("https://"):
|
|
106
|
+
response = requests.get(image, stream=True)
|
|
107
|
+
image_obj = Image.open(BytesIO(response.content))
|
|
108
|
+
elif image.startswith("file://"):
|
|
109
|
+
image_obj = Image.open(image[7:])
|
|
110
|
+
elif image.startswith("data:image"):
|
|
111
|
+
if "base64," in image:
|
|
112
|
+
_, base64_data = image.split("base64,", 1)
|
|
113
|
+
data = base64.b64decode(base64_data)
|
|
114
|
+
image_obj = Image.open(BytesIO(data))
|
|
115
|
+
else:
|
|
116
|
+
image_obj = Image.open(image)
|
|
117
|
+
if image_obj is None:
|
|
118
|
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
|
119
|
+
image = to_rgb(image_obj)
|
|
120
|
+
# resize
|
|
121
|
+
if "resized_height" in ele and "resized_width" in ele:
|
|
122
|
+
resized_height, resized_width = smart_resize(
|
|
123
|
+
int(ele["resized_height"]),
|
|
124
|
+
int(ele["resized_width"]),
|
|
125
|
+
factor=size_factor,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
width, height = image.size
|
|
129
|
+
min_pixels = int(ele.get("min_pixels", MIN_PIXELS))
|
|
130
|
+
max_pixels = int(ele.get("max_pixels", MAX_PIXELS))
|
|
131
|
+
resized_height, resized_width = smart_resize(
|
|
132
|
+
height,
|
|
133
|
+
width,
|
|
134
|
+
factor=size_factor,
|
|
135
|
+
min_pixels=min_pixels,
|
|
136
|
+
max_pixels=max_pixels,
|
|
137
|
+
)
|
|
138
|
+
image = image.resize((resized_width, resized_height))
|
|
139
|
+
|
|
140
|
+
return image
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def smart_nframes(
|
|
144
|
+
ele: dict,
|
|
145
|
+
total_frames: int,
|
|
146
|
+
video_fps: Union[int, float],
|
|
147
|
+
) -> int:
|
|
148
|
+
"""calculate the number of frames for video used for model inputs.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
ele (dict): a dict contains the configuration of video.
|
|
152
|
+
support either `fps` or `nframes`:
|
|
153
|
+
- nframes: the number of frames to extract for model inputs.
|
|
154
|
+
- fps: the fps to extract frames for model inputs.
|
|
155
|
+
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
|
156
|
+
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
|
157
|
+
total_frames (int): the original total number of frames of the video.
|
|
158
|
+
video_fps (int | float): the original fps of the video.
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
int: the number of frames for video used for model inputs.
|
|
165
|
+
"""
|
|
166
|
+
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
|
|
167
|
+
if "nframes" in ele:
|
|
168
|
+
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
|
169
|
+
else:
|
|
170
|
+
fps = ele.get("fps", FPS)
|
|
171
|
+
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
|
172
|
+
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
|
|
173
|
+
nframes = total_frames / video_fps * fps
|
|
174
|
+
if nframes > total_frames:
|
|
175
|
+
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
|
|
176
|
+
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
|
|
177
|
+
nframes = floor_by_factor(nframes, FRAME_FACTOR)
|
|
178
|
+
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
|
179
|
+
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
|
|
180
|
+
return nframes
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _read_video_torchvision(
|
|
184
|
+
ele: dict,
|
|
185
|
+
):
|
|
186
|
+
"""read video using torchvision.io.read_video
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
ele (dict): a dict contains the configuration of video.
|
|
190
|
+
support keys:
|
|
191
|
+
- video: the path of video. support "file://", "http://", "https://" and local path.
|
|
192
|
+
- video_start: the start time of video.
|
|
193
|
+
- video_end: the end time of video.
|
|
194
|
+
Returns:
|
|
195
|
+
torch.Tensor: the video tensor with shape (T, C, H, W).
|
|
196
|
+
"""
|
|
197
|
+
video_path = ele["video"]
|
|
198
|
+
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
|
199
|
+
if "http://" in video_path or "https://" in video_path:
|
|
200
|
+
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
|
|
201
|
+
if "file://" in video_path:
|
|
202
|
+
video_path = video_path[7:]
|
|
203
|
+
st = time.time()
|
|
204
|
+
video, audio, info = io.read_video(
|
|
205
|
+
video_path,
|
|
206
|
+
start_pts=ele.get("video_start", 0.0),
|
|
207
|
+
end_pts=ele.get("video_end", None),
|
|
208
|
+
pts_unit="sec",
|
|
209
|
+
output_format="TCHW",
|
|
210
|
+
)
|
|
211
|
+
total_frames, video_fps = video.size(0), info["video_fps"]
|
|
212
|
+
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
|
213
|
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
|
214
|
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
|
215
|
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
|
216
|
+
video = video[idx]
|
|
217
|
+
return video, sample_fps
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def is_decord_available() -> bool:
|
|
221
|
+
import importlib.util
|
|
222
|
+
|
|
223
|
+
return importlib.util.find_spec("decord") is not None
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _read_video_decord(
|
|
227
|
+
ele: dict,
|
|
228
|
+
):
|
|
229
|
+
"""read video using decord.VideoReader
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
ele (dict): a dict contains the configuration of video.
|
|
233
|
+
support keys:
|
|
234
|
+
- video: the path of video. support "file://", "http://", "https://" and local path.
|
|
235
|
+
- video_start: the start time of video.
|
|
236
|
+
- video_end: the end time of video.
|
|
237
|
+
Returns:
|
|
238
|
+
torch.Tensor: the video tensor with shape (T, C, H, W).
|
|
239
|
+
"""
|
|
240
|
+
import decord
|
|
241
|
+
|
|
242
|
+
video_path = ele["video"]
|
|
243
|
+
st = time.time()
|
|
244
|
+
vr = decord.VideoReader(video_path)
|
|
245
|
+
# TODO: support start_pts and end_pts
|
|
246
|
+
if "video_start" in ele or "video_end" in ele:
|
|
247
|
+
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
|
|
248
|
+
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
|
249
|
+
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
|
250
|
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
|
251
|
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
|
252
|
+
video = vr.get_batch(idx).asnumpy()
|
|
253
|
+
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
|
254
|
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
|
255
|
+
return video, sample_fps
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
VIDEO_READER_BACKENDS = {
|
|
259
|
+
"decord": _read_video_decord,
|
|
260
|
+
"torchvision": _read_video_torchvision,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@lru_cache(maxsize=1)
|
|
267
|
+
def get_video_reader_backend() -> str:
|
|
268
|
+
if FORCE_QWENVL_VIDEO_READER is not None:
|
|
269
|
+
video_reader_backend = FORCE_QWENVL_VIDEO_READER
|
|
270
|
+
elif is_decord_available():
|
|
271
|
+
video_reader_backend = "decord"
|
|
272
|
+
else:
|
|
273
|
+
video_reader_backend = "torchvision"
|
|
274
|
+
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
|
|
275
|
+
return video_reader_backend
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False):
|
|
279
|
+
if isinstance(ele["video"], str):
|
|
280
|
+
video_reader_backend = get_video_reader_backend()
|
|
281
|
+
try:
|
|
282
|
+
video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
|
283
|
+
except Exception as e:
|
|
284
|
+
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
|
|
285
|
+
video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
|
|
286
|
+
|
|
287
|
+
nframes, _, height, width = video.shape
|
|
288
|
+
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
|
289
|
+
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
|
290
|
+
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
|
|
291
|
+
max_pixels_supposed = ele.get("max_pixels", max_pixels)
|
|
292
|
+
if max_pixels_supposed > max_pixels:
|
|
293
|
+
logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
|
|
294
|
+
max_pixels = min(max_pixels_supposed, max_pixels)
|
|
295
|
+
if "resized_height" in ele and "resized_width" in ele:
|
|
296
|
+
resized_height, resized_width = smart_resize(
|
|
297
|
+
ele["resized_height"],
|
|
298
|
+
ele["resized_width"],
|
|
299
|
+
factor=image_factor,
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
resized_height, resized_width = smart_resize(
|
|
303
|
+
height,
|
|
304
|
+
width,
|
|
305
|
+
factor=image_factor,
|
|
306
|
+
min_pixels=min_pixels,
|
|
307
|
+
max_pixels=max_pixels,
|
|
308
|
+
)
|
|
309
|
+
video = transforms.functional.resize(
|
|
310
|
+
video,
|
|
311
|
+
[resized_height, resized_width],
|
|
312
|
+
interpolation=InterpolationMode.BICUBIC,
|
|
313
|
+
antialias=True,
|
|
314
|
+
).float()
|
|
315
|
+
if return_video_sample_fps:
|
|
316
|
+
return video, sample_fps
|
|
317
|
+
return video
|
|
318
|
+
else:
|
|
319
|
+
assert isinstance(ele["video"], (list, tuple))
|
|
320
|
+
process_info = ele.copy()
|
|
321
|
+
process_info.pop("type", None)
|
|
322
|
+
process_info.pop("video", None)
|
|
323
|
+
images = [
|
|
324
|
+
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
|
|
325
|
+
for video_element in ele["video"]
|
|
326
|
+
]
|
|
327
|
+
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
|
328
|
+
if len(images) < nframes:
|
|
329
|
+
images.extend([images[-1]] * (nframes - len(images)))
|
|
330
|
+
if return_video_sample_fps:
|
|
331
|
+
return images, process_info.pop("fps", 2.0)
|
|
332
|
+
return images
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def extract_vision_info(conversations) -> list[dict]:
|
|
336
|
+
vision_infos = []
|
|
337
|
+
if isinstance(conversations[0], dict):
|
|
338
|
+
conversations_p = [conversations]
|
|
339
|
+
for conversation in conversations_p:
|
|
340
|
+
for message in conversation:
|
|
341
|
+
if isinstance(message["content"], list):
|
|
342
|
+
for ele in message["content"]:
|
|
343
|
+
if (
|
|
344
|
+
"image" in ele
|
|
345
|
+
or "image_url" in ele
|
|
346
|
+
or "video" in ele
|
|
347
|
+
or ele["type"] in ("image", "image_url", "video")
|
|
348
|
+
):
|
|
349
|
+
vision_infos.append(ele)
|
|
350
|
+
return vision_infos
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def process_vision_info(
|
|
354
|
+
conversations: list[dict] | list[list[dict]],
|
|
355
|
+
return_video_kwargs: bool = False,
|
|
356
|
+
):
|
|
357
|
+
|
|
358
|
+
vision_infos = extract_vision_info(conversations)
|
|
359
|
+
# Read images or videos
|
|
360
|
+
image_inputs: Optional[List] = []
|
|
361
|
+
video_inputs: Optional[List] = []
|
|
362
|
+
video_sample_fps_list = []
|
|
363
|
+
for vision_info in vision_infos:
|
|
364
|
+
if "image" in vision_info or "image_url" in vision_info:
|
|
365
|
+
assert image_inputs is not None
|
|
366
|
+
image_inputs.append(fetch_image(vision_info))
|
|
367
|
+
elif "video" in vision_info:
|
|
368
|
+
assert video_inputs is not None
|
|
369
|
+
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
|
|
370
|
+
video_sample_fps_list.append(video_sample_fps)
|
|
371
|
+
video_inputs.append(video_input)
|
|
372
|
+
else:
|
|
373
|
+
raise ValueError("image, image_url or video should in content.")
|
|
374
|
+
if image_inputs is not None and len(image_inputs) == 0:
|
|
375
|
+
image_inputs = None
|
|
376
|
+
if video_inputs is not None and len(video_inputs) == 0:
|
|
377
|
+
video_inputs = None
|
|
378
|
+
if return_video_kwargs:
|
|
379
|
+
return image_inputs, video_inputs, {"fps": video_sample_fps_list}
|
|
380
|
+
return image_inputs, video_inputs
|
|
@@ -293,12 +293,12 @@ class HuggingFaceClient(CachingClient):
|
|
|
293
293
|
if self._apply_chat_template:
|
|
294
294
|
with self._wrapped_tokenizer as tokenizer:
|
|
295
295
|
if request.messages:
|
|
296
|
-
prompt = tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
296
|
+
prompt = tokenizer.apply_chat_template(request.messages, tokenize=False, add_generation_prompt=True)
|
|
297
297
|
assert isinstance(prompt, str)
|
|
298
298
|
return prompt
|
|
299
299
|
else:
|
|
300
300
|
prompt = tokenizer.apply_chat_template(
|
|
301
|
-
[{"role": "user", "content": request.prompt}], tokenize=False
|
|
301
|
+
[{"role": "user", "content": request.prompt}], tokenize=False, add_generation_prompt=True
|
|
302
302
|
)
|
|
303
303
|
assert isinstance(prompt, str)
|
|
304
304
|
return prompt
|
|
@@ -141,7 +141,7 @@ class Encoder(nn.Module):
|
|
|
141
141
|
in_channels: int,
|
|
142
142
|
resolution: int,
|
|
143
143
|
z_channels: int,
|
|
144
|
-
double_z: Optional[bool] = None
|
|
144
|
+
double_z: Optional[bool] = None,
|
|
145
145
|
) -> None:
|
|
146
146
|
super().__init__()
|
|
147
147
|
self.ch = ch
|
|
@@ -232,7 +232,7 @@ class Decoder(nn.Module):
|
|
|
232
232
|
in_channels: int,
|
|
233
233
|
resolution: int,
|
|
234
234
|
z_channels: int,
|
|
235
|
-
double_z: bool
|
|
235
|
+
double_z: bool,
|
|
236
236
|
) -> None:
|
|
237
237
|
super().__init__()
|
|
238
238
|
self.ch = ch
|
helm/clients/openai_client.py
CHANGED
|
@@ -33,9 +33,12 @@ class OpenAIClientUtils:
|
|
|
33
33
|
@classmethod
|
|
34
34
|
def is_reasoning_model(cls, model_engine: str) -> bool:
|
|
35
35
|
# All OpenAI reasoning models start "o[somenumber]", so we regexp for that to future proof things
|
|
36
|
-
return bool(re.match(r"^o\d+", model_engine))
|
|
36
|
+
return bool(re.match(r"^o\d+", model_engine)) or bool(re.match(r"^gpt-5", model_engine))
|
|
37
37
|
|
|
38
38
|
# Error OpenAI throws when the image in the prompt violates their content policy
|
|
39
|
+
HARMFUL_INFORMATION_ERROR: str = (
|
|
40
|
+
"Invalid prompt: we've limited access to this content for safety reasons. This type of information may be used to benefit or to harm people." # noqa: E501
|
|
41
|
+
)
|
|
39
42
|
INAPPROPRIATE_IMAGE_ERROR: str = "Your input image may contain content that is not allowed by our safety system"
|
|
40
43
|
INAPPROPRIATE_PROMPT_ERROR: str = "Invalid prompt: your prompt was flagged"
|
|
41
44
|
INAPPROPRIATE_PROMPT_AZURE_ERROR: str = (
|
|
@@ -44,12 +47,10 @@ class OpenAIClientUtils:
|
|
|
44
47
|
INAPPROPRIATE_PROMPT_MICROSOFT_ERROR: str = (
|
|
45
48
|
"The response was filtered due to the prompt triggering Microsoft's content management policy."
|
|
46
49
|
)
|
|
47
|
-
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
"or contact us through our help center at help.openai.com if you keep seeing this error."
|
|
52
|
-
)
|
|
50
|
+
# Grok content safety guidelines error message
|
|
51
|
+
# TODO: Refactor so that this is owned by the Grok client instead.
|
|
52
|
+
SAFETY_GUIDELINES_GROK_ERROR: str = "Content violates safety guidelines."
|
|
53
|
+
USAGE_GUIDELINES_GROK_ERROR: str = "Content violates usage guidelines."
|
|
53
54
|
|
|
54
55
|
# Set the finish reason to this if the prompt violates OpenAI's content policy
|
|
55
56
|
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
|
|
@@ -74,21 +75,14 @@ class OpenAIClientUtils:
|
|
|
74
75
|
completions=[empty_completion] * request.num_completions,
|
|
75
76
|
embedding=[],
|
|
76
77
|
)
|
|
77
|
-
elif cls.
|
|
78
|
-
# Handle these errors by returning an empty completion to unblock
|
|
79
|
-
hwarn(f"OpenAI server error for request: {str(request)}")
|
|
80
|
-
empty_completion = GeneratedOutput(
|
|
81
|
-
text="",
|
|
82
|
-
logprob=0,
|
|
83
|
-
tokens=[],
|
|
84
|
-
finish_reason={"reason": cls.OPENAI_SERVER_ERROR},
|
|
85
|
-
)
|
|
78
|
+
elif cls.HARMFUL_INFORMATION_ERROR in str(e):
|
|
86
79
|
return RequestResult(
|
|
87
|
-
success=
|
|
80
|
+
success=False,
|
|
88
81
|
cached=False,
|
|
89
|
-
|
|
90
|
-
completions=[
|
|
82
|
+
error="Prompt blocked by OpenAI's safety filter",
|
|
83
|
+
completions=[],
|
|
91
84
|
embedding=[],
|
|
85
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
92
86
|
)
|
|
93
87
|
elif cls.INAPPROPRIATE_PROMPT_AZURE_ERROR in str(e) or cls.INAPPROPRIATE_PROMPT_MICROSOFT_ERROR in str(e):
|
|
94
88
|
return RequestResult(
|
|
@@ -99,6 +93,24 @@ class OpenAIClientUtils:
|
|
|
99
93
|
embedding=[],
|
|
100
94
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
101
95
|
)
|
|
96
|
+
elif cls.SAFETY_GUIDELINES_GROK_ERROR in str(e):
|
|
97
|
+
return RequestResult(
|
|
98
|
+
success=False,
|
|
99
|
+
cached=False,
|
|
100
|
+
error="Grok API error: Content violates safety guidelines",
|
|
101
|
+
completions=[],
|
|
102
|
+
embedding=[],
|
|
103
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
104
|
+
)
|
|
105
|
+
elif cls.USAGE_GUIDELINES_GROK_ERROR in str(e):
|
|
106
|
+
return RequestResult(
|
|
107
|
+
success=False,
|
|
108
|
+
cached=False,
|
|
109
|
+
error="Grok API error: Content violates usage guidelines",
|
|
110
|
+
completions=[],
|
|
111
|
+
embedding=[],
|
|
112
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
113
|
+
)
|
|
102
114
|
|
|
103
115
|
error: str = f"OpenAI error: {e}"
|
|
104
116
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
@@ -118,11 +130,12 @@ class OpenAIClient(CachingClient):
|
|
|
118
130
|
reasoning_effort: Optional[str] = None,
|
|
119
131
|
openai_model_name: Optional[str] = None,
|
|
120
132
|
output_processor: Optional[str] = None,
|
|
133
|
+
**kwargs,
|
|
121
134
|
):
|
|
122
135
|
super().__init__(cache_config=cache_config)
|
|
123
136
|
self.tokenizer = tokenizer
|
|
124
137
|
self.tokenizer_name = tokenizer_name
|
|
125
|
-
self.client = OpenAI(api_key=api_key, organization=org_id, base_url=base_url)
|
|
138
|
+
self.client = OpenAI(api_key=api_key, organization=org_id, base_url=base_url, **kwargs)
|
|
126
139
|
self.reasoning_effort = reasoning_effort
|
|
127
140
|
self.openai_model_name = openai_model_name
|
|
128
141
|
self.output_processor: Optional[Callable[[str], str]] = (
|
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
5
5
|
|
|
6
6
|
from helm.clients.openai_client import OpenAIClientUtils
|
|
7
7
|
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.common.hierarchical_logger import hwarn
|
|
8
9
|
from helm.common.media_object import TEXT_TYPE
|
|
9
10
|
from helm.common.request import (
|
|
10
11
|
Thinking,
|
|
@@ -60,7 +61,28 @@ class OpenAIResponseClient(CachingClient):
|
|
|
60
61
|
|
|
61
62
|
def _make_raw_request(self, request: Request) -> dict[str, Any]:
|
|
62
63
|
input: Union[str, List[Dict[str, Any]]]
|
|
63
|
-
|
|
64
|
+
|
|
65
|
+
if (
|
|
66
|
+
(request.prompt and request.messages)
|
|
67
|
+
or (request.prompt and request.multimodal_prompt)
|
|
68
|
+
or (request.messages and request.multimodal_prompt)
|
|
69
|
+
):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"More than one of `prompt`, `messages` and `multimodal_prompt` was set in request: {request}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if request.messages is not None:
|
|
75
|
+
# Checks that all messages have a role and some content
|
|
76
|
+
for message in request.messages:
|
|
77
|
+
if not message.get("role") or not message.get("content"):
|
|
78
|
+
raise ValueError("All messages must have a role and content")
|
|
79
|
+
# Checks that the last role is "user"
|
|
80
|
+
if request.messages[-1]["role"] != "user":
|
|
81
|
+
raise ValueError("Last message must have role 'user'")
|
|
82
|
+
if request.prompt != "":
|
|
83
|
+
hwarn("Since message is set, prompt will be ignored")
|
|
84
|
+
input = request.messages
|
|
85
|
+
elif request.multimodal_prompt is not None:
|
|
64
86
|
content = []
|
|
65
87
|
request.validate()
|
|
66
88
|
for media_object in request.multimodal_prompt.media_objects:
|
|
@@ -101,6 +123,8 @@ class OpenAIResponseClient(CachingClient):
|
|
|
101
123
|
# Plus other changes
|
|
102
124
|
model_engine: str = request.model_engine
|
|
103
125
|
if OpenAIClientUtils.is_reasoning_model(model_engine):
|
|
126
|
+
if "reasoning" not in raw_request:
|
|
127
|
+
raw_request["reasoning"] = {}
|
|
104
128
|
raw_request["reasoning"]["summary"] = "detailed"
|
|
105
129
|
# Avoid error:
|
|
106
130
|
# "Error code: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is
|
|
@@ -145,13 +169,15 @@ class OpenAIResponseClient(CachingClient):
|
|
|
145
169
|
if request.echo_prompt:
|
|
146
170
|
text_output += request.prompt
|
|
147
171
|
for output in response["output"]:
|
|
148
|
-
output_type = output[
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
172
|
+
output_type = output[
|
|
173
|
+
"type"
|
|
174
|
+
] # one of "message" or "reasoning" from API observation, but can also include tool calls
|
|
175
|
+
|
|
176
|
+
if output_type == "reasoning":
|
|
177
|
+
reasoning_output += "\n\n".join([raw_output["text"] for raw_output in output["summary"]])
|
|
178
|
+
elif output_type == "message":
|
|
179
|
+
text_output += "\n\n".join([raw_output["text"] for raw_output in output["content"]])
|
|
180
|
+
# (Other output types are ignored)
|
|
155
181
|
|
|
156
182
|
completion = truncate_and_tokenize_response_text(
|
|
157
183
|
text_output,
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from helm.clients.openai_client import OpenAIClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OpenRouterClient(OpenAIClient):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
tokenizer_name: str,
|
|
12
|
+
tokenizer: Tokenizer,
|
|
13
|
+
cache_config: CacheConfig,
|
|
14
|
+
api_key: Optional[str] = None,
|
|
15
|
+
model_name: Optional[str] = None,
|
|
16
|
+
output_processor: Optional[str] = None,
|
|
17
|
+
):
|
|
18
|
+
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
19
|
+
self.base_url = "https://openrouter.ai/api/v1/"
|
|
20
|
+
super().__init__(
|
|
21
|
+
tokenizer,
|
|
22
|
+
tokenizer_name,
|
|
23
|
+
cache_config=cache_config,
|
|
24
|
+
output_processor=output_processor,
|
|
25
|
+
base_url=self.base_url,
|
|
26
|
+
api_key=self.api_key,
|
|
27
|
+
)
|
|
28
|
+
self.model_name = model_name
|
|
29
|
+
|
|
30
|
+
def _get_model_for_request(self, request):
|
|
31
|
+
return self.model_name or request.model
|
|
@@ -9,7 +9,7 @@ from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
|
9
9
|
class TestHuggingFaceClient:
|
|
10
10
|
def test_gpt2(self):
|
|
11
11
|
tokenizer = HuggingFaceTokenizer(
|
|
12
|
-
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
12
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai-community/gpt2"
|
|
13
13
|
)
|
|
14
14
|
client = HuggingFaceClient(
|
|
15
15
|
cache_config=BlackHoleCacheConfig(),
|
|
@@ -36,7 +36,7 @@ class TestHuggingFaceClient:
|
|
|
36
36
|
@pytest.mark.skip(reason="GPT-J 6B is 22 GB and extremely slow without a GPU.")
|
|
37
37
|
def test_gptj_6b(self):
|
|
38
38
|
tokenizer = HuggingFaceTokenizer(
|
|
39
|
-
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
39
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai-community/gpt2"
|
|
40
40
|
)
|
|
41
41
|
client = HuggingFaceClient(
|
|
42
42
|
cache_config=BlackHoleCacheConfig(),
|
|
@@ -57,7 +57,7 @@ class TestHuggingFaceClient:
|
|
|
57
57
|
|
|
58
58
|
def test_logprob(self):
|
|
59
59
|
tokenizer = HuggingFaceTokenizer(
|
|
60
|
-
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
60
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai-community/gpt2"
|
|
61
61
|
)
|
|
62
62
|
client = HuggingFaceClient(
|
|
63
63
|
cache_config=BlackHoleCacheConfig(),
|