broccoli-ml 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +10 -6
- broccoli/vit.py +2 -9
- {broccoli_ml-0.19.0.dist-info → broccoli_ml-0.21.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.19.0.dist-info → broccoli_ml-0.21.0.dist-info}/RECORD +6 -6
- {broccoli_ml-0.19.0.dist-info → broccoli_ml-0.21.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.19.0.dist-info → broccoli_ml-0.21.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -255,20 +255,24 @@ class FeedforwardBlock(nn.Module):
|
|
255
255
|
|
256
256
|
if regularise_values:
|
257
257
|
self.memory_type = SpectralNormLinear
|
258
|
-
self.bias_memories = False
|
259
258
|
else:
|
260
259
|
self.memory_type = nn.Linear
|
261
|
-
self.bias_memories = True
|
262
260
|
|
263
261
|
self.process = nn.Sequential(
|
264
262
|
*[
|
265
|
-
|
263
|
+
(
|
264
|
+
nn.LayerNorm(input_features)
|
265
|
+
if not self.regularise_values
|
266
|
+
else nn.Identity()
|
267
|
+
),
|
266
268
|
linear_module(input_features, self.max_features),
|
267
269
|
self.activation,
|
268
|
-
|
269
|
-
|
270
|
-
|
270
|
+
(
|
271
|
+
nn.LayerNorm(input_features)
|
272
|
+
if not self.regularise_values
|
273
|
+
else nn.Identity()
|
271
274
|
),
|
275
|
+
self.memory_type(ratio * output_features, output_features),
|
272
276
|
self.dropout,
|
273
277
|
]
|
274
278
|
)
|
broccoli/vit.py
CHANGED
@@ -30,7 +30,7 @@ class GetCLSToken(nn.Module):
|
|
30
30
|
|
31
31
|
|
32
32
|
class SequencePool(nn.Module):
|
33
|
-
def __init__(self, d_model, linear_module):
|
33
|
+
def __init__(self, d_model, linear_module=nn.Linear):
|
34
34
|
super().__init__()
|
35
35
|
self.attention = nn.Sequential(
|
36
36
|
*[
|
@@ -54,13 +54,6 @@ class ClassificationHead(nn.Module):
|
|
54
54
|
super().__init__()
|
55
55
|
self.d_model = d_model
|
56
56
|
self.summarize = GetCLSToken()
|
57
|
-
self.process = nn.Sequential(
|
58
|
-
*[
|
59
|
-
linear_module(d_model, 1),
|
60
|
-
Rearrange("batch seq 1 -> batch seq"),
|
61
|
-
nn.Softmax(dim=-1),
|
62
|
-
]
|
63
|
-
)
|
64
57
|
self.projection = nn.Linear(d_model, n_classes)
|
65
58
|
if batch_norm:
|
66
59
|
self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
|
@@ -87,7 +80,7 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
87
80
|
"""
|
88
81
|
|
89
82
|
def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
|
90
|
-
super().__init__(d_model, linear_module, out_dim, batch_norm=
|
83
|
+
super().__init__(d_model, linear_module, out_dim, batch_norm=batch_norm)
|
91
84
|
self.summarize = SequencePool(d_model, linear_module)
|
92
85
|
|
93
86
|
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=g8YrxNl6g_WcHrWVmbaBHJU5hv6daFS0r4TxAoPJ9UE,3012
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=MUvXtwD2f1sPTBym4FB0x_ZfsJUBNLgULUlN8btV8GI,1943
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=2O9EALlZPQXYCtOGvkCgoWv0-K8rNtsp-MTon1h_4n8,16343
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=vQtOcC0Dd8y6PTWx0xCbnE4ymYkL_HfYrerqaJ0hs1k,16404
|
14
|
+
broccoli_ml-0.21.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.21.0.dist-info/METADATA,sha256=jtoA3JoVE-R_VYlofiCzNxnWtzvV1Tu1HjR_Bm9T8lA,1257
|
16
|
+
broccoli_ml-0.21.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.21.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|