codon-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codon/utils/mask.py ADDED
@@ -0,0 +1,266 @@
1
+ import torch
2
+
3
+ from tokenizers import Tokenizer
4
+ from dataclasses import dataclass
5
+ from enum import Enum, auto
6
+
7
+ from typing import Union
8
+
9
+
10
+ class MaskMode(Enum):
11
+ '''
12
+ Enumeration of different masking modes for TokenMask.
13
+
14
+ Each mode defines how the sequence is masked relative to the special token(s).
15
+ The mask values are: 0 for masked, 1 for unmasked (kept).
16
+
17
+ Attributes:
18
+ FIRST_MASK_PRE: Find the first occurrence of the special token.
19
+ Mask tokens before and including the special token (0). Keep the rest (1).
20
+ FIRST_MASK_POST: Find the first occurrence of the special token.
21
+ Keep tokens before and including the special token (1). Mask the rest (0).
22
+ LAST_MASK_PRE: Find the last occurrence of the special token.
23
+ Mask tokens before and including the special token (0). Keep the rest (1).
24
+ LAST_MASK_POST: Find the last occurrence of the special token.
25
+ Keep tokens before and including the special token (1). Mask the rest (0).
26
+ ALL_MASK_FIRST: Find all occurrences.
27
+ The first segment (ending with the special token) is masked (0), then alternates.
28
+ ALL_KEEP_FIRST: Find all occurrences.
29
+ The first segment (ending with the special token) is kept (1), then alternates.
30
+ '''
31
+ FIRST_MASK_PRE = auto()
32
+ FIRST_MASK_POST = auto()
33
+ LAST_MASK_PRE = auto()
34
+ LAST_MASK_POST = auto()
35
+ ALL_MASK_FIRST = auto()
36
+ ALL_KEEP_FIRST = auto()
37
+
38
+
39
+ def make_padding_mask(src: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
40
+ '''
41
+ Creates a padding mask.
42
+
43
+ Args:
44
+ src (torch.Tensor): The source sequence tensor. Shape is [B, L_src].
45
+ pad_idx (int, optional): The index of the padding symbol. Defaults to 0.
46
+
47
+ Returns:
48
+ torch.Tensor: The padding mask. Shape is [B, 1, 1, L_src].
49
+ True indicates the position is not padding and should be attended to.
50
+ '''
51
+ mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
52
+ return mask
53
+
54
+
55
+ def make_lookahead_mask(size: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
56
+ '''
57
+ Creates a lookahead mask (lower triangular matrix).
58
+
59
+ Args:
60
+ size (int): The sequence length.
61
+ device (torch.device, optional): The device. Defaults to cpu.
62
+
63
+ Returns:
64
+ torch.Tensor: The lookahead mask. Shape is [size, size].
65
+ True indicates allowed positions to attend to (lower triangular part).
66
+ '''
67
+ mask = torch.tril(torch.ones((size, size), device=device)).bool()
68
+ return mask
69
+
70
+
71
+ def make_causal_mask(tgt: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
72
+ '''
73
+ Creates a causal mask (combining padding mask and lookahead mask).
74
+
75
+ Args:
76
+ tgt (torch.Tensor): The target sequence tensor. Shape is [B, L_tgt].
77
+ pad_idx (int, optional): The index of the padding symbol. Defaults to 0.
78
+
79
+ Returns:
80
+ torch.Tensor: The causal mask. Shape is [B, 1, L_tgt, L_tgt].
81
+ '''
82
+ pad_mask = make_padding_mask(tgt, pad_idx)
83
+ seq_len = tgt.size(1)
84
+ lookahead_mask = make_lookahead_mask(seq_len, device=tgt.device)
85
+
86
+ # pad_mask: [B, 1, 1, L]
87
+ # lookahead_mask: [L, L]
88
+ mask = pad_mask & lookahead_mask
89
+ return mask
90
+
91
+
92
+ def make_sliding_window_mask(
93
+ tensor: torch.Tensor, window_size: int, pad_idx: int = 0, causal: bool = True
94
+ ) -> torch.Tensor:
95
+ '''
96
+ Creates a sliding window mask.
97
+
98
+ Args:
99
+ tensor (torch.Tensor): The input sequence tensor. Shape is [B, L].
100
+ window_size (int): The window size (one-sided).
101
+ pad_idx (int, optional): The index of the padding symbol. Defaults to 0.
102
+ causal (bool, optional): Whether it is causal (unidirectional). Defaults to True.
103
+ If True, position i can only attend to [i - window_size, i].
104
+ If False, position i can attend to [i - window_size, i + window_size].
105
+
106
+ Returns:
107
+ torch.Tensor: The sliding window mask. Shape is [B, 1, L, L].
108
+ '''
109
+ pad_mask = make_padding_mask(tensor, pad_idx) # [B, 1, 1, L]
110
+ seq_len = tensor.size(1)
111
+
112
+ ones = torch.ones((seq_len, seq_len), device=tensor.device, dtype=torch.bool)
113
+
114
+ if causal:
115
+ # j <= i AND j >= i - window_size
116
+ window_mask = torch.tril(ones, diagonal=0) & torch.triu(
117
+ ones, diagonal=-window_size
118
+ )
119
+ else:
120
+ # j <= i + window_size AND j >= i - window_size
121
+ window_mask = torch.tril(ones, diagonal=window_size) & torch.triu(
122
+ ones, diagonal=-window_size
123
+ )
124
+
125
+ mask = pad_mask & window_mask
126
+ return mask
127
+
128
+
129
+ @dataclass
130
+ class MaskedContent:
131
+ '''
132
+ Result of the token masking process.
133
+
134
+ Attributes:
135
+ content (str): The original text content.
136
+ tokenized (Union[list[int], torch.Tensor]): The list of token IDs or tensor.
137
+ mask (Union[list[int], torch.Tensor]): The mask values (0 for masked, 1 for unmasked).
138
+ '''
139
+ content: str
140
+ tokenized: Union[list[int], torch.Tensor]
141
+ mask: Union[list[int], torch.Tensor]
142
+
143
+
144
+ class TokenMask:
145
+ '''
146
+ Handles token masking logic based on special tokens.
147
+ '''
148
+
149
+ def __init__(self, tokenizer: Tokenizer) -> None:
150
+ '''
151
+ Initializes the TokenMask.
152
+
153
+ Args:
154
+ tokenizer (Tokenizer): The configured tokenizer instance.
155
+ '''
156
+ self.tokenizer = tokenizer
157
+
158
+ def mask(
159
+ self,
160
+ content: str,
161
+ special_token: Union[str, int, list[Union[str, int]]],
162
+ mode: MaskMode = MaskMode.FIRST_MASK_PRE,
163
+ tensor_mask: bool = True
164
+ ) -> MaskedContent:
165
+ '''
166
+ Tokenizes content and generates a mask based on the specified mode.
167
+
168
+ Args:
169
+ content (str): The text content to tokenize and mask.
170
+ special_token (Union[str, int, list[Union[str, int]]]): The special token(s) to use as a separator.
171
+ mode (MaskMode): The masking mode. Defaults to MaskMode.FIRST_MASK_PRE.
172
+ tensor_mask (bool, optional): Whether to return tensors instead of lists. Defaults to True.
173
+
174
+ Returns:
175
+ MaskedContent: Dataclass containing the original content, token IDs, and the generated mask.
176
+ '''
177
+ encoded = self.tokenizer.encode(content)
178
+ ids = encoded.ids
179
+
180
+ candidates = []
181
+ if isinstance(special_token, list):
182
+ candidates = special_token
183
+ else:
184
+ candidates = [special_token]
185
+
186
+ # Determine the separator token id used in the sequence
187
+ sep_id = None
188
+ for cand in candidates:
189
+ tid = None
190
+ if isinstance(cand, str):
191
+ tid = self.tokenizer.token_to_id(cand)
192
+ elif isinstance(cand, int):
193
+ tid = cand
194
+
195
+ if tid is not None and tid in ids:
196
+ sep_id = tid
197
+ break
198
+
199
+ mask = []
200
+
201
+ if sep_id is None:
202
+ if mode in [MaskMode.FIRST_MASK_POST, MaskMode.LAST_MASK_POST, MaskMode.ALL_KEEP_FIRST]:
203
+ mask = [1] * len(ids)
204
+ else:
205
+ mask = [0] * len(ids)
206
+ else:
207
+ indices = [i for i, x in enumerate(ids) if x == sep_id]
208
+
209
+ if mode == MaskMode.FIRST_MASK_PRE:
210
+ # 1. Find first, mask pre (0), keep post (1).
211
+ # [0, 0, sep, 1, 1]
212
+ idx = indices[0]
213
+ mask = [0] * (idx + 1) + [1] * (len(ids) - idx - 1)
214
+
215
+ elif mode == MaskMode.FIRST_MASK_POST:
216
+ # 2. Find first, keep pre (1), mask post (0).
217
+ # [1, 1, sep, 0, 0]
218
+ idx = indices[0]
219
+ mask = [1] * (idx + 1) + [0] * (len(ids) - idx - 1)
220
+
221
+ elif mode == MaskMode.LAST_MASK_PRE:
222
+ # 3. Find last, mask pre (0), keep post (1).
223
+ idx = indices[-1]
224
+ mask = [0] * (idx + 1) + [1] * (len(ids) - idx - 1)
225
+
226
+ elif mode == MaskMode.LAST_MASK_POST:
227
+ # 4. Find last, keep pre (1), mask post (0).
228
+ idx = indices[-1]
229
+ mask = [1] * (idx + 1) + [0] * (len(ids) - idx - 1)
230
+
231
+ elif mode == MaskMode.ALL_MASK_FIRST:
232
+ # 5. All, segments. First seg mask (0), second unmask (1)...
233
+ # Segments end at sep.
234
+ current_val = 0
235
+ last_idx = 0
236
+ mask = []
237
+ for idx in indices:
238
+ # chunk includes sep
239
+ chunk_len = idx - last_idx + 1
240
+ mask.extend([current_val] * chunk_len)
241
+ current_val = 1 - current_val # toggle
242
+ last_idx = idx + 1
243
+
244
+ # Remaining part
245
+ if last_idx < len(ids):
246
+ mask.extend([current_val] * (len(ids) - last_idx))
247
+
248
+ elif mode == MaskMode.ALL_KEEP_FIRST:
249
+ # 6. All, segments. First seg keep (1), second mask (0)...
250
+ current_val = 1
251
+ last_idx = 0
252
+ mask = []
253
+ for idx in indices:
254
+ chunk_len = idx - last_idx + 1
255
+ mask.extend([current_val] * chunk_len)
256
+ current_val = 1 - current_val
257
+ last_idx = idx + 1
258
+
259
+ if last_idx < len(ids):
260
+ mask.extend([current_val] * (len(ids) - last_idx))
261
+
262
+ if tensor_mask:
263
+ ids = torch.tensor(ids)
264
+ mask = torch.tensor(mask)
265
+
266
+ return MaskedContent(content=content, tokenized=ids, mask=mask)
@@ -0,0 +1,24 @@
1
+ import random
2
+ import string
3
+
4
+
5
+ def safecode(length: int = 4, exclude_confusing: bool = False) -> str:
6
+ '''
7
+ Generates a random safe code consisting of letters and digits.
8
+
9
+ Args:
10
+ length (int): The length of the code to generate. Defaults to 4.
11
+ exclude_confusing (bool): If True, excludes confusing characters
12
+ ('0oO1iIlLq9g') to reduce human error. Defaults to False.
13
+
14
+ Returns:
15
+ str: The generated random code.
16
+ '''
17
+ characters = string.ascii_letters + string.digits
18
+
19
+ if exclude_confusing:
20
+ confusing_chars = '0oO1iIlLq9g'
21
+ characters = ''.join(c for c in characters if c not in confusing_chars)
22
+
23
+ code = ''.join(random.choices(characters, k=length))
24
+ return code
codon/utils/seed.py ADDED
@@ -0,0 +1,75 @@
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import os
5
+
6
+ from typing import Any, Optional
7
+
8
+ def seed_everything(seed: int = 42, strict: bool = False, warn_only: bool = True) -> None:
9
+ '''
10
+ Sets all random seeds to ensure reproducibility of PyTorch experiments.
11
+
12
+ Args:
13
+ seed (int): The random seed value. Defaults to 42.
14
+ strict (bool): Whether to enable strict deterministic mode.
15
+ If True, calls torch.use_deterministic_algorithms(True),
16
+ which may raise errors for certain operations that do not support
17
+ deterministic algorithms and may reduce training speed.
18
+ warn_only (bool): If True, only warns when deterministic algorithms
19
+ are not available. Defaults to True.
20
+ '''
21
+ import codon
22
+ codon.__seed__ = seed
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+
26
+ os.environ['PYTHONHASHSEED'] = str(seed)
27
+
28
+ torch.manual_seed(seed)
29
+ if torch.cuda.is_available():
30
+ torch.cuda.manual_seed_all(seed)
31
+
32
+ if torch.cuda.is_available():
33
+ torch.backends.cudnn.benchmark = False
34
+ torch.backends.cudnn.deterministic = True
35
+ if strict:
36
+ try:
37
+ torch.use_deterministic_algorithms(True, warn_only=warn_only)
38
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
39
+ print(f'[Info] Strict deterministic mode enabled. (seed={seed})')
40
+ except AttributeError:
41
+ print('[Warning] torch.use_deterministic_algorithms is not available in your PyTorch version.')
42
+ else:
43
+ print(f'[Info] Random seed set as {seed}')
44
+
45
+ def get_seed() -> Optional[int]:
46
+ '''
47
+ Retrieves the global random seed.
48
+
49
+ Returns:
50
+ Optional[int]: The current seed value, or None if not set.
51
+ '''
52
+ import codon
53
+ return codon.__seed__
54
+
55
+ def worker_init_fn(worker_id: Any) -> None:
56
+ '''
57
+ Worker initialization function for DataLoader to ensure each worker has a unique seed.
58
+
59
+ Args:
60
+ worker_id (Any): The worker ID.
61
+ '''
62
+ worker_seed = torch.initial_seed() % 2**32
63
+ np.random.seed(worker_seed)
64
+ random.seed(worker_seed)
65
+
66
+ def create_generator() -> torch.Generator:
67
+ '''
68
+ Creates a random number generator with the global seed.
69
+
70
+ Returns:
71
+ torch.Generator: The configured random number generator.
72
+ '''
73
+ import codon
74
+ seed = codon.__seed__ if hasattr(codon, '__seed__') and codon.__seed__ is not None else 42
75
+ return torch.Generator().manual_seed(seed)
codon/utils/theta.py ADDED
@@ -0,0 +1,55 @@
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class ValidateRoPEConfig:
6
+ '''
7
+ Configuration result for RoPE/Sinusoidal validation.
8
+
9
+ Attributes:
10
+ is_passed (bool): Whether the configuration is valid.
11
+ info (str): Assessment status.
12
+ suggested_base (float): The recommended base value.
13
+ '''
14
+ is_passed: bool
15
+ info: str
16
+ suggested_base: float
17
+
18
+ def validate_rope_config(max_len: int, base: float) -> ValidateRoPEConfig:
19
+ '''
20
+ Validates if the RoPE/Sinusoidal base is sufficient for the given maximum length.
21
+
22
+ Args:
23
+ max_len (int): The maximum sequence length.
24
+ base (float): The base value for RoPE.
25
+
26
+ Returns:
27
+ ValidateRoPEConfig: The validation result containing status and recommendation.
28
+ '''
29
+ max_period = 2 * math.pi * base
30
+
31
+ if max_len <= 8192:
32
+ recommended = 10000.0
33
+ else:
34
+ # base = 10000 * (scaling_factor ^ 1.1)
35
+ scaling_factor = max_len / 4096
36
+ recommended = 10000.0 * (scaling_factor ** 1.1)
37
+
38
+ if recommended > 1000000:
39
+ recommended = round(recommended / 1000000) * 1000000.0
40
+ elif recommended > 100000:
41
+ recommended = round(recommended / 100000) * 100000.0
42
+ else:
43
+ recommended = round(recommended / 10000) * 10000.0
44
+
45
+ if max_period < max_len:
46
+ return ValidateRoPEConfig(is_passed=False, info='critical_low', suggested_base=recommended)
47
+
48
+ elif max_period < max_len * 2:
49
+ return ValidateRoPEConfig(is_passed=False, info='low', suggested_base=recommended)
50
+
51
+ elif max_period > max_len * 100:
52
+ return ValidateRoPEConfig(is_passed=True, info='too_high', suggested_base=recommended)
53
+
54
+ else:
55
+ return ValidateRoPEConfig(is_passed=True, info='optimal', suggested_base=recommended)
codon/utils/token.py ADDED
@@ -0,0 +1,276 @@
1
+ import json
2
+ import os
3
+ import zipfile
4
+ from dataclasses import dataclass
5
+ from tokenizers import Tokenizer, pre_tokenizers, decoders
6
+ from tokenizers import normalizers
7
+ from tokenizers.models import BPE
8
+ from tokenizers.trainers import BpeTrainer
9
+
10
+ from transformers import PreTrainedTokenizerFast
11
+
12
+ from typing import Union, Optional
13
+
14
+ @dataclass
15
+ class TokenizerTrainerResult:
16
+ '''
17
+ Result of the tokenizer trainer creation.
18
+
19
+ Attributes:
20
+ tokenizer (Tokenizer): The configured tokenizer instance.
21
+ trainer (BpeTrainer): The configured BPE trainer instance.
22
+ '''
23
+ tokenizer: Tokenizer
24
+ trainer: BpeTrainer
25
+
26
+
27
+ core_tokens = ['[unk]', '[pad]', '[sep]']
28
+ chat_tokens = [
29
+ '[im_start]', '[im_end]',
30
+ '[system]', '[user]', '[model]', '[tool]', '[train]',
31
+ '[interruption]', '[fim]',
32
+ ]
33
+ reasoning_tokens = ['[cot_start]', '[cot_end]', '[verification]', '[solution]']
34
+ code_tokens = ['[fim_pre]', '[fim_mid]', '[fim_suf]']
35
+ tool_tokens = ['[tool_start]', '[tool_name]', '[tool_args]', '[tool_end]']
36
+
37
+ multimodal_tokens = [
38
+ '[image_start]', '[image_end]', '[audio_start]', '[audio_end]',
39
+ '[video_start]', '[video_end]'
40
+ ]
41
+
42
+ base_special_tokens = (
43
+ core_tokens +
44
+ chat_tokens +
45
+ reasoning_tokens +
46
+ code_tokens +
47
+ tool_tokens +
48
+ multimodal_tokens
49
+ )
50
+
51
+ base_special_tokens += [f'[unused_{i}]' for i in range(len(base_special_tokens), 64)]
52
+ base_special_tokens += [f'[mask_{i}]' for i in range(32)]
53
+
54
+ chat_template = (
55
+ "{% for message in messages %}"
56
+ "{{ '[im_start]' }}"
57
+
58
+ "{% if message['role'] == 'fim' %}"
59
+ "{{ '[fim]' }}"
60
+ "{{ '[fim_pre]' + message['prefix'] + '[fim_suf]' + message['suffix'] + '[fim_mid]' }}"
61
+
62
+ "{% if message['middle'] %}"
63
+ "{{ message['middle'] + '[im_end]' }}"
64
+ "{% endif %}"
65
+
66
+ "{% else %}"
67
+
68
+ "{% if message['role'] in ['system', 'instruction'] %}"
69
+ "{{ '[system]' }}"
70
+ "{% elif message['role'] == 'user' %}"
71
+ "{{ '[user]' }}"
72
+ "{% elif message['role'] in ['assistant', 'model'] %}"
73
+ "{{ '[model]' }}"
74
+ "{% elif message['role'] == 'tool' %}"
75
+ "{{ '[tool]' }}"
76
+ "{% elif message['role'] == 'train' %}"
77
+ "{{ '[train]' }}"
78
+ "{% else %}"
79
+ "{{ message['role'] }}"
80
+ "{% endif %}"
81
+
82
+ "{{ '\n' }}"
83
+
84
+ "{% set thought_content = message['thought'] or message['reasoning_content'] %}"
85
+ "{% if thought_content %}"
86
+ "{{ '[cot_start]' + thought_content + '[cot_end]\n' }}"
87
+ "{% else %}"
88
+ "{{ '[cot_start][cot_end]\n' }}"
89
+ "{% endif %}"
90
+
91
+ "{% if message['content'] is defined and message['content'] is not none %}"
92
+ "{% if message['content'] is string %}"
93
+ "{{ message['content'] }}"
94
+ "{% else %}"
95
+ "{% for item in message['content'] %}"
96
+ "{% if item['type'] == 'text' %}"
97
+ "{{ item['text'] }}"
98
+ "{% elif item['type'] == 'image' %}"
99
+ "{{ '[image_start][image_end]' }}"
100
+ "{% elif item['type'] == 'audio' %}"
101
+ "{{ '[audio_start][audio_end]' }}"
102
+ "{% elif item['type'] == 'video' %}"
103
+ "{{ '[video_start][video_end]' }}"
104
+ "{% endif %}"
105
+ "{% endfor %}"
106
+ "{% endif %}"
107
+ "{% endif %}"
108
+
109
+ "{% if message['tool_calls'] is defined and message['tool_calls'] %}"
110
+ "{% for tool_call in message['tool_calls'] %}"
111
+ "{{ '[tool_start][tool_name]' + tool_call.function.name + '[tool_args]' + tool_call.function.arguments + '[tool_end]' }}"
112
+ "{% endfor %}"
113
+ "{% endif %}"
114
+
115
+ "{{ '[im_end]\n' }}"
116
+ "{% endif %}"
117
+ "{% endfor %}"
118
+
119
+ "{% if add_generation_prompt %}"
120
+ "{% if messages[-1]['role'] != 'fim' %}"
121
+ "{{ '[im_start][model]\n' }}"
122
+ "{% if enable_thinking is defined and enable_thinking %}"
123
+ "{{ '[cot_start]' }}"
124
+ "{% elif enable_thinking is defined and not enable_thinking %}"
125
+ "{{ '[cot_start][cot_end]\n' }}"
126
+ "{% endif %}"
127
+ "{% endif %}"
128
+ "{% endif %}"
129
+ )
130
+
131
+
132
+ def create_tokenizer_trainer(
133
+ unk_token: str='[unk]',
134
+ vocab_size: int=32000,
135
+ special_tokens: list[str]=base_special_tokens
136
+ ) -> TokenizerTrainerResult:
137
+ '''
138
+ Creates a BPE Tokenizer trainer.
139
+
140
+ Configures and returns a tokenizer trainer object for training BPE (Byte-Pair Encoding) models.
141
+ The trainer is pre-configured with NFKC normalization, digit splitting, and byte-level pre-tokenization.
142
+
143
+ Args:
144
+ unk_token (str): Identifier for unknown tokens. Defaults to '[unk]'.
145
+ vocab_size (int): Target vocabulary size. Defaults to 32000.
146
+ special_tokens (list[str]): List of special tokens.
147
+ Defaults to base_special_tokens, including core, chat, reasoning, code, tool, and multimodal tokens.
148
+
149
+ Returns:
150
+ TokenizerTrainerResult: A dataclass containing the tokenizer and trainer instances.
151
+ '''
152
+ tokenizer = Tokenizer(BPE(unk_token=unk_token))
153
+
154
+ tokenizer.normalizer = normalizers.NFKC()
155
+
156
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
157
+ pre_tokenizers.Digits(individual_digits=True),
158
+ pre_tokenizers.ByteLevel(
159
+ add_prefix_space=False,
160
+ use_regex=True
161
+ )
162
+ ])
163
+
164
+ tokenizer.decoder = decoders.ByteLevel()
165
+
166
+ trainer = BpeTrainer(
167
+ vocab_size=vocab_size,
168
+ special_tokens=special_tokens,
169
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
170
+ max_token_length=32,
171
+ min_frequency=10
172
+ )
173
+
174
+ return TokenizerTrainerResult(tokenizer=tokenizer, trainer=trainer)
175
+
176
+
177
+ class PackedTokenizer:
178
+ def __init__(self, tokenizer: Optional[Union[Tokenizer, str]]):
179
+ self._tokenizer: Optional[Tokenizer] = None
180
+ self._fast_tokenizer: Optional[PreTrainedTokenizerFast] = None
181
+ self.config = {}
182
+ self.template = chat_template
183
+
184
+ if isinstance(tokenizer, str):
185
+ self.load(tokenizer)
186
+ elif isinstance(tokenizer, Tokenizer):
187
+ self._tokenizer = tokenizer
188
+ self.config = {
189
+ 'unk_token': '[unk]',
190
+ 'pad_token': '[pad]',
191
+ 'bos_token': '[im_start]',
192
+ 'eos_token': '[im_end]',
193
+ }
194
+ self._update_fast_tokenizer()
195
+
196
+ def _update_fast_tokenizer(self) -> None:
197
+ '''
198
+ Updates the cached PreTrainedTokenizerFast instance.
199
+ '''
200
+ if self._tokenizer is None:
201
+ self._fast_tokenizer = None
202
+ return
203
+
204
+ self._fast_tokenizer = PreTrainedTokenizerFast(
205
+ tokenizer_object=self._tokenizer,
206
+ unk_token=self.config.get('unk_token', '[unk]'),
207
+ pad_token=self.config.get('pad_token', '[pad]'),
208
+ bos_token=self.config.get('bos_token', '[im_start]'),
209
+ eos_token=self.config.get('eos_token', '[im_end]'),
210
+ chat_template=self.template,
211
+ clean_up_tokenization_spaces=False
212
+ )
213
+
214
+ @property
215
+ def tokenizer(self) -> Tokenizer:
216
+ if self._tokenizer is None:
217
+ raise ValueError("Tokenizer is not loaded.")
218
+ return self._tokenizer
219
+
220
+ @property
221
+ def fast_tokenizer(self) -> PreTrainedTokenizerFast:
222
+ if self._fast_tokenizer is None:
223
+ raise ValueError('Tokenizer is not loaded.')
224
+ return self._fast_tokenizer
225
+
226
+ def save(self, path: str) -> 'PackedTokenizer':
227
+ if self._tokenizer is None:
228
+ raise ValueError("No tokenizer to save.")
229
+
230
+ with zipfile.ZipFile(path, 'w') as z:
231
+ # Save tokenizer.json
232
+ z.writestr('tokenizer.json', self._tokenizer.to_str())
233
+
234
+ # Save tokenizer_config.json
235
+ z.writestr('tokenizer_config.json', json.dumps(self.config, indent=2))
236
+
237
+ # Save chat_template.jinja
238
+ z.writestr('chat_template.jinja', self.template)
239
+
240
+ return self
241
+
242
+ def load(self, path: str) -> 'PackedTokenizer':
243
+ if not os.path.exists(path):
244
+ raise FileNotFoundError(f"File not found: {path}")
245
+
246
+ with zipfile.ZipFile(path, 'r') as z:
247
+ file_list = z.namelist()
248
+
249
+ # Helper to find file in zip (ignoring directory prefix if any)
250
+ def find_file(name):
251
+ for f in file_list:
252
+ if f == name or f.endswith(f'/{name}'):
253
+ return f
254
+ return None
255
+
256
+ # Load tokenizer.json
257
+ tokenizer_file = find_file('tokenizer.json')
258
+ if tokenizer_file:
259
+ tokenizer_json = z.read(tokenizer_file).decode('utf-8')
260
+ self._tokenizer = Tokenizer.from_str(tokenizer_json)
261
+ else:
262
+ raise ValueError("tokenizer.json not found in zip file")
263
+
264
+ # Load tokenizer_config.json
265
+ config_file = find_file('tokenizer_config.json')
266
+ if config_file:
267
+ config_json = z.read(config_file).decode('utf-8')
268
+ self.config = json.loads(config_json)
269
+
270
+ # Load chat_template.jinja
271
+ template_file = find_file('chat_template.jinja')
272
+ if template_file:
273
+ self.template = z.read(template_file).decode('utf-8')
274
+
275
+ self._update_fast_tokenizer()
276
+ return self