crfm-helm 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/METADATA +11 -8
  2. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/RECORD +67 -38
  3. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/WHEEL +1 -1
  4. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/entry_points.txt +2 -1
  5. helm/benchmark/__init__.py +13 -0
  6. helm/benchmark/adaptation/adapter_spec.py +3 -0
  7. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -7
  8. helm/benchmark/augmentations/correct_to_misspelling.json +1 -0
  9. helm/benchmark/contamination/__init__.py +0 -0
  10. helm/benchmark/metrics/classification_metrics.py +70 -0
  11. helm/benchmark/metrics/machine_translation_metrics.py +36 -0
  12. helm/benchmark/metrics/summarization_metrics.py +7 -8
  13. helm/benchmark/metrics/test_classification_metrics.py +150 -0
  14. helm/benchmark/presentation/create_plots.py +617 -0
  15. helm/benchmark/presentation/run_display.py +7 -48
  16. helm/benchmark/presentation/summarize.py +4 -2
  17. helm/benchmark/presentation/test_create_plots.py +32 -0
  18. helm/benchmark/run.py +144 -48
  19. helm/benchmark/run_expander.py +164 -47
  20. helm/benchmark/run_specs.py +346 -39
  21. helm/benchmark/runner.py +34 -6
  22. helm/benchmark/scenarios/copyright_scenario.py +1 -1
  23. helm/benchmark/scenarios/covid_dialog_scenario.py +84 -0
  24. helm/benchmark/scenarios/imdb_listdir.json +50014 -0
  25. helm/benchmark/scenarios/lex_glue_scenario.py +253 -0
  26. helm/benchmark/scenarios/lextreme_scenario.py +458 -0
  27. helm/benchmark/scenarios/me_q_sum_scenario.py +86 -0
  28. helm/benchmark/scenarios/med_dialog_scenario.py +132 -0
  29. helm/benchmark/scenarios/med_mcqa_scenario.py +102 -0
  30. helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +119 -0
  31. helm/benchmark/scenarios/med_qa_scenario.py +96 -0
  32. helm/benchmark/scenarios/opinions_qa_scenario.py +194 -0
  33. helm/benchmark/scenarios/scenario.py +5 -0
  34. helm/benchmark/scenarios/the_pile_scenario.py +1 -1
  35. helm/benchmark/scenarios/wmt_14_scenario.py +96 -0
  36. helm/benchmark/static/benchmarking.css +14 -0
  37. helm/benchmark/static/benchmarking.js +43 -0
  38. helm/benchmark/static/index.html +2 -0
  39. helm/benchmark/static/json-urls.js +4 -0
  40. helm/benchmark/static/plot-captions.js +16 -0
  41. helm/benchmark/static/schema.yaml +154 -1
  42. helm/benchmark/window_services/cohere_window_service.py +20 -0
  43. helm/benchmark/window_services/flan_t5_window_service.py +29 -0
  44. helm/benchmark/window_services/huggingface_window_service.py +39 -0
  45. helm/benchmark/window_services/santacoder_window_service.py +27 -0
  46. helm/benchmark/window_services/test_flan_t5_window_service.py +12 -0
  47. helm/benchmark/window_services/wider_ai21_window_service.py +13 -0
  48. helm/benchmark/window_services/window_service_factory.py +34 -7
  49. helm/common/codec.py +123 -0
  50. helm/common/general.py +12 -5
  51. helm/common/test_codec.py +144 -0
  52. helm/proxy/clients/aleph_alpha_client.py +47 -28
  53. helm/proxy/clients/auto_client.py +32 -24
  54. helm/proxy/clients/google_client.py +88 -0
  55. helm/proxy/clients/huggingface_client.py +32 -16
  56. helm/proxy/clients/huggingface_model_registry.py +111 -0
  57. helm/proxy/clients/huggingface_tokenizer.py +25 -7
  58. helm/proxy/clients/openai_client.py +60 -2
  59. helm/proxy/clients/test_huggingface_model_registry.py +57 -0
  60. helm/proxy/clients/test_huggingface_tokenizer.py +3 -0
  61. helm/proxy/clients/together_client.py +17 -2
  62. helm/proxy/clients/yalm_tokenizer/voc_100b.sp +0 -0
  63. helm/proxy/clients/yalm_tokenizer/yalm_tokenizer.py +8 -2
  64. helm/proxy/models.py +115 -7
  65. helm/proxy/test_models.py +1 -1
  66. helm/benchmark/presentation/present.py +0 -249
  67. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/LICENSE +0 -0
  68. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/top_level.txt +0 -0
helm/common/codec.py ADDED
@@ -0,0 +1,123 @@
1
+ """Functions for converting to and from dataclasses."""
2
+
3
+ import dataclasses
4
+ import json
5
+ import typing
6
+ from typing import Any, Callable, Dict, List, Union, Type, TypeVar
7
+
8
+ from helm.benchmark.augmentations.dialect_perturbation import DialectPerturbation
9
+ from helm.benchmark.augmentations.extra_space_perturbation import ExtraSpacePerturbation
10
+ from helm.benchmark.augmentations.filler_words_perturbation import FillerWordsPerturbation
11
+ from helm.benchmark.augmentations.gender_perturbation import GenderPerturbation
12
+ from helm.benchmark.augmentations.misspelling_perturbation import MisspellingPerturbation
13
+ from helm.benchmark.augmentations.person_name_perturbation import PersonNamePerturbation
14
+ from helm.benchmark.augmentations.space_perturbation import SpacePerturbation
15
+ from helm.benchmark.augmentations.synonym_perturbation import SynonymPerturbation
16
+ from helm.benchmark.augmentations.typos_perturbation import TyposPerturbation
17
+ from helm.benchmark.augmentations.perturbation_description import PerturbationDescription
18
+
19
+ import cattrs
20
+ from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn
21
+
22
+
23
+ T = TypeVar("T")
24
+ StructureFn = Callable[[Dict[str, Any], Type[T]], T] # dict -> dataclass
25
+ UnstructureFn = Callable[[T], Dict[str, Any]] # dataclass -> dict
26
+
27
+
28
+ # TODO(#1251): Add proper class registration
29
+ PERTURBATION_NAME_TO_DESCRIPTION = {
30
+ DialectPerturbation.name: DialectPerturbation.Description,
31
+ ExtraSpacePerturbation.name: ExtraSpacePerturbation.Description,
32
+ FillerWordsPerturbation.name: FillerWordsPerturbation.Description,
33
+ GenderPerturbation.name: GenderPerturbation.Description,
34
+ MisspellingPerturbation.name: MisspellingPerturbation.Description,
35
+ PersonNamePerturbation.name: PersonNamePerturbation.Description,
36
+ SpacePerturbation.name: SpacePerturbation.Description,
37
+ SynonymPerturbation.name: SynonymPerturbation.Description,
38
+ TyposPerturbation.name: TyposPerturbation.Description,
39
+ }
40
+
41
+
42
+ def _build_converter() -> cattrs.Converter:
43
+ converter = cattrs.Converter()
44
+
45
+ # Handle omission of Nones in JSON.
46
+ # To improve readability and reduce storage space, if a field value is None and the field
47
+ # has no default value or a None default value, the field is omitted in the serialized JSON.
48
+ def get_dataclass_optional_fields_without_default(cls: Type[T]) -> List[str]:
49
+ if not dataclasses.is_dataclass(cls):
50
+ return []
51
+ return [
52
+ field.name
53
+ for field in dataclasses.fields(cls)
54
+ if typing.get_origin(field.type) == Union and type(None) in typing.get_args(field.type)
55
+ # For optional fields with a non-None default value, do not replace a missing value
56
+ # with None.
57
+ and (field.default == dataclasses.MISSING or field.default is None)
58
+ and field.default_factory == dataclasses.MISSING
59
+ ]
60
+
61
+ def make_omit_nones_dict_structure_fn(cls: Type[T]) -> StructureFn[T]:
62
+ field_names = get_dataclass_optional_fields_without_default(cls)
63
+ _base_structure = make_dict_structure_fn(cls, converter)
64
+
65
+ def structure(raw_dict: Dict[str, Any], inner_cls: Type[T]) -> T:
66
+ for field_name in field_names:
67
+ if field_name not in raw_dict:
68
+ raw_dict[field_name] = None
69
+ return _base_structure(raw_dict, inner_cls)
70
+
71
+ return structure
72
+
73
+ def make_omit_nones_dict_unstructure_fn(cls: Type[T]) -> UnstructureFn[T]:
74
+ field_names = get_dataclass_optional_fields_without_default(cls)
75
+ _base_unstructure = make_dict_unstructure_fn(cls, converter)
76
+
77
+ def structure(data: T) -> Dict[str, Any]:
78
+ raw_dict = _base_unstructure(data)
79
+ for field_name in field_names:
80
+ if raw_dict[field_name] is None:
81
+ del raw_dict[field_name]
82
+ return raw_dict
83
+
84
+ return structure
85
+
86
+ converter.register_structure_hook_factory(
87
+ lambda cls: bool(get_dataclass_optional_fields_without_default(cls)), make_omit_nones_dict_structure_fn
88
+ )
89
+ converter.register_unstructure_hook_factory(
90
+ lambda cls: bool(get_dataclass_optional_fields_without_default(cls)), make_omit_nones_dict_unstructure_fn
91
+ )
92
+
93
+ # Handle the use of the name field in PerturbationDescription to determine the subclass.
94
+ base_perturbation_description_structure_fn: StructureFn = make_omit_nones_dict_structure_fn(PerturbationDescription)
95
+ perturbation_name_to_base_structure_fn: Dict[str, StructureFn] = {
96
+ name: make_omit_nones_dict_structure_fn(cls) for name, cls in PERTURBATION_NAME_TO_DESCRIPTION.items()
97
+ }
98
+
99
+ def structure_perturbation_description(
100
+ raw_dict: Dict[Any, Any], cls: Type[PerturbationDescription]
101
+ ) -> PerturbationDescription:
102
+ """Convert a raw dictionary to a PerturbationDescription.
103
+ This uses the name field to look up the correct PerturbationDescription subclass to output.
104
+ """
105
+ structure = perturbation_name_to_base_structure_fn.get(
106
+ raw_dict["name"], base_perturbation_description_structure_fn
107
+ )
108
+ return structure(raw_dict, cls)
109
+
110
+ converter.register_structure_hook(PerturbationDescription, structure_perturbation_description)
111
+
112
+ return converter
113
+
114
+
115
+ _converter = _build_converter()
116
+
117
+
118
+ def from_json(data: Union[bytes, str], cls: Type[T]) -> T:
119
+ return _converter.structure(json.loads(data), cls)
120
+
121
+
122
+ def to_json(data: Any) -> str:
123
+ return json.dumps(_converter.unstructure(data), indent=2)
helm/common/general.py CHANGED
@@ -49,7 +49,13 @@ def shell(args: List[str]):
49
49
 
50
50
 
51
51
  @htrack(None)
52
- def ensure_file_downloaded(source_url: str, target_path: str, unpack: bool = False, unpack_type: Optional[str] = None):
52
+ def ensure_file_downloaded(
53
+ source_url: str,
54
+ target_path: str,
55
+ unpack: bool = False,
56
+ downloader_executable: str = "wget",
57
+ unpack_type: Optional[str] = None,
58
+ ):
53
59
  """Download `source_url` to `target_path` if it doesn't exist."""
54
60
  if os.path.exists(target_path):
55
61
  # Assume it's all good
@@ -59,7 +65,8 @@ def ensure_file_downloaded(source_url: str, target_path: str, unpack: bool = Fal
59
65
  # Download
60
66
  # gdown is used to download large files/zip folders from Google Drive.
61
67
  # It bypasses security warnings which wget cannot handle.
62
- downloader_executable: str = "gdown" if source_url.startswith("https://drive.google.com") else "wget"
68
+ if source_url.startswith("https://drive.google.com"):
69
+ downloader_executable = "gdown"
63
70
  tmp_path: str = f"{target_path}.tmp"
64
71
  shell([downloader_executable, source_url, "-O", tmp_path])
65
72
 
@@ -195,13 +202,13 @@ def parallel_map(
195
202
  with htrack_block(f"Parallelizing computation on {len(items)} items over {parallelism} {units}"):
196
203
  results: List
197
204
  if parallelism == 1:
198
- results = list(tqdm(map(process, items), total=len(items)))
205
+ results = list(tqdm(map(process, items), total=len(items), disable=None))
199
206
  elif multiprocessing:
200
207
  with ProcessPoolExecutor(max_workers=parallelism) as executor:
201
- results = list(tqdm(executor.map(process, items), total=len(items)))
208
+ results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
202
209
  else:
203
210
  with ThreadPoolExecutor(max_workers=parallelism) as executor:
204
- results = list(tqdm(executor.map(process, items), total=len(items)))
211
+ results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
205
212
  return results
206
213
 
207
214
 
@@ -0,0 +1,144 @@
1
+ import unittest
2
+ import json
3
+
4
+ from dataclasses import dataclass
5
+
6
+ from typing import Dict, List, Optional
7
+
8
+ from helm.benchmark.augmentations.dialect_perturbation import DialectPerturbation
9
+ from helm.benchmark.augmentations.extra_space_perturbation import ExtraSpacePerturbation
10
+ from helm.benchmark.augmentations.filler_words_perturbation import FillerWordsPerturbation
11
+ from helm.benchmark.augmentations.gender_perturbation import GenderPerturbation
12
+ from helm.benchmark.augmentations.misspelling_perturbation import MisspellingPerturbation
13
+ from helm.benchmark.augmentations.person_name_perturbation import PersonNamePerturbation
14
+ from helm.benchmark.augmentations.space_perturbation import SpacePerturbation
15
+ from helm.benchmark.augmentations.synonym_perturbation import SynonymPerturbation
16
+ from helm.benchmark.augmentations.typos_perturbation import TyposPerturbation
17
+ from helm.benchmark.augmentations.perturbation_description import PerturbationDescription
18
+ from helm.common.codec import from_json, to_json
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class DataClassChildForTest:
23
+ required_int: int
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class DataClassWithOptionals:
28
+ optional_str: Optional[str]
29
+ optional_int: Optional[int]
30
+ optional_bool: Optional[bool]
31
+ optional_list: Optional[List[int]]
32
+ optional_dict: Optional[Dict[str, int]]
33
+ optional_child: Optional[DataClassChildForTest]
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class DataClassWithDefaults:
38
+ required_int_with_default: int = -1
39
+ optional_int_with_int_default: Optional[int] = -2
40
+ optional_int_with_none_default: Optional[int] = None
41
+
42
+
43
+ class TestJsonCodec(unittest.TestCase):
44
+ def test_round_trip_optional(self):
45
+ data = DataClassWithOptionals(
46
+ optional_str="hello",
47
+ optional_int=42,
48
+ optional_bool=True,
49
+ optional_list=[2, 3, 5],
50
+ optional_dict={"x": 7},
51
+ optional_child=DataClassChildForTest(137),
52
+ )
53
+ self.assertEqual(data, from_json(to_json(data), DataClassWithOptionals))
54
+
55
+ def test_round_trip_optional_nones(self):
56
+ data = DataClassWithOptionals(
57
+ optional_str=None,
58
+ optional_int=None,
59
+ optional_bool=None,
60
+ optional_list=None,
61
+ optional_dict=None,
62
+ optional_child=None,
63
+ )
64
+ data_json = to_json(data)
65
+ self.assertEqual("{}", data_json)
66
+ self.assertEqual(data, from_json(data_json, DataClassWithOptionals))
67
+
68
+ def test_round_trip_default(self):
69
+ data = DataClassWithDefaults()
70
+ data_json = to_json(data)
71
+ self.assertCountEqual(
72
+ {"required_int_with_default": -1, "optional_int_with_int_default": -2}.items(),
73
+ json.loads(data_json).items(),
74
+ )
75
+ self.assertEqual(data, from_json(data_json, DataClassWithDefaults))
76
+
77
+ def test_round_trip_default_ints(self):
78
+ data = DataClassWithDefaults(
79
+ required_int_with_default=1,
80
+ optional_int_with_int_default=2,
81
+ optional_int_with_none_default=3,
82
+ )
83
+ data_json = to_json(data)
84
+ self.assertEqual(data, from_json(data_json, DataClassWithDefaults))
85
+
86
+ def test_round_trip_default_nones(self):
87
+ data = DataClassWithDefaults(
88
+ optional_int_with_int_default=None,
89
+ optional_int_with_none_default=None,
90
+ )
91
+ data_json = to_json(data)
92
+ self.assertCountEqual(
93
+ {
94
+ "required_int_with_default": -1,
95
+ # `optional_int_with_int_default` should deserialize back to None,
96
+ # rather than the default int value. Therefore it must be
97
+ # serialized to null in JSON instead of removed.
98
+ "optional_int_with_int_default": None,
99
+ }.items(),
100
+ json.loads(data_json).items(),
101
+ )
102
+ self.assertEqual(data, from_json(data_json, DataClassWithDefaults))
103
+
104
+ def test_round_trip_perturbation_descriptions(self):
105
+ descriptions = [
106
+ PerturbationDescription(
107
+ name="unknown",
108
+ ),
109
+ DialectPerturbation.Description(
110
+ name=DialectPerturbation.name,
111
+ fairness=True,
112
+ prob=0.5,
113
+ source_class="source_class",
114
+ target_class="target_class",
115
+ mapping_file_path="mapping_file_path",
116
+ ),
117
+ ExtraSpacePerturbation.Description(name=ExtraSpacePerturbation.name, robustness=True, num_spaces=2),
118
+ FillerWordsPerturbation.Description(name=FillerWordsPerturbation.name, robustness=True, insert_prob=0.5),
119
+ GenderPerturbation.Description(
120
+ name=GenderPerturbation.name,
121
+ mode="mode",
122
+ fairness=True,
123
+ prob=0.5,
124
+ source_class="source_class",
125
+ target_class="target_class",
126
+ bidirectional=True,
127
+ ),
128
+ MisspellingPerturbation.Description(name=MisspellingPerturbation.name, robustness=True, prob=0.5),
129
+ PersonNamePerturbation.Description(
130
+ name=PersonNamePerturbation.name,
131
+ fairness=True,
132
+ prob=0.5,
133
+ source_class="source_str",
134
+ target_class="target_str",
135
+ name_file_path="name_file_path",
136
+ person_name_type="person_name_type",
137
+ preserve_gender=True,
138
+ ),
139
+ SpacePerturbation.Description(name=SpacePerturbation.name, robustness=True, max_spaces=2),
140
+ SynonymPerturbation.Description(name=SynonymPerturbation.name, robustness=True, prob=0.5),
141
+ TyposPerturbation.Description(name=TyposPerturbation.name, robustness=True, prob=0.5),
142
+ ]
143
+ for description in descriptions:
144
+ self.assertEqual(description, from_json(to_json(description), PerturbationDescription))
@@ -2,7 +2,11 @@ import json
2
2
  import requests
3
3
  from typing import Any, Dict, List
4
4
 
5
+ from aleph_alpha_client import Client as AlephAlphaPythonClient
6
+ from tokenizers import Tokenizer, Encoding
7
+
5
8
  from helm.common.cache import Cache, CacheConfig
9
+ from helm.common.hierarchical_logger import hlog
6
10
  from helm.common.request import Request, RequestResult, Sequence, Token
7
11
  from helm.common.tokenization_request import (
8
12
  DecodeRequest,
@@ -19,9 +23,27 @@ class AlephAlphaClient(Client):
19
23
  TOKENIZE_ENDPOINT: str = "tokenize"
20
24
  DETOKENIZE_ENDPOINT: str = "detokenize"
21
25
 
26
+ VALID_TOKENIZERS: List[str] = [
27
+ "luminous-base",
28
+ "luminous-extended",
29
+ "luminous-supreme",
30
+ ]
31
+
22
32
  def __init__(self, api_key: str, cache_config: CacheConfig):
23
33
  self.api_key: str = api_key
24
34
  self.cache = Cache(cache_config)
35
+ self._aleph_alpha_client = AlephAlphaPythonClient(token=api_key)
36
+ self._tokenizer_name_to_tokenizer: Dict[str, Tokenizer] = {}
37
+
38
+ def _get_tokenizer(self, tokenizer_name: str) -> Tokenizer:
39
+ if tokenizer_name not in self.VALID_TOKENIZERS:
40
+ raise ValueError(f"Invalid tokenizer: {tokenizer_name}")
41
+
42
+ # Check if the tokenizer is cached
43
+ if tokenizer_name not in self._tokenizer_name_to_tokenizer:
44
+ self._tokenizer_name_to_tokenizer[tokenizer_name] = self._aleph_alpha_client.tokenizer(tokenizer_name)
45
+ hlog(f"Initialized tokenizer: {tokenizer_name}")
46
+ return self._tokenizer_name_to_tokenizer[tokenizer_name]
25
47
 
26
48
  def _send_request(self, endpoint: str, raw_request: Dict[str, Any]) -> Dict[str, Any]:
27
49
  response = requests.request(
@@ -33,6 +55,8 @@ class AlephAlphaClient(Client):
33
55
  "Authorization": f"Bearer {self.api_key}",
34
56
  },
35
57
  data=json.dumps(raw_request),
58
+ # Setting the nice flag prevents intensive benchmarking runs from saturating Aleph Alpha's API queues
59
+ params=json.dumps({"nice": True}),
36
60
  )
37
61
  result = json.loads(response.text)
38
62
  assert "error" not in result, f"Request failed with error: {result['error']}"
@@ -40,7 +64,6 @@ class AlephAlphaClient(Client):
40
64
 
41
65
  def make_request(self, request: Request) -> RequestResult:
42
66
  """Make a request following https://docs.aleph-alpha.com/api/complete."""
43
- # TODO: echo is not supported. Follow up on this.
44
67
  raw_request = {
45
68
  "model": request.model_engine,
46
69
  "prompt": request.prompt,
@@ -53,6 +76,7 @@ class AlephAlphaClient(Client):
53
76
  "n": request.num_completions,
54
77
  "stop_sequences": request.stop_sequences,
55
78
  "log_probs": request.top_k_per_token,
79
+ "echo": request.echo_prompt,
56
80
  "tokens": True, # Setting to True returns individual tokens of the completion
57
81
  }
58
82
 
@@ -102,24 +126,21 @@ class AlephAlphaClient(Client):
102
126
  )
103
127
 
104
128
  def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
105
- """Make a request following https://docs.aleph-alpha.com/api/tokenize."""
106
- raw_request = {
107
- "model": request.tokenizer_name,
108
- "prompt": request.text,
109
- "tokens": True,
110
- "token_ids": True,
111
- }
112
-
129
+ """
130
+ Encode the text using Aleph Alpha's tokenizer library:
131
+ https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.Client.tokenizer
132
+ """
113
133
  try:
114
134
 
115
135
  def do_it():
116
- result = self._send_request(AlephAlphaClient.TOKENIZE_ENDPOINT, raw_request)
117
- assert "tokens" in result and "token_ids" in result, f"Invalid response: {result}"
118
- return result
119
-
120
- response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
121
- except (requests.exceptions.RequestException, AssertionError) as e:
122
- error: str = f"AlephAlphaClient error: {e}"
136
+ tokenizer: Tokenizer = self._get_tokenizer(request.tokenizer_name)
137
+ result: Encoding = tokenizer.encode(request.text, add_special_tokens=False)
138
+ return {"token_ids": result.ids, "tokens": result.tokens}
139
+
140
+ cache_key = {"model": request.tokenizer_name, "prompt": request.text, "tokens": True, "token_ids": True}
141
+ response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
142
+ except RuntimeError as e:
143
+ error: str = f"AlephAlphaClient tokenize error: {e}"
123
144
  return TokenizationRequestResult(error=error, success=False, cached=False, text="", tokens=[])
124
145
 
125
146
  tokens = response["token_ids" if request.encode else "tokens"]
@@ -135,22 +156,20 @@ class AlephAlphaClient(Client):
135
156
  )
136
157
 
137
158
  def decode(self, request: DecodeRequest) -> DecodeRequestResult:
138
- """Make a request following https://docs.aleph-alpha.com/api/detokenize."""
139
- raw_request = {
140
- "model": request.tokenizer_name,
141
- "token_ids": request.tokens,
142
- }
143
-
159
+ """
160
+ Decode the tokens using Aleph Alpha's tokenizer library:
161
+ https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.Client.tokenizer
162
+ """
144
163
  try:
145
164
 
146
165
  def do_it():
147
- result = self._send_request(AlephAlphaClient.DETOKENIZE_ENDPOINT, raw_request)
148
- assert "result" in result, f"Invalid response: {result}"
149
- return result
166
+ tokenizer: Tokenizer = self._get_tokenizer(request.tokenizer_name)
167
+ return {"result": tokenizer.decode(request.tokens)}
150
168
 
151
- response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
152
- except (requests.exceptions.RequestException, AssertionError) as e:
153
- error: str = f"AlephAlphaClient error: {e}"
169
+ cache_key = {"model": request.tokenizer_name, "token_ids": request.tokens}
170
+ response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
171
+ except RuntimeError as e:
172
+ error: str = f"AlephAlphaClient decode error: {e}"
154
173
  return DecodeRequestResult(error=error, success=False, cached=False, text="")
155
174
 
156
175
  return DecodeRequestResult(
@@ -21,6 +21,7 @@ from .anthropic_client import AnthropicClient
21
21
  from .chat_gpt_client import ChatGPTClient
22
22
  from .cohere_client import CohereClient
23
23
  from .together_client import TogetherClient
24
+ from .google_client import GoogleClient
24
25
  from .goose_ai_client import GooseAIClient
25
26
  from .huggingface_client import HuggingFaceClient
26
27
  from .ice_tokenizer_client import ICETokenizerClient
@@ -29,6 +30,7 @@ from .microsoft_client import MicrosoftClient
29
30
  from .perspective_api_client import PerspectiveAPIClient
30
31
  from .yalm_tokenizer_client import YaLMTokenizerClient
31
32
  from .simple_client import SimpleClient
33
+ from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_config
32
34
 
33
35
 
34
36
  class AutoClient(Client):
@@ -53,15 +55,17 @@ class AutoClient(Client):
53
55
  # TODO: Allow setting CacheConfig.follower_cache_path from a command line flag.
54
56
  return SqliteCacheConfig(client_cache_path)
55
57
 
56
- def get_client(self, request: Request) -> Client:
57
- """Return a client based on `organization`, creating it if necessary."""
58
- organization: str = request.model_organization
59
- client: Optional[Client] = self.clients.get(organization)
58
+ def _get_client(self, model: str) -> Client:
59
+ """Return a client based on the model, creating it if necessary."""
60
+ client: Optional[Client] = self.clients.get(model)
60
61
 
61
62
  if client is None:
63
+ organization: str = model.split("/")[0]
62
64
  cache_config: CacheConfig = self._build_cache_config(organization)
63
65
 
64
- if organization == "openai":
66
+ if get_huggingface_model_config(model):
67
+ client = HuggingFaceClient(cache_config=cache_config)
68
+ elif organization == "openai":
65
69
  # TODO: add ChatGPT to the OpenAIClient when it's supported.
66
70
  # We're using a separate client for now since we're using an unofficial Python library.
67
71
  # See https://github.com/acheong08/ChatGPT/wiki/Setup on how to get a valid session token.
@@ -71,13 +75,14 @@ class AutoClient(Client):
71
75
  # TODO: use `cache_config` above. Since this feature is still experimental,
72
76
  # save queries and responses in a separate collection.
73
77
  cache_config=self._build_cache_config("ChatGPT"),
74
- tokenizer_client=self.get_tokenizer_client("huggingface"),
78
+ tokenizer_client=self._get_tokenizer_client("huggingface"),
75
79
  )
76
80
 
77
81
  org_id = self.credentials.get("openaiOrgId", None)
78
82
  client = OpenAIClient(
79
83
  api_key=self.credentials["openaiApiKey"],
80
84
  cache_config=cache_config,
85
+ tokenizer_client=self._get_tokenizer_client("huggingface"),
81
86
  chat_gpt_client=chat_gpt_client,
82
87
  org_id=org_id,
83
88
  )
@@ -105,18 +110,20 @@ class AutoClient(Client):
105
110
  cache_config=cache_config,
106
111
  org_id=org_id,
107
112
  )
113
+ elif organization == "google":
114
+ client = GoogleClient(cache_config=cache_config)
108
115
  elif organization == "together":
109
116
  client = TogetherClient(api_key=self.credentials.get("togetherApiKey", None), cache_config=cache_config)
110
117
  elif organization == "simple":
111
118
  client = SimpleClient(cache_config=cache_config)
112
119
  else:
113
- raise ValueError(f"Unknown organization: {organization}")
114
- self.clients[organization] = client
120
+ raise ValueError(f"Could not find client for model: {model}")
121
+ self.clients[model] = client
115
122
  return client
116
123
 
117
124
  def make_request(self, request: Request) -> RequestResult:
118
125
  """
119
- Dispatch based on the organization in the name of the model (e.g., openai/davinci).
126
+ Dispatch based on the the name of the model (e.g., openai/davinci).
120
127
  Retries if request fails.
121
128
  """
122
129
 
@@ -125,30 +132,33 @@ class AutoClient(Client):
125
132
  def make_request_with_retry(client: Client, request: Request) -> RequestResult:
126
133
  return client.make_request(request)
127
134
 
128
- organization: str = request.model_organization
129
- client: Client = self.get_client(request)
135
+ client: Client = self._get_client(request.model)
130
136
 
131
137
  try:
132
138
  return make_request_with_retry(client=client, request=request)
133
139
  except RetryError as e:
134
140
  last_attempt: Attempt = e.last_attempt
135
141
  retry_error: str = (
136
- f"Failed to make request to {organization} after retrying {last_attempt.attempt_number} times"
142
+ f"Failed to make request to {request.model} after retrying {last_attempt.attempt_number} times"
137
143
  )
138
144
  hlog(retry_error)
139
145
 
140
146
  # Notify our user that we failed to make the request even after retrying.
141
147
  return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")
142
148
 
143
- def get_tokenizer_client(self, organization: str) -> Client:
144
- """Return a client based on `organization`, creating it if necessary."""
145
- client: Optional[Client] = self.tokenizer_clients.get(organization)
149
+ def _get_tokenizer_client(self, tokenizer: str) -> Client:
150
+ """Return a client based on the tokenizer, creating it if necessary."""
151
+ organization: str = tokenizer.split("/")[0]
152
+ client: Optional[Client] = self.tokenizer_clients.get(tokenizer)
146
153
 
147
154
  if client is None:
148
155
  cache_config: CacheConfig = self._build_cache_config(organization)
149
- if organization in [
156
+ if get_huggingface_model_config(tokenizer):
157
+ client = HuggingFaceClient(cache_config=cache_config)
158
+ elif organization in [
150
159
  "anthropic",
151
160
  "bigscience",
161
+ "bigcode",
152
162
  "EleutherAI",
153
163
  "facebook",
154
164
  "google",
@@ -171,19 +181,18 @@ class AutoClient(Client):
171
181
  elif organization == "simple":
172
182
  client = SimpleClient(cache_config=cache_config)
173
183
  else:
174
- raise ValueError(f"Unknown organization: {organization}")
175
- self.tokenizer_clients[organization] = client
184
+ raise ValueError(f"Could not find tokenizer client for model: {tokenizer}")
185
+ self.tokenizer_clients[tokenizer] = client
176
186
  return client
177
187
 
178
188
  def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
179
- """Tokenizes based on the organization in the name of the tokenizer (e.g., huggingface/gpt2)."""
189
+ """Tokenizes based on the name of the tokenizer (e.g., huggingface/gpt2)."""
180
190
 
181
191
  @retry_request
182
192
  def tokenize_with_retry(client: Client, request: TokenizationRequest) -> TokenizationRequestResult:
183
193
  return client.tokenize(request)
184
194
 
185
- organization: str = request.tokenizer_organization
186
- client: Client = self.get_tokenizer_client(organization)
195
+ client: Client = self._get_tokenizer_client(request.tokenizer)
187
196
 
188
197
  try:
189
198
  return tokenize_with_retry(client=client, request=request)
@@ -194,14 +203,13 @@ class AutoClient(Client):
194
203
  return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")
195
204
 
196
205
  def decode(self, request: DecodeRequest) -> DecodeRequestResult:
197
- """Decodes based on the organization in the name of the tokenizer (e.g., huggingface/gpt2)."""
206
+ """Decodes based on the the name of the tokenizer (e.g., huggingface/gpt2)."""
198
207
 
199
208
  @retry_request
200
209
  def decode_with_retry(client: Client, request: DecodeRequest) -> DecodeRequestResult:
201
210
  return client.decode(request)
202
211
 
203
- organization: str = request.tokenizer_organization
204
- client: Client = self.get_tokenizer_client(organization)
212
+ client: Client = self._get_tokenizer_client(request.tokenizer)
205
213
 
206
214
  try:
207
215
  return decode_with_retry(client=client, request=request)