omni-split 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omni_split/__init__.py +16 -0
- omni_split/base/__init__.py +0 -0
- omni_split/base/chonkie_base.py +139 -0
- omni_split/base/chonkie_tokenizer.py +285 -0
- omni_split/base/chonkie_types.py +519 -0
- omni_split/base/md2json_list.py +303 -0
- omni_split/base/md_json_list2chunk.py +348 -0
- omni_split/main.py +73 -0
- omni_split/model/text_chunker_tokenizer/qwen_tokenizer.json +303282 -0
- omni_split/omni_split.py +93 -0
- omni_split/sub_chunker/__init__.py +0 -0
- omni_split/sub_chunker/document_split.py +32 -0
- omni_split/sub_chunker/markdown_split.py +47 -0
- omni_split/sub_chunker/text_split.py +343 -0
- omni_split/test.py +80 -0
- omni_split/utils/__init__.py +0 -0
- omni_split/utils/base_utils.py +181 -0
- omni_split/utils/download_test_doc.py +61 -0
- omni_split-0.0.3.dist-info/METADATA +147 -0
- omni_split-0.0.3.dist-info/RECORD +23 -0
- omni_split-0.0.3.dist-info/WHEEL +5 -0
- omni_split-0.0.3.dist-info/licenses/LICENSE +21 -0
- omni_split-0.0.3.dist-info/top_level.txt +1 -0
omni_split/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .omni_split import OmniSplit
|
|
2
|
+
from .utils.base_utils import word_preprocessing_and_return_bytesIO
|
|
3
|
+
from .utils.download_test_doc import download_files_to_test_doc
|
|
4
|
+
|
|
5
|
+
__version__ = "0.0.3"
|
|
6
|
+
__name__ = "omni_split"
|
|
7
|
+
__author__ = "dinobot22"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"__name__",
|
|
11
|
+
"__version__",
|
|
12
|
+
"__author__",
|
|
13
|
+
"OmniSplit",
|
|
14
|
+
"word_preprocessing_and_return_bytesIO",
|
|
15
|
+
"download_files_to_test_doc"
|
|
16
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Base classes for chunking text."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from multiprocessing import Pool, cpu_count
|
|
6
|
+
from typing import Any, Callable, List, Union
|
|
7
|
+
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from .chonkie_tokenizer import Tokenizer
|
|
11
|
+
from .chonkie_types import Chunk
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseChunker(ABC):
|
|
15
|
+
"""Abstract base class for all chunker implementations.
|
|
16
|
+
|
|
17
|
+
All chunker implementations should inherit from this class and implement
|
|
18
|
+
the chunk() method according to their specific chunking strategy.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, tokenizer_or_token_counter: Union[str, Any, Callable[[str], int]]):
|
|
22
|
+
"""Initialize the chunker with a tokenizer.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
tokenizer_or_token_counter (Union[str, Any]): String, tokenizer object, or token counter object
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
self.tokenizer = Tokenizer(tokenizer_or_token_counter)
|
|
29
|
+
|
|
30
|
+
# Set whether to use multiprocessing or not
|
|
31
|
+
self._use_multiprocessing = True
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def chunk(self, text: str) -> List[Chunk]:
|
|
35
|
+
"""Split text into chunks according to the implementation strategy.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
text: Input text to be chunked
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
List of Chunk objects containing the chunked text and metadata
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def _determine_optimal_workers(self) -> int:
|
|
47
|
+
"""Determine optimal number of workers based on system resources."""
|
|
48
|
+
try:
|
|
49
|
+
# Get CPU cores
|
|
50
|
+
cpu_cores = cpu_count()
|
|
51
|
+
|
|
52
|
+
# Never use more than 75% of available cores
|
|
53
|
+
max_workers = max(1, int(cpu_cores * 0.75))
|
|
54
|
+
|
|
55
|
+
# Cap at 8 workers
|
|
56
|
+
return min(max_workers, 8)
|
|
57
|
+
|
|
58
|
+
except Exception as e:
|
|
59
|
+
warnings.warn(f"Error determining optimal workers: {e}. Using single process.")
|
|
60
|
+
return 1
|
|
61
|
+
|
|
62
|
+
def _process_batch_sequential(self, texts: List[str], show_progress_bar: bool = True) -> List[List[Chunk]]:
|
|
63
|
+
"""Process a batch of texts sequentially."""
|
|
64
|
+
return [
|
|
65
|
+
self.chunk(t)
|
|
66
|
+
for t in tqdm(
|
|
67
|
+
texts,
|
|
68
|
+
desc="🦛",
|
|
69
|
+
disable=not show_progress_bar,
|
|
70
|
+
unit="doc",
|
|
71
|
+
bar_format="{desc} ch{bar:20}nk {percentage:3.0f}% • {n_fmt}/{total_fmt} docs chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱",
|
|
72
|
+
ascii=" o",
|
|
73
|
+
)
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
def _process_batch_multiprocessing(self, texts: List[str], show_progress_bar: bool = True) -> List[List[Chunk]]:
|
|
77
|
+
"""Process a batch of texts using multiprocessing."""
|
|
78
|
+
num_workers = self._determine_optimal_workers()
|
|
79
|
+
total = len(texts)
|
|
80
|
+
chunksize = max(1, min(total // (num_workers * 16), 10)) # Optimize chunk size
|
|
81
|
+
|
|
82
|
+
with Pool(processes=num_workers) as pool:
|
|
83
|
+
results = []
|
|
84
|
+
with tqdm(
|
|
85
|
+
total=total,
|
|
86
|
+
desc="🦛",
|
|
87
|
+
disable=not show_progress_bar,
|
|
88
|
+
unit="doc",
|
|
89
|
+
bar_format="{desc} ch{bar:20}nk {percentage:3.0f}% • {n_fmt}/{total_fmt} docs chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱",
|
|
90
|
+
ascii=" o",
|
|
91
|
+
) as pbar:
|
|
92
|
+
for result in pool.imap(self.chunk, texts, chunksize=chunksize):
|
|
93
|
+
results.append(result)
|
|
94
|
+
pbar.update()
|
|
95
|
+
return results
|
|
96
|
+
|
|
97
|
+
def chunk_batch(
|
|
98
|
+
self,
|
|
99
|
+
texts: List[str],
|
|
100
|
+
show_progress_bar: bool = True,
|
|
101
|
+
) -> List[List[Chunk]]:
|
|
102
|
+
"""Split a List of texts into their respective chunks.
|
|
103
|
+
|
|
104
|
+
By default, this method uses multiprocessing to parallelize the chunking process.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
texts: List of input texts to be chunked.
|
|
108
|
+
show_progress_bar: Whether to show a progress bar.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of lists of Chunk objects containing the chunked text and metadata
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
if self._use_multiprocessing:
|
|
115
|
+
return self._process_batch_multiprocessing(texts, show_progress_bar)
|
|
116
|
+
else:
|
|
117
|
+
return self._process_batch_sequential(texts, show_progress_bar)
|
|
118
|
+
|
|
119
|
+
def __call__(self, text: Union[str, List[str]], show_progress_bar: bool = True) -> Union[List[Chunk], List[List[Chunk]]]:
|
|
120
|
+
"""Make the chunker callable directly.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
text: Input text or list of texts to be chunked
|
|
124
|
+
show_progress_bar: Whether to show a progress bar (for batch chunking)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of Chunk objects or list of lists of Chunk
|
|
128
|
+
|
|
129
|
+
"""
|
|
130
|
+
if isinstance(text, str):
|
|
131
|
+
return self.chunk(text)
|
|
132
|
+
elif isinstance(text, list):
|
|
133
|
+
return self.chunk_batch(text, show_progress_bar)
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError("Input must be a string or a list of strings.")
|
|
136
|
+
|
|
137
|
+
def __repr__(self) -> str:
|
|
138
|
+
"""Return string representation of the chunker."""
|
|
139
|
+
return f"{self.__class__.__name__}()"
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""A utility module for handling tokenization across different backends."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import inspect
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import Any, Callable, Dict, List, Union
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Tokenizer:
|
|
10
|
+
"""Unified tokenizer interface for Chonkie.
|
|
11
|
+
|
|
12
|
+
Handles tokenizer initialization and operations across different backends
|
|
13
|
+
(HuggingFace, TikToken, custom tokenizers).
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
tokenizer: Tokenizer instance or identifier (e.g., "gpt2")
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
ImportError: If required tokenizer backend is not installed
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, tokenizer: Union[str, Callable, Any] = "gpt2"):
|
|
24
|
+
"""Initialize the tokenizer."""
|
|
25
|
+
# Initialize the tokenizer
|
|
26
|
+
if isinstance(tokenizer, str):
|
|
27
|
+
self.tokenizer = self._load_tokenizer(tokenizer)
|
|
28
|
+
else:
|
|
29
|
+
self.tokenizer = tokenizer
|
|
30
|
+
|
|
31
|
+
# Determine the tokenizer backend
|
|
32
|
+
self._tokenizer_backend = self._get_tokenizer_backend()
|
|
33
|
+
|
|
34
|
+
def _get_tokenizer_backend(self) -> str:
|
|
35
|
+
"""Determine the tokenizer backend."""
|
|
36
|
+
# Check if the tokenizer is a character or word tokenizer
|
|
37
|
+
if "chonkie" in str(type(self.tokenizer)):
|
|
38
|
+
return "chonkie"
|
|
39
|
+
elif "transformers" in str(type(self.tokenizer)):
|
|
40
|
+
return "transformers"
|
|
41
|
+
elif "tokenizers" in str(type(self.tokenizer)):
|
|
42
|
+
return "tokenizers"
|
|
43
|
+
elif "tiktoken" in str(type(self.tokenizer)):
|
|
44
|
+
return "tiktoken"
|
|
45
|
+
elif callable(self.tokenizer) or inspect.isfunction(self.tokenizer) or inspect.ismethod(self.tokenizer):
|
|
46
|
+
return "callable"
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError(f"Tokenizer backend {str(type(self.tokenizer))} not supported")
|
|
49
|
+
|
|
50
|
+
def _load_tokenizer(self, tokenizer_name: str):
|
|
51
|
+
"""Load a tokenizer based on the backend."""
|
|
52
|
+
# Check if the string is equal to "character"
|
|
53
|
+
if tokenizer_name == "character":
|
|
54
|
+
return CharacterTokenizer()
|
|
55
|
+
elif tokenizer_name == "word":
|
|
56
|
+
return WordTokenizer()
|
|
57
|
+
else:
|
|
58
|
+
try:
|
|
59
|
+
if importlib.util.find_spec("tokenizers") is not None:
|
|
60
|
+
from tokenizers import Tokenizer
|
|
61
|
+
|
|
62
|
+
return Tokenizer.from_pretrained(tokenizer_name)
|
|
63
|
+
else:
|
|
64
|
+
raise Warning("Tokenizers library not found. Trying tiktoken.")
|
|
65
|
+
except Exception:
|
|
66
|
+
try:
|
|
67
|
+
if importlib.util.find_spec("tiktoken") is not None:
|
|
68
|
+
from tiktoken import get_encoding
|
|
69
|
+
|
|
70
|
+
return get_encoding(tokenizer_name)
|
|
71
|
+
else:
|
|
72
|
+
raise Warning("TikToken library not found. Trying transformers.")
|
|
73
|
+
except Exception:
|
|
74
|
+
try:
|
|
75
|
+
if importlib.util.find_spec("transformers") is not None:
|
|
76
|
+
from transformers import AutoTokenizer
|
|
77
|
+
|
|
78
|
+
return AutoTokenizer.from_pretrained(tokenizer_name)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError("Tokenizer not found in the following libraries: transformers, tokenizers, tiktoken")
|
|
81
|
+
except Exception:
|
|
82
|
+
raise ValueError("Tokenizer not found in the following libraries: transformers, tokenizers, tiktoken")
|
|
83
|
+
|
|
84
|
+
def encode(self, text: str) -> List[int]:
|
|
85
|
+
"""Encode text to token ids."""
|
|
86
|
+
if self._tokenizer_backend == "chonkie":
|
|
87
|
+
return self.tokenizer.encode(text)
|
|
88
|
+
elif self._tokenizer_backend == "transformers":
|
|
89
|
+
return self.tokenizer.encode(text, add_special_tokens=False)
|
|
90
|
+
elif self._tokenizer_backend == "tokenizers":
|
|
91
|
+
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
92
|
+
elif self._tokenizer_backend == "tiktoken":
|
|
93
|
+
return self.tokenizer.encode(text)
|
|
94
|
+
elif self._tokenizer_backend == "callable":
|
|
95
|
+
raise NotImplementedError("Callable tokenizer backend does not support encoding.")
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
|
|
98
|
+
|
|
99
|
+
def encode_batch(self, texts: List[str]) -> List[List[int]]:
|
|
100
|
+
"""Encode a batch of texts."""
|
|
101
|
+
if self._tokenizer_backend == "chonkie":
|
|
102
|
+
return self.tokenizer.encode_batch(texts)
|
|
103
|
+
elif self._tokenizer_backend == "transformers":
|
|
104
|
+
return self.tokenizer.batch_encode_plus(texts, add_special_tokens=False)["input_ids"]
|
|
105
|
+
elif self._tokenizer_backend == "tokenizers":
|
|
106
|
+
return [t.ids for t in self.tokenizer.encode_batch(texts, add_special_tokens=False)]
|
|
107
|
+
elif self._tokenizer_backend == "tiktoken":
|
|
108
|
+
return self.tokenizer.encode_batch(texts)
|
|
109
|
+
elif self._tokenizer_backend == "callable":
|
|
110
|
+
raise NotImplementedError("Callable tokenizer backend does not support batch encoding.")
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
|
|
113
|
+
|
|
114
|
+
def decode(self, tokens: List[int]) -> str:
|
|
115
|
+
"""Decode token ids back to text."""
|
|
116
|
+
if self._tokenizer_backend == "callable":
|
|
117
|
+
raise NotImplementedError("Callable tokenizer backend does not support decoding.")
|
|
118
|
+
return self.tokenizer.decode(tokens)
|
|
119
|
+
|
|
120
|
+
def decode_batch(self, token_lists: List[List[int]]) -> List[str]:
|
|
121
|
+
"""Decode multiple token lists."""
|
|
122
|
+
if self._tokenizer_backend == "chonkie":
|
|
123
|
+
return self.tokenizer.decode_batch(token_lists)
|
|
124
|
+
elif self._tokenizer_backend == "transformers":
|
|
125
|
+
return self.tokenizer.batch_decode(token_lists, skip_special_tokens=True)
|
|
126
|
+
elif self._tokenizer_backend in ["tokenizers", "tiktoken"]:
|
|
127
|
+
return self.tokenizer.decode_batch(token_lists)
|
|
128
|
+
elif self._tokenizer_backend == "callable":
|
|
129
|
+
raise NotImplementedError("Callable tokenizer backend does not support batch decoding.")
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
|
|
132
|
+
|
|
133
|
+
def count_tokens(self, text: str) -> int:
|
|
134
|
+
"""Count number of tokens in text."""
|
|
135
|
+
if self._tokenizer_backend == "chonkie":
|
|
136
|
+
return self.tokenizer.count_tokens(text)
|
|
137
|
+
elif self._tokenizer_backend == "transformers":
|
|
138
|
+
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
|
139
|
+
elif self._tokenizer_backend == "tokenizers":
|
|
140
|
+
return len(self.tokenizer.encode(text, add_special_tokens=False).ids)
|
|
141
|
+
elif self._tokenizer_backend == "tiktoken":
|
|
142
|
+
return len(self.tokenizer.encode(text))
|
|
143
|
+
elif self._tokenizer_backend == "callable":
|
|
144
|
+
return self.tokenizer(text)
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
|
|
147
|
+
|
|
148
|
+
def count_tokens_batch(self, texts: List[str]) -> List[int]:
|
|
149
|
+
"""Count tokens in multiple texts."""
|
|
150
|
+
if self._tokenizer_backend == "chonkie":
|
|
151
|
+
return self.tokenizer.count_tokens_batch(texts)
|
|
152
|
+
elif self._tokenizer_backend == "transformers":
|
|
153
|
+
return [len(token_list) for token_list in self.tokenizer.batch_encode_plus(texts, add_special_tokens=False)["input_ids"]]
|
|
154
|
+
elif self._tokenizer_backend == "tokenizers":
|
|
155
|
+
return [len(token_list) for token_list in [t.ids for t in self.tokenizer.encode_batch(texts, add_special_tokens=False)]]
|
|
156
|
+
elif self._tokenizer_backend == "tiktoken":
|
|
157
|
+
return [len(token_list) for token_list in self.tokenizer.encode_batch(texts)]
|
|
158
|
+
elif self._tokenizer_backend == "callable":
|
|
159
|
+
return [self.tokenizer(text) for text in texts]
|
|
160
|
+
else:
|
|
161
|
+
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class CharacterTokenizer:
|
|
165
|
+
"""Character-based tokenizer."""
|
|
166
|
+
|
|
167
|
+
def __init__(self):
|
|
168
|
+
"""Initialize the tokenizer."""
|
|
169
|
+
# Initialize the vocabulary with a space character
|
|
170
|
+
self.vocab = []
|
|
171
|
+
self.token2id = defaultdict(lambda: len(self.vocab))
|
|
172
|
+
|
|
173
|
+
# Add space character to vocabulary
|
|
174
|
+
_ = self.token2id[" "]
|
|
175
|
+
self.vocab.append(" ")
|
|
176
|
+
|
|
177
|
+
def get_vocab(self) -> List[str]:
|
|
178
|
+
"""Get the vocabulary."""
|
|
179
|
+
return self.vocab
|
|
180
|
+
|
|
181
|
+
def get_token2id(self) -> Dict[str, int]:
|
|
182
|
+
"""Get the token to id mapping."""
|
|
183
|
+
return self.token2id
|
|
184
|
+
|
|
185
|
+
def encode(self, text: str) -> List[int]:
|
|
186
|
+
"""Encode text to token ids."""
|
|
187
|
+
ids = []
|
|
188
|
+
for token in text:
|
|
189
|
+
token_id = self.token2id[token]
|
|
190
|
+
if token_id >= len(self.vocab):
|
|
191
|
+
self.vocab.append(token)
|
|
192
|
+
ids.append(token_id)
|
|
193
|
+
return ids
|
|
194
|
+
|
|
195
|
+
def encode_batch(self, texts: List[str]) -> List[List[int]]:
|
|
196
|
+
"""Encode a batch of texts."""
|
|
197
|
+
return [self.encode(text) for text in texts]
|
|
198
|
+
|
|
199
|
+
def decode(self, tokens: List[int]) -> str:
|
|
200
|
+
"""Decode token ids back to text."""
|
|
201
|
+
try:
|
|
202
|
+
return "".join([self.vocab[token] for token in tokens])
|
|
203
|
+
except IndexError:
|
|
204
|
+
raise ValueError(f"Token {tokens} not found in vocabulary.")
|
|
205
|
+
|
|
206
|
+
def decode_batch(self, token_lists: List[List[int]]) -> List[str]:
|
|
207
|
+
"""Decode multiple token lists."""
|
|
208
|
+
return [self.decode(token_list) for token_list in token_lists]
|
|
209
|
+
|
|
210
|
+
def count_tokens(self, text: str) -> int:
|
|
211
|
+
"""Count number of tokens in text."""
|
|
212
|
+
return len(text)
|
|
213
|
+
|
|
214
|
+
def count_tokens_batch(self, texts: List[str]) -> List[int]:
|
|
215
|
+
"""Count tokens in multiple texts."""
|
|
216
|
+
return [len(text) for text in texts]
|
|
217
|
+
|
|
218
|
+
def __repr__(self) -> str:
|
|
219
|
+
"""Return a string representation of the tokenizer."""
|
|
220
|
+
return f"CharacterTokenizer(vocab_size={len(self.vocab)})"
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class WordTokenizer:
|
|
224
|
+
"""Word-based tokenizer."""
|
|
225
|
+
|
|
226
|
+
def __init__(self):
|
|
227
|
+
"""Initialize the tokenizer."""
|
|
228
|
+
# Initialize the vocabulary with a space character
|
|
229
|
+
self.vocab = []
|
|
230
|
+
self.token2id = defaultdict(lambda: len(self.vocab))
|
|
231
|
+
|
|
232
|
+
# Add space character to vocabulary
|
|
233
|
+
_ = self.token2id[" "]
|
|
234
|
+
self.vocab.append(" ")
|
|
235
|
+
|
|
236
|
+
def get_vocab(self) -> List[str]:
|
|
237
|
+
"""Get the vocabulary."""
|
|
238
|
+
return self.vocab
|
|
239
|
+
|
|
240
|
+
def get_token2id(self) -> Dict[str, int]:
|
|
241
|
+
"""Get the token to id mapping."""
|
|
242
|
+
return self.token2id
|
|
243
|
+
|
|
244
|
+
def tokenize(self, text: str) -> List[str]:
|
|
245
|
+
"""Tokenize text."""
|
|
246
|
+
words = text.split(" ")
|
|
247
|
+
return words
|
|
248
|
+
|
|
249
|
+
def encode(self, text: str) -> List[int]:
|
|
250
|
+
"""Encode text to token ids."""
|
|
251
|
+
tokens = self.tokenize(text)
|
|
252
|
+
ids = []
|
|
253
|
+
for token in tokens:
|
|
254
|
+
token_id = self.token2id[token]
|
|
255
|
+
if token_id >= len(self.vocab):
|
|
256
|
+
self.vocab.append(token)
|
|
257
|
+
ids.append(token_id)
|
|
258
|
+
return ids
|
|
259
|
+
|
|
260
|
+
def encode_batch(self, texts: List[str]) -> List[List[int]]:
|
|
261
|
+
"""Encode a batch of texts."""
|
|
262
|
+
return [self.encode(text) for text in texts]
|
|
263
|
+
|
|
264
|
+
def decode(self, tokens: List[int]) -> str:
|
|
265
|
+
"""Decode token ids back to text."""
|
|
266
|
+
try:
|
|
267
|
+
return " ".join([self.vocab[token] for token in tokens])
|
|
268
|
+
except IndexError:
|
|
269
|
+
raise ValueError(f"Token {tokens} not found in vocabulary.")
|
|
270
|
+
|
|
271
|
+
def decode_batch(self, token_lists: List[List[int]]) -> List[str]:
|
|
272
|
+
"""Decode multiple token lists."""
|
|
273
|
+
return [self.decode(token_list) for token_list in token_lists]
|
|
274
|
+
|
|
275
|
+
def count_tokens(self, text: str) -> int:
|
|
276
|
+
"""Count number of tokens in text."""
|
|
277
|
+
return len(self.encode(text))
|
|
278
|
+
|
|
279
|
+
def count_tokens_batch(self, texts: List[str]) -> List[int]:
|
|
280
|
+
"""Count tokens in multiple texts."""
|
|
281
|
+
return [len(self.encode(text)) for text in texts]
|
|
282
|
+
|
|
283
|
+
def __repr__(self) -> str:
|
|
284
|
+
"""Return a string representation of the tokenizer."""
|
|
285
|
+
return f"WordTokenizer(vocab_size={len(self.vocab)})"
|