broccoli-ml 0.33.2__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +36 -25
- broccoli/vit.py +26 -6
- {broccoli_ml-0.33.2.dist-info → broccoli_ml-0.35.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.33.2.dist-info → broccoli_ml-0.35.0.dist-info}/RECORD +6 -6
- {broccoli_ml-0.33.2.dist-info → broccoli_ml-0.35.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.33.2.dist-info → broccoli_ml-0.35.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -235,25 +235,26 @@ class FeedforwardBlock(nn.Module):
|
|
235
235
|
activation=nn.ReLU,
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
|
-
|
238
|
+
linear_module_up=nn.Linear,
|
239
|
+
linear_module_down=nn.Linear,
|
239
240
|
pre_norm=True,
|
240
241
|
normformer=False,
|
241
|
-
|
242
|
+
post_norm=True,
|
243
|
+
residual_path=True,
|
242
244
|
):
|
243
245
|
super().__init__()
|
244
246
|
|
247
|
+
self.residual_path = residual_path
|
248
|
+
self.post_norm = post_norm
|
249
|
+
|
250
|
+
if self.post_norm:
|
251
|
+
self.layernorm = nn.LayerNorm(output_features)
|
252
|
+
|
245
253
|
if activation_kwargs is not None:
|
246
254
|
self.activation = activation(**activation_kwargs)
|
247
255
|
else:
|
248
256
|
self.activation = activation()
|
249
257
|
|
250
|
-
if raw_input:
|
251
|
-
self.memory_type = AnchoredLinear
|
252
|
-
self.memory_bias = False
|
253
|
-
else:
|
254
|
-
self.memory_type = nn.Linear
|
255
|
-
self.memory_bias = True
|
256
|
-
|
257
258
|
self.dropout = nn.Dropout(dropout)
|
258
259
|
|
259
260
|
self.max_features = (
|
@@ -265,18 +266,21 @@ class FeedforwardBlock(nn.Module):
|
|
265
266
|
self.process = nn.Sequential(
|
266
267
|
*[
|
267
268
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
268
|
-
|
269
|
+
linear_module_up(input_features, self.max_features),
|
269
270
|
self.activation,
|
270
271
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
271
|
-
|
272
|
-
ratio * output_features, output_features, bias=self.memory_bias
|
273
|
-
),
|
272
|
+
linear_module_down(ratio * output_features, output_features),
|
274
273
|
self.dropout,
|
275
274
|
]
|
276
275
|
)
|
277
276
|
|
278
277
|
def forward(self, x):
|
279
|
-
|
278
|
+
if self.residual_path and self.post_norm:
|
279
|
+
return self.layernorm(x + self.process(x))
|
280
|
+
elif self.residual_path:
|
281
|
+
return x + self.process(x)
|
282
|
+
else:
|
283
|
+
return x
|
280
284
|
|
281
285
|
|
282
286
|
class TransformerBlock(nn.Module):
|
@@ -305,11 +309,14 @@ class TransformerBlock(nn.Module):
|
|
305
309
|
causal=False,
|
306
310
|
linear_module=nn.Linear,
|
307
311
|
pre_norm=True,
|
312
|
+
post_norm=False,
|
308
313
|
normformer=False,
|
309
314
|
):
|
310
315
|
super().__init__()
|
311
316
|
|
312
317
|
self.pre_norm = pre_norm
|
318
|
+
self.post_norm = post_norm
|
319
|
+
self.normformer = normformer
|
313
320
|
|
314
321
|
self.identity_probability = identity_probability
|
315
322
|
|
@@ -340,7 +347,7 @@ class TransformerBlock(nn.Module):
|
|
340
347
|
bos_tokens=bos_tokens,
|
341
348
|
)
|
342
349
|
|
343
|
-
#
|
350
|
+
# Submodule for the feedforward process
|
344
351
|
self.ff = FeedforwardBlock(
|
345
352
|
d_model,
|
346
353
|
mlp_ratio,
|
@@ -348,9 +355,12 @@ class TransformerBlock(nn.Module):
|
|
348
355
|
activation=activation,
|
349
356
|
activation_kwargs=activation_kwargs,
|
350
357
|
dropout=mlp_dropout,
|
351
|
-
|
358
|
+
linear_module_up=linear_module,
|
359
|
+
linear_module_down=linear_module,
|
352
360
|
pre_norm=pre_norm,
|
353
361
|
normformer=normformer,
|
362
|
+
post_norm=post_norm,
|
363
|
+
residual_path=True,
|
354
364
|
)
|
355
365
|
|
356
366
|
@property
|
@@ -371,19 +381,18 @@ class TransformerBlock(nn.Module):
|
|
371
381
|
identity_x = shuffled[:identity_count, :, :]
|
372
382
|
process_x = shuffled[identity_count:, :, :]
|
373
383
|
|
384
|
+
residual_x = process_x
|
385
|
+
|
374
386
|
if self.pre_norm:
|
375
|
-
norm_process_x = self.layer_norm_1(process_x)
|
376
|
-
process_x = process_x + self.attn(
|
377
|
-
norm_process_x, norm_process_x, norm_process_x
|
378
|
-
)
|
379
|
-
process_x = process_x + self.ff(process_x)
|
380
|
-
else: # post-norm
|
381
|
-
process_x = process_x + self.attn(process_x, process_x, process_x)
|
382
387
|
process_x = self.layer_norm_1(process_x)
|
383
|
-
|
388
|
+
|
389
|
+
process_x = residual_x + self.attn(process_x, process_x, process_x)
|
390
|
+
|
391
|
+
if self.post_norm:
|
384
392
|
process_x = self.layer_norm_2(process_x)
|
385
393
|
|
386
|
-
|
394
|
+
process_x = self.ff(process_x)
|
395
|
+
|
387
396
|
x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
388
397
|
|
389
398
|
return x
|
@@ -414,6 +423,7 @@ class TransformerEncoder(nn.Module):
|
|
414
423
|
bos_tokens=0,
|
415
424
|
return_bos_tokens=False,
|
416
425
|
pre_norm=True,
|
426
|
+
post_norm=False,
|
417
427
|
normformer=False,
|
418
428
|
):
|
419
429
|
if position_embedding_type == "relative":
|
@@ -474,6 +484,7 @@ class TransformerEncoder(nn.Module):
|
|
474
484
|
causal=causal,
|
475
485
|
linear_module=linear_module,
|
476
486
|
pre_norm=pre_norm,
|
487
|
+
post_norm=post_norm,
|
477
488
|
normformer=normformer,
|
478
489
|
)
|
479
490
|
for i in range(n_layers)
|
broccoli/vit.py
CHANGED
@@ -53,9 +53,7 @@ class ClassificationHead(nn.Module):
|
|
53
53
|
A general classification head for a ViT
|
54
54
|
"""
|
55
55
|
|
56
|
-
def __init__(
|
57
|
-
self, d_model, linear_module, n_classes, layer_norm=True, batch_norm=True
|
58
|
-
):
|
56
|
+
def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
|
59
57
|
super().__init__()
|
60
58
|
self.d_model = d_model
|
61
59
|
self.summarize = GetCLSToken()
|
@@ -67,7 +65,6 @@ class ClassificationHead(nn.Module):
|
|
67
65
|
|
68
66
|
self.classification_process = nn.Sequential(
|
69
67
|
*[
|
70
|
-
nn.LayerNorm(d_model) if layer_norm else nn.Identity(),
|
71
68
|
self.summarize,
|
72
69
|
self.projection,
|
73
70
|
self.batch_norm,
|
@@ -120,8 +117,12 @@ class ViTEncoder(nn.Module):
|
|
120
117
|
pooling_kernel_stride=2,
|
121
118
|
pooling_padding=1,
|
122
119
|
transformer_feedforward_first=True,
|
120
|
+
transformer_initial_ff_residual_path=True,
|
121
|
+
transformer_initial_ff_linear_module_up=None,
|
122
|
+
transformer_initial_ff_linear_module_down=None,
|
123
123
|
transformer_pre_norm=True,
|
124
124
|
transformer_normformer=False,
|
125
|
+
transformer_post_norm=False,
|
125
126
|
transformer_position_embedding="relative", # absolute or relative
|
126
127
|
transformer_embedding_size=256,
|
127
128
|
transformer_layers=7,
|
@@ -296,6 +297,7 @@ class ViTEncoder(nn.Module):
|
|
296
297
|
return_bos_tokens=transformer_return_bos_tokens,
|
297
298
|
pre_norm=transformer_pre_norm,
|
298
299
|
normformer=transformer_normformer,
|
300
|
+
post_norm=transformer_post_norm,
|
299
301
|
)
|
300
302
|
else:
|
301
303
|
self.transformer = nn.Identity()
|
@@ -308,10 +310,20 @@ class ViTEncoder(nn.Module):
|
|
308
310
|
activation=transformer_activation,
|
309
311
|
activation_kwargs=transformer_activation_kwargs,
|
310
312
|
dropout=transformer_mlp_dropout,
|
311
|
-
|
313
|
+
linear_module_up=(
|
314
|
+
transformer_initial_ff_linear_module_up
|
315
|
+
if transformer_initial_ff_linear_module_up is not None
|
316
|
+
else linear_module
|
317
|
+
),
|
318
|
+
linear_module_down=(
|
319
|
+
transformer_initial_ff_linear_module_down
|
320
|
+
if transformer_initial_ff_linear_module_down is not None
|
321
|
+
else linear_module
|
322
|
+
),
|
312
323
|
pre_norm=transformer_pre_norm,
|
313
324
|
normformer=transformer_normformer,
|
314
|
-
|
325
|
+
post_norm=transformer_post_norm,
|
326
|
+
residual_path=transformer_initial_ff_residual_path,
|
315
327
|
)
|
316
328
|
else:
|
317
329
|
self.initial_ff = nn.Identity()
|
@@ -365,8 +377,12 @@ class ViT(nn.Module):
|
|
365
377
|
pooling_kernel_stride=2,
|
366
378
|
pooling_padding=1,
|
367
379
|
transformer_feedforward_first=True,
|
380
|
+
transformer_initial_ff_residual_path=True,
|
381
|
+
transformer_initial_ff_linear_module_up=None,
|
382
|
+
transformer_initial_ff_linear_module_down=None,
|
368
383
|
transformer_pre_norm=True,
|
369
384
|
transformer_normformer=False,
|
385
|
+
transformer_post_norm=False,
|
370
386
|
transformer_position_embedding="relative", # absolute or relative
|
371
387
|
transformer_embedding_size=256,
|
372
388
|
transformer_layers=7,
|
@@ -421,8 +437,12 @@ class ViT(nn.Module):
|
|
421
437
|
pooling_kernel_stride=pooling_kernel_stride,
|
422
438
|
pooling_padding=pooling_padding,
|
423
439
|
transformer_feedforward_first=transformer_feedforward_first,
|
440
|
+
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
441
|
+
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
442
|
+
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
424
443
|
transformer_pre_norm=transformer_pre_norm,
|
425
444
|
transformer_normformer=transformer_normformer,
|
445
|
+
transformer_post_norm=transformer_post_norm,
|
426
446
|
transformer_position_embedding=transformer_position_embedding,
|
427
447
|
transformer_embedding_size=transformer_embedding_size,
|
428
448
|
transformer_layers=transformer_layers,
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=ks2TRCdS10k2XvxEieh2sj_LzjTNRuiO6gekKFTtziI,4533
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=t0gsADJC9UOlwjm7tDKdy0pAZ8l3clTcCnes86zvH-k,17203
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=c-ZRHiLDOoQDJO9OJ51zD9HqaluG33flIwTXQQfms-g,17389
|
14
|
+
broccoli_ml-0.35.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.35.0.dist-info/METADATA,sha256=v0JSpcubSGwxA5dFPbDwz2r2oGZWSeqYND1Mu8WOiJY,1257
|
16
|
+
broccoli_ml-0.35.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.35.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|