broccoli-ml 3.2.0__tar.gz → 3.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/PKG-INFO +1 -1
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/vit.py +5 -8
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/pyproject.toml +1 -1
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/LICENSE +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/README.md +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/activation.py +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/linear.py +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/rope.py +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-3.2.0 → broccoli_ml-3.3.0}/broccoli/transformer.py +0 -0
|
@@ -58,7 +58,6 @@ class ClassificationHead(nn.Module):
|
|
|
58
58
|
self,
|
|
59
59
|
d_model,
|
|
60
60
|
n_classes,
|
|
61
|
-
linear_module=nn.Linear,
|
|
62
61
|
logit_projection_layer=nn.Linear,
|
|
63
62
|
batch_norm_logits=True,
|
|
64
63
|
):
|
|
@@ -66,10 +65,11 @@ class ClassificationHead(nn.Module):
|
|
|
66
65
|
self.d_model = d_model
|
|
67
66
|
self.summarize = GetCLSToken()
|
|
68
67
|
|
|
69
|
-
if
|
|
70
|
-
|
|
71
|
-
else:
|
|
68
|
+
if d_model == n_classes:
|
|
69
|
+
# No need to project
|
|
72
70
|
self.projection = nn.Identity()
|
|
71
|
+
else:
|
|
72
|
+
self.projection = logit_projection_layer(d_model, n_classes)
|
|
73
73
|
|
|
74
74
|
if batch_norm_logits:
|
|
75
75
|
self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
|
|
@@ -99,19 +99,17 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
|
99
99
|
self,
|
|
100
100
|
d_model,
|
|
101
101
|
n_classes,
|
|
102
|
-
linear_module=nn.Linear,
|
|
103
102
|
logit_projection_layer=nn.Linear,
|
|
104
103
|
batch_norm_logits=True,
|
|
105
104
|
):
|
|
106
105
|
super().__init__(
|
|
107
106
|
d_model,
|
|
108
107
|
n_classes,
|
|
109
|
-
linear_module=linear_module,
|
|
110
108
|
logit_projection_layer=logit_projection_layer,
|
|
111
109
|
batch_norm_logits=batch_norm_logits,
|
|
112
110
|
)
|
|
113
111
|
|
|
114
|
-
self.summarize = SequencePool(d_model,
|
|
112
|
+
self.summarize = SequencePool(d_model, logit_projection_layer)
|
|
115
113
|
# Rebuild the classification process with the correct summary module:
|
|
116
114
|
self.classification_process = nn.Sequential(
|
|
117
115
|
*[
|
|
@@ -513,7 +511,6 @@ class ViT(nn.Module):
|
|
|
513
511
|
self.pool = head(
|
|
514
512
|
transformer_embedding_size,
|
|
515
513
|
image_classes,
|
|
516
|
-
linear_module=linear_module,
|
|
517
514
|
logit_projection_layer=logit_projection_layer,
|
|
518
515
|
batch_norm_logits=batch_norm_logits,
|
|
519
516
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|