crfm-helm 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (236) hide show
  1. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/METADATA +41 -57
  2. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/RECORD +197 -152
  3. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +32 -31
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
  6. helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
  7. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
  8. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
  9. helm/benchmark/adaptation/common_adapter_specs.py +2 -0
  10. helm/benchmark/annotation/air_bench_annotator.py +64 -0
  11. helm/benchmark/annotation/annotator_factory.py +6 -0
  12. helm/benchmark/annotation/anthropic_red_team_annotator.py +70 -0
  13. helm/benchmark/annotation/call_center_annotator.py +247 -0
  14. helm/benchmark/annotation/financebench_annotator.py +79 -0
  15. helm/benchmark/annotation/harm_bench_annotator.py +68 -0
  16. helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
  17. helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
  18. helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
  19. helm/benchmark/annotation/live_qa_annotator.py +71 -0
  20. helm/benchmark/annotation/medication_qa_annotator.py +68 -0
  21. helm/benchmark/annotation/model_as_judge.py +45 -0
  22. helm/benchmark/annotation/simple_safety_tests_annotator.py +64 -0
  23. helm/benchmark/annotation/xstest_annotator.py +110 -0
  24. helm/benchmark/augmentations/translate_perturbation.py +1 -0
  25. helm/benchmark/huggingface_registration.py +16 -6
  26. helm/benchmark/metrics/air_bench_metrics.py +56 -0
  27. helm/benchmark/metrics/annotation_metrics.py +108 -0
  28. helm/benchmark/metrics/bhasa_metrics.py +188 -0
  29. helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
  30. helm/benchmark/metrics/code_metrics_helper.py +11 -1
  31. helm/benchmark/metrics/fin_qa_metrics.py +60 -0
  32. helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
  33. helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
  34. helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
  35. helm/benchmark/metrics/live_qa_metrics.py +23 -0
  36. helm/benchmark/metrics/medication_qa_metrics.py +23 -0
  37. helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
  38. helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
  39. helm/benchmark/metrics/safety_metrics.py +57 -0
  40. helm/benchmark/metrics/summac/model_summac.py +3 -3
  41. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
  42. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
  43. helm/benchmark/metrics/unitxt_metrics.py +20 -10
  44. helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
  45. helm/benchmark/metrics/vision_language/image_metrics.py +30 -72
  46. helm/benchmark/metrics/vision_language/image_utils.py +1 -1
  47. helm/benchmark/model_metadata_registry.py +3 -3
  48. helm/benchmark/presentation/schema.py +54 -4
  49. helm/benchmark/presentation/test_run_entry.py +1 -0
  50. helm/benchmark/presentation/test_schema.py +11 -0
  51. helm/benchmark/run.py +31 -2
  52. helm/benchmark/run_expander.py +113 -10
  53. helm/benchmark/run_spec_factory.py +4 -0
  54. helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
  55. helm/benchmark/run_specs/bhasa_run_specs.py +638 -0
  56. helm/benchmark/run_specs/call_center_run_specs.py +152 -0
  57. helm/benchmark/run_specs/classic_run_specs.py +15 -11
  58. helm/benchmark/run_specs/decodingtrust_run_specs.py +11 -9
  59. helm/benchmark/run_specs/experimental_run_specs.py +85 -0
  60. helm/benchmark/run_specs/finance_run_specs.py +110 -0
  61. helm/benchmark/run_specs/safety_run_specs.py +154 -0
  62. helm/benchmark/run_specs/vlm_run_specs.py +251 -57
  63. helm/benchmark/scenarios/air_bench_scenario.py +50 -0
  64. helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
  65. helm/benchmark/scenarios/banking77_scenario.py +51 -0
  66. helm/benchmark/scenarios/bhasa_scenario.py +1798 -0
  67. helm/benchmark/scenarios/call_center_scenario.py +84 -0
  68. helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
  69. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
  70. helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
  71. helm/benchmark/scenarios/ewok_scenario.py +116 -0
  72. helm/benchmark/scenarios/fin_qa_scenario.py +119 -0
  73. helm/benchmark/scenarios/financebench_scenario.py +53 -0
  74. helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
  75. helm/benchmark/scenarios/scenario.py +1 -1
  76. helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
  77. helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
  78. helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
  79. helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
  80. helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
  81. helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
  82. helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
  83. helm/benchmark/scenarios/test_math_scenario.py +2 -8
  84. helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
  85. helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
  86. helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
  87. helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
  88. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
  89. helm/benchmark/scenarios/vision_language/bingo_scenario.py +5 -5
  90. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
  91. helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
  92. helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
  93. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
  94. helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
  95. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
  96. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
  97. helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +13 -2
  98. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -7
  99. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -5
  100. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
  101. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
  102. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
  103. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +44 -13
  104. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
  105. helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
  106. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
  107. helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
  108. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
  109. helm/benchmark/scenarios/vision_language/pairs_scenario.py +7 -6
  110. helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
  111. helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
  112. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
  113. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +5 -5
  114. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +98 -0
  115. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
  116. helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
  117. helm/benchmark/scenarios/xstest_scenario.py +35 -0
  118. helm/benchmark/server.py +1 -6
  119. helm/benchmark/static/schema_air_bench.yaml +3149 -0
  120. helm/benchmark/static/schema_bhasa.yaml +709 -0
  121. helm/benchmark/static/schema_call_center.yaml +232 -0
  122. helm/benchmark/static/schema_classic.yaml +3 -59
  123. helm/benchmark/static/schema_cleva.yaml +768 -0
  124. helm/benchmark/static/schema_decodingtrust.yaml +444 -0
  125. helm/benchmark/static/schema_ewok.yaml +367 -0
  126. helm/benchmark/static/schema_finance.yaml +189 -0
  127. helm/benchmark/static/schema_image2struct.yaml +588 -0
  128. helm/benchmark/static/schema_instruction_following.yaml +3 -52
  129. helm/benchmark/static/schema_lite.yaml +3 -61
  130. helm/benchmark/static/schema_medical.yaml +255 -0
  131. helm/benchmark/static/schema_mmlu.yaml +3 -61
  132. helm/benchmark/static/schema_safety.yaml +247 -0
  133. helm/benchmark/static/schema_tables.yaml +317 -0
  134. helm/benchmark/static/schema_thai.yaml +244 -0
  135. helm/benchmark/static/schema_unitxt.yaml +3 -61
  136. helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +304 -298
  137. helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
  138. helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
  139. helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
  140. helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
  141. helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
  142. helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
  143. helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
  144. helm/benchmark/static_build/assets/index-58f97dcd.js +10 -0
  145. helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
  146. helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
  147. helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
  148. helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
  149. helm/benchmark/static_build/index.html +2 -2
  150. helm/benchmark/window_services/test_openai_window_service.py +8 -8
  151. helm/clients/ai21_client.py +71 -1
  152. helm/clients/anthropic_client.py +50 -28
  153. helm/clients/auto_client.py +11 -0
  154. helm/clients/client.py +24 -7
  155. helm/clients/cohere_client.py +98 -3
  156. helm/clients/huggingface_client.py +79 -19
  157. helm/clients/nvidia_nim_client.py +35 -0
  158. helm/clients/openai_client.py +11 -5
  159. helm/clients/palmyra_client.py +25 -0
  160. helm/clients/perspective_api_client.py +11 -6
  161. helm/clients/reka_client.py +189 -0
  162. helm/clients/test_client.py +7 -9
  163. helm/clients/test_huggingface_client.py +19 -3
  164. helm/clients/test_together_client.py +72 -2
  165. helm/clients/together_client.py +129 -23
  166. helm/clients/vertexai_client.py +62 -18
  167. helm/clients/vision_language/huggingface_vlm_client.py +1 -0
  168. helm/clients/vision_language/open_flamingo_client.py +1 -2
  169. helm/clients/vision_language/paligemma_client.py +146 -0
  170. helm/clients/vision_language/palmyra_vision_client.py +99 -0
  171. helm/clients/yi_client.py +31 -0
  172. helm/common/critique_request.py +10 -1
  173. helm/common/images_utils.py +25 -0
  174. helm/common/mongo_key_value_store.py +2 -1
  175. helm/common/request.py +16 -0
  176. helm/config/model_deployments.yaml +740 -363
  177. helm/config/model_metadata.yaml +824 -128
  178. helm/config/tokenizer_configs.yaml +207 -10
  179. helm/proxy/critique/model_critique_client.py +32 -4
  180. helm/proxy/example_queries.py +14 -21
  181. helm/proxy/services/server_service.py +2 -3
  182. helm/proxy/token_counters/test_auto_token_counter.py +2 -2
  183. helm/tokenizers/ai21_tokenizer.py +51 -59
  184. helm/tokenizers/auto_tokenizer.py +1 -1
  185. helm/tokenizers/cohere_tokenizer.py +29 -62
  186. helm/tokenizers/huggingface_tokenizer.py +35 -13
  187. helm/tokenizers/test_ai21_tokenizer.py +48 -0
  188. helm/tokenizers/test_cohere_tokenizer.py +39 -0
  189. helm/tokenizers/test_huggingface_tokenizer.py +5 -1
  190. helm/benchmark/static/benchmarking.css +0 -156
  191. helm/benchmark/static/benchmarking.js +0 -1705
  192. helm/benchmark/static/config.js +0 -3
  193. helm/benchmark/static/general.js +0 -122
  194. helm/benchmark/static/images/crfm-logo.png +0 -0
  195. helm/benchmark/static/images/helm-logo-simple.png +0 -0
  196. helm/benchmark/static/images/helm-logo.png +0 -0
  197. helm/benchmark/static/images/language-model-helm.png +0 -0
  198. helm/benchmark/static/images/organizations/ai21.png +0 -0
  199. helm/benchmark/static/images/organizations/anthropic.png +0 -0
  200. helm/benchmark/static/images/organizations/bigscience.png +0 -0
  201. helm/benchmark/static/images/organizations/cohere.png +0 -0
  202. helm/benchmark/static/images/organizations/eleutherai.png +0 -0
  203. helm/benchmark/static/images/organizations/google.png +0 -0
  204. helm/benchmark/static/images/organizations/meta.png +0 -0
  205. helm/benchmark/static/images/organizations/microsoft.png +0 -0
  206. helm/benchmark/static/images/organizations/nvidia.png +0 -0
  207. helm/benchmark/static/images/organizations/openai.png +0 -0
  208. helm/benchmark/static/images/organizations/together.png +0 -0
  209. helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
  210. helm/benchmark/static/images/organizations/yandex.png +0 -0
  211. helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
  212. helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
  213. helm/benchmark/static/index.html +0 -68
  214. helm/benchmark/static/info-icon.png +0 -0
  215. helm/benchmark/static/json-urls.js +0 -69
  216. helm/benchmark/static/plot-captions.js +0 -27
  217. helm/benchmark/static/schema_image2structure.yaml +0 -304
  218. helm/benchmark/static/utils.js +0 -285
  219. helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
  220. helm/benchmark/static_build/assets/index-878a1094.css +0 -1
  221. helm/benchmark/window_services/ai21_window_service.py +0 -247
  222. helm/benchmark/window_services/cohere_window_service.py +0 -101
  223. helm/benchmark/window_services/test_ai21_window_service.py +0 -163
  224. helm/benchmark/window_services/test_cohere_window_service.py +0 -75
  225. helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
  226. helm/benchmark/window_services/test_ice_window_service.py +0 -327
  227. helm/tokenizers/ice_tokenizer.py +0 -30
  228. helm/tokenizers/test_ice_tokenizer.py +0 -57
  229. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/LICENSE +0 -0
  230. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/entry_points.txt +0 -0
  231. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/top_level.txt +0 -0
  232. /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
  233. /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
  234. /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/__init__.py +0 -0
  235. /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/__init__.py +0 -0
  236. /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
@@ -1,7 +1,7 @@
1
1
  import requests
2
2
  from abc import ABC, abstractmethod
3
3
  from threading import Lock
4
- from typing import Any, Dict, Optional, List, Union
4
+ from typing import Any, Dict, Mapping, Optional, List, Union
5
5
 
6
6
  from helm.common.cache import CacheConfig
7
7
  from helm.common.media_object import TEXT_TYPE
@@ -26,22 +26,62 @@ class VertexAIContentBlockedError(Exception):
26
26
  pass
27
27
 
28
28
 
29
+ class SafetySettingPresets:
30
+ BLOCK_NONE = "block_none" # Disable all blocking
31
+ DEFAULT = "default" # Use default safety settings
32
+
33
+
34
+ def _get_safety_settings_for_preset(
35
+ safety_settings_preset: Optional[str],
36
+ ) -> Optional[Dict[HarmCategory, SafetySetting.HarmBlockThreshold]]:
37
+ """Get the safety settings for the safety_settings_preset.
38
+
39
+ If safety_settings_preset is None, use the default value of BLOCK_NONE (*not* DEFAULT)."""
40
+ if safety_settings_preset is None or safety_settings_preset == SafetySettingPresets.BLOCK_NONE:
41
+ return {
42
+ harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
43
+ for harm_category in iter(HarmCategory)
44
+ }
45
+ elif safety_settings_preset == SafetySettingPresets.DEFAULT:
46
+ return None
47
+ else:
48
+ raise ValueError(f"Unknown safety_settings_preset: {safety_settings_preset}")
49
+
50
+
51
+ def _get_model_name_for_request(request: Request) -> str:
52
+ # We have to strip "-safety-" suffixes from model names because they are not part of the Vertex AI model name
53
+ # TODO: Clean up this hack
54
+ return request.model_engine.split("-safety-")[0]
55
+
56
+
29
57
  class VertexAIClient(CachingClient, ABC):
30
58
  """Client for Vertex AI models"""
31
59
 
32
- def __init__(self, cache_config: CacheConfig, project_id: str, location: str) -> None:
60
+ def __init__(
61
+ self, cache_config: CacheConfig, project_id: str, location: str, safety_settings_preset: Optional[str] = None
62
+ ) -> None:
33
63
  super().__init__(cache_config=cache_config)
34
64
  self.project_id = project_id
35
65
  self.location = location
36
66
 
37
- # VertexAI's default safety filter is overly sensitive, so we disable it.
38
- self.safety_settings: Dict[HarmCategory, SafetySetting.HarmBlockThreshold] = {
39
- harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
40
- for harm_category in iter(HarmCategory)
41
- }
67
+ self.safety_settings_preset = safety_settings_preset
68
+ self.safety_settings = _get_safety_settings_for_preset(safety_settings_preset)
42
69
 
43
70
  vertexai.init(project=self.project_id, location=self.location)
44
71
 
72
+ def make_cache_key_with_safety_settings_preset(self, raw_request: Mapping, request: Request) -> Mapping:
73
+ """Construct the key for the cache using the raw request.
74
+
75
+ Add `self.safety_settings_preset` to the key, if not None."""
76
+ if self.safety_settings_preset is not None:
77
+ assert "safety_settings_preset" not in raw_request
78
+ return {
79
+ **CachingClient.make_cache_key(raw_request, request),
80
+ "safety_settings_preset": self.safety_settings_preset,
81
+ }
82
+ else:
83
+ return CachingClient.make_cache_key(raw_request, request)
84
+
45
85
  @abstractmethod
46
86
  def make_request(self, request: Request) -> RequestResult:
47
87
  raise NotImplementedError
@@ -71,7 +111,7 @@ class VertexAITextClient(VertexAIClient):
71
111
  }
72
112
 
73
113
  completions: List[GeneratedOutput] = []
74
- model_name: str = request.model_engine
114
+ model_name: str = _get_model_name_for_request(request)
75
115
 
76
116
  try:
77
117
 
@@ -87,9 +127,9 @@ class VertexAITextClient(VertexAIClient):
87
127
  # We need to include the engine's name to differentiate among requests made for different model
88
128
  # engines since the engine name is not included in the request itself.
89
129
  # Same for the prompt.
90
- cache_key = CachingClient.make_cache_key(
130
+ cache_key = self.make_cache_key_with_safety_settings_preset(
91
131
  {
92
- "engine": request.model_engine,
132
+ "engine": model_name,
93
133
  "prompt": request.prompt,
94
134
  **parameters,
95
135
  },
@@ -177,7 +217,7 @@ class VertexAIChatClient(VertexAIClient):
177
217
  }
178
218
 
179
219
  completions: List[GeneratedOutput] = []
180
- model_name: str = request.model_engine
220
+ model_name: str = _get_model_name_for_request(request)
181
221
  model = self.get_model(model_name)
182
222
 
183
223
  try:
@@ -197,7 +237,7 @@ class VertexAIChatClient(VertexAIClient):
197
237
 
198
238
  # Depending on the version of the Vertex AI library and the type of prompt blocking,
199
239
  # prompt blocking can show up in many ways, so this defensively handles most of these ways
200
- if response.prompt_feedback.block_reason:
240
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
201
241
  raise VertexAIContentBlockedError(
202
242
  f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
203
243
  )
@@ -209,8 +249,10 @@ class VertexAIChatClient(VertexAIClient):
209
249
  # content blocking can show up in many ways, so this defensively handles most of these ways
210
250
  if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
211
251
  raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
252
+ if not candidate.content:
253
+ raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
212
254
  if not candidate.content.parts:
213
- raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
255
+ raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
214
256
  predictions.append({"text": candidate.content.text})
215
257
  # TODO: Extract more information from the response
216
258
  return {"predictions": predictions}
@@ -218,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
218
260
  # We need to include the engine's name to differentiate among requests made for different model
219
261
  # engines since the engine name is not included in the request itself.
220
262
  # Same for the prompt.
221
- cache_key = CachingClient.make_cache_key(
263
+ cache_key = self.make_cache_key_with_safety_settings_preset(
222
264
  {
223
265
  "model_name": model_name,
224
266
  "prompt": request.prompt,
@@ -313,7 +355,7 @@ class VertexAIChatClient(VertexAIClient):
313
355
  }
314
356
 
315
357
  completions: List[GeneratedOutput] = []
316
- model_name: str = request.model_engine
358
+ model_name: str = _get_model_name_for_request(request)
317
359
  model = self.get_model(model_name)
318
360
 
319
361
  request_time = 0
@@ -330,7 +372,7 @@ class VertexAIChatClient(VertexAIClient):
330
372
  )
331
373
  # Depending on the version of the Vertex AI library and the type of prompt blocking,
332
374
  # prompt blocking can show up in many ways, so this defensively handles most of these ways
333
- if response.prompt_feedback.block_reason:
375
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
334
376
  raise VertexAIContentBlockedError(
335
377
  f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
336
378
  )
@@ -345,15 +387,17 @@ class VertexAIChatClient(VertexAIClient):
345
387
  # content blocking can show up in many ways, so this defensively handles most of these ways
346
388
  if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
347
389
  raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
390
+ if not candidate.content:
391
+ raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
348
392
  if not candidate.content.parts:
349
- raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
393
+ raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
350
394
  return {"predictions": [{"text": candidate.text}]}
351
395
 
352
396
  raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
353
397
  if completion_index > 0:
354
398
  raw_cache_key["completion_index"] = completion_index
355
399
 
356
- cache_key = CachingClient.make_cache_key(raw_cache_key, request)
400
+ cache_key = self.make_cache_key_with_safety_settings_preset(raw_cache_key, request)
357
401
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
358
402
  except requests.exceptions.RequestException as e:
359
403
  error: str = f"Gemini Vision error: {e}"
@@ -38,6 +38,7 @@ class HuggingFaceVLMClient(CachingClient):
38
38
  "huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
39
39
  "huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
40
40
  "huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
41
+ "huggingface/prometheus-vision-13b-v1.0-hf": "PahaII/prometheus-vision-13b-v1.0-hf",
41
42
  }
42
43
 
43
44
  def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
@@ -82,13 +82,12 @@ class OpenFlamingoClient(CachingClient):
82
82
  # Build the prompt
83
83
  prompt_text: str = ""
84
84
  images: List[Image.Image] = []
85
+ request.validate()
85
86
  for media_object in request.multimodal_prompt.media_objects:
86
87
  if media_object.is_type("image") and media_object.location:
87
88
  images.append(open_image(media_object.location))
88
89
  prompt_text += self.IMAGE_TOKEN
89
90
  elif media_object.is_type(TEXT_TYPE):
90
- if media_object.text is None:
91
- raise ValueError("MediaObject of text type has missing text field value")
92
91
  prompt_text += media_object.text
93
92
  else:
94
93
  raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
@@ -0,0 +1,146 @@
1
+ from threading import Lock
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from dataclasses import dataclass
6
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
7
+
8
+ from helm.common.cache import CacheConfig
9
+ from helm.common.images_utils import open_image
10
+ from helm.common.gpu_utils import get_torch_device_name
11
+ from helm.common.hierarchical_logger import hlog, htrack_block
12
+ from helm.common.media_object import TEXT_TYPE
13
+ from helm.common.optional_dependencies import handle_module_not_found_error
14
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
15
+ from helm.common.tokenization_request import TokenizationRequest
16
+ from helm.common.request import wrap_request_time
17
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
18
+ from helm.tokenizers.tokenizer import Tokenizer
19
+
20
+ try:
21
+ from PIL import Image
22
+ except ModuleNotFoundError as e:
23
+ handle_module_not_found_error(e, ["images"])
24
+
25
+ # Added to solve: cutlassF: no kernel found to launch!
26
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
27
+ torch.backends.cuda.enable_flash_sdp(False)
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class LoadedPaliGemmaForConditionalGeneration:
32
+ """Loaded model and processor for PaliGemma."""
33
+
34
+ model: PaliGemmaForConditionalGeneration
35
+ processor: AutoProcessor
36
+
37
+
38
+ _models_lock: Lock = Lock()
39
+ _models: Dict[str, Optional[LoadedPaliGemmaForConditionalGeneration]] = {}
40
+
41
+
42
+ class PaliGemmaClient(CachingClient):
43
+ """
44
+ PaliGemma is a versatile and lightweight vision-language model (VLM) inspired by PaLI-3
45
+ and based on open components such as the SigLIP vision model and the Gemma language model.
46
+ It takes both image and text as input and generates text as output, supporting multiple languages.
47
+ It is designed for class-leading fine-tune performance on a wide range of vision-language tasks
48
+ such as image and short video caption, visual question answering, text reading, object detection
49
+ and object segmentation.
50
+ """
51
+
52
+ def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
53
+ super().__init__(cache_config=cache_config)
54
+ self.tokenizer = tokenizer
55
+ self.tokenizer_name = tokenizer_name
56
+ self._device: str = get_torch_device_name()
57
+
58
+ def _get_model(self, checkpoint: str) -> LoadedPaliGemmaForConditionalGeneration:
59
+ global _models_lock
60
+ global _models
61
+
62
+ # Ensure that only one thread is loading the model at a time
63
+ with _models_lock:
64
+ if checkpoint not in _models or _models[checkpoint] is None:
65
+ hlog(f"Loading model {checkpoint} and caching in memory...")
66
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
67
+ checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
68
+ ).eval()
69
+ processor = AutoProcessor.from_pretrained(checkpoint)
70
+ _models[checkpoint] = LoadedPaliGemmaForConditionalGeneration(model, processor)
71
+ loaded_model_processor = _models[checkpoint]
72
+
73
+ assert loaded_model_processor is not None
74
+ return loaded_model_processor
75
+
76
+ def make_request(self, request: Request) -> RequestResult:
77
+ assert request.multimodal_prompt is not None, "Multimodal prompt is required"
78
+
79
+ loaded_model_processor: LoadedPaliGemmaForConditionalGeneration = self._get_model(request.model_deployment)
80
+ model = loaded_model_processor.model
81
+ processor = loaded_model_processor.processor
82
+ generation_args = {"max_new_tokens": request.max_tokens}
83
+
84
+ images: List[Image.Image] = []
85
+ prompt_pieces: List[str] = []
86
+ for media_object in request.multimodal_prompt.media_objects:
87
+ if media_object.is_type("image") and media_object.location:
88
+ images += [open_image(media_object.location).convert("RGB")]
89
+ elif media_object.is_type(TEXT_TYPE):
90
+ if media_object.text is None:
91
+ raise ValueError("MediaObject of text type has missing text field value")
92
+ prompt_pieces.append(media_object.text)
93
+ else:
94
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
95
+ prompt_text: str = "\n".join(prompt_pieces)
96
+ model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
97
+ input_len = model_inputs["input_ids"].shape[-1]
98
+
99
+ completions: List[GeneratedOutput] = []
100
+ with htrack_block(f"Generating for prompt: {prompt_text}"):
101
+ try:
102
+ concat_results = []
103
+ for i_completion in range(request.num_completions):
104
+
105
+ def do_it() -> Dict[str, Any]:
106
+ with torch.inference_mode():
107
+ generation = model.generate(
108
+ **model_inputs, max_new_tokens=request.max_tokens, do_sample=False
109
+ )[0]
110
+ if not request.echo_prompt:
111
+ generation = generation[input_len:]
112
+ decoded = processor.decode(generation, skip_special_tokens=True)
113
+ return {"output": decoded}
114
+
115
+ # Include the prompt and model name in the cache key
116
+ cache_key = CachingClient.make_cache_key(
117
+ raw_request={
118
+ "n": request.num_completions,
119
+ "i": i_completion,
120
+ "model": request.model,
121
+ "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
122
+ **generation_args,
123
+ },
124
+ request=request,
125
+ )
126
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
127
+ concat_results.append(result)
128
+ except RuntimeError as model_error:
129
+ return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
130
+
131
+ for result in concat_results:
132
+ text = result["output"]
133
+ hlog(f"Generated text: {text}")
134
+ tokenization_result = self.tokenizer.tokenize(
135
+ TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
136
+ )
137
+ tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
138
+ completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
139
+
140
+ return RequestResult(
141
+ success=True,
142
+ cached=cached,
143
+ request_time=result["request_time"],
144
+ completions=completions,
145
+ embedding=[],
146
+ )
@@ -0,0 +1,99 @@
1
+ from typing import Dict, List
2
+ import json
3
+
4
+ import requests
5
+
6
+ from helm.common.cache import CacheConfig
7
+ from helm.common.images_utils import encode_base64
8
+ from helm.common.media_object import TEXT_TYPE
9
+ from helm.common.request import Request, RequestResult, GeneratedOutput, ErrorFlags
10
+ from helm.common.request import wrap_request_time
11
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt, truncate_and_tokenize_response_text
12
+ from helm.tokenizers.tokenizer import Tokenizer
13
+
14
+
15
+ class PalmyraVisionContentBlockedError(Exception):
16
+ pass
17
+
18
+
19
+ class PalmyraVisionClient(CachingClient):
20
+ CONTENT_BLOCKED_ERROR: str = "fail.input.content.moderation"
21
+
22
+ def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, endpoint: str, cache_config: CacheConfig):
23
+ super().__init__(cache_config)
24
+ self.tokenizer: Tokenizer = tokenizer
25
+ self.tokenizer_name: str = tokenizer_name
26
+
27
+ # Currently, the Palmyra Vision model does not have a public API, so we need to use a secret endpoint
28
+ self.endpoint: str = endpoint
29
+
30
+ def make_request(self, request: Request) -> RequestResult:
31
+ assert request.multimodal_prompt is not None, "Multimodal prompt is required"
32
+
33
+ # Build the prompt
34
+ prompt: List[Dict[str, str]] = []
35
+ for media_object in request.multimodal_prompt.media_objects:
36
+ if media_object.is_type("image") and media_object.location:
37
+ prompt.append(
38
+ {
39
+ "type": "InlineData",
40
+ "value": encode_base64(media_object.location, format="JPEG"),
41
+ "contentType": "image/jpeg",
42
+ }
43
+ )
44
+ elif media_object.is_type(TEXT_TYPE):
45
+ if media_object.text is None:
46
+ raise ValueError("MediaObject of text type has missing text field value")
47
+ prompt.append({"type": "Text", "value": media_object.text})
48
+ else:
49
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
50
+
51
+ # Generate
52
+ try:
53
+
54
+ def do_it():
55
+ response = requests.post(
56
+ self.endpoint, headers={"Content-Type": "application/json"}, data=json.dumps({"parts": prompt})
57
+ )
58
+ json_response = json.loads(response.text)
59
+
60
+ # Check for content blocked error
61
+ if (
62
+ "errors" in json_response
63
+ and "tpe" in json_response
64
+ and json_response["tpe"] == self.CONTENT_BLOCKED_ERROR
65
+ ):
66
+ raise PalmyraVisionContentBlockedError(json_response["errors"])
67
+
68
+ # Hard fail if the `choices` is missing from the response
69
+ assert "choices" in json_response, f"Invalid response: {response.text}"
70
+
71
+ return json_response
72
+
73
+ cache_key = CachingClient.make_cache_key(
74
+ raw_request={"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt)},
75
+ request=request,
76
+ )
77
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
78
+ except PalmyraVisionContentBlockedError as ex:
79
+ return RequestResult(
80
+ success=False,
81
+ cached=False,
82
+ error=f"Content blocked: {str(ex)}",
83
+ completions=[],
84
+ embedding=[],
85
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
86
+ )
87
+
88
+ # The internal endpoint doesn't support any other parameters, so we have to truncate ourselves
89
+ completions: List[GeneratedOutput] = [
90
+ truncate_and_tokenize_response_text(choice["text"], request, self.tokenizer, self.tokenizer_name)
91
+ for choice in result["choices"]
92
+ ]
93
+ return RequestResult(
94
+ success=True,
95
+ cached=cached,
96
+ request_time=result["request_time"],
97
+ completions=completions,
98
+ embedding=[],
99
+ )
@@ -0,0 +1,31 @@
1
+ from typing import Optional
2
+
3
+ from helm.clients.openai_client import OpenAIClient
4
+ from helm.common.cache import CacheConfig
5
+ from helm.tokenizers.tokenizer import Tokenizer
6
+
7
+
8
+ class YiChatClient(OpenAIClient):
9
+
10
+ BASE_URL = "http://api.01ww.xyz/v1"
11
+
12
+ def __init__(
13
+ self,
14
+ tokenizer: Tokenizer,
15
+ tokenizer_name: str,
16
+ cache_config: CacheConfig,
17
+ api_key: Optional[str] = None,
18
+ ):
19
+ self.tokenizer = tokenizer
20
+ self.tokenizer_name = tokenizer_name
21
+ super().__init__(
22
+ tokenizer=tokenizer,
23
+ tokenizer_name=tokenizer_name,
24
+ cache_config=cache_config,
25
+ api_key=api_key,
26
+ org_id=None,
27
+ base_url=YiChatClient.BASE_URL,
28
+ )
29
+
30
+ def _is_chat_model_engine(self, model_engine: str) -> bool:
31
+ return True
@@ -1,5 +1,6 @@
1
1
  from dataclasses import dataclass
2
- from typing import Dict, List, Union
2
+ from typing import Dict, List, Union, Optional
3
+ from helm.common.media_object import MediaObject
3
4
 
4
5
 
5
6
  class QuestionType:
@@ -34,6 +35,11 @@ class CritiqueQuestionTemplate:
34
35
 
35
36
  Can contain placeholders like {{placeholder}} that will be interpolated using the fields in CritiqueRequest."""
36
37
 
38
+ media_object: Optional[MediaObject] = None
39
+ """Path of image for multimodal input.
40
+
41
+ Image path or URL of the question."""
42
+
37
43
 
38
44
  @dataclass(frozen=True)
39
45
  class CritiqueTaskTemplate:
@@ -53,6 +59,9 @@ class CritiqueTaskTemplate:
53
59
  questions: List[CritiqueQuestionTemplate]
54
60
  """List of templates for the questions."""
55
61
 
62
+ max_tokens: Optional[int] = None
63
+ """Max token to be generated for the free-end generation."""
64
+
56
65
 
57
66
  @dataclass(frozen=True)
58
67
  class CritiqueRequest:
@@ -1,5 +1,7 @@
1
+ from hashlib import md5
1
2
  import base64
2
3
  import io
4
+ import os
3
5
 
4
6
  import requests
5
7
  import shutil
@@ -43,6 +45,11 @@ def encode_base64(image_location: str, format="JPEG") -> str:
43
45
  return base64.b64encode(image_file.getvalue()).decode("ascii")
44
46
 
45
47
 
48
+ def generate_hash(image: Image.Image) -> str:
49
+ """Generates a hash for the image."""
50
+ return md5(image.tobytes()).hexdigest()
51
+
52
+
46
53
  def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optional[int] = None) -> None:
47
54
  """
48
55
  Copies the image file from `src` path to `dest` path. If dimensions `width` and `height`
@@ -57,6 +64,24 @@ def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optiona
57
64
  shutil.copy(src, dest)
58
65
 
59
66
 
67
+ def resize_image_to_max_file_size(src: str, dest: str, max_size_in_bytes: int, step=10):
68
+ # Open an image file
69
+ with Image.open(src) as img:
70
+ width, height = img.size
71
+
72
+ # Reduce dimensions iteratively until the file size is under the limit
73
+ while True:
74
+ # Save the image temporarily to check the file size
75
+ img.save(dest, quality=95) # Start with high quality
76
+ if os.path.getsize(dest) < max_size_in_bytes:
77
+ break
78
+
79
+ # Reduce dimensions
80
+ width -= step
81
+ height -= step
82
+ img = img.resize((width, height), Image.Resampling.LANCZOS)
83
+
84
+
60
85
  def is_blacked_out_image(image_location: str) -> bool:
61
86
  """Returns True if the image is all black. False otherwise."""
62
87
  try:
@@ -85,4 +85,5 @@ class MongoKeyValueStore(KeyValueStore):
85
85
  self._collection.bulk_write(operations)
86
86
 
87
87
  def remove(self, key: Dict) -> None:
88
- self._collection.delete_one(key)
88
+ query = {self._REQUEST_KEY: self._canonicalize_key(key)}
89
+ self._collection.delete_one(query)
helm/common/request.py CHANGED
@@ -72,6 +72,22 @@ class Request:
72
72
  image_generation_parameters: Optional[ImageGenerationParameters] = None
73
73
  """Parameters for image generation."""
74
74
 
75
+ def validate(self):
76
+ if (
77
+ (self.messages and self.prompt)
78
+ or (self.messages and self.multimodal_prompt)
79
+ or (self.prompt and self.multimodal_prompt)
80
+ ):
81
+ raise ValueError("Exactly one of the messages, prompt, multimodal_prompt fields should be set")
82
+
83
+ if self.multimodal_prompt:
84
+ for media_object in self.multimodal_prompt.media_objects:
85
+ if media_object.content_type == "text" and media_object.text is None:
86
+ raise ValueError("Media object with text content type must have text set")
87
+
88
+ if media_object.content_type == "image" and media_object.location is None:
89
+ raise ValueError("Media object with image content type must have location set")
90
+
75
91
  @property
76
92
  def model_host(self) -> str:
77
93
  """Returns the model host (referring to the deployment).