broccoli-ml 3.3.0__tar.gz → 4.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 3.3.0
3
+ Version: 4.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,11 +4,20 @@ from typing import Optional
4
4
  import torch
5
5
  import torch.nn as nn
6
6
  import torch.nn.functional as F
7
+ from torch.utils.checkpoint import checkpoint
7
8
 
8
9
  from einops import rearrange
9
10
 
10
11
  from .rope import RotaryEmbedding, apply_rotary_emb
11
12
 
13
+ try:
14
+ from flash_attn import flash_attn_func
15
+
16
+ FLASH_ATTN = True
17
+ except ImportError:
18
+ pass
19
+ FLASH_ATTN = False
20
+
12
21
 
13
22
  def drop_path(
14
23
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
@@ -206,32 +215,53 @@ class MHAttention(nn.Module):
206
215
  q = torch.cat([q_bos, q_img], dim=1)
207
216
  k = torch.cat([k_bos, k_img], dim=1)
208
217
 
209
- # Divide Q/K/V into heads
210
- q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
211
- k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
212
- v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
213
-
214
- qk_scores = q @ k.transpose(-1, -2)
215
-
216
218
  if self.scaling == "sqrtd":
217
- qk_scores /= math.sqrt(self.head_dim)
219
+ scaling_factor = 1 / math.sqrt(self.head_dim)
218
220
  elif self.scaling == "d":
219
221
  # for backwards compatibility, per https://github.com/microsoft/mup
220
- qk_scores *= 8 / self.head_dim
222
+ scaling_factor = 8 / self.head_dim
221
223
  else:
222
224
  raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
223
225
 
224
- # Apply mask if causal (must come before softmax)
225
- if self.causal:
226
- qk_scores.masked_fill_(self.mask, float("-inf"))
226
+ if FLASH_ATTN:
227
+ # Divide Q/K/V into heads
228
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
229
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
230
+ v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
231
+
232
+ output_with_heads = flash_attn_func(
233
+ q,
234
+ k,
235
+ v,
236
+ dropout_p=self.dropout if self.training else 0.0,
237
+ softmax_scale=scaling_factor,
238
+ causal=self.causal,
239
+ )
240
+
241
+ output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
242
+
243
+ return self.out_proj(output_without_heads)
244
+ else:
245
+ # Divide Q/K/V into heads
246
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
247
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
248
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
249
+
250
+ qk_scores = q @ k.transpose(-1, -2)
227
251
 
228
- qk_scores = F.softmax(qk_scores, dim=-1)
252
+ qk_scores *= scaling_factor
229
253
 
230
- output_with_heads = qk_scores @ v
254
+ # Apply mask if causal (must come before softmax)
255
+ if self.causal:
256
+ qk_scores.masked_fill_(self.mask, float("-inf"))
231
257
 
232
- output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
258
+ qk_scores = F.softmax(qk_scores, dim=-1)
233
259
 
234
- return self.out_proj(output_without_heads)
260
+ output_with_heads = qk_scores @ v
261
+
262
+ output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
263
+
264
+ return self.out_proj(output_without_heads)
235
265
 
236
266
 
237
267
  class FeedforwardBlock(nn.Module):
@@ -259,6 +289,13 @@ class FeedforwardBlock(nn.Module):
259
289
  self.residual_path = residual_path
260
290
  self.post_norm = post_norm
261
291
 
292
+ if self.residual_path and (output_features < input_features):
293
+ raise ValueError(
294
+ "If the number of output features will be less than "
295
+ "the number of input features, then `residual_path` "
296
+ "should be set to False."
297
+ )
298
+
262
299
  if self.post_norm:
263
300
  self.layernorm = nn.LayerNorm(output_features)
264
301
 
@@ -403,21 +440,20 @@ class TransformerBlock(nn.Module):
403
440
  def forward(self, x):
404
441
 
405
442
  if self.pre_norm:
406
- normx = self.layer_norm_1(x)
407
- x = x + self.drop_path(self.attn(normx, normx, normx))
408
- normx = self.layer_norm_2(x)
409
- x = x + self.drop_path(self.ff(normx))
410
- elif self.post_norm:
443
+ x = self.layer_norm_1(x)
444
+ x = x + self.drop_path(self.attn(x, x, x))
445
+ x = self.layer_norm_2(x)
446
+ x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
447
+ if self.post_norm: # i.e. in addition! Pre and post.
448
+ x = self.layer_norm_3(x)
449
+ elif self.post_norm: # i.e. only, not prenorm, just post
411
450
  x = x + self.drop_path(self.attn(x, x, x))
412
451
  x = self.layer_norm_1(x)
413
- x = x + self.drop_path(self.ff(x))
452
+ x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
414
453
  x = self.layer_norm_2(x)
415
- else:
454
+ else: # Not pre or post norm. Stand well back.
416
455
  x = x + self.drop_path(self.attn(x, x, x))
417
- x = x + self.drop_path(self.ff(x))
418
-
419
- if self.pre_norm and self.post_norm:
420
- x = self.layer_norm_3(x)
456
+ x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
421
457
 
422
458
  return x
423
459
 
@@ -0,0 +1,15 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class PadTensor(nn.Module):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__()
8
+ self.args = args
9
+ self.kwargs = kwargs
10
+
11
+ def forward(self, x):
12
+ if sum(self.args[0]) == 0:
13
+ return x
14
+ else:
15
+ return F.pad(x, *self.args, **self.kwargs)
@@ -4,25 +4,12 @@ from typing import Optional
4
4
  from .transformer import TransformerEncoder, FeedforwardBlock
5
5
  from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
+ from .utils import PadTensor
7
8
 
8
9
  from einops import einsum
9
10
  from einops.layers.torch import Rearrange
10
11
 
11
12
  import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
-
15
- class PadTensor(nn.Module):
16
- def __init__(self, *args, **kwargs):
17
- super().__init__()
18
- self.args = args
19
- self.kwargs = kwargs
20
-
21
- def forward(self, x):
22
- if sum(self.args[0]) == 0:
23
- return x
24
- else:
25
- return F.pad(x, *self.args, **self.kwargs)
26
13
 
27
14
 
28
15
  class GetCLSToken(nn.Module):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "3.3.0"
3
+ version = "4.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes