broccoli-ml 0.37.0__tar.gz → 0.39.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.37.0
3
+ Version: 0.39.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -34,7 +34,8 @@ class SpectralNormLinear(nn.Module):
34
34
 
35
35
  def reset_parameters(self) -> None:
36
36
  weights = torch.empty(self.out_features, self.in_features)
37
- nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
37
+ stdv = 1.0 / math.sqrt(self.in_features)
38
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
38
39
  if self.use_bias:
39
40
  fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
40
41
  bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
@@ -77,7 +78,8 @@ class AnchoredLinear(nn.Module):
77
78
 
78
79
  def reset_parameters(self) -> None:
79
80
  weights = torch.empty(self.out_features, self.in_features)
80
- nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
81
+ stdv = 1.0 / math.sqrt(self.in_features)
82
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
81
83
  if self.use_bias:
82
84
  fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
83
85
  bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
@@ -120,7 +122,8 @@ class WeightNormedLinear(nn.Module):
120
122
 
121
123
  def reset_parameters(self) -> None:
122
124
  weights = torch.empty(self.out_features, self.in_features)
123
- nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
125
+ stdv = 1.0 / math.sqrt(self.in_features)
126
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
124
127
  if self.use_bias:
125
128
  fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
126
129
  bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
@@ -131,8 +134,7 @@ class WeightNormedLinear(nn.Module):
131
134
  return F.linear(x, self.weights(), self.bias)
132
135
 
133
136
  def __repr__(self) -> str:
134
- # Optional: A nice representation for printing the module.
135
137
  return (
136
- f"AnchoredLinear(in_features={self.in_features},"
138
+ f"WeightNormedLinear(in_features={self.in_features},"
137
139
  f"out_features={self.out_features}, bias={self.use_bias})"
138
140
  )
@@ -303,6 +303,8 @@ class TransformerBlock(nn.Module):
303
303
  mlp_ratio=4,
304
304
  activation: nn.Module = nn.ReLU,
305
305
  activation_kwargs: Optional[dict] = None,
306
+ ff_linear_module_up=None,
307
+ ff_linear_module_down=None,
306
308
  mlp_dropout=0.0,
307
309
  msa_dropout=0.0,
308
310
  identity_probability=0.0,
@@ -356,8 +358,16 @@ class TransformerBlock(nn.Module):
356
358
  activation=activation,
357
359
  activation_kwargs=activation_kwargs,
358
360
  dropout=mlp_dropout,
359
- linear_module_up=linear_module,
360
- linear_module_down=linear_module,
361
+ linear_module_up=(
362
+ ff_linear_module_up
363
+ if ff_linear_module_up is not None
364
+ else linear_module
365
+ ),
366
+ linear_module_down=(
367
+ ff_linear_module_down
368
+ if ff_linear_module_down is not None
369
+ else linear_module
370
+ ),
361
371
  pre_norm=False, # Handled outside the block
362
372
  normformer=normformer,
363
373
  post_norm=False, # Handled outside the block
@@ -389,51 +399,6 @@ class TransformerBlock(nn.Module):
389
399
 
390
400
  return x
391
401
 
392
- # if not self.training:
393
- # identity_probability = 0.0
394
- # else:
395
- # identity_probability = self.identity_probability
396
-
397
- # if random.random() < identity_probability:
398
- # return x
399
- # else:
400
- # ...
401
-
402
- # # perform the identity operation for some rows in the batch
403
- # dist = torch.distributions.Binomial(x.size(0), identity_probability)
404
- # identity_count = int(dist.sample().item())
405
-
406
- # shuffle_indices = torch.randperm(x.size(0), device=x.device)
407
- # unshuffle_indices = torch.argsort(shuffle_indices)
408
- # shuffled = x[shuffle_indices, :, :]
409
- # norm_shuffled = self.layer_norm_1(shuffled)
410
- # identity_x = shuffled[:identity_count, :, :]
411
- # process_x = shuffled[identity_count:, :, :]
412
- # residual = process_x
413
-
414
- # if self.pre_norm:
415
- # process_x = norm_shuffled[identity_count:, :, :]
416
-
417
- # process_x = residual + self.attn(process_x, process_x, process_x)
418
- # residual = process_x
419
-
420
- # shuffled = torch.cat([identity_x, process_x])
421
- # norm_shuffled = self.layer_norm_2(shuffled)
422
-
423
- # if self.pre_norm:
424
- # residual = process_x # residual NOT normed
425
- # process_x = norm_shuffled[identity_count:, :, :]
426
-
427
- # if self.post_norm:
428
- # process_x = norm_shuffled[identity_count:, :, :]
429
- # residual = process_x # residual normed
430
-
431
- # process_x = residual + self.ff(process_x) # handles residual connection
432
-
433
- # x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
434
-
435
- # return x if not self.post_norm else self.layer_norm_3(x)
436
-
437
402
 
438
403
  class TransformerEncoder(nn.Module):
439
404
  """
@@ -452,6 +417,8 @@ class TransformerEncoder(nn.Module):
452
417
  mlp_ratio=4,
453
418
  activation: nn.Module = nn.ReLU,
454
419
  activation_kwargs: Optional[dict] = None,
420
+ ff_linear_module_up=None,
421
+ ff_linear_module_down=None,
455
422
  mlp_dropout=0.0,
456
423
  msa_dropout=0.0,
457
424
  stochastic_depth=0.0,
@@ -515,6 +482,8 @@ class TransformerEncoder(nn.Module):
515
482
  mlp_ratio=mlp_ratio,
516
483
  activation=activation,
517
484
  activation_kwargs=activation_kwargs,
485
+ ff_linear_module_up=ff_linear_module_up,
486
+ ff_linear_module_down=ff_linear_module_down,
518
487
  mlp_dropout=mlp_dropout,
519
488
  msa_dropout=msa_dropout,
520
489
  identity_probability=self.stochastic_depth_probabilities[i],
@@ -132,6 +132,8 @@ class ViTEncoder(nn.Module):
132
132
  transformer_return_bos_tokens=False,
133
133
  transformer_activation: nn.Module = SquaredReLU,
134
134
  transformer_activation_kwargs: Optional[dict] = None,
135
+ transformer_ff_linear_module_up=None,
136
+ transformer_ff_linear_module_down=None,
135
137
  transformer_mlp_dropout=0.0,
136
138
  transformer_msa_dropout=0.1,
137
139
  transformer_stochastic_depth=0.1,
@@ -282,6 +284,8 @@ class ViTEncoder(nn.Module):
282
284
  mlp_ratio=transformer_mlp_ratio,
283
285
  activation=transformer_activation,
284
286
  activation_kwargs=transformer_activation_kwargs,
287
+ ff_linear_module_up=transformer_ff_linear_module_up,
288
+ ff_linear_module_down=transformer_ff_linear_module_down,
285
289
  mlp_dropout=transformer_mlp_dropout,
286
290
  msa_dropout=transformer_msa_dropout,
287
291
  stochastic_depth=transformer_stochastic_depth,
@@ -305,14 +309,16 @@ class ViTEncoder(nn.Module):
305
309
  activation_kwargs=transformer_activation_kwargs,
306
310
  dropout=transformer_mlp_dropout,
307
311
  linear_module_up=(
312
+ # First truthy assigned value
308
313
  transformer_initial_ff_linear_module_up
309
- if transformer_initial_ff_linear_module_up is not None
310
- else linear_module
314
+ or transformer_ff_linear_module_up
315
+ or linear_module
311
316
  ),
312
317
  linear_module_down=(
318
+ # First truthy assigned value
313
319
  transformer_initial_ff_linear_module_down
314
- if transformer_initial_ff_linear_module_down is not None
315
- else linear_module
320
+ or transformer_ff_linear_module_down
321
+ or linear_module
316
322
  ),
317
323
  pre_norm=transformer_pre_norm,
318
324
  normformer=transformer_normformer,
@@ -386,6 +392,8 @@ class ViT(nn.Module):
386
392
  transformer_return_bos_tokens=False,
387
393
  transformer_activation: nn.Module = SquaredReLU,
388
394
  transformer_activation_kwargs: Optional[dict] = None,
395
+ transformer_ff_linear_module_up=None,
396
+ transformer_ff_linear_module_down=None,
389
397
  transformer_mlp_dropout=0.0,
390
398
  transformer_msa_dropout=0.1,
391
399
  transformer_stochastic_depth=0.1,
@@ -446,6 +454,8 @@ class ViT(nn.Module):
446
454
  transformer_return_bos_tokens=transformer_return_bos_tokens,
447
455
  transformer_activation=transformer_activation,
448
456
  transformer_activation_kwargs=transformer_activation_kwargs,
457
+ transformer_ff_linear_module_up=transformer_ff_linear_module_up,
458
+ transformer_ff_linear_module_down=transformer_ff_linear_module_down,
449
459
  transformer_mlp_dropout=transformer_mlp_dropout,
450
460
  transformer_msa_dropout=transformer_msa_dropout,
451
461
  transformer_stochastic_depth=transformer_stochastic_depth,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.37.0"
3
+ version = "0.39.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes