crfm-helm 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (209) hide show
  1. {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/METADATA +81 -112
  2. {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/RECORD +165 -155
  3. {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
  5. helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
  6. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
  7. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
  8. helm/benchmark/adaptation/common_adapter_specs.py +2 -0
  9. helm/benchmark/annotation/anthropic_red_team_annotator.py +57 -0
  10. helm/benchmark/annotation/call_center_annotator.py +258 -0
  11. helm/benchmark/annotation/financebench_annotator.py +79 -0
  12. helm/benchmark/annotation/harm_bench_annotator.py +55 -0
  13. helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
  14. helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
  15. helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
  16. helm/benchmark/annotation/live_qa_annotator.py +37 -45
  17. helm/benchmark/annotation/medication_qa_annotator.py +36 -44
  18. helm/benchmark/annotation/model_as_judge.py +96 -0
  19. helm/benchmark/annotation/simple_safety_tests_annotator.py +50 -0
  20. helm/benchmark/annotation/xstest_annotator.py +100 -0
  21. helm/benchmark/metrics/annotation_metrics.py +108 -0
  22. helm/benchmark/metrics/bhasa_metrics.py +188 -0
  23. helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
  24. helm/benchmark/metrics/code_metrics_helper.py +11 -1
  25. helm/benchmark/metrics/safety_metrics.py +79 -0
  26. helm/benchmark/metrics/summac/model_summac.py +3 -3
  27. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
  28. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
  29. helm/benchmark/metrics/unitxt_metrics.py +17 -3
  30. helm/benchmark/metrics/vision_language/image_metrics.py +7 -3
  31. helm/benchmark/metrics/vision_language/image_utils.py +1 -1
  32. helm/benchmark/model_metadata_registry.py +3 -3
  33. helm/benchmark/presentation/create_plots.py +1 -1
  34. helm/benchmark/presentation/schema.py +3 -0
  35. helm/benchmark/presentation/summarize.py +106 -256
  36. helm/benchmark/presentation/test_run_entry.py +1 -0
  37. helm/benchmark/presentation/test_summarize.py +145 -3
  38. helm/benchmark/run.py +15 -0
  39. helm/benchmark/run_expander.py +83 -30
  40. helm/benchmark/run_specs/bhasa_run_specs.py +652 -0
  41. helm/benchmark/run_specs/call_center_run_specs.py +152 -0
  42. helm/benchmark/run_specs/decodingtrust_run_specs.py +8 -8
  43. helm/benchmark/run_specs/experimental_run_specs.py +52 -0
  44. helm/benchmark/run_specs/finance_run_specs.py +82 -1
  45. helm/benchmark/run_specs/safety_run_specs.py +154 -0
  46. helm/benchmark/run_specs/vlm_run_specs.py +100 -24
  47. helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
  48. helm/benchmark/scenarios/banking77_scenario.py +51 -0
  49. helm/benchmark/scenarios/bhasa_scenario.py +1942 -0
  50. helm/benchmark/scenarios/call_center_scenario.py +84 -0
  51. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
  52. helm/benchmark/scenarios/ewok_scenario.py +116 -0
  53. helm/benchmark/scenarios/fin_qa_scenario.py +2 -0
  54. helm/benchmark/scenarios/financebench_scenario.py +53 -0
  55. helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
  56. helm/benchmark/scenarios/raft_scenario.py +1 -1
  57. helm/benchmark/scenarios/scenario.py +1 -1
  58. helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
  59. helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
  60. helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
  61. helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
  62. helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
  63. helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
  64. helm/benchmark/scenarios/test_math_scenario.py +2 -8
  65. helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
  66. helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
  67. helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
  68. helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
  69. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
  70. helm/benchmark/scenarios/vision_language/bingo_scenario.py +2 -2
  71. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
  72. helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
  73. helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
  74. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
  75. helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
  76. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
  77. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
  78. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -3
  79. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -1
  80. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
  81. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
  82. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
  83. helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +41 -12
  84. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
  85. helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
  86. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
  87. helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
  88. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
  89. helm/benchmark/scenarios/vision_language/pairs_scenario.py +1 -1
  90. helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
  91. helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
  92. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
  93. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +2 -2
  94. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +6 -3
  95. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
  96. helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
  97. helm/benchmark/scenarios/xstest_scenario.py +35 -0
  98. helm/benchmark/server.py +1 -6
  99. helm/benchmark/static/schema_air_bench.yaml +750 -750
  100. helm/benchmark/static/schema_bhasa.yaml +709 -0
  101. helm/benchmark/static/schema_call_center.yaml +232 -0
  102. helm/benchmark/static/schema_cleva.yaml +768 -0
  103. helm/benchmark/static/schema_decodingtrust.yaml +444 -0
  104. helm/benchmark/static/schema_ewok.yaml +367 -0
  105. helm/benchmark/static/schema_finance.yaml +55 -9
  106. helm/benchmark/static/{schema_image2structure.yaml → schema_image2struct.yaml} +231 -90
  107. helm/benchmark/static/schema_legal.yaml +566 -0
  108. helm/benchmark/static/schema_safety.yaml +266 -0
  109. helm/benchmark/static/schema_tables.yaml +149 -8
  110. helm/benchmark/static/schema_thai.yaml +21 -0
  111. helm/benchmark/static/schema_vhelm.yaml +137 -101
  112. helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
  113. helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
  114. helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
  115. helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
  116. helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
  117. helm/benchmark/static_build/assets/index-3ee38b3d.js +10 -0
  118. helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
  119. helm/benchmark/static_build/assets/vhelm-aspects-1437d673.png +0 -0
  120. helm/benchmark/static_build/assets/vhelm-framework-a1ca3f3f.png +0 -0
  121. helm/benchmark/static_build/assets/vhelm-model-8afb7616.png +0 -0
  122. helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
  123. helm/benchmark/static_build/index.html +2 -2
  124. helm/benchmark/window_services/test_openai_window_service.py +8 -8
  125. helm/benchmark/window_services/tokenizer_service.py +0 -5
  126. helm/clients/ai21_client.py +71 -1
  127. helm/clients/anthropic_client.py +7 -19
  128. helm/clients/huggingface_client.py +38 -37
  129. helm/clients/nvidia_nim_client.py +35 -0
  130. helm/clients/openai_client.py +18 -4
  131. helm/clients/palmyra_client.py +24 -0
  132. helm/clients/perspective_api_client.py +11 -6
  133. helm/clients/test_client.py +4 -6
  134. helm/clients/together_client.py +22 -0
  135. helm/clients/vision_language/open_flamingo_client.py +1 -2
  136. helm/clients/vision_language/palmyra_vision_client.py +28 -13
  137. helm/common/cache.py +8 -30
  138. helm/common/images_utils.py +6 -0
  139. helm/common/key_value_store.py +9 -9
  140. helm/common/mongo_key_value_store.py +5 -4
  141. helm/common/request.py +16 -0
  142. helm/common/test_cache.py +1 -48
  143. helm/common/tokenization_request.py +0 -9
  144. helm/config/model_deployments.yaml +444 -329
  145. helm/config/model_metadata.yaml +513 -111
  146. helm/config/tokenizer_configs.yaml +140 -11
  147. helm/proxy/example_queries.py +14 -21
  148. helm/proxy/server.py +0 -9
  149. helm/proxy/services/remote_service.py +0 -6
  150. helm/proxy/services/server_service.py +6 -20
  151. helm/proxy/services/service.py +0 -6
  152. helm/proxy/token_counters/test_auto_token_counter.py +2 -2
  153. helm/tokenizers/ai21_tokenizer.py +51 -59
  154. helm/tokenizers/cohere_tokenizer.py +0 -75
  155. helm/tokenizers/huggingface_tokenizer.py +0 -1
  156. helm/tokenizers/test_ai21_tokenizer.py +48 -0
  157. helm/benchmark/data_overlap/data_overlap_spec.py +0 -86
  158. helm/benchmark/data_overlap/export_scenario_text.py +0 -119
  159. helm/benchmark/data_overlap/light_scenario.py +0 -60
  160. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  161. helm/benchmark/static/benchmarking.css +0 -156
  162. helm/benchmark/static/benchmarking.js +0 -1705
  163. helm/benchmark/static/config.js +0 -3
  164. helm/benchmark/static/general.js +0 -122
  165. helm/benchmark/static/images/crfm-logo.png +0 -0
  166. helm/benchmark/static/images/helm-logo-simple.png +0 -0
  167. helm/benchmark/static/images/helm-logo.png +0 -0
  168. helm/benchmark/static/images/language-model-helm.png +0 -0
  169. helm/benchmark/static/images/organizations/ai21.png +0 -0
  170. helm/benchmark/static/images/organizations/anthropic.png +0 -0
  171. helm/benchmark/static/images/organizations/bigscience.png +0 -0
  172. helm/benchmark/static/images/organizations/cohere.png +0 -0
  173. helm/benchmark/static/images/organizations/eleutherai.png +0 -0
  174. helm/benchmark/static/images/organizations/google.png +0 -0
  175. helm/benchmark/static/images/organizations/meta.png +0 -0
  176. helm/benchmark/static/images/organizations/microsoft.png +0 -0
  177. helm/benchmark/static/images/organizations/nvidia.png +0 -0
  178. helm/benchmark/static/images/organizations/openai.png +0 -0
  179. helm/benchmark/static/images/organizations/together.png +0 -0
  180. helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
  181. helm/benchmark/static/images/organizations/yandex.png +0 -0
  182. helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
  183. helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
  184. helm/benchmark/static/index.html +0 -68
  185. helm/benchmark/static/info-icon.png +0 -0
  186. helm/benchmark/static/json-urls.js +0 -69
  187. helm/benchmark/static/plot-captions.js +0 -27
  188. helm/benchmark/static/utils.js +0 -285
  189. helm/benchmark/static_build/assets/index-30dbceba.js +0 -10
  190. helm/benchmark/static_build/assets/index-66b02d40.css +0 -1
  191. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  192. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  193. helm/benchmark/window_services/ai21_window_service.py +0 -247
  194. helm/benchmark/window_services/cohere_window_service.py +0 -101
  195. helm/benchmark/window_services/test_ai21_window_service.py +0 -163
  196. helm/benchmark/window_services/test_cohere_window_service.py +0 -75
  197. helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
  198. helm/benchmark/window_services/test_ice_window_service.py +0 -327
  199. helm/tokenizers/ice_tokenizer.py +0 -30
  200. helm/tokenizers/test_ice_tokenizer.py +0 -57
  201. {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/LICENSE +0 -0
  202. {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/entry_points.txt +0 -0
  203. {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/top_level.txt +0 -0
  204. /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
  205. /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
  206. /helm/benchmark/{data_overlap → scenarios/vision_language/image2struct}/__init__.py +0 -0
  207. /helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +0 -0
  208. /helm/benchmark/scenarios/vision_language/{image2structure → image2struct/webpage}/__init__.py +0 -0
  209. /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, TypedDict
9
9
 
10
10
  from helm.common.cache import CacheConfig
11
11
  from helm.common.hierarchical_logger import htrack_block, hlog
12
+ from helm.common.optional_dependencies import handle_module_not_found_error
12
13
  from helm.common.request import (
13
14
  wrap_request_time,
14
15
  EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
@@ -58,60 +59,61 @@ class HuggingFaceServer:
58
59
  self,
59
60
  pretrained_model_name_or_path: str,
60
61
  wrapped_tokenizer: WrappedPreTrainedTokenizer,
61
- openvino=False,
62
+ openvino: bool = False,
62
63
  **kwargs,
63
64
  ):
64
- if torch.cuda.is_available():
65
- hlog("CUDA is available, initializing with a GPU...")
66
- self.device: str = "cuda:0"
65
+ self.device: Optional[str]
66
+ if "device_map" in kwargs:
67
+ try:
68
+ import accelerate # noqa: F401
69
+ except ModuleNotFoundError as e:
70
+ handle_module_not_found_error(e, ["accelerate"])
71
+ hlog(f'Hugging Face device_map set to "{kwargs["device_map"]}".')
72
+ self.device = None
73
+ elif torch.cuda.is_available():
74
+ hlog('Hugging Face device set to "cuda:0" because CUDA is available.')
75
+ self.device = "cuda:0"
67
76
  else:
77
+ hlog('Hugging Face device set to "cpu" because CUDA is unavailable.')
68
78
  self.device = "cpu"
79
+
80
+ # Security issue: currently we trust remote code by default.
81
+ # We retain this temporarily to maintain reverse compatibility.
82
+ # TODO: Delete if-else and don't set trust_remote_code=True
83
+ if "trust_remote_code" not in kwargs:
84
+ kwargs["trust_remote_code"] = True
85
+
69
86
  with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
70
87
  # WARNING this may fail if your GPU does not have enough memory
71
88
  if openvino:
72
- """
73
- Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
74
- OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
75
- Intel® architectures using OpenVINO™ runtime.
76
- """
77
- from helm.common.optional_dependencies import handle_module_not_found_error
78
-
89
+ # Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
90
+ # OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
91
+ # Intel® architectures using OpenVINO™ runtime.
79
92
  try:
80
93
  from optimum.intel.openvino import OVModelForCausalLM
81
94
  except ModuleNotFoundError as e:
82
95
  handle_module_not_found_error(e, ["openvino"])
83
96
 
84
97
  self.device = "cpu"
85
- # Security issue: currently we trust remote code by default.
86
- # We retain this temporarily to maintain reverse compatibility.
87
- # TODO: Delete if-else and don't set trust_remote_code=True
88
- if "trust_remote_code" in kwargs:
89
- self.model = OVModelForCausalLM.from_pretrained(
90
- pretrained_model_name_or_path, export=True, **kwargs
91
- ).to(self.device)
92
- else:
93
- self.model = OVModelForCausalLM.from_pretrained(
94
- pretrained_model_name_or_path, export=True, trust_remote_code=True, **kwargs
95
- ).to(self.device)
98
+ self.model = OVModelForCausalLM.from_pretrained(
99
+ pretrained_model_name_or_path, export=True, **kwargs
100
+ ).to(self.device)
101
+ elif self.device is None:
102
+ # kwargs contains device_map=auto
103
+ # Do not call to() because accelerate will take care of model device placement.
104
+ self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
96
105
  else:
97
- # Security issue: currently we trust remote code by default.
98
- # We retain this temporarily to maintain reverse compatibility.
99
- # TODO: Delete if-else and don't set trust_remote_code=True
100
- if "trust_remote_code" in kwargs:
101
- self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
102
- self.device
103
- )
104
- else:
105
- self.model = AutoModelForCausalLM.from_pretrained(
106
- pretrained_model_name_or_path, trust_remote_code=True, **kwargs
107
- ).to(self.device)
106
+ self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
107
+ self.device
108
+ )
108
109
  self.wrapped_tokenizer = wrapped_tokenizer
109
110
 
110
111
  def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
111
112
  with self.wrapped_tokenizer as tokenizer:
112
113
  encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
113
- self.device
114
+ 0 if self.device is None else self.device
114
115
  )
116
+
115
117
  stopping_criteria: Optional[StoppingCriteriaList] = None
116
118
  optional_args = {}
117
119
  if len(raw_request["stop_sequences"]) > 0:
@@ -249,9 +251,8 @@ def _process_huggingface_client_kwargs(raw_kwargs: Dict[str, Any]):
249
251
  # e.g. the string "torch.bfloat16" is converted to torch.bfloat16
250
252
  torch_dtype = processed_kwargs.get(TORCH_DTYPE_KEY)
251
253
  if torch_dtype and isinstance(torch_dtype, str):
252
- if not torch_dtype.startswith(TORCH_DTYPE_VALUE_PREFIX):
253
- raise ValueError(f'Unknown dtype "{torch_dtype}"; expected a string such as "torch.bfloat16"')
254
- processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
254
+ if torch_dtype.startswith(TORCH_DTYPE_VALUE_PREFIX):
255
+ processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
255
256
 
256
257
  return processed_kwargs
257
258
 
@@ -0,0 +1,35 @@
1
+ from typing import Optional
2
+
3
+ from helm.clients.openai_client import OpenAIClient
4
+ from helm.common.cache import CacheConfig
5
+ from helm.common.request import Request
6
+ from helm.tokenizers.tokenizer import Tokenizer
7
+
8
+
9
+ class NvidiaNimClient(OpenAIClient):
10
+
11
+ BASE_URL = "https://integrate.api.nvidia.com/v1"
12
+
13
+ def __init__(
14
+ self,
15
+ tokenizer: Tokenizer,
16
+ tokenizer_name: str,
17
+ cache_config: CacheConfig,
18
+ api_key: Optional[str] = None,
19
+ ):
20
+ self.tokenizer = tokenizer
21
+ self.tokenizer_name = tokenizer_name
22
+ super().__init__(
23
+ tokenizer=tokenizer,
24
+ tokenizer_name=tokenizer_name,
25
+ cache_config=cache_config,
26
+ api_key=api_key,
27
+ org_id=None,
28
+ base_url=NvidiaNimClient.BASE_URL,
29
+ )
30
+
31
+ def _get_model_for_request(self, request: Request) -> str:
32
+ return request.model
33
+
34
+ def _is_chat_model_engine(self, model_engine: str) -> bool:
35
+ return True
@@ -12,8 +12,8 @@ from helm.common.tokenization_request import (
12
12
  TokenizationRequest,
13
13
  TokenizationRequestResult,
14
14
  )
15
- from helm.tokenizers.tokenizer import Tokenizer
16
15
  from .client import CachingClient, truncate_sequence, generate_uid_for_multimodal_prompt
16
+ from helm.tokenizers.tokenizer import Tokenizer
17
17
 
18
18
  try:
19
19
  import openai
@@ -51,7 +51,7 @@ class OpenAIClient(CachingClient):
51
51
  def _is_chat_model_engine(self, model_engine: str) -> bool:
52
52
  if model_engine == "gpt-3.5-turbo-instruct":
53
53
  return False
54
- elif model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4"):
54
+ elif model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4") or model_engine.startswith("o1"):
55
55
  return True
56
56
  return False
57
57
 
@@ -132,6 +132,7 @@ class OpenAIClient(CachingClient):
132
132
  content: Union[str, List[Union[str, Any]]]
133
133
  if request.multimodal_prompt is not None:
134
134
  content = []
135
+ request.validate()
135
136
  for media_object in request.multimodal_prompt.media_objects:
136
137
  if media_object.is_type("image") and media_object.location:
137
138
  from helm.common.images_utils import encode_base64
@@ -140,8 +141,6 @@ class OpenAIClient(CachingClient):
140
141
  image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
141
142
  content.append({"type": "image_url", "image_url": image_object})
142
143
  elif media_object.is_type(TEXT_TYPE):
143
- if media_object.text is None:
144
- raise ValueError("MediaObject of text type has missing text field value")
145
144
  content.append({"type": media_object.type, "text": media_object.text})
146
145
  else:
147
146
  raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
@@ -170,6 +169,21 @@ class OpenAIClient(CachingClient):
170
169
  if is_vlm(request.model) and raw_request["stop"] is None:
171
170
  raw_request.pop("stop")
172
171
 
172
+ # Special handling for o1 models.
173
+ # Refer to the "Reasoning models" documentation further discussion of o1 model limitations:
174
+ # https://platform.openai.com/docs/guides/reasoning
175
+ if request.model_engine.startswith("o1"):
176
+ # Avoid error:
177
+ # "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead." # noqa: E501
178
+ # Note that openai>=1.45 is needed for this
179
+ if raw_request["max_tokens"]:
180
+ raw_request["max_completion_tokens"] = raw_request["max_tokens"]
181
+ raw_request.pop("max_tokens")
182
+ # Avoid error:
183
+ # "Invalid type for 'stop': expected an unsupported value, but got null instead."
184
+ if raw_request["stop"] is None:
185
+ raw_request.pop("stop")
186
+
173
187
  def do_it() -> Dict[str, Any]:
174
188
  return self.client.chat.completions.create(**raw_request).model_dump(mode="json")
175
189
 
@@ -3,6 +3,7 @@ import json
3
3
  import requests
4
4
  from typing import Any, Dict, List
5
5
 
6
+ from helm.clients.openai_client import OpenAIClient
6
7
  from helm.common.cache import CacheConfig
7
8
  from helm.common.hierarchical_logger import hlog
8
9
  from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token, ErrorFlags
@@ -142,3 +143,26 @@ class PalmyraClient(CachingClient):
142
143
  completions=completions,
143
144
  embedding=[],
144
145
  )
146
+
147
+
148
+ class PalmyraChatClient(OpenAIClient):
149
+ """Sends request to a Palmyra model using a OpenAI-compatible Chat API."""
150
+
151
+ def __init__(
152
+ self,
153
+ tokenizer: Tokenizer,
154
+ tokenizer_name: str,
155
+ cache_config: CacheConfig,
156
+ api_key: str,
157
+ ):
158
+ super().__init__(
159
+ tokenizer=tokenizer,
160
+ tokenizer_name=tokenizer_name,
161
+ cache_config=cache_config,
162
+ api_key=api_key,
163
+ org_id=None,
164
+ base_url="https://api.writer.com/v1/chat",
165
+ )
166
+
167
+ def _is_chat_model_engine(self, model_engine: str) -> bool:
168
+ return True
@@ -4,16 +4,21 @@ from dataclasses import asdict
4
4
  from typing import Any, List, Dict, Optional
5
5
 
6
6
  from dacite import from_dict
7
- from googleapiclient import discovery
8
- from googleapiclient.errors import BatchError, HttpError
9
- from googleapiclient.http import BatchHttpRequest
10
- from httplib2 import HttpLib2Error
7
+
11
8
  from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
9
+ from helm.common.optional_dependencies import handle_module_not_found_error
12
10
  from helm.proxy.retry import NonRetriableException
13
-
14
11
  from helm.common.cache import Cache, CacheConfig
15
12
  from helm.common.perspective_api_request import ToxicityAttributes, PerspectiveAPIRequest, PerspectiveAPIRequestResult
16
- from google.auth.exceptions import DefaultCredentialsError
13
+
14
+ try:
15
+ from googleapiclient import discovery
16
+ from googleapiclient.errors import BatchError, HttpError
17
+ from googleapiclient.http import BatchHttpRequest
18
+ from httplib2 import HttpLib2Error
19
+ from google.auth.exceptions import DefaultCredentialsError
20
+ except ModuleNotFoundError as e:
21
+ handle_module_not_found_error(e, ["metrics"])
17
22
 
18
23
 
19
24
  class PerspectiveAPIClientCredentialsError(NonRetriableException):
@@ -23,30 +23,28 @@ def test_truncate_sequence():
23
23
  # echo_prompt = True, nothing gets truncated
24
24
  truncate_sequence_helper(
25
25
  ["a", "b", "c"],
26
- Request(
27
- model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", prompt="abc", echo_prompt=True
28
- ),
26
+ Request(model="openai/gpt2", model_deployment="huggingface/gpt2", prompt="abc", echo_prompt=True),
29
27
  ["a", "b", "c"],
30
28
  )
31
29
 
32
30
  # Nothing gets truncated
33
31
  truncate_sequence_helper(
34
32
  ["hello", " world"],
35
- Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", stop_sequences=["#"]),
33
+ Request(model="openai/gpt2", model_deployment="huggingface/gpt2", stop_sequences=["#"]),
36
34
  ["hello", " world"],
37
35
  )
38
36
 
39
37
  # Truncate using stop sequences
40
38
  truncate_sequence_helper(
41
39
  ["hello", " world", "\n", "what"],
42
- Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", stop_sequences=["\n"]),
40
+ Request(model="openai/gpt2", model_deployment="huggingface/gpt2", stop_sequences=["\n"]),
43
41
  ["hello", " world"],
44
42
  )
45
43
 
46
44
  # Truncate using max tokens
47
45
  truncate_sequence_helper(
48
46
  ["a", "b", "c"],
49
- Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", max_tokens=2),
47
+ Request(model="openai/gpt2", model_deployment="huggingface/gpt2", max_tokens=2),
50
48
  ["a", "b"],
51
49
  )
52
50
 
@@ -7,6 +7,7 @@ import requests
7
7
  from retrying import retry
8
8
 
9
9
  from helm.common.cache import CacheConfig
10
+ from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
10
11
  from helm.common.optional_dependencies import handle_module_not_found_error
11
12
  from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
12
13
  from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
@@ -323,8 +324,29 @@ class TogetherChatClient(CachingClient):
323
324
  self._together_model = together_model
324
325
 
325
326
  def convert_to_raw_chat_request(self, request: Request) -> TogetherRawChatRequest:
327
+ request.validate()
328
+ messages: List[Dict[str, Any]]
326
329
  if request.messages:
327
330
  messages = request.messages
331
+ elif request.multimodal_prompt:
332
+ message_contents = []
333
+ for media_object in request.multimodal_prompt.media_objects:
334
+ if media_object.is_type(IMAGE_TYPE) and media_object.location:
335
+ assert media_object.location
336
+ if media_object.is_local_file:
337
+ from helm.common.images_utils import encode_base64
338
+
339
+ base64_image: str = encode_base64(media_object.location)
340
+ image_url = f"data:image/jpeg;base64,{base64_image}"
341
+ else:
342
+ image_url = media_object.location
343
+ message_contents.append({"type": "image_url", "image_url": {"url": image_url}})
344
+ elif media_object.is_type(TEXT_TYPE):
345
+ assert media_object.text
346
+ message_contents.append({"type": "text", "text": media_object.text})
347
+ else:
348
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
349
+ messages = [{"role": "user", "content": message_contents}]
328
350
  else:
329
351
  messages = [{"role": "user", "content": request.prompt}]
330
352
  if self._together_model is not None:
@@ -82,13 +82,12 @@ class OpenFlamingoClient(CachingClient):
82
82
  # Build the prompt
83
83
  prompt_text: str = ""
84
84
  images: List[Image.Image] = []
85
+ request.validate()
85
86
  for media_object in request.multimodal_prompt.media_objects:
86
87
  if media_object.is_type("image") and media_object.location:
87
88
  images.append(open_image(media_object.location))
88
89
  prompt_text += self.IMAGE_TOKEN
89
90
  elif media_object.is_type(TEXT_TYPE):
90
- if media_object.text is None:
91
- raise ValueError("MediaObject of text type has missing text field value")
92
91
  prompt_text += media_object.text
93
92
  else:
94
93
  raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
@@ -6,13 +6,19 @@ import requests
6
6
  from helm.common.cache import CacheConfig
7
7
  from helm.common.images_utils import encode_base64
8
8
  from helm.common.media_object import TEXT_TYPE
9
- from helm.common.request import Request, RequestResult, GeneratedOutput
9
+ from helm.common.request import Request, RequestResult, GeneratedOutput, ErrorFlags
10
10
  from helm.common.request import wrap_request_time
11
11
  from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt, truncate_and_tokenize_response_text
12
12
  from helm.tokenizers.tokenizer import Tokenizer
13
13
 
14
14
 
15
+ class PalmyraVisionContentBlockedError(Exception):
16
+ pass
17
+
18
+
15
19
  class PalmyraVisionClient(CachingClient):
20
+ CONTENT_BLOCKED_ERROR: str = "fail.input.content.moderation"
21
+
16
22
  def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, endpoint: str, cache_config: CacheConfig):
17
23
  super().__init__(cache_config)
18
24
  self.tokenizer: Tokenizer = tokenizer
@@ -49,17 +55,19 @@ class PalmyraVisionClient(CachingClient):
49
55
  response = requests.post(
50
56
  self.endpoint, headers={"Content-Type": "application/json"}, data=json.dumps({"parts": prompt})
51
57
  )
52
- if response.status_code != 200:
53
- curl_command: str = (
54
- f"curl --location '{self.endpoint}' --header 'Content-Type: application/json' "
55
- f"--data '{json.dumps({'parts': prompt})}'"
56
- )
57
- assert False, f"Got status code {response.status_code}. Try {curl_command}"
58
-
59
58
  json_response = json.loads(response.text)
60
- assert (
61
- "choices" in json_response and "errors" not in json_response
62
- ), f"Invalid response: {response.text}"
59
+
60
+ # Check for content blocked error
61
+ if (
62
+ "errors" in json_response
63
+ and "tpe" in json_response
64
+ and json_response["tpe"] == self.CONTENT_BLOCKED_ERROR
65
+ ):
66
+ raise PalmyraVisionContentBlockedError(json_response["errors"])
67
+
68
+ # Hard fail if the `choices` is missing from the response
69
+ assert "choices" in json_response, f"Invalid response: {response.text}"
70
+
63
71
  return json_response
64
72
 
65
73
  cache_key = CachingClient.make_cache_key(
@@ -67,8 +75,15 @@ class PalmyraVisionClient(CachingClient):
67
75
  request=request,
68
76
  )
69
77
  result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
70
- except RuntimeError as ex:
71
- return RequestResult(success=False, cached=False, error=str(ex), completions=[], embedding=[])
78
+ except PalmyraVisionContentBlockedError as ex:
79
+ return RequestResult(
80
+ success=False,
81
+ cached=False,
82
+ error=f"Content blocked: {str(ex)}",
83
+ completions=[],
84
+ embedding=[],
85
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
86
+ )
72
87
 
73
88
  # The internal endpoint doesn't support any other parameters, so we have to truncate ourselves
74
89
  completions: List[GeneratedOutput] = [
helm/common/cache.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from dataclasses import dataclass
3
- from typing import Dict, Callable, Generator, Mapping, Optional, Tuple
3
+ from typing import Dict, Callable, Generator, Mapping, Tuple
4
4
  import json
5
5
  import threading
6
6
 
@@ -38,6 +38,12 @@ class CacheConfig:
38
38
  class KeyValueStoreCacheConfig(CacheConfig):
39
39
  """Configuration for a cache backed by a key-value store."""
40
40
 
41
+ # This was originally to distinguish between "primitive" cache configs
42
+ # and "compound" cache configs. But we don't have any "compound" cache configs currently.
43
+ # Hypthetical "compound" example: ReadOnlyCacheConfig(SqliteCacheConfig("path"))
44
+ # TODO: Maybe remove this eventually?
45
+ pass
46
+
41
47
 
42
48
  @dataclass(frozen=True)
43
49
  class SqliteCacheConfig(KeyValueStoreCacheConfig):
@@ -78,24 +84,6 @@ class MongoCacheConfig(KeyValueStoreCacheConfig):
78
84
  return f"{self.uri}/{self.collection_name}"
79
85
 
80
86
 
81
- @dataclass(frozen=True)
82
- class WithFollowerCacheConfig(CacheConfig):
83
- """Configuration of a cache backed by a main cache and a follower cache."""
84
-
85
- # Configuration for the main cache.
86
- # Responses will be written to and served out of this cache.
87
- main: KeyValueStoreCacheConfig
88
-
89
- # Configuration for the follower cache.
90
- # The follower cache is a write-only cache. Responses will be written to this cache,
91
- # but not served out of this cache.
92
- follower: KeyValueStoreCacheConfig
93
-
94
- @property
95
- def cache_stats_key(self) -> str:
96
- return self.main.cache_stats_key
97
-
98
-
99
87
  def get_all_from_sqlite(path: str) -> Generator[Tuple[Dict, Dict], None, None]:
100
88
  """Yields all decoded key, value pairs from the SQLite cache.
101
89
 
@@ -126,7 +114,7 @@ def create_key_value_store(config: KeyValueStoreCacheConfig) -> KeyValueStore:
126
114
  elif isinstance(config, BlackHoleCacheConfig):
127
115
  return BlackHoleKeyValueStore()
128
116
  else:
129
- raise ValueError(f"KeyValueStoreCacheConfig with unknown type: {config}")
117
+ raise ValueError(f"CacheConfig with unknown type: {config}")
130
118
 
131
119
 
132
120
  @retry
@@ -189,14 +177,8 @@ class Cache(object):
189
177
 
190
178
  def __init__(self, config: CacheConfig):
191
179
  hlog(f"Created cache with config: {config}")
192
- self.config: KeyValueStoreCacheConfig
193
- self.follower_config: Optional[KeyValueStoreCacheConfig]
194
180
  if isinstance(config, KeyValueStoreCacheConfig):
195
181
  self.config = config
196
- self.follower_config = None
197
- elif isinstance(config, WithFollowerCacheConfig):
198
- self.config = config.main
199
- self.follower_config = config.follower
200
182
  else:
201
183
  raise ValueError(f"CacheConfig with unknown type: {config}")
202
184
 
@@ -216,8 +198,4 @@ class Cache(object):
216
198
  response = compute()
217
199
 
218
200
  write_to_key_value_store(key_value_store, request, response)
219
- if self.follower_config is not None:
220
- # TODO: Initialize follower_key_value_store in constructor
221
- with create_key_value_store(self.follower_config) as follower_key_value_store:
222
- write_to_key_value_store(follower_key_value_store, request, response)
223
201
  return response, cached
@@ -1,3 +1,4 @@
1
+ from hashlib import md5
1
2
  import base64
2
3
  import io
3
4
  import os
@@ -44,6 +45,11 @@ def encode_base64(image_location: str, format="JPEG") -> str:
44
45
  return base64.b64encode(image_file.getvalue()).decode("ascii")
45
46
 
46
47
 
48
+ def generate_hash(image: Image.Image) -> str:
49
+ """Generates a hash for the image."""
50
+ return md5(image.tobytes()).hexdigest()
51
+
52
+
47
53
  def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optional[int] = None) -> None:
48
54
  """
49
55
  Copies the image file from `src` path to `dest` path. If dimensions `width` and `height`
@@ -15,11 +15,11 @@ class KeyValueStore(contextlib.AbstractContextManager):
15
15
  """Key value store that persists writes."""
16
16
 
17
17
  @abstractmethod
18
- def contains(self, key: Dict) -> bool:
18
+ def contains(self, key: Mapping) -> bool:
19
19
  pass
20
20
 
21
21
  @abstractmethod
22
- def get(self, key: Dict) -> Optional[Dict]:
22
+ def get(self, key: Mapping) -> Optional[Dict]:
23
23
  pass
24
24
 
25
25
  @abstractmethod
@@ -35,7 +35,7 @@ class KeyValueStore(contextlib.AbstractContextManager):
35
35
  pass
36
36
 
37
37
  @abstractmethod
38
- def remove(self, key: Dict) -> None:
38
+ def remove(self, key: Mapping) -> None:
39
39
  pass
40
40
 
41
41
 
@@ -53,10 +53,10 @@ class SqliteKeyValueStore(KeyValueStore):
53
53
  def __exit__(self, exc_type, exc_value, traceback) -> None:
54
54
  self._sqlite_dict.__exit__(exc_type, exc_value, traceback)
55
55
 
56
- def contains(self, key: Dict) -> bool:
56
+ def contains(self, key: Mapping) -> bool:
57
57
  return request_to_key(key) in self._sqlite_dict
58
58
 
59
- def get(self, key: Dict) -> Optional[Dict]:
59
+ def get(self, key: Mapping) -> Optional[Dict]:
60
60
  key_string = request_to_key(key)
61
61
  result = self._sqlite_dict.get(key_string)
62
62
  if result is not None:
@@ -77,7 +77,7 @@ class SqliteKeyValueStore(KeyValueStore):
77
77
  for key, value in pairs:
78
78
  self.put(key, value)
79
79
 
80
- def remove(self, key: Dict) -> None:
80
+ def remove(self, key: Mapping) -> None:
81
81
  del self._sqlite_dict[key]
82
82
  self._sqlite_dict.commit()
83
83
 
@@ -91,10 +91,10 @@ class BlackHoleKeyValueStore(KeyValueStore):
91
91
  def __exit__(self, exc_type, exc_value, traceback) -> None:
92
92
  pass
93
93
 
94
- def contains(self, key: Dict) -> bool:
94
+ def contains(self, key: Mapping) -> bool:
95
95
  return False
96
96
 
97
- def get(self, key: Dict) -> Optional[Dict]:
97
+ def get(self, key: Mapping) -> Optional[Dict]:
98
98
  return None
99
99
 
100
100
  def get_all(self) -> Generator[Tuple[Dict, Dict], None, None]:
@@ -109,5 +109,5 @@ class BlackHoleKeyValueStore(KeyValueStore):
109
109
  def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
110
110
  return None
111
111
 
112
- def remove(self, key: Dict) -> None:
112
+ def remove(self, key: Mapping) -> None:
113
113
  return None
@@ -39,11 +39,11 @@ class MongoKeyValueStore(KeyValueStore):
39
39
  serialized = json.dumps(key, sort_keys=True)
40
40
  return json.loads(serialized, object_pairs_hook=SON)
41
41
 
42
- def contains(self, key: Dict) -> bool:
42
+ def contains(self, key: Mapping) -> bool:
43
43
  query = {self._REQUEST_KEY: self._canonicalize_key(key)}
44
44
  return self._collection.find_one(query) is not None
45
45
 
46
- def get(self, key: Dict) -> Optional[Dict]:
46
+ def get(self, key: Mapping) -> Optional[Dict]:
47
47
  query = {self._REQUEST_KEY: self._canonicalize_key(key)}
48
48
  document = self._collection.find_one(query)
49
49
  if document is not None:
@@ -84,5 +84,6 @@ class MongoKeyValueStore(KeyValueStore):
84
84
  # Note: unlike put, multi_put does not support documents with null bytes in keys.
85
85
  self._collection.bulk_write(operations)
86
86
 
87
- def remove(self, key: Dict) -> None:
88
- self._collection.delete_one(key)
87
+ def remove(self, key: Mapping) -> None:
88
+ query = {self._REQUEST_KEY: self._canonicalize_key(key)}
89
+ self._collection.delete_one(query)
helm/common/request.py CHANGED
@@ -72,6 +72,22 @@ class Request:
72
72
  image_generation_parameters: Optional[ImageGenerationParameters] = None
73
73
  """Parameters for image generation."""
74
74
 
75
+ def validate(self):
76
+ if (
77
+ (self.messages and self.prompt)
78
+ or (self.messages and self.multimodal_prompt)
79
+ or (self.prompt and self.multimodal_prompt)
80
+ ):
81
+ raise ValueError("Exactly one of the messages, prompt, multimodal_prompt fields should be set")
82
+
83
+ if self.multimodal_prompt:
84
+ for media_object in self.multimodal_prompt.media_objects:
85
+ if media_object.content_type == "text" and media_object.text is None:
86
+ raise ValueError("Media object with text content type must have text set")
87
+
88
+ if media_object.content_type == "image" and media_object.location is None:
89
+ raise ValueError("Media object with image content type must have location set")
90
+
75
91
  @property
76
92
  def model_host(self) -> str:
77
93
  """Returns the model host (referring to the deployment).