broccoli-ml 3.1.3__tar.gz → 3.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/PKG-INFO +1 -1
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/vit.py +5 -1
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/pyproject.toml +1 -1
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/LICENSE +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/README.md +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/activation.py +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/linear.py +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/rope.py +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-3.1.3 → broccoli_ml-3.2.0}/broccoli/transformer.py +0 -0
|
@@ -58,6 +58,7 @@ class ClassificationHead(nn.Module):
|
|
|
58
58
|
self,
|
|
59
59
|
d_model,
|
|
60
60
|
n_classes,
|
|
61
|
+
linear_module=nn.Linear,
|
|
61
62
|
logit_projection_layer=nn.Linear,
|
|
62
63
|
batch_norm_logits=True,
|
|
63
64
|
):
|
|
@@ -98,17 +99,19 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
|
98
99
|
self,
|
|
99
100
|
d_model,
|
|
100
101
|
n_classes,
|
|
102
|
+
linear_module=nn.Linear,
|
|
101
103
|
logit_projection_layer=nn.Linear,
|
|
102
104
|
batch_norm_logits=True,
|
|
103
105
|
):
|
|
104
106
|
super().__init__(
|
|
105
107
|
d_model,
|
|
106
108
|
n_classes,
|
|
109
|
+
linear_module=linear_module,
|
|
107
110
|
logit_projection_layer=logit_projection_layer,
|
|
108
111
|
batch_norm_logits=batch_norm_logits,
|
|
109
112
|
)
|
|
110
113
|
|
|
111
|
-
self.summarize = SequencePool(d_model,
|
|
114
|
+
self.summarize = SequencePool(d_model, linear_module)
|
|
112
115
|
# Rebuild the classification process with the correct summary module:
|
|
113
116
|
self.classification_process = nn.Sequential(
|
|
114
117
|
*[
|
|
@@ -510,6 +513,7 @@ class ViT(nn.Module):
|
|
|
510
513
|
self.pool = head(
|
|
511
514
|
transformer_embedding_size,
|
|
512
515
|
image_classes,
|
|
516
|
+
linear_module=linear_module,
|
|
513
517
|
logit_projection_layer=logit_projection_layer,
|
|
514
518
|
batch_norm_logits=batch_norm_logits,
|
|
515
519
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|