broccoli-ml 0.1.41__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/cnn.py +404 -322
- broccoli/transformer.py +96 -82
- broccoli/vit.py +173 -125
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -68,12 +68,7 @@ class MHAttention(nn.Module):
|
|
68
68
|
dropout=0.0,
|
69
69
|
causal=False,
|
70
70
|
seq_len=None,
|
71
|
-
share_kv=False,
|
72
71
|
linear_module: nn.Module = nn.Linear,
|
73
|
-
max_subtract=False,
|
74
|
-
d_model_scale=True,
|
75
|
-
log_length_scale=False,
|
76
|
-
quiet=False,
|
77
72
|
bos_tokens=0,
|
78
73
|
rotary_embedding=None,
|
79
74
|
source_size=None,
|
@@ -88,15 +83,15 @@ class MHAttention(nn.Module):
|
|
88
83
|
self.embed_dim = embed_dim
|
89
84
|
self.n_heads = n_heads
|
90
85
|
assert embed_dim % n_heads == 0
|
86
|
+
|
91
87
|
self.head_dim = self.embed_dim // self.n_heads
|
92
|
-
|
88
|
+
|
93
89
|
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
94
90
|
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
95
|
-
|
96
|
-
|
97
|
-
else:
|
98
|
-
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
91
|
+
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
92
|
+
|
99
93
|
self.out_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
94
|
+
|
100
95
|
self.causal = causal
|
101
96
|
self.seq_len = seq_len
|
102
97
|
self.dropout = nn.Dropout(dropout)
|
@@ -107,10 +102,6 @@ class MHAttention(nn.Module):
|
|
107
102
|
.unsqueeze(0)
|
108
103
|
.unsqueeze(0),
|
109
104
|
)
|
110
|
-
self.max_subtract = max_subtract
|
111
|
-
self.d_model_scale = d_model_scale
|
112
|
-
self.log_length_scale = log_length_scale
|
113
|
-
self.quiet = quiet
|
114
105
|
self.rotary_embedding = rotary_embedding
|
115
106
|
self.source_size = source_size
|
116
107
|
self.bos_tokens = bos_tokens
|
@@ -152,37 +143,57 @@ class MHAttention(nn.Module):
|
|
152
143
|
# Project q, k and v
|
153
144
|
q = self.q_proj(q)
|
154
145
|
k = self.k_proj(k)
|
155
|
-
|
156
|
-
v = self.k_proj(v)
|
157
|
-
else:
|
158
|
-
v = self.v_proj(v)
|
146
|
+
v = self.v_proj(v)
|
159
147
|
|
160
148
|
# Rearrange dimensions and add RoPE if needed
|
161
149
|
if self.rotary_embedding is not None:
|
162
150
|
|
151
|
+
if len(self.source_size) == 1:
|
152
|
+
spatial_dimension_names = "D1"
|
153
|
+
spatial_dimension_values = {"D1": self.source_size[0]}
|
154
|
+
elif len(self.source_size) == 2:
|
155
|
+
spatial_dimension_names = "D1 D2"
|
156
|
+
spatial_dimension_values = {
|
157
|
+
"D1": self.source_size[0],
|
158
|
+
"D2": self.source_size[1],
|
159
|
+
}
|
160
|
+
elif len(self.source_size) == 3:
|
161
|
+
spatial_dimension_names = "D1 D2 D3"
|
162
|
+
spatial_dimension_values = {
|
163
|
+
"D1": self.source_size[0],
|
164
|
+
"D2": self.source_size[1],
|
165
|
+
"D3": self.source_size[2],
|
166
|
+
}
|
167
|
+
else:
|
168
|
+
raise NotImplementedError(
|
169
|
+
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
170
|
+
)
|
171
|
+
|
163
172
|
q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
|
164
173
|
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
165
174
|
|
166
175
|
q_img = rearrange(
|
167
176
|
q_img,
|
168
|
-
"b (
|
169
|
-
|
170
|
-
width=self.source_size[1],
|
177
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
178
|
+
**spatial_dimension_values,
|
171
179
|
)
|
172
180
|
k_img = rearrange(
|
173
181
|
k_img,
|
174
|
-
"b (
|
175
|
-
|
176
|
-
width=self.source_size[1],
|
177
|
-
)
|
178
|
-
freqs = self.rotary_embedding.get_axial_freqs(
|
179
|
-
self.source_size[0], self.source_size[1]
|
182
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
183
|
+
**spatial_dimension_values,
|
180
184
|
)
|
185
|
+
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
181
186
|
q_img = apply_rotary_emb(freqs, q_img)
|
182
187
|
k_img = apply_rotary_emb(freqs, k_img)
|
183
188
|
|
184
|
-
q_img = rearrange(
|
185
|
-
|
189
|
+
q_img = rearrange(
|
190
|
+
q_img,
|
191
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
192
|
+
)
|
193
|
+
k_img = rearrange(
|
194
|
+
k_img,
|
195
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
196
|
+
)
|
186
197
|
|
187
198
|
# Re-combine the BOS tokens and the RoPE-enhanced image tokens
|
188
199
|
q = torch.cat([q_bos, q_img], dim=1)
|
@@ -195,26 +206,13 @@ class MHAttention(nn.Module):
|
|
195
206
|
|
196
207
|
qk_scores = q @ k.transpose(-1, -2)
|
197
208
|
|
198
|
-
|
199
|
-
qk_scores /= math.sqrt(self.head_dim) # scaling
|
200
|
-
|
201
|
-
if self.log_length_scale:
|
202
|
-
qk_scores *= math.log(qk_scores.size(0))
|
203
|
-
|
204
|
-
if self.max_subtract:
|
205
|
-
max_scores, _ = torch.max(qk_scores, dim=-1, keepdim=True)
|
206
|
-
qk_scores -= max_scores
|
209
|
+
qk_scores /= math.sqrt(self.head_dim)
|
207
210
|
|
208
211
|
# Apply mask if causal (must come before softmax)
|
209
212
|
if self.causal:
|
210
213
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
211
214
|
|
212
|
-
|
213
|
-
denominator = torch.sum(torch.exp(qk_scores), dim=-1, keepdim=True)
|
214
|
-
if self.quiet:
|
215
|
-
denominator += 1
|
216
|
-
numerator = torch.exp(qk_scores)
|
217
|
-
qk_scores = self.dropout(numerator / denominator)
|
215
|
+
qk_scores = F.softmax(qk_scores, dim=-1)
|
218
216
|
|
219
217
|
output_with_heads = qk_scores @ v
|
220
218
|
|
@@ -223,6 +221,50 @@ class MHAttention(nn.Module):
|
|
223
221
|
return self.out_proj(output_without_heads)
|
224
222
|
|
225
223
|
|
224
|
+
class DenoisingAutoEncoder(nn.Module):
|
225
|
+
"""
|
226
|
+
A denoising autoencoder, of the type used in transformer blocks.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def __init__(
|
230
|
+
self,
|
231
|
+
input_features,
|
232
|
+
ratio,
|
233
|
+
output_features,
|
234
|
+
activation=nn.ReLU,
|
235
|
+
activation_kwargs=None,
|
236
|
+
dropout=0.0,
|
237
|
+
linear_module=nn.Linear,
|
238
|
+
):
|
239
|
+
super().__init__()
|
240
|
+
|
241
|
+
if activation_kwargs is not None:
|
242
|
+
self.activation = activation(**activation_kwargs)
|
243
|
+
else:
|
244
|
+
self.activation = activation()
|
245
|
+
|
246
|
+
self.dropout = nn.Dropout(dropout)
|
247
|
+
|
248
|
+
self.process = nn.Sequential(
|
249
|
+
*[
|
250
|
+
linear_module(
|
251
|
+
input_features,
|
252
|
+
(
|
253
|
+
2 * ratio * input_features
|
254
|
+
if activation.__name__.endswith("GLU")
|
255
|
+
else ratio * input_features
|
256
|
+
),
|
257
|
+
),
|
258
|
+
self.activation,
|
259
|
+
self.dropout,
|
260
|
+
linear_module(ratio * input_features, output_features),
|
261
|
+
]
|
262
|
+
)
|
263
|
+
|
264
|
+
def forward(self, x):
|
265
|
+
return self.process(x)
|
266
|
+
|
267
|
+
|
226
268
|
class TransformerBlock(nn.Module):
|
227
269
|
"""
|
228
270
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
@@ -247,28 +289,18 @@ class TransformerBlock(nn.Module):
|
|
247
289
|
msa_dropout=0.0,
|
248
290
|
identity_probability=0.0,
|
249
291
|
causal=False,
|
250
|
-
share_kv=False,
|
251
|
-
max_subtract=False,
|
252
|
-
d_model_scale=True,
|
253
|
-
log_length_scale=False,
|
254
|
-
quiet_attention=False,
|
255
292
|
linear_module=nn.Linear,
|
256
293
|
):
|
257
294
|
super().__init__()
|
258
295
|
|
259
296
|
self.identity_probability = identity_probability
|
260
297
|
|
261
|
-
if activation_kwargs is not None:
|
262
|
-
self.activation = activation(**activation_kwargs)
|
263
|
-
else:
|
264
|
-
self.activation = activation()
|
265
|
-
|
266
298
|
# Submodules for applying attention
|
267
299
|
self.layer_norm = nn.LayerNorm(d_model)
|
268
300
|
|
269
301
|
if position_embedding_type == "relative":
|
270
302
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
271
|
-
if d_model <
|
303
|
+
if d_model < 16:
|
272
304
|
dim = d_model
|
273
305
|
else:
|
274
306
|
dim = 16
|
@@ -284,11 +316,6 @@ class TransformerBlock(nn.Module):
|
|
284
316
|
dropout=msa_dropout,
|
285
317
|
causal=causal,
|
286
318
|
seq_len=seq_len,
|
287
|
-
share_kv=share_kv,
|
288
|
-
max_subtract=max_subtract,
|
289
|
-
d_model_scale=d_model_scale,
|
290
|
-
log_length_scale=log_length_scale,
|
291
|
-
quiet=quiet_attention,
|
292
319
|
linear_module=linear_module,
|
293
320
|
rotary_embedding=self.rotary_embedding,
|
294
321
|
source_size=source_size,
|
@@ -301,20 +328,17 @@ class TransformerBlock(nn.Module):
|
|
301
328
|
[
|
302
329
|
("layer_norm", nn.LayerNorm(d_model)),
|
303
330
|
(
|
304
|
-
|
305
|
-
|
306
|
-
|
331
|
+
"denoising_autoencoder",
|
332
|
+
DenoisingAutoEncoder(
|
333
|
+
d_model,
|
334
|
+
mlp_ratio,
|
307
335
|
d_model,
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
),
|
336
|
+
activation=activation,
|
337
|
+
activation_kwargs=activation_kwargs,
|
338
|
+
dropout=0.0,
|
339
|
+
linear_module=linear_module,
|
313
340
|
),
|
314
341
|
),
|
315
|
-
# xGLU activations will halve embedding size
|
316
|
-
("activation", self.activation),
|
317
|
-
("down_projection", linear_module(mlp_ratio * d_model, d_model)),
|
318
342
|
("dropout", nn.Dropout(mlp_dropout)),
|
319
343
|
]
|
320
344
|
)
|
@@ -369,11 +393,6 @@ class TransformerEncoder(nn.Module):
|
|
369
393
|
msa_dropout=0.0,
|
370
394
|
stochastic_depth=0.0,
|
371
395
|
causal=False,
|
372
|
-
share_kv=False,
|
373
|
-
max_subtract=False,
|
374
|
-
d_model_scale=True,
|
375
|
-
log_length_scale=False,
|
376
|
-
quiet_attention=False,
|
377
396
|
linear_module=nn.Linear,
|
378
397
|
bos_tokens=0,
|
379
398
|
):
|
@@ -419,7 +438,7 @@ class TransformerEncoder(nn.Module):
|
|
419
438
|
self.blocks = nn.ModuleList(
|
420
439
|
[
|
421
440
|
TransformerBlock(
|
422
|
-
|
441
|
+
self.full_sequence_length,
|
423
442
|
d_model,
|
424
443
|
n_heads,
|
425
444
|
position_embedding_type=position_embedding_type,
|
@@ -432,11 +451,6 @@ class TransformerEncoder(nn.Module):
|
|
432
451
|
msa_dropout=msa_dropout,
|
433
452
|
identity_probability=self.stochastic_depth_probabilities[i],
|
434
453
|
causal=causal,
|
435
|
-
share_kv=share_kv,
|
436
|
-
max_subtract=max_subtract,
|
437
|
-
d_model_scale=d_model_scale,
|
438
|
-
log_length_scale=log_length_scale,
|
439
|
-
quiet_attention=quiet_attention,
|
440
454
|
linear_module=linear_module,
|
441
455
|
)
|
442
456
|
for i in range(n_layers)
|