broccoli-ml 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +42 -4
- broccoli/vit.py +22 -8
- {broccoli_ml-1.0.0.dist-info → broccoli_ml-1.2.0.dist-info}/METADATA +1 -1
- {broccoli_ml-1.0.0.dist-info → broccoli_ml-1.2.0.dist-info}/RECORD +6 -6
- {broccoli_ml-1.0.0.dist-info → broccoli_ml-1.2.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-1.0.0.dist-info → broccoli_ml-1.2.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -73,7 +73,15 @@ class MHAttention(nn.Module):
|
|
73
73
|
bos_tokens=0,
|
74
74
|
rotary_embedding=None,
|
75
75
|
source_size=None,
|
76
|
+
scaling="d",
|
76
77
|
):
|
78
|
+
"""
|
79
|
+
Args:
|
80
|
+
scaling: how should the attention logits be scaled? Can be "sqrtd"
|
81
|
+
to mimic the original Attention is All You Need approach of
|
82
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
83
|
+
"Tensor Programs V...". Default "d"
|
84
|
+
"""
|
77
85
|
super().__init__()
|
78
86
|
|
79
87
|
if rotary_embedding is not None:
|
@@ -84,6 +92,7 @@ class MHAttention(nn.Module):
|
|
84
92
|
self.embed_dim = embed_dim
|
85
93
|
self.n_heads = n_heads
|
86
94
|
assert embed_dim % n_heads == 0
|
95
|
+
self.scaling = scaling
|
87
96
|
|
88
97
|
self.head_dim = self.embed_dim // self.n_heads
|
89
98
|
|
@@ -207,7 +216,12 @@ class MHAttention(nn.Module):
|
|
207
216
|
|
208
217
|
qk_scores = q @ k.transpose(-1, -2)
|
209
218
|
|
210
|
-
|
219
|
+
if self.scaling == "sqrtd":
|
220
|
+
qk_scores /= math.sqrt(self.head_dim)
|
221
|
+
elif self.scaling == "d":
|
222
|
+
qk_scores /= self.head_dim
|
223
|
+
else:
|
224
|
+
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
211
225
|
|
212
226
|
# Apply mask if causal (must come before softmax)
|
213
227
|
if self.causal:
|
@@ -305,6 +319,7 @@ class TransformerBlock(nn.Module):
|
|
305
319
|
activation_kwargs: Optional[dict] = None,
|
306
320
|
ff_linear_module_up=None,
|
307
321
|
ff_linear_module_down=None,
|
322
|
+
msa_scaling="d",
|
308
323
|
mlp_dropout=0.0,
|
309
324
|
msa_dropout=0.0,
|
310
325
|
identity_probability=0.0,
|
@@ -314,6 +329,14 @@ class TransformerBlock(nn.Module):
|
|
314
329
|
post_norm=False,
|
315
330
|
normformer=False,
|
316
331
|
):
|
332
|
+
"""
|
333
|
+
Args:
|
334
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
335
|
+
to mimic the original Attention is All You Need approach of
|
336
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
337
|
+
"Tensor Programs V...". Default "d"
|
338
|
+
"""
|
339
|
+
|
317
340
|
super().__init__()
|
318
341
|
|
319
342
|
self.pre_norm = pre_norm
|
@@ -348,6 +371,7 @@ class TransformerBlock(nn.Module):
|
|
348
371
|
rotary_embedding=self.rotary_embedding,
|
349
372
|
source_size=source_size,
|
350
373
|
bos_tokens=bos_tokens,
|
374
|
+
scaling=msa_scaling,
|
351
375
|
)
|
352
376
|
|
353
377
|
# Submodule for the feedforward process
|
@@ -429,9 +453,21 @@ class TransformerEncoder(nn.Module):
|
|
429
453
|
pre_norm=True,
|
430
454
|
post_norm=False,
|
431
455
|
normformer=False,
|
456
|
+
msa_scaling="d",
|
432
457
|
):
|
433
|
-
|
434
|
-
|
458
|
+
"""
|
459
|
+
Args:
|
460
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
461
|
+
to mimic the original Attention is All You Need approach of
|
462
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
463
|
+
"Tensor Programs V...". Default "d"
|
464
|
+
"""
|
465
|
+
|
466
|
+
if (position_embedding_type == "relative") and (source_size is None):
|
467
|
+
raise ValueError(
|
468
|
+
"`source_size` for TransformerEncoder cannot be None if"
|
469
|
+
" `position_embedding_type` is relative"
|
470
|
+
)
|
435
471
|
|
436
472
|
super().__init__()
|
437
473
|
self.seq_len = seq_len
|
@@ -461,7 +497,8 @@ class TransformerEncoder(nn.Module):
|
|
461
497
|
self.msa_dropout = msa_dropout
|
462
498
|
self.stochastic_depth = stochastic_depth
|
463
499
|
|
464
|
-
assert isinstance(n_layers, int)
|
500
|
+
assert isinstance(n_layers, int)
|
501
|
+
|
465
502
|
if n_layers == 1:
|
466
503
|
self.stochastic_depth_probabilities = [0.0]
|
467
504
|
else:
|
@@ -484,6 +521,7 @@ class TransformerEncoder(nn.Module):
|
|
484
521
|
activation_kwargs=activation_kwargs,
|
485
522
|
ff_linear_module_up=ff_linear_module_up,
|
486
523
|
ff_linear_module_down=ff_linear_module_down,
|
524
|
+
msa_scaling=msa_scaling,
|
487
525
|
mlp_dropout=mlp_dropout,
|
488
526
|
msa_dropout=msa_dropout,
|
489
527
|
identity_probability=self.stochastic_depth_probabilities[i],
|
broccoli/vit.py
CHANGED
@@ -53,12 +53,20 @@ class ClassificationHead(nn.Module):
|
|
53
53
|
A general classification head for a ViT
|
54
54
|
"""
|
55
55
|
|
56
|
-
def __init__(
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
d_model,
|
59
|
+
n_classes,
|
60
|
+
linear_module=nn.Linear,
|
61
|
+
logit_projection_layer=nn.Linear,
|
62
|
+
batch_norm_logits=True,
|
63
|
+
):
|
57
64
|
super().__init__()
|
58
65
|
self.d_model = d_model
|
59
66
|
self.summarize = GetCLSToken()
|
60
|
-
self.projection =
|
61
|
-
|
67
|
+
self.projection = logit_projection_layer(d_model, n_classes)
|
68
|
+
|
69
|
+
if batch_norm_logits:
|
62
70
|
self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
|
63
71
|
else:
|
64
72
|
self.batch_norm = nn.Identity()
|
@@ -83,7 +91,7 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
83
91
|
"""
|
84
92
|
|
85
93
|
def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
|
86
|
-
super().__init__(d_model, linear_module, out_dim,
|
94
|
+
super().__init__(d_model, linear_module, out_dim, batch_norm_logits=batch_norm)
|
87
95
|
self.summarize = SequencePool(d_model, linear_module)
|
88
96
|
# Rebuild the classification process with the correct summary module:
|
89
97
|
self.classification_process = nn.Sequential(
|
@@ -143,6 +151,7 @@ class ViTEncoder(nn.Module):
|
|
143
151
|
transformer_activation_kwargs: Optional[dict] = None,
|
144
152
|
transformer_ff_linear_module_up=None,
|
145
153
|
transformer_ff_linear_module_down=None,
|
154
|
+
transformer_msa_scaling="d",
|
146
155
|
transformer_mlp_dropout=0.0,
|
147
156
|
transformer_msa_dropout=0.1,
|
148
157
|
transformer_stochastic_depth=0.1,
|
@@ -295,6 +304,7 @@ class ViTEncoder(nn.Module):
|
|
295
304
|
activation_kwargs=transformer_activation_kwargs,
|
296
305
|
ff_linear_module_up=transformer_ff_linear_module_up,
|
297
306
|
ff_linear_module_down=transformer_ff_linear_module_down,
|
307
|
+
msa_scaling=transformer_msa_scaling,
|
298
308
|
mlp_dropout=transformer_mlp_dropout,
|
299
309
|
msa_dropout=transformer_msa_dropout,
|
300
310
|
stochastic_depth=transformer_stochastic_depth,
|
@@ -405,12 +415,14 @@ class ViT(nn.Module):
|
|
405
415
|
transformer_activation_kwargs: Optional[dict] = None,
|
406
416
|
transformer_ff_linear_module_up=None,
|
407
417
|
transformer_ff_linear_module_down=None,
|
418
|
+
transformer_msa_scaling="d",
|
408
419
|
transformer_mlp_dropout=0.0,
|
409
420
|
transformer_msa_dropout=0.1,
|
410
421
|
transformer_stochastic_depth=0.1,
|
411
|
-
batch_norm_outputs=True,
|
412
|
-
linear_module=SpectralNormLinear,
|
413
422
|
head=SequencePoolClassificationHead,
|
423
|
+
batch_norm_logits=True,
|
424
|
+
logit_projection_layer=nn.Linear,
|
425
|
+
linear_module=nn.Linear,
|
414
426
|
):
|
415
427
|
|
416
428
|
super().__init__()
|
@@ -468,6 +480,7 @@ class ViT(nn.Module):
|
|
468
480
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
469
481
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
470
482
|
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
483
|
+
transformer_msa_scaling=transformer_msa_scaling,
|
471
484
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
472
485
|
transformer_msa_dropout=transformer_msa_dropout,
|
473
486
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
@@ -476,9 +489,10 @@ class ViT(nn.Module):
|
|
476
489
|
|
477
490
|
self.pool = head(
|
478
491
|
transformer_embedding_size,
|
479
|
-
linear_module,
|
480
492
|
image_classes,
|
481
|
-
|
493
|
+
linear_module=linear_module,
|
494
|
+
logit_projection_layer=logit_projection_layer,
|
495
|
+
batch_norm=batch_norm_logits,
|
482
496
|
)
|
483
497
|
|
484
498
|
@property
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=w021EDzWVDEu9odzrf9QwBZ3G8Ydu3nroV8soIJeRng,4894
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=KhwrV5Nz0utXJQKFFf6MIifmymgOaylc_2_hs9kQt7g,18455
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-1.
|
15
|
-
broccoli_ml-1.
|
16
|
-
broccoli_ml-1.
|
17
|
-
broccoli_ml-1.
|
13
|
+
broccoli/vit.py,sha256=9dfOPO7y-o96OYpy_ancH4-Y2_Y62fqgzJvI3-86W_o,18471
|
14
|
+
broccoli_ml-1.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-1.2.0.dist-info/METADATA,sha256=am67vHH2HDQdq3G336ng56xX6-ys1LyZzLRIWaf8dE8,1256
|
16
|
+
broccoli_ml-1.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-1.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|