crfm-helm 0.5.7__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (243) hide show
  1. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +5 -77
  2. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +228 -197
  3. helm/benchmark/adaptation/adapter_spec.py +5 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
  6. helm/benchmark/annotation/aci_bench_annotator.py +11 -22
  7. helm/benchmark/annotation/alrage_annotator.py +90 -0
  8. helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
  9. helm/benchmark/annotation/dischargeme_annotator.py +11 -22
  10. helm/benchmark/annotation/med_dialog_annotator.py +11 -22
  11. helm/benchmark/annotation/medalign_annotator.py +11 -22
  12. helm/benchmark/annotation/medi_qa_annotator.py +11 -22
  13. helm/benchmark/annotation/medication_qa_annotator.py +11 -22
  14. helm/benchmark/annotation/mental_health_annotator.py +11 -22
  15. helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
  16. helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
  17. helm/benchmark/annotation/model_as_judge.py +23 -18
  18. helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
  19. helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
  20. helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
  21. helm/benchmark/metrics/air_bench_metrics.py +3157 -1
  22. helm/benchmark/metrics/alrage_metric.py +35 -0
  23. helm/benchmark/metrics/basic_metrics.py +267 -2
  24. helm/benchmark/metrics/classification_metrics.py +19 -1
  25. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
  26. helm/benchmark/metrics/dry_run_metrics.py +30 -1
  27. helm/benchmark/metrics/efficiency_metrics.py +74 -0
  28. helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
  29. helm/benchmark/metrics/evaluate_reference_metrics.py +299 -0
  30. helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
  31. helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
  32. helm/benchmark/metrics/ifeval_metrics.py +13 -1
  33. helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
  34. helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
  35. helm/benchmark/metrics/language_modeling_metrics.py +13 -1
  36. helm/benchmark/metrics/live_qa_metrics.py +13 -1
  37. helm/benchmark/metrics/llm_jury_metrics.py +13 -1
  38. helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
  39. helm/benchmark/metrics/medec_metrics.py +25 -2
  40. helm/benchmark/metrics/metric.py +25 -0
  41. helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
  42. helm/benchmark/metrics/omni_math_metrics.py +13 -1
  43. helm/benchmark/metrics/seahelm_metrics.py +14 -1
  44. helm/benchmark/metrics/summac/model_summac.py +2 -2
  45. helm/benchmark/metrics/summarization_metrics.py +129 -1
  46. helm/benchmark/metrics/toxicity_metrics.py +31 -1
  47. helm/benchmark/metrics/wildbench_metrics.py +21 -1
  48. helm/benchmark/presentation/schema.py +5 -22
  49. helm/benchmark/presentation/summarize.py +180 -11
  50. helm/benchmark/presentation/taxonomy_info.py +20 -0
  51. helm/benchmark/run_expander.py +4 -0
  52. helm/benchmark/run_specs/arabic_run_specs.py +134 -16
  53. helm/benchmark/run_specs/bluex_run_specs.py +1 -1
  54. helm/benchmark/run_specs/classic_run_specs.py +2 -2
  55. helm/benchmark/run_specs/long_context_run_specs.py +2 -2
  56. helm/benchmark/run_specs/medhelm/__init__.py +0 -0
  57. helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
  58. helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
  59. helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
  60. helm/benchmark/scenarios/air_bench_scenario.py +21 -0
  61. helm/benchmark/scenarios/alrage_scenario.py +54 -0
  62. helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
  63. helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
  64. helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
  65. helm/benchmark/scenarios/aratrust_scenario.py +19 -0
  66. helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
  67. helm/benchmark/scenarios/bbq_scenario.py +15 -0
  68. helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
  69. helm/benchmark/scenarios/bluex_scenario.py +6 -2
  70. helm/benchmark/scenarios/bold_scenario.py +15 -0
  71. helm/benchmark/scenarios/boolq_scenario.py +20 -0
  72. helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
  73. helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
  74. helm/benchmark/scenarios/clear_scenario.py +23 -0
  75. helm/benchmark/scenarios/cleva_scenario.py +479 -0
  76. helm/benchmark/scenarios/code_scenario.py +28 -0
  77. helm/benchmark/scenarios/commonsense_scenario.py +26 -0
  78. helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
  79. helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
  80. helm/benchmark/scenarios/copyright_scenario.py +35 -1
  81. helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
  82. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
  83. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
  84. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
  85. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
  86. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
  87. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
  88. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
  89. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
  90. helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
  91. helm/benchmark/scenarios/disinformation_scenario.py +22 -0
  92. helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
  93. helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
  94. helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
  95. helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
  96. helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
  97. helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
  98. helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
  99. helm/benchmark/scenarios/gpqa_scenario.py +18 -0
  100. helm/benchmark/scenarios/grammar_scenario.py +20 -1
  101. helm/benchmark/scenarios/gsm_scenario.py +15 -0
  102. helm/benchmark/scenarios/headqa_scenario.py +22 -0
  103. helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
  104. helm/benchmark/scenarios/ice_scenario.py +21 -1
  105. helm/benchmark/scenarios/ifeval_scenario.py +18 -0
  106. helm/benchmark/scenarios/imdb_scenario.py +15 -0
  107. helm/benchmark/scenarios/koala_scenario.py +21 -1
  108. helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
  109. helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
  110. helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
  111. helm/benchmark/scenarios/legal_support_scenario.py +13 -0
  112. helm/benchmark/scenarios/legalbench_scenario.py +20 -0
  113. helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
  114. helm/benchmark/scenarios/lextreme_scenario.py +11 -0
  115. helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
  116. helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
  117. helm/benchmark/scenarios/math_scenario.py +26 -0
  118. helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
  119. helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
  120. helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
  121. helm/benchmark/scenarios/med_qa_scenario.py +14 -0
  122. helm/benchmark/scenarios/medalign_scenario.py +23 -0
  123. helm/benchmark/scenarios/medbullets_scenario.py +22 -0
  124. helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
  125. helm/benchmark/scenarios/medec_scenario.py +23 -0
  126. helm/benchmark/scenarios/medhallu_scenario.py +23 -0
  127. helm/benchmark/scenarios/medhelm/__init__.py +0 -0
  128. helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
  129. helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
  130. helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
  131. helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
  132. helm/benchmark/scenarios/mental_health_scenario.py +23 -0
  133. helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
  134. helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
  135. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
  136. helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
  137. helm/benchmark/scenarios/mmlu_scenario.py +15 -0
  138. helm/benchmark/scenarios/msmarco_scenario.py +30 -0
  139. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
  140. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
  141. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
  142. helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
  143. helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
  144. helm/benchmark/scenarios/omni_math_scenario.py +18 -0
  145. helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
  146. helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
  147. helm/benchmark/scenarios/quac_scenario.py +14 -0
  148. helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
  149. helm/benchmark/scenarios/raft_scenario.py +15 -0
  150. helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
  151. helm/benchmark/scenarios/scenario.py +31 -0
  152. helm/benchmark/scenarios/seahelm_scenario.py +348 -0
  153. helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
  154. helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
  155. helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
  156. helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
  157. helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
  158. helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
  159. helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
  160. helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
  161. helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
  162. helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
  163. helm/benchmark/scenarios/situation_prompts.yaml +49 -0
  164. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
  165. helm/benchmark/scenarios/summarization_scenario.py +37 -0
  166. helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
  167. helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
  168. helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
  169. helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
  170. helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
  171. helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
  172. helm/benchmark/scenarios/the_pile_scenario.py +13 -1
  173. helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
  174. helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
  175. helm/benchmark/scenarios/vicuna_scenario.py +21 -1
  176. helm/benchmark/scenarios/wikifact_scenario.py +20 -0
  177. helm/benchmark/scenarios/wildbench_scenario.py +18 -0
  178. helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
  179. helm/benchmark/static/schema_arabic.yaml +55 -12
  180. helm/benchmark/static/schema_long_context.yaml +17 -17
  181. helm/benchmark/static/schema_medhelm.yaml +36 -0
  182. helm/benchmark/static/schema_slp.yaml +219 -0
  183. helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
  184. helm/benchmark/static_build/assets/index-9352595e.css +1 -0
  185. helm/benchmark/static_build/index.html +2 -2
  186. helm/clients/audio_language/llama_omni/arguments.py +61 -0
  187. helm/clients/audio_language/llama_omni/constants.py +9 -0
  188. helm/clients/audio_language/llama_omni/conversation.py +213 -0
  189. helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
  190. helm/clients/audio_language/llama_omni/model/builder.py +88 -0
  191. helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
  192. helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
  193. helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
  194. helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
  195. helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
  196. helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
  197. helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
  198. helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
  199. helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
  200. helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
  201. helm/clients/audio_language/llama_omni/preprocess.py +295 -0
  202. helm/clients/audio_language/llama_omni/utils.py +202 -0
  203. helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
  204. helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
  205. helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
  206. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
  207. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
  208. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
  209. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
  210. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
  211. helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
  212. helm/clients/openai_client.py +31 -19
  213. helm/clients/openai_responses_client.py +27 -3
  214. helm/clients/openrouter_client.py +31 -0
  215. helm/clients/test_openrouter_client.py +69 -0
  216. helm/clients/together_client.py +48 -11
  217. helm/clients/vertexai_client.py +8 -2
  218. helm/config/model_deployments.yaml +75 -1
  219. helm/config/model_metadata.yaml +70 -2
  220. helm/config/tokenizer_configs.yaml +19 -1
  221. helm/proxy/example_queries.py +8 -8
  222. helm/proxy/server.py +2 -1
  223. helm/proxy/static/index.css +4 -0
  224. helm/proxy/static/index.js +7 -1
  225. helm/benchmark/metrics/aci_bench_metrics.py +0 -14
  226. helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
  227. helm/benchmark/metrics/dischargeme_metrics.py +0 -14
  228. helm/benchmark/metrics/med_dialog_metrics.py +0 -14
  229. helm/benchmark/metrics/medalign_metrics.py +0 -14
  230. helm/benchmark/metrics/medi_qa_metrics.py +0 -14
  231. helm/benchmark/metrics/medication_qa_metrics.py +0 -14
  232. helm/benchmark/metrics/mental_health_metrics.py +0 -14
  233. helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
  234. helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
  235. helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
  236. helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
  237. helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
  238. helm/benchmark/static_build/assets/index-b9779128.css +0 -1
  239. helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
  240. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
  241. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
  242. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
  243. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import logging
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+ import warnings
10
+ from functools import lru_cache
11
+ from io import BytesIO
12
+
13
+ import requests
14
+ import torch
15
+ import torchvision
16
+ from packaging import version
17
+ from PIL import Image
18
+ from torchvision import io, transforms
19
+ from torchvision.transforms import InterpolationMode
20
+ from typing import List, Optional, Union
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ IMAGE_FACTOR = 28
26
+ MIN_PIXELS = 4 * 28 * 28
27
+ MAX_PIXELS = 16384 * 28 * 28
28
+ MAX_RATIO = 200
29
+
30
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
31
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
32
+ FRAME_FACTOR = 2
33
+ FPS = 2.0
34
+ FPS_MIN_FRAMES = 4
35
+ FPS_MAX_FRAMES = 768
36
+
37
+ # Set the maximum number of video token inputs.
38
+ # Here, 128K represents the maximum number of input tokens for the VLLM model.
39
+ # Remember to adjust it according to your own configuration.
40
+ VIDEO_TOTAL_PIXELS = int(float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9)))
41
+ logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
42
+
43
+
44
+ def round_by_factor(number: int, factor: int) -> int:
45
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
46
+ return round(number / factor) * factor
47
+
48
+
49
+ def ceil_by_factor(number: int, factor: int) -> int:
50
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
51
+ return math.ceil(number / factor) * factor
52
+
53
+
54
+ def floor_by_factor(number: int, factor: int) -> int:
55
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
56
+ return math.floor(number / factor) * factor
57
+
58
+
59
+ def smart_resize(
60
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
61
+ ) -> tuple[int, int]:
62
+ """
63
+ Rescales the image so that the following conditions are met:
64
+
65
+ 1. Both dimensions (height and width) are divisible by 'factor'.
66
+
67
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
68
+
69
+ 3. The aspect ratio of the image is maintained as closely as possible.
70
+ """
71
+ if max(height, width) / min(height, width) > MAX_RATIO:
72
+ raise ValueError(
73
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
74
+ )
75
+ h_bar = max(factor, round_by_factor(height, factor))
76
+ w_bar = max(factor, round_by_factor(width, factor))
77
+ if h_bar * w_bar > max_pixels:
78
+ beta = math.sqrt((height * width) / max_pixels)
79
+ h_bar = floor_by_factor(int(height / beta), factor)
80
+ w_bar = floor_by_factor(int(width / beta), factor)
81
+ elif h_bar * w_bar < min_pixels:
82
+ beta = math.sqrt(min_pixels / (height * width))
83
+ h_bar = ceil_by_factor(int(height * beta), factor)
84
+ w_bar = ceil_by_factor(int(width * beta), factor)
85
+ return h_bar, w_bar
86
+
87
+
88
+ def to_rgb(pil_image: Image.Image) -> Image.Image:
89
+ if pil_image.mode == "RGBA":
90
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
91
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
92
+ return white_background
93
+ else:
94
+ return pil_image.convert("RGB")
95
+
96
+
97
+ def fetch_image(ele, size_factor: int = IMAGE_FACTOR) -> Image.Image:
98
+ if "image" in ele:
99
+ image = ele["image"]
100
+ else:
101
+ image = ele["image_url"]
102
+ image_obj = None
103
+ if isinstance(image, Image.Image):
104
+ image_obj = image
105
+ elif image.startswith("http://") or image.startswith("https://"):
106
+ response = requests.get(image, stream=True)
107
+ image_obj = Image.open(BytesIO(response.content))
108
+ elif image.startswith("file://"):
109
+ image_obj = Image.open(image[7:])
110
+ elif image.startswith("data:image"):
111
+ if "base64," in image:
112
+ _, base64_data = image.split("base64,", 1)
113
+ data = base64.b64decode(base64_data)
114
+ image_obj = Image.open(BytesIO(data))
115
+ else:
116
+ image_obj = Image.open(image)
117
+ if image_obj is None:
118
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
119
+ image = to_rgb(image_obj)
120
+ # resize
121
+ if "resized_height" in ele and "resized_width" in ele:
122
+ resized_height, resized_width = smart_resize(
123
+ int(ele["resized_height"]),
124
+ int(ele["resized_width"]),
125
+ factor=size_factor,
126
+ )
127
+ else:
128
+ width, height = image.size
129
+ min_pixels = int(ele.get("min_pixels", MIN_PIXELS))
130
+ max_pixels = int(ele.get("max_pixels", MAX_PIXELS))
131
+ resized_height, resized_width = smart_resize(
132
+ height,
133
+ width,
134
+ factor=size_factor,
135
+ min_pixels=min_pixels,
136
+ max_pixels=max_pixels,
137
+ )
138
+ image = image.resize((resized_width, resized_height))
139
+
140
+ return image
141
+
142
+
143
+ def smart_nframes(
144
+ ele: dict,
145
+ total_frames: int,
146
+ video_fps: Union[int, float],
147
+ ) -> int:
148
+ """calculate the number of frames for video used for model inputs.
149
+
150
+ Args:
151
+ ele (dict): a dict contains the configuration of video.
152
+ support either `fps` or `nframes`:
153
+ - nframes: the number of frames to extract for model inputs.
154
+ - fps: the fps to extract frames for model inputs.
155
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
156
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
157
+ total_frames (int): the original total number of frames of the video.
158
+ video_fps (int | float): the original fps of the video.
159
+
160
+ Raises:
161
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
162
+
163
+ Returns:
164
+ int: the number of frames for video used for model inputs.
165
+ """
166
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
167
+ if "nframes" in ele:
168
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
169
+ else:
170
+ fps = ele.get("fps", FPS)
171
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
172
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
173
+ nframes = total_frames / video_fps * fps
174
+ if nframes > total_frames:
175
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
176
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
177
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
178
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
179
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
180
+ return nframes
181
+
182
+
183
+ def _read_video_torchvision(
184
+ ele: dict,
185
+ ):
186
+ """read video using torchvision.io.read_video
187
+
188
+ Args:
189
+ ele (dict): a dict contains the configuration of video.
190
+ support keys:
191
+ - video: the path of video. support "file://", "http://", "https://" and local path.
192
+ - video_start: the start time of video.
193
+ - video_end: the end time of video.
194
+ Returns:
195
+ torch.Tensor: the video tensor with shape (T, C, H, W).
196
+ """
197
+ video_path = ele["video"]
198
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
199
+ if "http://" in video_path or "https://" in video_path:
200
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
201
+ if "file://" in video_path:
202
+ video_path = video_path[7:]
203
+ st = time.time()
204
+ video, audio, info = io.read_video(
205
+ video_path,
206
+ start_pts=ele.get("video_start", 0.0),
207
+ end_pts=ele.get("video_end", None),
208
+ pts_unit="sec",
209
+ output_format="TCHW",
210
+ )
211
+ total_frames, video_fps = video.size(0), info["video_fps"]
212
+ logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
213
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
214
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
215
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
216
+ video = video[idx]
217
+ return video, sample_fps
218
+
219
+
220
+ def is_decord_available() -> bool:
221
+ import importlib.util
222
+
223
+ return importlib.util.find_spec("decord") is not None
224
+
225
+
226
+ def _read_video_decord(
227
+ ele: dict,
228
+ ):
229
+ """read video using decord.VideoReader
230
+
231
+ Args:
232
+ ele (dict): a dict contains the configuration of video.
233
+ support keys:
234
+ - video: the path of video. support "file://", "http://", "https://" and local path.
235
+ - video_start: the start time of video.
236
+ - video_end: the end time of video.
237
+ Returns:
238
+ torch.Tensor: the video tensor with shape (T, C, H, W).
239
+ """
240
+ import decord
241
+
242
+ video_path = ele["video"]
243
+ st = time.time()
244
+ vr = decord.VideoReader(video_path)
245
+ # TODO: support start_pts and end_pts
246
+ if "video_start" in ele or "video_end" in ele:
247
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
248
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
249
+ logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
250
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
251
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
252
+ video = vr.get_batch(idx).asnumpy()
253
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
254
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
255
+ return video, sample_fps
256
+
257
+
258
+ VIDEO_READER_BACKENDS = {
259
+ "decord": _read_video_decord,
260
+ "torchvision": _read_video_torchvision,
261
+ }
262
+
263
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
264
+
265
+
266
+ @lru_cache(maxsize=1)
267
+ def get_video_reader_backend() -> str:
268
+ if FORCE_QWENVL_VIDEO_READER is not None:
269
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
270
+ elif is_decord_available():
271
+ video_reader_backend = "decord"
272
+ else:
273
+ video_reader_backend = "torchvision"
274
+ print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
275
+ return video_reader_backend
276
+
277
+
278
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False):
279
+ if isinstance(ele["video"], str):
280
+ video_reader_backend = get_video_reader_backend()
281
+ try:
282
+ video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
283
+ except Exception as e:
284
+ logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
285
+ video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
286
+
287
+ nframes, _, height, width = video.shape
288
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
289
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
290
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
291
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
292
+ if max_pixels_supposed > max_pixels:
293
+ logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
294
+ max_pixels = min(max_pixels_supposed, max_pixels)
295
+ if "resized_height" in ele and "resized_width" in ele:
296
+ resized_height, resized_width = smart_resize(
297
+ ele["resized_height"],
298
+ ele["resized_width"],
299
+ factor=image_factor,
300
+ )
301
+ else:
302
+ resized_height, resized_width = smart_resize(
303
+ height,
304
+ width,
305
+ factor=image_factor,
306
+ min_pixels=min_pixels,
307
+ max_pixels=max_pixels,
308
+ )
309
+ video = transforms.functional.resize(
310
+ video,
311
+ [resized_height, resized_width],
312
+ interpolation=InterpolationMode.BICUBIC,
313
+ antialias=True,
314
+ ).float()
315
+ if return_video_sample_fps:
316
+ return video, sample_fps
317
+ return video
318
+ else:
319
+ assert isinstance(ele["video"], (list, tuple))
320
+ process_info = ele.copy()
321
+ process_info.pop("type", None)
322
+ process_info.pop("video", None)
323
+ images = [
324
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
325
+ for video_element in ele["video"]
326
+ ]
327
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
328
+ if len(images) < nframes:
329
+ images.extend([images[-1]] * (nframes - len(images)))
330
+ if return_video_sample_fps:
331
+ return images, process_info.pop("fps", 2.0)
332
+ return images
333
+
334
+
335
+ def extract_vision_info(conversations) -> list[dict]:
336
+ vision_infos = []
337
+ if isinstance(conversations[0], dict):
338
+ conversations_p = [conversations]
339
+ for conversation in conversations_p:
340
+ for message in conversation:
341
+ if isinstance(message["content"], list):
342
+ for ele in message["content"]:
343
+ if (
344
+ "image" in ele
345
+ or "image_url" in ele
346
+ or "video" in ele
347
+ or ele["type"] in ("image", "image_url", "video")
348
+ ):
349
+ vision_infos.append(ele)
350
+ return vision_infos
351
+
352
+
353
+ def process_vision_info(
354
+ conversations: list[dict] | list[list[dict]],
355
+ return_video_kwargs: bool = False,
356
+ ):
357
+
358
+ vision_infos = extract_vision_info(conversations)
359
+ # Read images or videos
360
+ image_inputs: Optional[List] = []
361
+ video_inputs: Optional[List] = []
362
+ video_sample_fps_list = []
363
+ for vision_info in vision_infos:
364
+ if "image" in vision_info or "image_url" in vision_info:
365
+ assert image_inputs is not None
366
+ image_inputs.append(fetch_image(vision_info))
367
+ elif "video" in vision_info:
368
+ assert video_inputs is not None
369
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
370
+ video_sample_fps_list.append(video_sample_fps)
371
+ video_inputs.append(video_input)
372
+ else:
373
+ raise ValueError("image, image_url or video should in content.")
374
+ if image_inputs is not None and len(image_inputs) == 0:
375
+ image_inputs = None
376
+ if video_inputs is not None and len(video_inputs) == 0:
377
+ video_inputs = None
378
+ if return_video_kwargs:
379
+ return image_inputs, video_inputs, {"fps": video_sample_fps_list}
380
+ return image_inputs, video_inputs
@@ -184,7 +184,7 @@ def sparse_attention_2d_light(
184
184
  attention_dropout=None,
185
185
  log_attention_weights=None,
186
186
  add_scalar=0,
187
- **kwargs
187
+ **kwargs,
188
188
  ):
189
189
  """
190
190
  q0, k0, v0: [batch_size, 1088, hidden_size]
@@ -141,7 +141,7 @@ class Encoder(nn.Module):
141
141
  in_channels: int,
142
142
  resolution: int,
143
143
  z_channels: int,
144
- double_z: Optional[bool] = None
144
+ double_z: Optional[bool] = None,
145
145
  ) -> None:
146
146
  super().__init__()
147
147
  self.ch = ch
@@ -232,7 +232,7 @@ class Decoder(nn.Module):
232
232
  in_channels: int,
233
233
  resolution: int,
234
234
  z_channels: int,
235
- double_z: bool
235
+ double_z: bool,
236
236
  ) -> None:
237
237
  super().__init__()
238
238
  self.ch = ch
@@ -33,9 +33,12 @@ class OpenAIClientUtils:
33
33
  @classmethod
34
34
  def is_reasoning_model(cls, model_engine: str) -> bool:
35
35
  # All OpenAI reasoning models start "o[somenumber]", so we regexp for that to future proof things
36
- return bool(re.match(r"^o\d+", model_engine))
36
+ return bool(re.match(r"^o\d+", model_engine)) or bool(re.match(r"^gpt-5", model_engine))
37
37
 
38
38
  # Error OpenAI throws when the image in the prompt violates their content policy
39
+ HARMFUL_INFORMATION_ERROR: str = (
40
+ "Invalid prompt: we've limited access to this content for safety reasons. This type of information may be used to benefit or to harm people." # noqa: E501
41
+ )
39
42
  INAPPROPRIATE_IMAGE_ERROR: str = "Your input image may contain content that is not allowed by our safety system"
40
43
  INAPPROPRIATE_PROMPT_ERROR: str = "Invalid prompt: your prompt was flagged"
41
44
  INAPPROPRIATE_PROMPT_AZURE_ERROR: str = (
@@ -44,12 +47,10 @@ class OpenAIClientUtils:
44
47
  INAPPROPRIATE_PROMPT_MICROSOFT_ERROR: str = (
45
48
  "The response was filtered due to the prompt triggering Microsoft's content management policy."
46
49
  )
47
-
48
- # OpenAI server error
49
- OPENAI_SERVER_ERROR: str = (
50
- "The server had an error processing your request. Sorry about that! You can retry your request, "
51
- "or contact us through our help center at help.openai.com if you keep seeing this error."
52
- )
50
+ # Grok content safety guidelines error message
51
+ # TODO: Refactor so that this is owned by the Grok client instead.
52
+ SAFETY_GUIDELINES_GROK_ERROR: str = "Content violates safety guidelines."
53
+ USAGE_GUIDELINES_GROK_ERROR: str = "Content violates usage guidelines."
53
54
 
54
55
  # Set the finish reason to this if the prompt violates OpenAI's content policy
55
56
  CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
@@ -74,21 +75,14 @@ class OpenAIClientUtils:
74
75
  completions=[empty_completion] * request.num_completions,
75
76
  embedding=[],
76
77
  )
77
- elif cls.OPENAI_SERVER_ERROR in str(e):
78
- # Handle these errors by returning an empty completion to unblock
79
- hwarn(f"OpenAI server error for request: {str(request)}")
80
- empty_completion = GeneratedOutput(
81
- text="",
82
- logprob=0,
83
- tokens=[],
84
- finish_reason={"reason": cls.OPENAI_SERVER_ERROR},
85
- )
78
+ elif cls.HARMFUL_INFORMATION_ERROR in str(e):
86
79
  return RequestResult(
87
- success=True,
80
+ success=False,
88
81
  cached=False,
89
- request_time=0,
90
- completions=[empty_completion] * request.num_completions,
82
+ error="Prompt blocked by OpenAI's safety filter",
83
+ completions=[],
91
84
  embedding=[],
85
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
92
86
  )
93
87
  elif cls.INAPPROPRIATE_PROMPT_AZURE_ERROR in str(e) or cls.INAPPROPRIATE_PROMPT_MICROSOFT_ERROR in str(e):
94
88
  return RequestResult(
@@ -99,6 +93,24 @@ class OpenAIClientUtils:
99
93
  embedding=[],
100
94
  error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
101
95
  )
96
+ elif cls.SAFETY_GUIDELINES_GROK_ERROR in str(e):
97
+ return RequestResult(
98
+ success=False,
99
+ cached=False,
100
+ error="Grok API error: Content violates safety guidelines",
101
+ completions=[],
102
+ embedding=[],
103
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
104
+ )
105
+ elif cls.USAGE_GUIDELINES_GROK_ERROR in str(e):
106
+ return RequestResult(
107
+ success=False,
108
+ cached=False,
109
+ error="Grok API error: Content violates usage guidelines",
110
+ completions=[],
111
+ embedding=[],
112
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
113
+ )
102
114
 
103
115
  error: str = f"OpenAI error: {e}"
104
116
  return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union
5
5
 
6
6
  from helm.clients.openai_client import OpenAIClientUtils
7
7
  from helm.common.cache import CacheConfig
8
+ from helm.common.hierarchical_logger import hwarn
8
9
  from helm.common.media_object import TEXT_TYPE
9
10
  from helm.common.request import (
10
11
  Thinking,
@@ -60,7 +61,28 @@ class OpenAIResponseClient(CachingClient):
60
61
 
61
62
  def _make_raw_request(self, request: Request) -> dict[str, Any]:
62
63
  input: Union[str, List[Dict[str, Any]]]
63
- if request.multimodal_prompt is not None:
64
+
65
+ if (
66
+ (request.prompt and request.messages)
67
+ or (request.prompt and request.multimodal_prompt)
68
+ or (request.messages and request.multimodal_prompt)
69
+ ):
70
+ raise ValueError(
71
+ f"More than one of `prompt`, `messages` and `multimodal_prompt` was set in request: {request}"
72
+ )
73
+
74
+ if request.messages is not None:
75
+ # Checks that all messages have a role and some content
76
+ for message in request.messages:
77
+ if not message.get("role") or not message.get("content"):
78
+ raise ValueError("All messages must have a role and content")
79
+ # Checks that the last role is "user"
80
+ if request.messages[-1]["role"] != "user":
81
+ raise ValueError("Last message must have role 'user'")
82
+ if request.prompt != "":
83
+ hwarn("Since message is set, prompt will be ignored")
84
+ input = request.messages
85
+ elif request.multimodal_prompt is not None:
64
86
  content = []
65
87
  request.validate()
66
88
  for media_object in request.multimodal_prompt.media_objects:
@@ -101,6 +123,8 @@ class OpenAIResponseClient(CachingClient):
101
123
  # Plus other changes
102
124
  model_engine: str = request.model_engine
103
125
  if OpenAIClientUtils.is_reasoning_model(model_engine):
126
+ if "reasoning" not in raw_request:
127
+ raw_request["reasoning"] = {}
104
128
  raw_request["reasoning"]["summary"] = "detailed"
105
129
  # Avoid error:
106
130
  # "Error code: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is
@@ -150,9 +174,9 @@ class OpenAIResponseClient(CachingClient):
150
174
  ] # one of "message" or "reasoning" from API observation, but can also include tool calls
151
175
 
152
176
  if output_type == "reasoning":
153
- reasoning_output += "\n".join([raw_output["text"] for raw_output in output["summary"]])
177
+ reasoning_output += "\n\n".join([raw_output["text"] for raw_output in output["summary"]])
154
178
  elif output_type == "message":
155
- text_output += "\n".join([raw_output["text"] for raw_output in output["content"]])
179
+ text_output += "\n\n".join([raw_output["text"] for raw_output in output["content"]])
156
180
  # (Other output types are ignored)
157
181
 
158
182
  completion = truncate_and_tokenize_response_text(
@@ -0,0 +1,31 @@
1
+ import os
2
+ from typing import Optional
3
+ from helm.clients.openai_client import OpenAIClient
4
+ from helm.common.cache import CacheConfig
5
+ from helm.tokenizers.tokenizer import Tokenizer
6
+
7
+
8
+ class OpenRouterClient(OpenAIClient):
9
+ def __init__(
10
+ self,
11
+ tokenizer_name: str,
12
+ tokenizer: Tokenizer,
13
+ cache_config: CacheConfig,
14
+ api_key: Optional[str] = None,
15
+ model_name: Optional[str] = None,
16
+ output_processor: Optional[str] = None,
17
+ ):
18
+ self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
19
+ self.base_url = "https://openrouter.ai/api/v1/"
20
+ super().__init__(
21
+ tokenizer,
22
+ tokenizer_name,
23
+ cache_config=cache_config,
24
+ output_processor=output_processor,
25
+ base_url=self.base_url,
26
+ api_key=self.api_key,
27
+ )
28
+ self.model_name = model_name
29
+
30
+ def _get_model_for_request(self, request):
31
+ return self.model_name or request.model
@@ -0,0 +1,69 @@
1
+ import os
2
+ import pytest
3
+ import tempfile
4
+
5
+ from helm.common.cache import BlackHoleCacheConfig, SqliteCacheConfig
6
+ from helm.common.request import Request
7
+ from helm.clients.openrouter_client import OpenRouterClient
8
+
9
+ from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
10
+
11
+
12
+ class TestOpenRouterClient:
13
+ def setup_method(self, method):
14
+ cache_file = tempfile.NamedTemporaryFile(delete=False)
15
+ self.cache_path: str = cache_file.name
16
+ self.tokenizer_name = "mistralai/Mistral-7B-v0.1"
17
+ self.tokenizer = HuggingFaceTokenizer(
18
+ cache_config=BlackHoleCacheConfig(),
19
+ tokenizer_name=self.tokenizer_name,
20
+ )
21
+
22
+ def teardown_method(self, method):
23
+ os.remove(self.cache_path)
24
+
25
+ @pytest.mark.parametrize(
26
+ "model_name,test_input,expected_model",
27
+ [
28
+ (
29
+ "mistralai/mistral-medium-3.1",
30
+ Request(
31
+ model="mistralai/mistral-medium-3.1",
32
+ model_deployment="openrouter/mistral-medium-3.1",
33
+ ),
34
+ "mistralai/mistral-medium-3.1",
35
+ ),
36
+ (
37
+ None,
38
+ Request(model="openai/gpt-oss-20b:free", model_deployment="openrouter/gpt-oss-20b:free"),
39
+ "openai/gpt-oss-20b:free",
40
+ ),
41
+ ],
42
+ )
43
+ def test_get_model_for_request(self, model_name, test_input, expected_model):
44
+ client = OpenRouterClient(
45
+ tokenizer_name=self.tokenizer_name,
46
+ tokenizer=self.tokenizer,
47
+ cache_config=SqliteCacheConfig(self.cache_path),
48
+ model_name=model_name,
49
+ api_key="test_key",
50
+ )
51
+ assert client._get_model_for_request(test_input) == expected_model
52
+
53
+ def test_api_key_env_var(self, monkeypatch):
54
+ monkeypatch.setenv("OPENROUTER_API_KEY", "test_key")
55
+ client = OpenRouterClient(
56
+ tokenizer_name=self.tokenizer_name,
57
+ tokenizer=self.tokenizer,
58
+ cache_config=SqliteCacheConfig(self.cache_path),
59
+ )
60
+ assert client.api_key == "test_key"
61
+
62
+ def test_api_key_argument(self):
63
+ client = OpenRouterClient(
64
+ tokenizer_name=self.tokenizer_name,
65
+ tokenizer=self.tokenizer,
66
+ cache_config=BlackHoleCacheConfig(),
67
+ api_key="explicit_key",
68
+ )
69
+ assert client.api_key == "explicit_key"