crfm-helm 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {crfm_helm-0.2.1.dist-info → crfm_helm-0.2.2.dist-info}/METADATA +10 -8
  2. {crfm_helm-0.2.1.dist-info → crfm_helm-0.2.2.dist-info}/RECORD +50 -37
  3. {crfm_helm-0.2.1.dist-info → crfm_helm-0.2.2.dist-info}/WHEEL +1 -1
  4. {crfm_helm-0.2.1.dist-info → crfm_helm-0.2.2.dist-info}/entry_points.txt +1 -0
  5. helm/benchmark/__init__.py +2 -0
  6. helm/benchmark/adaptation/adapter_spec.py +3 -0
  7. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -7
  8. helm/benchmark/contamination/__init__.py +0 -0
  9. helm/benchmark/metrics/classification_metrics.py +28 -23
  10. helm/benchmark/metrics/test_classification_metrics.py +44 -9
  11. helm/benchmark/presentation/create_plots.py +617 -0
  12. helm/benchmark/presentation/summarize.py +4 -2
  13. helm/benchmark/presentation/test_create_plots.py +32 -0
  14. helm/benchmark/run.py +23 -1
  15. helm/benchmark/run_expander.py +161 -47
  16. helm/benchmark/run_specs.py +84 -10
  17. helm/benchmark/runner.py +31 -3
  18. helm/benchmark/scenarios/copyright_scenario.py +1 -1
  19. helm/benchmark/scenarios/imdb_listdir.json +50014 -0
  20. helm/benchmark/scenarios/lex_glue_scenario.py +58 -17
  21. helm/benchmark/scenarios/lextreme_scenario.py +37 -25
  22. helm/benchmark/scenarios/opinions_qa_scenario.py +194 -0
  23. helm/benchmark/scenarios/scenario.py +5 -0
  24. helm/benchmark/scenarios/the_pile_scenario.py +1 -1
  25. helm/benchmark/static/benchmarking.css +14 -0
  26. helm/benchmark/static/benchmarking.js +43 -0
  27. helm/benchmark/static/index.html +2 -0
  28. helm/benchmark/static/json-urls.js +4 -0
  29. helm/benchmark/static/plot-captions.js +16 -0
  30. helm/benchmark/static/schema.yaml +66 -8
  31. helm/benchmark/window_services/cohere_window_service.py +20 -0
  32. helm/benchmark/window_services/flan_t5_window_service.py +29 -0
  33. helm/benchmark/window_services/huggingface_window_service.py +39 -0
  34. helm/benchmark/window_services/test_flan_t5_window_service.py +12 -0
  35. helm/benchmark/window_services/wider_ai21_window_service.py +13 -0
  36. helm/benchmark/window_services/window_service_factory.py +27 -6
  37. helm/common/general.py +12 -5
  38. helm/proxy/clients/aleph_alpha_client.py +47 -28
  39. helm/proxy/clients/auto_client.py +28 -24
  40. helm/proxy/clients/huggingface_client.py +30 -17
  41. helm/proxy/clients/huggingface_model_registry.py +111 -0
  42. helm/proxy/clients/huggingface_tokenizer.py +23 -7
  43. helm/proxy/clients/openai_client.py +60 -2
  44. helm/proxy/clients/test_huggingface_model_registry.py +57 -0
  45. helm/proxy/clients/together_client.py +17 -2
  46. helm/proxy/clients/yalm_tokenizer/voc_100b.sp +0 -0
  47. helm/proxy/clients/yalm_tokenizer/yalm_tokenizer.py +8 -2
  48. helm/proxy/models.py +82 -2
  49. {crfm_helm-0.2.1.dist-info → crfm_helm-0.2.2.dist-info}/LICENSE +0 -0
  50. {crfm_helm-0.2.1.dist-info → crfm_helm-0.2.2.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,27 @@ models:
30
30
  access: limited
31
31
  num_parameters: 17000000000
32
32
  release_date: 2022-10-28
33
+ - name: ai21/j2-jumbo
34
+ display_name: Jurassic-2 Jumbo (178B)
35
+ description: Jurassic-2 Jumbo (178B parameters) ([docs](https://www.ai21.com/blog/introducing-j2))
36
+ creator_organization: AI21 Labs
37
+ access: limited
38
+ num_parameters: 178000000000
39
+ release_date: 2023-03-09
40
+ - name: ai21/j2-grande
41
+ display_name: Jurassic-2 Grande (17B)
42
+ description: Jurassic-2 Grande (17B parameters) ([docs](https://www.ai21.com/blog/introducing-j2))
43
+ creator_organization: AI21 Labs
44
+ access: limited
45
+ num_parameters: 17000000000
46
+ release_date: 2023-03-09
47
+ - name: ai21/j2-large
48
+ display_name: Jurassic-2 Large (7.5B)
49
+ description: Jurassic-2 Large (7.5B parameters) ([docs](https://www.ai21.com/blog/introducing-j2))
50
+ creator_organization: AI21 Labs
51
+ access: limited
52
+ num_parameters: 7500000000
53
+ release_date: 2023-03-09
33
54
 
34
55
  # Aleph Alpha
35
56
  # TODO: add Luminous World when it's released
@@ -142,9 +163,17 @@ models:
142
163
  access: limited
143
164
  num_parameters: 6100000000
144
165
  release_date: 2022-11-08
145
- - name: cohere/command-xlarge-nightly
146
- display_name: Cohere command nightly (52.4B)
147
- description: Cohere command nightly (52.4B parameters) is fine-tuned from the XL model to respond well with instruction-like prompts ([details](https://docs.cohere.ai/docs/command-beta)).
166
+ - name: cohere/command-medium-beta
167
+ display_name: Cohere Command beta (6.1B)
168
+ description: Cohere Command beta (6.1B parameters) is fine-tuned from the medium model to respond well with instruction-like prompts ([details](https://docs.cohere.ai/docs/command-beta)).
169
+ creator_organization: Cohere
170
+ access: limited
171
+ num_parameters: 6100000000
172
+ release_date: 2022-11-08
173
+ todo: true
174
+ - name: cohere/command-xlarge-beta
175
+ display_name: Cohere Command beta (52.4B)
176
+ description: Cohere Command beta (52.4B parameters) is fine-tuned from the XL model to respond well with instruction-like prompts ([details](https://docs.cohere.ai/docs/command-beta)).
148
177
  creator_organization: Cohere
149
178
  access: limited
150
179
  num_parameters: 52400000000
@@ -203,7 +232,6 @@ models:
203
232
  description: Flan-T5 (11B parameters) is T5 fine-tuned on 1.8K tasks ([paper](https://arxiv.org/pdf/2210.11416.pdf)).
204
233
  creator_organization: Google
205
234
  access: open
206
- todo: true
207
235
 
208
236
  - name: google/palm
209
237
  display_name: PaLM (540B)
@@ -379,6 +407,12 @@ models:
379
407
  description: Codex-style model that is a stronger, multilingual version of the Codex (12B) model in the [Codex paper](https://arxiv.org/pdf/2107.03374.pdf).
380
408
  creator_organization: OpenAI
381
409
  access: limited
410
+ - name: openai/gpt-3.5-turbo-0301
411
+ display_name: gpt-3.5-turbo-0301
412
+ description: Sibling model Sibling model of text-davinci-003 is optimized for chat but works well for traditional completions tasks as well. Snapshot from 2023-03-01.
413
+ creator_organization: OpenAI
414
+ access: limited
415
+ release_date: 2023-03-01
382
416
  - name: openai/chat-gpt
383
417
  display_name: ChatGPT
384
418
  description: Sibling model to InstructGPT which interacts in a conversational way. See [OpenAI's announcement](https://openai.com/blog/chatgpt/). The size of the model is unknown.
@@ -396,6 +430,14 @@ models:
396
430
  num_parameters: 6700000000
397
431
  release_date: 2022-11-29
398
432
  todo: true
433
+ - name: together/gpt-neoxt-chat-base-20b
434
+ display_name: GPT-NeoXT-Chat-Base (20B)
435
+ description: GPT-NeoXT-Chat-Base (20B) is fine-tuned from GPT-NeoX, serving as a base model for developing open-source chatbots.
436
+ creator_organization: Together
437
+ access: open
438
+ num_parameters: 20000000000
439
+ release_date: 2023-03-08
440
+ todo: true
399
441
 
400
442
  # Salesforce
401
443
  - name: together/codegen
@@ -634,6 +676,14 @@ metrics:
634
676
  display_name: F1
635
677
  description: Average F1 score in terms of word overlap between the model output and correct reference.
636
678
  lower_is_better: false
679
+ - name: classification_macro_f1
680
+ display_name: Macro-F1
681
+ description: Population-level macro-averaged F1 score.
682
+ lower_is_better: false
683
+ - name: classification_micro_f1
684
+ display_name: Micro-F1
685
+ description: Population-level micro-averaged F1 score.
686
+ lower_is_better: false
637
687
  - name: absolute_value_difference
638
688
  display_name: Absolute difference
639
689
  short_display_name: Diff.
@@ -1165,6 +1215,14 @@ metric_groups:
1165
1215
  - name: monte_carlo_entropy
1166
1216
  split: ${main_split}
1167
1217
 
1218
+ - name: classification_metrics
1219
+ display_name: Classification metrics
1220
+ metrics:
1221
+ - name: classification_macro_f1
1222
+ split: ${main_split}
1223
+ - name: classification_micro_f1
1224
+ split: ${main_split}
1225
+
1168
1226
  ############################################################
1169
1227
  run_groups:
1170
1228
  ## Top-level
@@ -2106,24 +2164,24 @@ run_groups:
2106
2164
  display_name: LEXTREME
2107
2165
  description: A Multilingual Legal Benchmark for Natural Language Understanding
2108
2166
  metric_groups:
2109
- - accuracy
2167
+ - classification_metrics
2110
2168
  - calibration
2111
2169
  - efficiency
2112
2170
  - general_information
2113
2171
  environment:
2114
- main_name: f1_score
2172
+ main_name: classification_macro_f1
2115
2173
  main_split: test
2116
2174
 
2117
2175
  - name: lex_glue
2118
2176
  display_name: LexGLUE
2119
2177
  description: A Benchmark Dataset for Legal Language Understanding in English
2120
2178
  metric_groups:
2121
- - accuracy
2179
+ - classification_metrics
2122
2180
  - calibration
2123
2181
  - efficiency
2124
2182
  - general_information
2125
2183
  environment:
2126
- main_name: f1_score
2184
+ main_name: classification_macro_f1
2127
2185
  main_split: test
2128
2186
 
2129
2187
  - name: entity_data_imputation
@@ -141,3 +141,23 @@ class CohereWindowService(LocalWindowService):
141
141
  result = result[:-1]
142
142
 
143
143
  return result
144
+
145
+
146
+ class CohereCommandWindowService(CohereWindowService):
147
+ def __init__(self, service: TokenizerService):
148
+ super().__init__(service)
149
+
150
+ @property
151
+ def max_request_length(self) -> int:
152
+ """
153
+ The max request length of the model. For Cohere, this is the same as the `max_sequence_length`.
154
+ If we exceed the `max_sequence_length`, we get the following error:
155
+
156
+ Request failed with too many tokens: total number of tokens (prompt and prediction) cannot
157
+ exceed 2048 - received 2049. Try using a shorter prompt or a smaller max_tokens value.
158
+
159
+ For the Command model, in rare situations, the co.tokenize returns a shorter list of tokens
160
+ than the co.generate. This causes sequence length errors for rare inputs. Cohere's advice is
161
+ to reduce the sequence length to 2020 to avoid these issues.
162
+ """
163
+ return 2020
@@ -0,0 +1,29 @@
1
+ from .encoder_decoder_window_service import EncoderDecoderWindowService
2
+ from .tokenizer_service import TokenizerService
3
+
4
+
5
+ class FlanT5WindowService(EncoderDecoderWindowService):
6
+ def __init__(self, service: TokenizerService):
7
+ super().__init__(service)
8
+
9
+ @property
10
+ def max_sequence_length(self) -> int:
11
+ """Return the max sequence length."""
12
+ # We subtract 1 to account for <extra_id_0> that gets appended to prompts.
13
+ return 512 - 1
14
+
15
+ @property
16
+ def end_of_text_token(self) -> str:
17
+ """The end of text token."""
18
+ return "</s>"
19
+
20
+ @property
21
+ def tokenizer_name(self) -> str:
22
+ """Name of the tokenizer to use when sending a request."""
23
+ return "google/flan-t5-xxl"
24
+
25
+ @property
26
+ def prefix_token(self) -> str:
27
+ """The prefix token is the same as the end of text token."""
28
+ # echo=True is not supported
29
+ return ""
@@ -0,0 +1,39 @@
1
+ from helm.proxy.clients.huggingface_tokenizer import HuggingFaceTokenizers
2
+ from .local_window_service import LocalWindowService
3
+ from .tokenizer_service import TokenizerService
4
+ from helm.proxy.clients.huggingface_client import HuggingFaceModelConfig
5
+
6
+
7
+ class HuggingFaceWindowService(LocalWindowService):
8
+ def __init__(self, service: TokenizerService, model_config: HuggingFaceModelConfig):
9
+ super().__init__(service)
10
+ self._tokenizer_name = str(model_config)
11
+ tokenizer = HuggingFaceTokenizers.get_tokenizer(self._tokenizer_name)
12
+ self._prefix_token = tokenizer.bos_token
13
+ self._end_of_text_token = tokenizer.eos_token
14
+ self._max_request_length = tokenizer.model_max_length
15
+
16
+ @property
17
+ def max_sequence_length(self) -> int:
18
+ """Return the max sequence length of this tokenizer."""
19
+ return self._max_request_length
20
+
21
+ @property
22
+ def max_request_length(self) -> int:
23
+ """Return the max request length of this tokenizer."""
24
+ return self.max_sequence_length
25
+
26
+ @property
27
+ def end_of_text_token(self) -> str:
28
+ """The end of text token."""
29
+ return self._end_of_text_token
30
+
31
+ @property
32
+ def tokenizer_name(self) -> str:
33
+ """Name of the tokenizer to use when sending a request."""
34
+ return self._tokenizer_name
35
+
36
+ @property
37
+ def prefix_token(self) -> str:
38
+ """The prefix token."""
39
+ return self._prefix_token
@@ -0,0 +1,12 @@
1
+ import tempfile
2
+
3
+ from helm.benchmark.window_services.test_t511b_window_service import TestT511bWindowService
4
+ from helm.benchmark.window_services.window_service_factory import TokenizerService, WindowServiceFactory
5
+ from helm.benchmark.window_services.test_utils import get_tokenizer_service
6
+
7
+
8
+ class TestFlanT5WindowService(TestT511bWindowService):
9
+ def setup_method(self):
10
+ self.path: str = tempfile.mkdtemp()
11
+ service: TokenizerService = get_tokenizer_service(self.path)
12
+ self.window_service = WindowServiceFactory.get_window_service("together/flan-t5-xxl", service)
@@ -0,0 +1,13 @@
1
+ from .ai21_window_service import AI21WindowService
2
+
3
+
4
+ class WiderAI21WindowService(AI21WindowService):
5
+ @property
6
+ def max_sequence_length(self) -> int:
7
+ """
8
+ Return the max sequence length of the larger AI21 Jurassic-2 models.
9
+
10
+ The AI21 server automatically prepends a token to every prompt,
11
+ so the actual max sequence length is 8192 - 1 = 8191.
12
+ """
13
+ return 8191
@@ -1,7 +1,14 @@
1
- from helm.proxy.models import get_model, get_model_names_with_tag, Model, WIDER_CONTEXT_WINDOW_TAG
1
+ from helm.proxy.models import (
2
+ get_model,
3
+ get_model_names_with_tag,
4
+ Model,
5
+ AI21_WIDER_CONTEXT_WINDOW_TAG,
6
+ WIDER_CONTEXT_WINDOW_TAG,
7
+ )
2
8
  from .ai21_window_service import AI21WindowService
9
+ from .wider_ai21_window_service import WiderAI21WindowService
3
10
  from .anthropic_window_service import AnthropicWindowService
4
- from .cohere_window_service import CohereWindowService
11
+ from .cohere_window_service import CohereWindowService, CohereCommandWindowService
5
12
  from .luminous_window_service import (
6
13
  LuminousBaseWindowService,
7
14
  LuminousExtendedWindowService,
@@ -12,6 +19,7 @@ from .openai_window_service import OpenAIWindowService
12
19
  from .wider_openai_window_service import WiderOpenAIWindowService
13
20
  from .mt_nlg_window_service import MTNLGWindowService
14
21
  from .bloom_window_service import BloomWindowService
22
+ from .huggingface_window_service import HuggingFaceWindowService
15
23
  from .ice_window_service import ICEWindowService
16
24
  from .santacoder_window_service import SantaCoderWindowService
17
25
  from .gpt2_window_service import GPT2WindowService
@@ -20,10 +28,12 @@ from .gptneox_window_service import GPTNeoXWindowService
20
28
  from .opt_window_service import OPTWindowService
21
29
  from .t0pp_window_service import T0ppWindowService
22
30
  from .t511b_window_service import T511bWindowService
31
+ from .flan_t5_window_service import FlanT5WindowService
23
32
  from .ul2_window_service import UL2WindowService
24
33
  from .yalm_window_service import YaLMWindowService
25
34
  from .window_service import WindowService
26
35
  from .tokenizer_service import TokenizerService
36
+ from helm.proxy.clients.huggingface_client import get_huggingface_model_config
27
37
 
28
38
 
29
39
  class WindowServiceFactory:
@@ -38,7 +48,10 @@ class WindowServiceFactory:
38
48
  engine: str = model.engine
39
49
 
40
50
  window_service: WindowService
41
- if model_name in get_model_names_with_tag(WIDER_CONTEXT_WINDOW_TAG):
51
+ huggingface_model_config = get_huggingface_model_config(model_name)
52
+ if huggingface_model_config:
53
+ window_service = HuggingFaceWindowService(service=service, model_config=huggingface_model_config)
54
+ elif model_name in get_model_names_with_tag(WIDER_CONTEXT_WINDOW_TAG):
42
55
  window_service = WiderOpenAIWindowService(service)
43
56
  # For the Google models, we approximate with the OpenAIWindowService
44
57
  elif organization == "openai" or organization == "simple" or organization == "google":
@@ -70,7 +83,7 @@ class WindowServiceFactory:
70
83
  window_service = ICEWindowService(service)
71
84
  elif model_name in ["huggingface/gpt-j-6b", "together/gpt-j-6b", "gooseai/gpt-j-6b"]:
72
85
  window_service = GPTJWindowService(service)
73
- elif model_name in ["together/gpt-neox-20b", "gooseai/gpt-neo-20b"]:
86
+ elif model_name in ["together/gpt-neox-20b", "gooseai/gpt-neo-20b", "together/gpt-neoxt-chat-base-20b"]:
74
87
  window_service = GPTNeoXWindowService(service)
75
88
  elif model_name == "together/h3-2.7b":
76
89
  window_service = GPT2WindowService(service)
@@ -80,14 +93,22 @@ class WindowServiceFactory:
80
93
  window_service = T0ppWindowService(service)
81
94
  elif model_name == "together/t5-11b":
82
95
  window_service = T511bWindowService(service)
96
+ elif model_name == "together/flan-t5-xxl":
97
+ window_service = FlanT5WindowService(service)
83
98
  elif model_name == "together/ul2":
84
99
  window_service = UL2WindowService(service)
85
100
  elif model_name == "together/yalm":
86
101
  window_service = YaLMWindowService(service)
87
102
  elif organization == "cohere":
88
- window_service = CohereWindowService(service)
103
+ if "command" in engine:
104
+ window_service = CohereCommandWindowService(service)
105
+ else:
106
+ window_service = CohereWindowService(service)
89
107
  elif organization == "ai21":
90
- window_service = AI21WindowService(service=service, gpt2_window_service=GPT2WindowService(service))
108
+ if model_name in get_model_names_with_tag(AI21_WIDER_CONTEXT_WINDOW_TAG):
109
+ window_service = WiderAI21WindowService(service=service, gpt2_window_service=GPT2WindowService(service))
110
+ else:
111
+ window_service = AI21WindowService(service=service, gpt2_window_service=GPT2WindowService(service))
91
112
  else:
92
113
  raise ValueError(f"Unhandled model name: {model_name}")
93
114
 
helm/common/general.py CHANGED
@@ -49,7 +49,13 @@ def shell(args: List[str]):
49
49
 
50
50
 
51
51
  @htrack(None)
52
- def ensure_file_downloaded(source_url: str, target_path: str, unpack: bool = False, unpack_type: Optional[str] = None):
52
+ def ensure_file_downloaded(
53
+ source_url: str,
54
+ target_path: str,
55
+ unpack: bool = False,
56
+ downloader_executable: str = "wget",
57
+ unpack_type: Optional[str] = None,
58
+ ):
53
59
  """Download `source_url` to `target_path` if it doesn't exist."""
54
60
  if os.path.exists(target_path):
55
61
  # Assume it's all good
@@ -59,7 +65,8 @@ def ensure_file_downloaded(source_url: str, target_path: str, unpack: bool = Fal
59
65
  # Download
60
66
  # gdown is used to download large files/zip folders from Google Drive.
61
67
  # It bypasses security warnings which wget cannot handle.
62
- downloader_executable: str = "gdown" if source_url.startswith("https://drive.google.com") else "wget"
68
+ if source_url.startswith("https://drive.google.com"):
69
+ downloader_executable = "gdown"
63
70
  tmp_path: str = f"{target_path}.tmp"
64
71
  shell([downloader_executable, source_url, "-O", tmp_path])
65
72
 
@@ -195,13 +202,13 @@ def parallel_map(
195
202
  with htrack_block(f"Parallelizing computation on {len(items)} items over {parallelism} {units}"):
196
203
  results: List
197
204
  if parallelism == 1:
198
- results = list(tqdm(map(process, items), total=len(items)))
205
+ results = list(tqdm(map(process, items), total=len(items), disable=None))
199
206
  elif multiprocessing:
200
207
  with ProcessPoolExecutor(max_workers=parallelism) as executor:
201
- results = list(tqdm(executor.map(process, items), total=len(items)))
208
+ results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
202
209
  else:
203
210
  with ThreadPoolExecutor(max_workers=parallelism) as executor:
204
- results = list(tqdm(executor.map(process, items), total=len(items)))
211
+ results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
205
212
  return results
206
213
 
207
214
 
@@ -2,7 +2,11 @@ import json
2
2
  import requests
3
3
  from typing import Any, Dict, List
4
4
 
5
+ from aleph_alpha_client import Client as AlephAlphaPythonClient
6
+ from tokenizers import Tokenizer, Encoding
7
+
5
8
  from helm.common.cache import Cache, CacheConfig
9
+ from helm.common.hierarchical_logger import hlog
6
10
  from helm.common.request import Request, RequestResult, Sequence, Token
7
11
  from helm.common.tokenization_request import (
8
12
  DecodeRequest,
@@ -19,9 +23,27 @@ class AlephAlphaClient(Client):
19
23
  TOKENIZE_ENDPOINT: str = "tokenize"
20
24
  DETOKENIZE_ENDPOINT: str = "detokenize"
21
25
 
26
+ VALID_TOKENIZERS: List[str] = [
27
+ "luminous-base",
28
+ "luminous-extended",
29
+ "luminous-supreme",
30
+ ]
31
+
22
32
  def __init__(self, api_key: str, cache_config: CacheConfig):
23
33
  self.api_key: str = api_key
24
34
  self.cache = Cache(cache_config)
35
+ self._aleph_alpha_client = AlephAlphaPythonClient(token=api_key)
36
+ self._tokenizer_name_to_tokenizer: Dict[str, Tokenizer] = {}
37
+
38
+ def _get_tokenizer(self, tokenizer_name: str) -> Tokenizer:
39
+ if tokenizer_name not in self.VALID_TOKENIZERS:
40
+ raise ValueError(f"Invalid tokenizer: {tokenizer_name}")
41
+
42
+ # Check if the tokenizer is cached
43
+ if tokenizer_name not in self._tokenizer_name_to_tokenizer:
44
+ self._tokenizer_name_to_tokenizer[tokenizer_name] = self._aleph_alpha_client.tokenizer(tokenizer_name)
45
+ hlog(f"Initialized tokenizer: {tokenizer_name}")
46
+ return self._tokenizer_name_to_tokenizer[tokenizer_name]
25
47
 
26
48
  def _send_request(self, endpoint: str, raw_request: Dict[str, Any]) -> Dict[str, Any]:
27
49
  response = requests.request(
@@ -33,6 +55,8 @@ class AlephAlphaClient(Client):
33
55
  "Authorization": f"Bearer {self.api_key}",
34
56
  },
35
57
  data=json.dumps(raw_request),
58
+ # Setting the nice flag prevents intensive benchmarking runs from saturating Aleph Alpha's API queues
59
+ params=json.dumps({"nice": True}),
36
60
  )
37
61
  result = json.loads(response.text)
38
62
  assert "error" not in result, f"Request failed with error: {result['error']}"
@@ -40,7 +64,6 @@ class AlephAlphaClient(Client):
40
64
 
41
65
  def make_request(self, request: Request) -> RequestResult:
42
66
  """Make a request following https://docs.aleph-alpha.com/api/complete."""
43
- # TODO: echo is not supported. Follow up on this.
44
67
  raw_request = {
45
68
  "model": request.model_engine,
46
69
  "prompt": request.prompt,
@@ -53,6 +76,7 @@ class AlephAlphaClient(Client):
53
76
  "n": request.num_completions,
54
77
  "stop_sequences": request.stop_sequences,
55
78
  "log_probs": request.top_k_per_token,
79
+ "echo": request.echo_prompt,
56
80
  "tokens": True, # Setting to True returns individual tokens of the completion
57
81
  }
58
82
 
@@ -102,24 +126,21 @@ class AlephAlphaClient(Client):
102
126
  )
103
127
 
104
128
  def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
105
- """Make a request following https://docs.aleph-alpha.com/api/tokenize."""
106
- raw_request = {
107
- "model": request.tokenizer_name,
108
- "prompt": request.text,
109
- "tokens": True,
110
- "token_ids": True,
111
- }
112
-
129
+ """
130
+ Encode the text using Aleph Alpha's tokenizer library:
131
+ https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.Client.tokenizer
132
+ """
113
133
  try:
114
134
 
115
135
  def do_it():
116
- result = self._send_request(AlephAlphaClient.TOKENIZE_ENDPOINT, raw_request)
117
- assert "tokens" in result and "token_ids" in result, f"Invalid response: {result}"
118
- return result
119
-
120
- response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
121
- except (requests.exceptions.RequestException, AssertionError) as e:
122
- error: str = f"AlephAlphaClient error: {e}"
136
+ tokenizer: Tokenizer = self._get_tokenizer(request.tokenizer_name)
137
+ result: Encoding = tokenizer.encode(request.text, add_special_tokens=False)
138
+ return {"token_ids": result.ids, "tokens": result.tokens}
139
+
140
+ cache_key = {"model": request.tokenizer_name, "prompt": request.text, "tokens": True, "token_ids": True}
141
+ response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
142
+ except RuntimeError as e:
143
+ error: str = f"AlephAlphaClient tokenize error: {e}"
123
144
  return TokenizationRequestResult(error=error, success=False, cached=False, text="", tokens=[])
124
145
 
125
146
  tokens = response["token_ids" if request.encode else "tokens"]
@@ -135,22 +156,20 @@ class AlephAlphaClient(Client):
135
156
  )
136
157
 
137
158
  def decode(self, request: DecodeRequest) -> DecodeRequestResult:
138
- """Make a request following https://docs.aleph-alpha.com/api/detokenize."""
139
- raw_request = {
140
- "model": request.tokenizer_name,
141
- "token_ids": request.tokens,
142
- }
143
-
159
+ """
160
+ Decode the tokens using Aleph Alpha's tokenizer library:
161
+ https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.Client.tokenizer
162
+ """
144
163
  try:
145
164
 
146
165
  def do_it():
147
- result = self._send_request(AlephAlphaClient.DETOKENIZE_ENDPOINT, raw_request)
148
- assert "result" in result, f"Invalid response: {result}"
149
- return result
166
+ tokenizer: Tokenizer = self._get_tokenizer(request.tokenizer_name)
167
+ return {"result": tokenizer.decode(request.tokens)}
150
168
 
151
- response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
152
- except (requests.exceptions.RequestException, AssertionError) as e:
153
- error: str = f"AlephAlphaClient error: {e}"
169
+ cache_key = {"model": request.tokenizer_name, "token_ids": request.tokens}
170
+ response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
171
+ except RuntimeError as e:
172
+ error: str = f"AlephAlphaClient decode error: {e}"
154
173
  return DecodeRequestResult(error=error, success=False, cached=False, text="")
155
174
 
156
175
  return DecodeRequestResult(