crfm-helm 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/METADATA +11 -8
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/RECORD +67 -38
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/WHEEL +1 -1
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/entry_points.txt +2 -1
- helm/benchmark/__init__.py +13 -0
- helm/benchmark/adaptation/adapter_spec.py +3 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -7
- helm/benchmark/augmentations/correct_to_misspelling.json +1 -0
- helm/benchmark/contamination/__init__.py +0 -0
- helm/benchmark/metrics/classification_metrics.py +70 -0
- helm/benchmark/metrics/machine_translation_metrics.py +36 -0
- helm/benchmark/metrics/summarization_metrics.py +7 -8
- helm/benchmark/metrics/test_classification_metrics.py +150 -0
- helm/benchmark/presentation/create_plots.py +617 -0
- helm/benchmark/presentation/run_display.py +7 -48
- helm/benchmark/presentation/summarize.py +4 -2
- helm/benchmark/presentation/test_create_plots.py +32 -0
- helm/benchmark/run.py +144 -48
- helm/benchmark/run_expander.py +164 -47
- helm/benchmark/run_specs.py +346 -39
- helm/benchmark/runner.py +34 -6
- helm/benchmark/scenarios/copyright_scenario.py +1 -1
- helm/benchmark/scenarios/covid_dialog_scenario.py +84 -0
- helm/benchmark/scenarios/imdb_listdir.json +50014 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +253 -0
- helm/benchmark/scenarios/lextreme_scenario.py +458 -0
- helm/benchmark/scenarios/me_q_sum_scenario.py +86 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +132 -0
- helm/benchmark/scenarios/med_mcqa_scenario.py +102 -0
- helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +119 -0
- helm/benchmark/scenarios/med_qa_scenario.py +96 -0
- helm/benchmark/scenarios/opinions_qa_scenario.py +194 -0
- helm/benchmark/scenarios/scenario.py +5 -0
- helm/benchmark/scenarios/the_pile_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +96 -0
- helm/benchmark/static/benchmarking.css +14 -0
- helm/benchmark/static/benchmarking.js +43 -0
- helm/benchmark/static/index.html +2 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/plot-captions.js +16 -0
- helm/benchmark/static/schema.yaml +154 -1
- helm/benchmark/window_services/cohere_window_service.py +20 -0
- helm/benchmark/window_services/flan_t5_window_service.py +29 -0
- helm/benchmark/window_services/huggingface_window_service.py +39 -0
- helm/benchmark/window_services/santacoder_window_service.py +27 -0
- helm/benchmark/window_services/test_flan_t5_window_service.py +12 -0
- helm/benchmark/window_services/wider_ai21_window_service.py +13 -0
- helm/benchmark/window_services/window_service_factory.py +34 -7
- helm/common/codec.py +123 -0
- helm/common/general.py +12 -5
- helm/common/test_codec.py +144 -0
- helm/proxy/clients/aleph_alpha_client.py +47 -28
- helm/proxy/clients/auto_client.py +32 -24
- helm/proxy/clients/google_client.py +88 -0
- helm/proxy/clients/huggingface_client.py +32 -16
- helm/proxy/clients/huggingface_model_registry.py +111 -0
- helm/proxy/clients/huggingface_tokenizer.py +25 -7
- helm/proxy/clients/openai_client.py +60 -2
- helm/proxy/clients/test_huggingface_model_registry.py +57 -0
- helm/proxy/clients/test_huggingface_tokenizer.py +3 -0
- helm/proxy/clients/together_client.py +17 -2
- helm/proxy/clients/yalm_tokenizer/voc_100b.sp +0 -0
- helm/proxy/clients/yalm_tokenizer/yalm_tokenizer.py +8 -2
- helm/proxy/models.py +115 -7
- helm/proxy/test_models.py +1 -1
- helm/benchmark/presentation/present.py +0 -249
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/LICENSE +0 -0
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/top_level.txt +0 -0
helm/common/codec.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Functions for converting to and from dataclasses."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import json
|
|
5
|
+
import typing
|
|
6
|
+
from typing import Any, Callable, Dict, List, Union, Type, TypeVar
|
|
7
|
+
|
|
8
|
+
from helm.benchmark.augmentations.dialect_perturbation import DialectPerturbation
|
|
9
|
+
from helm.benchmark.augmentations.extra_space_perturbation import ExtraSpacePerturbation
|
|
10
|
+
from helm.benchmark.augmentations.filler_words_perturbation import FillerWordsPerturbation
|
|
11
|
+
from helm.benchmark.augmentations.gender_perturbation import GenderPerturbation
|
|
12
|
+
from helm.benchmark.augmentations.misspelling_perturbation import MisspellingPerturbation
|
|
13
|
+
from helm.benchmark.augmentations.person_name_perturbation import PersonNamePerturbation
|
|
14
|
+
from helm.benchmark.augmentations.space_perturbation import SpacePerturbation
|
|
15
|
+
from helm.benchmark.augmentations.synonym_perturbation import SynonymPerturbation
|
|
16
|
+
from helm.benchmark.augmentations.typos_perturbation import TyposPerturbation
|
|
17
|
+
from helm.benchmark.augmentations.perturbation_description import PerturbationDescription
|
|
18
|
+
|
|
19
|
+
import cattrs
|
|
20
|
+
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T")
|
|
24
|
+
StructureFn = Callable[[Dict[str, Any], Type[T]], T] # dict -> dataclass
|
|
25
|
+
UnstructureFn = Callable[[T], Dict[str, Any]] # dataclass -> dict
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# TODO(#1251): Add proper class registration
|
|
29
|
+
PERTURBATION_NAME_TO_DESCRIPTION = {
|
|
30
|
+
DialectPerturbation.name: DialectPerturbation.Description,
|
|
31
|
+
ExtraSpacePerturbation.name: ExtraSpacePerturbation.Description,
|
|
32
|
+
FillerWordsPerturbation.name: FillerWordsPerturbation.Description,
|
|
33
|
+
GenderPerturbation.name: GenderPerturbation.Description,
|
|
34
|
+
MisspellingPerturbation.name: MisspellingPerturbation.Description,
|
|
35
|
+
PersonNamePerturbation.name: PersonNamePerturbation.Description,
|
|
36
|
+
SpacePerturbation.name: SpacePerturbation.Description,
|
|
37
|
+
SynonymPerturbation.name: SynonymPerturbation.Description,
|
|
38
|
+
TyposPerturbation.name: TyposPerturbation.Description,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _build_converter() -> cattrs.Converter:
|
|
43
|
+
converter = cattrs.Converter()
|
|
44
|
+
|
|
45
|
+
# Handle omission of Nones in JSON.
|
|
46
|
+
# To improve readability and reduce storage space, if a field value is None and the field
|
|
47
|
+
# has no default value or a None default value, the field is omitted in the serialized JSON.
|
|
48
|
+
def get_dataclass_optional_fields_without_default(cls: Type[T]) -> List[str]:
|
|
49
|
+
if not dataclasses.is_dataclass(cls):
|
|
50
|
+
return []
|
|
51
|
+
return [
|
|
52
|
+
field.name
|
|
53
|
+
for field in dataclasses.fields(cls)
|
|
54
|
+
if typing.get_origin(field.type) == Union and type(None) in typing.get_args(field.type)
|
|
55
|
+
# For optional fields with a non-None default value, do not replace a missing value
|
|
56
|
+
# with None.
|
|
57
|
+
and (field.default == dataclasses.MISSING or field.default is None)
|
|
58
|
+
and field.default_factory == dataclasses.MISSING
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
def make_omit_nones_dict_structure_fn(cls: Type[T]) -> StructureFn[T]:
|
|
62
|
+
field_names = get_dataclass_optional_fields_without_default(cls)
|
|
63
|
+
_base_structure = make_dict_structure_fn(cls, converter)
|
|
64
|
+
|
|
65
|
+
def structure(raw_dict: Dict[str, Any], inner_cls: Type[T]) -> T:
|
|
66
|
+
for field_name in field_names:
|
|
67
|
+
if field_name not in raw_dict:
|
|
68
|
+
raw_dict[field_name] = None
|
|
69
|
+
return _base_structure(raw_dict, inner_cls)
|
|
70
|
+
|
|
71
|
+
return structure
|
|
72
|
+
|
|
73
|
+
def make_omit_nones_dict_unstructure_fn(cls: Type[T]) -> UnstructureFn[T]:
|
|
74
|
+
field_names = get_dataclass_optional_fields_without_default(cls)
|
|
75
|
+
_base_unstructure = make_dict_unstructure_fn(cls, converter)
|
|
76
|
+
|
|
77
|
+
def structure(data: T) -> Dict[str, Any]:
|
|
78
|
+
raw_dict = _base_unstructure(data)
|
|
79
|
+
for field_name in field_names:
|
|
80
|
+
if raw_dict[field_name] is None:
|
|
81
|
+
del raw_dict[field_name]
|
|
82
|
+
return raw_dict
|
|
83
|
+
|
|
84
|
+
return structure
|
|
85
|
+
|
|
86
|
+
converter.register_structure_hook_factory(
|
|
87
|
+
lambda cls: bool(get_dataclass_optional_fields_without_default(cls)), make_omit_nones_dict_structure_fn
|
|
88
|
+
)
|
|
89
|
+
converter.register_unstructure_hook_factory(
|
|
90
|
+
lambda cls: bool(get_dataclass_optional_fields_without_default(cls)), make_omit_nones_dict_unstructure_fn
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Handle the use of the name field in PerturbationDescription to determine the subclass.
|
|
94
|
+
base_perturbation_description_structure_fn: StructureFn = make_omit_nones_dict_structure_fn(PerturbationDescription)
|
|
95
|
+
perturbation_name_to_base_structure_fn: Dict[str, StructureFn] = {
|
|
96
|
+
name: make_omit_nones_dict_structure_fn(cls) for name, cls in PERTURBATION_NAME_TO_DESCRIPTION.items()
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def structure_perturbation_description(
|
|
100
|
+
raw_dict: Dict[Any, Any], cls: Type[PerturbationDescription]
|
|
101
|
+
) -> PerturbationDescription:
|
|
102
|
+
"""Convert a raw dictionary to a PerturbationDescription.
|
|
103
|
+
This uses the name field to look up the correct PerturbationDescription subclass to output.
|
|
104
|
+
"""
|
|
105
|
+
structure = perturbation_name_to_base_structure_fn.get(
|
|
106
|
+
raw_dict["name"], base_perturbation_description_structure_fn
|
|
107
|
+
)
|
|
108
|
+
return structure(raw_dict, cls)
|
|
109
|
+
|
|
110
|
+
converter.register_structure_hook(PerturbationDescription, structure_perturbation_description)
|
|
111
|
+
|
|
112
|
+
return converter
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
_converter = _build_converter()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def from_json(data: Union[bytes, str], cls: Type[T]) -> T:
|
|
119
|
+
return _converter.structure(json.loads(data), cls)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def to_json(data: Any) -> str:
|
|
123
|
+
return json.dumps(_converter.unstructure(data), indent=2)
|
helm/common/general.py
CHANGED
|
@@ -49,7 +49,13 @@ def shell(args: List[str]):
|
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
@htrack(None)
|
|
52
|
-
def ensure_file_downloaded(
|
|
52
|
+
def ensure_file_downloaded(
|
|
53
|
+
source_url: str,
|
|
54
|
+
target_path: str,
|
|
55
|
+
unpack: bool = False,
|
|
56
|
+
downloader_executable: str = "wget",
|
|
57
|
+
unpack_type: Optional[str] = None,
|
|
58
|
+
):
|
|
53
59
|
"""Download `source_url` to `target_path` if it doesn't exist."""
|
|
54
60
|
if os.path.exists(target_path):
|
|
55
61
|
# Assume it's all good
|
|
@@ -59,7 +65,8 @@ def ensure_file_downloaded(source_url: str, target_path: str, unpack: bool = Fal
|
|
|
59
65
|
# Download
|
|
60
66
|
# gdown is used to download large files/zip folders from Google Drive.
|
|
61
67
|
# It bypasses security warnings which wget cannot handle.
|
|
62
|
-
|
|
68
|
+
if source_url.startswith("https://drive.google.com"):
|
|
69
|
+
downloader_executable = "gdown"
|
|
63
70
|
tmp_path: str = f"{target_path}.tmp"
|
|
64
71
|
shell([downloader_executable, source_url, "-O", tmp_path])
|
|
65
72
|
|
|
@@ -195,13 +202,13 @@ def parallel_map(
|
|
|
195
202
|
with htrack_block(f"Parallelizing computation on {len(items)} items over {parallelism} {units}"):
|
|
196
203
|
results: List
|
|
197
204
|
if parallelism == 1:
|
|
198
|
-
results = list(tqdm(map(process, items), total=len(items)))
|
|
205
|
+
results = list(tqdm(map(process, items), total=len(items), disable=None))
|
|
199
206
|
elif multiprocessing:
|
|
200
207
|
with ProcessPoolExecutor(max_workers=parallelism) as executor:
|
|
201
|
-
results = list(tqdm(executor.map(process, items), total=len(items)))
|
|
208
|
+
results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
|
|
202
209
|
else:
|
|
203
210
|
with ThreadPoolExecutor(max_workers=parallelism) as executor:
|
|
204
|
-
results = list(tqdm(executor.map(process, items), total=len(items)))
|
|
211
|
+
results = list(tqdm(executor.map(process, items), total=len(items), disable=None))
|
|
205
212
|
return results
|
|
206
213
|
|
|
207
214
|
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from helm.benchmark.augmentations.dialect_perturbation import DialectPerturbation
|
|
9
|
+
from helm.benchmark.augmentations.extra_space_perturbation import ExtraSpacePerturbation
|
|
10
|
+
from helm.benchmark.augmentations.filler_words_perturbation import FillerWordsPerturbation
|
|
11
|
+
from helm.benchmark.augmentations.gender_perturbation import GenderPerturbation
|
|
12
|
+
from helm.benchmark.augmentations.misspelling_perturbation import MisspellingPerturbation
|
|
13
|
+
from helm.benchmark.augmentations.person_name_perturbation import PersonNamePerturbation
|
|
14
|
+
from helm.benchmark.augmentations.space_perturbation import SpacePerturbation
|
|
15
|
+
from helm.benchmark.augmentations.synonym_perturbation import SynonymPerturbation
|
|
16
|
+
from helm.benchmark.augmentations.typos_perturbation import TyposPerturbation
|
|
17
|
+
from helm.benchmark.augmentations.perturbation_description import PerturbationDescription
|
|
18
|
+
from helm.common.codec import from_json, to_json
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class DataClassChildForTest:
|
|
23
|
+
required_int: int
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class DataClassWithOptionals:
|
|
28
|
+
optional_str: Optional[str]
|
|
29
|
+
optional_int: Optional[int]
|
|
30
|
+
optional_bool: Optional[bool]
|
|
31
|
+
optional_list: Optional[List[int]]
|
|
32
|
+
optional_dict: Optional[Dict[str, int]]
|
|
33
|
+
optional_child: Optional[DataClassChildForTest]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class DataClassWithDefaults:
|
|
38
|
+
required_int_with_default: int = -1
|
|
39
|
+
optional_int_with_int_default: Optional[int] = -2
|
|
40
|
+
optional_int_with_none_default: Optional[int] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TestJsonCodec(unittest.TestCase):
|
|
44
|
+
def test_round_trip_optional(self):
|
|
45
|
+
data = DataClassWithOptionals(
|
|
46
|
+
optional_str="hello",
|
|
47
|
+
optional_int=42,
|
|
48
|
+
optional_bool=True,
|
|
49
|
+
optional_list=[2, 3, 5],
|
|
50
|
+
optional_dict={"x": 7},
|
|
51
|
+
optional_child=DataClassChildForTest(137),
|
|
52
|
+
)
|
|
53
|
+
self.assertEqual(data, from_json(to_json(data), DataClassWithOptionals))
|
|
54
|
+
|
|
55
|
+
def test_round_trip_optional_nones(self):
|
|
56
|
+
data = DataClassWithOptionals(
|
|
57
|
+
optional_str=None,
|
|
58
|
+
optional_int=None,
|
|
59
|
+
optional_bool=None,
|
|
60
|
+
optional_list=None,
|
|
61
|
+
optional_dict=None,
|
|
62
|
+
optional_child=None,
|
|
63
|
+
)
|
|
64
|
+
data_json = to_json(data)
|
|
65
|
+
self.assertEqual("{}", data_json)
|
|
66
|
+
self.assertEqual(data, from_json(data_json, DataClassWithOptionals))
|
|
67
|
+
|
|
68
|
+
def test_round_trip_default(self):
|
|
69
|
+
data = DataClassWithDefaults()
|
|
70
|
+
data_json = to_json(data)
|
|
71
|
+
self.assertCountEqual(
|
|
72
|
+
{"required_int_with_default": -1, "optional_int_with_int_default": -2}.items(),
|
|
73
|
+
json.loads(data_json).items(),
|
|
74
|
+
)
|
|
75
|
+
self.assertEqual(data, from_json(data_json, DataClassWithDefaults))
|
|
76
|
+
|
|
77
|
+
def test_round_trip_default_ints(self):
|
|
78
|
+
data = DataClassWithDefaults(
|
|
79
|
+
required_int_with_default=1,
|
|
80
|
+
optional_int_with_int_default=2,
|
|
81
|
+
optional_int_with_none_default=3,
|
|
82
|
+
)
|
|
83
|
+
data_json = to_json(data)
|
|
84
|
+
self.assertEqual(data, from_json(data_json, DataClassWithDefaults))
|
|
85
|
+
|
|
86
|
+
def test_round_trip_default_nones(self):
|
|
87
|
+
data = DataClassWithDefaults(
|
|
88
|
+
optional_int_with_int_default=None,
|
|
89
|
+
optional_int_with_none_default=None,
|
|
90
|
+
)
|
|
91
|
+
data_json = to_json(data)
|
|
92
|
+
self.assertCountEqual(
|
|
93
|
+
{
|
|
94
|
+
"required_int_with_default": -1,
|
|
95
|
+
# `optional_int_with_int_default` should deserialize back to None,
|
|
96
|
+
# rather than the default int value. Therefore it must be
|
|
97
|
+
# serialized to null in JSON instead of removed.
|
|
98
|
+
"optional_int_with_int_default": None,
|
|
99
|
+
}.items(),
|
|
100
|
+
json.loads(data_json).items(),
|
|
101
|
+
)
|
|
102
|
+
self.assertEqual(data, from_json(data_json, DataClassWithDefaults))
|
|
103
|
+
|
|
104
|
+
def test_round_trip_perturbation_descriptions(self):
|
|
105
|
+
descriptions = [
|
|
106
|
+
PerturbationDescription(
|
|
107
|
+
name="unknown",
|
|
108
|
+
),
|
|
109
|
+
DialectPerturbation.Description(
|
|
110
|
+
name=DialectPerturbation.name,
|
|
111
|
+
fairness=True,
|
|
112
|
+
prob=0.5,
|
|
113
|
+
source_class="source_class",
|
|
114
|
+
target_class="target_class",
|
|
115
|
+
mapping_file_path="mapping_file_path",
|
|
116
|
+
),
|
|
117
|
+
ExtraSpacePerturbation.Description(name=ExtraSpacePerturbation.name, robustness=True, num_spaces=2),
|
|
118
|
+
FillerWordsPerturbation.Description(name=FillerWordsPerturbation.name, robustness=True, insert_prob=0.5),
|
|
119
|
+
GenderPerturbation.Description(
|
|
120
|
+
name=GenderPerturbation.name,
|
|
121
|
+
mode="mode",
|
|
122
|
+
fairness=True,
|
|
123
|
+
prob=0.5,
|
|
124
|
+
source_class="source_class",
|
|
125
|
+
target_class="target_class",
|
|
126
|
+
bidirectional=True,
|
|
127
|
+
),
|
|
128
|
+
MisspellingPerturbation.Description(name=MisspellingPerturbation.name, robustness=True, prob=0.5),
|
|
129
|
+
PersonNamePerturbation.Description(
|
|
130
|
+
name=PersonNamePerturbation.name,
|
|
131
|
+
fairness=True,
|
|
132
|
+
prob=0.5,
|
|
133
|
+
source_class="source_str",
|
|
134
|
+
target_class="target_str",
|
|
135
|
+
name_file_path="name_file_path",
|
|
136
|
+
person_name_type="person_name_type",
|
|
137
|
+
preserve_gender=True,
|
|
138
|
+
),
|
|
139
|
+
SpacePerturbation.Description(name=SpacePerturbation.name, robustness=True, max_spaces=2),
|
|
140
|
+
SynonymPerturbation.Description(name=SynonymPerturbation.name, robustness=True, prob=0.5),
|
|
141
|
+
TyposPerturbation.Description(name=TyposPerturbation.name, robustness=True, prob=0.5),
|
|
142
|
+
]
|
|
143
|
+
for description in descriptions:
|
|
144
|
+
self.assertEqual(description, from_json(to_json(description), PerturbationDescription))
|
|
@@ -2,7 +2,11 @@ import json
|
|
|
2
2
|
import requests
|
|
3
3
|
from typing import Any, Dict, List
|
|
4
4
|
|
|
5
|
+
from aleph_alpha_client import Client as AlephAlphaPythonClient
|
|
6
|
+
from tokenizers import Tokenizer, Encoding
|
|
7
|
+
|
|
5
8
|
from helm.common.cache import Cache, CacheConfig
|
|
9
|
+
from helm.common.hierarchical_logger import hlog
|
|
6
10
|
from helm.common.request import Request, RequestResult, Sequence, Token
|
|
7
11
|
from helm.common.tokenization_request import (
|
|
8
12
|
DecodeRequest,
|
|
@@ -19,9 +23,27 @@ class AlephAlphaClient(Client):
|
|
|
19
23
|
TOKENIZE_ENDPOINT: str = "tokenize"
|
|
20
24
|
DETOKENIZE_ENDPOINT: str = "detokenize"
|
|
21
25
|
|
|
26
|
+
VALID_TOKENIZERS: List[str] = [
|
|
27
|
+
"luminous-base",
|
|
28
|
+
"luminous-extended",
|
|
29
|
+
"luminous-supreme",
|
|
30
|
+
]
|
|
31
|
+
|
|
22
32
|
def __init__(self, api_key: str, cache_config: CacheConfig):
|
|
23
33
|
self.api_key: str = api_key
|
|
24
34
|
self.cache = Cache(cache_config)
|
|
35
|
+
self._aleph_alpha_client = AlephAlphaPythonClient(token=api_key)
|
|
36
|
+
self._tokenizer_name_to_tokenizer: Dict[str, Tokenizer] = {}
|
|
37
|
+
|
|
38
|
+
def _get_tokenizer(self, tokenizer_name: str) -> Tokenizer:
|
|
39
|
+
if tokenizer_name not in self.VALID_TOKENIZERS:
|
|
40
|
+
raise ValueError(f"Invalid tokenizer: {tokenizer_name}")
|
|
41
|
+
|
|
42
|
+
# Check if the tokenizer is cached
|
|
43
|
+
if tokenizer_name not in self._tokenizer_name_to_tokenizer:
|
|
44
|
+
self._tokenizer_name_to_tokenizer[tokenizer_name] = self._aleph_alpha_client.tokenizer(tokenizer_name)
|
|
45
|
+
hlog(f"Initialized tokenizer: {tokenizer_name}")
|
|
46
|
+
return self._tokenizer_name_to_tokenizer[tokenizer_name]
|
|
25
47
|
|
|
26
48
|
def _send_request(self, endpoint: str, raw_request: Dict[str, Any]) -> Dict[str, Any]:
|
|
27
49
|
response = requests.request(
|
|
@@ -33,6 +55,8 @@ class AlephAlphaClient(Client):
|
|
|
33
55
|
"Authorization": f"Bearer {self.api_key}",
|
|
34
56
|
},
|
|
35
57
|
data=json.dumps(raw_request),
|
|
58
|
+
# Setting the nice flag prevents intensive benchmarking runs from saturating Aleph Alpha's API queues
|
|
59
|
+
params=json.dumps({"nice": True}),
|
|
36
60
|
)
|
|
37
61
|
result = json.loads(response.text)
|
|
38
62
|
assert "error" not in result, f"Request failed with error: {result['error']}"
|
|
@@ -40,7 +64,6 @@ class AlephAlphaClient(Client):
|
|
|
40
64
|
|
|
41
65
|
def make_request(self, request: Request) -> RequestResult:
|
|
42
66
|
"""Make a request following https://docs.aleph-alpha.com/api/complete."""
|
|
43
|
-
# TODO: echo is not supported. Follow up on this.
|
|
44
67
|
raw_request = {
|
|
45
68
|
"model": request.model_engine,
|
|
46
69
|
"prompt": request.prompt,
|
|
@@ -53,6 +76,7 @@ class AlephAlphaClient(Client):
|
|
|
53
76
|
"n": request.num_completions,
|
|
54
77
|
"stop_sequences": request.stop_sequences,
|
|
55
78
|
"log_probs": request.top_k_per_token,
|
|
79
|
+
"echo": request.echo_prompt,
|
|
56
80
|
"tokens": True, # Setting to True returns individual tokens of the completion
|
|
57
81
|
}
|
|
58
82
|
|
|
@@ -102,24 +126,21 @@ class AlephAlphaClient(Client):
|
|
|
102
126
|
)
|
|
103
127
|
|
|
104
128
|
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
"tokens": True,
|
|
110
|
-
"token_ids": True,
|
|
111
|
-
}
|
|
112
|
-
|
|
129
|
+
"""
|
|
130
|
+
Encode the text using Aleph Alpha's tokenizer library:
|
|
131
|
+
https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.Client.tokenizer
|
|
132
|
+
"""
|
|
113
133
|
try:
|
|
114
134
|
|
|
115
135
|
def do_it():
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
return result
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
136
|
+
tokenizer: Tokenizer = self._get_tokenizer(request.tokenizer_name)
|
|
137
|
+
result: Encoding = tokenizer.encode(request.text, add_special_tokens=False)
|
|
138
|
+
return {"token_ids": result.ids, "tokens": result.tokens}
|
|
139
|
+
|
|
140
|
+
cache_key = {"model": request.tokenizer_name, "prompt": request.text, "tokens": True, "token_ids": True}
|
|
141
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
142
|
+
except RuntimeError as e:
|
|
143
|
+
error: str = f"AlephAlphaClient tokenize error: {e}"
|
|
123
144
|
return TokenizationRequestResult(error=error, success=False, cached=False, text="", tokens=[])
|
|
124
145
|
|
|
125
146
|
tokens = response["token_ids" if request.encode else "tokens"]
|
|
@@ -135,22 +156,20 @@ class AlephAlphaClient(Client):
|
|
|
135
156
|
)
|
|
136
157
|
|
|
137
158
|
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
}
|
|
143
|
-
|
|
159
|
+
"""
|
|
160
|
+
Decode the tokens using Aleph Alpha's tokenizer library:
|
|
161
|
+
https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.Client.tokenizer
|
|
162
|
+
"""
|
|
144
163
|
try:
|
|
145
164
|
|
|
146
165
|
def do_it():
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
return result
|
|
166
|
+
tokenizer: Tokenizer = self._get_tokenizer(request.tokenizer_name)
|
|
167
|
+
return {"result": tokenizer.decode(request.tokens)}
|
|
150
168
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
169
|
+
cache_key = {"model": request.tokenizer_name, "token_ids": request.tokens}
|
|
170
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
171
|
+
except RuntimeError as e:
|
|
172
|
+
error: str = f"AlephAlphaClient decode error: {e}"
|
|
154
173
|
return DecodeRequestResult(error=error, success=False, cached=False, text="")
|
|
155
174
|
|
|
156
175
|
return DecodeRequestResult(
|
|
@@ -21,6 +21,7 @@ from .anthropic_client import AnthropicClient
|
|
|
21
21
|
from .chat_gpt_client import ChatGPTClient
|
|
22
22
|
from .cohere_client import CohereClient
|
|
23
23
|
from .together_client import TogetherClient
|
|
24
|
+
from .google_client import GoogleClient
|
|
24
25
|
from .goose_ai_client import GooseAIClient
|
|
25
26
|
from .huggingface_client import HuggingFaceClient
|
|
26
27
|
from .ice_tokenizer_client import ICETokenizerClient
|
|
@@ -29,6 +30,7 @@ from .microsoft_client import MicrosoftClient
|
|
|
29
30
|
from .perspective_api_client import PerspectiveAPIClient
|
|
30
31
|
from .yalm_tokenizer_client import YaLMTokenizerClient
|
|
31
32
|
from .simple_client import SimpleClient
|
|
33
|
+
from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_config
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
class AutoClient(Client):
|
|
@@ -53,15 +55,17 @@ class AutoClient(Client):
|
|
|
53
55
|
# TODO: Allow setting CacheConfig.follower_cache_path from a command line flag.
|
|
54
56
|
return SqliteCacheConfig(client_cache_path)
|
|
55
57
|
|
|
56
|
-
def
|
|
57
|
-
"""Return a client based on
|
|
58
|
-
|
|
59
|
-
client: Optional[Client] = self.clients.get(organization)
|
|
58
|
+
def _get_client(self, model: str) -> Client:
|
|
59
|
+
"""Return a client based on the model, creating it if necessary."""
|
|
60
|
+
client: Optional[Client] = self.clients.get(model)
|
|
60
61
|
|
|
61
62
|
if client is None:
|
|
63
|
+
organization: str = model.split("/")[0]
|
|
62
64
|
cache_config: CacheConfig = self._build_cache_config(organization)
|
|
63
65
|
|
|
64
|
-
if
|
|
66
|
+
if get_huggingface_model_config(model):
|
|
67
|
+
client = HuggingFaceClient(cache_config=cache_config)
|
|
68
|
+
elif organization == "openai":
|
|
65
69
|
# TODO: add ChatGPT to the OpenAIClient when it's supported.
|
|
66
70
|
# We're using a separate client for now since we're using an unofficial Python library.
|
|
67
71
|
# See https://github.com/acheong08/ChatGPT/wiki/Setup on how to get a valid session token.
|
|
@@ -71,13 +75,14 @@ class AutoClient(Client):
|
|
|
71
75
|
# TODO: use `cache_config` above. Since this feature is still experimental,
|
|
72
76
|
# save queries and responses in a separate collection.
|
|
73
77
|
cache_config=self._build_cache_config("ChatGPT"),
|
|
74
|
-
tokenizer_client=self.
|
|
78
|
+
tokenizer_client=self._get_tokenizer_client("huggingface"),
|
|
75
79
|
)
|
|
76
80
|
|
|
77
81
|
org_id = self.credentials.get("openaiOrgId", None)
|
|
78
82
|
client = OpenAIClient(
|
|
79
83
|
api_key=self.credentials["openaiApiKey"],
|
|
80
84
|
cache_config=cache_config,
|
|
85
|
+
tokenizer_client=self._get_tokenizer_client("huggingface"),
|
|
81
86
|
chat_gpt_client=chat_gpt_client,
|
|
82
87
|
org_id=org_id,
|
|
83
88
|
)
|
|
@@ -105,18 +110,20 @@ class AutoClient(Client):
|
|
|
105
110
|
cache_config=cache_config,
|
|
106
111
|
org_id=org_id,
|
|
107
112
|
)
|
|
113
|
+
elif organization == "google":
|
|
114
|
+
client = GoogleClient(cache_config=cache_config)
|
|
108
115
|
elif organization == "together":
|
|
109
116
|
client = TogetherClient(api_key=self.credentials.get("togetherApiKey", None), cache_config=cache_config)
|
|
110
117
|
elif organization == "simple":
|
|
111
118
|
client = SimpleClient(cache_config=cache_config)
|
|
112
119
|
else:
|
|
113
|
-
raise ValueError(f"
|
|
114
|
-
self.clients[
|
|
120
|
+
raise ValueError(f"Could not find client for model: {model}")
|
|
121
|
+
self.clients[model] = client
|
|
115
122
|
return client
|
|
116
123
|
|
|
117
124
|
def make_request(self, request: Request) -> RequestResult:
|
|
118
125
|
"""
|
|
119
|
-
Dispatch based on the
|
|
126
|
+
Dispatch based on the the name of the model (e.g., openai/davinci).
|
|
120
127
|
Retries if request fails.
|
|
121
128
|
"""
|
|
122
129
|
|
|
@@ -125,30 +132,33 @@ class AutoClient(Client):
|
|
|
125
132
|
def make_request_with_retry(client: Client, request: Request) -> RequestResult:
|
|
126
133
|
return client.make_request(request)
|
|
127
134
|
|
|
128
|
-
|
|
129
|
-
client: Client = self.get_client(request)
|
|
135
|
+
client: Client = self._get_client(request.model)
|
|
130
136
|
|
|
131
137
|
try:
|
|
132
138
|
return make_request_with_retry(client=client, request=request)
|
|
133
139
|
except RetryError as e:
|
|
134
140
|
last_attempt: Attempt = e.last_attempt
|
|
135
141
|
retry_error: str = (
|
|
136
|
-
f"Failed to make request to {
|
|
142
|
+
f"Failed to make request to {request.model} after retrying {last_attempt.attempt_number} times"
|
|
137
143
|
)
|
|
138
144
|
hlog(retry_error)
|
|
139
145
|
|
|
140
146
|
# Notify our user that we failed to make the request even after retrying.
|
|
141
147
|
return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")
|
|
142
148
|
|
|
143
|
-
def
|
|
144
|
-
"""Return a client based on
|
|
145
|
-
|
|
149
|
+
def _get_tokenizer_client(self, tokenizer: str) -> Client:
|
|
150
|
+
"""Return a client based on the tokenizer, creating it if necessary."""
|
|
151
|
+
organization: str = tokenizer.split("/")[0]
|
|
152
|
+
client: Optional[Client] = self.tokenizer_clients.get(tokenizer)
|
|
146
153
|
|
|
147
154
|
if client is None:
|
|
148
155
|
cache_config: CacheConfig = self._build_cache_config(organization)
|
|
149
|
-
if
|
|
156
|
+
if get_huggingface_model_config(tokenizer):
|
|
157
|
+
client = HuggingFaceClient(cache_config=cache_config)
|
|
158
|
+
elif organization in [
|
|
150
159
|
"anthropic",
|
|
151
160
|
"bigscience",
|
|
161
|
+
"bigcode",
|
|
152
162
|
"EleutherAI",
|
|
153
163
|
"facebook",
|
|
154
164
|
"google",
|
|
@@ -171,19 +181,18 @@ class AutoClient(Client):
|
|
|
171
181
|
elif organization == "simple":
|
|
172
182
|
client = SimpleClient(cache_config=cache_config)
|
|
173
183
|
else:
|
|
174
|
-
raise ValueError(f"
|
|
175
|
-
self.tokenizer_clients[
|
|
184
|
+
raise ValueError(f"Could not find tokenizer client for model: {tokenizer}")
|
|
185
|
+
self.tokenizer_clients[tokenizer] = client
|
|
176
186
|
return client
|
|
177
187
|
|
|
178
188
|
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
179
|
-
"""Tokenizes based on the
|
|
189
|
+
"""Tokenizes based on the name of the tokenizer (e.g., huggingface/gpt2)."""
|
|
180
190
|
|
|
181
191
|
@retry_request
|
|
182
192
|
def tokenize_with_retry(client: Client, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
183
193
|
return client.tokenize(request)
|
|
184
194
|
|
|
185
|
-
|
|
186
|
-
client: Client = self.get_tokenizer_client(organization)
|
|
195
|
+
client: Client = self._get_tokenizer_client(request.tokenizer)
|
|
187
196
|
|
|
188
197
|
try:
|
|
189
198
|
return tokenize_with_retry(client=client, request=request)
|
|
@@ -194,14 +203,13 @@ class AutoClient(Client):
|
|
|
194
203
|
return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")
|
|
195
204
|
|
|
196
205
|
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
197
|
-
"""Decodes based on the
|
|
206
|
+
"""Decodes based on the the name of the tokenizer (e.g., huggingface/gpt2)."""
|
|
198
207
|
|
|
199
208
|
@retry_request
|
|
200
209
|
def decode_with_retry(client: Client, request: DecodeRequest) -> DecodeRequestResult:
|
|
201
210
|
return client.decode(request)
|
|
202
211
|
|
|
203
|
-
|
|
204
|
-
client: Client = self.get_tokenizer_client(organization)
|
|
212
|
+
client: Client = self._get_tokenizer_client(request.tokenizer)
|
|
205
213
|
|
|
206
214
|
try:
|
|
207
215
|
return decode_with_retry(client=client, request=request)
|