crfm-helm 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/METADATA +41 -57
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/RECORD +197 -152
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
- helm/benchmark/adaptation/common_adapter_specs.py +2 -0
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/anthropic_red_team_annotator.py +70 -0
- helm/benchmark/annotation/call_center_annotator.py +247 -0
- helm/benchmark/annotation/financebench_annotator.py +79 -0
- helm/benchmark/annotation/harm_bench_annotator.py +68 -0
- helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
- helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
- helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
- helm/benchmark/annotation/live_qa_annotator.py +71 -0
- helm/benchmark/annotation/medication_qa_annotator.py +68 -0
- helm/benchmark/annotation/model_as_judge.py +45 -0
- helm/benchmark/annotation/simple_safety_tests_annotator.py +64 -0
- helm/benchmark/annotation/xstest_annotator.py +110 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/annotation_metrics.py +108 -0
- helm/benchmark/metrics/bhasa_metrics.py +188 -0
- helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
- helm/benchmark/metrics/code_metrics_helper.py +11 -1
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/safety_metrics.py +57 -0
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +30 -72
- helm/benchmark/metrics/vision_language/image_utils.py +1 -1
- helm/benchmark/model_metadata_registry.py +3 -3
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_run_entry.py +1 -0
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +31 -2
- helm/benchmark/run_expander.py +113 -10
- helm/benchmark/run_spec_factory.py +4 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/bhasa_run_specs.py +638 -0
- helm/benchmark/run_specs/call_center_run_specs.py +152 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +11 -9
- helm/benchmark/run_specs/experimental_run_specs.py +85 -0
- helm/benchmark/run_specs/finance_run_specs.py +110 -0
- helm/benchmark/run_specs/safety_run_specs.py +154 -0
- helm/benchmark/run_specs/vlm_run_specs.py +251 -57
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
- helm/benchmark/scenarios/banking77_scenario.py +51 -0
- helm/benchmark/scenarios/bhasa_scenario.py +1798 -0
- helm/benchmark/scenarios/call_center_scenario.py +84 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/ewok_scenario.py +116 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +119 -0
- helm/benchmark/scenarios/financebench_scenario.py +53 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
- helm/benchmark/scenarios/scenario.py +1 -1
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
- helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
- helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
- helm/benchmark/scenarios/test_math_scenario.py +2 -8
- helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
- helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
- helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +5 -5
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +13 -2
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -7
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +44 -13
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +7 -6
- helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +5 -5
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +98 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
- helm/benchmark/scenarios/xstest_scenario.py +35 -0
- helm/benchmark/server.py +1 -6
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_bhasa.yaml +709 -0
- helm/benchmark/static/schema_call_center.yaml +232 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_cleva.yaml +768 -0
- helm/benchmark/static/schema_decodingtrust.yaml +444 -0
- helm/benchmark/static/schema_ewok.yaml +367 -0
- helm/benchmark/static/schema_finance.yaml +189 -0
- helm/benchmark/static/schema_image2struct.yaml +588 -0
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_safety.yaml +247 -0
- helm/benchmark/static/schema_tables.yaml +317 -0
- helm/benchmark/static/schema_thai.yaml +244 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +304 -298
- helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
- helm/benchmark/static_build/assets/index-58f97dcd.js +10 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/test_openai_window_service.py +8 -8
- helm/clients/ai21_client.py +71 -1
- helm/clients/anthropic_client.py +50 -28
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +79 -19
- helm/clients/nvidia_nim_client.py +35 -0
- helm/clients/openai_client.py +11 -5
- helm/clients/palmyra_client.py +25 -0
- helm/clients/perspective_api_client.py +11 -6
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +7 -9
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +129 -23
- helm/clients/vertexai_client.py +62 -18
- helm/clients/vision_language/huggingface_vlm_client.py +1 -0
- helm/clients/vision_language/open_flamingo_client.py +1 -2
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +99 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +25 -0
- helm/common/mongo_key_value_store.py +2 -1
- helm/common/request.py +16 -0
- helm/config/model_deployments.yaml +740 -363
- helm/config/model_metadata.yaml +824 -128
- helm/config/tokenizer_configs.yaml +207 -10
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/example_queries.py +14 -21
- helm/proxy/services/server_service.py +2 -3
- helm/proxy/token_counters/test_auto_token_counter.py +2 -2
- helm/tokenizers/ai21_tokenizer.py +51 -59
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +29 -62
- helm/tokenizers/huggingface_tokenizer.py +35 -13
- helm/tokenizers/test_ai21_tokenizer.py +48 -0
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static/benchmarking.css +0 -156
- helm/benchmark/static/benchmarking.js +0 -1705
- helm/benchmark/static/config.js +0 -3
- helm/benchmark/static/general.js +0 -122
- helm/benchmark/static/images/crfm-logo.png +0 -0
- helm/benchmark/static/images/helm-logo-simple.png +0 -0
- helm/benchmark/static/images/helm-logo.png +0 -0
- helm/benchmark/static/images/language-model-helm.png +0 -0
- helm/benchmark/static/images/organizations/ai21.png +0 -0
- helm/benchmark/static/images/organizations/anthropic.png +0 -0
- helm/benchmark/static/images/organizations/bigscience.png +0 -0
- helm/benchmark/static/images/organizations/cohere.png +0 -0
- helm/benchmark/static/images/organizations/eleutherai.png +0 -0
- helm/benchmark/static/images/organizations/google.png +0 -0
- helm/benchmark/static/images/organizations/meta.png +0 -0
- helm/benchmark/static/images/organizations/microsoft.png +0 -0
- helm/benchmark/static/images/organizations/nvidia.png +0 -0
- helm/benchmark/static/images/organizations/openai.png +0 -0
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
- helm/benchmark/static/images/organizations/yandex.png +0 -0
- helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
- helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
- helm/benchmark/static/index.html +0 -68
- helm/benchmark/static/info-icon.png +0 -0
- helm/benchmark/static/json-urls.js +0 -69
- helm/benchmark/static/plot-captions.js +0 -27
- helm/benchmark/static/schema_image2structure.yaml +0 -304
- helm/benchmark/static/utils.js +0 -285
- helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
- helm/benchmark/static_build/assets/index-878a1094.css +0 -1
- helm/benchmark/window_services/ai21_window_service.py +0 -247
- helm/benchmark/window_services/cohere_window_service.py +0 -101
- helm/benchmark/window_services/test_ai21_window_service.py +0 -163
- helm/benchmark/window_services/test_cohere_window_service.py +0 -75
- helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
- helm/benchmark/window_services/test_ice_window_service.py +0 -327
- helm/tokenizers/ice_tokenizer.py +0 -30
- helm/tokenizers/test_ice_tokenizer.py +0 -57
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/top_level.txt +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
|
@@ -1,247 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
|
|
3
|
-
from typing import List, Optional, Tuple
|
|
4
|
-
from urllib.parse import unquote
|
|
5
|
-
|
|
6
|
-
from helm.common.tokenization_request import (
|
|
7
|
-
TokenizationRequest,
|
|
8
|
-
TokenizationRequestResult,
|
|
9
|
-
TokenizationToken,
|
|
10
|
-
TextRange,
|
|
11
|
-
)
|
|
12
|
-
from .window_service import ConfigurableWindowService, EncodeResult, WindowService
|
|
13
|
-
from .tokenizer_service import TokenizerService
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class AI21WindowService(ConfigurableWindowService):
|
|
17
|
-
"""Tokenizes by making a request to the proxy server with REST endpoint: `/api/tokenize`."""
|
|
18
|
-
|
|
19
|
-
# AI21's tokenizer API rejects a tokenization request if the input sequence is too long, so
|
|
20
|
-
# we need to set an upper limit for the length of the request. Empirically, if the GPT2 tokenizer tokenizes a
|
|
21
|
-
# sequence to <= 11000 tokens, then it is most likely safe to assume that AI21's tokenization API will
|
|
22
|
-
# process this request.
|
|
23
|
-
MAX_TOKENIZATION_REQUEST_LENGTH: int = 11000
|
|
24
|
-
|
|
25
|
-
# The AI21 tokenizer throws the following error when sending a request with text that has too many characters:
|
|
26
|
-
# "Text must be under 100,000 characters (type=value_error)"
|
|
27
|
-
# Sending a request with 100,000 characters seem to work though.
|
|
28
|
-
MAX_CHARACTER_LENGTH: int = 100_000
|
|
29
|
-
|
|
30
|
-
NOT_IMPLEMENTED_ERROR_MESSAGE: str = (
|
|
31
|
-
"AI21 only gave API access to their tokenizer, so this method is not supported."
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
def __init__(
|
|
35
|
-
self,
|
|
36
|
-
gpt2_window_service: WindowService,
|
|
37
|
-
service: TokenizerService,
|
|
38
|
-
tokenizer_name: str,
|
|
39
|
-
max_sequence_length: int,
|
|
40
|
-
max_request_length: Optional[int] = None,
|
|
41
|
-
max_sequence_and_generated_tokens_length: Optional[int] = None,
|
|
42
|
-
end_of_text_token: Optional[str] = None,
|
|
43
|
-
prefix_token: Optional[str] = None,
|
|
44
|
-
):
|
|
45
|
-
super().__init__(
|
|
46
|
-
tokenizer_name=tokenizer_name,
|
|
47
|
-
max_sequence_length=max_sequence_length,
|
|
48
|
-
max_request_length=max_request_length,
|
|
49
|
-
max_sequence_and_generated_tokens_length=max_sequence_and_generated_tokens_length,
|
|
50
|
-
end_of_text_token=end_of_text_token,
|
|
51
|
-
prefix_token=prefix_token,
|
|
52
|
-
)
|
|
53
|
-
# We need the `TokenizerService` to make requests to the server.
|
|
54
|
-
self.service: TokenizerService = service
|
|
55
|
-
# As explained above, we need a `GPT2WindowService` to help tokenize long text sequences.
|
|
56
|
-
self.gpt2_window_service: WindowService = gpt2_window_service
|
|
57
|
-
|
|
58
|
-
def encode(self, text: str, truncation: bool = False, max_length: Optional[int] = None) -> EncodeResult:
|
|
59
|
-
"""
|
|
60
|
-
Encodes the input text to tokens.
|
|
61
|
-
"""
|
|
62
|
-
tokens: List[TokenizationToken]
|
|
63
|
-
normalized_text: str
|
|
64
|
-
tokens, normalized_text = self._make_long_tokenization_request(text)
|
|
65
|
-
# The end position of the last token should be the end of the text.
|
|
66
|
-
if len(tokens) > 0:
|
|
67
|
-
assert tokens[-1].text_range is not None
|
|
68
|
-
assert tokens[-1].text_range.end == len(normalized_text)
|
|
69
|
-
|
|
70
|
-
return EncodeResult(text=normalized_text, tokens=tokens)
|
|
71
|
-
|
|
72
|
-
def decode(self, tokens: List[TokenizationToken], normalized_text: Optional[str] = None) -> str:
|
|
73
|
-
"""
|
|
74
|
-
Given the model and a list of tokens, outputs the corresponding text.
|
|
75
|
-
|
|
76
|
-
For models using the GPT-2 tokenizer, the tokens are integers; for AI21
|
|
77
|
-
models, the tokens are `TokenizationToken`s.
|
|
78
|
-
|
|
79
|
-
Some tokenizers (e.g. AI21) normalize the text before encoding it and
|
|
80
|
-
thus require the `normalized_text` for decoding.
|
|
81
|
-
"""
|
|
82
|
-
if not tokens:
|
|
83
|
-
return ""
|
|
84
|
-
|
|
85
|
-
# `normalized_text` is necessary for decoding AI21 tokens.
|
|
86
|
-
assert normalized_text, "The AI21 tokenizer needs `normalized_text` for decoding"
|
|
87
|
-
for j in range(len(tokens) - 1):
|
|
88
|
-
first_text_range = tokens[j].text_range
|
|
89
|
-
second_text_range = tokens[j + 1].text_range
|
|
90
|
-
assert first_text_range is not None
|
|
91
|
-
assert second_text_range is not None
|
|
92
|
-
assert (
|
|
93
|
-
first_text_range.end == second_text_range.start
|
|
94
|
-
), "The tokens to be decoded must form a substring of `normalized_text`."
|
|
95
|
-
|
|
96
|
-
token_texts: List[str] = []
|
|
97
|
-
# The format of AI21 byte token representations. e.g. <0xE8>
|
|
98
|
-
byte_pattern = "<0x[0-9A-F]{2}>"
|
|
99
|
-
i: int = 0
|
|
100
|
-
while i < len(tokens):
|
|
101
|
-
# If there are byte tokens, aggregates them to a string
|
|
102
|
-
token_value = tokens[i].value
|
|
103
|
-
assert isinstance(token_value, str)
|
|
104
|
-
if re.match(byte_pattern, token_value):
|
|
105
|
-
bytestring = ""
|
|
106
|
-
while i < len(tokens) and re.match(byte_pattern, token_value):
|
|
107
|
-
# e.g. <0xE8> -> \xE8
|
|
108
|
-
bytestring += "\\" + token_value[2:-1]
|
|
109
|
-
i += 1
|
|
110
|
-
# Convert to encoded URI (e.g., %e2%80%99) and decode
|
|
111
|
-
token_text = unquote(bytestring.replace("\\x", "%"))
|
|
112
|
-
# Not a byte token: retrieves the token text based on text_range.
|
|
113
|
-
else:
|
|
114
|
-
token: TokenizationToken = tokens[i]
|
|
115
|
-
assert token.text_range is not None
|
|
116
|
-
token_text = normalized_text[token.text_range.start : token.text_range.end]
|
|
117
|
-
i += 1
|
|
118
|
-
token_texts.append(token_text)
|
|
119
|
-
return "".join(token_texts)
|
|
120
|
-
|
|
121
|
-
def tokenize(self, text: str) -> List[str]:
|
|
122
|
-
"""
|
|
123
|
-
Tokenizes the text via the /api/tokenize REST endpoint.
|
|
124
|
-
"""
|
|
125
|
-
response: TokenizationRequestResult = self._make_tokenization_request(text)
|
|
126
|
-
result = []
|
|
127
|
-
for token in response.tokens:
|
|
128
|
-
assert isinstance(token.value, str)
|
|
129
|
-
result.append(token.value)
|
|
130
|
-
return result
|
|
131
|
-
|
|
132
|
-
def get_num_tokens(self, text: str) -> int:
|
|
133
|
-
"""Tokenizes the text using the GPT-2 tokenizer and returns the number of tokens."""
|
|
134
|
-
return len(self.tokenize(text))
|
|
135
|
-
|
|
136
|
-
def fits_within_context_window(self, text: str, expected_completion_token_length: int = 0) -> bool:
|
|
137
|
-
return (
|
|
138
|
-
len(text) <= AI21WindowService.MAX_CHARACTER_LENGTH
|
|
139
|
-
and self.get_num_tokens(text) + expected_completion_token_length <= self.max_request_length
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
def truncate_from_right(self, text: str, expected_completion_token_length: int = 0) -> str:
|
|
143
|
-
"""
|
|
144
|
-
Truncates the text using the AI21 Jurassic tokenizer.
|
|
145
|
-
First, ensures the text is shorter than `AI21Tokenizer.MAX_CHARACTER_LENGTH` long.
|
|
146
|
-
Tokenizes, then truncates the list of tokens to fit within the context window minus the
|
|
147
|
-
expected completion length (defaults to 0), then uses the start of the text range of the first
|
|
148
|
-
token and the end of the text range of the last token of the truncated list of tokens to
|
|
149
|
-
build the truncated text.
|
|
150
|
-
"""
|
|
151
|
-
text = text[: AI21WindowService.MAX_CHARACTER_LENGTH]
|
|
152
|
-
response: TokenizationRequestResult = self._make_tokenization_request(text)
|
|
153
|
-
|
|
154
|
-
# Only look at the first `self.max_request_length` - `expected_completion_token_length`
|
|
155
|
-
# number of tokens to the fit the text within the context window.
|
|
156
|
-
# Each token is represented like this: {'text': '▁Hello', 'textRange': {'start': 0, 'end': 5}}
|
|
157
|
-
max_length: int = self.max_request_length - expected_completion_token_length
|
|
158
|
-
tokens: List[TokenizationToken] = response.tokens[:max_length]
|
|
159
|
-
|
|
160
|
-
# If there is no tokens, just return the original text
|
|
161
|
-
if len(tokens) == 0:
|
|
162
|
-
return text
|
|
163
|
-
|
|
164
|
-
# AI21 uses "_" to represent a single space in their tokens, so we have to build the new text from the
|
|
165
|
-
# original text after truncation using the text ranges of tokens generated from the original text.
|
|
166
|
-
assert tokens[0].text_range is not None
|
|
167
|
-
first_text_range: TextRange = tokens[0].text_range
|
|
168
|
-
assert tokens[-1].text_range is not None
|
|
169
|
-
last_text_range: TextRange = tokens[-1].text_range
|
|
170
|
-
start: int = first_text_range.start
|
|
171
|
-
end: int = last_text_range.end
|
|
172
|
-
truncated_text: str = text[start:end]
|
|
173
|
-
|
|
174
|
-
# HACK: For the vast majority of cases, the above logic works, but there are a few where the
|
|
175
|
-
# token count exceeds `max_length` by 1. This might be a bug with the AI21 tokenizer API.
|
|
176
|
-
# We handle those by removing characters one by one until it fits within the context window.
|
|
177
|
-
while not self.fits_within_context_window(truncated_text, expected_completion_token_length):
|
|
178
|
-
end -= 1
|
|
179
|
-
truncated_text = text[start:end]
|
|
180
|
-
return truncated_text
|
|
181
|
-
|
|
182
|
-
def _make_tokenization_request(self, text: str) -> TokenizationRequestResult:
|
|
183
|
-
"""Sends a request to the server to tokenize the text via the `TokenizerService`."""
|
|
184
|
-
return self.service.tokenize(TokenizationRequest(text=text, tokenizer=self.tokenizer_name))
|
|
185
|
-
|
|
186
|
-
def _make_long_tokenization_request(self, text: str) -> Tuple[List[TokenizationToken], str]:
|
|
187
|
-
"""If the text is too long (longer than 11,000 tokens when tokenized by the GPT-2 tokenizer),
|
|
188
|
-
the AI21 server will close the connection. Therefore, we need to split the text into smaller
|
|
189
|
-
chunks, tokenize each chunk, and re-assemble the tokenization results."""
|
|
190
|
-
# Uses the number of gpt2-style tokens as a measure of text length.
|
|
191
|
-
gpt2_tokens: List[TokenizationToken] = self.gpt2_window_service.encode(text).tokens
|
|
192
|
-
|
|
193
|
-
# If the text is short, just makes one request and returns the result.
|
|
194
|
-
if len(gpt2_tokens) < AI21WindowService.MAX_TOKENIZATION_REQUEST_LENGTH:
|
|
195
|
-
result: TokenizationRequestResult = self._make_tokenization_request(text)
|
|
196
|
-
return result.tokens, result.text
|
|
197
|
-
# Otherwise, splits the text to chunks, tokenizes each chunk, and re-assembles them.
|
|
198
|
-
else:
|
|
199
|
-
all_tokens: List[TokenizationToken] = []
|
|
200
|
-
normalized_text_chunks: List[str] = []
|
|
201
|
-
# The number of gpt2-style tokens we have tokenized with the AI21 tokenizer.
|
|
202
|
-
num_processed_tokens: int = 0
|
|
203
|
-
# The length of the (normalized) text string we have tokenized with the AI21 tokenizer.
|
|
204
|
-
num_processed_positions: int = 0
|
|
205
|
-
while num_processed_tokens < len(gpt2_tokens):
|
|
206
|
-
token_chunk_size: int = min(
|
|
207
|
-
len(gpt2_tokens) - num_processed_tokens, AI21WindowService.MAX_TOKENIZATION_REQUEST_LENGTH
|
|
208
|
-
)
|
|
209
|
-
token_chunk: List[TokenizationToken] = gpt2_tokens[
|
|
210
|
-
num_processed_tokens : num_processed_tokens + token_chunk_size
|
|
211
|
-
]
|
|
212
|
-
text_chunk: str = self.gpt2_window_service.decode(token_chunk)
|
|
213
|
-
# We need to avoid generating byte tokens when splitting the text
|
|
214
|
-
while text_chunk.endswith("\ufffd"):
|
|
215
|
-
token_chunk_size -= 1
|
|
216
|
-
token_chunk = gpt2_tokens[num_processed_tokens : num_processed_tokens + token_chunk_size]
|
|
217
|
-
text_chunk = self.gpt2_window_service.decode(token_chunk)
|
|
218
|
-
chunk_result: TokenizationRequestResult = self._make_tokenization_request(text_chunk)
|
|
219
|
-
chunk_tokens: List[TokenizationToken]
|
|
220
|
-
normalized_text_chunk: str
|
|
221
|
-
chunk_tokens, normalized_text_chunk = chunk_result.tokens, chunk_result.text
|
|
222
|
-
# Removes the empty tokens introduced by the split.
|
|
223
|
-
assert chunk_tokens[0].text_range is not None
|
|
224
|
-
if num_processed_tokens != 0 and chunk_tokens[0].text_range.start == chunk_tokens[0].text_range.end:
|
|
225
|
-
chunk_tokens = chunk_tokens[1:]
|
|
226
|
-
else:
|
|
227
|
-
chunk_tokens = chunk_tokens[:]
|
|
228
|
-
|
|
229
|
-
# Shifts the start and end index of each token
|
|
230
|
-
shifted_tokens: List[TokenizationToken] = []
|
|
231
|
-
for token in chunk_tokens:
|
|
232
|
-
assert token.text_range is not None
|
|
233
|
-
shifted_tokens.append(
|
|
234
|
-
TokenizationToken(
|
|
235
|
-
value=token.value,
|
|
236
|
-
text_range=TextRange(
|
|
237
|
-
start=token.text_range.start + num_processed_positions,
|
|
238
|
-
end=token.text_range.end + num_processed_positions,
|
|
239
|
-
),
|
|
240
|
-
)
|
|
241
|
-
)
|
|
242
|
-
all_tokens.extend(shifted_tokens)
|
|
243
|
-
normalized_text_chunks.append(normalized_text_chunk)
|
|
244
|
-
num_processed_tokens += token_chunk_size
|
|
245
|
-
num_processed_positions += len(normalized_text_chunk)
|
|
246
|
-
|
|
247
|
-
return all_tokens, "".join(normalized_text_chunks)
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
from typing import List, Optional
|
|
2
|
-
|
|
3
|
-
from helm.tokenizers.cohere_tokenizer import CohereTokenizer
|
|
4
|
-
from .local_window_service import LocalWindowService
|
|
5
|
-
from .window_service import EncodeResult
|
|
6
|
-
from helm.common.tokenization_request import (
|
|
7
|
-
TokenizationRequest,
|
|
8
|
-
TokenizationRequestResult,
|
|
9
|
-
TokenizationToken,
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class CohereWindowService(LocalWindowService):
|
|
14
|
-
def encode(self, text: str, truncation: bool = False, max_length: Optional[int] = None) -> EncodeResult:
|
|
15
|
-
"""
|
|
16
|
-
Encodes the input text to tokens.
|
|
17
|
-
"""
|
|
18
|
-
if max_length is None:
|
|
19
|
-
max_length = self.max_request_length
|
|
20
|
-
|
|
21
|
-
response: TokenizationRequestResult
|
|
22
|
-
tokens: List[TokenizationToken] = []
|
|
23
|
-
if truncation or len(text) <= CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH:
|
|
24
|
-
response = self.service.tokenize(
|
|
25
|
-
TokenizationRequest(
|
|
26
|
-
text,
|
|
27
|
-
tokenizer=self.tokenizer_name,
|
|
28
|
-
# The Cohere API does not support decoding, so set `encode` to False to get the value of tokens
|
|
29
|
-
# as strings so we can simply concatenate them when we need to decode.
|
|
30
|
-
encode=False,
|
|
31
|
-
truncation=truncation,
|
|
32
|
-
max_length=max_length,
|
|
33
|
-
)
|
|
34
|
-
)
|
|
35
|
-
tokens = response.tokens
|
|
36
|
-
else:
|
|
37
|
-
# Perform chunk encoding: Cohere doesn't support long sequences, so break it up into chunks
|
|
38
|
-
# and make a request for each chunk.
|
|
39
|
-
# This can potentially break up valid tokens at the end of the chunk, but the chunk size
|
|
40
|
-
# is large enough that this happens infrequently.
|
|
41
|
-
chunk_size: int = CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH
|
|
42
|
-
for i in range(0, len(text), chunk_size):
|
|
43
|
-
chunk: str = text[i : chunk_size + i]
|
|
44
|
-
response = self.service.tokenize(
|
|
45
|
-
TokenizationRequest(chunk, tokenizer=self.tokenizer_name, encode=False, truncation=False)
|
|
46
|
-
)
|
|
47
|
-
tokens.extend(response.tokens)
|
|
48
|
-
|
|
49
|
-
return EncodeResult(text=text, tokens=tokens)
|
|
50
|
-
|
|
51
|
-
def get_num_tokens(self, text: str) -> int:
|
|
52
|
-
"""Tokenizes the text and returns the number of tokens."""
|
|
53
|
-
# We need this check since we can't pass in empty string via the `tokenize` endpoint
|
|
54
|
-
if len(text) == 0:
|
|
55
|
-
return 0
|
|
56
|
-
return len(self.encode(text).tokens)
|
|
57
|
-
|
|
58
|
-
def decode(self, tokens: List[TokenizationToken], normalized_text: Optional[str] = None) -> str:
|
|
59
|
-
"""
|
|
60
|
-
The Cohere API does not support decoding, but we're able to recover the original text from the
|
|
61
|
-
values of the tokens by concatenating them.
|
|
62
|
-
|
|
63
|
-
Note this logic currently only works with English text.
|
|
64
|
-
"""
|
|
65
|
-
token_strings = []
|
|
66
|
-
for token in tokens:
|
|
67
|
-
assert isinstance(token.value, str)
|
|
68
|
-
token_strings.append(token.value)
|
|
69
|
-
return "".join(token_strings)
|
|
70
|
-
|
|
71
|
-
def fits_within_context_window(self, text: str, expected_completion_token_length: int = 0) -> bool:
|
|
72
|
-
"""
|
|
73
|
-
Checks if the given text fits within the context window given by `max_request_length`
|
|
74
|
-
taking to account the expected completion length (defaults to 0).
|
|
75
|
-
|
|
76
|
-
According to https://docs.cohere.ai/tokenize-reference#request, for tokenize, text: "the string to
|
|
77
|
-
be tokenized, the minimum text length is 1 character, and the maximum text length is 65,536 characters.",
|
|
78
|
-
so first check if the text has fewer than 65,536 characters.
|
|
79
|
-
"""
|
|
80
|
-
return (
|
|
81
|
-
len(text) <= CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH
|
|
82
|
-
and self.get_num_tokens(text) + expected_completion_token_length <= self.max_request_length
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
def truncate_from_right(self, text: str, expected_completion_token_length: int = 0) -> str:
|
|
86
|
-
"""
|
|
87
|
-
Truncates text from the right to fit within the context window given by `max_request_length`
|
|
88
|
-
minus the expected completion length (defaults to 0).
|
|
89
|
-
"""
|
|
90
|
-
# First truncate the text so it's within `CohereClient.TOKENIZE_MAX_TEXT_LENGTH` length.
|
|
91
|
-
text = text[: CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH]
|
|
92
|
-
|
|
93
|
-
max_length: int = self.max_request_length - expected_completion_token_length
|
|
94
|
-
result: str = self.decode(self.encode(text, truncation=True, max_length=max_length).tokens)
|
|
95
|
-
|
|
96
|
-
# HACK: For the vast majority of cases, the above logic works, but it sometimes doesn't work
|
|
97
|
-
# for non-English text, since Cohere technically only supports English at the moment.
|
|
98
|
-
while not self.fits_within_context_window(result, expected_completion_token_length):
|
|
99
|
-
result = result[:-1]
|
|
100
|
-
|
|
101
|
-
return result
|
|
@@ -1,163 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
|
|
3
|
-
from typing import List
|
|
4
|
-
import unittest
|
|
5
|
-
from unittest import mock
|
|
6
|
-
|
|
7
|
-
from helm.common.authentication import Authentication
|
|
8
|
-
from helm.common.tokenization_request import TokenizationRequestResult, TokenizationToken, TextRange
|
|
9
|
-
from helm.proxy.services.remote_service import RemoteService
|
|
10
|
-
from .test_utils import TEST_PROMPT
|
|
11
|
-
from .tokenizer_service import TokenizerService
|
|
12
|
-
from .window_service_factory import WindowServiceFactory
|
|
13
|
-
|
|
14
|
-
# TODO(#1522): Remove "▁" from the tokens.
|
|
15
|
-
TEST_TOKEN_REPRESENTATIONS: List[TokenizationToken] = [
|
|
16
|
-
TokenizationToken(value="▁The▁Center▁for", text_range=TextRange(start=0, end=14)),
|
|
17
|
-
TokenizationToken(value="▁Research▁on", text_range=TextRange(start=14, end=26)),
|
|
18
|
-
TokenizationToken(value="▁Foundation", text_range=TextRange(start=26, end=37)),
|
|
19
|
-
TokenizationToken(value="▁Models", text_range=TextRange(start=37, end=44)),
|
|
20
|
-
TokenizationToken(value="▁", text_range=TextRange(start=44, end=45)),
|
|
21
|
-
TokenizationToken(value="(", text_range=TextRange(start=45, end=46)),
|
|
22
|
-
TokenizationToken(value="CRF", text_range=TextRange(start=46, end=49)),
|
|
23
|
-
TokenizationToken(value="M", text_range=TextRange(start=49, end=50)),
|
|
24
|
-
TokenizationToken(value=")", text_range=TextRange(start=50, end=51)),
|
|
25
|
-
TokenizationToken(value="▁is", text_range=TextRange(start=51, end=54)),
|
|
26
|
-
TokenizationToken(value="▁an▁interdisciplinary", text_range=TextRange(start=54, end=75)),
|
|
27
|
-
TokenizationToken(value="▁initiative", text_range=TextRange(start=75, end=86)),
|
|
28
|
-
TokenizationToken(value="▁born▁out▁of", text_range=TextRange(start=86, end=98)),
|
|
29
|
-
TokenizationToken(value="▁the", text_range=TextRange(start=98, end=102)),
|
|
30
|
-
TokenizationToken(value="▁Stanford", text_range=TextRange(start=102, end=111)),
|
|
31
|
-
TokenizationToken(value="▁Institute▁for", text_range=TextRange(start=111, end=125)),
|
|
32
|
-
TokenizationToken(value="▁Human", text_range=TextRange(start=125, end=131)),
|
|
33
|
-
TokenizationToken(value="-Centered", text_range=TextRange(start=131, end=140)),
|
|
34
|
-
TokenizationToken(value="▁Artificial▁Intelligence", text_range=TextRange(start=140, end=164)),
|
|
35
|
-
TokenizationToken(value="▁", text_range=TextRange(start=164, end=165)),
|
|
36
|
-
TokenizationToken(value="(", text_range=TextRange(start=165, end=166)),
|
|
37
|
-
TokenizationToken(value="HAI", text_range=TextRange(start=166, end=169)),
|
|
38
|
-
TokenizationToken(value=")", text_range=TextRange(start=169, end=170)),
|
|
39
|
-
TokenizationToken(value="▁that", text_range=TextRange(start=170, end=175)),
|
|
40
|
-
TokenizationToken(value="▁aims▁to▁make", text_range=TextRange(start=175, end=188)),
|
|
41
|
-
TokenizationToken(value="▁fundamental", text_range=TextRange(start=188, end=200)),
|
|
42
|
-
TokenizationToken(value="▁advances▁in", text_range=TextRange(start=200, end=212)),
|
|
43
|
-
TokenizationToken(value="▁the▁study", text_range=TextRange(start=212, end=222)),
|
|
44
|
-
TokenizationToken(value=",", text_range=TextRange(start=222, end=223)),
|
|
45
|
-
TokenizationToken(value="▁development", text_range=TextRange(start=223, end=235)),
|
|
46
|
-
TokenizationToken(value=",", text_range=TextRange(start=235, end=236)),
|
|
47
|
-
TokenizationToken(value="▁and", text_range=TextRange(start=236, end=240)),
|
|
48
|
-
TokenizationToken(value="▁deployment▁of", text_range=TextRange(start=240, end=254)),
|
|
49
|
-
TokenizationToken(value="▁foundation", text_range=TextRange(start=254, end=265)),
|
|
50
|
-
TokenizationToken(value="▁models", text_range=TextRange(start=265, end=272)),
|
|
51
|
-
TokenizationToken(value=".", text_range=TextRange(start=272, end=273)),
|
|
52
|
-
]
|
|
53
|
-
|
|
54
|
-
TEST_TOKENS: List[str] = [
|
|
55
|
-
"▁The▁Center▁for",
|
|
56
|
-
"▁Research▁on",
|
|
57
|
-
"▁Foundation",
|
|
58
|
-
"▁Models",
|
|
59
|
-
"▁",
|
|
60
|
-
"(",
|
|
61
|
-
"CRF",
|
|
62
|
-
"M",
|
|
63
|
-
")",
|
|
64
|
-
"▁is",
|
|
65
|
-
"▁an▁interdisciplinary",
|
|
66
|
-
"▁initiative",
|
|
67
|
-
"▁born▁out▁of",
|
|
68
|
-
"▁the",
|
|
69
|
-
"▁Stanford",
|
|
70
|
-
"▁Institute▁for",
|
|
71
|
-
"▁Human",
|
|
72
|
-
"-Centered",
|
|
73
|
-
"▁Artificial▁Intelligence",
|
|
74
|
-
"▁",
|
|
75
|
-
"(",
|
|
76
|
-
"HAI",
|
|
77
|
-
")",
|
|
78
|
-
"▁that",
|
|
79
|
-
"▁aims▁to▁make",
|
|
80
|
-
"▁fundamental",
|
|
81
|
-
"▁advances▁in",
|
|
82
|
-
"▁the▁study",
|
|
83
|
-
",",
|
|
84
|
-
"▁development",
|
|
85
|
-
",",
|
|
86
|
-
"▁and",
|
|
87
|
-
"▁deployment▁of",
|
|
88
|
-
"▁foundation",
|
|
89
|
-
"▁models",
|
|
90
|
-
".",
|
|
91
|
-
]
|
|
92
|
-
|
|
93
|
-
REQUEST_RESULT: TokenizationRequestResult
|
|
94
|
-
LONG_REQUEST_RESULT: TokenizationRequestResult
|
|
95
|
-
TRUNCATED_REQUEST_RESULT: TokenizationRequestResult
|
|
96
|
-
|
|
97
|
-
# The request results are too long to be put here, so we save them to file.
|
|
98
|
-
# TODO: Re-encode requests and results.
|
|
99
|
-
# with open("src/helm/benchmark/window_services/mock_ai21_tokenizer_request_results.pkl", "rb") as f:
|
|
100
|
-
# REQUEST_RESULT, LONG_REQUEST_RESULT, TRUNCATED_REQUEST_RESULT = pickle.load(f)
|
|
101
|
-
REQUEST_RESULT = None # type:ignore
|
|
102
|
-
LONG_REQUEST_RESULT = None # type:ignore
|
|
103
|
-
TRUNCATED_REQUEST_RESULT = None # type:ignore
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
@unittest.skip("The requests and results cannot be unpicked after the modules moved")
|
|
107
|
-
class TestAI21WindowService:
|
|
108
|
-
def setup_method(self):
|
|
109
|
-
# We use mocking for tokenization calls so no valid api_keys are required.
|
|
110
|
-
auth = Authentication(api_key="DUMMY_API_KEY")
|
|
111
|
-
service = TokenizerService(RemoteService("DUMMY_URL"), auth)
|
|
112
|
-
self.window_service = WindowServiceFactory.get_window_service("ai21/j1-jumbo", service)
|
|
113
|
-
|
|
114
|
-
@mock.patch("helm.benchmark.tokenizer.ai21_tokenizer.TokenizerService.tokenize", return_value=REQUEST_RESULT)
|
|
115
|
-
@pytest.mark.skip("TODO: update the pickle file with the response")
|
|
116
|
-
def test_encode(self, mock_tokenize):
|
|
117
|
-
assert self.window_service.encode(TEST_PROMPT).tokens == TEST_TOKEN_REPRESENTATIONS
|
|
118
|
-
|
|
119
|
-
@pytest.mark.skip("TODO: update the pickle file with the response")
|
|
120
|
-
def test_decode(self):
|
|
121
|
-
assert self.window_service.decode(TEST_TOKEN_REPRESENTATIONS, TEST_PROMPT) == TEST_PROMPT
|
|
122
|
-
assert self.window_service.decode(TEST_TOKEN_REPRESENTATIONS, TEST_PROMPT)[:-1] == TEST_PROMPT[:-1]
|
|
123
|
-
|
|
124
|
-
@mock.patch("helm.benchmark.tokenizer.ai21_tokenizer.TokenizerService.tokenize", return_value=REQUEST_RESULT)
|
|
125
|
-
@pytest.mark.skip("TODO: update the pickle file with the response")
|
|
126
|
-
def test_tokenize(self, mock_tokenize):
|
|
127
|
-
assert self.window_service.tokenize(TEST_PROMPT) == TEST_TOKENS
|
|
128
|
-
|
|
129
|
-
@mock.patch("helm.benchmark.tokenizer.ai21_tokenizer.TokenizerService.tokenize", return_value=REQUEST_RESULT)
|
|
130
|
-
@pytest.mark.skip("TODO: update the pickle file with the response")
|
|
131
|
-
def test_fits_within_context_window(self, mock_tokenize):
|
|
132
|
-
# Should fit in the context window since we subtracted the number of tokens of the test prompt
|
|
133
|
-
# from the max context window
|
|
134
|
-
assert self.window_service.fits_within_context_window(TEST_PROMPT, 2047 - 36)
|
|
135
|
-
# Should not fit in the context window because we're expecting one more extra token in the completion
|
|
136
|
-
assert not self.window_service.fits_within_context_window(TEST_PROMPT, 2047 - 36 + 1)
|
|
137
|
-
|
|
138
|
-
@mock.patch(
|
|
139
|
-
"helm.benchmark.tokenizer.ai21_tokenizer.TokenizerService.tokenize",
|
|
140
|
-
side_effect=[
|
|
141
|
-
LONG_REQUEST_RESULT,
|
|
142
|
-
LONG_REQUEST_RESULT,
|
|
143
|
-
TRUNCATED_REQUEST_RESULT,
|
|
144
|
-
TRUNCATED_REQUEST_RESULT,
|
|
145
|
-
TRUNCATED_REQUEST_RESULT,
|
|
146
|
-
],
|
|
147
|
-
)
|
|
148
|
-
@pytest.mark.skip("TODO: update the pickle file with the response")
|
|
149
|
-
def test_truncate_from_right(self, mock_tokenize):
|
|
150
|
-
# Create a prompt that exceed max context length: 36 * 57 = 2052 tokens.
|
|
151
|
-
# Our naive concatenation of the strings here also leads to extra tokens.
|
|
152
|
-
long_prompt: str = TEST_PROMPT * 57
|
|
153
|
-
assert not self.window_service.fits_within_context_window(long_prompt)
|
|
154
|
-
|
|
155
|
-
# Truncate and ensure it fits within the context window
|
|
156
|
-
truncated_long_prompt: str = self.window_service.truncate_from_right(long_prompt)
|
|
157
|
-
assert self.window_service.get_num_tokens(truncated_long_prompt) == 2047
|
|
158
|
-
assert self.window_service.fits_within_context_window(truncated_long_prompt)
|
|
159
|
-
|
|
160
|
-
@mock.patch("helm.benchmark.tokenizer.ai21_tokenizer.TokenizerService.tokenize", return_value=REQUEST_RESULT)
|
|
161
|
-
@pytest.mark.skip("TODO: update the pickle file with the response")
|
|
162
|
-
def test_tokenize_and_count(self, mock_tokenize):
|
|
163
|
-
assert self.window_service.get_num_tokens(TEST_PROMPT) == 36
|
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
# mypy: check_untyped_defs = False
|
|
2
|
-
import os
|
|
3
|
-
import shutil
|
|
4
|
-
import tempfile
|
|
5
|
-
from typing import List
|
|
6
|
-
|
|
7
|
-
from sqlitedict import SqliteDict
|
|
8
|
-
|
|
9
|
-
from helm.common.cache_backend_config import SqliteCacheBackendConfig
|
|
10
|
-
from helm.common.general import ensure_directory_exists
|
|
11
|
-
from .test_cohere_window_service_utils import REQUESTS_TO_RESPONSES, TEST_PROMPT, TOKENIZED_PROMPT
|
|
12
|
-
from .tokenizer_service import TokenizerService
|
|
13
|
-
from .window_service_factory import WindowServiceFactory
|
|
14
|
-
from .test_utils import get_tokenizer_service
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class TestCohereWindowService:
|
|
18
|
-
@classmethod
|
|
19
|
-
def setup_class(cls):
|
|
20
|
-
cls.path: str = tempfile.mkdtemp()
|
|
21
|
-
cache_path: str = os.path.join(cls.path, "cache")
|
|
22
|
-
ensure_directory_exists(cache_path)
|
|
23
|
-
|
|
24
|
-
# Build the cache with real requests and responses
|
|
25
|
-
with SqliteDict(os.path.join(cache_path, "cohere.sqlite")) as cache:
|
|
26
|
-
for request, response in REQUESTS_TO_RESPONSES.items():
|
|
27
|
-
cache[request] = response
|
|
28
|
-
cache.commit()
|
|
29
|
-
|
|
30
|
-
# Requests/responses are already cached. Just write out a fake key to credentials.conf.
|
|
31
|
-
with open(os.path.join(cls.path, "credentials.conf"), "w") as f:
|
|
32
|
-
f.write("cohereApiKey: secret")
|
|
33
|
-
|
|
34
|
-
service: TokenizerService = get_tokenizer_service(cls.path, SqliteCacheBackendConfig(cache_path))
|
|
35
|
-
cls.window_service = WindowServiceFactory.get_window_service("cohere/xlarge-20220609", service)
|
|
36
|
-
cls.prompt: str = TEST_PROMPT
|
|
37
|
-
cls.tokenized_prompt: List[str] = TOKENIZED_PROMPT
|
|
38
|
-
|
|
39
|
-
@classmethod
|
|
40
|
-
def teardown_class(cls):
|
|
41
|
-
shutil.rmtree(cls.path)
|
|
42
|
-
|
|
43
|
-
def test_max_request_length(self):
|
|
44
|
-
assert self.window_service.max_request_length == 2048
|
|
45
|
-
|
|
46
|
-
def test_encode(self):
|
|
47
|
-
assert self.window_service.encode(self.prompt).token_values == self.tokenized_prompt
|
|
48
|
-
|
|
49
|
-
def test_decode(self):
|
|
50
|
-
assert self.window_service.decode(self.window_service.encode(self.prompt).tokens) == self.prompt
|
|
51
|
-
|
|
52
|
-
def test_tokenize(self):
|
|
53
|
-
assert self.window_service.tokenize(self.prompt) == self.tokenized_prompt
|
|
54
|
-
|
|
55
|
-
def test_tokenize_and_count(self):
|
|
56
|
-
assert self.window_service.get_num_tokens(self.prompt) == 6
|
|
57
|
-
|
|
58
|
-
def test_fits_within_context_window(self):
|
|
59
|
-
# Should fit in the context window since we subtracted the number of tokens of the prompt
|
|
60
|
-
# from the max context window.
|
|
61
|
-
assert self.window_service.fits_within_context_window(self.prompt, self.window_service.max_request_length - 6)
|
|
62
|
-
# Should not fit in the context window because we're expecting one more extra token in the completion.
|
|
63
|
-
assert not self.window_service.fits_within_context_window(
|
|
64
|
-
self.prompt, self.window_service.max_request_length - 6 + 1
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
def test_truncate_from_right(self):
|
|
68
|
-
# Create a prompt that exceed max context length: 6 * 342 = 2,052 tokens
|
|
69
|
-
long_prompt: str = self.prompt * 342
|
|
70
|
-
assert not self.window_service.fits_within_context_window(long_prompt)
|
|
71
|
-
|
|
72
|
-
# Truncate and ensure it fits within the context window
|
|
73
|
-
truncated_long_prompt: str = self.window_service.truncate_from_right(long_prompt)
|
|
74
|
-
assert self.window_service.get_num_tokens(truncated_long_prompt) == self.window_service.max_request_length
|
|
75
|
-
assert self.window_service.fits_within_context_window(truncated_long_prompt)
|