rxnn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ import os
2
+ from pathlib import Path
3
+ from tokenizers import Tokenizer
4
+ from tokenizers.models import BPE, WordPiece, Unigram, WordLevel
5
+ from tokenizers.trainers import BpeTrainer, WordPieceTrainer, UnigramTrainer, WordLevelTrainer
6
+ from tokenizers.pre_tokenizers import Whitespace, Punctuation, BertPreTokenizer, ByteLevel
7
+ from tokenizers.processors import TemplateProcessing
8
+ from tokenizers.normalizers import Lowercase, NFKC, Sequence
9
+ from transformers import PreTrainedTokenizerFast
10
+ from typing import Any
11
+
12
+ class TokenizerTrainer:
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 30000,
16
+ model_type: str = "byte-level-bpe", # Options: "bpe", "wordpiece", "unigram", "sentencepiece"
17
+ special_tokens: list[str] = None,
18
+ lowercase: bool = False,
19
+ normalization: bool = False,
20
+ pre_tokenizer_type: str = "bert", # Options: "bert", "whitespace_punctuation",
21
+ vocab: Any = None,
22
+ byte_fallback: bool = False,
23
+ max_input_chars_per_word: int = 32,
24
+ use_post_processor: bool = True,
25
+ post_processor_single: str = "[BOS] $A [EOS]",
26
+ post_processor_pair: str = "[BOS] $A [EOS][BOS] $B:1 [EOS]:1",
27
+ post_processor_special_tokens: list[str] = None,
28
+ ):
29
+ self.vocab_size = vocab_size
30
+ self.special_tokens = special_tokens if special_tokens is not None else ["[PAD]", "[UNK]", "[BOS]", "[EOS]",
31
+ "[MASK]"]
32
+ self.model_type = model_type.lower()
33
+ self.lowercase = lowercase
34
+ self.normalization = normalization
35
+ self.pre_tokenizer_type = pre_tokenizer_type.lower()
36
+
37
+ # Initialize tokenizer model
38
+ if self.model_type == "bpe":
39
+ self.tokenizer = Tokenizer(BPE(unk_token="[UNK]", vocab=vocab, byte_fallback=byte_fallback))
40
+ elif self.model_type == "wordpiece":
41
+ self.tokenizer = Tokenizer(
42
+ WordPiece(unk_token="[UNK]", vocab=vocab, max_input_chars_per_word=max_input_chars_per_word))
43
+ elif self.model_type == "unigram":
44
+ self.tokenizer = Tokenizer(Unigram(unk_id="[UNK]", vocab=None, byte_fallback=byte_fallback))
45
+ elif self.model_type == "wordlevel":
46
+ self.tokenizer = Tokenizer(WordLevel(unk_token="[UNK]", vocab=None))
47
+ elif self.model_type == "byte-level-bpe":
48
+ self.tokenizer = Tokenizer(BPE(unk_token="[UNK]", vocab=vocab, byte_fallback=byte_fallback))
49
+ self.tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
50
+ else:
51
+ raise ValueError(f"Unsupported model type: {model_type}")
52
+
53
+ # Configure pre-tokenizer
54
+ if self.model_type != "byte-level-bpe":
55
+ if self.pre_tokenizer_type == "bert":
56
+ self.tokenizer.pre_tokenizer = BertPreTokenizer()
57
+ elif self.pre_tokenizer_type == "whitespace_punctuation":
58
+ self.tokenizer.pre_tokenizer = Whitespace()
59
+ self.tokenizer.pre_tokenizer = Punctuation()
60
+ elif self.pre_tokenizer_type == "whitespace":
61
+ self.tokenizer.pre_tokenizer = Whitespace()
62
+ else:
63
+ raise ValueError(f"Unsupported pre-tokenizer: {pre_tokenizer_type}")
64
+
65
+ # Add normalization steps
66
+ if self.normalization:
67
+ normalizers = []
68
+ if self.lowercase:
69
+ normalizers.append(Lowercase())
70
+ normalizers.append(NFKC())
71
+ self.tokenizer.normalizer = Sequence(normalizers)
72
+
73
+ self.use_post_processor = use_post_processor
74
+ self.post_processor_single = post_processor_single
75
+ self.post_processor_pair = post_processor_pair
76
+ self.post_processor_special_tokens = post_processor_special_tokens
77
+
78
+ def train(
79
+ self,
80
+ files: list[str],
81
+ limit_alphabet: int = 1000,
82
+ show_progress: bool = True,
83
+ **kwargs
84
+ ):
85
+ # Prepare trainer based on model type
86
+ trainer_kwargs = {
87
+ "vocab_size": self.vocab_size,
88
+ "special_tokens": self.special_tokens,
89
+ "limit_alphabet": limit_alphabet,
90
+ "show_progress": show_progress,
91
+ **kwargs # Allow custom parameters
92
+ }
93
+
94
+ if self.model_type in ["bpe", "byte-level-bpe"]:
95
+ trainer = BpeTrainer(**trainer_kwargs)
96
+ elif self.model_type == "wordpiece":
97
+ trainer = WordPieceTrainer(**trainer_kwargs)
98
+ elif self.model_type == "unigram":
99
+ trainer = UnigramTrainer(**trainer_kwargs)
100
+ elif self.model_type == "wordlevel":
101
+ trainer = WordLevelTrainer(**trainer_kwargs)
102
+
103
+ # Train tokenizer
104
+ self.tokenizer.train(files, trainer)
105
+
106
+ if self.use_post_processor:
107
+ post_processor_special_tokens = self.post_processor_special_tokens or ["[BOS]", "[EOS]"]
108
+ self.tokenizer.post_processor = TemplateProcessing(
109
+ single=self.post_processor_single,
110
+ pair=self.post_processor_pair,
111
+ special_tokens=[(token, self.tokenizer.token_to_id(token)) for token in post_processor_special_tokens],
112
+ )
113
+
114
+ def save(self, output_dir: str):
115
+ os.makedirs(output_dir, exist_ok=True)
116
+ self.tokenizer.save(f"{output_dir}/tokenizer.json")
117
+
118
+ def load(self, model_path: str):
119
+ self.tokenizer = Tokenizer.from_file(model_path)
120
+
121
+ def get_hf_tokenizer(self):
122
+ return PreTrainedTokenizerFast(
123
+ tokenizer_object=self.tokenizer,
124
+ unk_token="[UNK]",
125
+ pad_token="[PAD]",
126
+ cls_token="[CLS]",
127
+ sep_token="[SEP]",
128
+ mask_token="[MASK]"
129
+ )
130
+
131
+ def push_to_hub(
132
+ self,
133
+ repo_id: str,
134
+ create: bool = False,
135
+ private: bool = False,
136
+ api_token: str = None,
137
+ **kwargs
138
+ ):
139
+ """
140
+ Push the trained tokenizer to HuggingFace Hub.
141
+
142
+ Args:
143
+ repo_id (str): Hub repository name (e.g., "username/my-tokenizer")
144
+ private (bool): Whether the repo is private
145
+ api_token (str): HuggingFace API token (optional if already logged in)
146
+ **kwargs: Additional args for HuggingFace API
147
+ """
148
+ from huggingface_hub import HfApi, Repository
149
+
150
+ # Create a temporary directory for Hub upload
151
+ temp_dir = "temp_hub_upload"
152
+ os.makedirs(temp_dir, exist_ok=True)
153
+ self.save(temp_dir) # Save tokenizer files locally
154
+
155
+ # Push to Hub using HuggingFace API
156
+ api = HfApi(token=api_token)
157
+ if create:
158
+ api.create_repo(
159
+ repo_id=repo_id,
160
+ private=private,
161
+ exist_ok=True,
162
+ )
163
+
164
+ # Push files to the repo
165
+ api.upload_folder(
166
+ repo_id=repo_id,
167
+ folder_path=temp_dir,
168
+ repo_type="model",
169
+ **kwargs
170
+ )
171
+
172
+ # Cleanup
173
+ os.remove(Path(temp_dir) / 'tokenizer.json')
174
+ os.rmdir(temp_dir)
175
+
176
+ @staticmethod
177
+ def hf_tokenizer_from_file(path: str):
178
+ return PreTrainedTokenizerFast(
179
+ tokenizer_file=path,
180
+ unk_token="[UNK]",
181
+ pad_token="[PAD]",
182
+ cls_token="[CLS]",
183
+ sep_token="[SEP]",
184
+ mask_token="[MASK]"
185
+ )
186
+
187
+ @classmethod
188
+ def from_pretrained(cls, repo_id: str, **kwargs):
189
+ """
190
+ Load tokenizer from HuggingFace Hub.
191
+
192
+ Args:
193
+ repo_id (str): Hub repository name (e.g., "username/my-tokenizer")
194
+ **kwargs: Additional args for HuggingFace API
195
+ """
196
+ from huggingface_hub import hf_hub_download
197
+
198
+ # Download tokenizer.json from Hub
199
+ tokenizer_file = hf_hub_download(
200
+ repo_id=repo_id,
201
+ filename="tokenizer.json",
202
+ **kwargs
203
+ )
204
+
205
+ # Initialize trainer and load tokenizer
206
+ trainer = cls()
207
+ trainer.load(tokenizer_file)
208
+ return trainer
@@ -0,0 +1,324 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
6
+
7
+
8
+ class MultiHeadAttention(nn.Module):
9
+ """Custom, extendable Multi-head attention layer, with RoPE support"""
10
+
11
+ def __init__(
12
+ self,
13
+ embed_dim: int,
14
+ num_heads: int,
15
+ dropout: float = 0.0,
16
+ rope: RotaryPositionalEmbedding = None,
17
+ rope_only_for_query: bool = False,
18
+ use_relative_embeddings: bool = False,
19
+ max_seq_len: int = 1024,
20
+ use_flash_attention: bool = False,
21
+ is_causal: bool = False,
22
+ use_bias: bool = False,
23
+ *args,
24
+ **kwargs,
25
+ ):
26
+ super(MultiHeadAttention, self).__init__(*args, **kwargs)
27
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
28
+ self.embed_dim = embed_dim
29
+ self.num_heads = num_heads
30
+
31
+ self.use_flash_attention = use_flash_attention
32
+ self.is_causal = is_causal
33
+ self.use_bias = use_bias
34
+ if use_relative_embeddings:
35
+ self.use_flash_attention = False
36
+ self.rel_embed = RelativePositionalEmbedding(max_seq_len, embed_dim // num_heads)
37
+ self.rope = None
38
+ self.rope_only_for_query = False
39
+ else:
40
+ self.rel_embed = None
41
+ self.rope = rope
42
+ self.rope_only_for_query = rope_only_for_query
43
+ self.dropout = nn.Dropout(dropout)
44
+ self._init_q(embed_dim)
45
+ self._init_kv(embed_dim)
46
+ self._init_out(embed_dim)
47
+
48
+ def _init_q(self, embed_dim: int):
49
+ """Initialize query projection"""
50
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=self.use_bias)
51
+
52
+ def _init_kv(self, embed_dim: int):
53
+ """Initialize key and value projections"""
54
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=self.use_bias)
55
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=self.use_bias)
56
+
57
+ def _init_out(self, embed_dim: int):
58
+ """Initialize output projection"""
59
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
60
+
61
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
62
+ """Forward pass through query, key, and value projections, and split the results into heads"""
63
+ q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
64
+ k = self.k_proj(key).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
65
+ v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
66
+ return q, k, v
67
+
68
+ def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
69
+ if self.rope is not None:
70
+ if self.rope_only_for_query:
71
+ q = self.rope.forward_one(q)
72
+ else:
73
+ q, k = self.rope(q, k)
74
+ return q, k
75
+
76
+ def _calculate_attn_weights(self, q: torch.Tensor, k: torch.Tensor, d: int, mask: torch.Tensor = None):
77
+ """Calculate attention weights using scaled dot-product attention"""
78
+ q, k = self._apply_rope(q, k)
79
+ attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (d // self.num_heads) ** 0.5
80
+ if mask is not None:
81
+ attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
82
+ return F.softmax(attn_logits, dim=-1)
83
+
84
+ def _calculate_attn_weight_with_relative_embeddings(self, q: torch.Tensor, k: torch.Tensor,
85
+ mask: torch.Tensor = None):
86
+ """Calculate attention weights using scaled dot-product attention and apply relative embedding"""
87
+ attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
88
+ rel_pos_bias = self.rel_embed(q, k)
89
+ attn_logits += rel_pos_bias
90
+ if mask is not None:
91
+ attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
92
+ return F.softmax(attn_logits, dim=-1)
93
+
94
+ def _calculate_output(self, attn_weights: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int):
95
+ """Calculate the output by multiplying attention weights with values and concatenating heads"""
96
+ return torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(b, t, d)
97
+
98
+ def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
99
+ mask: torch.Tensor = None, enable_gqa: bool = False):
100
+ attn_output = F.scaled_dot_product_attention(
101
+ q, k, v,
102
+ attn_mask=mask if not self.is_causal else None,
103
+ dropout_p=self.dropout.p if self.training else 0.0,
104
+ is_causal=self.is_causal,
105
+ enable_gqa=enable_gqa,
106
+ )
107
+
108
+ # Reshape back to (B, T, D)
109
+ return attn_output.transpose(1, 2).contiguous().view(b, t, d)
110
+
111
+ def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
112
+ mask: torch.Tensor = None):
113
+ # Compute attention with FlashAttention
114
+ return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
115
+
116
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
117
+ b, t, d = query.size()
118
+ q, k, v = self._forward_qkv(query, key, value, b, t, d)
119
+ if self.use_flash_attention:
120
+ q, k = self._apply_rope(q, k)
121
+ attn_output = self._calculate_flash_attention(q, k, v, b, t, d, mask=mask)
122
+ else:
123
+ if not self.rel_embed:
124
+ attn_weights = self._calculate_attn_weights(q, k, d, mask=mask)
125
+ else:
126
+ attn_weights = self._calculate_attn_weight_with_relative_embeddings(q, k, mask=mask)
127
+
128
+ attn_weights = self.dropout(attn_weights)
129
+
130
+ attn_output = self._calculate_output(attn_weights, v, b, t, d)
131
+ return self.out_proj(attn_output)
132
+
133
+
134
+ class GroupedQueryAttention(MultiHeadAttention):
135
+ """Custom Grouped Query attention layer, with RoPE support"""
136
+
137
+ def __init__(
138
+ self,
139
+ embed_dim: int,
140
+ num_heads: int,
141
+ num_groups: int,
142
+ dropout: float = 0.0,
143
+ rope: RotaryPositionalEmbedding = None,
144
+ rope_only_for_query: bool = False,
145
+ use_relative_embeddings: bool = False,
146
+ max_seq_len: int = 1024,
147
+ use_flash_attention: bool = False,
148
+ is_causal: bool = False,
149
+ use_bias: bool = False,
150
+ *args,
151
+ **kwargs,
152
+ ):
153
+ self.num_groups = num_groups
154
+ super(GroupedQueryAttention, self).__init__(
155
+ embed_dim,
156
+ num_heads,
157
+ dropout=dropout,
158
+ rope=rope,
159
+ rope_only_for_query=rope_only_for_query,
160
+ use_relative_embeddings=use_relative_embeddings,
161
+ max_seq_len=max_seq_len,
162
+ use_flash_attention=use_flash_attention,
163
+ is_causal=is_causal,
164
+ use_bias=use_bias,
165
+ *args,
166
+ **kwargs,
167
+ )
168
+ assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
169
+
170
+ def _init_kv(self, embed_dim: int):
171
+ self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
172
+ self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
173
+
174
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
175
+ """Override query, key, and value projections for GQA case - split data into heads and groups"""
176
+ head_dim = d // self.num_heads
177
+ if self.use_flash_attention:
178
+ q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
179
+ k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
180
+ v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
181
+ else:
182
+ group_heads = self.num_heads // self.num_groups
183
+
184
+ # Process Q
185
+ q = self.q_proj(query).view(b, t, self.num_groups, group_heads, head_dim).permute(0, 2, 3, 1,
186
+ 4) # (B, G, group_heads, T, head_dim)
187
+
188
+ # Process K and V
189
+ k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
190
+ v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
191
+
192
+ # Expand and flatten to 4D tensors
193
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
194
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
195
+
196
+ q = q.flatten(start_dim=1, end_dim=2) # (B, H, T, head_dim)
197
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
198
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
199
+ return q, k, v
200
+
201
+ def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
202
+ mask: torch.Tensor = None):
203
+ return self._flash_attention(
204
+ q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
205
+ enable_gqa=(self.num_heads != self.num_groups)
206
+ )
207
+
208
+
209
+ class MultiQueryAttention(MultiHeadAttention):
210
+ """Custom Multi Query attention layer, with RoPE support"""
211
+
212
+ def __init__(
213
+ self,
214
+ embed_dim: int,
215
+ num_heads: int,
216
+ dropout: float = 0.0,
217
+ rope: RotaryPositionalEmbedding = None,
218
+ rope_only_for_query: bool = False,
219
+ use_relative_embeddings: bool = False,
220
+ max_seq_len: int = 1024,
221
+ use_flash_attention: bool = False,
222
+ is_causal: bool = False,
223
+ use_bias: bool = False,
224
+ *args,
225
+ **kwargs,
226
+ ):
227
+ super(MultiQueryAttention, self).__init__(
228
+ embed_dim,
229
+ num_heads,
230
+ dropout=dropout,
231
+ rope=rope,
232
+ rope_only_for_query=rope_only_for_query,
233
+ use_relative_embeddings=use_relative_embeddings,
234
+ max_seq_len=max_seq_len,
235
+ use_flash_attention=use_flash_attention,
236
+ is_causal=is_causal,
237
+ use_bias=use_bias,
238
+ *args,
239
+ **kwargs
240
+ )
241
+
242
+ def _init_kv(self, embed_dim: int):
243
+ """Override key/value initialization for MQA case"""
244
+ self.k_proj = nn.Linear(embed_dim, embed_dim // self.num_heads, bias=self.use_bias)
245
+ self.v_proj = nn.Linear(embed_dim, embed_dim // self.num_heads, bias=self.use_bias)
246
+
247
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
248
+ """Override query, key, and value projections for GQA case - use multiple heads
249
+ for query and single for key/values"""
250
+ if self.use_flash_attention:
251
+ q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
252
+ k = self.k_proj(key).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
253
+ v = self.v_proj(value).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
254
+ else:
255
+ q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
256
+ k = self.k_proj(key).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
257
+ v = self.v_proj(value).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
258
+ return q, k, v
259
+
260
+ def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
261
+ mask: torch.Tensor = None):
262
+ return self._flash_attention(
263
+ q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
264
+ enable_gqa=True
265
+ )
266
+
267
+
268
+ def init_attention(
269
+ embed_dim: int,
270
+ num_heads: int,
271
+ attention_type: str,
272
+ gqa_groups: int = 1,
273
+ dropout: float = 0.0,
274
+ rope: RotaryPositionalEmbedding = None,
275
+ rope_only_for_query: bool = False,
276
+ use_relative_embeddings: bool = False,
277
+ max_seq_len: int = 1024,
278
+ use_flash_attention: bool = False,
279
+ is_causal: bool = False,
280
+ use_bias: bool = False,
281
+ ) -> MultiHeadAttention:
282
+ assert attention_type == 'mha' or attention_type == 'gqa' or attention_type == 'mqa', \
283
+ "Error, attention type should be one of: 'mha', 'gqa', 'mqa'"
284
+
285
+ if attention_type == "gqa":
286
+ return GroupedQueryAttention(
287
+ embed_dim,
288
+ num_heads,
289
+ gqa_groups,
290
+ dropout=dropout,
291
+ rope=rope,
292
+ use_relative_embeddings=use_relative_embeddings,
293
+ max_seq_len=max_seq_len,
294
+ rope_only_for_query=rope_only_for_query,
295
+ use_flash_attention=use_flash_attention,
296
+ is_causal=is_causal,
297
+ use_bias=use_bias,
298
+ )
299
+ elif attention_type == "mqa":
300
+ return MultiQueryAttention(
301
+ embed_dim,
302
+ num_heads,
303
+ dropout=dropout,
304
+ rope=rope,
305
+ use_relative_embeddings=use_relative_embeddings,
306
+ max_seq_len=max_seq_len,
307
+ rope_only_for_query=rope_only_for_query,
308
+ use_flash_attention=use_flash_attention,
309
+ is_causal=is_causal,
310
+ use_bias=use_bias,
311
+ )
312
+ else:
313
+ return MultiHeadAttention(
314
+ embed_dim,
315
+ num_heads,
316
+ dropout=dropout,
317
+ rope=rope,
318
+ use_relative_embeddings=use_relative_embeddings,
319
+ max_seq_len=max_seq_len,
320
+ rope_only_for_query=rope_only_for_query,
321
+ use_flash_attention=use_flash_attention,
322
+ is_causal=is_causal,
323
+ use_bias=use_bias,
324
+ )
src/transformers/ff.py ADDED
@@ -0,0 +1,72 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FeedForward(nn.Module):
6
+ """Basic Feed-forward layer with activation function and optional dropout"""
7
+
8
+ def __init__(self, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float = 0.0, *args, **kwargs):
9
+ super(FeedForward, self).__init__(*args, **kwargs)
10
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
11
+ self.activation = activation
12
+ self.fc2 = nn.Linear(hidden_dim, embed_dim)
13
+ self.dropout = nn.Dropout(dropout)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ x = self.fc1(x)
17
+ x = self.activation(x)
18
+ x = self.dropout(x)
19
+ return self.fc2(x)
20
+
21
+
22
+ class GatedLinearUnit(nn.Module):
23
+ """Gated linear unit layer with configurable activation (SwiGLU, ReGLU, etc.)"""
24
+
25
+ def __init__(self, embed_dim: int, hidden_dim: int, activation: nn.Module, *args, **kwargs):
26
+ super(GatedLinearUnit, self).__init__(*args, **kwargs)
27
+ self.linear = nn.Linear(embed_dim, hidden_dim * 2)
28
+ self.activation = activation
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ l, g = self.linear(x).chunk(2, dim=-1)
32
+ return l * self.activation(g)
33
+
34
+
35
+ class GatedFeedForward(nn.Module):
36
+ """Gated feed-forward layer with activation function and optional dropout"""
37
+
38
+ def __init__(self, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float = 0.0, *args, **kwargs):
39
+ super(GatedFeedForward, self).__init__(*args, **kwargs)
40
+ self.fc1 = GatedLinearUnit(embed_dim, hidden_dim, activation)
41
+ self.fc2 = nn.Linear(hidden_dim, embed_dim)
42
+ self.dropout = nn.Dropout(dropout)
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ x = self.fc1(x)
46
+ x = self.dropout(x)
47
+ return self.fc2(x)
48
+
49
+
50
+ class LinearActivation(nn.Module):
51
+ """Linear activation - identity function, for Bilinear Gated Unit"""
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ super(LinearActivation, self).__init__(*args, **kwargs)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return x
58
+
59
+
60
+ def get_activation_layer(activation: str):
61
+ if activation == 'relu':
62
+ return nn.ReLU()
63
+ elif activation == 'gelu':
64
+ return nn.GELU()
65
+ elif activation == 'silu' or activation == 'swish':
66
+ return nn.SiLU()
67
+ elif activation == 'sigmoid':
68
+ return nn.Sigmoid()
69
+ elif activation == 'linear':
70
+ return LinearActivation()
71
+ else:
72
+ raise ValueError(f'Activation {activation} not supported')