broccoli-ml 0.33.2__tar.gz → 0.34.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/PKG-INFO +1 -1
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/transformer.py +30 -21
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/vit.py +16 -6
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/pyproject.toml +1 -1
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/LICENSE +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/README.md +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/activation.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/linear.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/rope.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-0.33.2 → broccoli_ml-0.34.0}/broccoli/utils.py +0 -0
@@ -238,22 +238,22 @@ class FeedforwardBlock(nn.Module):
|
|
238
238
|
linear_module=nn.Linear,
|
239
239
|
pre_norm=True,
|
240
240
|
normformer=False,
|
241
|
-
|
241
|
+
post_norm=True,
|
242
|
+
residual_path=True,
|
242
243
|
):
|
243
244
|
super().__init__()
|
244
245
|
|
246
|
+
self.residual_path = residual_path
|
247
|
+
self.post_norm = post_norm
|
248
|
+
|
249
|
+
if self.post_norm:
|
250
|
+
self.layernorm = nn.LayerNorm(output_features)
|
251
|
+
|
245
252
|
if activation_kwargs is not None:
|
246
253
|
self.activation = activation(**activation_kwargs)
|
247
254
|
else:
|
248
255
|
self.activation = activation()
|
249
256
|
|
250
|
-
if raw_input:
|
251
|
-
self.memory_type = AnchoredLinear
|
252
|
-
self.memory_bias = False
|
253
|
-
else:
|
254
|
-
self.memory_type = nn.Linear
|
255
|
-
self.memory_bias = True
|
256
|
-
|
257
257
|
self.dropout = nn.Dropout(dropout)
|
258
258
|
|
259
259
|
self.max_features = (
|
@@ -268,15 +268,18 @@ class FeedforwardBlock(nn.Module):
|
|
268
268
|
linear_module(input_features, self.max_features),
|
269
269
|
self.activation,
|
270
270
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
271
|
-
|
272
|
-
ratio * output_features, output_features, bias=self.memory_bias
|
273
|
-
),
|
271
|
+
linear_module(ratio * output_features, output_features),
|
274
272
|
self.dropout,
|
275
273
|
]
|
276
274
|
)
|
277
275
|
|
278
276
|
def forward(self, x):
|
279
|
-
|
277
|
+
if self.residual_path and self.post_norm:
|
278
|
+
return self.layernorm(x + self.process(x))
|
279
|
+
elif self.residual_path:
|
280
|
+
return x + self.process(x)
|
281
|
+
else:
|
282
|
+
return x
|
280
283
|
|
281
284
|
|
282
285
|
class TransformerBlock(nn.Module):
|
@@ -305,11 +308,14 @@ class TransformerBlock(nn.Module):
|
|
305
308
|
causal=False,
|
306
309
|
linear_module=nn.Linear,
|
307
310
|
pre_norm=True,
|
311
|
+
post_norm=False,
|
308
312
|
normformer=False,
|
309
313
|
):
|
310
314
|
super().__init__()
|
311
315
|
|
312
316
|
self.pre_norm = pre_norm
|
317
|
+
self.post_norm = post_norm
|
318
|
+
self.normformer = normformer
|
313
319
|
|
314
320
|
self.identity_probability = identity_probability
|
315
321
|
|
@@ -351,6 +357,8 @@ class TransformerBlock(nn.Module):
|
|
351
357
|
linear_module=linear_module,
|
352
358
|
pre_norm=pre_norm,
|
353
359
|
normformer=normformer,
|
360
|
+
post_norm=post_norm,
|
361
|
+
residual_path=True,
|
354
362
|
)
|
355
363
|
|
356
364
|
@property
|
@@ -371,19 +379,18 @@ class TransformerBlock(nn.Module):
|
|
371
379
|
identity_x = shuffled[:identity_count, :, :]
|
372
380
|
process_x = shuffled[identity_count:, :, :]
|
373
381
|
|
382
|
+
residual_x = process_x
|
383
|
+
|
374
384
|
if self.pre_norm:
|
375
|
-
norm_process_x = self.layer_norm_1(process_x)
|
376
|
-
process_x = process_x + self.attn(
|
377
|
-
norm_process_x, norm_process_x, norm_process_x
|
378
|
-
)
|
379
|
-
process_x = process_x + self.ff(process_x)
|
380
|
-
else: # post-norm
|
381
|
-
process_x = process_x + self.attn(process_x, process_x, process_x)
|
382
385
|
process_x = self.layer_norm_1(process_x)
|
383
|
-
|
386
|
+
|
387
|
+
process_x = residual_x + self.attn(process_x, process_x, process_x)
|
388
|
+
|
389
|
+
if self.post_norm:
|
384
390
|
process_x = self.layer_norm_2(process_x)
|
385
391
|
|
386
|
-
|
392
|
+
process_x = self.ff(process_x)
|
393
|
+
|
387
394
|
x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
388
395
|
|
389
396
|
return x
|
@@ -414,6 +421,7 @@ class TransformerEncoder(nn.Module):
|
|
414
421
|
bos_tokens=0,
|
415
422
|
return_bos_tokens=False,
|
416
423
|
pre_norm=True,
|
424
|
+
post_norm=False,
|
417
425
|
normformer=False,
|
418
426
|
):
|
419
427
|
if position_embedding_type == "relative":
|
@@ -474,6 +482,7 @@ class TransformerEncoder(nn.Module):
|
|
474
482
|
causal=causal,
|
475
483
|
linear_module=linear_module,
|
476
484
|
pre_norm=pre_norm,
|
485
|
+
post_norm=post_norm,
|
477
486
|
normformer=normformer,
|
478
487
|
)
|
479
488
|
for i in range(n_layers)
|
@@ -53,9 +53,7 @@ class ClassificationHead(nn.Module):
|
|
53
53
|
A general classification head for a ViT
|
54
54
|
"""
|
55
55
|
|
56
|
-
def __init__(
|
57
|
-
self, d_model, linear_module, n_classes, layer_norm=True, batch_norm=True
|
58
|
-
):
|
56
|
+
def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
|
59
57
|
super().__init__()
|
60
58
|
self.d_model = d_model
|
61
59
|
self.summarize = GetCLSToken()
|
@@ -67,7 +65,6 @@ class ClassificationHead(nn.Module):
|
|
67
65
|
|
68
66
|
self.classification_process = nn.Sequential(
|
69
67
|
*[
|
70
|
-
nn.LayerNorm(d_model) if layer_norm else nn.Identity(),
|
71
68
|
self.summarize,
|
72
69
|
self.projection,
|
73
70
|
self.batch_norm,
|
@@ -120,8 +117,11 @@ class ViTEncoder(nn.Module):
|
|
120
117
|
pooling_kernel_stride=2,
|
121
118
|
pooling_padding=1,
|
122
119
|
transformer_feedforward_first=True,
|
120
|
+
transformer_initial_ff_residual_path=True,
|
121
|
+
transformer_initial_ff_linear_module=None,
|
123
122
|
transformer_pre_norm=True,
|
124
123
|
transformer_normformer=False,
|
124
|
+
transformer_post_norm=False,
|
125
125
|
transformer_position_embedding="relative", # absolute or relative
|
126
126
|
transformer_embedding_size=256,
|
127
127
|
transformer_layers=7,
|
@@ -308,10 +308,14 @@ class ViTEncoder(nn.Module):
|
|
308
308
|
activation=transformer_activation,
|
309
309
|
activation_kwargs=transformer_activation_kwargs,
|
310
310
|
dropout=transformer_mlp_dropout,
|
311
|
-
linear_module=
|
311
|
+
linear_module=(
|
312
|
+
transformer_initial_ff_linear_module
|
313
|
+
if transformer_initial_ff_linear_module is not None
|
314
|
+
else linear_module
|
315
|
+
),
|
312
316
|
pre_norm=transformer_pre_norm,
|
313
317
|
normformer=transformer_normformer,
|
314
|
-
|
318
|
+
residual_path=transformer_initial_ff_residual_path,
|
315
319
|
)
|
316
320
|
else:
|
317
321
|
self.initial_ff = nn.Identity()
|
@@ -365,8 +369,11 @@ class ViT(nn.Module):
|
|
365
369
|
pooling_kernel_stride=2,
|
366
370
|
pooling_padding=1,
|
367
371
|
transformer_feedforward_first=True,
|
372
|
+
transformer_initial_ff_residual_path=True,
|
373
|
+
transformer_initial_ff_linear_module=None,
|
368
374
|
transformer_pre_norm=True,
|
369
375
|
transformer_normformer=False,
|
376
|
+
transformer_post_norm=False,
|
370
377
|
transformer_position_embedding="relative", # absolute or relative
|
371
378
|
transformer_embedding_size=256,
|
372
379
|
transformer_layers=7,
|
@@ -421,8 +428,11 @@ class ViT(nn.Module):
|
|
421
428
|
pooling_kernel_stride=pooling_kernel_stride,
|
422
429
|
pooling_padding=pooling_padding,
|
423
430
|
transformer_feedforward_first=transformer_feedforward_first,
|
431
|
+
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
432
|
+
transformer_initial_ff_linear_module=transformer_initial_ff_linear_module,
|
424
433
|
transformer_pre_norm=transformer_pre_norm,
|
425
434
|
transformer_normformer=transformer_normformer,
|
435
|
+
transformer_post_norm=transformer_post_norm,
|
426
436
|
transformer_position_embedding=transformer_position_embedding,
|
427
437
|
transformer_embedding_size=transformer_embedding_size,
|
428
438
|
transformer_layers=transformer_layers,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|