broccoli-ml 0.5.5__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +12 -22
- broccoli/vit.py +7 -0
- {broccoli_ml-0.5.5.dist-info → broccoli_ml-0.7.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.5.5.dist-info → broccoli_ml-0.7.0.dist-info}/RECORD +6 -6
- {broccoli_ml-0.5.5.dist-info → broccoli_ml-0.7.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.5.5.dist-info → broccoli_ml-0.7.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -223,7 +223,7 @@ class MHAttention(nn.Module):
|
|
223
223
|
|
224
224
|
class FeedforwardLayer(nn.Module):
|
225
225
|
"""
|
226
|
-
|
226
|
+
...
|
227
227
|
"""
|
228
228
|
|
229
229
|
def __init__(
|
@@ -247,6 +247,7 @@ class FeedforwardLayer(nn.Module):
|
|
247
247
|
|
248
248
|
self.process = nn.Sequential(
|
249
249
|
*[
|
250
|
+
nn.LayerNorm(input_features),
|
250
251
|
linear_module(
|
251
252
|
input_features,
|
252
253
|
(
|
@@ -256,8 +257,8 @@ class FeedforwardLayer(nn.Module):
|
|
256
257
|
),
|
257
258
|
),
|
258
259
|
self.activation,
|
259
|
-
self.dropout,
|
260
260
|
linear_module(ratio * input_features, output_features),
|
261
|
+
self.dropout,
|
261
262
|
]
|
262
263
|
)
|
263
264
|
|
@@ -323,25 +324,14 @@ class TransformerBlock(nn.Module):
|
|
323
324
|
)
|
324
325
|
|
325
326
|
# Submodules for the feedforward process
|
326
|
-
self.
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
mlp_ratio,
|
335
|
-
d_model,
|
336
|
-
activation=activation,
|
337
|
-
activation_kwargs=activation_kwargs,
|
338
|
-
dropout=0.0,
|
339
|
-
linear_module=linear_module,
|
340
|
-
),
|
341
|
-
),
|
342
|
-
("dropout", nn.Dropout(mlp_dropout)),
|
343
|
-
]
|
344
|
-
)
|
327
|
+
self.ff = FeedforwardLayer(
|
328
|
+
d_model,
|
329
|
+
mlp_ratio,
|
330
|
+
d_model,
|
331
|
+
activation=activation,
|
332
|
+
activation_kwargs=activation_kwargs,
|
333
|
+
dropout=mlp_dropout,
|
334
|
+
linear_module=linear_module,
|
345
335
|
)
|
346
336
|
|
347
337
|
@property
|
@@ -366,7 +356,7 @@ class TransformerBlock(nn.Module):
|
|
366
356
|
process_x = process_x + self.attn(
|
367
357
|
norm_process_x, norm_process_x, norm_process_x
|
368
358
|
)
|
369
|
-
process_x = process_x + self.
|
359
|
+
process_x = process_x + self.ff(process_x)
|
370
360
|
x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
371
361
|
|
372
362
|
return x
|
broccoli/vit.py
CHANGED
@@ -269,6 +269,13 @@ class ViTEncoder(nn.Module):
|
|
269
269
|
Rearrange( # for transformer
|
270
270
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
271
271
|
),
|
272
|
+
(
|
273
|
+
PadTensor(
|
274
|
+
(0, transformer_embedding_size - pooling_out_channels)
|
275
|
+
)
|
276
|
+
if not intermediate_feedforward_layer
|
277
|
+
else nn.Identity()
|
278
|
+
),
|
272
279
|
]
|
273
280
|
)
|
274
281
|
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=RSZpbHs_K4ts5os6lWxcGDI3p0zreRwQNnk6mV8HJnk,15930
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=_oL0NRUJakyIke2g8WK5eWaiEh06gAhI67l6Wl7k1oM,15659
|
14
|
+
broccoli_ml-0.7.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.7.0.dist-info/METADATA,sha256=1QUwYpIruYYiYcMHSgj5lCf-i-FaiipD_5KSAJZeb2s,1256
|
16
|
+
broccoli_ml-0.7.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.7.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|