crfm-helm 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (98) hide show
  1. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +13 -3
  2. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +96 -63
  3. helm/benchmark/adaptation/adapter_spec.py +32 -31
  4. helm/benchmark/annotation/air_bench_annotator.py +64 -0
  5. helm/benchmark/annotation/annotator_factory.py +6 -0
  6. helm/benchmark/annotation/live_qa_annotator.py +84 -0
  7. helm/benchmark/annotation/medication_qa_annotator.py +81 -0
  8. helm/benchmark/augmentations/translate_perturbation.py +1 -0
  9. helm/benchmark/huggingface_registration.py +16 -6
  10. helm/benchmark/metrics/air_bench_metrics.py +56 -0
  11. helm/benchmark/metrics/fin_qa_metrics.py +60 -0
  12. helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
  13. helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
  14. helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
  15. helm/benchmark/metrics/live_qa_metrics.py +23 -0
  16. helm/benchmark/metrics/medication_qa_metrics.py +23 -0
  17. helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
  18. helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
  19. helm/benchmark/metrics/unitxt_metrics.py +20 -10
  20. helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
  21. helm/benchmark/metrics/vision_language/image_metrics.py +29 -71
  22. helm/benchmark/presentation/schema.py +54 -4
  23. helm/benchmark/presentation/test_schema.py +11 -0
  24. helm/benchmark/run.py +16 -2
  25. helm/benchmark/run_expander.py +77 -0
  26. helm/benchmark/run_spec_factory.py +4 -0
  27. helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
  28. helm/benchmark/run_specs/classic_run_specs.py +15 -11
  29. helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
  30. helm/benchmark/run_specs/experimental_run_specs.py +33 -0
  31. helm/benchmark/run_specs/finance_run_specs.py +33 -0
  32. helm/benchmark/run_specs/vlm_run_specs.py +168 -45
  33. helm/benchmark/scenarios/air_bench_scenario.py +50 -0
  34. helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
  35. helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
  36. helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
  37. helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
  38. helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
  39. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
  40. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
  41. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +0 -4
  42. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +4 -2
  43. helm/benchmark/scenarios/vision_language/pairs_scenario.py +6 -5
  44. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
  45. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
  46. helm/benchmark/static/schema_air_bench.yaml +3149 -0
  47. helm/benchmark/static/schema_classic.yaml +3 -59
  48. helm/benchmark/static/schema_finance.yaml +143 -0
  49. helm/benchmark/static/schema_image2structure.yaml +254 -111
  50. helm/benchmark/static/schema_instruction_following.yaml +3 -52
  51. helm/benchmark/static/schema_lite.yaml +3 -61
  52. helm/benchmark/static/schema_medical.yaml +255 -0
  53. helm/benchmark/static/schema_mmlu.yaml +3 -61
  54. helm/benchmark/static/schema_tables.yaml +200 -0
  55. helm/benchmark/static/schema_thai.yaml +223 -0
  56. helm/benchmark/static/schema_unitxt.yaml +3 -61
  57. helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +294 -293
  58. helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
  59. helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
  60. helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
  61. helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
  62. helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
  63. helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
  64. helm/benchmark/static_build/index.html +2 -2
  65. helm/clients/anthropic_client.py +43 -9
  66. helm/clients/auto_client.py +11 -0
  67. helm/clients/client.py +24 -7
  68. helm/clients/cohere_client.py +98 -3
  69. helm/clients/huggingface_client.py +71 -12
  70. helm/clients/openai_client.py +9 -2
  71. helm/clients/reka_client.py +189 -0
  72. helm/clients/test_client.py +3 -3
  73. helm/clients/test_huggingface_client.py +19 -3
  74. helm/clients/test_together_client.py +72 -2
  75. helm/clients/together_client.py +129 -23
  76. helm/clients/vertexai_client.py +62 -18
  77. helm/clients/vision_language/huggingface_vlm_client.py +1 -0
  78. helm/clients/vision_language/paligemma_client.py +146 -0
  79. helm/clients/vision_language/palmyra_vision_client.py +84 -0
  80. helm/clients/yi_client.py +31 -0
  81. helm/common/critique_request.py +10 -1
  82. helm/common/images_utils.py +19 -0
  83. helm/config/model_deployments.yaml +412 -18
  84. helm/config/model_metadata.yaml +447 -25
  85. helm/config/tokenizer_configs.yaml +93 -1
  86. helm/proxy/critique/model_critique_client.py +32 -4
  87. helm/proxy/services/server_service.py +1 -1
  88. helm/tokenizers/auto_tokenizer.py +1 -1
  89. helm/tokenizers/cohere_tokenizer.py +44 -2
  90. helm/tokenizers/huggingface_tokenizer.py +36 -13
  91. helm/tokenizers/test_cohere_tokenizer.py +39 -0
  92. helm/tokenizers/test_huggingface_tokenizer.py +5 -1
  93. helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
  94. helm/benchmark/static_build/assets/index-878a1094.css +0 -1
  95. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
  96. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
  97. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
  98. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
- from helm.common.cache import BlackHoleCacheConfig
2
- from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
1
+ from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
2
+ from helm.tokenizers.auto_tokenizer import AutoTokenizer
3
3
  from .client import truncate_sequence, truncate_and_tokenize_response_text
4
4
  from typing import List
5
5
  from helm.common.request import Request, GeneratedOutput, Token
@@ -52,8 +52,8 @@ def test_truncate_sequence():
52
52
 
53
53
 
54
54
  def test_truncate_and_tokenize_response_text():
55
- tokenizer = HuggingFaceTokenizer(BlackHoleCacheConfig())
56
55
  tokenizer_name = "huggingface/gpt2"
56
+ tokenizer = AutoTokenizer(credentials={}, cache_backend_config=BlackHoleCacheBackendConfig())
57
57
 
58
58
  # No truncation
59
59
  response = truncate_and_tokenize_response_text(
@@ -3,12 +3,18 @@ import pytest
3
3
  from helm.common.cache import BlackHoleCacheConfig
4
4
  from helm.common.request import Request, RequestResult
5
5
  from helm.clients.huggingface_client import HuggingFaceClient
6
+ from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
6
7
 
7
8
 
8
9
  class TestHuggingFaceClient:
9
10
  def test_gpt2(self):
11
+ tokenizer = HuggingFaceTokenizer(
12
+ BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
13
+ )
10
14
  client = HuggingFaceClient(
11
- cache_config=BlackHoleCacheConfig(), pretrained_model_name_or_path="openai-community/gpt2"
15
+ cache_config=BlackHoleCacheConfig(),
16
+ tokenizer=tokenizer,
17
+ pretrained_model_name_or_path="openai-community/gpt2",
12
18
  )
13
19
  prompt: str = "I am a computer scientist."
14
20
  result: RequestResult = client.make_request(
@@ -29,8 +35,13 @@ class TestHuggingFaceClient:
29
35
 
30
36
  @pytest.mark.skip(reason="GPT-J 6B is 22 GB and extremely slow without a GPU.")
31
37
  def test_gptj_6b(self):
38
+ tokenizer = HuggingFaceTokenizer(
39
+ BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
40
+ )
32
41
  client = HuggingFaceClient(
33
- cache_config=BlackHoleCacheConfig(), pretrained_model_name_or_path="openai-community/gpt2"
42
+ cache_config=BlackHoleCacheConfig(),
43
+ tokenizer=tokenizer,
44
+ pretrained_model_name_or_path="openai-community/gpt2",
34
45
  )
35
46
  result: RequestResult = client.make_request(
36
47
  Request(
@@ -45,8 +56,13 @@ class TestHuggingFaceClient:
45
56
  assert len(result.completions) == 3
46
57
 
47
58
  def test_logprob(self):
59
+ tokenizer = HuggingFaceTokenizer(
60
+ BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
61
+ )
48
62
  client = HuggingFaceClient(
49
- cache_config=BlackHoleCacheConfig(), pretrained_model_name_or_path="openai-community/gpt2"
63
+ cache_config=BlackHoleCacheConfig(),
64
+ tokenizer=tokenizer,
65
+ pretrained_model_name_or_path="openai-community/gpt2",
50
66
  )
51
67
  prompt: str = "I am a computer scientist."
52
68
  result: RequestResult = client.make_request(
@@ -2,10 +2,10 @@ import os
2
2
  import pytest
3
3
  import tempfile
4
4
 
5
- from helm.common.cache import SqliteCacheConfig
5
+ from helm.common.cache import BlackHoleCacheConfig, SqliteCacheConfig
6
6
  from helm.common.request import Request
7
7
 
8
- from .together_client import TogetherClient, TogetherClientError
8
+ from .together_client import TogetherClient, TogetherChatClient, TogetherCompletionClient, TogetherClientError
9
9
 
10
10
 
11
11
  class TestTogetherClient:
@@ -107,3 +107,73 @@ class TestTogetherClient:
107
107
  model_deployment="together/redpajama-incite-base-3b-v1",
108
108
  )
109
109
  )
110
+
111
+
112
+ @pytest.mark.models
113
+ def test_together_chat_client_make_request():
114
+ # Requires setting TOGETHER_API_KEY environment variable.
115
+ client = TogetherChatClient(
116
+ cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-chat-hf"
117
+ )
118
+ request = Request(
119
+ model="meta/llama-3-8b-instruct",
120
+ model_deployment="together/llama-3-8b-instruct",
121
+ prompt="Elephants are one of the most",
122
+ temperature=0.0,
123
+ max_tokens=10,
124
+ )
125
+ result = client.make_request(request)
126
+ assert result.success
127
+ assert not result.cached
128
+ assert result.embedding == []
129
+ assert len(result.completions) == 1
130
+ assert result.completions[0].text == "...intelligent animals on Earth!assistant"
131
+ assert result.completions[0].logprob == 0.0
132
+ result_token_strings = [token.text for token in result.completions[0].tokens]
133
+ assert result_token_strings == [
134
+ "...",
135
+ "int",
136
+ "elligent",
137
+ " animals",
138
+ " on",
139
+ " Earth",
140
+ "!",
141
+ "<|eot_id|>",
142
+ "<|start_header_id|>",
143
+ "assistant",
144
+ ]
145
+
146
+
147
+ @pytest.mark.models
148
+ def test_together_completion_client_make_request():
149
+ # Requires setting TOGETHER_API_KEY environment variable.
150
+ client = TogetherCompletionClient(
151
+ cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-hf"
152
+ )
153
+ request = Request(
154
+ model="meta/llama-3-8b",
155
+ model_deployment="together/llama-3-8b",
156
+ prompt="Elephants are one of the most",
157
+ temperature=0.0,
158
+ max_tokens=10,
159
+ )
160
+ result = client.make_request(request)
161
+ assert result.success
162
+ assert not result.cached
163
+ assert result.embedding == []
164
+ assert len(result.completions) == 1
165
+ assert result.completions[0].text == " popular animals in the world. They are known for"
166
+ assert result.completions[0].logprob == 0.0
167
+ result_token_strings = [token.text for token in result.completions[0].tokens]
168
+ assert result_token_strings == [
169
+ " popular",
170
+ " animals",
171
+ " in",
172
+ " the",
173
+ " world",
174
+ ".",
175
+ " They",
176
+ " are",
177
+ " known",
178
+ " for",
179
+ ]
@@ -1,6 +1,7 @@
1
1
  from copy import deepcopy
2
2
  from itertools import zip_longest
3
- from typing import List, Dict, Any, Optional, TypedDict, Union
3
+ import threading
4
+ from typing import List, Dict, Any, Mapping, Optional, TypedDict, Union
4
5
 
5
6
  import requests
6
7
  from retrying import retry
@@ -12,7 +13,7 @@ from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
12
13
 
13
14
  try:
14
15
  from together import Together
15
- from together.types import ChatCompletionResponse
16
+ from together.types import ChatCompletionResponse, CompletionResponse
16
17
  except ModuleNotFoundError as e:
17
18
  handle_module_not_found_error(e, ["together"])
18
19
 
@@ -282,6 +283,24 @@ class TogetherClient(CachingClient):
282
283
  )
283
284
 
284
285
 
286
+ _MODEL_TO_DEFAULT_STOP_TOKENS: Optional[Mapping[str, List[str]]] = None
287
+ _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK = threading.Lock()
288
+
289
+
290
+ def get_default_stop_tokens_for_model(together_model: str, together_client: Together) -> List[str]:
291
+ global _MODEL_TO_DEFAULT_STOP_TOKENS
292
+ global _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK
293
+ with _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK:
294
+ if _MODEL_TO_DEFAULT_STOP_TOKENS is None:
295
+ _MODEL_TO_DEFAULT_STOP_TOKENS = {}
296
+ for model in together_client.models.list():
297
+ _MODEL_TO_DEFAULT_STOP_TOKENS[model.id.lower()] = model.config["stop"]
298
+ stop_tokens = _MODEL_TO_DEFAULT_STOP_TOKENS.get(together_model.lower())
299
+ if stop_tokens is None:
300
+ raise ValueError(f"Unknown together_model {together_model}")
301
+ return stop_tokens
302
+
303
+
285
304
  class TogetherRawChatRequest(TypedDict):
286
305
  messages: List[Dict[str, str]]
287
306
  model: str
@@ -295,34 +314,38 @@ class TogetherRawChatRequest(TypedDict):
295
314
  n: int
296
315
 
297
316
 
298
- def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
299
- if request.messages:
300
- messages = request.messages
301
- else:
302
- messages = [{"role": "user", "content": request.prompt}]
303
- return {
304
- "messages": messages,
305
- "model": request.model,
306
- "max_tokens": request.max_tokens,
307
- "stop": request.stop_sequences,
308
- "temperature": request.temperature,
309
- "top_p": request.top_p,
310
- "top_k": request.top_k_per_token,
311
- "logprobs": min(request.top_k_per_token, 1),
312
- "echo": request.echo_prompt,
313
- "n": request.num_completions,
314
- }
315
-
316
-
317
317
  class TogetherChatClient(CachingClient):
318
318
  """Client that uses the Python Together library for chat models."""
319
319
 
320
- def __init__(self, cache_config: CacheConfig, api_key: str, together_model: Optional[str] = None):
320
+ def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
321
321
  super().__init__(cache_config=cache_config)
322
322
  self._client = Together(api_key=api_key)
323
+ self._together_model = together_model
324
+
325
+ def convert_to_raw_chat_request(self, request: Request) -> TogetherRawChatRequest:
326
+ if request.messages:
327
+ messages = request.messages
328
+ else:
329
+ messages = [{"role": "user", "content": request.prompt}]
330
+ if self._together_model is not None:
331
+ model = self._together_model
332
+ else:
333
+ model = request.model
334
+ return {
335
+ "messages": messages,
336
+ "model": model,
337
+ "max_tokens": request.max_tokens,
338
+ "stop": request.stop_sequences + get_default_stop_tokens_for_model(model, self._client),
339
+ "temperature": request.temperature,
340
+ "top_p": request.top_p,
341
+ "top_k": request.top_k_per_token,
342
+ "logprobs": min(request.top_k_per_token, 1),
343
+ "echo": request.echo_prompt,
344
+ "n": request.num_completions,
345
+ }
323
346
 
324
347
  def make_request(self, request: Request) -> RequestResult:
325
- raw_request = convert_to_raw_chat_request(request)
348
+ raw_request = self.convert_to_raw_chat_request(request)
326
349
  cache_key = CachingClient.make_cache_key(raw_request, request)
327
350
 
328
351
  def do_it() -> Dict[Any, Any]:
@@ -363,3 +386,86 @@ class TogetherChatClient(CachingClient):
363
386
  completions=generated_outputs,
364
387
  embedding=[],
365
388
  )
389
+
390
+
391
+ class TogetherRawCompletionRequest(TypedDict):
392
+ prompt: str
393
+ model: str
394
+ max_tokens: int
395
+ stop: List[str]
396
+ temperature: float
397
+ top_p: float
398
+ top_k: int
399
+ logprobs: int
400
+ echo: bool
401
+ n: int
402
+
403
+
404
+ class TogetherCompletionClient(CachingClient):
405
+ """Client that uses the Python Together library for text completion models."""
406
+
407
+ def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
408
+ super().__init__(cache_config=cache_config)
409
+ self._client = Together(api_key=api_key)
410
+ self._together_model = together_model
411
+
412
+ def convert_to_raw_completion_request(self, request: Request) -> TogetherRawCompletionRequest:
413
+ if self._together_model is not None:
414
+ model = self._together_model
415
+ else:
416
+ model = request.model
417
+ return {
418
+ "prompt": request.prompt,
419
+ "model": model,
420
+ "max_tokens": request.max_tokens,
421
+ "stop": request.stop_sequences + get_default_stop_tokens_for_model(model, self._client),
422
+ "temperature": request.temperature,
423
+ "top_p": request.top_p,
424
+ "top_k": request.top_k_per_token,
425
+ "logprobs": min(request.top_k_per_token, 1),
426
+ "echo": request.echo_prompt,
427
+ "n": request.num_completions,
428
+ }
429
+
430
+ def make_request(self, request: Request) -> RequestResult:
431
+ raw_request = self.convert_to_raw_completion_request(request)
432
+ cache_key = CachingClient.make_cache_key(raw_request, request)
433
+
434
+ def do_it() -> Dict[Any, Any]:
435
+ response = self._client.completions.create(**raw_request)
436
+ return response.model_dump(mode="json")
437
+
438
+ try:
439
+ raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
440
+ response = CompletionResponse.model_validate(raw_response)
441
+ except Exception as error:
442
+ return RequestResult(
443
+ success=False,
444
+ cached=False,
445
+ error=str(error),
446
+ completions=[],
447
+ embedding=[],
448
+ )
449
+
450
+ generated_outputs: List[GeneratedOutput] = []
451
+ for choice in response.choices:
452
+ # NOTE: Together always returns None for choice.finish_reason
453
+ # NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
454
+ tokens: List[Token] = []
455
+ if choice.logprobs:
456
+ for token_text, token_logprob in zip_longest(
457
+ choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
458
+ ):
459
+ if token_text is None:
460
+ break
461
+ tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
462
+ assert choice.text
463
+ generated_outputs.append(GeneratedOutput(text=choice.text, logprob=0.0, tokens=tokens))
464
+ return RequestResult(
465
+ success=True,
466
+ cached=cached,
467
+ request_time=raw_response["request_time"],
468
+ request_datetime=raw_response["request_datetime"],
469
+ completions=generated_outputs,
470
+ embedding=[],
471
+ )
@@ -1,7 +1,7 @@
1
1
  import requests
2
2
  from abc import ABC, abstractmethod
3
3
  from threading import Lock
4
- from typing import Any, Dict, Optional, List, Union
4
+ from typing import Any, Dict, Mapping, Optional, List, Union
5
5
 
6
6
  from helm.common.cache import CacheConfig
7
7
  from helm.common.media_object import TEXT_TYPE
@@ -26,22 +26,62 @@ class VertexAIContentBlockedError(Exception):
26
26
  pass
27
27
 
28
28
 
29
+ class SafetySettingPresets:
30
+ BLOCK_NONE = "block_none" # Disable all blocking
31
+ DEFAULT = "default" # Use default safety settings
32
+
33
+
34
+ def _get_safety_settings_for_preset(
35
+ safety_settings_preset: Optional[str],
36
+ ) -> Optional[Dict[HarmCategory, SafetySetting.HarmBlockThreshold]]:
37
+ """Get the safety settings for the safety_settings_preset.
38
+
39
+ If safety_settings_preset is None, use the default value of BLOCK_NONE (*not* DEFAULT)."""
40
+ if safety_settings_preset is None or safety_settings_preset == SafetySettingPresets.BLOCK_NONE:
41
+ return {
42
+ harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
43
+ for harm_category in iter(HarmCategory)
44
+ }
45
+ elif safety_settings_preset == SafetySettingPresets.DEFAULT:
46
+ return None
47
+ else:
48
+ raise ValueError(f"Unknown safety_settings_preset: {safety_settings_preset}")
49
+
50
+
51
+ def _get_model_name_for_request(request: Request) -> str:
52
+ # We have to strip "-safety-" suffixes from model names because they are not part of the Vertex AI model name
53
+ # TODO: Clean up this hack
54
+ return request.model_engine.split("-safety-")[0]
55
+
56
+
29
57
  class VertexAIClient(CachingClient, ABC):
30
58
  """Client for Vertex AI models"""
31
59
 
32
- def __init__(self, cache_config: CacheConfig, project_id: str, location: str) -> None:
60
+ def __init__(
61
+ self, cache_config: CacheConfig, project_id: str, location: str, safety_settings_preset: Optional[str] = None
62
+ ) -> None:
33
63
  super().__init__(cache_config=cache_config)
34
64
  self.project_id = project_id
35
65
  self.location = location
36
66
 
37
- # VertexAI's default safety filter is overly sensitive, so we disable it.
38
- self.safety_settings: Dict[HarmCategory, SafetySetting.HarmBlockThreshold] = {
39
- harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
40
- for harm_category in iter(HarmCategory)
41
- }
67
+ self.safety_settings_preset = safety_settings_preset
68
+ self.safety_settings = _get_safety_settings_for_preset(safety_settings_preset)
42
69
 
43
70
  vertexai.init(project=self.project_id, location=self.location)
44
71
 
72
+ def make_cache_key_with_safety_settings_preset(self, raw_request: Mapping, request: Request) -> Mapping:
73
+ """Construct the key for the cache using the raw request.
74
+
75
+ Add `self.safety_settings_preset` to the key, if not None."""
76
+ if self.safety_settings_preset is not None:
77
+ assert "safety_settings_preset" not in raw_request
78
+ return {
79
+ **CachingClient.make_cache_key(raw_request, request),
80
+ "safety_settings_preset": self.safety_settings_preset,
81
+ }
82
+ else:
83
+ return CachingClient.make_cache_key(raw_request, request)
84
+
45
85
  @abstractmethod
46
86
  def make_request(self, request: Request) -> RequestResult:
47
87
  raise NotImplementedError
@@ -71,7 +111,7 @@ class VertexAITextClient(VertexAIClient):
71
111
  }
72
112
 
73
113
  completions: List[GeneratedOutput] = []
74
- model_name: str = request.model_engine
114
+ model_name: str = _get_model_name_for_request(request)
75
115
 
76
116
  try:
77
117
 
@@ -87,9 +127,9 @@ class VertexAITextClient(VertexAIClient):
87
127
  # We need to include the engine's name to differentiate among requests made for different model
88
128
  # engines since the engine name is not included in the request itself.
89
129
  # Same for the prompt.
90
- cache_key = CachingClient.make_cache_key(
130
+ cache_key = self.make_cache_key_with_safety_settings_preset(
91
131
  {
92
- "engine": request.model_engine,
132
+ "engine": model_name,
93
133
  "prompt": request.prompt,
94
134
  **parameters,
95
135
  },
@@ -177,7 +217,7 @@ class VertexAIChatClient(VertexAIClient):
177
217
  }
178
218
 
179
219
  completions: List[GeneratedOutput] = []
180
- model_name: str = request.model_engine
220
+ model_name: str = _get_model_name_for_request(request)
181
221
  model = self.get_model(model_name)
182
222
 
183
223
  try:
@@ -197,7 +237,7 @@ class VertexAIChatClient(VertexAIClient):
197
237
 
198
238
  # Depending on the version of the Vertex AI library and the type of prompt blocking,
199
239
  # prompt blocking can show up in many ways, so this defensively handles most of these ways
200
- if response.prompt_feedback.block_reason:
240
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
201
241
  raise VertexAIContentBlockedError(
202
242
  f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
203
243
  )
@@ -209,8 +249,10 @@ class VertexAIChatClient(VertexAIClient):
209
249
  # content blocking can show up in many ways, so this defensively handles most of these ways
210
250
  if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
211
251
  raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
252
+ if not candidate.content:
253
+ raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
212
254
  if not candidate.content.parts:
213
- raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
255
+ raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
214
256
  predictions.append({"text": candidate.content.text})
215
257
  # TODO: Extract more information from the response
216
258
  return {"predictions": predictions}
@@ -218,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
218
260
  # We need to include the engine's name to differentiate among requests made for different model
219
261
  # engines since the engine name is not included in the request itself.
220
262
  # Same for the prompt.
221
- cache_key = CachingClient.make_cache_key(
263
+ cache_key = self.make_cache_key_with_safety_settings_preset(
222
264
  {
223
265
  "model_name": model_name,
224
266
  "prompt": request.prompt,
@@ -313,7 +355,7 @@ class VertexAIChatClient(VertexAIClient):
313
355
  }
314
356
 
315
357
  completions: List[GeneratedOutput] = []
316
- model_name: str = request.model_engine
358
+ model_name: str = _get_model_name_for_request(request)
317
359
  model = self.get_model(model_name)
318
360
 
319
361
  request_time = 0
@@ -330,7 +372,7 @@ class VertexAIChatClient(VertexAIClient):
330
372
  )
331
373
  # Depending on the version of the Vertex AI library and the type of prompt blocking,
332
374
  # prompt blocking can show up in many ways, so this defensively handles most of these ways
333
- if response.prompt_feedback.block_reason:
375
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
334
376
  raise VertexAIContentBlockedError(
335
377
  f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
336
378
  )
@@ -345,15 +387,17 @@ class VertexAIChatClient(VertexAIClient):
345
387
  # content blocking can show up in many ways, so this defensively handles most of these ways
346
388
  if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
347
389
  raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
390
+ if not candidate.content:
391
+ raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
348
392
  if not candidate.content.parts:
349
- raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
393
+ raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
350
394
  return {"predictions": [{"text": candidate.text}]}
351
395
 
352
396
  raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
353
397
  if completion_index > 0:
354
398
  raw_cache_key["completion_index"] = completion_index
355
399
 
356
- cache_key = CachingClient.make_cache_key(raw_cache_key, request)
400
+ cache_key = self.make_cache_key_with_safety_settings_preset(raw_cache_key, request)
357
401
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
358
402
  except requests.exceptions.RequestException as e:
359
403
  error: str = f"Gemini Vision error: {e}"
@@ -38,6 +38,7 @@ class HuggingFaceVLMClient(CachingClient):
38
38
  "huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
39
39
  "huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
40
40
  "huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
41
+ "huggingface/prometheus-vision-13b-v1.0-hf": "PahaII/prometheus-vision-13b-v1.0-hf",
41
42
  }
42
43
 
43
44
  def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
@@ -0,0 +1,146 @@
1
+ from threading import Lock
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from dataclasses import dataclass
6
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
7
+
8
+ from helm.common.cache import CacheConfig
9
+ from helm.common.images_utils import open_image
10
+ from helm.common.gpu_utils import get_torch_device_name
11
+ from helm.common.hierarchical_logger import hlog, htrack_block
12
+ from helm.common.media_object import TEXT_TYPE
13
+ from helm.common.optional_dependencies import handle_module_not_found_error
14
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
15
+ from helm.common.tokenization_request import TokenizationRequest
16
+ from helm.common.request import wrap_request_time
17
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
18
+ from helm.tokenizers.tokenizer import Tokenizer
19
+
20
+ try:
21
+ from PIL import Image
22
+ except ModuleNotFoundError as e:
23
+ handle_module_not_found_error(e, ["images"])
24
+
25
+ # Added to solve: cutlassF: no kernel found to launch!
26
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
27
+ torch.backends.cuda.enable_flash_sdp(False)
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class LoadedPaliGemmaForConditionalGeneration:
32
+ """Loaded model and processor for PaliGemma."""
33
+
34
+ model: PaliGemmaForConditionalGeneration
35
+ processor: AutoProcessor
36
+
37
+
38
+ _models_lock: Lock = Lock()
39
+ _models: Dict[str, Optional[LoadedPaliGemmaForConditionalGeneration]] = {}
40
+
41
+
42
+ class PaliGemmaClient(CachingClient):
43
+ """
44
+ PaliGemma is a versatile and lightweight vision-language model (VLM) inspired by PaLI-3
45
+ and based on open components such as the SigLIP vision model and the Gemma language model.
46
+ It takes both image and text as input and generates text as output, supporting multiple languages.
47
+ It is designed for class-leading fine-tune performance on a wide range of vision-language tasks
48
+ such as image and short video caption, visual question answering, text reading, object detection
49
+ and object segmentation.
50
+ """
51
+
52
+ def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
53
+ super().__init__(cache_config=cache_config)
54
+ self.tokenizer = tokenizer
55
+ self.tokenizer_name = tokenizer_name
56
+ self._device: str = get_torch_device_name()
57
+
58
+ def _get_model(self, checkpoint: str) -> LoadedPaliGemmaForConditionalGeneration:
59
+ global _models_lock
60
+ global _models
61
+
62
+ # Ensure that only one thread is loading the model at a time
63
+ with _models_lock:
64
+ if checkpoint not in _models or _models[checkpoint] is None:
65
+ hlog(f"Loading model {checkpoint} and caching in memory...")
66
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
67
+ checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
68
+ ).eval()
69
+ processor = AutoProcessor.from_pretrained(checkpoint)
70
+ _models[checkpoint] = LoadedPaliGemmaForConditionalGeneration(model, processor)
71
+ loaded_model_processor = _models[checkpoint]
72
+
73
+ assert loaded_model_processor is not None
74
+ return loaded_model_processor
75
+
76
+ def make_request(self, request: Request) -> RequestResult:
77
+ assert request.multimodal_prompt is not None, "Multimodal prompt is required"
78
+
79
+ loaded_model_processor: LoadedPaliGemmaForConditionalGeneration = self._get_model(request.model_deployment)
80
+ model = loaded_model_processor.model
81
+ processor = loaded_model_processor.processor
82
+ generation_args = {"max_new_tokens": request.max_tokens}
83
+
84
+ images: List[Image.Image] = []
85
+ prompt_pieces: List[str] = []
86
+ for media_object in request.multimodal_prompt.media_objects:
87
+ if media_object.is_type("image") and media_object.location:
88
+ images += [open_image(media_object.location).convert("RGB")]
89
+ elif media_object.is_type(TEXT_TYPE):
90
+ if media_object.text is None:
91
+ raise ValueError("MediaObject of text type has missing text field value")
92
+ prompt_pieces.append(media_object.text)
93
+ else:
94
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
95
+ prompt_text: str = "\n".join(prompt_pieces)
96
+ model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
97
+ input_len = model_inputs["input_ids"].shape[-1]
98
+
99
+ completions: List[GeneratedOutput] = []
100
+ with htrack_block(f"Generating for prompt: {prompt_text}"):
101
+ try:
102
+ concat_results = []
103
+ for i_completion in range(request.num_completions):
104
+
105
+ def do_it() -> Dict[str, Any]:
106
+ with torch.inference_mode():
107
+ generation = model.generate(
108
+ **model_inputs, max_new_tokens=request.max_tokens, do_sample=False
109
+ )[0]
110
+ if not request.echo_prompt:
111
+ generation = generation[input_len:]
112
+ decoded = processor.decode(generation, skip_special_tokens=True)
113
+ return {"output": decoded}
114
+
115
+ # Include the prompt and model name in the cache key
116
+ cache_key = CachingClient.make_cache_key(
117
+ raw_request={
118
+ "n": request.num_completions,
119
+ "i": i_completion,
120
+ "model": request.model,
121
+ "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
122
+ **generation_args,
123
+ },
124
+ request=request,
125
+ )
126
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
127
+ concat_results.append(result)
128
+ except RuntimeError as model_error:
129
+ return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
130
+
131
+ for result in concat_results:
132
+ text = result["output"]
133
+ hlog(f"Generated text: {text}")
134
+ tokenization_result = self.tokenizer.tokenize(
135
+ TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
136
+ )
137
+ tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
138
+ completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
139
+
140
+ return RequestResult(
141
+ success=True,
142
+ cached=cached,
143
+ request_time=result["request_time"],
144
+ completions=completions,
145
+ embedding=[],
146
+ )