langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,254 @@
1
+ """
2
+ transformer.py: Standard Transformer implementations for Langtune
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import logging
9
+ from typing import Optional, Dict
10
+
11
+ from .layers import LoRALinear, MultiHeadAttention
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class TransformerBlock(nn.Module):
16
+ """
17
+ Transformer block with LoRA support.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ embed_dim: int,
23
+ num_heads: int,
24
+ mlp_ratio: float = 4.0,
25
+ dropout: float = 0.1,
26
+ lora_config: Optional[Dict] = None
27
+ ):
28
+ super().__init__()
29
+ self.embed_dim = embed_dim
30
+ mlp_dim = int(embed_dim * mlp_ratio)
31
+
32
+ # Attention
33
+ self.attention = MultiHeadAttention(
34
+ embed_dim, num_heads, dropout, lora_config
35
+ )
36
+ self.attention_norm = nn.LayerNorm(embed_dim)
37
+
38
+ # MLP
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(embed_dim, mlp_dim),
41
+ nn.GELU(),
42
+ nn.Dropout(dropout),
43
+ nn.Linear(mlp_dim, embed_dim),
44
+ nn.Dropout(dropout)
45
+ )
46
+
47
+ # LoRA for MLP
48
+ self.use_lora = lora_config is not None
49
+ if self.use_lora:
50
+ self.lora_mlp_fc1 = LoRALinear(
51
+ embed_dim, mlp_dim,
52
+ rank=lora_config.get('rank', 8),
53
+ alpha=lora_config.get('alpha', 16.0),
54
+ dropout=lora_config.get('dropout', 0.1)
55
+ )
56
+ self.lora_mlp_fc2 = LoRALinear(
57
+ mlp_dim, embed_dim,
58
+ rank=lora_config.get('rank', 8),
59
+ alpha=lora_config.get('alpha', 16.0),
60
+ dropout=lora_config.get('dropout', 0.1)
61
+ )
62
+
63
+ self.mlp_norm = nn.LayerNorm(embed_dim)
64
+
65
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
66
+ # Self-attention with residual connection
67
+ attn_out = self.attention(x, mask)
68
+ x = self.attention_norm(x + attn_out)
69
+
70
+ # MLP with residual connection
71
+ if self.use_lora:
72
+ mlp_out = self.lora_mlp_fc1(x)
73
+ mlp_out = F.gelu(mlp_out)
74
+ mlp_out = self.lora_mlp_fc2(mlp_out)
75
+ # Add original MLP output
76
+ mlp_out = mlp_out + self.mlp(x)
77
+ else:
78
+ mlp_out = self.mlp(x)
79
+
80
+ x = self.mlp_norm(x + mlp_out)
81
+ return x
82
+
83
+ class LoRALanguageModel(nn.Module):
84
+ """
85
+ A complete language model with LoRA support for efficient fine-tuning.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_size: int,
91
+ embed_dim: int,
92
+ num_layers: int,
93
+ num_heads: int,
94
+ max_seq_len: int = 512,
95
+ mlp_ratio: float = 4.0,
96
+ dropout: float = 0.1,
97
+ lora_config: Optional[Dict] = None
98
+ ):
99
+ super().__init__()
100
+ self.vocab_size = vocab_size
101
+ self.embed_dim = embed_dim
102
+ self.num_layers = num_layers
103
+ self.num_heads = num_heads
104
+ self.max_seq_len = max_seq_len
105
+
106
+ # Token and positional embeddings
107
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
108
+ self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
109
+
110
+ # Transformer blocks
111
+ self.blocks = nn.ModuleList([
112
+ TransformerBlock(
113
+ embed_dim, num_heads, mlp_ratio, dropout, lora_config
114
+ )
115
+ for _ in range(num_layers)
116
+ ])
117
+
118
+ # Output layer
119
+ self.norm = nn.LayerNorm(embed_dim)
120
+ self.head = nn.Linear(embed_dim, vocab_size, bias=False)
121
+
122
+ # Initialize weights
123
+ self.apply(self._init_weights)
124
+
125
+ logger.info(f"Initialized LoRALanguageModel with {self.count_parameters()} parameters")
126
+ if lora_config:
127
+ logger.info(f"LoRA config: {lora_config}")
128
+
129
+ def _init_weights(self, module):
130
+ """Initialize model weights."""
131
+ if isinstance(module, nn.Linear):
132
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
133
+ if module.bias is not None:
134
+ torch.nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
137
+ elif isinstance(module, nn.LayerNorm):
138
+ torch.nn.init.zeros_(module.bias)
139
+ torch.nn.init.ones_(module.weight)
140
+
141
+ def count_parameters(self) -> int:
142
+ """Count total number of parameters."""
143
+ return sum(p.numel() for p in self.parameters())
144
+
145
+ def count_lora_parameters(self) -> int:
146
+ """Count LoRA-specific parameters."""
147
+ lora_params = 0
148
+ for module in self.modules():
149
+ if isinstance(module, LoRALinear):
150
+ lora_params += module.lora_A.numel() + module.lora_B.numel()
151
+ return lora_params
152
+
153
+ def create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
154
+ """Create causal attention mask."""
155
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
156
+ return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
157
+
158
+ def forward(
159
+ self,
160
+ input_ids: torch.Tensor,
161
+ attention_mask: Optional[torch.Tensor] = None,
162
+ labels: Optional[torch.Tensor] = None
163
+ ) -> Dict[str, torch.Tensor]:
164
+ """
165
+ Forward pass through the model.
166
+ """
167
+ batch_size, seq_len = input_ids.shape
168
+ device = input_ids.device
169
+
170
+ # Create causal mask
171
+ causal_mask = self.create_causal_mask(seq_len, device)
172
+
173
+ # Token and positional embeddings
174
+ positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
175
+ x = self.token_embedding(input_ids) + self.position_embedding(positions)
176
+
177
+ # Apply attention mask if provided
178
+ if attention_mask is not None:
179
+ # Convert attention mask to 4D for broadcasting
180
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
181
+ causal_mask = causal_mask * attention_mask
182
+
183
+ # Forward through transformer blocks
184
+ for block in self.blocks:
185
+ x = block(x, causal_mask)
186
+
187
+ # Final layer norm and output projection
188
+ x = self.norm(x)
189
+ logits = self.head(x)
190
+
191
+ outputs = {"logits": logits}
192
+
193
+ # Compute loss if labels are provided
194
+ if labels is not None:
195
+ shift_logits = logits[..., :-1, :].contiguous()
196
+ shift_labels = labels[..., 1:].contiguous()
197
+ loss = F.cross_entropy(
198
+ shift_logits.view(-1, shift_logits.size(-1)),
199
+ shift_labels.view(-1),
200
+ ignore_index=-100
201
+ )
202
+ outputs["loss"] = loss
203
+
204
+ return outputs
205
+
206
+ def generate(
207
+ self,
208
+ input_ids: torch.Tensor,
209
+ max_length: int = 100,
210
+ temperature: float = 1.0,
211
+ top_k: Optional[int] = None,
212
+ top_p: Optional[float] = None,
213
+ pad_token_id: int = 0,
214
+ eos_token_id: int = 1
215
+ ) -> torch.Tensor:
216
+ """Generate text using the model."""
217
+ self.eval()
218
+ with torch.no_grad():
219
+ for _ in range(max_length - input_ids.size(1)):
220
+ # Forward pass
221
+ outputs = self.forward(input_ids)
222
+ logits = outputs["logits"][:, -1, :] / temperature
223
+
224
+ # Apply top-k filtering
225
+ if top_k is not None:
226
+ top_k = min(top_k, logits.size(-1))
227
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
228
+ logits[indices_to_remove] = -float('inf')
229
+
230
+ # Apply top-p filtering
231
+ if top_p is not None:
232
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
233
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
234
+
235
+ # Remove tokens with cumulative probability above the threshold
236
+ sorted_indices_to_remove = cumulative_probs > top_p
237
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
238
+ sorted_indices_to_remove[..., 0] = 0
239
+
240
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
241
+ logits[indices_to_remove] = -float('inf')
242
+
243
+ # Sample next token
244
+ probs = F.softmax(logits, dim=-1)
245
+ next_token = torch.multinomial(probs, num_samples=1)
246
+
247
+ # Append to input_ids
248
+ input_ids = torch.cat([input_ids, next_token], dim=1)
249
+
250
+ # Check for EOS token
251
+ if (next_token == eos_token_id).all():
252
+ break
253
+
254
+ return input_ids