mortm 4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mortm/__init__.py +0 -0
- mortm/constants.py +31 -0
- mortm/models/__init__.py +0 -0
- mortm/models/bertm.py +294 -0
- mortm/models/modules/PositionalEncoding.py +27 -0
- mortm/models/modules/__init__.py +0 -0
- mortm/models/modules/attention.py +300 -0
- mortm/models/modules/audio_patch.py +44 -0
- mortm/models/modules/config.py +77 -0
- mortm/models/modules/layers.py +471 -0
- mortm/models/modules/progress.py +52 -0
- mortm/models/mortm.py +338 -0
- mortm/models/mortm_live.py +26 -0
- mortm/models/v_mortm.py +65 -0
- mortm/train/__init__.py +0 -0
- mortm/train/config.py +55 -0
- mortm/train/custom_token.py +603 -0
- mortm/train/datasets.py +321 -0
- mortm/train/epoch.py +20 -0
- mortm/train/noam.py +7 -0
- mortm/train/rl/__init__.py +0 -0
- mortm/train/rl/reinforcement.py +207 -0
- mortm/train/tokenizer.py +204 -0
- mortm/train/train.py +686 -0
- mortm/train/utils/__init__.py +0 -0
- mortm/train/utils/chord_midi.py +47 -0
- mortm/train/utils/loss.py +135 -0
- mortm/utils/__init__.py +0 -0
- mortm/utils/convert.py +1220 -0
- mortm/utils/de_convert.py +40 -0
- mortm/utils/eval.py +155 -0
- mortm/utils/generate.py +149 -0
- mortm/utils/gmail_messanger.py +66 -0
- mortm/utils/key.py +354 -0
- mortm/utils/messager.py +21 -0
- mortm/utils/pianoroll_convert.py +182 -0
- mortm/utils/tag.py +97 -0
- mortm-4.5.dist-info/METADATA +254 -0
- mortm-4.5.dist-info/RECORD +41 -0
- mortm-4.5.dist-info/WHEEL +5 -0
- mortm-4.5.dist-info/top_level.txt +1 -0
mortm/__init__.py
ADDED
|
File without changes
|
mortm/constants.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
PITCH_MAX = 128
|
|
4
|
+
VELO = 128
|
|
5
|
+
LENGTH = 999
|
|
6
|
+
LENGTH_HALF = 999
|
|
7
|
+
BEGIN = 999
|
|
8
|
+
BEGIN_HALF = 999
|
|
9
|
+
ROOT = 99
|
|
10
|
+
START_SEQ_TOKEN = "<S_SEQ>"
|
|
11
|
+
END_SEQ_TOKEN = "<E_SEQ>"
|
|
12
|
+
PADDING_TOKEN = "<PAD>"
|
|
13
|
+
|
|
14
|
+
MODEL_NAME = "MORTM"
|
|
15
|
+
|
|
16
|
+
# 前回のトークンのID + 前回のトークンの使用個数
|
|
17
|
+
PADDING_BEGIN_ID = 0
|
|
18
|
+
SPECIAL_BEGIN_ID = PADDING_BEGIN_ID + 1
|
|
19
|
+
PITCH_BEGIN_ID = SPECIAL_BEGIN_ID + 2
|
|
20
|
+
VELOCITY_BEGIN_ID = PITCH_BEGIN_ID + 128
|
|
21
|
+
DURATION_BEGIN_ID = VELOCITY_BEGIN_ID + 128
|
|
22
|
+
START_BEGIN_ID = DURATION_BEGIN_ID + 100
|
|
23
|
+
SHIFT_BEGIN_ID = START_BEGIN_ID + 32
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
PITCH_GROUP = range(PITCH_BEGIN_ID, PITCH_BEGIN_ID + 128 + 1)
|
|
27
|
+
VELOCITY_GROUP = range(VELOCITY_BEGIN_ID, VELOCITY_BEGIN_ID + 128 + 1)
|
|
28
|
+
DURATION_GROUP = range(DURATION_BEGIN_ID, DURATION_BEGIN_ID + 100 + 1)
|
|
29
|
+
START_GROUP = range(START_BEGIN_ID, START_BEGIN_ID + 32 + 1)
|
|
30
|
+
SHIFT_GROUP = range(SHIFT_BEGIN_ID, SHIFT_BEGIN_ID + 4 + 1)
|
|
31
|
+
|
mortm/models/__init__.py
ADDED
|
File without changes
|
mortm/models/bertm.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from torch.distributions import Categorical
|
|
4
|
+
from .modules.layers import *
|
|
5
|
+
from .modules.config import MORTMArgs
|
|
6
|
+
|
|
7
|
+
from flash_attn.bert_padding import unpad_input, pad_input
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ActorCritic(nn.Module):
|
|
12
|
+
def __init__(self, args: MORTMArgs, progress):
|
|
13
|
+
super(ActorCritic, self).__init__()
|
|
14
|
+
self.args = args
|
|
15
|
+
self.progress = progress
|
|
16
|
+
self.e_layer = args.e_layer
|
|
17
|
+
self.d_layer = args.d_layer
|
|
18
|
+
self.num_heads = args.num_heads
|
|
19
|
+
self.d_model = args.d_model
|
|
20
|
+
self.dim_feedforward = args.dim_feedforward
|
|
21
|
+
self.dropout = args.dropout
|
|
22
|
+
self.use_lora = args.use_lora
|
|
23
|
+
|
|
24
|
+
self.decoder = MORTMDecoder(args, progress=progress)
|
|
25
|
+
|
|
26
|
+
print(f"Input Vocab Size:{args.vocab_size}")
|
|
27
|
+
self.embedding: nn.Embedding = nn.Embedding(args.vocab_size, self.d_model, padding_idx=0).to(self.progress.get_device())
|
|
28
|
+
if not self.use_lora:
|
|
29
|
+
self.Wout: nn.Linear = nn.Linear(self.d_model, args.vocab_size).to(self.progress.get_device())
|
|
30
|
+
else:
|
|
31
|
+
self.Wout: lora.Linear = lora.Linear(self.d_model, args.vocab_size, r=args.lora_r, lora_alpha=args.lora_alpha)
|
|
32
|
+
|
|
33
|
+
self.critic_hidden = nn.Linear(self.d_model, self.d_model // 2)
|
|
34
|
+
self.critic_out = nn.Linear(self.d_model // 2, 1) # 出力次元を1に設定
|
|
35
|
+
|
|
36
|
+
self.softmax: nn.Softmax = nn.Softmax(dim=-1).to(self.progress.get_device())
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def evaluate_actions(self, sequence_tensors, padding_mask):
|
|
40
|
+
# 1. 入力とターゲットを作成(1つずらす)
|
|
41
|
+
input_ids = sequence_tensors[:, :-1]
|
|
42
|
+
target_ids = sequence_tensors[:, 1:]
|
|
43
|
+
|
|
44
|
+
# マスクも同様にずらす
|
|
45
|
+
if padding_mask is not None:
|
|
46
|
+
input_mask = padding_mask[:, :-1]
|
|
47
|
+
else:
|
|
48
|
+
input_mask = None
|
|
49
|
+
|
|
50
|
+
# 2. モデルに`input_ids`を通して、各ステップのlogitsを取得
|
|
51
|
+
logits, new_values = self.forward(input_ids, padding_mask=input_mask, is_causal=True)
|
|
52
|
+
|
|
53
|
+
reshaped_logits = logits.view(-1, self.args.vocab_size)
|
|
54
|
+
reshaped_targets = target_ids.reshape(-1).long()
|
|
55
|
+
# Categorical分布を使って一括で計算
|
|
56
|
+
dist = Categorical(logits=reshaped_logits)
|
|
57
|
+
log_probs = dist.log_prob(reshaped_targets)
|
|
58
|
+
|
|
59
|
+
# 元の形状 (batch, seq_len-1) に戻す
|
|
60
|
+
log_probs = log_probs.view(logits.size(0), logits.size(1))
|
|
61
|
+
|
|
62
|
+
# パディング部分のlog_probを0にする
|
|
63
|
+
if padding_mask is not None:
|
|
64
|
+
log_probs = log_probs * padding_mask[:, 1:]
|
|
65
|
+
return log_probs, new_values.reshape(new_values.shape[0], new_values.shape[1])
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def eval_seq(self, src, print_log=True):
|
|
69
|
+
"""
|
|
70
|
+
KVキャッシュを利用してトークンを生成するためのメソッドです。
|
|
71
|
+
複数バッチに対応しています。
|
|
72
|
+
"""
|
|
73
|
+
self.eval()
|
|
74
|
+
is_running = True
|
|
75
|
+
end_count = 0
|
|
76
|
+
device = self.progress.get_device()
|
|
77
|
+
|
|
78
|
+
if isinstance(src, numpy.ndarray):
|
|
79
|
+
src = torch.tensor(src, device=device)
|
|
80
|
+
if src.dim() == 1:
|
|
81
|
+
src = src.unsqueeze(0)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# --- 1. プロンプト処理 (Pre-fill) ---
|
|
85
|
+
if print_log: print("--- Pre-fill Phase ---")
|
|
86
|
+
prompt_padding_mask = (src != self.embedding.padding_idx)
|
|
87
|
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
88
|
+
logits, values = self.forward(src, padding_mask=prompt_padding_mask, is_causal=True, is_save_cache=True)
|
|
89
|
+
prob_list = Categorical(logits=logits)
|
|
90
|
+
|
|
91
|
+
last_token_logits = logits[:, -1, :]
|
|
92
|
+
# next_tokens は (batch_size,) の形状を持つテンソル
|
|
93
|
+
next_tokens = self.top_p_sampling(last_token_logits, p=0.95, temperature=1.0)
|
|
94
|
+
|
|
95
|
+
# 全トークンを保持するテンソル
|
|
96
|
+
all_tokens = torch.cat([src, next_tokens.unsqueeze(1)], dim=1)
|
|
97
|
+
all_values = values
|
|
98
|
+
all_probs = prob_list.log_prob(all_tokens[:, 1:])
|
|
99
|
+
# --- 2. トークン生成 (Decoding) ---
|
|
100
|
+
i = len(all_tokens)
|
|
101
|
+
if print_log: print("\n--- Decoding Phase ---")
|
|
102
|
+
while is_running:
|
|
103
|
+
# 入力は直前に生成されたトークン (B, 1)
|
|
104
|
+
input_tokens = next_tokens
|
|
105
|
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
106
|
+
logits, values = self.forward(input_tokens, padding_mask=None, is_causal=True, is_save_cache=True)
|
|
107
|
+
probs_list = Categorical(logits=logits.squeeze(1))
|
|
108
|
+
# next_tokens は (batch_size,) の形状を持つテンソル
|
|
109
|
+
next_tokens = self.top_p_sampling(logits.squeeze(1), p=0.95, temperature=1.0)
|
|
110
|
+
|
|
111
|
+
#print(next_tokens.max(), all_tokens.max())
|
|
112
|
+
# 生成されたトークンを連結
|
|
113
|
+
all_tokens = torch.cat([all_tokens, next_tokens.unsqueeze(1)], dim=1)
|
|
114
|
+
all_values = torch.cat([all_values, values.unsqueeze(1)], dim=1)
|
|
115
|
+
all_probs = torch.cat([all_probs, probs_list.log_prob(next_tokens).unsqueeze(1)], dim=1)
|
|
116
|
+
|
|
117
|
+
if print_log: print(f"\r Step {i+1}: Generated tokens {next_tokens.tolist()}", end="")
|
|
118
|
+
|
|
119
|
+
if self.is_end_point(all_tokens) or i > self.args.position_length:
|
|
120
|
+
is_running = False
|
|
121
|
+
|
|
122
|
+
i += 1
|
|
123
|
+
print(all_tokens.shape, all_values.shape, all_probs.shape)
|
|
124
|
+
np_all_tokens = []
|
|
125
|
+
np_all_values = []
|
|
126
|
+
np_all_probs = []
|
|
127
|
+
np_generated_only_tokens = []
|
|
128
|
+
if print_log: print(all_tokens.max())
|
|
129
|
+
for i, seq in enumerate(all_tokens):
|
|
130
|
+
seq: Tensor
|
|
131
|
+
np_seq = np.array([], dtype=int)
|
|
132
|
+
np_value = np.array([], dtype=float)
|
|
133
|
+
np_probs = np.array([], dtype=float)
|
|
134
|
+
pad = (seq == 0).nonzero(as_tuple=True)[0]
|
|
135
|
+
eseq = (seq == 585).nonzero(as_tuple=True)[0]
|
|
136
|
+
|
|
137
|
+
# ... eseq の長さを決定するロジックは変更なし ...
|
|
138
|
+
if len(eseq) == 0:
|
|
139
|
+
eseq = len(seq)-1
|
|
140
|
+
elif len(eseq) != 1:
|
|
141
|
+
eseq = eseq[0].item()
|
|
142
|
+
else:
|
|
143
|
+
eseq = eseq.item()
|
|
144
|
+
|
|
145
|
+
if len(pad) != 0:
|
|
146
|
+
start = pad[0]
|
|
147
|
+
end = pad[-1]
|
|
148
|
+
|
|
149
|
+
# トークン部分は変更なし
|
|
150
|
+
np_seq = np.append(np_seq, seq[:start].cpu().numpy())
|
|
151
|
+
# --- 修正: スライスの終点を-1する ---
|
|
152
|
+
np_value = all_values[i, :start-1].cpu().numpy()
|
|
153
|
+
np_probs = all_probs[i, :start-1].cpu().numpy()
|
|
154
|
+
|
|
155
|
+
if eseq == len(seq):
|
|
156
|
+
# トークン部分は変更なし
|
|
157
|
+
np_seq = np.append(np_seq, seq[end+1:].cpu().numpy())
|
|
158
|
+
# --- 修正: スライスの始点と終点を-1する ---
|
|
159
|
+
np_value = np.append(np_value, all_values[i, end:].cpu().numpy()) # eseqはlen(seq)-1なので-1不要
|
|
160
|
+
np_probs = np.append(np_probs, all_probs[i, end:].cpu().numpy())
|
|
161
|
+
if print_log:print(f"fdkso LEN : {np_seq.shape, np_value.shape, np_probs.shape}")
|
|
162
|
+
|
|
163
|
+
else:
|
|
164
|
+
# トークン部分は変更なし
|
|
165
|
+
np_seq = np.append(np_seq, seq[end+1:eseq+1].cpu().numpy())
|
|
166
|
+
# --- 修正: スライスの始点と終点を-1する ---
|
|
167
|
+
np_value = np.append(np_value, all_values[i, end:eseq].cpu().numpy())
|
|
168
|
+
np_probs = np.append(np_probs, all_probs[i, end:eseq].cpu().numpy())
|
|
169
|
+
if print_log:print(f"!WDFG : {np_seq.shape, np_value.shape, np_probs.shape}")
|
|
170
|
+
else: # パディングがない場合
|
|
171
|
+
if eseq == len(seq):
|
|
172
|
+
# トークン部分は変更なし
|
|
173
|
+
np_seq = np.append(np_seq, seq.cpu().numpy())
|
|
174
|
+
# --- 修正: 全体をスライスするが、長さは元々1短いのでそのままでOK ---
|
|
175
|
+
np_value = np.append(np_value, all_values[i].cpu().numpy())
|
|
176
|
+
np_probs = np.append(np_probs, all_probs[i].cpu().numpy())
|
|
177
|
+
if print_log:print(f"ESEQ LEN : {np_seq.shape, np_value.shape, np_probs.shape}")
|
|
178
|
+
else:
|
|
179
|
+
# トークン部分は変更なし
|
|
180
|
+
np_seq = np.append(np_seq, seq[:eseq+1].cpu().numpy())
|
|
181
|
+
# --- 修正: スライスの終点を-1する ---
|
|
182
|
+
np_value = np.append(np_value, all_values[i, :eseq].cpu().numpy())
|
|
183
|
+
np_probs = np.append(np_probs, all_probs[i, :eseq].cpu().numpy())
|
|
184
|
+
if print_log: print(f"Not ESEQ LEN : {np_seq.shape, np_value.shape, np_probs.shape}")
|
|
185
|
+
|
|
186
|
+
np_all_tokens.append(np_seq)
|
|
187
|
+
np_all_values.append(np_value)
|
|
188
|
+
np_all_probs.append(np_probs)
|
|
189
|
+
if np_seq.max() > self.args.vocab_size:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"生成されたトークンIDが語彙サイズ({self.args.vocab_size})を超えています: {np_seq.max()}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return np_all_tokens, np_all_values, np_all_probs
|
|
195
|
+
|
|
196
|
+
def forward(self, x, padding_mask=None, is_causal=False, is_save_cache=False):
|
|
197
|
+
x: Tensor = self.embedding(x).to(dtype=torch.bfloat16)
|
|
198
|
+
if padding_mask is not None:
|
|
199
|
+
batch, tgt_len, embed_dim = x.size()
|
|
200
|
+
x, indices, cu_seqlens, max_s, used_seqlens = unpad_input(x, padding_mask)
|
|
201
|
+
else:
|
|
202
|
+
tgt_len, embed_dim = x.size()
|
|
203
|
+
batch = None
|
|
204
|
+
indices = cu_seqlens = max_s = used_seqlens = None
|
|
205
|
+
out = self.decoder(tgt=x, tgt_is_causal=is_causal, cu_seqlens=cu_seqlens, max_seqlen=max_s,
|
|
206
|
+
batch_size=batch, indices=indices, is_save_cache=is_save_cache)
|
|
207
|
+
if padding_mask is not None:
|
|
208
|
+
out = pad_input(out, indices, batch, tgt_len)
|
|
209
|
+
|
|
210
|
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
|
211
|
+
score: Tensor = self.Wout(out)
|
|
212
|
+
hidden = self.critic_hidden(out)
|
|
213
|
+
hidden = F.relu(hidden)
|
|
214
|
+
critic_score = self.critic_out(hidden)
|
|
215
|
+
return score, critic_score
|
|
216
|
+
|
|
217
|
+
def is_end_point(self, x: torch.Tensor) -> bool:
|
|
218
|
+
"""
|
|
219
|
+
x: Tensor of shape [n, 14]
|
|
220
|
+
戻り値: 全ての行に少なくとも1つ 5 があれば True、そうでなければ False
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
mask = (x == 585) | (x == 586)
|
|
224
|
+
per_row_has5 = mask.any(dim=1)
|
|
225
|
+
# 3) 全行が True かを判定する
|
|
226
|
+
all_rows_ok = per_row_has5.all()
|
|
227
|
+
|
|
228
|
+
# 4) Python の bool 型で返す
|
|
229
|
+
return bool(all_rows_ok)
|
|
230
|
+
|
|
231
|
+
def top_p_sampling(self, logits: Tensor, p=0.9, temperature=1.0) -> Tensor:
|
|
232
|
+
"""
|
|
233
|
+
複数バッチに対応したTop-pサンプリング。(修正版)
|
|
234
|
+
"""
|
|
235
|
+
logits = logits / temperature
|
|
236
|
+
probs = self.softmax(logits)
|
|
237
|
+
|
|
238
|
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
|
239
|
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
240
|
+
|
|
241
|
+
sorted_probs_to_remove = cumulative_probs > p
|
|
242
|
+
sorted_probs_to_remove[..., 1:] = sorted_probs_to_remove[..., :-1].clone()
|
|
243
|
+
sorted_probs_to_remove[..., 0] = 0
|
|
244
|
+
|
|
245
|
+
probs_to_keep = sorted_probs.masked_fill(sorted_probs_to_remove, 0)
|
|
246
|
+
|
|
247
|
+
# ゼロ除算を避けるため、分母に微小な値を加える
|
|
248
|
+
probs_sum = probs_to_keep.sum(dim=-1, keepdim=True)
|
|
249
|
+
renormalized_probs = probs_to_keep / (probs_sum + 1e-9) #
|
|
250
|
+
|
|
251
|
+
sampled_next_indices = torch.multinomial(renormalized_probs, num_samples=1)
|
|
252
|
+
sampled_original_indices = torch.gather(sorted_indices, dim=-1, index=sampled_next_indices)
|
|
253
|
+
|
|
254
|
+
r = sampled_original_indices.squeeze(-1)
|
|
255
|
+
|
|
256
|
+
# vocab_sizeを直接取得してチェックする
|
|
257
|
+
vocab_size = logits.shape[-1]
|
|
258
|
+
if r.max().item() > vocab_size:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f"サンプリングされたトークンIDが語彙サイズ({vocab_size})以上です: {r.max().item()}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return r
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class BERTM(nn.Module):
|
|
267
|
+
|
|
268
|
+
def __init__(self, args: MORTMArgs, progress):
|
|
269
|
+
super(BERTM, self).__init__()
|
|
270
|
+
self.args = args # argsを保存しておくと便利
|
|
271
|
+
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
|
|
272
|
+
self.decoder = MORTMDecoder(args=args,
|
|
273
|
+
progress=progress)
|
|
274
|
+
self.attn_pool = Pool(args)
|
|
275
|
+
self.hidden = nn.Linear(args.d_model, args.d_model // 2)
|
|
276
|
+
self.Wout = nn.Linear(args.d_model // 2, 1) # linear層の出力次元に合わせる
|
|
277
|
+
|
|
278
|
+
def forward(self, x: Tensor, padding_mask=None):
|
|
279
|
+
x: Tensor = self.embedding(x).to(dtype=torch.bfloat16)
|
|
280
|
+
|
|
281
|
+
if padding_mask is not None:
|
|
282
|
+
x, indices, cu_seqlens, max_s, used_seqlens = unpad_input(x, padding_mask)
|
|
283
|
+
else:
|
|
284
|
+
indices = cu_seqlens = max_s = used_seqlens = None
|
|
285
|
+
|
|
286
|
+
out = self.decoder(tgt=x, tgt_is_causal=False, cu_seqlens=cu_seqlens, max_seqlen=max_s)
|
|
287
|
+
|
|
288
|
+
out = self.attn_pool(out, cu_seqlens if cu_seqlens is not None else torch.tensor([0, len(x)], dtype=torch.int32, device=x.device)) # バッチサイズをcu_seqlensに設定
|
|
289
|
+
|
|
290
|
+
out = self.hidden(out)
|
|
291
|
+
hid = F.relu(out)
|
|
292
|
+
score = self.Wout(hid)
|
|
293
|
+
|
|
294
|
+
return score
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from .progress import LearningProgress
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PositionalEncoding(nn.Module):
|
|
9
|
+
|
|
10
|
+
def __init__(self, d_model, progress: LearningProgress, dropout=0.1, max_len=5000):
|
|
11
|
+
super(PositionalEncoding, self).__init__()
|
|
12
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
13
|
+
|
|
14
|
+
# Positional Encodingのテンソルを生成
|
|
15
|
+
pe = torch.zeros(max_len, d_model, device=progress.get_device())
|
|
16
|
+
position = torch.arange(0, max_len, dtype=torch.float, device=progress.get_device()).unsqueeze(1)
|
|
17
|
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(progress.get_device())
|
|
18
|
+
|
|
19
|
+
pe[:, 0::2] = torch.sin(position * div_term,)
|
|
20
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
21
|
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
22
|
+
|
|
23
|
+
self.register_buffer('pe', pe)
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
x = x + self.pe[:x.size(0), :]
|
|
27
|
+
return self.dropout(x)
|
|
File without changes
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from torch.nn.parameter import Parameter
|
|
4
|
+
from torch.nn.init import *
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
import loralib.layers as lora
|
|
7
|
+
|
|
8
|
+
from torch.nn.functional import linear, softmax, dropout
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
import math
|
|
14
|
+
from einops import rearrange
|
|
15
|
+
|
|
16
|
+
from .config import MORTMArgs
|
|
17
|
+
try:
|
|
18
|
+
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb
|
|
19
|
+
IS_NOT_LINUX = True #本当はFlase
|
|
20
|
+
except ImportError as i:
|
|
21
|
+
IS_NOT_LINUX = True
|
|
22
|
+
print(f"モジュールをインストールできませんでした。(WindowsではFlashを利用できません)\n {i.name}")
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from flash_attn.bert_padding import pad_input, unpad_input
|
|
26
|
+
from flash_attn.flash_attn_interface import *
|
|
27
|
+
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func, flash_attn_qkvpacked_func
|
|
28
|
+
except ImportError as i:
|
|
29
|
+
print(f"モジュールをインストールできませんでした。\n {i.name}")
|
|
30
|
+
|
|
31
|
+
# FlashAttention2 の関数(flash_attn_func)をインポート
|
|
32
|
+
# (ライブラリがダウンロード済みであると仮定)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def marge_cache(kv_cache: Optional[Tuple[Tensor, Tensor]], cache_seqlens: Optional[Tensor],
|
|
37
|
+
k: Tensor, v: Tensor) -> Tuple[Optional[Tuple[Tensor, Tensor]], Optional[Tensor]]:
|
|
38
|
+
for i in range(k.shape[0]):
|
|
39
|
+
pos = cache_seqlens[i] # シーケンス内の位置
|
|
40
|
+
if pos >= kv_cache[0].shape[1]:
|
|
41
|
+
kv_cache[0] = torch.cat([kv_cache[0],torch.zeros_like(kv_cache[0][:, :1])], dim=1)
|
|
42
|
+
kv_cache[1] = torch.cat([kv_cache[1],torch.zeros_like(kv_cache[1][:, :1])], dim=1)
|
|
43
|
+
|
|
44
|
+
kv_cache[0][i, pos, :, :] = k[i, 0] # バッチi, スロットposに格納
|
|
45
|
+
kv_cache[1][i, pos, :, :] = v[i, 0]
|
|
46
|
+
cache_seqlens[i] += 1
|
|
47
|
+
|
|
48
|
+
return kv_cache, cache_seqlens
|
|
49
|
+
|
|
50
|
+
def get_alibi_slopes(n_heads):
|
|
51
|
+
"""
|
|
52
|
+
ALiBi のスロープを計算する関数。
|
|
53
|
+
n_heads が 2 のべき乗の場合はシンプルな幾何級数になり、
|
|
54
|
+
そうでない場合は補間してスロープを拡張します。
|
|
55
|
+
"""
|
|
56
|
+
def get_slopes_power_of_2(n):
|
|
57
|
+
start = 2 ** (-2 ** -(math.log2(n) - 3))
|
|
58
|
+
return [start * (start ** i) for i in range(n)]
|
|
59
|
+
|
|
60
|
+
if math.log2(n_heads).is_integer():
|
|
61
|
+
slopes = get_slopes_power_of_2(n_heads)
|
|
62
|
+
else:
|
|
63
|
+
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
|
64
|
+
slopes = get_slopes_power_of_2(closest_power_of_2)
|
|
65
|
+
extra = get_alibi_slopes(2 * closest_power_of_2)[0::2]
|
|
66
|
+
slopes.extend(extra[: n_heads - closest_power_of_2])
|
|
67
|
+
return slopes
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class QKVLinear(nn.Module):
|
|
71
|
+
def __init__(self, args: MORTMArgs, use_cross_attention: bool=False):
|
|
72
|
+
super(QKVLinear, self).__init__()
|
|
73
|
+
self.num_heads = args.num_heads
|
|
74
|
+
self.drop_out = nn.Dropout(args.dropout)
|
|
75
|
+
self.use_cross_attention = use_cross_attention
|
|
76
|
+
|
|
77
|
+
if not use_cross_attention:
|
|
78
|
+
if not args.use_lora:
|
|
79
|
+
self.qkv_weight = nn.Linear(args.d_model, 3 * args.d_model, bias=False, dtype=torch.bfloat16)
|
|
80
|
+
self.W_o = nn.Linear(args.d_model, args.d_model, dtype=torch.bfloat16)
|
|
81
|
+
else:
|
|
82
|
+
self.qkv_weight = lora.Linear(args.d_model, 3 * args.d_model, r=args.lora_r, lora_alpha=args.lora_alpha, bias=False, dtype=torch.bfloat16)
|
|
83
|
+
self.W_o = lora.Linear(args.d_model, args.d_model, r=args.lora_r, lora_alpha=args.lora_alpha, dtype=torch.bfloat16)
|
|
84
|
+
else:
|
|
85
|
+
self.q_weight = nn.Linear(args.d_model, args.d_model, bias=True, dtype=torch.bfloat16)
|
|
86
|
+
self.kv_weight = nn.Linear(args.d_model, 2 * args.d_model, bias=True, dtype=torch.bfloat16)
|
|
87
|
+
self.W_o = nn.Linear(args.d_model, args.d_model, dtype=torch.bfloat16)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def forward(self, q: Tensor, kv: Tensor = None):
|
|
91
|
+
if not self.use_cross_attention:
|
|
92
|
+
total, D = q.size()
|
|
93
|
+
qkv = self.qkv_weight(q).view(total, 3, self.num_heads, D // self.num_heads)
|
|
94
|
+
return qkv
|
|
95
|
+
else:
|
|
96
|
+
total_q, D_q = q.size()
|
|
97
|
+
total_kv, D_kv = kv.size()
|
|
98
|
+
|
|
99
|
+
q = self.q_weight(q).view(total_q, self.num_heads, D_q // self.num_heads)
|
|
100
|
+
kv = self.kv_weight(kv).view(total_kv, 2, self.num_heads, D_kv // self.num_heads)
|
|
101
|
+
return q, kv
|
|
102
|
+
|
|
103
|
+
def comp(self, o: Tensor):
|
|
104
|
+
out: Tensor = self.W_o(o)
|
|
105
|
+
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class FlashSelfAttentionM(nn.Module):
|
|
110
|
+
def __init__(self, args: MORTMArgs, progress=None):
|
|
111
|
+
super(FlashSelfAttentionM, self).__init__()
|
|
112
|
+
self.batch_first = True
|
|
113
|
+
self._qkv_same_embed_dim = True
|
|
114
|
+
self.in_proj_bias = None
|
|
115
|
+
self.args = args
|
|
116
|
+
|
|
117
|
+
self.embed_dim = args.d_model
|
|
118
|
+
self.qkv_block = QKVLinear(args)
|
|
119
|
+
self.drop = args.dropout
|
|
120
|
+
self.kv_cache: Optional[Tuple[Tensor, Tensor]] = None
|
|
121
|
+
self.cache_seqlens: Tensor = None
|
|
122
|
+
|
|
123
|
+
if not self.args.use_rope:
|
|
124
|
+
print("FlashAttention2のALiBiを使用します。")
|
|
125
|
+
self.alibi_slopes = torch.tensor(get_alibi_slopes(args.num_heads), dtype=torch.float32, device=progress.get_device())
|
|
126
|
+
else:
|
|
127
|
+
print("FlashAttention2のRoPEを使用します。")
|
|
128
|
+
head_dim = args.d_model // args.num_heads
|
|
129
|
+
device = progress.get_device() if progress else None
|
|
130
|
+
self.rotary_emb = RotaryEmbedding(dim=head_dim, base=10000.0, interleaved=False, device=device)
|
|
131
|
+
|
|
132
|
+
def _init_kv_cache(self, batch_size, device, dtype):
|
|
133
|
+
"""最初の呼び出し時に、バッチサイズに合わせてキャッシュを初期化する"""
|
|
134
|
+
max_seq_len = self.args.position_length + 100 # 設定ファイルなどから最大長を取得
|
|
135
|
+
head_dim = self.args.d_model // self.args.num_heads
|
|
136
|
+
shape = (batch_size, max_seq_len, self.args.num_heads, head_dim)
|
|
137
|
+
|
|
138
|
+
# torch.emptyでメモリを確保するだけ。0で埋める必要はない
|
|
139
|
+
self.kv_cache = (
|
|
140
|
+
torch.empty(shape, device=device, dtype=dtype),
|
|
141
|
+
torch.empty(shape, device=device, dtype=dtype)
|
|
142
|
+
)
|
|
143
|
+
self.cache_seqlens = torch.zeros(batch_size, device=device, dtype=torch.int32)
|
|
144
|
+
|
|
145
|
+
def forward(self, x: Tensor, is_causal=False, cu_seqlens=None, max_seqlen=None,
|
|
146
|
+
batch_size=None, indices=None, is_save_cache=False):
|
|
147
|
+
if x.dtype == torch.float32:
|
|
148
|
+
x = x.to(torch.bfloat16)
|
|
149
|
+
|
|
150
|
+
# --- フェーズ1: 学習 または 推論のプロンプト処理 ---
|
|
151
|
+
if cu_seqlens is not None:
|
|
152
|
+
# プロンプト処理時にはキャッシュを初期化
|
|
153
|
+
if is_save_cache and (self.kv_cache is None or self.kv_cache[0].shape[0] != batch_size):
|
|
154
|
+
self._init_kv_cache(batch_size, x.device, x.dtype)
|
|
155
|
+
|
|
156
|
+
qkv: Tensor = self.qkv_block(q=x)
|
|
157
|
+
|
|
158
|
+
# RoPE/ALiBiの適用とアテンション計算 (この部分は元のロジックを維持)
|
|
159
|
+
if not self.args.use_rope:
|
|
160
|
+
out = flash_attn_varlen_qkvpacked_func(qkv, dropout_p=self.drop, causal=is_causal,
|
|
161
|
+
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
|
|
162
|
+
alibi_slopes=self.alibi_slopes)
|
|
163
|
+
else:
|
|
164
|
+
q, k, v = qkv.unbind(1)
|
|
165
|
+
self.rotary_emb._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
|
166
|
+
q = apply_rotary_emb(q, self.rotary_emb._cos_cached, self.rotary_emb._sin_cached, interleaved=False, cu_seqlens=cu_seqlens)
|
|
167
|
+
k = apply_rotary_emb(k, self.rotary_emb._cos_cached, self.rotary_emb._sin_cached, interleaved=False, cu_seqlens=cu_seqlens)
|
|
168
|
+
qkv_rotated = torch.stack([q, k, v], dim=1)
|
|
169
|
+
out = flash_attn_varlen_qkvpacked_func(qkv_rotated, dropout_p=self.drop, causal=is_causal,
|
|
170
|
+
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
|
171
|
+
|
|
172
|
+
# is_save_cacheがTrueの場合、計算結果を事前確保したキャッシュに書き込む
|
|
173
|
+
if is_save_cache:
|
|
174
|
+
with torch.no_grad():
|
|
175
|
+
# RoPE適用済みのk,vをキャッシュするのが望ましい場合があるが、ここではqkvから取得
|
|
176
|
+
_, k_unpad, v_unpad = qkv.unbind(dim=1)
|
|
177
|
+
|
|
178
|
+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int32)
|
|
179
|
+
|
|
180
|
+
# 各シーケンスのK,Vを、事前確保したキャッシュの先頭に書き込む
|
|
181
|
+
for i in range(batch_size):
|
|
182
|
+
start, end = cu_seqlens[i], cu_seqlens[i+1]
|
|
183
|
+
seq_len = end - start
|
|
184
|
+
self.kv_cache[0][i, :seq_len] = k_unpad[start:end]
|
|
185
|
+
self.kv_cache[1][i, :seq_len] = v_unpad[start:end]
|
|
186
|
+
|
|
187
|
+
self.cache_seqlens = seqlens
|
|
188
|
+
|
|
189
|
+
# --- フェーズ2: 1トークンずつの推論 ---
|
|
190
|
+
else:
|
|
191
|
+
if is_save_cache:
|
|
192
|
+
# このパスでは、xは (batch_size, d_model) の形状を想定
|
|
193
|
+
qkv: Tensor = self.qkv_block(q=x)
|
|
194
|
+
# (batch_size, 3, num_heads, head_dim) -> (3, batch_size, num_heads, head_dim)
|
|
195
|
+
qkv = qkv.permute(1, 0, 2, 3)
|
|
196
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
197
|
+
|
|
198
|
+
# (batch_size, num_heads, head_dim) -> (batch_size, 1, num_heads, head_dim)
|
|
199
|
+
# flash_attn_with_kvcache の入力形状に合わせる
|
|
200
|
+
q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)
|
|
201
|
+
|
|
202
|
+
# RoPE / ALiBi の引数を準備
|
|
203
|
+
rotary_kwargs = {}
|
|
204
|
+
if not IS_NOT_LINUX:
|
|
205
|
+
self.rotary_emb._update_cos_sin_cache(self.args.position_length, device=x.device, dtype=x.dtype)
|
|
206
|
+
rotary_kwargs = {
|
|
207
|
+
"rotary_cos": self.rotary_emb.cos_cached,
|
|
208
|
+
"rotary_sin": self.rotary_emb.sin_cached,
|
|
209
|
+
"rotary_interleaved": False
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
# flash_attn_with_kvcache を呼び出すだけで、計算とキャッシュ更新が完了
|
|
213
|
+
out = flash_attn_with_kvcache(
|
|
214
|
+
q,
|
|
215
|
+
k_cache=self.kv_cache[0],
|
|
216
|
+
v_cache=self.kv_cache[1],
|
|
217
|
+
k=k,
|
|
218
|
+
v=v,
|
|
219
|
+
cache_seqlens=self.cache_seqlens,
|
|
220
|
+
alibi_slopes=self.alibi_slopes if IS_NOT_LINUX else None,
|
|
221
|
+
causal=True,
|
|
222
|
+
**rotary_kwargs
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# キャッシュの有効長をインクリメント
|
|
226
|
+
self.cache_seqlens += 1
|
|
227
|
+
|
|
228
|
+
# (batch_size, 1, h, d_model) -> (batch_size, h, d_model)
|
|
229
|
+
out = out.squeeze(1)
|
|
230
|
+
else:
|
|
231
|
+
qkv = self.qkv_block(q=x)
|
|
232
|
+
qkv = qkv.unsqueeze(0)
|
|
233
|
+
out = flash_attn_qkvpacked_func(qkv=qkv, dropout_p=self.drop, causal=is_causal)
|
|
234
|
+
out = rearrange(out, "b s h d -> (b s) (h d)")
|
|
235
|
+
return self.qkv_block.comp(out)
|
|
236
|
+
|
|
237
|
+
# 最終的な出力層
|
|
238
|
+
out = rearrange(out, "total h d -> total (h d)")
|
|
239
|
+
return self.qkv_block.comp(out)
|
|
240
|
+
|
|
241
|
+
def compute_cache_seqlens(self, k: torch.Tensor) -> torch.Tensor:
|
|
242
|
+
"""
|
|
243
|
+
k: Tensor of shape [batch_size, max_seq_len, num_heads, head_dim]
|
|
244
|
+
Returns:
|
|
245
|
+
cache_seqlens: Tensor of shape [batch_size] (実際のシーケンス長)
|
|
246
|
+
"""
|
|
247
|
+
# 各タイムステップが "all-zero" かどうかを判定
|
|
248
|
+
is_nonzero = k.abs().sum(dim=(-1, -2)) != 0 # shape: [batch_size, max_seq_len]
|
|
249
|
+
|
|
250
|
+
# True/False → int に変換して累積和で長さを求める(ただし最初の False 位置でもOK)
|
|
251
|
+
seqlens = is_nonzero.sum(dim=1) # shape: [batch_size]
|
|
252
|
+
|
|
253
|
+
return seqlens
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class FlashCrossAttentionM(nn.Module):
|
|
257
|
+
def __init__(self, args: MORTMArgs, progress=None):
|
|
258
|
+
super(FlashCrossAttentionM, self).__init__()
|
|
259
|
+
self.batch_first = True
|
|
260
|
+
self._qkv_same_embed_dim = True
|
|
261
|
+
self.in_proj_bias = None
|
|
262
|
+
self.args = args
|
|
263
|
+
|
|
264
|
+
self.embed_dim = args.d_model
|
|
265
|
+
self.qkv_block = QKVLinear(args, use_cross_attention=True)
|
|
266
|
+
self.drop = args.dropout
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def forward(self, x: Tensor, encoder_x: Tensor,cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=None,
|
|
270
|
+
max_seqlen_k=None):
|
|
271
|
+
if x.dtype == torch.float32:
|
|
272
|
+
x = x.to(torch.bfloat16)
|
|
273
|
+
if encoder_x.dtype == torch.float32:
|
|
274
|
+
encoder_x = encoder_x.to(torch.bfloat16)
|
|
275
|
+
|
|
276
|
+
# --- フェーズ1: 学習 または 推論のプロンプト処理 ---
|
|
277
|
+
if cu_seqlens_q is not None:
|
|
278
|
+
q, kv = self.qkv_block(q=x, kv=encoder_x)
|
|
279
|
+
|
|
280
|
+
out = flash_attn_varlen_kvpacked_func(
|
|
281
|
+
q=q,
|
|
282
|
+
kv=kv,
|
|
283
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
284
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
285
|
+
max_seqlen_q=max_seqlen_q,
|
|
286
|
+
max_seqlen_k=max_seqlen_k,
|
|
287
|
+
causal=False,
|
|
288
|
+
dropout_p=self.drop
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
q, kv = self.qkv_block(q=x, kv=encoder_x)
|
|
292
|
+
q = q.unsqueeze(0)
|
|
293
|
+
kv = kv.unsqueeze(0)
|
|
294
|
+
out = flash_attn_kvpacked_func(q=q, kv=kv, dropout_p=self.drop, causal=False)
|
|
295
|
+
out = rearrange(out, "b s h d -> (b s) (h d)")
|
|
296
|
+
return self.qkv_block.comp(out)
|
|
297
|
+
|
|
298
|
+
# 最終的な出力層
|
|
299
|
+
out = rearrange(out, "total h d -> total (h d)")
|
|
300
|
+
return self.qkv_block.comp(out)
|