rxnn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn-0.1.0.dist-info/LICENSE +201 -0
- rxnn-0.1.0.dist-info/METADATA +257 -0
- rxnn-0.1.0.dist-info/RECORD +23 -0
- rxnn-0.1.0.dist-info/WHEEL +4 -0
- src/experimental/attention.py +133 -0
- src/memory/norm.py +173 -0
- src/memory/stm.py +53 -0
- src/rxt/models.py +180 -0
- src/training/base.py +275 -0
- src/training/bml.py +345 -0
- src/training/callbacks.py +491 -0
- src/training/dataset.py +164 -0
- src/training/scheduler.py +19 -0
- src/training/tokenizer.py +208 -0
- src/transformers/attention.py +324 -0
- src/transformers/ff.py +72 -0
- src/transformers/layers.py +150 -0
- src/transformers/mask.py +10 -0
- src/transformers/models.py +168 -0
- src/transformers/moe.py +139 -0
- src/transformers/positional.py +105 -0
- src/transformers/sampler.py +109 -0
- src/utils.py +14 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
import os
|
2
|
+
from pathlib import Path
|
3
|
+
from tokenizers import Tokenizer
|
4
|
+
from tokenizers.models import BPE, WordPiece, Unigram, WordLevel
|
5
|
+
from tokenizers.trainers import BpeTrainer, WordPieceTrainer, UnigramTrainer, WordLevelTrainer
|
6
|
+
from tokenizers.pre_tokenizers import Whitespace, Punctuation, BertPreTokenizer, ByteLevel
|
7
|
+
from tokenizers.processors import TemplateProcessing
|
8
|
+
from tokenizers.normalizers import Lowercase, NFKC, Sequence
|
9
|
+
from transformers import PreTrainedTokenizerFast
|
10
|
+
from typing import Any
|
11
|
+
|
12
|
+
class TokenizerTrainer:
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
vocab_size: int = 30000,
|
16
|
+
model_type: str = "byte-level-bpe", # Options: "bpe", "wordpiece", "unigram", "sentencepiece"
|
17
|
+
special_tokens: list[str] = None,
|
18
|
+
lowercase: bool = False,
|
19
|
+
normalization: bool = False,
|
20
|
+
pre_tokenizer_type: str = "bert", # Options: "bert", "whitespace_punctuation",
|
21
|
+
vocab: Any = None,
|
22
|
+
byte_fallback: bool = False,
|
23
|
+
max_input_chars_per_word: int = 32,
|
24
|
+
use_post_processor: bool = True,
|
25
|
+
post_processor_single: str = "[BOS] $A [EOS]",
|
26
|
+
post_processor_pair: str = "[BOS] $A [EOS][BOS] $B:1 [EOS]:1",
|
27
|
+
post_processor_special_tokens: list[str] = None,
|
28
|
+
):
|
29
|
+
self.vocab_size = vocab_size
|
30
|
+
self.special_tokens = special_tokens if special_tokens is not None else ["[PAD]", "[UNK]", "[BOS]", "[EOS]",
|
31
|
+
"[MASK]"]
|
32
|
+
self.model_type = model_type.lower()
|
33
|
+
self.lowercase = lowercase
|
34
|
+
self.normalization = normalization
|
35
|
+
self.pre_tokenizer_type = pre_tokenizer_type.lower()
|
36
|
+
|
37
|
+
# Initialize tokenizer model
|
38
|
+
if self.model_type == "bpe":
|
39
|
+
self.tokenizer = Tokenizer(BPE(unk_token="[UNK]", vocab=vocab, byte_fallback=byte_fallback))
|
40
|
+
elif self.model_type == "wordpiece":
|
41
|
+
self.tokenizer = Tokenizer(
|
42
|
+
WordPiece(unk_token="[UNK]", vocab=vocab, max_input_chars_per_word=max_input_chars_per_word))
|
43
|
+
elif self.model_type == "unigram":
|
44
|
+
self.tokenizer = Tokenizer(Unigram(unk_id="[UNK]", vocab=None, byte_fallback=byte_fallback))
|
45
|
+
elif self.model_type == "wordlevel":
|
46
|
+
self.tokenizer = Tokenizer(WordLevel(unk_token="[UNK]", vocab=None))
|
47
|
+
elif self.model_type == "byte-level-bpe":
|
48
|
+
self.tokenizer = Tokenizer(BPE(unk_token="[UNK]", vocab=vocab, byte_fallback=byte_fallback))
|
49
|
+
self.tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
52
|
+
|
53
|
+
# Configure pre-tokenizer
|
54
|
+
if self.model_type != "byte-level-bpe":
|
55
|
+
if self.pre_tokenizer_type == "bert":
|
56
|
+
self.tokenizer.pre_tokenizer = BertPreTokenizer()
|
57
|
+
elif self.pre_tokenizer_type == "whitespace_punctuation":
|
58
|
+
self.tokenizer.pre_tokenizer = Whitespace()
|
59
|
+
self.tokenizer.pre_tokenizer = Punctuation()
|
60
|
+
elif self.pre_tokenizer_type == "whitespace":
|
61
|
+
self.tokenizer.pre_tokenizer = Whitespace()
|
62
|
+
else:
|
63
|
+
raise ValueError(f"Unsupported pre-tokenizer: {pre_tokenizer_type}")
|
64
|
+
|
65
|
+
# Add normalization steps
|
66
|
+
if self.normalization:
|
67
|
+
normalizers = []
|
68
|
+
if self.lowercase:
|
69
|
+
normalizers.append(Lowercase())
|
70
|
+
normalizers.append(NFKC())
|
71
|
+
self.tokenizer.normalizer = Sequence(normalizers)
|
72
|
+
|
73
|
+
self.use_post_processor = use_post_processor
|
74
|
+
self.post_processor_single = post_processor_single
|
75
|
+
self.post_processor_pair = post_processor_pair
|
76
|
+
self.post_processor_special_tokens = post_processor_special_tokens
|
77
|
+
|
78
|
+
def train(
|
79
|
+
self,
|
80
|
+
files: list[str],
|
81
|
+
limit_alphabet: int = 1000,
|
82
|
+
show_progress: bool = True,
|
83
|
+
**kwargs
|
84
|
+
):
|
85
|
+
# Prepare trainer based on model type
|
86
|
+
trainer_kwargs = {
|
87
|
+
"vocab_size": self.vocab_size,
|
88
|
+
"special_tokens": self.special_tokens,
|
89
|
+
"limit_alphabet": limit_alphabet,
|
90
|
+
"show_progress": show_progress,
|
91
|
+
**kwargs # Allow custom parameters
|
92
|
+
}
|
93
|
+
|
94
|
+
if self.model_type in ["bpe", "byte-level-bpe"]:
|
95
|
+
trainer = BpeTrainer(**trainer_kwargs)
|
96
|
+
elif self.model_type == "wordpiece":
|
97
|
+
trainer = WordPieceTrainer(**trainer_kwargs)
|
98
|
+
elif self.model_type == "unigram":
|
99
|
+
trainer = UnigramTrainer(**trainer_kwargs)
|
100
|
+
elif self.model_type == "wordlevel":
|
101
|
+
trainer = WordLevelTrainer(**trainer_kwargs)
|
102
|
+
|
103
|
+
# Train tokenizer
|
104
|
+
self.tokenizer.train(files, trainer)
|
105
|
+
|
106
|
+
if self.use_post_processor:
|
107
|
+
post_processor_special_tokens = self.post_processor_special_tokens or ["[BOS]", "[EOS]"]
|
108
|
+
self.tokenizer.post_processor = TemplateProcessing(
|
109
|
+
single=self.post_processor_single,
|
110
|
+
pair=self.post_processor_pair,
|
111
|
+
special_tokens=[(token, self.tokenizer.token_to_id(token)) for token in post_processor_special_tokens],
|
112
|
+
)
|
113
|
+
|
114
|
+
def save(self, output_dir: str):
|
115
|
+
os.makedirs(output_dir, exist_ok=True)
|
116
|
+
self.tokenizer.save(f"{output_dir}/tokenizer.json")
|
117
|
+
|
118
|
+
def load(self, model_path: str):
|
119
|
+
self.tokenizer = Tokenizer.from_file(model_path)
|
120
|
+
|
121
|
+
def get_hf_tokenizer(self):
|
122
|
+
return PreTrainedTokenizerFast(
|
123
|
+
tokenizer_object=self.tokenizer,
|
124
|
+
unk_token="[UNK]",
|
125
|
+
pad_token="[PAD]",
|
126
|
+
cls_token="[CLS]",
|
127
|
+
sep_token="[SEP]",
|
128
|
+
mask_token="[MASK]"
|
129
|
+
)
|
130
|
+
|
131
|
+
def push_to_hub(
|
132
|
+
self,
|
133
|
+
repo_id: str,
|
134
|
+
create: bool = False,
|
135
|
+
private: bool = False,
|
136
|
+
api_token: str = None,
|
137
|
+
**kwargs
|
138
|
+
):
|
139
|
+
"""
|
140
|
+
Push the trained tokenizer to HuggingFace Hub.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
repo_id (str): Hub repository name (e.g., "username/my-tokenizer")
|
144
|
+
private (bool): Whether the repo is private
|
145
|
+
api_token (str): HuggingFace API token (optional if already logged in)
|
146
|
+
**kwargs: Additional args for HuggingFace API
|
147
|
+
"""
|
148
|
+
from huggingface_hub import HfApi, Repository
|
149
|
+
|
150
|
+
# Create a temporary directory for Hub upload
|
151
|
+
temp_dir = "temp_hub_upload"
|
152
|
+
os.makedirs(temp_dir, exist_ok=True)
|
153
|
+
self.save(temp_dir) # Save tokenizer files locally
|
154
|
+
|
155
|
+
# Push to Hub using HuggingFace API
|
156
|
+
api = HfApi(token=api_token)
|
157
|
+
if create:
|
158
|
+
api.create_repo(
|
159
|
+
repo_id=repo_id,
|
160
|
+
private=private,
|
161
|
+
exist_ok=True,
|
162
|
+
)
|
163
|
+
|
164
|
+
# Push files to the repo
|
165
|
+
api.upload_folder(
|
166
|
+
repo_id=repo_id,
|
167
|
+
folder_path=temp_dir,
|
168
|
+
repo_type="model",
|
169
|
+
**kwargs
|
170
|
+
)
|
171
|
+
|
172
|
+
# Cleanup
|
173
|
+
os.remove(Path(temp_dir) / 'tokenizer.json')
|
174
|
+
os.rmdir(temp_dir)
|
175
|
+
|
176
|
+
@staticmethod
|
177
|
+
def hf_tokenizer_from_file(path: str):
|
178
|
+
return PreTrainedTokenizerFast(
|
179
|
+
tokenizer_file=path,
|
180
|
+
unk_token="[UNK]",
|
181
|
+
pad_token="[PAD]",
|
182
|
+
cls_token="[CLS]",
|
183
|
+
sep_token="[SEP]",
|
184
|
+
mask_token="[MASK]"
|
185
|
+
)
|
186
|
+
|
187
|
+
@classmethod
|
188
|
+
def from_pretrained(cls, repo_id: str, **kwargs):
|
189
|
+
"""
|
190
|
+
Load tokenizer from HuggingFace Hub.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
repo_id (str): Hub repository name (e.g., "username/my-tokenizer")
|
194
|
+
**kwargs: Additional args for HuggingFace API
|
195
|
+
"""
|
196
|
+
from huggingface_hub import hf_hub_download
|
197
|
+
|
198
|
+
# Download tokenizer.json from Hub
|
199
|
+
tokenizer_file = hf_hub_download(
|
200
|
+
repo_id=repo_id,
|
201
|
+
filename="tokenizer.json",
|
202
|
+
**kwargs
|
203
|
+
)
|
204
|
+
|
205
|
+
# Initialize trainer and load tokenizer
|
206
|
+
trainer = cls()
|
207
|
+
trainer.load(tokenizer_file)
|
208
|
+
return trainer
|
@@ -0,0 +1,324 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import math
|
5
|
+
from positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
6
|
+
|
7
|
+
|
8
|
+
class MultiHeadAttention(nn.Module):
|
9
|
+
"""Custom, extendable Multi-head attention layer, with RoPE support"""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
embed_dim: int,
|
14
|
+
num_heads: int,
|
15
|
+
dropout: float = 0.0,
|
16
|
+
rope: RotaryPositionalEmbedding = None,
|
17
|
+
rope_only_for_query: bool = False,
|
18
|
+
use_relative_embeddings: bool = False,
|
19
|
+
max_seq_len: int = 1024,
|
20
|
+
use_flash_attention: bool = False,
|
21
|
+
is_causal: bool = False,
|
22
|
+
use_bias: bool = False,
|
23
|
+
*args,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
super(MultiHeadAttention, self).__init__(*args, **kwargs)
|
27
|
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
28
|
+
self.embed_dim = embed_dim
|
29
|
+
self.num_heads = num_heads
|
30
|
+
|
31
|
+
self.use_flash_attention = use_flash_attention
|
32
|
+
self.is_causal = is_causal
|
33
|
+
self.use_bias = use_bias
|
34
|
+
if use_relative_embeddings:
|
35
|
+
self.use_flash_attention = False
|
36
|
+
self.rel_embed = RelativePositionalEmbedding(max_seq_len, embed_dim // num_heads)
|
37
|
+
self.rope = None
|
38
|
+
self.rope_only_for_query = False
|
39
|
+
else:
|
40
|
+
self.rel_embed = None
|
41
|
+
self.rope = rope
|
42
|
+
self.rope_only_for_query = rope_only_for_query
|
43
|
+
self.dropout = nn.Dropout(dropout)
|
44
|
+
self._init_q(embed_dim)
|
45
|
+
self._init_kv(embed_dim)
|
46
|
+
self._init_out(embed_dim)
|
47
|
+
|
48
|
+
def _init_q(self, embed_dim: int):
|
49
|
+
"""Initialize query projection"""
|
50
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=self.use_bias)
|
51
|
+
|
52
|
+
def _init_kv(self, embed_dim: int):
|
53
|
+
"""Initialize key and value projections"""
|
54
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=self.use_bias)
|
55
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=self.use_bias)
|
56
|
+
|
57
|
+
def _init_out(self, embed_dim: int):
|
58
|
+
"""Initialize output projection"""
|
59
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
60
|
+
|
61
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
62
|
+
"""Forward pass through query, key, and value projections, and split the results into heads"""
|
63
|
+
q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
|
64
|
+
k = self.k_proj(key).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
|
65
|
+
v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
|
66
|
+
return q, k, v
|
67
|
+
|
68
|
+
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
|
69
|
+
if self.rope is not None:
|
70
|
+
if self.rope_only_for_query:
|
71
|
+
q = self.rope.forward_one(q)
|
72
|
+
else:
|
73
|
+
q, k = self.rope(q, k)
|
74
|
+
return q, k
|
75
|
+
|
76
|
+
def _calculate_attn_weights(self, q: torch.Tensor, k: torch.Tensor, d: int, mask: torch.Tensor = None):
|
77
|
+
"""Calculate attention weights using scaled dot-product attention"""
|
78
|
+
q, k = self._apply_rope(q, k)
|
79
|
+
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (d // self.num_heads) ** 0.5
|
80
|
+
if mask is not None:
|
81
|
+
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
|
82
|
+
return F.softmax(attn_logits, dim=-1)
|
83
|
+
|
84
|
+
def _calculate_attn_weight_with_relative_embeddings(self, q: torch.Tensor, k: torch.Tensor,
|
85
|
+
mask: torch.Tensor = None):
|
86
|
+
"""Calculate attention weights using scaled dot-product attention and apply relative embedding"""
|
87
|
+
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
|
88
|
+
rel_pos_bias = self.rel_embed(q, k)
|
89
|
+
attn_logits += rel_pos_bias
|
90
|
+
if mask is not None:
|
91
|
+
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
|
92
|
+
return F.softmax(attn_logits, dim=-1)
|
93
|
+
|
94
|
+
def _calculate_output(self, attn_weights: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int):
|
95
|
+
"""Calculate the output by multiplying attention weights with values and concatenating heads"""
|
96
|
+
return torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(b, t, d)
|
97
|
+
|
98
|
+
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
99
|
+
mask: torch.Tensor = None, enable_gqa: bool = False):
|
100
|
+
attn_output = F.scaled_dot_product_attention(
|
101
|
+
q, k, v,
|
102
|
+
attn_mask=mask if not self.is_causal else None,
|
103
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
104
|
+
is_causal=self.is_causal,
|
105
|
+
enable_gqa=enable_gqa,
|
106
|
+
)
|
107
|
+
|
108
|
+
# Reshape back to (B, T, D)
|
109
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
110
|
+
|
111
|
+
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
112
|
+
mask: torch.Tensor = None):
|
113
|
+
# Compute attention with FlashAttention
|
114
|
+
return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
|
115
|
+
|
116
|
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
117
|
+
b, t, d = query.size()
|
118
|
+
q, k, v = self._forward_qkv(query, key, value, b, t, d)
|
119
|
+
if self.use_flash_attention:
|
120
|
+
q, k = self._apply_rope(q, k)
|
121
|
+
attn_output = self._calculate_flash_attention(q, k, v, b, t, d, mask=mask)
|
122
|
+
else:
|
123
|
+
if not self.rel_embed:
|
124
|
+
attn_weights = self._calculate_attn_weights(q, k, d, mask=mask)
|
125
|
+
else:
|
126
|
+
attn_weights = self._calculate_attn_weight_with_relative_embeddings(q, k, mask=mask)
|
127
|
+
|
128
|
+
attn_weights = self.dropout(attn_weights)
|
129
|
+
|
130
|
+
attn_output = self._calculate_output(attn_weights, v, b, t, d)
|
131
|
+
return self.out_proj(attn_output)
|
132
|
+
|
133
|
+
|
134
|
+
class GroupedQueryAttention(MultiHeadAttention):
|
135
|
+
"""Custom Grouped Query attention layer, with RoPE support"""
|
136
|
+
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
embed_dim: int,
|
140
|
+
num_heads: int,
|
141
|
+
num_groups: int,
|
142
|
+
dropout: float = 0.0,
|
143
|
+
rope: RotaryPositionalEmbedding = None,
|
144
|
+
rope_only_for_query: bool = False,
|
145
|
+
use_relative_embeddings: bool = False,
|
146
|
+
max_seq_len: int = 1024,
|
147
|
+
use_flash_attention: bool = False,
|
148
|
+
is_causal: bool = False,
|
149
|
+
use_bias: bool = False,
|
150
|
+
*args,
|
151
|
+
**kwargs,
|
152
|
+
):
|
153
|
+
self.num_groups = num_groups
|
154
|
+
super(GroupedQueryAttention, self).__init__(
|
155
|
+
embed_dim,
|
156
|
+
num_heads,
|
157
|
+
dropout=dropout,
|
158
|
+
rope=rope,
|
159
|
+
rope_only_for_query=rope_only_for_query,
|
160
|
+
use_relative_embeddings=use_relative_embeddings,
|
161
|
+
max_seq_len=max_seq_len,
|
162
|
+
use_flash_attention=use_flash_attention,
|
163
|
+
is_causal=is_causal,
|
164
|
+
use_bias=use_bias,
|
165
|
+
*args,
|
166
|
+
**kwargs,
|
167
|
+
)
|
168
|
+
assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
|
169
|
+
|
170
|
+
def _init_kv(self, embed_dim: int):
|
171
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
172
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
173
|
+
|
174
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
175
|
+
"""Override query, key, and value projections for GQA case - split data into heads and groups"""
|
176
|
+
head_dim = d // self.num_heads
|
177
|
+
if self.use_flash_attention:
|
178
|
+
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
179
|
+
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
180
|
+
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
181
|
+
else:
|
182
|
+
group_heads = self.num_heads // self.num_groups
|
183
|
+
|
184
|
+
# Process Q
|
185
|
+
q = self.q_proj(query).view(b, t, self.num_groups, group_heads, head_dim).permute(0, 2, 3, 1,
|
186
|
+
4) # (B, G, group_heads, T, head_dim)
|
187
|
+
|
188
|
+
# Process K and V
|
189
|
+
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
|
190
|
+
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
|
191
|
+
|
192
|
+
# Expand and flatten to 4D tensors
|
193
|
+
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
194
|
+
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
195
|
+
|
196
|
+
q = q.flatten(start_dim=1, end_dim=2) # (B, H, T, head_dim)
|
197
|
+
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
198
|
+
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
199
|
+
return q, k, v
|
200
|
+
|
201
|
+
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
202
|
+
mask: torch.Tensor = None):
|
203
|
+
return self._flash_attention(
|
204
|
+
q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
|
205
|
+
enable_gqa=(self.num_heads != self.num_groups)
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
class MultiQueryAttention(MultiHeadAttention):
|
210
|
+
"""Custom Multi Query attention layer, with RoPE support"""
|
211
|
+
|
212
|
+
def __init__(
|
213
|
+
self,
|
214
|
+
embed_dim: int,
|
215
|
+
num_heads: int,
|
216
|
+
dropout: float = 0.0,
|
217
|
+
rope: RotaryPositionalEmbedding = None,
|
218
|
+
rope_only_for_query: bool = False,
|
219
|
+
use_relative_embeddings: bool = False,
|
220
|
+
max_seq_len: int = 1024,
|
221
|
+
use_flash_attention: bool = False,
|
222
|
+
is_causal: bool = False,
|
223
|
+
use_bias: bool = False,
|
224
|
+
*args,
|
225
|
+
**kwargs,
|
226
|
+
):
|
227
|
+
super(MultiQueryAttention, self).__init__(
|
228
|
+
embed_dim,
|
229
|
+
num_heads,
|
230
|
+
dropout=dropout,
|
231
|
+
rope=rope,
|
232
|
+
rope_only_for_query=rope_only_for_query,
|
233
|
+
use_relative_embeddings=use_relative_embeddings,
|
234
|
+
max_seq_len=max_seq_len,
|
235
|
+
use_flash_attention=use_flash_attention,
|
236
|
+
is_causal=is_causal,
|
237
|
+
use_bias=use_bias,
|
238
|
+
*args,
|
239
|
+
**kwargs
|
240
|
+
)
|
241
|
+
|
242
|
+
def _init_kv(self, embed_dim: int):
|
243
|
+
"""Override key/value initialization for MQA case"""
|
244
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim // self.num_heads, bias=self.use_bias)
|
245
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim // self.num_heads, bias=self.use_bias)
|
246
|
+
|
247
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
248
|
+
"""Override query, key, and value projections for GQA case - use multiple heads
|
249
|
+
for query and single for key/values"""
|
250
|
+
if self.use_flash_attention:
|
251
|
+
q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
|
252
|
+
k = self.k_proj(key).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
|
253
|
+
v = self.v_proj(value).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
|
254
|
+
else:
|
255
|
+
q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
|
256
|
+
k = self.k_proj(key).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
257
|
+
v = self.v_proj(value).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
258
|
+
return q, k, v
|
259
|
+
|
260
|
+
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
261
|
+
mask: torch.Tensor = None):
|
262
|
+
return self._flash_attention(
|
263
|
+
q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
|
264
|
+
enable_gqa=True
|
265
|
+
)
|
266
|
+
|
267
|
+
|
268
|
+
def init_attention(
|
269
|
+
embed_dim: int,
|
270
|
+
num_heads: int,
|
271
|
+
attention_type: str,
|
272
|
+
gqa_groups: int = 1,
|
273
|
+
dropout: float = 0.0,
|
274
|
+
rope: RotaryPositionalEmbedding = None,
|
275
|
+
rope_only_for_query: bool = False,
|
276
|
+
use_relative_embeddings: bool = False,
|
277
|
+
max_seq_len: int = 1024,
|
278
|
+
use_flash_attention: bool = False,
|
279
|
+
is_causal: bool = False,
|
280
|
+
use_bias: bool = False,
|
281
|
+
) -> MultiHeadAttention:
|
282
|
+
assert attention_type == 'mha' or attention_type == 'gqa' or attention_type == 'mqa', \
|
283
|
+
"Error, attention type should be one of: 'mha', 'gqa', 'mqa'"
|
284
|
+
|
285
|
+
if attention_type == "gqa":
|
286
|
+
return GroupedQueryAttention(
|
287
|
+
embed_dim,
|
288
|
+
num_heads,
|
289
|
+
gqa_groups,
|
290
|
+
dropout=dropout,
|
291
|
+
rope=rope,
|
292
|
+
use_relative_embeddings=use_relative_embeddings,
|
293
|
+
max_seq_len=max_seq_len,
|
294
|
+
rope_only_for_query=rope_only_for_query,
|
295
|
+
use_flash_attention=use_flash_attention,
|
296
|
+
is_causal=is_causal,
|
297
|
+
use_bias=use_bias,
|
298
|
+
)
|
299
|
+
elif attention_type == "mqa":
|
300
|
+
return MultiQueryAttention(
|
301
|
+
embed_dim,
|
302
|
+
num_heads,
|
303
|
+
dropout=dropout,
|
304
|
+
rope=rope,
|
305
|
+
use_relative_embeddings=use_relative_embeddings,
|
306
|
+
max_seq_len=max_seq_len,
|
307
|
+
rope_only_for_query=rope_only_for_query,
|
308
|
+
use_flash_attention=use_flash_attention,
|
309
|
+
is_causal=is_causal,
|
310
|
+
use_bias=use_bias,
|
311
|
+
)
|
312
|
+
else:
|
313
|
+
return MultiHeadAttention(
|
314
|
+
embed_dim,
|
315
|
+
num_heads,
|
316
|
+
dropout=dropout,
|
317
|
+
rope=rope,
|
318
|
+
use_relative_embeddings=use_relative_embeddings,
|
319
|
+
max_seq_len=max_seq_len,
|
320
|
+
rope_only_for_query=rope_only_for_query,
|
321
|
+
use_flash_attention=use_flash_attention,
|
322
|
+
is_causal=is_causal,
|
323
|
+
use_bias=use_bias,
|
324
|
+
)
|
src/transformers/ff.py
ADDED
@@ -0,0 +1,72 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
|
5
|
+
class FeedForward(nn.Module):
|
6
|
+
"""Basic Feed-forward layer with activation function and optional dropout"""
|
7
|
+
|
8
|
+
def __init__(self, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float = 0.0, *args, **kwargs):
|
9
|
+
super(FeedForward, self).__init__(*args, **kwargs)
|
10
|
+
self.fc1 = nn.Linear(embed_dim, hidden_dim)
|
11
|
+
self.activation = activation
|
12
|
+
self.fc2 = nn.Linear(hidden_dim, embed_dim)
|
13
|
+
self.dropout = nn.Dropout(dropout)
|
14
|
+
|
15
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
16
|
+
x = self.fc1(x)
|
17
|
+
x = self.activation(x)
|
18
|
+
x = self.dropout(x)
|
19
|
+
return self.fc2(x)
|
20
|
+
|
21
|
+
|
22
|
+
class GatedLinearUnit(nn.Module):
|
23
|
+
"""Gated linear unit layer with configurable activation (SwiGLU, ReGLU, etc.)"""
|
24
|
+
|
25
|
+
def __init__(self, embed_dim: int, hidden_dim: int, activation: nn.Module, *args, **kwargs):
|
26
|
+
super(GatedLinearUnit, self).__init__(*args, **kwargs)
|
27
|
+
self.linear = nn.Linear(embed_dim, hidden_dim * 2)
|
28
|
+
self.activation = activation
|
29
|
+
|
30
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
31
|
+
l, g = self.linear(x).chunk(2, dim=-1)
|
32
|
+
return l * self.activation(g)
|
33
|
+
|
34
|
+
|
35
|
+
class GatedFeedForward(nn.Module):
|
36
|
+
"""Gated feed-forward layer with activation function and optional dropout"""
|
37
|
+
|
38
|
+
def __init__(self, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float = 0.0, *args, **kwargs):
|
39
|
+
super(GatedFeedForward, self).__init__(*args, **kwargs)
|
40
|
+
self.fc1 = GatedLinearUnit(embed_dim, hidden_dim, activation)
|
41
|
+
self.fc2 = nn.Linear(hidden_dim, embed_dim)
|
42
|
+
self.dropout = nn.Dropout(dropout)
|
43
|
+
|
44
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
45
|
+
x = self.fc1(x)
|
46
|
+
x = self.dropout(x)
|
47
|
+
return self.fc2(x)
|
48
|
+
|
49
|
+
|
50
|
+
class LinearActivation(nn.Module):
|
51
|
+
"""Linear activation - identity function, for Bilinear Gated Unit"""
|
52
|
+
|
53
|
+
def __init__(self, *args, **kwargs):
|
54
|
+
super(LinearActivation, self).__init__(*args, **kwargs)
|
55
|
+
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
57
|
+
return x
|
58
|
+
|
59
|
+
|
60
|
+
def get_activation_layer(activation: str):
|
61
|
+
if activation == 'relu':
|
62
|
+
return nn.ReLU()
|
63
|
+
elif activation == 'gelu':
|
64
|
+
return nn.GELU()
|
65
|
+
elif activation == 'silu' or activation == 'swish':
|
66
|
+
return nn.SiLU()
|
67
|
+
elif activation == 'sigmoid':
|
68
|
+
return nn.Sigmoid()
|
69
|
+
elif activation == 'linear':
|
70
|
+
return LinearActivation()
|
71
|
+
else:
|
72
|
+
raise ValueError(f'Activation {activation} not supported')
|