broccoli-ml 13.0.4__tar.gz → 13.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/PKG-INFO +1 -1
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/transformer.py +14 -14
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/vit.py +3 -1
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/pyproject.toml +1 -1
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/LICENSE +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/README.md +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/__init__.py +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/activation.py +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/cnn.py +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/linear.py +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/rope.py +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/tensor.py +0 -0
- {broccoli_ml-13.0.4 → broccoli_ml-13.0.6}/broccoli/utils.py +0 -0
|
@@ -200,26 +200,26 @@ class MHAttention(nn.Module):
|
|
|
200
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
201
201
|
)
|
|
202
202
|
|
|
203
|
-
q = rearrange(q, "b t (h d) -> b t
|
|
204
|
-
k = rearrange(k, "b t (h d) -> b t
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
205
205
|
|
|
206
206
|
q_util, q_img = (
|
|
207
|
-
q[:, : self.utility_tokens,
|
|
208
|
-
q[:, self.utility_tokens :,
|
|
207
|
+
q[:, :, : self.utility_tokens, :],
|
|
208
|
+
q[:, :, self.utility_tokens :, :],
|
|
209
209
|
)
|
|
210
210
|
k_util, k_img = (
|
|
211
|
-
k[:, : self.utility_tokens,
|
|
212
|
-
k[:, self.utility_tokens :,
|
|
211
|
+
k[:, :, : self.utility_tokens, :],
|
|
212
|
+
k[:, :, self.utility_tokens :, :],
|
|
213
213
|
)
|
|
214
214
|
|
|
215
215
|
q_img = rearrange(
|
|
216
216
|
q_img,
|
|
217
|
-
f"b ({spatial_dimension_names})
|
|
217
|
+
f"b h ({spatial_dimension_names}) d -> b h {spatial_dimension_names} d",
|
|
218
218
|
**spatial_dimension_values,
|
|
219
219
|
)
|
|
220
220
|
k_img = rearrange(
|
|
221
221
|
k_img,
|
|
222
|
-
f"b ({spatial_dimension_names})
|
|
222
|
+
f"b h ({spatial_dimension_names}) d -> b h {spatial_dimension_names} d",
|
|
223
223
|
**spatial_dimension_values,
|
|
224
224
|
)
|
|
225
225
|
|
|
@@ -230,19 +230,19 @@ class MHAttention(nn.Module):
|
|
|
230
230
|
|
|
231
231
|
q_img = rearrange(
|
|
232
232
|
q_img,
|
|
233
|
-
f"b {spatial_dimension_names}
|
|
233
|
+
f"b h {spatial_dimension_names} d -> b h ({spatial_dimension_names}) d",
|
|
234
234
|
)
|
|
235
235
|
k_img = rearrange(
|
|
236
236
|
k_img,
|
|
237
|
-
f"b {spatial_dimension_names}
|
|
237
|
+
f"b h {spatial_dimension_names} d -> b h ({spatial_dimension_names}) d",
|
|
238
238
|
)
|
|
239
239
|
|
|
240
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
241
|
-
q = torch.cat([q_util, q_img], dim=
|
|
242
|
-
k = torch.cat([k_util, k_img], dim=
|
|
241
|
+
q = torch.cat([q_util, q_img], dim=2)
|
|
242
|
+
k = torch.cat([k_util, k_img], dim=2)
|
|
243
243
|
|
|
244
|
-
q = rearrange(q, "b t
|
|
245
|
-
k = rearrange(k, "b t
|
|
244
|
+
q = rearrange(q, "b h t d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b h t d -> b t (h d)")
|
|
246
246
|
|
|
247
247
|
return q, k
|
|
248
248
|
|
|
@@ -433,9 +433,11 @@ class ViTEncoder(nn.Module):
|
|
|
433
433
|
return self.transformer.attention_logits(x)
|
|
434
434
|
|
|
435
435
|
def reset_parameters(self):
|
|
436
|
-
for module in self.
|
|
436
|
+
for module in self.preprocess:
|
|
437
437
|
if hasattr(module, "reset_parameters"):
|
|
438
438
|
module.reset_parameters()
|
|
439
|
+
self.initial_ff.reset_parameters()
|
|
440
|
+
self.transformer.reset_parameters()
|
|
439
441
|
|
|
440
442
|
|
|
441
443
|
class ViT(nn.Module):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|