langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""
|
|
2
|
+
transformer.py: Standard Transformer implementations for Langtune
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional, Dict
|
|
10
|
+
|
|
11
|
+
from .layers import LoRALinear, MultiHeadAttention
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class TransformerBlock(nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Transformer block with LoRA support.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
embed_dim: int,
|
|
23
|
+
num_heads: int,
|
|
24
|
+
mlp_ratio: float = 4.0,
|
|
25
|
+
dropout: float = 0.1,
|
|
26
|
+
lora_config: Optional[Dict] = None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.embed_dim = embed_dim
|
|
30
|
+
mlp_dim = int(embed_dim * mlp_ratio)
|
|
31
|
+
|
|
32
|
+
# Attention
|
|
33
|
+
self.attention = MultiHeadAttention(
|
|
34
|
+
embed_dim, num_heads, dropout, lora_config
|
|
35
|
+
)
|
|
36
|
+
self.attention_norm = nn.LayerNorm(embed_dim)
|
|
37
|
+
|
|
38
|
+
# MLP
|
|
39
|
+
self.mlp = nn.Sequential(
|
|
40
|
+
nn.Linear(embed_dim, mlp_dim),
|
|
41
|
+
nn.GELU(),
|
|
42
|
+
nn.Dropout(dropout),
|
|
43
|
+
nn.Linear(mlp_dim, embed_dim),
|
|
44
|
+
nn.Dropout(dropout)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# LoRA for MLP
|
|
48
|
+
self.use_lora = lora_config is not None
|
|
49
|
+
if self.use_lora:
|
|
50
|
+
self.lora_mlp_fc1 = LoRALinear(
|
|
51
|
+
embed_dim, mlp_dim,
|
|
52
|
+
rank=lora_config.get('rank', 8),
|
|
53
|
+
alpha=lora_config.get('alpha', 16.0),
|
|
54
|
+
dropout=lora_config.get('dropout', 0.1)
|
|
55
|
+
)
|
|
56
|
+
self.lora_mlp_fc2 = LoRALinear(
|
|
57
|
+
mlp_dim, embed_dim,
|
|
58
|
+
rank=lora_config.get('rank', 8),
|
|
59
|
+
alpha=lora_config.get('alpha', 16.0),
|
|
60
|
+
dropout=lora_config.get('dropout', 0.1)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self.mlp_norm = nn.LayerNorm(embed_dim)
|
|
64
|
+
|
|
65
|
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
66
|
+
# Self-attention with residual connection
|
|
67
|
+
attn_out = self.attention(x, mask)
|
|
68
|
+
x = self.attention_norm(x + attn_out)
|
|
69
|
+
|
|
70
|
+
# MLP with residual connection
|
|
71
|
+
if self.use_lora:
|
|
72
|
+
mlp_out = self.lora_mlp_fc1(x)
|
|
73
|
+
mlp_out = F.gelu(mlp_out)
|
|
74
|
+
mlp_out = self.lora_mlp_fc2(mlp_out)
|
|
75
|
+
# Add original MLP output
|
|
76
|
+
mlp_out = mlp_out + self.mlp(x)
|
|
77
|
+
else:
|
|
78
|
+
mlp_out = self.mlp(x)
|
|
79
|
+
|
|
80
|
+
x = self.mlp_norm(x + mlp_out)
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
class LoRALanguageModel(nn.Module):
|
|
84
|
+
"""
|
|
85
|
+
A complete language model with LoRA support for efficient fine-tuning.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
vocab_size: int,
|
|
91
|
+
embed_dim: int,
|
|
92
|
+
num_layers: int,
|
|
93
|
+
num_heads: int,
|
|
94
|
+
max_seq_len: int = 512,
|
|
95
|
+
mlp_ratio: float = 4.0,
|
|
96
|
+
dropout: float = 0.1,
|
|
97
|
+
lora_config: Optional[Dict] = None
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.vocab_size = vocab_size
|
|
101
|
+
self.embed_dim = embed_dim
|
|
102
|
+
self.num_layers = num_layers
|
|
103
|
+
self.num_heads = num_heads
|
|
104
|
+
self.max_seq_len = max_seq_len
|
|
105
|
+
|
|
106
|
+
# Token and positional embeddings
|
|
107
|
+
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
|
|
108
|
+
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
|
|
109
|
+
|
|
110
|
+
# Transformer blocks
|
|
111
|
+
self.blocks = nn.ModuleList([
|
|
112
|
+
TransformerBlock(
|
|
113
|
+
embed_dim, num_heads, mlp_ratio, dropout, lora_config
|
|
114
|
+
)
|
|
115
|
+
for _ in range(num_layers)
|
|
116
|
+
])
|
|
117
|
+
|
|
118
|
+
# Output layer
|
|
119
|
+
self.norm = nn.LayerNorm(embed_dim)
|
|
120
|
+
self.head = nn.Linear(embed_dim, vocab_size, bias=False)
|
|
121
|
+
|
|
122
|
+
# Initialize weights
|
|
123
|
+
self.apply(self._init_weights)
|
|
124
|
+
|
|
125
|
+
logger.info(f"Initialized LoRALanguageModel with {self.count_parameters()} parameters")
|
|
126
|
+
if lora_config:
|
|
127
|
+
logger.info(f"LoRA config: {lora_config}")
|
|
128
|
+
|
|
129
|
+
def _init_weights(self, module):
|
|
130
|
+
"""Initialize model weights."""
|
|
131
|
+
if isinstance(module, nn.Linear):
|
|
132
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
133
|
+
if module.bias is not None:
|
|
134
|
+
torch.nn.init.zeros_(module.bias)
|
|
135
|
+
elif isinstance(module, nn.Embedding):
|
|
136
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
137
|
+
elif isinstance(module, nn.LayerNorm):
|
|
138
|
+
torch.nn.init.zeros_(module.bias)
|
|
139
|
+
torch.nn.init.ones_(module.weight)
|
|
140
|
+
|
|
141
|
+
def count_parameters(self) -> int:
|
|
142
|
+
"""Count total number of parameters."""
|
|
143
|
+
return sum(p.numel() for p in self.parameters())
|
|
144
|
+
|
|
145
|
+
def count_lora_parameters(self) -> int:
|
|
146
|
+
"""Count LoRA-specific parameters."""
|
|
147
|
+
lora_params = 0
|
|
148
|
+
for module in self.modules():
|
|
149
|
+
if isinstance(module, LoRALinear):
|
|
150
|
+
lora_params += module.lora_A.numel() + module.lora_B.numel()
|
|
151
|
+
return lora_params
|
|
152
|
+
|
|
153
|
+
def create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
154
|
+
"""Create causal attention mask."""
|
|
155
|
+
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
|
|
156
|
+
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
|
|
157
|
+
|
|
158
|
+
def forward(
|
|
159
|
+
self,
|
|
160
|
+
input_ids: torch.Tensor,
|
|
161
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
162
|
+
labels: Optional[torch.Tensor] = None
|
|
163
|
+
) -> Dict[str, torch.Tensor]:
|
|
164
|
+
"""
|
|
165
|
+
Forward pass through the model.
|
|
166
|
+
"""
|
|
167
|
+
batch_size, seq_len = input_ids.shape
|
|
168
|
+
device = input_ids.device
|
|
169
|
+
|
|
170
|
+
# Create causal mask
|
|
171
|
+
causal_mask = self.create_causal_mask(seq_len, device)
|
|
172
|
+
|
|
173
|
+
# Token and positional embeddings
|
|
174
|
+
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
|
|
175
|
+
x = self.token_embedding(input_ids) + self.position_embedding(positions)
|
|
176
|
+
|
|
177
|
+
# Apply attention mask if provided
|
|
178
|
+
if attention_mask is not None:
|
|
179
|
+
# Convert attention mask to 4D for broadcasting
|
|
180
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
|
|
181
|
+
causal_mask = causal_mask * attention_mask
|
|
182
|
+
|
|
183
|
+
# Forward through transformer blocks
|
|
184
|
+
for block in self.blocks:
|
|
185
|
+
x = block(x, causal_mask)
|
|
186
|
+
|
|
187
|
+
# Final layer norm and output projection
|
|
188
|
+
x = self.norm(x)
|
|
189
|
+
logits = self.head(x)
|
|
190
|
+
|
|
191
|
+
outputs = {"logits": logits}
|
|
192
|
+
|
|
193
|
+
# Compute loss if labels are provided
|
|
194
|
+
if labels is not None:
|
|
195
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
196
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
197
|
+
loss = F.cross_entropy(
|
|
198
|
+
shift_logits.view(-1, shift_logits.size(-1)),
|
|
199
|
+
shift_labels.view(-1),
|
|
200
|
+
ignore_index=-100
|
|
201
|
+
)
|
|
202
|
+
outputs["loss"] = loss
|
|
203
|
+
|
|
204
|
+
return outputs
|
|
205
|
+
|
|
206
|
+
def generate(
|
|
207
|
+
self,
|
|
208
|
+
input_ids: torch.Tensor,
|
|
209
|
+
max_length: int = 100,
|
|
210
|
+
temperature: float = 1.0,
|
|
211
|
+
top_k: Optional[int] = None,
|
|
212
|
+
top_p: Optional[float] = None,
|
|
213
|
+
pad_token_id: int = 0,
|
|
214
|
+
eos_token_id: int = 1
|
|
215
|
+
) -> torch.Tensor:
|
|
216
|
+
"""Generate text using the model."""
|
|
217
|
+
self.eval()
|
|
218
|
+
with torch.no_grad():
|
|
219
|
+
for _ in range(max_length - input_ids.size(1)):
|
|
220
|
+
# Forward pass
|
|
221
|
+
outputs = self.forward(input_ids)
|
|
222
|
+
logits = outputs["logits"][:, -1, :] / temperature
|
|
223
|
+
|
|
224
|
+
# Apply top-k filtering
|
|
225
|
+
if top_k is not None:
|
|
226
|
+
top_k = min(top_k, logits.size(-1))
|
|
227
|
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
228
|
+
logits[indices_to_remove] = -float('inf')
|
|
229
|
+
|
|
230
|
+
# Apply top-p filtering
|
|
231
|
+
if top_p is not None:
|
|
232
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
233
|
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
234
|
+
|
|
235
|
+
# Remove tokens with cumulative probability above the threshold
|
|
236
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
237
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
238
|
+
sorted_indices_to_remove[..., 0] = 0
|
|
239
|
+
|
|
240
|
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
241
|
+
logits[indices_to_remove] = -float('inf')
|
|
242
|
+
|
|
243
|
+
# Sample next token
|
|
244
|
+
probs = F.softmax(logits, dim=-1)
|
|
245
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
246
|
+
|
|
247
|
+
# Append to input_ids
|
|
248
|
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
249
|
+
|
|
250
|
+
# Check for EOS token
|
|
251
|
+
if (next_token == eos_token_id).all():
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
return input_ids
|