mortm 4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mortm/__init__.py ADDED
File without changes
mortm/constants.py ADDED
@@ -0,0 +1,31 @@
1
+ import torch
2
+
3
+ PITCH_MAX = 128
4
+ VELO = 128
5
+ LENGTH = 999
6
+ LENGTH_HALF = 999
7
+ BEGIN = 999
8
+ BEGIN_HALF = 999
9
+ ROOT = 99
10
+ START_SEQ_TOKEN = "<S_SEQ>"
11
+ END_SEQ_TOKEN = "<E_SEQ>"
12
+ PADDING_TOKEN = "<PAD>"
13
+
14
+ MODEL_NAME = "MORTM"
15
+
16
+ # 前回のトークンのID + 前回のトークンの使用個数
17
+ PADDING_BEGIN_ID = 0
18
+ SPECIAL_BEGIN_ID = PADDING_BEGIN_ID + 1
19
+ PITCH_BEGIN_ID = SPECIAL_BEGIN_ID + 2
20
+ VELOCITY_BEGIN_ID = PITCH_BEGIN_ID + 128
21
+ DURATION_BEGIN_ID = VELOCITY_BEGIN_ID + 128
22
+ START_BEGIN_ID = DURATION_BEGIN_ID + 100
23
+ SHIFT_BEGIN_ID = START_BEGIN_ID + 32
24
+
25
+
26
+ PITCH_GROUP = range(PITCH_BEGIN_ID, PITCH_BEGIN_ID + 128 + 1)
27
+ VELOCITY_GROUP = range(VELOCITY_BEGIN_ID, VELOCITY_BEGIN_ID + 128 + 1)
28
+ DURATION_GROUP = range(DURATION_BEGIN_ID, DURATION_BEGIN_ID + 100 + 1)
29
+ START_GROUP = range(START_BEGIN_ID, START_BEGIN_ID + 32 + 1)
30
+ SHIFT_GROUP = range(SHIFT_BEGIN_ID, SHIFT_BEGIN_ID + 4 + 1)
31
+
File without changes
mortm/models/bertm.py ADDED
@@ -0,0 +1,294 @@
1
+ import numpy as np
2
+ import torch
3
+ from torch.distributions import Categorical
4
+ from .modules.layers import *
5
+ from .modules.config import MORTMArgs
6
+
7
+ from flash_attn.bert_padding import unpad_input, pad_input
8
+
9
+
10
+
11
+ class ActorCritic(nn.Module):
12
+ def __init__(self, args: MORTMArgs, progress):
13
+ super(ActorCritic, self).__init__()
14
+ self.args = args
15
+ self.progress = progress
16
+ self.e_layer = args.e_layer
17
+ self.d_layer = args.d_layer
18
+ self.num_heads = args.num_heads
19
+ self.d_model = args.d_model
20
+ self.dim_feedforward = args.dim_feedforward
21
+ self.dropout = args.dropout
22
+ self.use_lora = args.use_lora
23
+
24
+ self.decoder = MORTMDecoder(args, progress=progress)
25
+
26
+ print(f"Input Vocab Size:{args.vocab_size}")
27
+ self.embedding: nn.Embedding = nn.Embedding(args.vocab_size, self.d_model, padding_idx=0).to(self.progress.get_device())
28
+ if not self.use_lora:
29
+ self.Wout: nn.Linear = nn.Linear(self.d_model, args.vocab_size).to(self.progress.get_device())
30
+ else:
31
+ self.Wout: lora.Linear = lora.Linear(self.d_model, args.vocab_size, r=args.lora_r, lora_alpha=args.lora_alpha)
32
+
33
+ self.critic_hidden = nn.Linear(self.d_model, self.d_model // 2)
34
+ self.critic_out = nn.Linear(self.d_model // 2, 1) # 出力次元を1に設定
35
+
36
+ self.softmax: nn.Softmax = nn.Softmax(dim=-1).to(self.progress.get_device())
37
+
38
+
39
+ def evaluate_actions(self, sequence_tensors, padding_mask):
40
+ # 1. 入力とターゲットを作成(1つずらす)
41
+ input_ids = sequence_tensors[:, :-1]
42
+ target_ids = sequence_tensors[:, 1:]
43
+
44
+ # マスクも同様にずらす
45
+ if padding_mask is not None:
46
+ input_mask = padding_mask[:, :-1]
47
+ else:
48
+ input_mask = None
49
+
50
+ # 2. モデルに`input_ids`を通して、各ステップのlogitsを取得
51
+ logits, new_values = self.forward(input_ids, padding_mask=input_mask, is_causal=True)
52
+
53
+ reshaped_logits = logits.view(-1, self.args.vocab_size)
54
+ reshaped_targets = target_ids.reshape(-1).long()
55
+ # Categorical分布を使って一括で計算
56
+ dist = Categorical(logits=reshaped_logits)
57
+ log_probs = dist.log_prob(reshaped_targets)
58
+
59
+ # 元の形状 (batch, seq_len-1) に戻す
60
+ log_probs = log_probs.view(logits.size(0), logits.size(1))
61
+
62
+ # パディング部分のlog_probを0にする
63
+ if padding_mask is not None:
64
+ log_probs = log_probs * padding_mask[:, 1:]
65
+ return log_probs, new_values.reshape(new_values.shape[0], new_values.shape[1])
66
+
67
+
68
+ def eval_seq(self, src, print_log=True):
69
+ """
70
+ KVキャッシュを利用してトークンを生成するためのメソッドです。
71
+ 複数バッチに対応しています。
72
+ """
73
+ self.eval()
74
+ is_running = True
75
+ end_count = 0
76
+ device = self.progress.get_device()
77
+
78
+ if isinstance(src, numpy.ndarray):
79
+ src = torch.tensor(src, device=device)
80
+ if src.dim() == 1:
81
+ src = src.unsqueeze(0)
82
+
83
+
84
+ # --- 1. プロンプト処理 (Pre-fill) ---
85
+ if print_log: print("--- Pre-fill Phase ---")
86
+ prompt_padding_mask = (src != self.embedding.padding_idx)
87
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
88
+ logits, values = self.forward(src, padding_mask=prompt_padding_mask, is_causal=True, is_save_cache=True)
89
+ prob_list = Categorical(logits=logits)
90
+
91
+ last_token_logits = logits[:, -1, :]
92
+ # next_tokens は (batch_size,) の形状を持つテンソル
93
+ next_tokens = self.top_p_sampling(last_token_logits, p=0.95, temperature=1.0)
94
+
95
+ # 全トークンを保持するテンソル
96
+ all_tokens = torch.cat([src, next_tokens.unsqueeze(1)], dim=1)
97
+ all_values = values
98
+ all_probs = prob_list.log_prob(all_tokens[:, 1:])
99
+ # --- 2. トークン生成 (Decoding) ---
100
+ i = len(all_tokens)
101
+ if print_log: print("\n--- Decoding Phase ---")
102
+ while is_running:
103
+ # 入力は直前に生成されたトークン (B, 1)
104
+ input_tokens = next_tokens
105
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
106
+ logits, values = self.forward(input_tokens, padding_mask=None, is_causal=True, is_save_cache=True)
107
+ probs_list = Categorical(logits=logits.squeeze(1))
108
+ # next_tokens は (batch_size,) の形状を持つテンソル
109
+ next_tokens = self.top_p_sampling(logits.squeeze(1), p=0.95, temperature=1.0)
110
+
111
+ #print(next_tokens.max(), all_tokens.max())
112
+ # 生成されたトークンを連結
113
+ all_tokens = torch.cat([all_tokens, next_tokens.unsqueeze(1)], dim=1)
114
+ all_values = torch.cat([all_values, values.unsqueeze(1)], dim=1)
115
+ all_probs = torch.cat([all_probs, probs_list.log_prob(next_tokens).unsqueeze(1)], dim=1)
116
+
117
+ if print_log: print(f"\r Step {i+1}: Generated tokens {next_tokens.tolist()}", end="")
118
+
119
+ if self.is_end_point(all_tokens) or i > self.args.position_length:
120
+ is_running = False
121
+
122
+ i += 1
123
+ print(all_tokens.shape, all_values.shape, all_probs.shape)
124
+ np_all_tokens = []
125
+ np_all_values = []
126
+ np_all_probs = []
127
+ np_generated_only_tokens = []
128
+ if print_log: print(all_tokens.max())
129
+ for i, seq in enumerate(all_tokens):
130
+ seq: Tensor
131
+ np_seq = np.array([], dtype=int)
132
+ np_value = np.array([], dtype=float)
133
+ np_probs = np.array([], dtype=float)
134
+ pad = (seq == 0).nonzero(as_tuple=True)[0]
135
+ eseq = (seq == 585).nonzero(as_tuple=True)[0]
136
+
137
+ # ... eseq の長さを決定するロジックは変更なし ...
138
+ if len(eseq) == 0:
139
+ eseq = len(seq)-1
140
+ elif len(eseq) != 1:
141
+ eseq = eseq[0].item()
142
+ else:
143
+ eseq = eseq.item()
144
+
145
+ if len(pad) != 0:
146
+ start = pad[0]
147
+ end = pad[-1]
148
+
149
+ # トークン部分は変更なし
150
+ np_seq = np.append(np_seq, seq[:start].cpu().numpy())
151
+ # --- 修正: スライスの終点を-1する ---
152
+ np_value = all_values[i, :start-1].cpu().numpy()
153
+ np_probs = all_probs[i, :start-1].cpu().numpy()
154
+
155
+ if eseq == len(seq):
156
+ # トークン部分は変更なし
157
+ np_seq = np.append(np_seq, seq[end+1:].cpu().numpy())
158
+ # --- 修正: スライスの始点と終点を-1する ---
159
+ np_value = np.append(np_value, all_values[i, end:].cpu().numpy()) # eseqはlen(seq)-1なので-1不要
160
+ np_probs = np.append(np_probs, all_probs[i, end:].cpu().numpy())
161
+ if print_log:print(f"fdkso LEN : {np_seq.shape, np_value.shape, np_probs.shape}")
162
+
163
+ else:
164
+ # トークン部分は変更なし
165
+ np_seq = np.append(np_seq, seq[end+1:eseq+1].cpu().numpy())
166
+ # --- 修正: スライスの始点と終点を-1する ---
167
+ np_value = np.append(np_value, all_values[i, end:eseq].cpu().numpy())
168
+ np_probs = np.append(np_probs, all_probs[i, end:eseq].cpu().numpy())
169
+ if print_log:print(f"!WDFG : {np_seq.shape, np_value.shape, np_probs.shape}")
170
+ else: # パディングがない場合
171
+ if eseq == len(seq):
172
+ # トークン部分は変更なし
173
+ np_seq = np.append(np_seq, seq.cpu().numpy())
174
+ # --- 修正: 全体をスライスするが、長さは元々1短いのでそのままでOK ---
175
+ np_value = np.append(np_value, all_values[i].cpu().numpy())
176
+ np_probs = np.append(np_probs, all_probs[i].cpu().numpy())
177
+ if print_log:print(f"ESEQ LEN : {np_seq.shape, np_value.shape, np_probs.shape}")
178
+ else:
179
+ # トークン部分は変更なし
180
+ np_seq = np.append(np_seq, seq[:eseq+1].cpu().numpy())
181
+ # --- 修正: スライスの終点を-1する ---
182
+ np_value = np.append(np_value, all_values[i, :eseq].cpu().numpy())
183
+ np_probs = np.append(np_probs, all_probs[i, :eseq].cpu().numpy())
184
+ if print_log: print(f"Not ESEQ LEN : {np_seq.shape, np_value.shape, np_probs.shape}")
185
+
186
+ np_all_tokens.append(np_seq)
187
+ np_all_values.append(np_value)
188
+ np_all_probs.append(np_probs)
189
+ if np_seq.max() > self.args.vocab_size:
190
+ raise ValueError(
191
+ f"生成されたトークンIDが語彙サイズ({self.args.vocab_size})を超えています: {np_seq.max()}"
192
+ )
193
+
194
+ return np_all_tokens, np_all_values, np_all_probs
195
+
196
+ def forward(self, x, padding_mask=None, is_causal=False, is_save_cache=False):
197
+ x: Tensor = self.embedding(x).to(dtype=torch.bfloat16)
198
+ if padding_mask is not None:
199
+ batch, tgt_len, embed_dim = x.size()
200
+ x, indices, cu_seqlens, max_s, used_seqlens = unpad_input(x, padding_mask)
201
+ else:
202
+ tgt_len, embed_dim = x.size()
203
+ batch = None
204
+ indices = cu_seqlens = max_s = used_seqlens = None
205
+ out = self.decoder(tgt=x, tgt_is_causal=is_causal, cu_seqlens=cu_seqlens, max_seqlen=max_s,
206
+ batch_size=batch, indices=indices, is_save_cache=is_save_cache)
207
+ if padding_mask is not None:
208
+ out = pad_input(out, indices, batch, tgt_len)
209
+
210
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
211
+ score: Tensor = self.Wout(out)
212
+ hidden = self.critic_hidden(out)
213
+ hidden = F.relu(hidden)
214
+ critic_score = self.critic_out(hidden)
215
+ return score, critic_score
216
+
217
+ def is_end_point(self, x: torch.Tensor) -> bool:
218
+ """
219
+ x: Tensor of shape [n, 14]
220
+ 戻り値: 全ての行に少なくとも1つ 5 があれば True、そうでなければ False
221
+ """
222
+
223
+ mask = (x == 585) | (x == 586)
224
+ per_row_has5 = mask.any(dim=1)
225
+ # 3) 全行が True かを判定する
226
+ all_rows_ok = per_row_has5.all()
227
+
228
+ # 4) Python の bool 型で返す
229
+ return bool(all_rows_ok)
230
+
231
+ def top_p_sampling(self, logits: Tensor, p=0.9, temperature=1.0) -> Tensor:
232
+ """
233
+ 複数バッチに対応したTop-pサンプリング。(修正版)
234
+ """
235
+ logits = logits / temperature
236
+ probs = self.softmax(logits)
237
+
238
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
239
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
240
+
241
+ sorted_probs_to_remove = cumulative_probs > p
242
+ sorted_probs_to_remove[..., 1:] = sorted_probs_to_remove[..., :-1].clone()
243
+ sorted_probs_to_remove[..., 0] = 0
244
+
245
+ probs_to_keep = sorted_probs.masked_fill(sorted_probs_to_remove, 0)
246
+
247
+ # ゼロ除算を避けるため、分母に微小な値を加える
248
+ probs_sum = probs_to_keep.sum(dim=-1, keepdim=True)
249
+ renormalized_probs = probs_to_keep / (probs_sum + 1e-9) #
250
+
251
+ sampled_next_indices = torch.multinomial(renormalized_probs, num_samples=1)
252
+ sampled_original_indices = torch.gather(sorted_indices, dim=-1, index=sampled_next_indices)
253
+
254
+ r = sampled_original_indices.squeeze(-1)
255
+
256
+ # vocab_sizeを直接取得してチェックする
257
+ vocab_size = logits.shape[-1]
258
+ if r.max().item() > vocab_size:
259
+ raise ValueError(
260
+ f"サンプリングされたトークンIDが語彙サイズ({vocab_size})以上です: {r.max().item()}"
261
+ )
262
+
263
+ return r
264
+
265
+
266
+ class BERTM(nn.Module):
267
+
268
+ def __init__(self, args: MORTMArgs, progress):
269
+ super(BERTM, self).__init__()
270
+ self.args = args # argsを保存しておくと便利
271
+ self.embedding = nn.Embedding(args.vocab_size, args.d_model)
272
+ self.decoder = MORTMDecoder(args=args,
273
+ progress=progress)
274
+ self.attn_pool = Pool(args)
275
+ self.hidden = nn.Linear(args.d_model, args.d_model // 2)
276
+ self.Wout = nn.Linear(args.d_model // 2, 1) # linear層の出力次元に合わせる
277
+
278
+ def forward(self, x: Tensor, padding_mask=None):
279
+ x: Tensor = self.embedding(x).to(dtype=torch.bfloat16)
280
+
281
+ if padding_mask is not None:
282
+ x, indices, cu_seqlens, max_s, used_seqlens = unpad_input(x, padding_mask)
283
+ else:
284
+ indices = cu_seqlens = max_s = used_seqlens = None
285
+
286
+ out = self.decoder(tgt=x, tgt_is_causal=False, cu_seqlens=cu_seqlens, max_seqlen=max_s)
287
+
288
+ out = self.attn_pool(out, cu_seqlens if cu_seqlens is not None else torch.tensor([0, len(x)], dtype=torch.int32, device=x.device)) # バッチサイズをcu_seqlensに設定
289
+
290
+ out = self.hidden(out)
291
+ hid = F.relu(out)
292
+ score = self.Wout(hid)
293
+
294
+ return score
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from .progress import LearningProgress
6
+
7
+
8
+ class PositionalEncoding(nn.Module):
9
+
10
+ def __init__(self, d_model, progress: LearningProgress, dropout=0.1, max_len=5000):
11
+ super(PositionalEncoding, self).__init__()
12
+ self.dropout = nn.Dropout(p=dropout)
13
+
14
+ # Positional Encodingのテンソルを生成
15
+ pe = torch.zeros(max_len, d_model, device=progress.get_device())
16
+ position = torch.arange(0, max_len, dtype=torch.float, device=progress.get_device()).unsqueeze(1)
17
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(progress.get_device())
18
+
19
+ pe[:, 0::2] = torch.sin(position * div_term,)
20
+ pe[:, 1::2] = torch.cos(position * div_term)
21
+ pe = pe.unsqueeze(0).transpose(0, 1)
22
+
23
+ self.register_buffer('pe', pe)
24
+
25
+ def forward(self, x):
26
+ x = x + self.pe[:x.size(0), :]
27
+ return self.dropout(x)
File without changes
@@ -0,0 +1,300 @@
1
+ from typing import Optional
2
+
3
+ from torch.nn.parameter import Parameter
4
+ from torch.nn.init import *
5
+ from typing import Optional, Tuple
6
+ import loralib.layers as lora
7
+
8
+ from torch.nn.functional import linear, softmax, dropout
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch import Tensor
13
+ import math
14
+ from einops import rearrange
15
+
16
+ from .config import MORTMArgs
17
+ try:
18
+ from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb
19
+ IS_NOT_LINUX = True #本当はFlase
20
+ except ImportError as i:
21
+ IS_NOT_LINUX = True
22
+ print(f"モジュールをインストールできませんでした。(WindowsではFlashを利用できません)\n {i.name}")
23
+
24
+ try:
25
+ from flash_attn.bert_padding import pad_input, unpad_input
26
+ from flash_attn.flash_attn_interface import *
27
+ from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func, flash_attn_qkvpacked_func
28
+ except ImportError as i:
29
+ print(f"モジュールをインストールできませんでした。\n {i.name}")
30
+
31
+ # FlashAttention2 の関数(flash_attn_func)をインポート
32
+ # (ライブラリがダウンロード済みであると仮定)
33
+
34
+
35
+
36
+ def marge_cache(kv_cache: Optional[Tuple[Tensor, Tensor]], cache_seqlens: Optional[Tensor],
37
+ k: Tensor, v: Tensor) -> Tuple[Optional[Tuple[Tensor, Tensor]], Optional[Tensor]]:
38
+ for i in range(k.shape[0]):
39
+ pos = cache_seqlens[i] # シーケンス内の位置
40
+ if pos >= kv_cache[0].shape[1]:
41
+ kv_cache[0] = torch.cat([kv_cache[0],torch.zeros_like(kv_cache[0][:, :1])], dim=1)
42
+ kv_cache[1] = torch.cat([kv_cache[1],torch.zeros_like(kv_cache[1][:, :1])], dim=1)
43
+
44
+ kv_cache[0][i, pos, :, :] = k[i, 0] # バッチi, スロットposに格納
45
+ kv_cache[1][i, pos, :, :] = v[i, 0]
46
+ cache_seqlens[i] += 1
47
+
48
+ return kv_cache, cache_seqlens
49
+
50
+ def get_alibi_slopes(n_heads):
51
+ """
52
+ ALiBi のスロープを計算する関数。
53
+ n_heads が 2 のべき乗の場合はシンプルな幾何級数になり、
54
+ そうでない場合は補間してスロープを拡張します。
55
+ """
56
+ def get_slopes_power_of_2(n):
57
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
58
+ return [start * (start ** i) for i in range(n)]
59
+
60
+ if math.log2(n_heads).is_integer():
61
+ slopes = get_slopes_power_of_2(n_heads)
62
+ else:
63
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
64
+ slopes = get_slopes_power_of_2(closest_power_of_2)
65
+ extra = get_alibi_slopes(2 * closest_power_of_2)[0::2]
66
+ slopes.extend(extra[: n_heads - closest_power_of_2])
67
+ return slopes
68
+
69
+
70
+ class QKVLinear(nn.Module):
71
+ def __init__(self, args: MORTMArgs, use_cross_attention: bool=False):
72
+ super(QKVLinear, self).__init__()
73
+ self.num_heads = args.num_heads
74
+ self.drop_out = nn.Dropout(args.dropout)
75
+ self.use_cross_attention = use_cross_attention
76
+
77
+ if not use_cross_attention:
78
+ if not args.use_lora:
79
+ self.qkv_weight = nn.Linear(args.d_model, 3 * args.d_model, bias=False, dtype=torch.bfloat16)
80
+ self.W_o = nn.Linear(args.d_model, args.d_model, dtype=torch.bfloat16)
81
+ else:
82
+ self.qkv_weight = lora.Linear(args.d_model, 3 * args.d_model, r=args.lora_r, lora_alpha=args.lora_alpha, bias=False, dtype=torch.bfloat16)
83
+ self.W_o = lora.Linear(args.d_model, args.d_model, r=args.lora_r, lora_alpha=args.lora_alpha, dtype=torch.bfloat16)
84
+ else:
85
+ self.q_weight = nn.Linear(args.d_model, args.d_model, bias=True, dtype=torch.bfloat16)
86
+ self.kv_weight = nn.Linear(args.d_model, 2 * args.d_model, bias=True, dtype=torch.bfloat16)
87
+ self.W_o = nn.Linear(args.d_model, args.d_model, dtype=torch.bfloat16)
88
+
89
+
90
+ def forward(self, q: Tensor, kv: Tensor = None):
91
+ if not self.use_cross_attention:
92
+ total, D = q.size()
93
+ qkv = self.qkv_weight(q).view(total, 3, self.num_heads, D // self.num_heads)
94
+ return qkv
95
+ else:
96
+ total_q, D_q = q.size()
97
+ total_kv, D_kv = kv.size()
98
+
99
+ q = self.q_weight(q).view(total_q, self.num_heads, D_q // self.num_heads)
100
+ kv = self.kv_weight(kv).view(total_kv, 2, self.num_heads, D_kv // self.num_heads)
101
+ return q, kv
102
+
103
+ def comp(self, o: Tensor):
104
+ out: Tensor = self.W_o(o)
105
+
106
+ return out
107
+
108
+
109
+ class FlashSelfAttentionM(nn.Module):
110
+ def __init__(self, args: MORTMArgs, progress=None):
111
+ super(FlashSelfAttentionM, self).__init__()
112
+ self.batch_first = True
113
+ self._qkv_same_embed_dim = True
114
+ self.in_proj_bias = None
115
+ self.args = args
116
+
117
+ self.embed_dim = args.d_model
118
+ self.qkv_block = QKVLinear(args)
119
+ self.drop = args.dropout
120
+ self.kv_cache: Optional[Tuple[Tensor, Tensor]] = None
121
+ self.cache_seqlens: Tensor = None
122
+
123
+ if not self.args.use_rope:
124
+ print("FlashAttention2のALiBiを使用します。")
125
+ self.alibi_slopes = torch.tensor(get_alibi_slopes(args.num_heads), dtype=torch.float32, device=progress.get_device())
126
+ else:
127
+ print("FlashAttention2のRoPEを使用します。")
128
+ head_dim = args.d_model // args.num_heads
129
+ device = progress.get_device() if progress else None
130
+ self.rotary_emb = RotaryEmbedding(dim=head_dim, base=10000.0, interleaved=False, device=device)
131
+
132
+ def _init_kv_cache(self, batch_size, device, dtype):
133
+ """最初の呼び出し時に、バッチサイズに合わせてキャッシュを初期化する"""
134
+ max_seq_len = self.args.position_length + 100 # 設定ファイルなどから最大長を取得
135
+ head_dim = self.args.d_model // self.args.num_heads
136
+ shape = (batch_size, max_seq_len, self.args.num_heads, head_dim)
137
+
138
+ # torch.emptyでメモリを確保するだけ。0で埋める必要はない
139
+ self.kv_cache = (
140
+ torch.empty(shape, device=device, dtype=dtype),
141
+ torch.empty(shape, device=device, dtype=dtype)
142
+ )
143
+ self.cache_seqlens = torch.zeros(batch_size, device=device, dtype=torch.int32)
144
+
145
+ def forward(self, x: Tensor, is_causal=False, cu_seqlens=None, max_seqlen=None,
146
+ batch_size=None, indices=None, is_save_cache=False):
147
+ if x.dtype == torch.float32:
148
+ x = x.to(torch.bfloat16)
149
+
150
+ # --- フェーズ1: 学習 または 推論のプロンプト処理 ---
151
+ if cu_seqlens is not None:
152
+ # プロンプト処理時にはキャッシュを初期化
153
+ if is_save_cache and (self.kv_cache is None or self.kv_cache[0].shape[0] != batch_size):
154
+ self._init_kv_cache(batch_size, x.device, x.dtype)
155
+
156
+ qkv: Tensor = self.qkv_block(q=x)
157
+
158
+ # RoPE/ALiBiの適用とアテンション計算 (この部分は元のロジックを維持)
159
+ if not self.args.use_rope:
160
+ out = flash_attn_varlen_qkvpacked_func(qkv, dropout_p=self.drop, causal=is_causal,
161
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
162
+ alibi_slopes=self.alibi_slopes)
163
+ else:
164
+ q, k, v = qkv.unbind(1)
165
+ self.rotary_emb._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
166
+ q = apply_rotary_emb(q, self.rotary_emb._cos_cached, self.rotary_emb._sin_cached, interleaved=False, cu_seqlens=cu_seqlens)
167
+ k = apply_rotary_emb(k, self.rotary_emb._cos_cached, self.rotary_emb._sin_cached, interleaved=False, cu_seqlens=cu_seqlens)
168
+ qkv_rotated = torch.stack([q, k, v], dim=1)
169
+ out = flash_attn_varlen_qkvpacked_func(qkv_rotated, dropout_p=self.drop, causal=is_causal,
170
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
171
+
172
+ # is_save_cacheがTrueの場合、計算結果を事前確保したキャッシュに書き込む
173
+ if is_save_cache:
174
+ with torch.no_grad():
175
+ # RoPE適用済みのk,vをキャッシュするのが望ましい場合があるが、ここではqkvから取得
176
+ _, k_unpad, v_unpad = qkv.unbind(dim=1)
177
+
178
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int32)
179
+
180
+ # 各シーケンスのK,Vを、事前確保したキャッシュの先頭に書き込む
181
+ for i in range(batch_size):
182
+ start, end = cu_seqlens[i], cu_seqlens[i+1]
183
+ seq_len = end - start
184
+ self.kv_cache[0][i, :seq_len] = k_unpad[start:end]
185
+ self.kv_cache[1][i, :seq_len] = v_unpad[start:end]
186
+
187
+ self.cache_seqlens = seqlens
188
+
189
+ # --- フェーズ2: 1トークンずつの推論 ---
190
+ else:
191
+ if is_save_cache:
192
+ # このパスでは、xは (batch_size, d_model) の形状を想定
193
+ qkv: Tensor = self.qkv_block(q=x)
194
+ # (batch_size, 3, num_heads, head_dim) -> (3, batch_size, num_heads, head_dim)
195
+ qkv = qkv.permute(1, 0, 2, 3)
196
+ q, k, v = qkv[0], qkv[1], qkv[2]
197
+
198
+ # (batch_size, num_heads, head_dim) -> (batch_size, 1, num_heads, head_dim)
199
+ # flash_attn_with_kvcache の入力形状に合わせる
200
+ q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)
201
+
202
+ # RoPE / ALiBi の引数を準備
203
+ rotary_kwargs = {}
204
+ if not IS_NOT_LINUX:
205
+ self.rotary_emb._update_cos_sin_cache(self.args.position_length, device=x.device, dtype=x.dtype)
206
+ rotary_kwargs = {
207
+ "rotary_cos": self.rotary_emb.cos_cached,
208
+ "rotary_sin": self.rotary_emb.sin_cached,
209
+ "rotary_interleaved": False
210
+ }
211
+
212
+ # flash_attn_with_kvcache を呼び出すだけで、計算とキャッシュ更新が完了
213
+ out = flash_attn_with_kvcache(
214
+ q,
215
+ k_cache=self.kv_cache[0],
216
+ v_cache=self.kv_cache[1],
217
+ k=k,
218
+ v=v,
219
+ cache_seqlens=self.cache_seqlens,
220
+ alibi_slopes=self.alibi_slopes if IS_NOT_LINUX else None,
221
+ causal=True,
222
+ **rotary_kwargs
223
+ )
224
+
225
+ # キャッシュの有効長をインクリメント
226
+ self.cache_seqlens += 1
227
+
228
+ # (batch_size, 1, h, d_model) -> (batch_size, h, d_model)
229
+ out = out.squeeze(1)
230
+ else:
231
+ qkv = self.qkv_block(q=x)
232
+ qkv = qkv.unsqueeze(0)
233
+ out = flash_attn_qkvpacked_func(qkv=qkv, dropout_p=self.drop, causal=is_causal)
234
+ out = rearrange(out, "b s h d -> (b s) (h d)")
235
+ return self.qkv_block.comp(out)
236
+
237
+ # 最終的な出力層
238
+ out = rearrange(out, "total h d -> total (h d)")
239
+ return self.qkv_block.comp(out)
240
+
241
+ def compute_cache_seqlens(self, k: torch.Tensor) -> torch.Tensor:
242
+ """
243
+ k: Tensor of shape [batch_size, max_seq_len, num_heads, head_dim]
244
+ Returns:
245
+ cache_seqlens: Tensor of shape [batch_size] (実際のシーケンス長)
246
+ """
247
+ # 各タイムステップが "all-zero" かどうかを判定
248
+ is_nonzero = k.abs().sum(dim=(-1, -2)) != 0 # shape: [batch_size, max_seq_len]
249
+
250
+ # True/False → int に変換して累積和で長さを求める(ただし最初の False 位置でもOK)
251
+ seqlens = is_nonzero.sum(dim=1) # shape: [batch_size]
252
+
253
+ return seqlens
254
+
255
+
256
+ class FlashCrossAttentionM(nn.Module):
257
+ def __init__(self, args: MORTMArgs, progress=None):
258
+ super(FlashCrossAttentionM, self).__init__()
259
+ self.batch_first = True
260
+ self._qkv_same_embed_dim = True
261
+ self.in_proj_bias = None
262
+ self.args = args
263
+
264
+ self.embed_dim = args.d_model
265
+ self.qkv_block = QKVLinear(args, use_cross_attention=True)
266
+ self.drop = args.dropout
267
+
268
+
269
+ def forward(self, x: Tensor, encoder_x: Tensor,cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=None,
270
+ max_seqlen_k=None):
271
+ if x.dtype == torch.float32:
272
+ x = x.to(torch.bfloat16)
273
+ if encoder_x.dtype == torch.float32:
274
+ encoder_x = encoder_x.to(torch.bfloat16)
275
+
276
+ # --- フェーズ1: 学習 または 推論のプロンプト処理 ---
277
+ if cu_seqlens_q is not None:
278
+ q, kv = self.qkv_block(q=x, kv=encoder_x)
279
+
280
+ out = flash_attn_varlen_kvpacked_func(
281
+ q=q,
282
+ kv=kv,
283
+ cu_seqlens_q=cu_seqlens_q,
284
+ cu_seqlens_k=cu_seqlens_k,
285
+ max_seqlen_q=max_seqlen_q,
286
+ max_seqlen_k=max_seqlen_k,
287
+ causal=False,
288
+ dropout_p=self.drop
289
+ )
290
+ else:
291
+ q, kv = self.qkv_block(q=x, kv=encoder_x)
292
+ q = q.unsqueeze(0)
293
+ kv = kv.unsqueeze(0)
294
+ out = flash_attn_kvpacked_func(q=q, kv=kv, dropout_p=self.drop, causal=False)
295
+ out = rearrange(out, "b s h d -> (b s) (h d)")
296
+ return self.qkv_block.comp(out)
297
+
298
+ # 最終的な出力層
299
+ out = rearrange(out, "total h d -> total (h d)")
300
+ return self.qkv_block.comp(out)