broccoli-ml 3.3.1__tar.gz → 4.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 3.3.1
3
+ Version: 4.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,11 +4,20 @@ from typing import Optional
4
4
  import torch
5
5
  import torch.nn as nn
6
6
  import torch.nn.functional as F
7
+ from torch.utils.checkpoint import checkpoint
7
8
 
8
9
  from einops import rearrange
9
10
 
10
11
  from .rope import RotaryEmbedding, apply_rotary_emb
11
12
 
13
+ try:
14
+ from flash_attn import flash_attn_func
15
+
16
+ FLASH_ATTN = True
17
+ except ImportError:
18
+ pass
19
+ FLASH_ATTN = False
20
+
12
21
 
13
22
  def drop_path(
14
23
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
@@ -206,32 +215,55 @@ class MHAttention(nn.Module):
206
215
  q = torch.cat([q_bos, q_img], dim=1)
207
216
  k = torch.cat([k_bos, k_img], dim=1)
208
217
 
209
- # Divide Q/K/V into heads
210
- q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
211
- k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
212
- v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
213
-
214
- qk_scores = q @ k.transpose(-1, -2)
215
-
216
218
  if self.scaling == "sqrtd":
217
- qk_scores /= math.sqrt(self.head_dim)
219
+ scaling_factor = 1 / math.sqrt(self.head_dim)
218
220
  elif self.scaling == "d":
219
221
  # for backwards compatibility, per https://github.com/microsoft/mup
220
- qk_scores *= 8 / self.head_dim
222
+ scaling_factor = 8 / self.head_dim
221
223
  else:
222
224
  raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
223
225
 
224
- # Apply mask if causal (must come before softmax)
225
- if self.causal:
226
- qk_scores.masked_fill_(self.mask, float("-inf"))
226
+ if FLASH_ATTN:
227
+ # Divide Q/K/V into heads
228
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
229
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
230
+ v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
231
+
232
+ output_with_heads = flash_attn_func(
233
+ q,
234
+ k,
235
+ v,
236
+ dropout_p=self.dropout.p if self.training else 0.0,
237
+ softmax_scale=scaling_factor,
238
+ causal=self.causal,
239
+ )
240
+
241
+ output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
242
+
243
+ return self.out_proj(output_without_heads)
244
+ else:
245
+ # Divide Q/K/V into heads
246
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
247
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
248
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
227
249
 
228
- qk_scores = F.softmax(qk_scores, dim=-1)
250
+ qk_scores = q @ k.transpose(-1, -2)
229
251
 
230
- output_with_heads = qk_scores @ v
252
+ qk_scores *= scaling_factor
231
253
 
232
- output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
254
+ # Apply mask if causal (must come before softmax)
255
+ if self.causal:
256
+ qk_scores.masked_fill_(self.mask, float("-inf"))
233
257
 
234
- return self.out_proj(output_without_heads)
258
+ qk_scores = F.softmax(qk_scores, dim=-1)
259
+
260
+ qk_scores = self.dropout(qk_scores)
261
+
262
+ output_with_heads = qk_scores @ v
263
+
264
+ output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
265
+
266
+ return self.out_proj(output_without_heads)
235
267
 
236
268
 
237
269
  class FeedforwardBlock(nn.Module):
@@ -410,21 +442,20 @@ class TransformerBlock(nn.Module):
410
442
  def forward(self, x):
411
443
 
412
444
  if self.pre_norm:
413
- normx = self.layer_norm_1(x)
414
- x = x + self.drop_path(self.attn(normx, normx, normx))
415
- normx = self.layer_norm_2(x)
416
- x = x + self.drop_path(self.ff(normx))
417
- elif self.post_norm:
445
+ x = self.layer_norm_1(x)
446
+ x = x + self.drop_path(self.attn(x, x, x))
447
+ x = self.layer_norm_2(x)
448
+ x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
449
+ if self.post_norm: # i.e. in addition! Pre and post.
450
+ x = self.layer_norm_3(x)
451
+ elif self.post_norm: # i.e. only, not prenorm, just post
418
452
  x = x + self.drop_path(self.attn(x, x, x))
419
453
  x = self.layer_norm_1(x)
420
- x = x + self.drop_path(self.ff(x))
454
+ x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
421
455
  x = self.layer_norm_2(x)
422
- else:
456
+ else: # Not pre or post norm. Stand well back.
423
457
  x = x + self.drop_path(self.attn(x, x, x))
424
- x = x + self.drop_path(self.ff(x))
425
-
426
- if self.pre_norm and self.post_norm:
427
- x = self.layer_norm_3(x)
458
+ x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
428
459
 
429
460
  return x
430
461
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "3.3.1"
3
+ version = "4.0.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes