langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/py.typed ADDED
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561.
2
+ # The langtune package uses inline type annotations.
langtune/schedulers.py ADDED
@@ -0,0 +1,234 @@
1
+ """
2
+ schedulers.py: Learning rate schedulers for Langtune
3
+
4
+ Provides various learning rate scheduling strategies.
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ from torch.optim.lr_scheduler import _LRScheduler
10
+ from typing import Optional, List
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class WarmupScheduler(_LRScheduler):
17
+ """Linear warmup scheduler."""
18
+
19
+ def __init__(
20
+ self,
21
+ optimizer: torch.optim.Optimizer,
22
+ warmup_steps: int,
23
+ base_lr: float = None,
24
+ last_epoch: int = -1
25
+ ):
26
+ self.warmup_steps = warmup_steps
27
+ self.base_lr = base_lr
28
+ super().__init__(optimizer, last_epoch)
29
+
30
+ def get_lr(self):
31
+ if self.last_epoch < self.warmup_steps:
32
+ warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
33
+ return [base_lr * warmup_factor for base_lr in self.base_lrs]
34
+ return self.base_lrs
35
+
36
+
37
+ class CosineAnnealingWithWarmup(_LRScheduler):
38
+ """Cosine annealing with linear warmup."""
39
+
40
+ def __init__(
41
+ self,
42
+ optimizer: torch.optim.Optimizer,
43
+ total_steps: int,
44
+ warmup_steps: int = 0,
45
+ min_lr: float = 0.0,
46
+ last_epoch: int = -1
47
+ ):
48
+ self.total_steps = total_steps
49
+ self.warmup_steps = warmup_steps
50
+ self.min_lr = min_lr
51
+ super().__init__(optimizer, last_epoch)
52
+
53
+ def get_lr(self):
54
+ if self.last_epoch < self.warmup_steps:
55
+ # Linear warmup
56
+ warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
57
+ return [base_lr * warmup_factor for base_lr in self.base_lrs]
58
+ else:
59
+ # Cosine annealing
60
+ progress = (self.last_epoch - self.warmup_steps) / max(self.total_steps - self.warmup_steps, 1)
61
+ return [
62
+ self.min_lr + 0.5 * (base_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
63
+ for base_lr in self.base_lrs
64
+ ]
65
+
66
+
67
+ class LinearDecayWithWarmup(_LRScheduler):
68
+ """Linear decay with warmup."""
69
+
70
+ def __init__(
71
+ self,
72
+ optimizer: torch.optim.Optimizer,
73
+ total_steps: int,
74
+ warmup_steps: int = 0,
75
+ min_lr: float = 0.0,
76
+ last_epoch: int = -1
77
+ ):
78
+ self.total_steps = total_steps
79
+ self.warmup_steps = warmup_steps
80
+ self.min_lr = min_lr
81
+ super().__init__(optimizer, last_epoch)
82
+
83
+ def get_lr(self):
84
+ if self.last_epoch < self.warmup_steps:
85
+ # Linear warmup
86
+ warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
87
+ return [base_lr * warmup_factor for base_lr in self.base_lrs]
88
+ else:
89
+ # Linear decay
90
+ progress = (self.last_epoch - self.warmup_steps) / max(self.total_steps - self.warmup_steps, 1)
91
+ return [
92
+ max(self.min_lr, base_lr * (1 - progress))
93
+ for base_lr in self.base_lrs
94
+ ]
95
+
96
+
97
+ class PolynomialDecayWithWarmup(_LRScheduler):
98
+ """Polynomial decay with warmup."""
99
+
100
+ def __init__(
101
+ self,
102
+ optimizer: torch.optim.Optimizer,
103
+ total_steps: int,
104
+ warmup_steps: int = 0,
105
+ power: float = 2.0,
106
+ min_lr: float = 0.0,
107
+ last_epoch: int = -1
108
+ ):
109
+ self.total_steps = total_steps
110
+ self.warmup_steps = warmup_steps
111
+ self.power = power
112
+ self.min_lr = min_lr
113
+ super().__init__(optimizer, last_epoch)
114
+
115
+ def get_lr(self):
116
+ if self.last_epoch < self.warmup_steps:
117
+ warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
118
+ return [base_lr * warmup_factor for base_lr in self.base_lrs]
119
+ else:
120
+ progress = (self.last_epoch - self.warmup_steps) / max(self.total_steps - self.warmup_steps, 1)
121
+ decay_factor = (1 - progress) ** self.power
122
+ return [
123
+ max(self.min_lr, base_lr * decay_factor)
124
+ for base_lr in self.base_lrs
125
+ ]
126
+
127
+
128
+ class ConstantWithWarmup(_LRScheduler):
129
+ """Constant learning rate with warmup."""
130
+
131
+ def __init__(
132
+ self,
133
+ optimizer: torch.optim.Optimizer,
134
+ warmup_steps: int = 0,
135
+ last_epoch: int = -1
136
+ ):
137
+ self.warmup_steps = warmup_steps
138
+ super().__init__(optimizer, last_epoch)
139
+
140
+ def get_lr(self):
141
+ if self.last_epoch < self.warmup_steps:
142
+ warmup_factor = (self.last_epoch + 1) / max(self.warmup_steps, 1)
143
+ return [base_lr * warmup_factor for base_lr in self.base_lrs]
144
+ return self.base_lrs
145
+
146
+
147
+ class OneCycleLRWithWarmup(_LRScheduler):
148
+ """1cycle LR policy with customizable warmup."""
149
+
150
+ def __init__(
151
+ self,
152
+ optimizer: torch.optim.Optimizer,
153
+ max_lr: float,
154
+ total_steps: int,
155
+ pct_start: float = 0.3,
156
+ div_factor: float = 25.0,
157
+ final_div_factor: float = 10000.0,
158
+ last_epoch: int = -1
159
+ ):
160
+ self.max_lr = max_lr
161
+ self.total_steps = total_steps
162
+ self.pct_start = pct_start
163
+ self.div_factor = div_factor
164
+ self.final_div_factor = final_div_factor
165
+
166
+ # Calculate phase boundaries
167
+ self.warmup_steps = int(total_steps * pct_start)
168
+ self.cooldown_steps = total_steps - self.warmup_steps
169
+
170
+ super().__init__(optimizer, last_epoch)
171
+
172
+ def get_lr(self):
173
+ if self.last_epoch < self.warmup_steps:
174
+ # Warmup phase: linear increase
175
+ progress = self.last_epoch / max(self.warmup_steps, 1)
176
+ start_lr = self.max_lr / self.div_factor
177
+ return [start_lr + (self.max_lr - start_lr) * progress for _ in self.base_lrs]
178
+ else:
179
+ # Cooldown phase: cosine decay
180
+ progress = (self.last_epoch - self.warmup_steps) / max(self.cooldown_steps, 1)
181
+ end_lr = self.max_lr / self.final_div_factor
182
+ return [
183
+ end_lr + 0.5 * (self.max_lr - end_lr) * (1 + math.cos(math.pi * progress))
184
+ for _ in self.base_lrs
185
+ ]
186
+
187
+
188
+ def get_scheduler(
189
+ name: str,
190
+ optimizer: torch.optim.Optimizer,
191
+ total_steps: int,
192
+ warmup_steps: int = 0,
193
+ **kwargs
194
+ ) -> _LRScheduler:
195
+ """
196
+ Get a scheduler by name.
197
+
198
+ Args:
199
+ name: Scheduler name ('cosine', 'linear', 'constant', 'polynomial', 'onecycle')
200
+ optimizer: Optimizer
201
+ total_steps: Total training steps
202
+ warmup_steps: Warmup steps
203
+ **kwargs: Additional scheduler arguments
204
+
205
+ Returns:
206
+ Learning rate scheduler
207
+ """
208
+ schedulers = {
209
+ "cosine": CosineAnnealingWithWarmup,
210
+ "linear": LinearDecayWithWarmup,
211
+ "constant": ConstantWithWarmup,
212
+ "polynomial": PolynomialDecayWithWarmup,
213
+ "warmup": WarmupScheduler,
214
+ }
215
+
216
+ if name == "onecycle":
217
+ max_lr = kwargs.get("max_lr", optimizer.param_groups[0]['lr'])
218
+ return OneCycleLRWithWarmup(
219
+ optimizer, max_lr, total_steps,
220
+ pct_start=warmup_steps / total_steps if total_steps > 0 else 0.3,
221
+ **{k: v for k, v in kwargs.items() if k != "max_lr"}
222
+ )
223
+
224
+ if name not in schedulers:
225
+ raise ValueError(f"Unknown scheduler: {name}. Options: {list(schedulers.keys())}")
226
+
227
+ scheduler_cls = schedulers[name]
228
+
229
+ if name == "warmup":
230
+ return scheduler_cls(optimizer, warmup_steps, **kwargs)
231
+ elif name == "constant":
232
+ return scheduler_cls(optimizer, warmup_steps, **kwargs)
233
+ else:
234
+ return scheduler_cls(optimizer, total_steps, warmup_steps, **kwargs)
langtune/tokenizers.py ADDED
@@ -0,0 +1,275 @@
1
+ """
2
+ tokenizers.py: Tokenization utilities for Langtune
3
+
4
+ Provides tokenization helpers and wrappers for various tokenizers.
5
+ """
6
+
7
+ import re
8
+ from typing import List, Dict, Optional, Union
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CharacterTokenizer:
15
+ """Simple character-level tokenizer."""
16
+
17
+ def __init__(self, vocab_size: int = 256):
18
+ self.vocab_size = vocab_size
19
+ self.pad_token_id = 0
20
+ self.unk_token_id = 1
21
+ self.bos_token_id = 2
22
+ self.eos_token_id = 3
23
+
24
+ def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
25
+ """Encode text to character IDs."""
26
+ tokens = []
27
+ if add_special_tokens:
28
+ tokens.append(self.bos_token_id)
29
+
30
+ for char in text:
31
+ token_id = ord(char) % (self.vocab_size - 4) + 4
32
+ tokens.append(token_id)
33
+
34
+ if add_special_tokens:
35
+ tokens.append(self.eos_token_id)
36
+ return tokens
37
+
38
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
39
+ """Decode token IDs to text."""
40
+ chars = []
41
+ for tid in token_ids:
42
+ if skip_special_tokens and tid < 4:
43
+ continue
44
+ if tid >= 4:
45
+ chars.append(chr((tid - 4) % 128 + 32))
46
+ return ''.join(chars)
47
+
48
+ def __call__(self, text: str, **kwargs) -> Dict[str, List[int]]:
49
+ return {"input_ids": self.encode(text, **kwargs)}
50
+
51
+
52
+ class WordTokenizer:
53
+ """Simple word-level tokenizer with vocabulary."""
54
+
55
+ def __init__(self, vocab: Optional[Dict[str, int]] = None, max_vocab_size: int = 32000):
56
+ self.max_vocab_size = max_vocab_size
57
+ self.vocab = vocab or {}
58
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
59
+
60
+ # Special tokens
61
+ self.pad_token = "<pad>"
62
+ self.unk_token = "<unk>"
63
+ self.bos_token = "<s>"
64
+ self.eos_token = "</s>"
65
+
66
+ self.pad_token_id = 0
67
+ self.unk_token_id = 1
68
+ self.bos_token_id = 2
69
+ self.eos_token_id = 3
70
+
71
+ if not self.vocab:
72
+ self.vocab = {
73
+ self.pad_token: 0,
74
+ self.unk_token: 1,
75
+ self.bos_token: 2,
76
+ self.eos_token: 3
77
+ }
78
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
79
+
80
+ def fit(self, texts: List[str], min_freq: int = 1):
81
+ """Build vocabulary from texts."""
82
+ word_counts = {}
83
+ for text in texts:
84
+ for word in self._tokenize(text):
85
+ word_counts[word] = word_counts.get(word, 0) + 1
86
+
87
+ # Sort by frequency
88
+ sorted_words = sorted(word_counts.items(), key=lambda x: -x[1])
89
+
90
+ # Add to vocabulary
91
+ for word, count in sorted_words:
92
+ if count < min_freq:
93
+ break
94
+ if len(self.vocab) >= self.max_vocab_size:
95
+ break
96
+ if word not in self.vocab:
97
+ self.vocab[word] = len(self.vocab)
98
+
99
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
100
+ logger.info(f"Built vocabulary with {len(self.vocab)} tokens")
101
+
102
+ def _tokenize(self, text: str) -> List[str]:
103
+ """Split text into words."""
104
+ text = text.lower()
105
+ return re.findall(r'\b\w+\b|[^\w\s]', text)
106
+
107
+ def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
108
+ """Encode text to token IDs."""
109
+ tokens = []
110
+ if add_special_tokens:
111
+ tokens.append(self.bos_token_id)
112
+
113
+ for word in self._tokenize(text):
114
+ tokens.append(self.vocab.get(word, self.unk_token_id))
115
+
116
+ if add_special_tokens:
117
+ tokens.append(self.eos_token_id)
118
+ return tokens
119
+
120
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
121
+ """Decode token IDs to text."""
122
+ words = []
123
+ for tid in token_ids:
124
+ if skip_special_tokens and tid < 4:
125
+ continue
126
+ word = self.inv_vocab.get(tid, self.unk_token)
127
+ if not skip_special_tokens or not word.startswith("<"):
128
+ words.append(word)
129
+ return ' '.join(words)
130
+
131
+ def __call__(self, text: str, **kwargs) -> Dict[str, List[int]]:
132
+ return {"input_ids": self.encode(text, **kwargs)}
133
+
134
+ def save(self, path: str):
135
+ """Save vocabulary to file."""
136
+ import json
137
+ with open(path, 'w') as f:
138
+ json.dump(self.vocab, f, indent=2)
139
+
140
+ @classmethod
141
+ def load(cls, path: str) -> "WordTokenizer":
142
+ """Load vocabulary from file."""
143
+ import json
144
+ with open(path) as f:
145
+ vocab = json.load(f)
146
+ return cls(vocab=vocab)
147
+
148
+
149
+ class BPETokenizer:
150
+ """Simple Byte-Pair Encoding tokenizer."""
151
+
152
+ def __init__(self, vocab_size: int = 8000):
153
+ self.vocab_size = vocab_size
154
+ self.merges = {}
155
+ self.vocab = {}
156
+
157
+ self.pad_token_id = 0
158
+ self.unk_token_id = 1
159
+ self.bos_token_id = 2
160
+ self.eos_token_id = 3
161
+
162
+ def _get_pairs(self, word: List[str]) -> Dict[tuple, int]:
163
+ """Get pairs of consecutive symbols."""
164
+ pairs = {}
165
+ for i in range(len(word) - 1):
166
+ pair = (word[i], word[i + 1])
167
+ pairs[pair] = pairs.get(pair, 0) + 1
168
+ return pairs
169
+
170
+ def fit(self, texts: List[str], num_merges: int = None):
171
+ """Learn BPE merges from texts."""
172
+ num_merges = num_merges or (self.vocab_size - 256)
173
+
174
+ # Initialize vocabulary with characters
175
+ word_freqs = {}
176
+ for text in texts:
177
+ for word in text.split():
178
+ word = ' '.join(list(word)) + ' </w>'
179
+ word_freqs[word] = word_freqs.get(word, 0) + 1
180
+
181
+ # Build initial vocab
182
+ self.vocab = {chr(i): i for i in range(256)}
183
+ self.vocab['</w>'] = 256
184
+
185
+ # Learn merges
186
+ for i in range(num_merges):
187
+ pairs = {}
188
+ for word, freq in word_freqs.items():
189
+ word_pairs = self._get_pairs(word.split())
190
+ for pair, count in word_pairs.items():
191
+ pairs[pair] = pairs.get(pair, 0) + count * freq
192
+
193
+ if not pairs:
194
+ break
195
+
196
+ best_pair = max(pairs, key=pairs.get)
197
+ new_token = ''.join(best_pair)
198
+
199
+ if len(self.vocab) >= self.vocab_size:
200
+ break
201
+
202
+ self.merges[best_pair] = len(self.vocab)
203
+ self.vocab[new_token] = len(self.vocab)
204
+
205
+ # Update word_freqs
206
+ new_word_freqs = {}
207
+ pattern = ' '.join(best_pair)
208
+ replacement = new_token
209
+ for word, freq in word_freqs.items():
210
+ new_word = word.replace(pattern, replacement)
211
+ new_word_freqs[new_word] = freq
212
+ word_freqs = new_word_freqs
213
+
214
+ logger.info(f"Learned {len(self.merges)} BPE merges")
215
+
216
+ def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
217
+ """Encode text using BPE."""
218
+ tokens = []
219
+ if add_special_tokens:
220
+ tokens.append(self.bos_token_id)
221
+
222
+ for word in text.split():
223
+ word = ' '.join(list(word)) + ' </w>'
224
+ while True:
225
+ pairs = self._get_pairs(word.split())
226
+ if not pairs:
227
+ break
228
+
229
+ # Find best pair that exists in merges
230
+ best_pair = None
231
+ best_rank = float('inf')
232
+ for pair in pairs:
233
+ if pair in self.merges and self.merges[pair] < best_rank:
234
+ best_pair = pair
235
+ best_rank = self.merges[pair]
236
+
237
+ if best_pair is None:
238
+ break
239
+
240
+ pattern = ' '.join(best_pair)
241
+ replacement = ''.join(best_pair)
242
+ word = word.replace(pattern, replacement)
243
+
244
+ for token in word.split():
245
+ tokens.append(self.vocab.get(token, self.unk_token_id))
246
+
247
+ if add_special_tokens:
248
+ tokens.append(self.eos_token_id)
249
+ return tokens
250
+
251
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
252
+ """Decode BPE tokens."""
253
+ inv_vocab = {v: k for k, v in self.vocab.items()}
254
+ tokens = []
255
+ for tid in token_ids:
256
+ if skip_special_tokens and tid < 4:
257
+ continue
258
+ token = inv_vocab.get(tid, '')
259
+ tokens.append(token.replace('</w>', ' '))
260
+ return ''.join(tokens).strip()
261
+
262
+
263
+ def get_tokenizer(name: str = "character", **kwargs):
264
+ """Get a tokenizer by name."""
265
+ tokenizers = {
266
+ "character": CharacterTokenizer,
267
+ "char": CharacterTokenizer,
268
+ "word": WordTokenizer,
269
+ "bpe": BPETokenizer
270
+ }
271
+
272
+ if name not in tokenizers:
273
+ raise ValueError(f"Unknown tokenizer: {name}. Options: {list(tokenizers.keys())}")
274
+
275
+ return tokenizers[name](**kwargs)