broccoli-ml 9.2.2__py3-none-any.whl → 13.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/activation.py +1 -4
- broccoli/linear.py +123 -72
- broccoli/transformer.py +155 -64
- broccoli/vit.py +27 -12
- {broccoli_ml-9.2.2.dist-info → broccoli_ml-13.0.1.dist-info}/METADATA +1 -1
- broccoli_ml-13.0.1.dist-info/RECORD +13 -0
- broccoli_ml-9.2.2.dist-info/RECORD +0 -13
- {broccoli_ml-9.2.2.dist-info → broccoli_ml-13.0.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-9.2.2.dist-info → broccoli_ml-13.0.1.dist-info}/WHEEL +0 -0
broccoli/activation.py
CHANGED
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
broccoli/linear.py
CHANGED
|
@@ -151,11 +151,13 @@ class RecyclingLinear(nn.Module):
|
|
|
151
151
|
row_recycling_rate: float = 0.0,
|
|
152
152
|
column_recycling_rate: float = 0.0,
|
|
153
153
|
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
154
155
|
):
|
|
155
156
|
super().__init__()
|
|
156
157
|
self.in_features = in_features
|
|
157
158
|
self.out_features = out_features
|
|
158
159
|
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
159
161
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
160
162
|
self.row_recycling_rate = row_recycling_rate
|
|
161
163
|
self.column_recycling_rate = column_recycling_rate
|
|
@@ -191,28 +193,60 @@ class RecyclingLinear(nn.Module):
|
|
|
191
193
|
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
192
194
|
return min(multipliers) if multipliers else 0.0
|
|
193
195
|
|
|
194
|
-
def
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
196
|
+
def reset_rows(self, indices, incoming_data=None):
|
|
197
|
+
"""
|
|
198
|
+
Resets rows.
|
|
199
|
+
If incoming_data is provided, resets to the centroid (mean) of that data.
|
|
200
|
+
If not, resets to the mean of existing weights.
|
|
201
|
+
"""
|
|
202
|
+
if not torch.is_tensor(indices):
|
|
203
|
+
idx_tensor = torch.as_tensor(
|
|
204
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
idx_tensor = indices
|
|
198
208
|
|
|
199
|
-
if
|
|
209
|
+
if idx_tensor.numel() == 0:
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
if incoming_data is not None:
|
|
213
|
+
target_center = self._mean_input_weights(incoming_data)
|
|
214
|
+
else:
|
|
215
|
+
target_center = self._mean_value_weights()
|
|
200
216
|
|
|
201
|
-
|
|
202
|
-
probs = torch.rand(self.linear.out_features, device=x.device)
|
|
203
|
-
mask = probs < row_recycling_rate
|
|
204
|
-
if mask.any():
|
|
205
|
-
# nonzero returns [N, 1], squeeze to get [N]
|
|
206
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
207
|
-
self.reset_rows(indices, self.optimisers)
|
|
208
|
-
|
|
209
|
-
if col_recycling_rate > 0:
|
|
210
|
-
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
211
|
-
mask = probs < col_recycling_rate
|
|
212
|
-
if mask.any():
|
|
213
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
214
|
-
self.reset_columns(indices, self.optimisers)
|
|
217
|
+
target_center = target_center.expand(idx_tensor.size(0), -1)
|
|
215
218
|
|
|
219
|
+
if self.xglu:
|
|
220
|
+
gate_indices = idx_tensor
|
|
221
|
+
value_indices = idx_tensor + (self.linear.out_features // 2)
|
|
222
|
+
self._update_weights(gate_indices, 0, target_center, self.optimisers)
|
|
223
|
+
self._update_weights(value_indices, 0, target_center, self.optimisers)
|
|
224
|
+
else:
|
|
225
|
+
self._update_weights(idx_tensor, 0, target_center, self.optimisers)
|
|
226
|
+
|
|
227
|
+
def reset_columns(self, indices):
|
|
228
|
+
if not torch.is_tensor(indices):
|
|
229
|
+
idx_tensor = torch.as_tensor(
|
|
230
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
idx_tensor = indices
|
|
234
|
+
|
|
235
|
+
if idx_tensor.size(0):
|
|
236
|
+
random_weights = self._random_weights(
|
|
237
|
+
self.linear.weight.size(0), indices.size(0)
|
|
238
|
+
)
|
|
239
|
+
# Make random col weights quiet so they don't introduce loud noise...
|
|
240
|
+
# ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
|
|
241
|
+
random_weights *= 0.1
|
|
242
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
243
|
+
else:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
def forward(self, x):
|
|
247
|
+
if self.training and self.optimisers:
|
|
248
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
249
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
216
250
|
elif self.training and not self._warned_about_registration:
|
|
217
251
|
warnings.warn(
|
|
218
252
|
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
@@ -222,82 +256,99 @@ class RecyclingLinear(nn.Module):
|
|
|
222
256
|
|
|
223
257
|
return self.linear(x)
|
|
224
258
|
|
|
225
|
-
def
|
|
226
|
-
self
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if not isinstance(optimisers, list):
|
|
238
|
-
optimisers = [optimisers]
|
|
259
|
+
def get_reset_indices(self, dim):
|
|
260
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
261
|
+
p = base_rate * self._get_multiplier()
|
|
262
|
+
if dim == 0:
|
|
263
|
+
if self.xglu:
|
|
264
|
+
sample_space = self.linear.out_features // 2
|
|
265
|
+
else:
|
|
266
|
+
sample_space = self.linear.out_features
|
|
267
|
+
elif dim == 1:
|
|
268
|
+
sample_space = self.linear.in_features
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
239
271
|
|
|
272
|
+
# Sample the indices
|
|
273
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
274
|
+
mask = probs < p
|
|
275
|
+
if mask.any():
|
|
276
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
277
|
+
else:
|
|
278
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
279
|
+
|
|
280
|
+
def _random_weights(self, rows, columns):
|
|
240
281
|
device = self.linear.weight.device
|
|
241
|
-
|
|
282
|
+
weights = self.linear.weight.data
|
|
283
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
284
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
285
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
286
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
287
|
+
return random_weights
|
|
242
288
|
|
|
243
|
-
|
|
244
|
-
|
|
289
|
+
def _mean_input_weights(self, input):
|
|
290
|
+
reduce_dims = list(range(input.ndim - 1))
|
|
291
|
+
data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
|
|
245
292
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
self.linear.weight.data[idx_tensor] = update_data
|
|
293
|
+
weights = self.linear.weight.data
|
|
294
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
295
|
+
data_norm = data_mean.std() + 1e-6
|
|
296
|
+
scale_factor = stdv / data_norm
|
|
297
|
+
|
|
298
|
+
return data_mean * scale_factor
|
|
253
299
|
|
|
254
|
-
|
|
255
|
-
|
|
300
|
+
def _mean_value_weights(self):
|
|
301
|
+
"""
|
|
302
|
+
Only used when self.xglu
|
|
303
|
+
"""
|
|
304
|
+
weights = self.linear.weight.data
|
|
305
|
+
rows = weights.size(0)
|
|
306
|
+
if self.xglu:
|
|
307
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
308
|
+
else:
|
|
309
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
256
310
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
311
|
+
def _mean_gate_weights(self):
|
|
312
|
+
"""
|
|
313
|
+
Only used when self.xglu
|
|
314
|
+
"""
|
|
315
|
+
weights = self.linear.weight.data
|
|
316
|
+
rows = weights.size(0)
|
|
317
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
260
318
|
|
|
261
|
-
def
|
|
319
|
+
def _update_weights(
|
|
262
320
|
self,
|
|
263
321
|
indices: Iterable[int],
|
|
322
|
+
dim: int,
|
|
323
|
+
data: torch.Tensor,
|
|
264
324
|
optimisers: Union[
|
|
265
325
|
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
266
326
|
] = None,
|
|
267
327
|
):
|
|
268
|
-
"""
|
|
269
|
-
Update some of the weight columns to be random as though reinitialised.
|
|
270
|
-
"""
|
|
271
328
|
if optimisers is None:
|
|
272
329
|
optimisers = []
|
|
273
330
|
if not isinstance(optimisers, list):
|
|
274
331
|
optimisers = [optimisers]
|
|
275
332
|
|
|
276
|
-
|
|
277
|
-
|
|
333
|
+
if not torch.is_tensor(indices):
|
|
334
|
+
idx_tensor = torch.as_tensor(
|
|
335
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
idx_tensor = indices
|
|
278
339
|
|
|
279
340
|
if idx_tensor.numel() == 0:
|
|
280
341
|
return
|
|
281
342
|
|
|
282
343
|
with torch.no_grad():
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
)
|
|
292
|
-
random_weights = (random_weights - 0.5) * 2.0 * stdv
|
|
293
|
-
|
|
294
|
-
# 2. Update Weights (One-shot)
|
|
295
|
-
# We assign into the columns specified by idx_tensor
|
|
296
|
-
self.linear.weight.data[:, idx_tensor] = random_weights
|
|
297
|
-
|
|
298
|
-
# 3. Update Optimizers
|
|
299
|
-
# Bias is untouched by column resets (bias is shape [Out], cols are [In])
|
|
300
|
-
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
|
|
344
|
+
if dim == 0:
|
|
345
|
+
self.linear.weight.data[idx_tensor] = data
|
|
346
|
+
elif dim == 1:
|
|
347
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
350
|
+
|
|
351
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
301
352
|
|
|
302
353
|
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
303
354
|
"""
|
broccoli/transformer.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
3
|
from typing import Optional, Tuple
|
|
3
4
|
|
|
@@ -20,6 +21,12 @@ except ImportError:
|
|
|
20
21
|
FLASH_ATTN = False
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
def scale_parameters(torch_module: nn.Module, factor: float):
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
for param in torch_module.parameters():
|
|
27
|
+
param.mul_(factor)
|
|
28
|
+
|
|
29
|
+
|
|
23
30
|
def drop_path(
|
|
24
31
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
25
32
|
):
|
|
@@ -78,9 +85,11 @@ class MHAttention(nn.Module):
|
|
|
78
85
|
seq_len=None,
|
|
79
86
|
linear_module: nn.Module = nn.Linear,
|
|
80
87
|
utility_tokens=0,
|
|
88
|
+
talking_heads=False,
|
|
81
89
|
rotary_embedding=None,
|
|
82
90
|
source_size=None,
|
|
83
91
|
scaling="d",
|
|
92
|
+
beta=1.0,
|
|
84
93
|
):
|
|
85
94
|
"""
|
|
86
95
|
Args:
|
|
@@ -96,10 +105,20 @@ class MHAttention(nn.Module):
|
|
|
96
105
|
if causal:
|
|
97
106
|
assert seq_len is not None
|
|
98
107
|
|
|
108
|
+
self.talking_heads = talking_heads
|
|
109
|
+
|
|
110
|
+
if self.talking_heads:
|
|
111
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
112
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
113
|
+
else:
|
|
114
|
+
self.head_projection = None
|
|
115
|
+
self.sample_projection = None
|
|
116
|
+
|
|
99
117
|
self.embed_dim = embed_dim
|
|
100
118
|
self.n_heads = n_heads
|
|
101
119
|
assert embed_dim % n_heads == 0
|
|
102
120
|
self.scaling = scaling
|
|
121
|
+
self.beta = beta
|
|
103
122
|
|
|
104
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
105
124
|
|
|
@@ -181,17 +200,26 @@ class MHAttention(nn.Module):
|
|
|
181
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
182
201
|
)
|
|
183
202
|
|
|
184
|
-
|
|
185
|
-
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
205
|
+
|
|
206
|
+
q_util, q_img = (
|
|
207
|
+
q[:, : self.utility_tokens, :, :],
|
|
208
|
+
q[:, self.utility_tokens :, :, :],
|
|
209
|
+
)
|
|
210
|
+
k_util, k_img = (
|
|
211
|
+
k[:, : self.utility_tokens, :, :],
|
|
212
|
+
k[:, self.utility_tokens :, :, :],
|
|
213
|
+
)
|
|
186
214
|
|
|
187
215
|
q_img = rearrange(
|
|
188
216
|
q_img,
|
|
189
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
217
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
190
218
|
**spatial_dimension_values,
|
|
191
219
|
)
|
|
192
220
|
k_img = rearrange(
|
|
193
221
|
k_img,
|
|
194
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
222
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
195
223
|
**spatial_dimension_values,
|
|
196
224
|
)
|
|
197
225
|
|
|
@@ -202,17 +230,20 @@ class MHAttention(nn.Module):
|
|
|
202
230
|
|
|
203
231
|
q_img = rearrange(
|
|
204
232
|
q_img,
|
|
205
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
233
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
206
234
|
)
|
|
207
235
|
k_img = rearrange(
|
|
208
236
|
k_img,
|
|
209
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
237
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
210
238
|
)
|
|
211
239
|
|
|
212
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
213
241
|
q = torch.cat([q_util, q_img], dim=1)
|
|
214
242
|
k = torch.cat([k_util, k_img], dim=1)
|
|
215
243
|
|
|
244
|
+
q = rearrange(q, "b t h d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b t h d -> b t (h d)")
|
|
246
|
+
|
|
216
247
|
return q, k
|
|
217
248
|
|
|
218
249
|
def project_qkv(
|
|
@@ -243,7 +274,7 @@ class MHAttention(nn.Module):
|
|
|
243
274
|
|
|
244
275
|
q, k, v = self.project_qkv(q, k, v)
|
|
245
276
|
|
|
246
|
-
if FLASH_ATTN:
|
|
277
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
247
278
|
# Divide Q/K/V into heads
|
|
248
279
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
280
|
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
@@ -271,12 +302,22 @@ class MHAttention(nn.Module):
|
|
|
271
302
|
|
|
272
303
|
qk_scores *= self.scaling_factor
|
|
273
304
|
|
|
305
|
+
if self.talking_heads:
|
|
306
|
+
qk_scores = torch.einsum(
|
|
307
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
308
|
+
)
|
|
309
|
+
|
|
274
310
|
# Apply mask if causal (must come before softmax)
|
|
275
311
|
if self.causal:
|
|
276
312
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
277
313
|
|
|
278
314
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
279
315
|
|
|
316
|
+
if self.talking_heads:
|
|
317
|
+
qk_scores = torch.einsum(
|
|
318
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
319
|
+
)
|
|
320
|
+
|
|
280
321
|
qk_scores = self.dropout(qk_scores)
|
|
281
322
|
|
|
282
323
|
output_with_heads = qk_scores @ v
|
|
@@ -309,7 +350,14 @@ class MHAttention(nn.Module):
|
|
|
309
350
|
self.q_proj.reset_parameters()
|
|
310
351
|
self.k_proj.reset_parameters()
|
|
311
352
|
self.v_proj.reset_parameters()
|
|
353
|
+
scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
|
|
312
354
|
self.out_proj.reset_parameters()
|
|
355
|
+
scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
|
|
356
|
+
|
|
357
|
+
if self.talking_heads:
|
|
358
|
+
# Initialize close to identity
|
|
359
|
+
nn.init.eye_(self.head_projection.weight)
|
|
360
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
313
361
|
|
|
314
362
|
|
|
315
363
|
class FeedforwardBlock(nn.Module):
|
|
@@ -329,17 +377,14 @@ class FeedforwardBlock(nn.Module):
|
|
|
329
377
|
outer_dropout=None,
|
|
330
378
|
linear_module_up=nn.Linear,
|
|
331
379
|
linear_module_down=nn.Linear,
|
|
332
|
-
pre_norm=True,
|
|
333
380
|
normformer=False,
|
|
334
|
-
post_norm=True,
|
|
335
|
-
residual_path=True,
|
|
336
381
|
checkpoint=True,
|
|
382
|
+
beta=1.0,
|
|
337
383
|
):
|
|
338
384
|
super().__init__()
|
|
339
385
|
|
|
340
386
|
self.checkpoint = checkpoint
|
|
341
|
-
self.
|
|
342
|
-
self.post_norm = post_norm
|
|
387
|
+
self.beta = beta
|
|
343
388
|
self.xglu = activation.__name__.endswith("GLU")
|
|
344
389
|
|
|
345
390
|
if self.residual_path and (output_features < input_features):
|
|
@@ -365,19 +410,26 @@ class FeedforwardBlock(nn.Module):
|
|
|
365
410
|
)
|
|
366
411
|
|
|
367
412
|
self.max_features = (
|
|
368
|
-
2 * ratio * output_features
|
|
413
|
+
2 * int(ratio * output_features)
|
|
414
|
+
if self.xglu
|
|
415
|
+
else int(ratio * output_features)
|
|
369
416
|
)
|
|
370
417
|
|
|
371
418
|
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
-
self.linear_out = linear_module_down(
|
|
419
|
+
self.linear_out = linear_module_down(
|
|
420
|
+
int(ratio * output_features), output_features
|
|
421
|
+
)
|
|
373
422
|
|
|
374
423
|
self.process = nn.Sequential(
|
|
375
424
|
*[
|
|
376
|
-
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
377
425
|
self.linear_in,
|
|
378
426
|
self.activation,
|
|
379
427
|
self.inner_dropout,
|
|
380
|
-
|
|
428
|
+
(
|
|
429
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
430
|
+
if normformer
|
|
431
|
+
else nn.Identity()
|
|
432
|
+
),
|
|
381
433
|
self.linear_out,
|
|
382
434
|
self.outer_dropout,
|
|
383
435
|
]
|
|
@@ -410,33 +462,16 @@ class FeedforwardBlock(nn.Module):
|
|
|
410
462
|
|
|
411
463
|
# Recycle weights if using recycling linear layers
|
|
412
464
|
if self.training and self.recycling_enabled:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
417
|
-
mask = probs < rate
|
|
418
|
-
if mask.any():
|
|
419
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
420
|
-
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
421
|
-
if self.xglu:
|
|
422
|
-
indices_in = torch.cat(
|
|
423
|
-
[indices, indices + self.linear_out.in_features]
|
|
424
|
-
)
|
|
425
|
-
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
426
|
-
else:
|
|
427
|
-
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
465
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
466
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
467
|
+
self.linear_out.reset_columns(indices)
|
|
428
468
|
|
|
429
469
|
if self.checkpoint:
|
|
430
470
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
431
471
|
else:
|
|
432
472
|
processed = self.process(x)
|
|
433
473
|
|
|
434
|
-
|
|
435
|
-
return self.layernorm(x + processed)
|
|
436
|
-
elif self.residual_path:
|
|
437
|
-
return x + processed
|
|
438
|
-
else:
|
|
439
|
-
return processed
|
|
474
|
+
return processed
|
|
440
475
|
|
|
441
476
|
def reset_parameters(self):
|
|
442
477
|
if self.post_norm:
|
|
@@ -447,8 +482,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
447
482
|
if hasattr(module, "reset_parameters"):
|
|
448
483
|
module.reset_parameters()
|
|
449
484
|
|
|
485
|
+
scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
|
|
486
|
+
scale_parameters(self.linear_out, self.beta)
|
|
487
|
+
|
|
450
488
|
|
|
451
|
-
class
|
|
489
|
+
class EncoderBlock(nn.Module):
|
|
452
490
|
"""
|
|
453
491
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
454
492
|
which is also what is seen in e.g.
|
|
@@ -465,6 +503,7 @@ class TransformerBlock(nn.Module):
|
|
|
465
503
|
relative_position_embedding=False,
|
|
466
504
|
source_size=None,
|
|
467
505
|
utility_tokens=0,
|
|
506
|
+
talking_heads=False,
|
|
468
507
|
mlp_ratio=4,
|
|
469
508
|
activation: nn.Module = nn.ReLU,
|
|
470
509
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -482,6 +521,8 @@ class TransformerBlock(nn.Module):
|
|
|
482
521
|
post_norm=False,
|
|
483
522
|
normformer=False,
|
|
484
523
|
checkpoint_ff=True,
|
|
524
|
+
alpha=1.0,
|
|
525
|
+
beta=1.0,
|
|
485
526
|
):
|
|
486
527
|
"""
|
|
487
528
|
Args:
|
|
@@ -493,15 +534,29 @@ class TransformerBlock(nn.Module):
|
|
|
493
534
|
|
|
494
535
|
super().__init__()
|
|
495
536
|
|
|
537
|
+
if pre_norm and post_norm:
|
|
538
|
+
raise ValueError("A transformer cannot be both prenorm and postnorm.")
|
|
539
|
+
|
|
496
540
|
self.pre_norm = pre_norm
|
|
497
541
|
self.post_norm = post_norm
|
|
498
542
|
self.normformer = normformer
|
|
499
543
|
|
|
544
|
+
self.alpha = alpha
|
|
545
|
+
self.beta = beta
|
|
546
|
+
|
|
500
547
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
501
548
|
|
|
502
|
-
self.
|
|
503
|
-
|
|
504
|
-
|
|
549
|
+
if self.pre_norm:
|
|
550
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
551
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
552
|
+
|
|
553
|
+
if normformer:
|
|
554
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
555
|
+
|
|
556
|
+
if self.post_norm:
|
|
557
|
+
self.input_norm = nn.LayerNorm(d_model)
|
|
558
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
559
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
505
560
|
|
|
506
561
|
if relative_position_embedding:
|
|
507
562
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
@@ -525,7 +580,9 @@ class TransformerBlock(nn.Module):
|
|
|
525
580
|
rotary_embedding=self.rotary_embedding,
|
|
526
581
|
source_size=source_size,
|
|
527
582
|
utility_tokens=utility_tokens,
|
|
583
|
+
talking_heads=talking_heads,
|
|
528
584
|
scaling=msa_scaling,
|
|
585
|
+
beta=beta,
|
|
529
586
|
)
|
|
530
587
|
|
|
531
588
|
# Submodule for the feedforward process
|
|
@@ -548,11 +605,9 @@ class TransformerBlock(nn.Module):
|
|
|
548
605
|
if ff_linear_module_down is not None
|
|
549
606
|
else linear_module
|
|
550
607
|
),
|
|
551
|
-
pre_norm=False, # Handled outside the block
|
|
552
608
|
normformer=normformer,
|
|
553
|
-
post_norm=False, # Handled outside the block
|
|
554
|
-
residual_path=False, # Handled outside the block
|
|
555
609
|
checkpoint=checkpoint_ff,
|
|
610
|
+
beta=beta,
|
|
556
611
|
)
|
|
557
612
|
|
|
558
613
|
self.reset_parameters()
|
|
@@ -562,22 +617,34 @@ class TransformerBlock(nn.Module):
|
|
|
562
617
|
return self.attn._kv_distance
|
|
563
618
|
|
|
564
619
|
def forward(self, x):
|
|
620
|
+
if self.post_norm:
|
|
621
|
+
x = self.input_norm(x)
|
|
565
622
|
|
|
566
623
|
if self.pre_norm:
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
x =
|
|
580
|
-
|
|
624
|
+
process_x = self.pre_attention_norm(x)
|
|
625
|
+
else:
|
|
626
|
+
process_x = x
|
|
627
|
+
|
|
628
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
629
|
+
|
|
630
|
+
if self.normformer:
|
|
631
|
+
processed = self.normformer_norm(processed)
|
|
632
|
+
|
|
633
|
+
x = self.alpha * x + processed
|
|
634
|
+
|
|
635
|
+
if self.post_norm:
|
|
636
|
+
x = self.post_attention_norm(x)
|
|
637
|
+
elif self.pre_norm:
|
|
638
|
+
process_x = self.pre_mlp_norm(x)
|
|
639
|
+
else:
|
|
640
|
+
process_x = x
|
|
641
|
+
|
|
642
|
+
processed = self.drop_path(self.ff(process_x))
|
|
643
|
+
|
|
644
|
+
x = self.alpha * x + processed
|
|
645
|
+
|
|
646
|
+
if self.post_norm:
|
|
647
|
+
x = self.post_mlp_norm(x)
|
|
581
648
|
|
|
582
649
|
return x
|
|
583
650
|
|
|
@@ -585,16 +652,26 @@ class TransformerBlock(nn.Module):
|
|
|
585
652
|
"""
|
|
586
653
|
Give back the attention scores used in this layer.
|
|
587
654
|
"""
|
|
655
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
588
656
|
if self.pre_norm:
|
|
589
|
-
|
|
657
|
+
# We must normalize the input before measuring attention logits
|
|
658
|
+
# to match what the model actually sees during forward()
|
|
659
|
+
x = self.pre_attention_norm(x)
|
|
590
660
|
return self.attn.attention_logits(x, x, x)
|
|
591
661
|
else:
|
|
592
662
|
return self.attn.attention_logits(x, x, x)
|
|
593
663
|
|
|
594
664
|
def reset_parameters(self):
|
|
595
|
-
self.
|
|
596
|
-
|
|
597
|
-
|
|
665
|
+
if self.pre_norm:
|
|
666
|
+
self.pre_attention_norm.reset_parameters()
|
|
667
|
+
self.pre_mlp_norm.reset_parameters()
|
|
668
|
+
|
|
669
|
+
if self.post_norm:
|
|
670
|
+
self.post_attention_norm.reset_parameters()
|
|
671
|
+
self.post_mlp_norm.reset_parameters()
|
|
672
|
+
|
|
673
|
+
if self.normformer:
|
|
674
|
+
self.normformer_norm.reset_parameters()
|
|
598
675
|
|
|
599
676
|
self.attn.reset_parameters()
|
|
600
677
|
self.ff.reset_parameters()
|
|
@@ -628,12 +705,15 @@ class TransformerEncoder(nn.Module):
|
|
|
628
705
|
causal=False,
|
|
629
706
|
linear_module=nn.Linear,
|
|
630
707
|
utility_tokens=0,
|
|
708
|
+
talking_heads=False,
|
|
631
709
|
return_utility_tokens=False,
|
|
632
710
|
pre_norm=True,
|
|
633
711
|
post_norm=False,
|
|
634
712
|
normformer=False,
|
|
635
713
|
msa_scaling="d",
|
|
636
714
|
checkpoint_ff=True,
|
|
715
|
+
alpha=1.0,
|
|
716
|
+
beta=1.0,
|
|
637
717
|
):
|
|
638
718
|
"""
|
|
639
719
|
Args:
|
|
@@ -656,6 +736,13 @@ class TransformerEncoder(nn.Module):
|
|
|
656
736
|
)
|
|
657
737
|
|
|
658
738
|
super().__init__()
|
|
739
|
+
|
|
740
|
+
if FLASH_ATTN and talking_heads:
|
|
741
|
+
warnings.warn(
|
|
742
|
+
"Using talking heads currently prevents using flash attention.",
|
|
743
|
+
stacklevel=2,
|
|
744
|
+
)
|
|
745
|
+
|
|
659
746
|
self.seq_len = seq_len
|
|
660
747
|
self.n_heads = n_heads
|
|
661
748
|
self._utility_tokens = utility_tokens
|
|
@@ -700,13 +787,14 @@ class TransformerEncoder(nn.Module):
|
|
|
700
787
|
|
|
701
788
|
self.blocks = nn.ModuleList(
|
|
702
789
|
[
|
|
703
|
-
|
|
790
|
+
EncoderBlock(
|
|
704
791
|
self.full_sequence_length,
|
|
705
792
|
d_model,
|
|
706
793
|
n_heads,
|
|
707
794
|
relative_position_embedding=relative_position_embedding,
|
|
708
795
|
source_size=source_size,
|
|
709
796
|
utility_tokens=utility_tokens,
|
|
797
|
+
talking_heads=talking_heads,
|
|
710
798
|
mlp_ratio=mlp_ratio,
|
|
711
799
|
activation=activation,
|
|
712
800
|
activation_kwargs=activation_kwargs,
|
|
@@ -724,6 +812,8 @@ class TransformerEncoder(nn.Module):
|
|
|
724
812
|
post_norm=post_norm,
|
|
725
813
|
normformer=normformer,
|
|
726
814
|
checkpoint_ff=checkpoint_ff,
|
|
815
|
+
alpha=alpha,
|
|
816
|
+
beta=beta,
|
|
727
817
|
)
|
|
728
818
|
for i in range(n_layers)
|
|
729
819
|
]
|
|
@@ -744,13 +834,14 @@ class TransformerEncoder(nn.Module):
|
|
|
744
834
|
x = x
|
|
745
835
|
|
|
746
836
|
if self.absolute_position_embedding is not None:
|
|
747
|
-
|
|
837
|
+
position_embedding = self.absolute_position_embedding(
|
|
748
838
|
torch.arange(
|
|
749
839
|
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
750
840
|
).unsqueeze(
|
|
751
841
|
0
|
|
752
842
|
) # to shape (1, seq_len) to broadcast over batch
|
|
753
843
|
)
|
|
844
|
+
x += position_embedding
|
|
754
845
|
|
|
755
846
|
return x
|
|
756
847
|
|
broccoli/vit.py
CHANGED
|
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
|
|
|
158
158
|
pooling_kernel_stride=2,
|
|
159
159
|
pooling_padding=1,
|
|
160
160
|
transformer_feedforward_first=True,
|
|
161
|
-
transformer_initial_ff_residual_path=True,
|
|
162
161
|
transformer_initial_ff_linear_module_up=None,
|
|
163
162
|
transformer_initial_ff_linear_module_down=None,
|
|
164
163
|
transformer_initial_ff_dropout=None,
|
|
@@ -174,6 +173,7 @@ class ViTEncoder(nn.Module):
|
|
|
174
173
|
transformer_heads=4,
|
|
175
174
|
transformer_mlp_ratio=2,
|
|
176
175
|
transformer_utility_tokens=0,
|
|
176
|
+
transformer_talking_heads=False,
|
|
177
177
|
transformer_return_utility_tokens=False,
|
|
178
178
|
transformer_activation: nn.Module = SquaredReLU,
|
|
179
179
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -187,9 +187,14 @@ class ViTEncoder(nn.Module):
|
|
|
187
187
|
transformer_stochastic_depth=0.1,
|
|
188
188
|
transformer_checkpoint_ff=True,
|
|
189
189
|
linear_module=nn.Linear,
|
|
190
|
+
alpha=1.0,
|
|
191
|
+
beta=1.0,
|
|
190
192
|
):
|
|
191
193
|
super().__init__()
|
|
192
194
|
|
|
195
|
+
self.alpha = alpha
|
|
196
|
+
self.beta = beta
|
|
197
|
+
|
|
193
198
|
if cnn_activation_kwargs is not None:
|
|
194
199
|
self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
|
|
195
200
|
else:
|
|
@@ -345,11 +350,14 @@ class ViTEncoder(nn.Module):
|
|
|
345
350
|
causal=False,
|
|
346
351
|
linear_module=linear_module,
|
|
347
352
|
utility_tokens=transformer_utility_tokens,
|
|
353
|
+
talking_heads=transformer_talking_heads,
|
|
348
354
|
return_utility_tokens=transformer_return_utility_tokens,
|
|
349
355
|
pre_norm=transformer_pre_norm,
|
|
350
356
|
normformer=transformer_normformer,
|
|
351
357
|
post_norm=transformer_post_norm,
|
|
352
358
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
359
|
+
alpha=self.alpha,
|
|
360
|
+
beta=self.beta,
|
|
353
361
|
)
|
|
354
362
|
else:
|
|
355
363
|
self.transformer = nn.Identity()
|
|
@@ -391,16 +399,14 @@ class ViTEncoder(nn.Module):
|
|
|
391
399
|
or transformer_ff_linear_module_down
|
|
392
400
|
or linear_module
|
|
393
401
|
),
|
|
394
|
-
pre_norm=transformer_pre_norm,
|
|
395
402
|
normformer=transformer_normformer,
|
|
396
|
-
post_norm=transformer_post_norm,
|
|
397
|
-
residual_path=transformer_initial_ff_residual_path,
|
|
398
403
|
checkpoint=transformer_checkpoint_ff,
|
|
404
|
+
beta=self.beta,
|
|
399
405
|
)
|
|
400
406
|
else:
|
|
401
407
|
self.initial_ff = nn.Identity()
|
|
402
408
|
|
|
403
|
-
self.
|
|
409
|
+
self.preprocess = nn.Sequential(
|
|
404
410
|
*[
|
|
405
411
|
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
406
412
|
self.cnn,
|
|
@@ -410,19 +416,21 @@ class ViTEncoder(nn.Module):
|
|
|
410
416
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
411
417
|
),
|
|
412
418
|
self.pooling_channels_padding,
|
|
413
|
-
|
|
414
|
-
self.transformer,
|
|
419
|
+
nn.LayerNorm(),
|
|
415
420
|
]
|
|
416
421
|
)
|
|
417
422
|
|
|
418
423
|
self.reset_parameters()
|
|
419
424
|
|
|
420
425
|
def forward(self, x):
|
|
421
|
-
|
|
426
|
+
x = self.preprocess(x)
|
|
427
|
+
x = x + self.initial_ff(x)
|
|
428
|
+
return self.transformer(x)
|
|
422
429
|
|
|
423
430
|
def attention_logits(self, x):
|
|
424
|
-
x = self.
|
|
425
|
-
|
|
431
|
+
x = self.preprocess(x)
|
|
432
|
+
x = x + self.initial_ff(x)
|
|
433
|
+
return self.transformer.attention_logits(x)
|
|
426
434
|
|
|
427
435
|
def reset_parameters(self):
|
|
428
436
|
for module in self.encoder:
|
|
@@ -456,7 +464,6 @@ class ViT(nn.Module):
|
|
|
456
464
|
pooling_kernel_stride=2,
|
|
457
465
|
pooling_padding=1,
|
|
458
466
|
transformer_feedforward_first=True,
|
|
459
|
-
transformer_initial_ff_residual_path=True,
|
|
460
467
|
transformer_initial_ff_linear_module_up=None,
|
|
461
468
|
transformer_initial_ff_linear_module_down=None,
|
|
462
469
|
transformer_initial_ff_dropout=None,
|
|
@@ -472,6 +479,7 @@ class ViT(nn.Module):
|
|
|
472
479
|
transformer_heads=4,
|
|
473
480
|
transformer_mlp_ratio=2,
|
|
474
481
|
transformer_utility_tokens=0,
|
|
482
|
+
transformer_talking_heads=False,
|
|
475
483
|
transformer_return_utility_tokens=False,
|
|
476
484
|
transformer_activation: nn.Module = SquaredReLU,
|
|
477
485
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -488,6 +496,8 @@ class ViT(nn.Module):
|
|
|
488
496
|
batch_norm_logits=True,
|
|
489
497
|
logit_projection_layer=nn.Linear,
|
|
490
498
|
linear_module=nn.Linear,
|
|
499
|
+
alpha=1.0,
|
|
500
|
+
beta=1.0,
|
|
491
501
|
):
|
|
492
502
|
|
|
493
503
|
super().__init__()
|
|
@@ -508,6 +518,9 @@ class ViT(nn.Module):
|
|
|
508
518
|
"SwiGLU": SwiGLU,
|
|
509
519
|
}[transformer_activation]
|
|
510
520
|
|
|
521
|
+
self.alpha = alpha
|
|
522
|
+
self.beta = beta
|
|
523
|
+
|
|
511
524
|
self.encoder = ViTEncoder(
|
|
512
525
|
input_size=input_size,
|
|
513
526
|
initial_batch_norm=initial_batch_norm,
|
|
@@ -527,7 +540,6 @@ class ViT(nn.Module):
|
|
|
527
540
|
pooling_kernel_stride=pooling_kernel_stride,
|
|
528
541
|
pooling_padding=pooling_padding,
|
|
529
542
|
transformer_feedforward_first=transformer_feedforward_first,
|
|
530
|
-
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
531
543
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
532
544
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
533
545
|
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
@@ -543,6 +555,7 @@ class ViT(nn.Module):
|
|
|
543
555
|
transformer_heads=transformer_heads,
|
|
544
556
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
545
557
|
transformer_utility_tokens=transformer_utility_tokens,
|
|
558
|
+
transformer_talking_heads=transformer_talking_heads,
|
|
546
559
|
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
547
560
|
transformer_activation=transformer_activation,
|
|
548
561
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
@@ -556,6 +569,8 @@ class ViT(nn.Module):
|
|
|
556
569
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
557
570
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
558
571
|
linear_module=linear_module,
|
|
572
|
+
alpha=alpha,
|
|
573
|
+
beta=beta,
|
|
559
574
|
)
|
|
560
575
|
|
|
561
576
|
self.pool = head(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
|
+
broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
|
|
3
|
+
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
+
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
|
+
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
|
+
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
+
broccoli/transformer.py,sha256=3vAQQ75SAyr4-m3e7vSru8M-RpUy2Enp5cVUafaVYMU,28410
|
|
8
|
+
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
+
broccoli/vit.py,sha256=jd4e6MjL2JKB8ynSQssWRh6Hs36RuLj4uWyUNVhIMUY,22472
|
|
10
|
+
broccoli_ml-13.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-13.0.1.dist-info/METADATA,sha256=gN9cKQDpwRr8JG_Ilj0ZknlPdxhq62I8xEBExVspKGw,1369
|
|
12
|
+
broccoli_ml-13.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-13.0.1.dist-info/RECORD,,
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
|
-
broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
|
|
3
|
-
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=IwvPAMbHqOYyz0g5WZyevPAhC1Pn0RTLniFM4E6lJoI,11511
|
|
5
|
-
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
|
-
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=r-ggAeNDW5QpBi9As1U9sIfxITBOx0WHk_K4zWpyTM8,26233
|
|
8
|
-
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
-
broccoli_ml-9.2.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
-
broccoli_ml-9.2.2.dist-info/METADATA,sha256=8ySQYntl9czgYyEQN5nyPS31tjwYC8M8Mx_iYhvtbzg,1368
|
|
12
|
-
broccoli_ml-9.2.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
-
broccoli_ml-9.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|