evalscope 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (108) hide show
  1. evalscope/backend/opencompass/tasks/eval_api.py +2 -1
  2. evalscope/backend/opencompass/tasks/eval_datasets.py +1 -0
  3. evalscope/backend/rag_eval/clip_benchmark/utils/webdataset_convert.py +230 -0
  4. evalscope/backend/rag_eval/clip_benchmark/utils/webdatasets.txt +43 -0
  5. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +87 -0
  6. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +36 -0
  7. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +26 -0
  8. evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +41 -0
  9. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +60 -0
  10. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +36 -0
  11. evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +22 -0
  12. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +35 -0
  13. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  14. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  15. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  16. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  17. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +34 -0
  18. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +36 -0
  19. evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +25 -0
  20. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  21. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  22. evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +16 -0
  23. evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +24 -0
  24. evalscope/backend/rag_eval/ragas/prompts/persona_prompt.py +18 -0
  25. evalscope/backend/vlm_eval_kit/backend_manager.py +23 -21
  26. evalscope/benchmarks/ceval/samples.jsonl +1 -0
  27. evalscope/benchmarks/cmmlu/samples.jsonl +5 -0
  28. evalscope/benchmarks/mmlu/samples.jsonl +5 -0
  29. evalscope/benchmarks/race/samples.jsonl +5 -0
  30. evalscope/benchmarks/trivia_qa/samples.jsonl +5 -0
  31. evalscope/cli/start_perf.py +8 -11
  32. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +58485 -0
  33. evalscope/metrics/resources/gpt2-zhcn3-v4.json +1 -0
  34. evalscope/metrics/rouge_metric.py +30 -15
  35. evalscope/perf/arguments.py +179 -0
  36. evalscope/perf/benchmark.py +245 -0
  37. evalscope/perf/http_client.py +127 -711
  38. evalscope/perf/main.py +35 -0
  39. evalscope/perf/plugin/__init__.py +2 -0
  40. evalscope/perf/plugin/api/__init__.py +3 -0
  41. evalscope/perf/{api_plugin_base.py → plugin/api/base.py} +17 -18
  42. evalscope/perf/{custom_api.py → plugin/api/custom_api.py} +25 -19
  43. evalscope/perf/{dashscope_api.py → plugin/api/dashscope_api.py} +28 -14
  44. evalscope/perf/{openai_api.py → plugin/api/openai_api.py} +51 -27
  45. evalscope/perf/plugin/datasets/__init__.py +6 -0
  46. evalscope/perf/{dataset_plugin_base.py → plugin/datasets/base.py} +13 -10
  47. evalscope/perf/plugin/datasets/custom.py +21 -0
  48. evalscope/perf/plugin/datasets/flickr8k.py +51 -0
  49. evalscope/perf/{datasets → plugin/datasets}/line_by_line.py +9 -5
  50. evalscope/perf/plugin/datasets/longalpaca.py +28 -0
  51. evalscope/perf/plugin/datasets/openqa.py +38 -0
  52. evalscope/perf/plugin/datasets/speed_benchmark.py +50 -0
  53. evalscope/perf/plugin/registry.py +54 -0
  54. evalscope/perf/{how_to_analysis_result.py → utils/analysis_result.py} +11 -5
  55. evalscope/perf/utils/benchmark_util.py +135 -0
  56. evalscope/perf/utils/chat_service.py +252 -0
  57. evalscope/perf/utils/db_util.py +200 -0
  58. evalscope/perf/utils/handler.py +46 -0
  59. evalscope/perf/utils/local_server.py +139 -0
  60. evalscope/registry/config/cfg_arena.yaml +77 -0
  61. evalscope/registry/config/cfg_arena_zhihu.yaml +63 -0
  62. evalscope/registry/config/cfg_pairwise_baseline.yaml +83 -0
  63. evalscope/registry/config/cfg_single.yaml +78 -0
  64. evalscope/registry/data/prompt_template/lmsys_v2.jsonl +8 -0
  65. evalscope/registry/data/prompt_template/prompt_templates.jsonl +8 -0
  66. evalscope/registry/data/qa_browser/battle.jsonl +634 -0
  67. evalscope/registry/data/qa_browser/category_mapping.yaml +10 -0
  68. evalscope/registry/data/question.jsonl +80 -0
  69. evalscope/third_party/longbench_write/README.md +118 -0
  70. evalscope/third_party/longbench_write/default_task.json +27 -0
  71. evalscope/third_party/longbench_write/default_task.yaml +24 -0
  72. evalscope/third_party/toolbench_static/README.md +118 -0
  73. evalscope/third_party/toolbench_static/config_default.json +15 -0
  74. evalscope/third_party/toolbench_static/config_default.yaml +12 -0
  75. evalscope/third_party/toolbench_static/requirements.txt +2 -0
  76. evalscope/utils/logger.py +18 -20
  77. evalscope/utils/utils.py +41 -42
  78. evalscope/version.py +2 -2
  79. evalscope-0.7.1.dist-info/LICENSE +203 -0
  80. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/METADATA +93 -35
  81. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/RECORD +101 -31
  82. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/WHEEL +1 -1
  83. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/top_level.txt +1 -0
  84. tests/cli/__init__.py +1 -0
  85. tests/cli/test_run.py +76 -0
  86. tests/perf/__init__.py +1 -0
  87. tests/perf/test_perf.py +96 -0
  88. tests/rag/test_clip_benchmark.py +85 -0
  89. tests/rag/test_mteb.py +136 -0
  90. tests/rag/test_ragas.py +120 -0
  91. tests/swift/__init__.py +1 -0
  92. tests/swift/test_run_swift_eval.py +146 -0
  93. tests/swift/test_run_swift_vlm_eval.py +128 -0
  94. tests/swift/test_run_swift_vlm_jugde_eval.py +157 -0
  95. tests/test_run_all.py +12 -0
  96. tests/vlm/__init__.py +1 -0
  97. tests/vlm/test_vlmeval.py +59 -0
  98. evalscope/perf/_logging.py +0 -32
  99. evalscope/perf/datasets/longalpaca_12k.py +0 -20
  100. evalscope/perf/datasets/openqa.py +0 -22
  101. evalscope/perf/plugin_registry.py +0 -35
  102. evalscope/perf/query_parameters.py +0 -42
  103. evalscope/perf/server_sent_event.py +0 -43
  104. evalscope/preprocess/tokenizers/gpt2_tokenizer.py +0 -221
  105. /evalscope/perf/{datasets → utils}/__init__.py +0 -0
  106. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/entry_points.txt +0 -0
  107. {evalscope/preprocess → tests}/__init__.py +0 -0
  108. {evalscope/preprocess/tokenizers → tests/rag}/__init__.py +0 -0
tests/test_run_all.py ADDED
@@ -0,0 +1,12 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import subprocess
4
+
5
+ if __name__ == '__main__':
6
+ cmd = f'TEST_LEVEL_LIST=0,1 python3 -m unittest discover tests'
7
+ run_res = subprocess.run(cmd, text=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
8
+
9
+ if run_res.returncode == 0:
10
+ print(f'>>test_run_all stdout: {run_res.stdout}')
11
+ else:
12
+ print(f'>>test_run_all stderr: {run_res.stderr}')
tests/vlm/__init__.py ADDED
@@ -0,0 +1 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
@@ -0,0 +1,59 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import subprocess
4
+ import unittest
5
+
6
+ from evalscope.run import run_task
7
+ from evalscope.summarizer import Summarizer
8
+ from evalscope.utils import is_module_installed, test_level_list
9
+ from evalscope.utils.logger import get_logger
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ class TestVLMEval(unittest.TestCase):
15
+
16
+ def setUp(self) -> None:
17
+ self._check_env('vlmeval')
18
+
19
+ def tearDown(self) -> None:
20
+ pass
21
+
22
+ @staticmethod
23
+ def _check_env(module_name: str):
24
+ if is_module_installed(module_name):
25
+ logger.info(f'{module_name} is installed.')
26
+ else:
27
+ raise ModuleNotFoundError(f'run: pip install {module_name}')
28
+
29
+ @unittest.skipUnless(0 in test_level_list(), 'skip test in current test level')
30
+ def test_run_vlm_eval_local(self):
31
+ task_cfg = {
32
+ 'eval_backend': 'VLMEvalKit',
33
+ 'eval_config': {
34
+ 'data': ['SEEDBench_IMG', 'ChartQA_TEST'],
35
+ 'limit': 20,
36
+ 'mode': 'all',
37
+ 'model': [{
38
+ 'name': 'qwen-vl-chat',
39
+ 'model_path': '../models/Qwen-VL-Chat'
40
+ }], # model name for VLMEval config
41
+ 'nproc': 1,
42
+ 'reuse': True,
43
+ 'work_dir': 'outputs'
44
+ }
45
+ }
46
+
47
+ logger.info(f'>> Start to run task: {task_cfg}')
48
+
49
+ run_task(task_cfg)
50
+
51
+ logger.info('>> Start to get the report with summarizer ...')
52
+ report_list = Summarizer.get_report_from_cfg(task_cfg)
53
+ logger.info(f'\n>>The report list: {report_list}')
54
+
55
+ assert len(report_list) > 0, f'Failed to get report list: {report_list}'
56
+
57
+
58
+ if __name__ == '__main__':
59
+ unittest.main(buffer=False)
@@ -1,32 +0,0 @@
1
- import logging
2
- import os
3
-
4
-
5
- logger = logging.getLogger('perf')
6
-
7
-
8
- def enable_logging():
9
- level = os.environ.get('LOGGING_LEVEL', 'info')
10
- if level is not None: # set logging level.
11
- if level not in ['info', 'debug']:
12
- # set logging level env, but invalid value, use default.
13
- level = 'info'
14
- if level == 'info':
15
- logger.setLevel(logging.INFO)
16
- else:
17
- logger.setLevel(logging.DEBUG)
18
- # set default logging handler
19
- console_handler = logging.StreamHandler()
20
- formatter = logging.Formatter(
21
- '%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s' # noqa E501
22
- )
23
- #formatter = logging.Formatter(
24
- # '%(asctime)s - %(name)s - %(levelname)s - %(message)s' # noqa E501
25
- #)
26
- console_handler.setFormatter(formatter)
27
- logger.addHandler(console_handler)
28
-
29
-
30
- # in release disable dashscope log
31
- # you can enable dashscope log for debugger.
32
- enable_logging()
@@ -1,20 +0,0 @@
1
- import sys
2
- from typing import Any, Dict, Iterator, List
3
- from evalscope.perf.dataset_plugin_base import DatasetPluginBase
4
-
5
- from evalscope.perf.plugin_registry import register_dataset
6
- from evalscope.perf.query_parameters import QueryParameters
7
-
8
- @register_dataset('longalpaca')
9
- class LongAlpacaDatasetPlugin(DatasetPluginBase):
10
- """Read data from file which is list of requests.
11
- Sample: https://huggingface.co/datasets/Yukang/LongAlpaca-12k
12
- """
13
- def __init__(self, query_parameters: QueryParameters):
14
- super().__init__(query_parameters)
15
-
16
- def build_messages(self) -> Iterator[List[Dict]]:
17
- for item in self.dataset_json_list(self.query_parameters.dataset_path):
18
- prompt = item['instruction'].strip()
19
- if len(prompt) > self.query_parameters.min_prompt_length and len(prompt) < self.query_parameters.max_prompt_length:
20
- yield [{'role': 'user', 'content': prompt}]
@@ -1,22 +0,0 @@
1
- from sys import maxsize
2
- import sys
3
- from typing import Any, Dict, Iterator, List
4
- import json
5
- from evalscope.perf.dataset_plugin_base import DatasetPluginBase
6
- from evalscope.perf.plugin_registry import register_dataset
7
- from evalscope.perf.query_parameters import QueryParameters
8
-
9
- @register_dataset('openqa')
10
- class OpenqaDatasetPlugin(DatasetPluginBase):
11
- """Read dataset and return prompt.
12
- Datasets: https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese/blob/main/open_qa.jsonl
13
- """
14
- def __init__(self, query_parameters: QueryParameters):
15
- super().__init__(query_parameters)
16
-
17
- def build_messages(self) -> Iterator[List[Dict]]:
18
- for item in self.dataset_line_by_line(self.query_parameters.dataset_path):
19
- item = json.loads(item)
20
- prompt = item['question'].strip()
21
- if len(prompt) > self.query_parameters.min_prompt_length and len(prompt) < self.query_parameters.max_prompt_length:
22
- yield [{'role': 'user', 'content': prompt}]
@@ -1,35 +0,0 @@
1
-
2
- from typing import Any
3
-
4
-
5
- class PluginRegistry:
6
- def __init__(self):
7
- self._registry = {}
8
-
9
- def register(self, name, cls):
10
- self._registry[name] = cls
11
- return cls
12
-
13
- def get_class(self, name):
14
- return self._registry[name]
15
-
16
- def all_classes(self):
17
- return list(self._registry.keys())
18
-
19
- def __call__(self, name: str) -> Any:
20
- return self.get_class(name)
21
-
22
- dataset_registry = PluginRegistry()
23
- api_registry = PluginRegistry()
24
-
25
- def register_dataset(name: str):
26
- def class_decorator(cls):
27
- dataset_registry.register(name, cls)
28
- return cls
29
- return class_decorator
30
-
31
- def register_api(name: str):
32
- def class_decorator(cls):
33
- api_registry.register(name, cls)
34
- return cls
35
- return class_decorator
@@ -1,42 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
-
5
- @dataclass
6
- class QueryParameters:
7
- model: str
8
- prompt: Optional[str]
9
- dataset: Optional[str]
10
- query_template: Optional[str]
11
- dataset_path: Optional[str]
12
- frequency_penalty: Optional[float]
13
- logprobs: Optional[bool]
14
- max_tokens: Optional[int]
15
- n_choices: Optional[int]
16
- seed: Optional[int]
17
- stop: Optional[str]
18
- stream: Optional[bool]
19
- temperature: Optional[float]
20
- top_p: Optional[float]
21
- max_prompt_length: Optional[int]
22
- min_prompt_length: Optional[int]
23
- include_usage: Optional[bool]
24
-
25
- def __init__(self, args):
26
- self.model = args.model
27
- self.prompt = args.prompt
28
- self.dataset = args.dataset
29
- self.query_template = args.query_template
30
- self.dataset_path = args.dataset_path
31
- self.frequency_penalty = args.frequency_penalty
32
- self.logprobs = args.logprobs
33
- self.max_tokens = args.max_tokens
34
- self.n_choices = args.n_choices
35
- self.seed = args.seed
36
- self.stop = args.stop
37
- self.stream = args.stream
38
- self.temperature = args.temperature
39
- self.top_p = args.top_p
40
- self.max_prompt_length = args.max_prompt_length
41
- self.min_prompt_length = args.min_prompt_length
42
- self.stop_token_ids = args.stop_token_ids
@@ -1,43 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- @dataclass
4
- class ServerSentEvent(object):
5
- def __init__(self, data='', event=None, id=None, retry=None):
6
- self.data = data
7
- self.event = event
8
- self.id = id
9
- self.retry = retry
10
-
11
- @classmethod
12
- def decode(cls, line):
13
- """ Decode line to ServerSentEvent
14
-
15
-
16
- Args:
17
- line (str): The line.
18
-
19
- Return:
20
- ServerSentEvent (obj:`ServerSentEvent`): The ServerSentEvent object.
21
-
22
- """
23
- if not line:
24
- return None
25
- sse_msg = cls()
26
- # format data:xxx
27
- field_type, _, field_value = line.partition(":")
28
- if field_value.startswith(" "): # compatible with openai api
29
- field_value = field_value[1:]
30
- if field_type == "event":
31
- sse_msg.event = field_value
32
- elif field_type == "data":
33
- field_value = field_value.rstrip()
34
- sse_msg.data = field_value
35
- elif field_type == "id":
36
- sse_msg.id = field_value
37
- elif field_type == "retry":
38
- sse_msg.retry = field_value
39
- else:
40
- pass
41
-
42
- return sse_msg
43
-
@@ -1,221 +0,0 @@
1
- import logging
2
- import sys
3
- from functools import lru_cache
4
- from typing import Sequence
5
-
6
- import json
7
- import regex as re
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def get_pairs(word):
13
- """Return set of symbol pairs in a word.
14
-
15
- Word is represented as tuple of symbols (symbols being variable-length strings).
16
- """
17
- pairs = set()
18
- prev_char = word[0]
19
- for char in word[1:]:
20
- pairs.add((prev_char, char))
21
- prev_char = char
22
- return pairs
23
-
24
-
25
- @lru_cache()
26
- def bytes_to_unicode():
27
- """
28
- Returns list of utf-8 byte and a corresponding list of unicode strings.
29
- The reversible bpe codes work on unicode strings.
30
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
- This is a signficant percentage of your normal, say, 32K bpe vocab.
33
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
- And avoids mapping to whitespace/control characters the bpe code barfs on.
35
- """
36
- bs = list(range(ord('!'),
37
- ord('~') + 1)) + list(range(
38
- ord('¡'),
39
- ord('¬') + 1)) + list(range(ord('®'),
40
- ord('ÿ') + 1))
41
- cs = bs[:]
42
- n = 0
43
- for b in range(2**8):
44
- if b not in bs:
45
- bs.append(b)
46
- cs.append(2**8 + n)
47
- n += 1
48
- cs = [chr(n) for n in cs]
49
- return dict(zip(bs, cs))
50
-
51
-
52
- class GPT2Tokenizer(object):
53
- """
54
- GPT-2 BPE tokenizer. Peculiarities:
55
- - Byte-level BPE
56
-
57
- [NOTE]: Copied from megatron.tokenizer.gpt2_tokenization.GPT2Tokenizer.
58
- """
59
-
60
- def __init__(self,
61
- vocab_file,
62
- merges_file,
63
- errors='replace',
64
- special_tokens=None,
65
- max_len=None):
66
- assert sys.version_info[0] != 2
67
-
68
- self.max_len = max_len if max_len is not None else int(1e12)
69
- self.encoder = json.load(open(vocab_file))
70
- self.decoder = {v: k for k, v in self.encoder.items()}
71
- self.errors = errors # how to handle errors in decoding
72
- self.byte_encoder = bytes_to_unicode()
73
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
74
- bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
75
- bpe_merges = [tuple(merge.split()) for merge in bpe_data]
76
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
77
- self.cache = {}
78
-
79
- # Should haved added re.IGNORECASE so BPE merges can happen for
80
- # capitalized versions of contractions
81
- self.pat = re.compile(
82
- r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
83
- )
84
-
85
- self.special_tokens = {}
86
- self.special_tokens_decoder = {}
87
- self.set_special_tokens(special_tokens)
88
-
89
- def __len__(self):
90
- return len(self.encoder) + len(self.special_tokens)
91
-
92
- def set_special_tokens(self, special_tokens):
93
- """ Add a list of additional tokens to the encoder.
94
- The additional tokens are indexed starting from the last index of the
95
- current vocabulary in the order of the `special_tokens` list.
96
- """
97
- if not special_tokens:
98
- self.special_tokens = {}
99
- self.special_tokens_decoder = {}
100
- return
101
- self.special_tokens = dict((tok, len(self.encoder) + i)
102
- for i, tok in enumerate(special_tokens))
103
- self.special_tokens_decoder = {
104
- v: k
105
- for k, v in self.special_tokens.items()
106
- }
107
- logger.info('Special tokens {}'.format(self.special_tokens))
108
-
109
- def bpe(self, token):
110
- if token in self.cache:
111
- return self.cache[token]
112
- word = tuple(token)
113
- pairs = get_pairs(word)
114
-
115
- if not pairs:
116
- return token
117
-
118
- while True:
119
- bigram = min(
120
- pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
121
- if bigram not in self.bpe_ranks:
122
- break
123
- first, second = bigram
124
- new_word = []
125
- i = 0
126
- while i < len(word):
127
- try:
128
- j = word.index(first, i)
129
- new_word.extend(word[i:j])
130
- i = j
131
- except BaseException:
132
- new_word.extend(word[i:])
133
- break
134
-
135
- if word[i] == first and i < len(word) - 1 and word[
136
- i + 1] == second:
137
- new_word.append(first + second)
138
- i += 2
139
- else:
140
- new_word.append(word[i])
141
- i += 1
142
- new_word = tuple(new_word)
143
- word = new_word
144
- if len(word) == 1:
145
- break
146
- else:
147
- pairs = get_pairs(word)
148
- word = ' '.join(word)
149
- self.cache[token] = word
150
- return word
151
-
152
- def tokenize(self, text):
153
- """ Tokenize a string. """
154
- bpe_tokens = []
155
- for token in re.findall(self.pat, text):
156
- if sys.version_info[0] == 2:
157
- token = ''.join(self.byte_encoder[ord(b)] for b in token)
158
- else:
159
- token = ''.join(self.byte_encoder[b]
160
- for b in token.encode('utf-8'))
161
- bpe_tokens.extend(
162
- bpe_token for bpe_token in self.bpe(token).split(' '))
163
- return bpe_tokens
164
-
165
- def convert_tokens_to_ids(self, tokens):
166
- """ Converts a sequence of tokens into ids using the vocab. """
167
- ids = []
168
- if isinstance(tokens, str):
169
- if tokens in self.special_tokens:
170
- return self.special_tokens[tokens]
171
- else:
172
- return self.encoder.get(tokens, 0)
173
- for token in tokens:
174
- if token in self.special_tokens:
175
- ids.append(self.special_tokens[token])
176
- else:
177
- ids.append(self.encoder.get(token, 0))
178
- if len(ids) > self.max_len:
179
- logger.warning(
180
- 'Token indices sequence length is longer than the specified maximum '
181
- ' sequence length for this OpenAI GPT model ({} > {}). Running this'
182
- ' sequence through the model will result in indexing errors'.
183
- format(len(ids), self.max_len))
184
- return ids
185
-
186
- def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
187
- """Converts a sequence of ids in BPE tokens using the vocab."""
188
- tokens = []
189
- for i in ids:
190
- if i in self.special_tokens_decoder:
191
- if not skip_special_tokens:
192
- tokens.append(self.special_tokens_decoder[i])
193
- else:
194
- tokens.append(self.decoder[i])
195
- return tokens
196
-
197
- def encode(self, text):
198
- return self.convert_tokens_to_ids(self.tokenize(text))
199
-
200
- def decode(self, tokens):
201
- text = ''.join([self.decoder[token] for token in tokens])
202
- text = bytearray([self.byte_decoder[c] for c in text]).decode(
203
- 'utf-8', errors=self.errors)
204
- return text
205
-
206
-
207
- class DummyTokenizer:
208
-
209
- def tokenize(self, text: str):
210
- return text.split()
211
-
212
-
213
- def get_tokenized_string(tokenizer: GPT2Tokenizer, text_list: Sequence[str]):
214
- token_ids_list, tokenized_string_list = [], []
215
- for text in text_list:
216
- assert tokenizer is not None
217
- token_ids = tokenizer.encode(text)
218
- tokenized_string = ' '.join(tokenizer.convert_ids_to_tokens(token_ids))
219
- token_ids_list.append(token_ids)
220
- tokenized_string_list.append(tokenized_string)
221
- return token_ids_list, tokenized_string_list
File without changes
File without changes