broccoli-ml 3.1.2__tar.gz → 3.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/PKG-INFO +1 -1
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/rope.py +2 -2
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/vit.py +5 -1
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/pyproject.toml +1 -1
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/LICENSE +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/README.md +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/activation.py +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/linear.py +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-3.1.2 → broccoli_ml-3.2.0}/broccoli/transformer.py +0 -0
|
@@ -41,12 +41,12 @@ except ImportError:
|
|
|
41
41
|
# Fallback: For PyTorch 1.6 to 1.9
|
|
42
42
|
from torch.cuda.amp import autocast
|
|
43
43
|
|
|
44
|
-
def autocast_factory(_, enabled=True
|
|
44
|
+
def autocast_factory(_, enabled=True):
|
|
45
45
|
"""
|
|
46
46
|
A wrapper that mimics the modern autocast signature but calls the older
|
|
47
47
|
torch.cuda.amp.autocast, ignoring the device_type argument.
|
|
48
48
|
"""
|
|
49
|
-
return autocast(enabled=enabled
|
|
49
|
+
return autocast(enabled=enabled)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
from einops import rearrange, repeat
|
|
@@ -58,6 +58,7 @@ class ClassificationHead(nn.Module):
|
|
|
58
58
|
self,
|
|
59
59
|
d_model,
|
|
60
60
|
n_classes,
|
|
61
|
+
linear_module=nn.Linear,
|
|
61
62
|
logit_projection_layer=nn.Linear,
|
|
62
63
|
batch_norm_logits=True,
|
|
63
64
|
):
|
|
@@ -98,17 +99,19 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
|
98
99
|
self,
|
|
99
100
|
d_model,
|
|
100
101
|
n_classes,
|
|
102
|
+
linear_module=nn.Linear,
|
|
101
103
|
logit_projection_layer=nn.Linear,
|
|
102
104
|
batch_norm_logits=True,
|
|
103
105
|
):
|
|
104
106
|
super().__init__(
|
|
105
107
|
d_model,
|
|
106
108
|
n_classes,
|
|
109
|
+
linear_module=linear_module,
|
|
107
110
|
logit_projection_layer=logit_projection_layer,
|
|
108
111
|
batch_norm_logits=batch_norm_logits,
|
|
109
112
|
)
|
|
110
113
|
|
|
111
|
-
self.summarize = SequencePool(d_model,
|
|
114
|
+
self.summarize = SequencePool(d_model, linear_module)
|
|
112
115
|
# Rebuild the classification process with the correct summary module:
|
|
113
116
|
self.classification_process = nn.Sequential(
|
|
114
117
|
*[
|
|
@@ -510,6 +513,7 @@ class ViT(nn.Module):
|
|
|
510
513
|
self.pool = head(
|
|
511
514
|
transformer_embedding_size,
|
|
512
515
|
image_classes,
|
|
516
|
+
linear_module=linear_module,
|
|
513
517
|
logit_projection_layer=logit_projection_layer,
|
|
514
518
|
batch_norm_logits=batch_norm_logits,
|
|
515
519
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|