langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/py.typed
ADDED
langtune/schedulers.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""
|
|
2
|
+
schedulers.py: Learning rate schedulers for Langtune
|
|
3
|
+
|
|
4
|
+
Provides various learning rate scheduling strategies.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import torch
|
|
9
|
+
from torch.optim.lr_scheduler import _LRScheduler
|
|
10
|
+
from typing import Optional, List
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class WarmupScheduler(_LRScheduler):
|
|
17
|
+
"""Linear warmup scheduler."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
optimizer: torch.optim.Optimizer,
|
|
22
|
+
warmup_steps: int,
|
|
23
|
+
base_lr: float = None,
|
|
24
|
+
last_epoch: int = -1
|
|
25
|
+
):
|
|
26
|
+
self.warmup_steps = warmup_steps
|
|
27
|
+
self.base_lr = base_lr
|
|
28
|
+
super().__init__(optimizer, last_epoch)
|
|
29
|
+
|
|
30
|
+
def get_lr(self):
|
|
31
|
+
if self.last_epoch < self.warmup_steps:
|
|
32
|
+
warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
|
|
33
|
+
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
34
|
+
return self.base_lrs
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CosineAnnealingWithWarmup(_LRScheduler):
|
|
38
|
+
"""Cosine annealing with linear warmup."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
optimizer: torch.optim.Optimizer,
|
|
43
|
+
total_steps: int,
|
|
44
|
+
warmup_steps: int = 0,
|
|
45
|
+
min_lr: float = 0.0,
|
|
46
|
+
last_epoch: int = -1
|
|
47
|
+
):
|
|
48
|
+
self.total_steps = total_steps
|
|
49
|
+
self.warmup_steps = warmup_steps
|
|
50
|
+
self.min_lr = min_lr
|
|
51
|
+
super().__init__(optimizer, last_epoch)
|
|
52
|
+
|
|
53
|
+
def get_lr(self):
|
|
54
|
+
if self.last_epoch < self.warmup_steps:
|
|
55
|
+
# Linear warmup
|
|
56
|
+
warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
|
|
57
|
+
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
58
|
+
else:
|
|
59
|
+
# Cosine annealing
|
|
60
|
+
progress = (self.last_epoch - self.warmup_steps) / max(self.total_steps - self.warmup_steps, 1)
|
|
61
|
+
return [
|
|
62
|
+
self.min_lr + 0.5 * (base_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
|
|
63
|
+
for base_lr in self.base_lrs
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class LinearDecayWithWarmup(_LRScheduler):
|
|
68
|
+
"""Linear decay with warmup."""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
optimizer: torch.optim.Optimizer,
|
|
73
|
+
total_steps: int,
|
|
74
|
+
warmup_steps: int = 0,
|
|
75
|
+
min_lr: float = 0.0,
|
|
76
|
+
last_epoch: int = -1
|
|
77
|
+
):
|
|
78
|
+
self.total_steps = total_steps
|
|
79
|
+
self.warmup_steps = warmup_steps
|
|
80
|
+
self.min_lr = min_lr
|
|
81
|
+
super().__init__(optimizer, last_epoch)
|
|
82
|
+
|
|
83
|
+
def get_lr(self):
|
|
84
|
+
if self.last_epoch < self.warmup_steps:
|
|
85
|
+
# Linear warmup
|
|
86
|
+
warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
|
|
87
|
+
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
88
|
+
else:
|
|
89
|
+
# Linear decay
|
|
90
|
+
progress = (self.last_epoch - self.warmup_steps) / max(self.total_steps - self.warmup_steps, 1)
|
|
91
|
+
return [
|
|
92
|
+
max(self.min_lr, base_lr * (1 - progress))
|
|
93
|
+
for base_lr in self.base_lrs
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class PolynomialDecayWithWarmup(_LRScheduler):
|
|
98
|
+
"""Polynomial decay with warmup."""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
optimizer: torch.optim.Optimizer,
|
|
103
|
+
total_steps: int,
|
|
104
|
+
warmup_steps: int = 0,
|
|
105
|
+
power: float = 2.0,
|
|
106
|
+
min_lr: float = 0.0,
|
|
107
|
+
last_epoch: int = -1
|
|
108
|
+
):
|
|
109
|
+
self.total_steps = total_steps
|
|
110
|
+
self.warmup_steps = warmup_steps
|
|
111
|
+
self.power = power
|
|
112
|
+
self.min_lr = min_lr
|
|
113
|
+
super().__init__(optimizer, last_epoch)
|
|
114
|
+
|
|
115
|
+
def get_lr(self):
|
|
116
|
+
if self.last_epoch < self.warmup_steps:
|
|
117
|
+
warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
|
|
118
|
+
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
119
|
+
else:
|
|
120
|
+
progress = (self.last_epoch - self.warmup_steps) / max(self.total_steps - self.warmup_steps, 1)
|
|
121
|
+
decay_factor = (1 - progress) ** self.power
|
|
122
|
+
return [
|
|
123
|
+
max(self.min_lr, base_lr * decay_factor)
|
|
124
|
+
for base_lr in self.base_lrs
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ConstantWithWarmup(_LRScheduler):
|
|
129
|
+
"""Constant learning rate with warmup."""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
optimizer: torch.optim.Optimizer,
|
|
134
|
+
warmup_steps: int = 0,
|
|
135
|
+
last_epoch: int = -1
|
|
136
|
+
):
|
|
137
|
+
self.warmup_steps = warmup_steps
|
|
138
|
+
super().__init__(optimizer, last_epoch)
|
|
139
|
+
|
|
140
|
+
def get_lr(self):
|
|
141
|
+
if self.last_epoch < self.warmup_steps:
|
|
142
|
+
warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
|
|
143
|
+
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
144
|
+
return self.base_lrs
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class OneCycleLRWithWarmup(_LRScheduler):
|
|
148
|
+
"""1cycle LR policy with customizable warmup."""
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
optimizer: torch.optim.Optimizer,
|
|
153
|
+
max_lr: float,
|
|
154
|
+
total_steps: int,
|
|
155
|
+
pct_start: float = 0.3,
|
|
156
|
+
div_factor: float = 25.0,
|
|
157
|
+
final_div_factor: float = 10000.0,
|
|
158
|
+
last_epoch: int = -1
|
|
159
|
+
):
|
|
160
|
+
self.max_lr = max_lr
|
|
161
|
+
self.total_steps = total_steps
|
|
162
|
+
self.pct_start = pct_start
|
|
163
|
+
self.div_factor = div_factor
|
|
164
|
+
self.final_div_factor = final_div_factor
|
|
165
|
+
|
|
166
|
+
# Calculate phase boundaries
|
|
167
|
+
self.warmup_steps = int(total_steps * pct_start)
|
|
168
|
+
self.cooldown_steps = total_steps - self.warmup_steps
|
|
169
|
+
|
|
170
|
+
super().__init__(optimizer, last_epoch)
|
|
171
|
+
|
|
172
|
+
def get_lr(self):
|
|
173
|
+
if self.last_epoch < self.warmup_steps:
|
|
174
|
+
# Warmup phase: linear increase
|
|
175
|
+
progress = self.last_epoch / max(self.warmup_steps, 1)
|
|
176
|
+
start_lr = self.max_lr / self.div_factor
|
|
177
|
+
return [start_lr + (self.max_lr - start_lr) * progress for _ in self.base_lrs]
|
|
178
|
+
else:
|
|
179
|
+
# Cooldown phase: cosine decay
|
|
180
|
+
progress = (self.last_epoch - self.warmup_steps) / max(self.cooldown_steps, 1)
|
|
181
|
+
end_lr = self.max_lr / self.final_div_factor
|
|
182
|
+
return [
|
|
183
|
+
end_lr + 0.5 * (self.max_lr - end_lr) * (1 + math.cos(math.pi * progress))
|
|
184
|
+
for _ in self.base_lrs
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_scheduler(
|
|
189
|
+
name: str,
|
|
190
|
+
optimizer: torch.optim.Optimizer,
|
|
191
|
+
total_steps: int,
|
|
192
|
+
warmup_steps: int = 0,
|
|
193
|
+
**kwargs
|
|
194
|
+
) -> _LRScheduler:
|
|
195
|
+
"""
|
|
196
|
+
Get a scheduler by name.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
name: Scheduler name ('cosine', 'linear', 'constant', 'polynomial', 'onecycle')
|
|
200
|
+
optimizer: Optimizer
|
|
201
|
+
total_steps: Total training steps
|
|
202
|
+
warmup_steps: Warmup steps
|
|
203
|
+
**kwargs: Additional scheduler arguments
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Learning rate scheduler
|
|
207
|
+
"""
|
|
208
|
+
schedulers = {
|
|
209
|
+
"cosine": CosineAnnealingWithWarmup,
|
|
210
|
+
"linear": LinearDecayWithWarmup,
|
|
211
|
+
"constant": ConstantWithWarmup,
|
|
212
|
+
"polynomial": PolynomialDecayWithWarmup,
|
|
213
|
+
"warmup": WarmupScheduler,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
if name == "onecycle":
|
|
217
|
+
max_lr = kwargs.get("max_lr", optimizer.param_groups[0]['lr'])
|
|
218
|
+
return OneCycleLRWithWarmup(
|
|
219
|
+
optimizer, max_lr, total_steps,
|
|
220
|
+
pct_start=warmup_steps / total_steps if total_steps > 0 else 0.3,
|
|
221
|
+
**{k: v for k, v in kwargs.items() if k != "max_lr"}
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if name not in schedulers:
|
|
225
|
+
raise ValueError(f"Unknown scheduler: {name}. Options: {list(schedulers.keys())}")
|
|
226
|
+
|
|
227
|
+
scheduler_cls = schedulers[name]
|
|
228
|
+
|
|
229
|
+
if name == "warmup":
|
|
230
|
+
return scheduler_cls(optimizer, warmup_steps, **kwargs)
|
|
231
|
+
elif name == "constant":
|
|
232
|
+
return scheduler_cls(optimizer, warmup_steps, **kwargs)
|
|
233
|
+
else:
|
|
234
|
+
return scheduler_cls(optimizer, total_steps, warmup_steps, **kwargs)
|
langtune/tokenizers.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""
|
|
2
|
+
tokenizers.py: Tokenization utilities for Langtune
|
|
3
|
+
|
|
4
|
+
Provides tokenization helpers and wrappers for various tokenizers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from typing import List, Dict, Optional, Union
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CharacterTokenizer:
|
|
15
|
+
"""Simple character-level tokenizer."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, vocab_size: int = 256):
|
|
18
|
+
self.vocab_size = vocab_size
|
|
19
|
+
self.pad_token_id = 0
|
|
20
|
+
self.unk_token_id = 1
|
|
21
|
+
self.bos_token_id = 2
|
|
22
|
+
self.eos_token_id = 3
|
|
23
|
+
|
|
24
|
+
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
|
25
|
+
"""Encode text to character IDs."""
|
|
26
|
+
tokens = []
|
|
27
|
+
if add_special_tokens:
|
|
28
|
+
tokens.append(self.bos_token_id)
|
|
29
|
+
|
|
30
|
+
for char in text:
|
|
31
|
+
token_id = ord(char) % (self.vocab_size - 4) + 4
|
|
32
|
+
tokens.append(token_id)
|
|
33
|
+
|
|
34
|
+
if add_special_tokens:
|
|
35
|
+
tokens.append(self.eos_token_id)
|
|
36
|
+
return tokens
|
|
37
|
+
|
|
38
|
+
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
|
|
39
|
+
"""Decode token IDs to text."""
|
|
40
|
+
chars = []
|
|
41
|
+
for tid in token_ids:
|
|
42
|
+
if skip_special_tokens and tid < 4:
|
|
43
|
+
continue
|
|
44
|
+
if tid >= 4:
|
|
45
|
+
chars.append(chr((tid - 4) % 128 + 32))
|
|
46
|
+
return ''.join(chars)
|
|
47
|
+
|
|
48
|
+
def __call__(self, text: str, **kwargs) -> Dict[str, List[int]]:
|
|
49
|
+
return {"input_ids": self.encode(text, **kwargs)}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class WordTokenizer:
|
|
53
|
+
"""Simple word-level tokenizer with vocabulary."""
|
|
54
|
+
|
|
55
|
+
def __init__(self, vocab: Optional[Dict[str, int]] = None, max_vocab_size: int = 32000):
|
|
56
|
+
self.max_vocab_size = max_vocab_size
|
|
57
|
+
self.vocab = vocab or {}
|
|
58
|
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
|
59
|
+
|
|
60
|
+
# Special tokens
|
|
61
|
+
self.pad_token = "<pad>"
|
|
62
|
+
self.unk_token = "<unk>"
|
|
63
|
+
self.bos_token = "<s>"
|
|
64
|
+
self.eos_token = "</s>"
|
|
65
|
+
|
|
66
|
+
self.pad_token_id = 0
|
|
67
|
+
self.unk_token_id = 1
|
|
68
|
+
self.bos_token_id = 2
|
|
69
|
+
self.eos_token_id = 3
|
|
70
|
+
|
|
71
|
+
if not self.vocab:
|
|
72
|
+
self.vocab = {
|
|
73
|
+
self.pad_token: 0,
|
|
74
|
+
self.unk_token: 1,
|
|
75
|
+
self.bos_token: 2,
|
|
76
|
+
self.eos_token: 3
|
|
77
|
+
}
|
|
78
|
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
|
79
|
+
|
|
80
|
+
def fit(self, texts: List[str], min_freq: int = 1):
|
|
81
|
+
"""Build vocabulary from texts."""
|
|
82
|
+
word_counts = {}
|
|
83
|
+
for text in texts:
|
|
84
|
+
for word in self._tokenize(text):
|
|
85
|
+
word_counts[word] = word_counts.get(word, 0) + 1
|
|
86
|
+
|
|
87
|
+
# Sort by frequency
|
|
88
|
+
sorted_words = sorted(word_counts.items(), key=lambda x: -x[1])
|
|
89
|
+
|
|
90
|
+
# Add to vocabulary
|
|
91
|
+
for word, count in sorted_words:
|
|
92
|
+
if count < min_freq:
|
|
93
|
+
break
|
|
94
|
+
if len(self.vocab) >= self.max_vocab_size:
|
|
95
|
+
break
|
|
96
|
+
if word not in self.vocab:
|
|
97
|
+
self.vocab[word] = len(self.vocab)
|
|
98
|
+
|
|
99
|
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
|
100
|
+
logger.info(f"Built vocabulary with {len(self.vocab)} tokens")
|
|
101
|
+
|
|
102
|
+
def _tokenize(self, text: str) -> List[str]:
|
|
103
|
+
"""Split text into words."""
|
|
104
|
+
text = text.lower()
|
|
105
|
+
return re.findall(r'\b\w+\b|[^\w\s]', text)
|
|
106
|
+
|
|
107
|
+
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
|
108
|
+
"""Encode text to token IDs."""
|
|
109
|
+
tokens = []
|
|
110
|
+
if add_special_tokens:
|
|
111
|
+
tokens.append(self.bos_token_id)
|
|
112
|
+
|
|
113
|
+
for word in self._tokenize(text):
|
|
114
|
+
tokens.append(self.vocab.get(word, self.unk_token_id))
|
|
115
|
+
|
|
116
|
+
if add_special_tokens:
|
|
117
|
+
tokens.append(self.eos_token_id)
|
|
118
|
+
return tokens
|
|
119
|
+
|
|
120
|
+
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
|
|
121
|
+
"""Decode token IDs to text."""
|
|
122
|
+
words = []
|
|
123
|
+
for tid in token_ids:
|
|
124
|
+
if skip_special_tokens and tid < 4:
|
|
125
|
+
continue
|
|
126
|
+
word = self.inv_vocab.get(tid, self.unk_token)
|
|
127
|
+
if not skip_special_tokens or not word.startswith("<"):
|
|
128
|
+
words.append(word)
|
|
129
|
+
return ' '.join(words)
|
|
130
|
+
|
|
131
|
+
def __call__(self, text: str, **kwargs) -> Dict[str, List[int]]:
|
|
132
|
+
return {"input_ids": self.encode(text, **kwargs)}
|
|
133
|
+
|
|
134
|
+
def save(self, path: str):
|
|
135
|
+
"""Save vocabulary to file."""
|
|
136
|
+
import json
|
|
137
|
+
with open(path, 'w') as f:
|
|
138
|
+
json.dump(self.vocab, f, indent=2)
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def load(cls, path: str) -> "WordTokenizer":
|
|
142
|
+
"""Load vocabulary from file."""
|
|
143
|
+
import json
|
|
144
|
+
with open(path) as f:
|
|
145
|
+
vocab = json.load(f)
|
|
146
|
+
return cls(vocab=vocab)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class BPETokenizer:
|
|
150
|
+
"""Simple Byte-Pair Encoding tokenizer."""
|
|
151
|
+
|
|
152
|
+
def __init__(self, vocab_size: int = 8000):
|
|
153
|
+
self.vocab_size = vocab_size
|
|
154
|
+
self.merges = {}
|
|
155
|
+
self.vocab = {}
|
|
156
|
+
|
|
157
|
+
self.pad_token_id = 0
|
|
158
|
+
self.unk_token_id = 1
|
|
159
|
+
self.bos_token_id = 2
|
|
160
|
+
self.eos_token_id = 3
|
|
161
|
+
|
|
162
|
+
def _get_pairs(self, word: List[str]) -> Dict[tuple, int]:
|
|
163
|
+
"""Get pairs of consecutive symbols."""
|
|
164
|
+
pairs = {}
|
|
165
|
+
for i in range(len(word) - 1):
|
|
166
|
+
pair = (word[i], word[i + 1])
|
|
167
|
+
pairs[pair] = pairs.get(pair, 0) + 1
|
|
168
|
+
return pairs
|
|
169
|
+
|
|
170
|
+
def fit(self, texts: List[str], num_merges: int = None):
|
|
171
|
+
"""Learn BPE merges from texts."""
|
|
172
|
+
num_merges = num_merges or (self.vocab_size - 256)
|
|
173
|
+
|
|
174
|
+
# Initialize vocabulary with characters
|
|
175
|
+
word_freqs = {}
|
|
176
|
+
for text in texts:
|
|
177
|
+
for word in text.split():
|
|
178
|
+
word = ' '.join(list(word)) + ' </w>'
|
|
179
|
+
word_freqs[word] = word_freqs.get(word, 0) + 1
|
|
180
|
+
|
|
181
|
+
# Build initial vocab
|
|
182
|
+
self.vocab = {chr(i): i for i in range(256)}
|
|
183
|
+
self.vocab['</w>'] = 256
|
|
184
|
+
|
|
185
|
+
# Learn merges
|
|
186
|
+
for i in range(num_merges):
|
|
187
|
+
pairs = {}
|
|
188
|
+
for word, freq in word_freqs.items():
|
|
189
|
+
word_pairs = self._get_pairs(word.split())
|
|
190
|
+
for pair, count in word_pairs.items():
|
|
191
|
+
pairs[pair] = pairs.get(pair, 0) + count * freq
|
|
192
|
+
|
|
193
|
+
if not pairs:
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
best_pair = max(pairs, key=pairs.get)
|
|
197
|
+
new_token = ''.join(best_pair)
|
|
198
|
+
|
|
199
|
+
if len(self.vocab) >= self.vocab_size:
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
self.merges[best_pair] = len(self.vocab)
|
|
203
|
+
self.vocab[new_token] = len(self.vocab)
|
|
204
|
+
|
|
205
|
+
# Update word_freqs
|
|
206
|
+
new_word_freqs = {}
|
|
207
|
+
pattern = ' '.join(best_pair)
|
|
208
|
+
replacement = new_token
|
|
209
|
+
for word, freq in word_freqs.items():
|
|
210
|
+
new_word = word.replace(pattern, replacement)
|
|
211
|
+
new_word_freqs[new_word] = freq
|
|
212
|
+
word_freqs = new_word_freqs
|
|
213
|
+
|
|
214
|
+
logger.info(f"Learned {len(self.merges)} BPE merges")
|
|
215
|
+
|
|
216
|
+
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
|
217
|
+
"""Encode text using BPE."""
|
|
218
|
+
tokens = []
|
|
219
|
+
if add_special_tokens:
|
|
220
|
+
tokens.append(self.bos_token_id)
|
|
221
|
+
|
|
222
|
+
for word in text.split():
|
|
223
|
+
word = ' '.join(list(word)) + ' </w>'
|
|
224
|
+
while True:
|
|
225
|
+
pairs = self._get_pairs(word.split())
|
|
226
|
+
if not pairs:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
# Find best pair that exists in merges
|
|
230
|
+
best_pair = None
|
|
231
|
+
best_rank = float('inf')
|
|
232
|
+
for pair in pairs:
|
|
233
|
+
if pair in self.merges and self.merges[pair] < best_rank:
|
|
234
|
+
best_pair = pair
|
|
235
|
+
best_rank = self.merges[pair]
|
|
236
|
+
|
|
237
|
+
if best_pair is None:
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
pattern = ' '.join(best_pair)
|
|
241
|
+
replacement = ''.join(best_pair)
|
|
242
|
+
word = word.replace(pattern, replacement)
|
|
243
|
+
|
|
244
|
+
for token in word.split():
|
|
245
|
+
tokens.append(self.vocab.get(token, self.unk_token_id))
|
|
246
|
+
|
|
247
|
+
if add_special_tokens:
|
|
248
|
+
tokens.append(self.eos_token_id)
|
|
249
|
+
return tokens
|
|
250
|
+
|
|
251
|
+
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
|
|
252
|
+
"""Decode BPE tokens."""
|
|
253
|
+
inv_vocab = {v: k for k, v in self.vocab.items()}
|
|
254
|
+
tokens = []
|
|
255
|
+
for tid in token_ids:
|
|
256
|
+
if skip_special_tokens and tid < 4:
|
|
257
|
+
continue
|
|
258
|
+
token = inv_vocab.get(tid, '')
|
|
259
|
+
tokens.append(token.replace('</w>', ' '))
|
|
260
|
+
return ''.join(tokens).strip()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def get_tokenizer(name: str = "character", **kwargs):
|
|
264
|
+
"""Get a tokenizer by name."""
|
|
265
|
+
tokenizers = {
|
|
266
|
+
"character": CharacterTokenizer,
|
|
267
|
+
"char": CharacterTokenizer,
|
|
268
|
+
"word": WordTokenizer,
|
|
269
|
+
"bpe": BPETokenizer
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
if name not in tokenizers:
|
|
273
|
+
raise ValueError(f"Unknown tokenizer: {name}. Options: {list(tokenizers.keys())}")
|
|
274
|
+
|
|
275
|
+
return tokenizers[name](**kwargs)
|