broccoli-ml 0.1.41__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -68,12 +68,7 @@ class MHAttention(nn.Module):
68
68
  dropout=0.0,
69
69
  causal=False,
70
70
  seq_len=None,
71
- share_kv=False,
72
71
  linear_module: nn.Module = nn.Linear,
73
- max_subtract=False,
74
- d_model_scale=True,
75
- log_length_scale=False,
76
- quiet=False,
77
72
  bos_tokens=0,
78
73
  rotary_embedding=None,
79
74
  source_size=None,
@@ -88,15 +83,15 @@ class MHAttention(nn.Module):
88
83
  self.embed_dim = embed_dim
89
84
  self.n_heads = n_heads
90
85
  assert embed_dim % n_heads == 0
86
+
91
87
  self.head_dim = self.embed_dim // self.n_heads
92
- self.share_kv = share_kv
88
+
93
89
  self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
94
90
  self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
95
- if self.share_kv:
96
- self.v_proj = self.k_proj
97
- else:
98
- self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
91
+ self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
92
+
99
93
  self.out_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
94
+
100
95
  self.causal = causal
101
96
  self.seq_len = seq_len
102
97
  self.dropout = nn.Dropout(dropout)
@@ -107,10 +102,6 @@ class MHAttention(nn.Module):
107
102
  .unsqueeze(0)
108
103
  .unsqueeze(0),
109
104
  )
110
- self.max_subtract = max_subtract
111
- self.d_model_scale = d_model_scale
112
- self.log_length_scale = log_length_scale
113
- self.quiet = quiet
114
105
  self.rotary_embedding = rotary_embedding
115
106
  self.source_size = source_size
116
107
  self.bos_tokens = bos_tokens
@@ -152,37 +143,57 @@ class MHAttention(nn.Module):
152
143
  # Project q, k and v
153
144
  q = self.q_proj(q)
154
145
  k = self.k_proj(k)
155
- if self.share_kv:
156
- v = self.k_proj(v)
157
- else:
158
- v = self.v_proj(v)
146
+ v = self.v_proj(v)
159
147
 
160
148
  # Rearrange dimensions and add RoPE if needed
161
149
  if self.rotary_embedding is not None:
162
150
 
151
+ if len(self.source_size) == 1:
152
+ spatial_dimension_names = "D1"
153
+ spatial_dimension_values = {"D1": self.source_size[0]}
154
+ elif len(self.source_size) == 2:
155
+ spatial_dimension_names = "D1 D2"
156
+ spatial_dimension_values = {
157
+ "D1": self.source_size[0],
158
+ "D2": self.source_size[1],
159
+ }
160
+ elif len(self.source_size) == 3:
161
+ spatial_dimension_names = "D1 D2 D3"
162
+ spatial_dimension_values = {
163
+ "D1": self.source_size[0],
164
+ "D2": self.source_size[1],
165
+ "D3": self.source_size[2],
166
+ }
167
+ else:
168
+ raise NotImplementedError(
169
+ "`source_size` must be a tuple of 1, 2 or 3 integers"
170
+ )
171
+
163
172
  q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
164
173
  k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
165
174
 
166
175
  q_img = rearrange(
167
176
  q_img,
168
- "b (height width) d -> b height width d",
169
- height=self.source_size[0],
170
- width=self.source_size[1],
177
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
178
+ **spatial_dimension_values,
171
179
  )
172
180
  k_img = rearrange(
173
181
  k_img,
174
- "b (height width) d -> b height width d",
175
- height=self.source_size[0],
176
- width=self.source_size[1],
177
- )
178
- freqs = self.rotary_embedding.get_axial_freqs(
179
- self.source_size[0], self.source_size[1]
182
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
183
+ **spatial_dimension_values,
180
184
  )
185
+ freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
181
186
  q_img = apply_rotary_emb(freqs, q_img)
182
187
  k_img = apply_rotary_emb(freqs, k_img)
183
188
 
184
- q_img = rearrange(q_img, "b height width d -> b (height width) d")
185
- k_img = rearrange(k_img, "b height width d -> b (height width) d")
189
+ q_img = rearrange(
190
+ q_img,
191
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
192
+ )
193
+ k_img = rearrange(
194
+ k_img,
195
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
196
+ )
186
197
 
187
198
  # Re-combine the BOS tokens and the RoPE-enhanced image tokens
188
199
  q = torch.cat([q_bos, q_img], dim=1)
@@ -195,26 +206,13 @@ class MHAttention(nn.Module):
195
206
 
196
207
  qk_scores = q @ k.transpose(-1, -2)
197
208
 
198
- if self.d_model_scale:
199
- qk_scores /= math.sqrt(self.head_dim) # scaling
200
-
201
- if self.log_length_scale:
202
- qk_scores *= math.log(qk_scores.size(0))
203
-
204
- if self.max_subtract:
205
- max_scores, _ = torch.max(qk_scores, dim=-1, keepdim=True)
206
- qk_scores -= max_scores
209
+ qk_scores /= math.sqrt(self.head_dim)
207
210
 
208
211
  # Apply mask if causal (must come before softmax)
209
212
  if self.causal:
210
213
  qk_scores.masked_fill_(self.mask, float("-inf"))
211
214
 
212
- # Apply softmax and dropout
213
- denominator = torch.sum(torch.exp(qk_scores), dim=-1, keepdim=True)
214
- if self.quiet:
215
- denominator += 1
216
- numerator = torch.exp(qk_scores)
217
- qk_scores = self.dropout(numerator / denominator)
215
+ qk_scores = F.softmax(qk_scores, dim=-1)
218
216
 
219
217
  output_with_heads = qk_scores @ v
220
218
 
@@ -223,6 +221,50 @@ class MHAttention(nn.Module):
223
221
  return self.out_proj(output_without_heads)
224
222
 
225
223
 
224
+ class DenoisingAutoEncoder(nn.Module):
225
+ """
226
+ A denoising autoencoder, of the type used in transformer blocks.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ input_features,
232
+ ratio,
233
+ output_features,
234
+ activation=nn.ReLU,
235
+ activation_kwargs=None,
236
+ dropout=0.0,
237
+ linear_module=nn.Linear,
238
+ ):
239
+ super().__init__()
240
+
241
+ if activation_kwargs is not None:
242
+ self.activation = activation(**activation_kwargs)
243
+ else:
244
+ self.activation = activation()
245
+
246
+ self.dropout = nn.Dropout(dropout)
247
+
248
+ self.process = nn.Sequential(
249
+ *[
250
+ linear_module(
251
+ input_features,
252
+ (
253
+ 2 * ratio * input_features
254
+ if activation.__name__.endswith("GLU")
255
+ else ratio * input_features
256
+ ),
257
+ ),
258
+ self.activation,
259
+ self.dropout,
260
+ linear_module(ratio * input_features, output_features),
261
+ ]
262
+ )
263
+
264
+ def forward(self, x):
265
+ return self.process(x)
266
+
267
+
226
268
  class TransformerBlock(nn.Module):
227
269
  """
228
270
  Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
@@ -247,28 +289,18 @@ class TransformerBlock(nn.Module):
247
289
  msa_dropout=0.0,
248
290
  identity_probability=0.0,
249
291
  causal=False,
250
- share_kv=False,
251
- max_subtract=False,
252
- d_model_scale=True,
253
- log_length_scale=False,
254
- quiet_attention=False,
255
292
  linear_module=nn.Linear,
256
293
  ):
257
294
  super().__init__()
258
295
 
259
296
  self.identity_probability = identity_probability
260
297
 
261
- if activation_kwargs is not None:
262
- self.activation = activation(**activation_kwargs)
263
- else:
264
- self.activation = activation()
265
-
266
298
  # Submodules for applying attention
267
299
  self.layer_norm = nn.LayerNorm(d_model)
268
300
 
269
301
  if position_embedding_type == "relative":
270
302
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
271
- if d_model < 48:
303
+ if d_model < 16:
272
304
  dim = d_model
273
305
  else:
274
306
  dim = 16
@@ -284,11 +316,6 @@ class TransformerBlock(nn.Module):
284
316
  dropout=msa_dropout,
285
317
  causal=causal,
286
318
  seq_len=seq_len,
287
- share_kv=share_kv,
288
- max_subtract=max_subtract,
289
- d_model_scale=d_model_scale,
290
- log_length_scale=log_length_scale,
291
- quiet=quiet_attention,
292
319
  linear_module=linear_module,
293
320
  rotary_embedding=self.rotary_embedding,
294
321
  source_size=source_size,
@@ -301,20 +328,17 @@ class TransformerBlock(nn.Module):
301
328
  [
302
329
  ("layer_norm", nn.LayerNorm(d_model)),
303
330
  (
304
- # up_projection is appropriate to activation
305
- "up_projection",
306
- linear_module(
331
+ "denoising_autoencoder",
332
+ DenoisingAutoEncoder(
333
+ d_model,
334
+ mlp_ratio,
307
335
  d_model,
308
- (
309
- 2 * mlp_ratio * d_model
310
- if activation.__name__.endswith("GLU")
311
- else mlp_ratio * d_model
312
- ),
336
+ activation=activation,
337
+ activation_kwargs=activation_kwargs,
338
+ dropout=0.0,
339
+ linear_module=linear_module,
313
340
  ),
314
341
  ),
315
- # xGLU activations will halve embedding size
316
- ("activation", self.activation),
317
- ("down_projection", linear_module(mlp_ratio * d_model, d_model)),
318
342
  ("dropout", nn.Dropout(mlp_dropout)),
319
343
  ]
320
344
  )
@@ -369,11 +393,6 @@ class TransformerEncoder(nn.Module):
369
393
  msa_dropout=0.0,
370
394
  stochastic_depth=0.0,
371
395
  causal=False,
372
- share_kv=False,
373
- max_subtract=False,
374
- d_model_scale=True,
375
- log_length_scale=False,
376
- quiet_attention=False,
377
396
  linear_module=nn.Linear,
378
397
  bos_tokens=0,
379
398
  ):
@@ -419,7 +438,7 @@ class TransformerEncoder(nn.Module):
419
438
  self.blocks = nn.ModuleList(
420
439
  [
421
440
  TransformerBlock(
422
- seq_len,
441
+ self.full_sequence_length,
423
442
  d_model,
424
443
  n_heads,
425
444
  position_embedding_type=position_embedding_type,
@@ -432,11 +451,6 @@ class TransformerEncoder(nn.Module):
432
451
  msa_dropout=msa_dropout,
433
452
  identity_probability=self.stochastic_depth_probabilities[i],
434
453
  causal=causal,
435
- share_kv=share_kv,
436
- max_subtract=max_subtract,
437
- d_model_scale=d_model_scale,
438
- log_length_scale=log_length_scale,
439
- quiet_attention=quiet_attention,
440
454
  linear_module=linear_module,
441
455
  )
442
456
  for i in range(n_layers)