codon-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codon/__init__.py +5 -0
- codon/base.py +167 -0
- codon/exp/__init__.py +0 -0
- codon/exp/moe.py +307 -0
- codon/model/__init__.py +0 -0
- codon/model/motif/__init__.py +1 -0
- codon/model/motif/motif_a1.py +121 -0
- codon/model/patch_disc.py +151 -0
- codon/model/tcn.py +124 -0
- codon/ops/__init__.py +3 -0
- codon/ops/attention.py +107 -0
- codon/ops/bio.py +0 -0
- codon/utils/__init__.py +0 -0
- codon/utils/dataset/__init__.py +3 -0
- codon/utils/dataset/base.py +46 -0
- codon/utils/dataset/corpus.py +478 -0
- codon/utils/dataset/dataviewer.py +196 -0
- codon/utils/dataset/flatdata.py +455 -0
- codon/utils/mask.py +266 -0
- codon/utils/safecode.py +24 -0
- codon/utils/seed.py +75 -0
- codon/utils/theta.py +55 -0
- codon/utils/token.py +276 -0
- codon_model-0.0.1.dist-info/METADATA +17 -0
- codon_model-0.0.1.dist-info/RECORD +28 -0
- codon_model-0.0.1.dist-info/WHEEL +5 -0
- codon_model-0.0.1.dist-info/licenses/LICENSE +201 -0
- codon_model-0.0.1.dist-info/top_level.txt +1 -0
codon/utils/mask.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from tokenizers import Tokenizer
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum, auto
|
|
6
|
+
|
|
7
|
+
from typing import Union
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MaskMode(Enum):
|
|
11
|
+
'''
|
|
12
|
+
Enumeration of different masking modes for TokenMask.
|
|
13
|
+
|
|
14
|
+
Each mode defines how the sequence is masked relative to the special token(s).
|
|
15
|
+
The mask values are: 0 for masked, 1 for unmasked (kept).
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
FIRST_MASK_PRE: Find the first occurrence of the special token.
|
|
19
|
+
Mask tokens before and including the special token (0). Keep the rest (1).
|
|
20
|
+
FIRST_MASK_POST: Find the first occurrence of the special token.
|
|
21
|
+
Keep tokens before and including the special token (1). Mask the rest (0).
|
|
22
|
+
LAST_MASK_PRE: Find the last occurrence of the special token.
|
|
23
|
+
Mask tokens before and including the special token (0). Keep the rest (1).
|
|
24
|
+
LAST_MASK_POST: Find the last occurrence of the special token.
|
|
25
|
+
Keep tokens before and including the special token (1). Mask the rest (0).
|
|
26
|
+
ALL_MASK_FIRST: Find all occurrences.
|
|
27
|
+
The first segment (ending with the special token) is masked (0), then alternates.
|
|
28
|
+
ALL_KEEP_FIRST: Find all occurrences.
|
|
29
|
+
The first segment (ending with the special token) is kept (1), then alternates.
|
|
30
|
+
'''
|
|
31
|
+
FIRST_MASK_PRE = auto()
|
|
32
|
+
FIRST_MASK_POST = auto()
|
|
33
|
+
LAST_MASK_PRE = auto()
|
|
34
|
+
LAST_MASK_POST = auto()
|
|
35
|
+
ALL_MASK_FIRST = auto()
|
|
36
|
+
ALL_KEEP_FIRST = auto()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def make_padding_mask(src: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
|
|
40
|
+
'''
|
|
41
|
+
Creates a padding mask.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
src (torch.Tensor): The source sequence tensor. Shape is [B, L_src].
|
|
45
|
+
pad_idx (int, optional): The index of the padding symbol. Defaults to 0.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
torch.Tensor: The padding mask. Shape is [B, 1, 1, L_src].
|
|
49
|
+
True indicates the position is not padding and should be attended to.
|
|
50
|
+
'''
|
|
51
|
+
mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
|
|
52
|
+
return mask
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def make_lookahead_mask(size: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
|
|
56
|
+
'''
|
|
57
|
+
Creates a lookahead mask (lower triangular matrix).
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
size (int): The sequence length.
|
|
61
|
+
device (torch.device, optional): The device. Defaults to cpu.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: The lookahead mask. Shape is [size, size].
|
|
65
|
+
True indicates allowed positions to attend to (lower triangular part).
|
|
66
|
+
'''
|
|
67
|
+
mask = torch.tril(torch.ones((size, size), device=device)).bool()
|
|
68
|
+
return mask
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def make_causal_mask(tgt: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
|
|
72
|
+
'''
|
|
73
|
+
Creates a causal mask (combining padding mask and lookahead mask).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
tgt (torch.Tensor): The target sequence tensor. Shape is [B, L_tgt].
|
|
77
|
+
pad_idx (int, optional): The index of the padding symbol. Defaults to 0.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
torch.Tensor: The causal mask. Shape is [B, 1, L_tgt, L_tgt].
|
|
81
|
+
'''
|
|
82
|
+
pad_mask = make_padding_mask(tgt, pad_idx)
|
|
83
|
+
seq_len = tgt.size(1)
|
|
84
|
+
lookahead_mask = make_lookahead_mask(seq_len, device=tgt.device)
|
|
85
|
+
|
|
86
|
+
# pad_mask: [B, 1, 1, L]
|
|
87
|
+
# lookahead_mask: [L, L]
|
|
88
|
+
mask = pad_mask & lookahead_mask
|
|
89
|
+
return mask
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def make_sliding_window_mask(
|
|
93
|
+
tensor: torch.Tensor, window_size: int, pad_idx: int = 0, causal: bool = True
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
'''
|
|
96
|
+
Creates a sliding window mask.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
tensor (torch.Tensor): The input sequence tensor. Shape is [B, L].
|
|
100
|
+
window_size (int): The window size (one-sided).
|
|
101
|
+
pad_idx (int, optional): The index of the padding symbol. Defaults to 0.
|
|
102
|
+
causal (bool, optional): Whether it is causal (unidirectional). Defaults to True.
|
|
103
|
+
If True, position i can only attend to [i - window_size, i].
|
|
104
|
+
If False, position i can attend to [i - window_size, i + window_size].
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
torch.Tensor: The sliding window mask. Shape is [B, 1, L, L].
|
|
108
|
+
'''
|
|
109
|
+
pad_mask = make_padding_mask(tensor, pad_idx) # [B, 1, 1, L]
|
|
110
|
+
seq_len = tensor.size(1)
|
|
111
|
+
|
|
112
|
+
ones = torch.ones((seq_len, seq_len), device=tensor.device, dtype=torch.bool)
|
|
113
|
+
|
|
114
|
+
if causal:
|
|
115
|
+
# j <= i AND j >= i - window_size
|
|
116
|
+
window_mask = torch.tril(ones, diagonal=0) & torch.triu(
|
|
117
|
+
ones, diagonal=-window_size
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
# j <= i + window_size AND j >= i - window_size
|
|
121
|
+
window_mask = torch.tril(ones, diagonal=window_size) & torch.triu(
|
|
122
|
+
ones, diagonal=-window_size
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
mask = pad_mask & window_mask
|
|
126
|
+
return mask
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class MaskedContent:
|
|
131
|
+
'''
|
|
132
|
+
Result of the token masking process.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
content (str): The original text content.
|
|
136
|
+
tokenized (Union[list[int], torch.Tensor]): The list of token IDs or tensor.
|
|
137
|
+
mask (Union[list[int], torch.Tensor]): The mask values (0 for masked, 1 for unmasked).
|
|
138
|
+
'''
|
|
139
|
+
content: str
|
|
140
|
+
tokenized: Union[list[int], torch.Tensor]
|
|
141
|
+
mask: Union[list[int], torch.Tensor]
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class TokenMask:
|
|
145
|
+
'''
|
|
146
|
+
Handles token masking logic based on special tokens.
|
|
147
|
+
'''
|
|
148
|
+
|
|
149
|
+
def __init__(self, tokenizer: Tokenizer) -> None:
|
|
150
|
+
'''
|
|
151
|
+
Initializes the TokenMask.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
tokenizer (Tokenizer): The configured tokenizer instance.
|
|
155
|
+
'''
|
|
156
|
+
self.tokenizer = tokenizer
|
|
157
|
+
|
|
158
|
+
def mask(
|
|
159
|
+
self,
|
|
160
|
+
content: str,
|
|
161
|
+
special_token: Union[str, int, list[Union[str, int]]],
|
|
162
|
+
mode: MaskMode = MaskMode.FIRST_MASK_PRE,
|
|
163
|
+
tensor_mask: bool = True
|
|
164
|
+
) -> MaskedContent:
|
|
165
|
+
'''
|
|
166
|
+
Tokenizes content and generates a mask based on the specified mode.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
content (str): The text content to tokenize and mask.
|
|
170
|
+
special_token (Union[str, int, list[Union[str, int]]]): The special token(s) to use as a separator.
|
|
171
|
+
mode (MaskMode): The masking mode. Defaults to MaskMode.FIRST_MASK_PRE.
|
|
172
|
+
tensor_mask (bool, optional): Whether to return tensors instead of lists. Defaults to True.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
MaskedContent: Dataclass containing the original content, token IDs, and the generated mask.
|
|
176
|
+
'''
|
|
177
|
+
encoded = self.tokenizer.encode(content)
|
|
178
|
+
ids = encoded.ids
|
|
179
|
+
|
|
180
|
+
candidates = []
|
|
181
|
+
if isinstance(special_token, list):
|
|
182
|
+
candidates = special_token
|
|
183
|
+
else:
|
|
184
|
+
candidates = [special_token]
|
|
185
|
+
|
|
186
|
+
# Determine the separator token id used in the sequence
|
|
187
|
+
sep_id = None
|
|
188
|
+
for cand in candidates:
|
|
189
|
+
tid = None
|
|
190
|
+
if isinstance(cand, str):
|
|
191
|
+
tid = self.tokenizer.token_to_id(cand)
|
|
192
|
+
elif isinstance(cand, int):
|
|
193
|
+
tid = cand
|
|
194
|
+
|
|
195
|
+
if tid is not None and tid in ids:
|
|
196
|
+
sep_id = tid
|
|
197
|
+
break
|
|
198
|
+
|
|
199
|
+
mask = []
|
|
200
|
+
|
|
201
|
+
if sep_id is None:
|
|
202
|
+
if mode in [MaskMode.FIRST_MASK_POST, MaskMode.LAST_MASK_POST, MaskMode.ALL_KEEP_FIRST]:
|
|
203
|
+
mask = [1] * len(ids)
|
|
204
|
+
else:
|
|
205
|
+
mask = [0] * len(ids)
|
|
206
|
+
else:
|
|
207
|
+
indices = [i for i, x in enumerate(ids) if x == sep_id]
|
|
208
|
+
|
|
209
|
+
if mode == MaskMode.FIRST_MASK_PRE:
|
|
210
|
+
# 1. Find first, mask pre (0), keep post (1).
|
|
211
|
+
# [0, 0, sep, 1, 1]
|
|
212
|
+
idx = indices[0]
|
|
213
|
+
mask = [0] * (idx + 1) + [1] * (len(ids) - idx - 1)
|
|
214
|
+
|
|
215
|
+
elif mode == MaskMode.FIRST_MASK_POST:
|
|
216
|
+
# 2. Find first, keep pre (1), mask post (0).
|
|
217
|
+
# [1, 1, sep, 0, 0]
|
|
218
|
+
idx = indices[0]
|
|
219
|
+
mask = [1] * (idx + 1) + [0] * (len(ids) - idx - 1)
|
|
220
|
+
|
|
221
|
+
elif mode == MaskMode.LAST_MASK_PRE:
|
|
222
|
+
# 3. Find last, mask pre (0), keep post (1).
|
|
223
|
+
idx = indices[-1]
|
|
224
|
+
mask = [0] * (idx + 1) + [1] * (len(ids) - idx - 1)
|
|
225
|
+
|
|
226
|
+
elif mode == MaskMode.LAST_MASK_POST:
|
|
227
|
+
# 4. Find last, keep pre (1), mask post (0).
|
|
228
|
+
idx = indices[-1]
|
|
229
|
+
mask = [1] * (idx + 1) + [0] * (len(ids) - idx - 1)
|
|
230
|
+
|
|
231
|
+
elif mode == MaskMode.ALL_MASK_FIRST:
|
|
232
|
+
# 5. All, segments. First seg mask (0), second unmask (1)...
|
|
233
|
+
# Segments end at sep.
|
|
234
|
+
current_val = 0
|
|
235
|
+
last_idx = 0
|
|
236
|
+
mask = []
|
|
237
|
+
for idx in indices:
|
|
238
|
+
# chunk includes sep
|
|
239
|
+
chunk_len = idx - last_idx + 1
|
|
240
|
+
mask.extend([current_val] * chunk_len)
|
|
241
|
+
current_val = 1 - current_val # toggle
|
|
242
|
+
last_idx = idx + 1
|
|
243
|
+
|
|
244
|
+
# Remaining part
|
|
245
|
+
if last_idx < len(ids):
|
|
246
|
+
mask.extend([current_val] * (len(ids) - last_idx))
|
|
247
|
+
|
|
248
|
+
elif mode == MaskMode.ALL_KEEP_FIRST:
|
|
249
|
+
# 6. All, segments. First seg keep (1), second mask (0)...
|
|
250
|
+
current_val = 1
|
|
251
|
+
last_idx = 0
|
|
252
|
+
mask = []
|
|
253
|
+
for idx in indices:
|
|
254
|
+
chunk_len = idx - last_idx + 1
|
|
255
|
+
mask.extend([current_val] * chunk_len)
|
|
256
|
+
current_val = 1 - current_val
|
|
257
|
+
last_idx = idx + 1
|
|
258
|
+
|
|
259
|
+
if last_idx < len(ids):
|
|
260
|
+
mask.extend([current_val] * (len(ids) - last_idx))
|
|
261
|
+
|
|
262
|
+
if tensor_mask:
|
|
263
|
+
ids = torch.tensor(ids)
|
|
264
|
+
mask = torch.tensor(mask)
|
|
265
|
+
|
|
266
|
+
return MaskedContent(content=content, tokenized=ids, mask=mask)
|
codon/utils/safecode.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import string
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def safecode(length: int = 4, exclude_confusing: bool = False) -> str:
|
|
6
|
+
'''
|
|
7
|
+
Generates a random safe code consisting of letters and digits.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
length (int): The length of the code to generate. Defaults to 4.
|
|
11
|
+
exclude_confusing (bool): If True, excludes confusing characters
|
|
12
|
+
('0oO1iIlLq9g') to reduce human error. Defaults to False.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
str: The generated random code.
|
|
16
|
+
'''
|
|
17
|
+
characters = string.ascii_letters + string.digits
|
|
18
|
+
|
|
19
|
+
if exclude_confusing:
|
|
20
|
+
confusing_chars = '0oO1iIlLq9g'
|
|
21
|
+
characters = ''.join(c for c in characters if c not in confusing_chars)
|
|
22
|
+
|
|
23
|
+
code = ''.join(random.choices(characters, k=length))
|
|
24
|
+
return code
|
codon/utils/seed.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import random
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
def seed_everything(seed: int = 42, strict: bool = False, warn_only: bool = True) -> None:
|
|
9
|
+
'''
|
|
10
|
+
Sets all random seeds to ensure reproducibility of PyTorch experiments.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
seed (int): The random seed value. Defaults to 42.
|
|
14
|
+
strict (bool): Whether to enable strict deterministic mode.
|
|
15
|
+
If True, calls torch.use_deterministic_algorithms(True),
|
|
16
|
+
which may raise errors for certain operations that do not support
|
|
17
|
+
deterministic algorithms and may reduce training speed.
|
|
18
|
+
warn_only (bool): If True, only warns when deterministic algorithms
|
|
19
|
+
are not available. Defaults to True.
|
|
20
|
+
'''
|
|
21
|
+
import codon
|
|
22
|
+
codon.__seed__ = seed
|
|
23
|
+
random.seed(seed)
|
|
24
|
+
np.random.seed(seed)
|
|
25
|
+
|
|
26
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
27
|
+
|
|
28
|
+
torch.manual_seed(seed)
|
|
29
|
+
if torch.cuda.is_available():
|
|
30
|
+
torch.cuda.manual_seed_all(seed)
|
|
31
|
+
|
|
32
|
+
if torch.cuda.is_available():
|
|
33
|
+
torch.backends.cudnn.benchmark = False
|
|
34
|
+
torch.backends.cudnn.deterministic = True
|
|
35
|
+
if strict:
|
|
36
|
+
try:
|
|
37
|
+
torch.use_deterministic_algorithms(True, warn_only=warn_only)
|
|
38
|
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
|
39
|
+
print(f'[Info] Strict deterministic mode enabled. (seed={seed})')
|
|
40
|
+
except AttributeError:
|
|
41
|
+
print('[Warning] torch.use_deterministic_algorithms is not available in your PyTorch version.')
|
|
42
|
+
else:
|
|
43
|
+
print(f'[Info] Random seed set as {seed}')
|
|
44
|
+
|
|
45
|
+
def get_seed() -> Optional[int]:
|
|
46
|
+
'''
|
|
47
|
+
Retrieves the global random seed.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Optional[int]: The current seed value, or None if not set.
|
|
51
|
+
'''
|
|
52
|
+
import codon
|
|
53
|
+
return codon.__seed__
|
|
54
|
+
|
|
55
|
+
def worker_init_fn(worker_id: Any) -> None:
|
|
56
|
+
'''
|
|
57
|
+
Worker initialization function for DataLoader to ensure each worker has a unique seed.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
worker_id (Any): The worker ID.
|
|
61
|
+
'''
|
|
62
|
+
worker_seed = torch.initial_seed() % 2**32
|
|
63
|
+
np.random.seed(worker_seed)
|
|
64
|
+
random.seed(worker_seed)
|
|
65
|
+
|
|
66
|
+
def create_generator() -> torch.Generator:
|
|
67
|
+
'''
|
|
68
|
+
Creates a random number generator with the global seed.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
torch.Generator: The configured random number generator.
|
|
72
|
+
'''
|
|
73
|
+
import codon
|
|
74
|
+
seed = codon.__seed__ if hasattr(codon, '__seed__') and codon.__seed__ is not None else 42
|
|
75
|
+
return torch.Generator().manual_seed(seed)
|
codon/utils/theta.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class ValidateRoPEConfig:
|
|
6
|
+
'''
|
|
7
|
+
Configuration result for RoPE/Sinusoidal validation.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
is_passed (bool): Whether the configuration is valid.
|
|
11
|
+
info (str): Assessment status.
|
|
12
|
+
suggested_base (float): The recommended base value.
|
|
13
|
+
'''
|
|
14
|
+
is_passed: bool
|
|
15
|
+
info: str
|
|
16
|
+
suggested_base: float
|
|
17
|
+
|
|
18
|
+
def validate_rope_config(max_len: int, base: float) -> ValidateRoPEConfig:
|
|
19
|
+
'''
|
|
20
|
+
Validates if the RoPE/Sinusoidal base is sufficient for the given maximum length.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
max_len (int): The maximum sequence length.
|
|
24
|
+
base (float): The base value for RoPE.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
ValidateRoPEConfig: The validation result containing status and recommendation.
|
|
28
|
+
'''
|
|
29
|
+
max_period = 2 * math.pi * base
|
|
30
|
+
|
|
31
|
+
if max_len <= 8192:
|
|
32
|
+
recommended = 10000.0
|
|
33
|
+
else:
|
|
34
|
+
# base = 10000 * (scaling_factor ^ 1.1)
|
|
35
|
+
scaling_factor = max_len / 4096
|
|
36
|
+
recommended = 10000.0 * (scaling_factor ** 1.1)
|
|
37
|
+
|
|
38
|
+
if recommended > 1000000:
|
|
39
|
+
recommended = round(recommended / 1000000) * 1000000.0
|
|
40
|
+
elif recommended > 100000:
|
|
41
|
+
recommended = round(recommended / 100000) * 100000.0
|
|
42
|
+
else:
|
|
43
|
+
recommended = round(recommended / 10000) * 10000.0
|
|
44
|
+
|
|
45
|
+
if max_period < max_len:
|
|
46
|
+
return ValidateRoPEConfig(is_passed=False, info='critical_low', suggested_base=recommended)
|
|
47
|
+
|
|
48
|
+
elif max_period < max_len * 2:
|
|
49
|
+
return ValidateRoPEConfig(is_passed=False, info='low', suggested_base=recommended)
|
|
50
|
+
|
|
51
|
+
elif max_period > max_len * 100:
|
|
52
|
+
return ValidateRoPEConfig(is_passed=True, info='too_high', suggested_base=recommended)
|
|
53
|
+
|
|
54
|
+
else:
|
|
55
|
+
return ValidateRoPEConfig(is_passed=True, info='optimal', suggested_base=recommended)
|
codon/utils/token.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import zipfile
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from tokenizers import Tokenizer, pre_tokenizers, decoders
|
|
6
|
+
from tokenizers import normalizers
|
|
7
|
+
from tokenizers.models import BPE
|
|
8
|
+
from tokenizers.trainers import BpeTrainer
|
|
9
|
+
|
|
10
|
+
from transformers import PreTrainedTokenizerFast
|
|
11
|
+
|
|
12
|
+
from typing import Union, Optional
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class TokenizerTrainerResult:
|
|
16
|
+
'''
|
|
17
|
+
Result of the tokenizer trainer creation.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
tokenizer (Tokenizer): The configured tokenizer instance.
|
|
21
|
+
trainer (BpeTrainer): The configured BPE trainer instance.
|
|
22
|
+
'''
|
|
23
|
+
tokenizer: Tokenizer
|
|
24
|
+
trainer: BpeTrainer
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
core_tokens = ['[unk]', '[pad]', '[sep]']
|
|
28
|
+
chat_tokens = [
|
|
29
|
+
'[im_start]', '[im_end]',
|
|
30
|
+
'[system]', '[user]', '[model]', '[tool]', '[train]',
|
|
31
|
+
'[interruption]', '[fim]',
|
|
32
|
+
]
|
|
33
|
+
reasoning_tokens = ['[cot_start]', '[cot_end]', '[verification]', '[solution]']
|
|
34
|
+
code_tokens = ['[fim_pre]', '[fim_mid]', '[fim_suf]']
|
|
35
|
+
tool_tokens = ['[tool_start]', '[tool_name]', '[tool_args]', '[tool_end]']
|
|
36
|
+
|
|
37
|
+
multimodal_tokens = [
|
|
38
|
+
'[image_start]', '[image_end]', '[audio_start]', '[audio_end]',
|
|
39
|
+
'[video_start]', '[video_end]'
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
base_special_tokens = (
|
|
43
|
+
core_tokens +
|
|
44
|
+
chat_tokens +
|
|
45
|
+
reasoning_tokens +
|
|
46
|
+
code_tokens +
|
|
47
|
+
tool_tokens +
|
|
48
|
+
multimodal_tokens
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
base_special_tokens += [f'[unused_{i}]' for i in range(len(base_special_tokens), 64)]
|
|
52
|
+
base_special_tokens += [f'[mask_{i}]' for i in range(32)]
|
|
53
|
+
|
|
54
|
+
chat_template = (
|
|
55
|
+
"{% for message in messages %}"
|
|
56
|
+
"{{ '[im_start]' }}"
|
|
57
|
+
|
|
58
|
+
"{% if message['role'] == 'fim' %}"
|
|
59
|
+
"{{ '[fim]' }}"
|
|
60
|
+
"{{ '[fim_pre]' + message['prefix'] + '[fim_suf]' + message['suffix'] + '[fim_mid]' }}"
|
|
61
|
+
|
|
62
|
+
"{% if message['middle'] %}"
|
|
63
|
+
"{{ message['middle'] + '[im_end]' }}"
|
|
64
|
+
"{% endif %}"
|
|
65
|
+
|
|
66
|
+
"{% else %}"
|
|
67
|
+
|
|
68
|
+
"{% if message['role'] in ['system', 'instruction'] %}"
|
|
69
|
+
"{{ '[system]' }}"
|
|
70
|
+
"{% elif message['role'] == 'user' %}"
|
|
71
|
+
"{{ '[user]' }}"
|
|
72
|
+
"{% elif message['role'] in ['assistant', 'model'] %}"
|
|
73
|
+
"{{ '[model]' }}"
|
|
74
|
+
"{% elif message['role'] == 'tool' %}"
|
|
75
|
+
"{{ '[tool]' }}"
|
|
76
|
+
"{% elif message['role'] == 'train' %}"
|
|
77
|
+
"{{ '[train]' }}"
|
|
78
|
+
"{% else %}"
|
|
79
|
+
"{{ message['role'] }}"
|
|
80
|
+
"{% endif %}"
|
|
81
|
+
|
|
82
|
+
"{{ '\n' }}"
|
|
83
|
+
|
|
84
|
+
"{% set thought_content = message['thought'] or message['reasoning_content'] %}"
|
|
85
|
+
"{% if thought_content %}"
|
|
86
|
+
"{{ '[cot_start]' + thought_content + '[cot_end]\n' }}"
|
|
87
|
+
"{% else %}"
|
|
88
|
+
"{{ '[cot_start][cot_end]\n' }}"
|
|
89
|
+
"{% endif %}"
|
|
90
|
+
|
|
91
|
+
"{% if message['content'] is defined and message['content'] is not none %}"
|
|
92
|
+
"{% if message['content'] is string %}"
|
|
93
|
+
"{{ message['content'] }}"
|
|
94
|
+
"{% else %}"
|
|
95
|
+
"{% for item in message['content'] %}"
|
|
96
|
+
"{% if item['type'] == 'text' %}"
|
|
97
|
+
"{{ item['text'] }}"
|
|
98
|
+
"{% elif item['type'] == 'image' %}"
|
|
99
|
+
"{{ '[image_start][image_end]' }}"
|
|
100
|
+
"{% elif item['type'] == 'audio' %}"
|
|
101
|
+
"{{ '[audio_start][audio_end]' }}"
|
|
102
|
+
"{% elif item['type'] == 'video' %}"
|
|
103
|
+
"{{ '[video_start][video_end]' }}"
|
|
104
|
+
"{% endif %}"
|
|
105
|
+
"{% endfor %}"
|
|
106
|
+
"{% endif %}"
|
|
107
|
+
"{% endif %}"
|
|
108
|
+
|
|
109
|
+
"{% if message['tool_calls'] is defined and message['tool_calls'] %}"
|
|
110
|
+
"{% for tool_call in message['tool_calls'] %}"
|
|
111
|
+
"{{ '[tool_start][tool_name]' + tool_call.function.name + '[tool_args]' + tool_call.function.arguments + '[tool_end]' }}"
|
|
112
|
+
"{% endfor %}"
|
|
113
|
+
"{% endif %}"
|
|
114
|
+
|
|
115
|
+
"{{ '[im_end]\n' }}"
|
|
116
|
+
"{% endif %}"
|
|
117
|
+
"{% endfor %}"
|
|
118
|
+
|
|
119
|
+
"{% if add_generation_prompt %}"
|
|
120
|
+
"{% if messages[-1]['role'] != 'fim' %}"
|
|
121
|
+
"{{ '[im_start][model]\n' }}"
|
|
122
|
+
"{% if enable_thinking is defined and enable_thinking %}"
|
|
123
|
+
"{{ '[cot_start]' }}"
|
|
124
|
+
"{% elif enable_thinking is defined and not enable_thinking %}"
|
|
125
|
+
"{{ '[cot_start][cot_end]\n' }}"
|
|
126
|
+
"{% endif %}"
|
|
127
|
+
"{% endif %}"
|
|
128
|
+
"{% endif %}"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_tokenizer_trainer(
|
|
133
|
+
unk_token: str='[unk]',
|
|
134
|
+
vocab_size: int=32000,
|
|
135
|
+
special_tokens: list[str]=base_special_tokens
|
|
136
|
+
) -> TokenizerTrainerResult:
|
|
137
|
+
'''
|
|
138
|
+
Creates a BPE Tokenizer trainer.
|
|
139
|
+
|
|
140
|
+
Configures and returns a tokenizer trainer object for training BPE (Byte-Pair Encoding) models.
|
|
141
|
+
The trainer is pre-configured with NFKC normalization, digit splitting, and byte-level pre-tokenization.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
unk_token (str): Identifier for unknown tokens. Defaults to '[unk]'.
|
|
145
|
+
vocab_size (int): Target vocabulary size. Defaults to 32000.
|
|
146
|
+
special_tokens (list[str]): List of special tokens.
|
|
147
|
+
Defaults to base_special_tokens, including core, chat, reasoning, code, tool, and multimodal tokens.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
TokenizerTrainerResult: A dataclass containing the tokenizer and trainer instances.
|
|
151
|
+
'''
|
|
152
|
+
tokenizer = Tokenizer(BPE(unk_token=unk_token))
|
|
153
|
+
|
|
154
|
+
tokenizer.normalizer = normalizers.NFKC()
|
|
155
|
+
|
|
156
|
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
|
157
|
+
pre_tokenizers.Digits(individual_digits=True),
|
|
158
|
+
pre_tokenizers.ByteLevel(
|
|
159
|
+
add_prefix_space=False,
|
|
160
|
+
use_regex=True
|
|
161
|
+
)
|
|
162
|
+
])
|
|
163
|
+
|
|
164
|
+
tokenizer.decoder = decoders.ByteLevel()
|
|
165
|
+
|
|
166
|
+
trainer = BpeTrainer(
|
|
167
|
+
vocab_size=vocab_size,
|
|
168
|
+
special_tokens=special_tokens,
|
|
169
|
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
|
170
|
+
max_token_length=32,
|
|
171
|
+
min_frequency=10
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return TokenizerTrainerResult(tokenizer=tokenizer, trainer=trainer)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class PackedTokenizer:
|
|
178
|
+
def __init__(self, tokenizer: Optional[Union[Tokenizer, str]]):
|
|
179
|
+
self._tokenizer: Optional[Tokenizer] = None
|
|
180
|
+
self._fast_tokenizer: Optional[PreTrainedTokenizerFast] = None
|
|
181
|
+
self.config = {}
|
|
182
|
+
self.template = chat_template
|
|
183
|
+
|
|
184
|
+
if isinstance(tokenizer, str):
|
|
185
|
+
self.load(tokenizer)
|
|
186
|
+
elif isinstance(tokenizer, Tokenizer):
|
|
187
|
+
self._tokenizer = tokenizer
|
|
188
|
+
self.config = {
|
|
189
|
+
'unk_token': '[unk]',
|
|
190
|
+
'pad_token': '[pad]',
|
|
191
|
+
'bos_token': '[im_start]',
|
|
192
|
+
'eos_token': '[im_end]',
|
|
193
|
+
}
|
|
194
|
+
self._update_fast_tokenizer()
|
|
195
|
+
|
|
196
|
+
def _update_fast_tokenizer(self) -> None:
|
|
197
|
+
'''
|
|
198
|
+
Updates the cached PreTrainedTokenizerFast instance.
|
|
199
|
+
'''
|
|
200
|
+
if self._tokenizer is None:
|
|
201
|
+
self._fast_tokenizer = None
|
|
202
|
+
return
|
|
203
|
+
|
|
204
|
+
self._fast_tokenizer = PreTrainedTokenizerFast(
|
|
205
|
+
tokenizer_object=self._tokenizer,
|
|
206
|
+
unk_token=self.config.get('unk_token', '[unk]'),
|
|
207
|
+
pad_token=self.config.get('pad_token', '[pad]'),
|
|
208
|
+
bos_token=self.config.get('bos_token', '[im_start]'),
|
|
209
|
+
eos_token=self.config.get('eos_token', '[im_end]'),
|
|
210
|
+
chat_template=self.template,
|
|
211
|
+
clean_up_tokenization_spaces=False
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def tokenizer(self) -> Tokenizer:
|
|
216
|
+
if self._tokenizer is None:
|
|
217
|
+
raise ValueError("Tokenizer is not loaded.")
|
|
218
|
+
return self._tokenizer
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def fast_tokenizer(self) -> PreTrainedTokenizerFast:
|
|
222
|
+
if self._fast_tokenizer is None:
|
|
223
|
+
raise ValueError('Tokenizer is not loaded.')
|
|
224
|
+
return self._fast_tokenizer
|
|
225
|
+
|
|
226
|
+
def save(self, path: str) -> 'PackedTokenizer':
|
|
227
|
+
if self._tokenizer is None:
|
|
228
|
+
raise ValueError("No tokenizer to save.")
|
|
229
|
+
|
|
230
|
+
with zipfile.ZipFile(path, 'w') as z:
|
|
231
|
+
# Save tokenizer.json
|
|
232
|
+
z.writestr('tokenizer.json', self._tokenizer.to_str())
|
|
233
|
+
|
|
234
|
+
# Save tokenizer_config.json
|
|
235
|
+
z.writestr('tokenizer_config.json', json.dumps(self.config, indent=2))
|
|
236
|
+
|
|
237
|
+
# Save chat_template.jinja
|
|
238
|
+
z.writestr('chat_template.jinja', self.template)
|
|
239
|
+
|
|
240
|
+
return self
|
|
241
|
+
|
|
242
|
+
def load(self, path: str) -> 'PackedTokenizer':
|
|
243
|
+
if not os.path.exists(path):
|
|
244
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
245
|
+
|
|
246
|
+
with zipfile.ZipFile(path, 'r') as z:
|
|
247
|
+
file_list = z.namelist()
|
|
248
|
+
|
|
249
|
+
# Helper to find file in zip (ignoring directory prefix if any)
|
|
250
|
+
def find_file(name):
|
|
251
|
+
for f in file_list:
|
|
252
|
+
if f == name or f.endswith(f'/{name}'):
|
|
253
|
+
return f
|
|
254
|
+
return None
|
|
255
|
+
|
|
256
|
+
# Load tokenizer.json
|
|
257
|
+
tokenizer_file = find_file('tokenizer.json')
|
|
258
|
+
if tokenizer_file:
|
|
259
|
+
tokenizer_json = z.read(tokenizer_file).decode('utf-8')
|
|
260
|
+
self._tokenizer = Tokenizer.from_str(tokenizer_json)
|
|
261
|
+
else:
|
|
262
|
+
raise ValueError("tokenizer.json not found in zip file")
|
|
263
|
+
|
|
264
|
+
# Load tokenizer_config.json
|
|
265
|
+
config_file = find_file('tokenizer_config.json')
|
|
266
|
+
if config_file:
|
|
267
|
+
config_json = z.read(config_file).decode('utf-8')
|
|
268
|
+
self.config = json.loads(config_json)
|
|
269
|
+
|
|
270
|
+
# Load chat_template.jinja
|
|
271
|
+
template_file = find_file('chat_template.jinja')
|
|
272
|
+
if template_file:
|
|
273
|
+
self.template = z.read(template_file).decode('utf-8')
|
|
274
|
+
|
|
275
|
+
self._update_fast_tokenizer()
|
|
276
|
+
return self
|