broccoli-ml 3.2.0__tar.gz → 3.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 3.2.0
3
+ Version: 3.3.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -58,7 +58,6 @@ class ClassificationHead(nn.Module):
58
58
  self,
59
59
  d_model,
60
60
  n_classes,
61
- linear_module=nn.Linear,
62
61
  logit_projection_layer=nn.Linear,
63
62
  batch_norm_logits=True,
64
63
  ):
@@ -66,10 +65,11 @@ class ClassificationHead(nn.Module):
66
65
  self.d_model = d_model
67
66
  self.summarize = GetCLSToken()
68
67
 
69
- if logit_projection_layer is not None:
70
- self.projection = logit_projection_layer(d_model, n_classes)
71
- else:
68
+ if d_model == n_classes:
69
+ # No need to project
72
70
  self.projection = nn.Identity()
71
+ else:
72
+ self.projection = logit_projection_layer(d_model, n_classes)
73
73
 
74
74
  if batch_norm_logits:
75
75
  self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
@@ -99,19 +99,17 @@ class SequencePoolClassificationHead(ClassificationHead):
99
99
  self,
100
100
  d_model,
101
101
  n_classes,
102
- linear_module=nn.Linear,
103
102
  logit_projection_layer=nn.Linear,
104
103
  batch_norm_logits=True,
105
104
  ):
106
105
  super().__init__(
107
106
  d_model,
108
107
  n_classes,
109
- linear_module=linear_module,
110
108
  logit_projection_layer=logit_projection_layer,
111
109
  batch_norm_logits=batch_norm_logits,
112
110
  )
113
111
 
114
- self.summarize = SequencePool(d_model, linear_module)
112
+ self.summarize = SequencePool(d_model, logit_projection_layer)
115
113
  # Rebuild the classification process with the correct summary module:
116
114
  self.classification_process = nn.Sequential(
117
115
  *[
@@ -513,7 +511,6 @@ class ViT(nn.Module):
513
511
  self.pool = head(
514
512
  transformer_embedding_size,
515
513
  image_classes,
516
- linear_module=linear_module,
517
514
  logit_projection_layer=logit_projection_layer,
518
515
  batch_norm_logits=batch_norm_logits,
519
516
  )
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "3.2.0"
3
+ version = "3.3.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes