broccoli-ml 0.36.0__py3-none-any.whl → 0.38.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -131,8 +131,7 @@ class WeightNormedLinear(nn.Module):
131
131
  return F.linear(x, self.weights(), self.bias)
132
132
 
133
133
  def __repr__(self) -> str:
134
- # Optional: A nice representation for printing the module.
135
134
  return (
136
- f"AnchoredLinear(in_features={self.in_features},"
135
+ f"WeightNormedLinear(in_features={self.in_features},"
137
136
  f"out_features={self.out_features}, bias={self.use_bias})"
138
137
  )
broccoli/transformer.py CHANGED
@@ -13,6 +13,45 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  from .linear import AnchoredLinear, SpectralNormLinear
14
14
 
15
15
 
16
+ def drop_path(
17
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
18
+ ):
19
+ """
20
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
21
+ Copyright 2019 Ross Wightman
22
+ See documentation and licence there.
23
+ """
24
+ if drop_prob == 0.0 or not training:
25
+ return x
26
+ keep_prob = 1 - drop_prob
27
+ shape = (x.shape[0],) + (1,) * (
28
+ x.ndim - 1
29
+ ) # work with diff dim tensors, not just 2D ConvNets
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class DropPath(nn.Module):
37
+ """
38
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
39
+ Copyright 2019 Ross Wightman
40
+ See documentation and licence there.
41
+ """
42
+
43
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
44
+ super(DropPath, self).__init__()
45
+ self.drop_prob = drop_prob
46
+ self.scale_by_keep = scale_by_keep
47
+
48
+ def forward(self, x):
49
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
50
+
51
+ def extra_repr(self):
52
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
53
+
54
+
16
55
  class MHAttention(nn.Module):
17
56
  """
18
57
  Multi-head self-attention using einops and optionally a custom linear layer.
@@ -264,6 +303,8 @@ class TransformerBlock(nn.Module):
264
303
  mlp_ratio=4,
265
304
  activation: nn.Module = nn.ReLU,
266
305
  activation_kwargs: Optional[dict] = None,
306
+ ff_linear_module_up=None,
307
+ ff_linear_module_down=None,
267
308
  mlp_dropout=0.0,
268
309
  msa_dropout=0.0,
269
310
  identity_probability=0.0,
@@ -279,10 +320,11 @@ class TransformerBlock(nn.Module):
279
320
  self.post_norm = post_norm
280
321
  self.normformer = normformer
281
322
 
282
- self.identity_probability = identity_probability
323
+ self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
283
324
 
284
325
  self.layer_norm_1 = nn.LayerNorm(d_model)
285
326
  self.layer_norm_2 = nn.LayerNorm(d_model)
327
+ self.layer_norm_3 = nn.LayerNorm(d_model)
286
328
 
287
329
  if position_embedding_type == "relative":
288
330
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -316,12 +358,20 @@ class TransformerBlock(nn.Module):
316
358
  activation=activation,
317
359
  activation_kwargs=activation_kwargs,
318
360
  dropout=mlp_dropout,
319
- linear_module_up=linear_module,
320
- linear_module_down=linear_module,
321
- pre_norm=pre_norm,
361
+ linear_module_up=(
362
+ ff_linear_module_up
363
+ if ff_linear_module_up is not None
364
+ else linear_module
365
+ ),
366
+ linear_module_down=(
367
+ ff_linear_module_down
368
+ if ff_linear_module_down is not None
369
+ else linear_module
370
+ ),
371
+ pre_norm=False, # Handled outside the block
322
372
  normformer=normformer,
323
- post_norm=post_norm,
324
- residual_path=True,
373
+ post_norm=False, # Handled outside the block
374
+ residual_path=False, # Handled outside the block
325
375
  )
326
376
 
327
377
  @property
@@ -329,34 +379,23 @@ class TransformerBlock(nn.Module):
329
379
  return self.attn._kv_distance
330
380
 
331
381
  def forward(self, x):
332
- if not self.training:
333
- identity_probability = 0.0
334
- else:
335
- identity_probability = self.identity_probability
336
-
337
- # perform the identity operation for some rows in the batch
338
- dist = torch.distributions.Binomial(x.size(0), identity_probability)
339
- identity_count = int(dist.sample().item())
340
-
341
- shuffle_indices = torch.randperm(x.size(0), device=x.device)
342
- unshuffle_indices = torch.argsort(shuffle_indices)
343
- shuffled = x[shuffle_indices, :, :]
344
- identity_x = shuffled[:identity_count, :, :]
345
- process_x = shuffled[identity_count:, :, :]
346
-
347
- residual_x = process_x
348
382
 
349
383
  if self.pre_norm:
350
- process_x = self.layer_norm_1(process_x)
351
-
352
- process_x = residual_x + self.attn(process_x, process_x, process_x)
353
-
354
- if self.post_norm:
355
- process_x = self.layer_norm_2(process_x)
356
-
357
- process_x = self.ff(process_x)
384
+ normx = self.layer_norm_1(x)
385
+ x = x + self.drop_path(self.attn(normx, normx, normx))
386
+ normx = self.layer_norm_2(x)
387
+ x = x + self.drop_path(self.ff(normx))
388
+ elif self.post_norm:
389
+ x = x + self.drop_path(self.attn(x, x, x))
390
+ x = self.layer_norm_1(x)
391
+ x = x + self.drop_path(self.ff(x))
392
+ x = self.layer_norm_2(x)
393
+ else:
394
+ x = x + self.drop_path(self.attn(x, x, x))
395
+ x = x + self.drop_path(self.ff(x))
358
396
 
359
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
397
+ if self.pre_norm and self.post_norm:
398
+ x = self.layer_norm_3(x)
360
399
 
361
400
  return x
362
401
 
@@ -378,6 +417,8 @@ class TransformerEncoder(nn.Module):
378
417
  mlp_ratio=4,
379
418
  activation: nn.Module = nn.ReLU,
380
419
  activation_kwargs: Optional[dict] = None,
420
+ ff_linear_module_up=None,
421
+ ff_linear_module_down=None,
381
422
  mlp_dropout=0.0,
382
423
  msa_dropout=0.0,
383
424
  stochastic_depth=0.0,
@@ -441,6 +482,8 @@ class TransformerEncoder(nn.Module):
441
482
  mlp_ratio=mlp_ratio,
442
483
  activation=activation,
443
484
  activation_kwargs=activation_kwargs,
485
+ ff_linear_module_up=ff_linear_module_up,
486
+ ff_linear_module_down=ff_linear_module_down,
444
487
  mlp_dropout=mlp_dropout,
445
488
  msa_dropout=msa_dropout,
446
489
  identity_probability=self.stochastic_depth_probabilities[i],
broccoli/vit.py CHANGED
@@ -132,6 +132,8 @@ class ViTEncoder(nn.Module):
132
132
  transformer_return_bos_tokens=False,
133
133
  transformer_activation: nn.Module = SquaredReLU,
134
134
  transformer_activation_kwargs: Optional[dict] = None,
135
+ transformer_ff_linear_module_up=None,
136
+ transformer_ff_linear_module_down=None,
135
137
  transformer_mlp_dropout=0.0,
136
138
  transformer_msa_dropout=0.1,
137
139
  transformer_stochastic_depth=0.1,
@@ -282,6 +284,8 @@ class ViTEncoder(nn.Module):
282
284
  mlp_ratio=transformer_mlp_ratio,
283
285
  activation=transformer_activation,
284
286
  activation_kwargs=transformer_activation_kwargs,
287
+ ff_linear_module_up=transformer_ff_linear_module_up,
288
+ ff_linear_module_down=transformer_ff_linear_module_down,
285
289
  mlp_dropout=transformer_mlp_dropout,
286
290
  msa_dropout=transformer_msa_dropout,
287
291
  stochastic_depth=transformer_stochastic_depth,
@@ -305,14 +309,16 @@ class ViTEncoder(nn.Module):
305
309
  activation_kwargs=transformer_activation_kwargs,
306
310
  dropout=transformer_mlp_dropout,
307
311
  linear_module_up=(
312
+ # First truthy assigned value
308
313
  transformer_initial_ff_linear_module_up
309
- if transformer_initial_ff_linear_module_up is not None
310
- else linear_module
314
+ or transformer_ff_linear_module_up
315
+ or linear_module
311
316
  ),
312
317
  linear_module_down=(
318
+ # First truthy assigned value
313
319
  transformer_initial_ff_linear_module_down
314
- if transformer_initial_ff_linear_module_down is not None
315
- else linear_module
320
+ or transformer_ff_linear_module_down
321
+ or linear_module
316
322
  ),
317
323
  pre_norm=transformer_pre_norm,
318
324
  normformer=transformer_normformer,
@@ -386,6 +392,8 @@ class ViT(nn.Module):
386
392
  transformer_return_bos_tokens=False,
387
393
  transformer_activation: nn.Module = SquaredReLU,
388
394
  transformer_activation_kwargs: Optional[dict] = None,
395
+ transformer_ff_linear_module_up=None,
396
+ transformer_ff_linear_module_down=None,
389
397
  transformer_mlp_dropout=0.0,
390
398
  transformer_msa_dropout=0.1,
391
399
  transformer_stochastic_depth=0.1,
@@ -446,6 +454,8 @@ class ViT(nn.Module):
446
454
  transformer_return_bos_tokens=transformer_return_bos_tokens,
447
455
  transformer_activation=transformer_activation,
448
456
  transformer_activation_kwargs=transformer_activation_kwargs,
457
+ transformer_ff_linear_module_up=transformer_ff_linear_module_up,
458
+ transformer_ff_linear_module_down=transformer_ff_linear_module_down,
449
459
  transformer_mlp_dropout=transformer_mlp_dropout,
450
460
  transformer_msa_dropout=transformer_msa_dropout,
451
461
  transformer_stochastic_depth=transformer_stochastic_depth,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.36.0
3
+ Version: 0.38.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -5,13 +5,13 @@ broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYV
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
6
  broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
- broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
8
+ broccoli/linear.py,sha256=zlFDij9TngqDzTTpUlZtX0PXAQgxyWWnTIXpiO1rBk0,4768
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
11
- broccoli/transformer.py,sha256=NH94U6lxHzmDGDHTTtJV2kUs7IcS2iNmFJl44_6KtQ0,15456
11
+ broccoli/transformer.py,sha256=Xw1oBLsvVeHmMqgurhorRa49nrjfooLev5uBPFeK9og,17004
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=05xqIw9xvE5easXcp4wIA1jQ0xUyRIq6h0ZDtbitXi4,17184
14
- broccoli_ml-0.36.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.36.0.dist-info/METADATA,sha256=csog4ZG1PGeRuFO5QnHdVPgmDYXsGQQJ621JgU0D83w,1257
16
- broccoli_ml-0.36.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.36.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=RaSJZh2ogqNAvAkQDuZpNqdtRTWaW_8ug4BsBCBK_f4,17728
14
+ broccoli_ml-0.38.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.38.0.dist-info/METADATA,sha256=8lcwJvmlPT_-SkJZ5Qd6nFc4Njz7mQD-JefeUxhk8Xw,1257
16
+ broccoli_ml-0.38.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.38.0.dist-info/RECORD,,