broccoli-ml 9.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/utils.py ADDED
@@ -0,0 +1,15 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class PadTensor(nn.Module):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__()
8
+ self.args = args
9
+ self.kwargs = kwargs
10
+
11
+ def forward(self, x):
12
+ if sum(self.args[0]) == 0:
13
+ return x
14
+ else:
15
+ return F.pad(x, *self.args, **self.kwargs)
broccoli/vit.py ADDED
@@ -0,0 +1,600 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ from .transformer import TransformerEncoder, FeedforwardBlock
5
+ from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
+ from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
+ from .utils import PadTensor
8
+
9
+ from einops import einsum
10
+ from einops.layers.torch import Rearrange
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+
16
+ class GetCLSToken(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, x):
21
+ return x[:, 0, :]
22
+
23
+
24
+ class SequencePool(nn.Module):
25
+ def __init__(self, d_model, linear_module=nn.Linear):
26
+ super().__init__()
27
+ self.attention = nn.Sequential(
28
+ *[
29
+ linear_module(d_model, 1),
30
+ Rearrange("batch seq 1 -> batch seq"),
31
+ nn.Softmax(dim=-1),
32
+ ]
33
+ )
34
+
35
+ self.reset_parameters()
36
+
37
+ def forward(self, x):
38
+ weights = self.attention(x)
39
+ return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
40
+
41
+ def attention_scores(self, x):
42
+ return self.attention(x)
43
+
44
+ def reset_parameters(self):
45
+ # Iterate over modules in the sequential block
46
+ for module in self.attention:
47
+ if hasattr(module, "reset_parameters"):
48
+ module.reset_parameters()
49
+
50
+
51
+ class ClassificationHead(nn.Module):
52
+ """
53
+ A general classification head for a ViT
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ d_model,
59
+ n_classes,
60
+ logit_projection_layer=nn.Linear,
61
+ batch_norm_logits=True,
62
+ ):
63
+ super().__init__()
64
+ self.d_model = d_model
65
+ self.summarize = GetCLSToken()
66
+
67
+ if d_model == n_classes:
68
+ # No need to project
69
+ self.projection = nn.Identity()
70
+ else:
71
+ self.projection = logit_projection_layer(d_model, n_classes)
72
+
73
+ if batch_norm_logits:
74
+ self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
75
+ else:
76
+ self.batch_norm = nn.Identity()
77
+
78
+ self.classification_process = nn.Sequential(
79
+ *[
80
+ self.summarize,
81
+ self.projection,
82
+ self.batch_norm,
83
+ ]
84
+ )
85
+
86
+ self.reset_parameters()
87
+
88
+ def forward(self, x):
89
+ return self.classification_process(x)
90
+
91
+ def reset_parameters(self):
92
+ for module in self.classification_process:
93
+ if hasattr(module, "reset_parameters"):
94
+ module.reset_parameters()
95
+
96
+
97
+ class SequencePoolClassificationHead(ClassificationHead):
98
+ """
99
+ As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
100
+ Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
101
+ as a generalisation of average pooling.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ d_model,
107
+ n_classes,
108
+ logit_projection_layer=nn.Linear,
109
+ batch_norm_logits=True,
110
+ ):
111
+ super().__init__(
112
+ d_model,
113
+ n_classes,
114
+ logit_projection_layer=logit_projection_layer,
115
+ batch_norm_logits=batch_norm_logits,
116
+ )
117
+
118
+ self.summarize = SequencePool(d_model, logit_projection_layer)
119
+ # Rebuild the classification process with the correct summary module:
120
+ self.classification_process = nn.Sequential(
121
+ *[
122
+ self.summarize,
123
+ self.projection,
124
+ self.batch_norm,
125
+ ]
126
+ )
127
+
128
+ self.reset_parameters()
129
+
130
+
131
+ class ViTEncoder(nn.Module):
132
+ """
133
+ Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
134
+ *''Escaping the Big Data Paradigm with Compact Transformers''*](
135
+ https://arxiv.org/abs/2104.05704). It's basically a convolutional neural
136
+ network leading into a transformer encoder. To make it like the full CCT
137
+ we would finish it of with a sequence pooling layer but we won't always
138
+ want to do that.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ input_size=(32, 32),
144
+ in_channels=3,
145
+ initial_batch_norm=True,
146
+ cnn=True,
147
+ cnn_out_channels=16,
148
+ cnn_kernel_size=3,
149
+ cnn_kernel_stride=1,
150
+ cnn_padding="same",
151
+ cnn_kernel_dilation=1,
152
+ cnn_kernel_groups=1,
153
+ cnn_activation: nn.Module = ReLU,
154
+ cnn_activation_kwargs: Optional[dict] = None,
155
+ cnn_dropout=0.0,
156
+ pooling_type="concat", # max, average or concat
157
+ pooling_kernel_size=3,
158
+ pooling_kernel_stride=2,
159
+ pooling_padding=1,
160
+ transformer_feedforward_first=True,
161
+ transformer_initial_ff_residual_path=True,
162
+ transformer_initial_ff_linear_module_up=None,
163
+ transformer_initial_ff_linear_module_down=None,
164
+ transformer_initial_ff_dropout=None,
165
+ transformer_initial_ff_inner_dropout=None,
166
+ transformer_initial_ff_outer_dropout=None,
167
+ transformer_pre_norm=True,
168
+ transformer_normformer=False,
169
+ transformer_post_norm=False,
170
+ transformer_absolute_position_embedding=False,
171
+ transformer_relative_position_embedding=True,
172
+ transformer_embedding_size=256,
173
+ transformer_layers=7,
174
+ transformer_heads=4,
175
+ transformer_mlp_ratio=2,
176
+ transformer_utility_tokens=0,
177
+ transformer_return_utility_tokens=False,
178
+ transformer_activation: nn.Module = SquaredReLU,
179
+ transformer_activation_kwargs: Optional[dict] = None,
180
+ transformer_ff_linear_module_up=None,
181
+ transformer_ff_linear_module_down=None,
182
+ transformer_msa_scaling="d",
183
+ transformer_ff_dropout=0.0,
184
+ transformer_ff_inner_dropout=0.0,
185
+ transformer_ff_outer_dropout=0.0,
186
+ transformer_msa_dropout=0.1,
187
+ transformer_stochastic_depth=0.1,
188
+ transformer_checkpoint_ff=True,
189
+ linear_module=nn.Linear,
190
+ ):
191
+ super().__init__()
192
+
193
+ if cnn_activation_kwargs is not None:
194
+ self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
195
+ else:
196
+ self.cnn_activation = cnn_activation()
197
+
198
+ if transformer_activation_kwargs is not None:
199
+ self.transformer_activation = transformer_activation(
200
+ **transformer_activation_kwargs
201
+ )
202
+ else:
203
+ self.transformer_activation = transformer_activation()
204
+
205
+ self.input_size = input_size
206
+ self.spatial_dimensions = len(self.input_size)
207
+
208
+ if self.spatial_dimensions == 1:
209
+ maxpoolxd = nn.MaxPool1d
210
+ avgpoolxd = nn.AvgPool1d
211
+ convxd = nn.Conv1d
212
+ batchnormxd = nn.BatchNorm1d
213
+ spatial_dim_names = "D1"
214
+ elif self.spatial_dimensions == 2:
215
+ maxpoolxd = nn.MaxPool2d
216
+ avgpoolxd = nn.AvgPool2d
217
+ convxd = nn.Conv2d
218
+ batchnormxd = nn.BatchNorm2d
219
+ spatial_dim_names = "D1 D2"
220
+ elif self.spatial_dimensions == 3:
221
+ maxpoolxd = nn.MaxPool3d
222
+ avgpoolxd = nn.AvgPool3d
223
+ convxd = nn.Conv3d
224
+ batchnormxd = nn.BatchNorm3d
225
+ spatial_dim_names = "D1 D2 D3"
226
+ else:
227
+ raise NotImplementedError(
228
+ "`input_size` must be a tuple of length 1, 2, or 3."
229
+ )
230
+
231
+ if cnn:
232
+ # This block rhymes:
233
+ if cnn_activation.__name__.endswith("GLU"):
234
+ cnn_out_channels *= 2
235
+ cnn_output_size = calculate_output_spatial_size(
236
+ input_size,
237
+ kernel_size=cnn_kernel_size,
238
+ stride=cnn_kernel_stride,
239
+ padding=cnn_padding,
240
+ dilation=cnn_kernel_dilation,
241
+ )
242
+ self.cnn = convxd(
243
+ in_channels,
244
+ cnn_out_channels,
245
+ cnn_kernel_size,
246
+ stride=cnn_kernel_stride,
247
+ padding=cnn_padding,
248
+ dilation=cnn_kernel_dilation,
249
+ groups=cnn_kernel_groups,
250
+ bias=True,
251
+ padding_mode="zeros",
252
+ )
253
+ cnn_activation_out_channels = cnn_out_channels
254
+ self.activate_and_dropout = nn.Sequential(
255
+ *[
256
+ Rearrange( # rearrange in case we're using XGLU activation
257
+ f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
258
+ ),
259
+ self.cnn_activation,
260
+ Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
261
+ nn.Dropout(cnn_dropout),
262
+ batchnormxd(cnn_activation_out_channels),
263
+ ]
264
+ )
265
+ else:
266
+ self.cnn = nn.Identity()
267
+ self.activate_and_dropout = nn.Identity()
268
+ cnn_output_size = input_size
269
+ cnn_out_channels = in_channels
270
+ cnn_activation_out_channels = in_channels
271
+
272
+ pooling_kernel_voxels = math.prod(
273
+ spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
274
+ )
275
+
276
+ pooling_output_size = (
277
+ cnn_output_size
278
+ if pooling_type is None
279
+ else calculate_output_spatial_size(
280
+ cnn_output_size,
281
+ kernel_size=pooling_kernel_size,
282
+ stride=pooling_kernel_stride,
283
+ padding=pooling_padding,
284
+ dilation=1,
285
+ )
286
+ )
287
+
288
+ if pooling_type is None:
289
+ pooling_out_channels = cnn_activation_out_channels
290
+ self.pool = nn.Identity()
291
+
292
+ elif pooling_type == "max":
293
+ pooling_out_channels = cnn_activation_out_channels
294
+ self.pool = maxpoolxd(
295
+ pooling_kernel_size,
296
+ stride=pooling_kernel_stride,
297
+ padding=pooling_padding,
298
+ )
299
+ elif pooling_type == "average":
300
+ pooling_out_channels = cnn_activation_out_channels
301
+ self.pool = avgpoolxd(
302
+ pooling_kernel_size,
303
+ stride=pooling_kernel_stride,
304
+ padding=pooling_padding,
305
+ )
306
+ elif pooling_type == "concat":
307
+ pooling_out_channels = pooling_kernel_voxels * cnn_activation_out_channels
308
+ self.pool = SpaceToDepth(
309
+ pooling_kernel_size,
310
+ stride=pooling_kernel_stride,
311
+ padding=pooling_padding,
312
+ spatial_dimensions=self.spatial_dimensions,
313
+ )
314
+ else:
315
+ raise NotImplementedError(
316
+ "Pooling type must be max, average, concat or None"
317
+ )
318
+
319
+ self.pooling_channels_padding = PadTensor(
320
+ (0, max(0, transformer_embedding_size - pooling_out_channels))
321
+ )
322
+
323
+ self.sequence_length = math.prod(pooling_output_size) # One token per voxel
324
+
325
+ if transformer_layers > 0:
326
+ self.transformer = TransformerEncoder(
327
+ self.sequence_length,
328
+ transformer_embedding_size,
329
+ transformer_layers,
330
+ transformer_heads,
331
+ absolute_position_embedding=transformer_absolute_position_embedding,
332
+ relative_position_embedding=transformer_relative_position_embedding,
333
+ source_size=pooling_output_size,
334
+ mlp_ratio=transformer_mlp_ratio,
335
+ activation=transformer_activation,
336
+ activation_kwargs=transformer_activation_kwargs,
337
+ ff_linear_module_up=transformer_ff_linear_module_up,
338
+ ff_linear_module_down=transformer_ff_linear_module_down,
339
+ msa_scaling=transformer_msa_scaling,
340
+ ff_dropout=transformer_ff_dropout,
341
+ ff_inner_dropout=transformer_ff_inner_dropout,
342
+ ff_outer_dropout=transformer_ff_outer_dropout,
343
+ msa_dropout=transformer_msa_dropout,
344
+ stochastic_depth=transformer_stochastic_depth,
345
+ causal=False,
346
+ linear_module=linear_module,
347
+ utility_tokens=transformer_utility_tokens,
348
+ return_utility_tokens=transformer_return_utility_tokens,
349
+ pre_norm=transformer_pre_norm,
350
+ normformer=transformer_normformer,
351
+ post_norm=transformer_post_norm,
352
+ checkpoint_ff=transformer_checkpoint_ff,
353
+ )
354
+ else:
355
+ self.transformer = nn.Identity()
356
+
357
+ if transformer_feedforward_first:
358
+ self.initial_ff = FeedforwardBlock(
359
+ max(transformer_embedding_size, pooling_out_channels),
360
+ transformer_mlp_ratio,
361
+ transformer_embedding_size,
362
+ activation=transformer_activation,
363
+ activation_kwargs=transformer_activation_kwargs,
364
+ dropout=(
365
+ # First truthy assigned value
366
+ transformer_initial_ff_dropout
367
+ if transformer_initial_ff_dropout is not None
368
+ else transformer_ff_dropout
369
+ ),
370
+ inner_dropout=(
371
+ # First truthy assigned value
372
+ transformer_initial_ff_inner_dropout
373
+ if transformer_initial_ff_inner_dropout is not None
374
+ else transformer_ff_inner_dropout
375
+ ),
376
+ outer_dropout=(
377
+ # First truthy assigned value
378
+ transformer_initial_ff_outer_dropout
379
+ if transformer_initial_ff_outer_dropout is not None
380
+ else transformer_ff_outer_dropout
381
+ ),
382
+ linear_module_up=(
383
+ # First truthy assigned value
384
+ transformer_initial_ff_linear_module_up
385
+ or transformer_ff_linear_module_up
386
+ or linear_module
387
+ ),
388
+ linear_module_down=(
389
+ # First truthy assigned value
390
+ transformer_initial_ff_linear_module_down
391
+ or transformer_ff_linear_module_down
392
+ or linear_module
393
+ ),
394
+ pre_norm=transformer_pre_norm,
395
+ normformer=transformer_normformer,
396
+ post_norm=transformer_post_norm,
397
+ residual_path=transformer_initial_ff_residual_path,
398
+ checkpoint=transformer_checkpoint_ff,
399
+ )
400
+ else:
401
+ self.initial_ff = nn.Identity()
402
+
403
+ self.encoder = nn.Sequential(
404
+ *[
405
+ batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
406
+ self.cnn,
407
+ self.activate_and_dropout,
408
+ self.pool,
409
+ Rearrange( # for transformer
410
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
411
+ ),
412
+ self.pooling_channels_padding,
413
+ self.initial_ff,
414
+ self.transformer,
415
+ ]
416
+ )
417
+
418
+ self.reset_parameters()
419
+
420
+ def forward(self, x):
421
+ return self.encoder(x)
422
+
423
+ def attention_logits(self, x):
424
+ x = self.encoder[:-1](x)
425
+ return self.encoder[-1].attention_logits(x)
426
+
427
+ def reset_parameters(self):
428
+ for module in self.encoder:
429
+ if hasattr(module, "reset_parameters"):
430
+ module.reset_parameters()
431
+
432
+
433
+ class ViT(nn.Module):
434
+ """
435
+ ...
436
+ """
437
+
438
+ def __init__(
439
+ self,
440
+ input_size=(32, 32),
441
+ image_classes=100,
442
+ in_channels=3,
443
+ initial_batch_norm=True,
444
+ cnn=True,
445
+ cnn_out_channels=16,
446
+ cnn_kernel_size=3,
447
+ cnn_kernel_stride=1,
448
+ cnn_padding="same",
449
+ cnn_kernel_dilation=1,
450
+ cnn_kernel_groups=1,
451
+ cnn_activation: nn.Module = ReLU,
452
+ cnn_activation_kwargs: Optional[dict] = None,
453
+ cnn_dropout=0.0,
454
+ pooling_type="concat", # max, average or concat
455
+ pooling_kernel_size=3,
456
+ pooling_kernel_stride=2,
457
+ pooling_padding=1,
458
+ transformer_feedforward_first=True,
459
+ transformer_initial_ff_residual_path=True,
460
+ transformer_initial_ff_linear_module_up=None,
461
+ transformer_initial_ff_linear_module_down=None,
462
+ transformer_initial_ff_dropout=None,
463
+ transformer_initial_ff_inner_dropout=None,
464
+ transformer_initial_ff_outer_dropout=None,
465
+ transformer_pre_norm=True,
466
+ transformer_normformer=False,
467
+ transformer_post_norm=False,
468
+ transformer_absolute_position_embedding=False,
469
+ transformer_relative_position_embedding=True,
470
+ transformer_embedding_size=256,
471
+ transformer_layers=7,
472
+ transformer_heads=4,
473
+ transformer_mlp_ratio=2,
474
+ transformer_utility_tokens=0,
475
+ transformer_return_utility_tokens=False,
476
+ transformer_activation: nn.Module = SquaredReLU,
477
+ transformer_activation_kwargs: Optional[dict] = None,
478
+ transformer_ff_linear_module_up=None,
479
+ transformer_ff_linear_module_down=None,
480
+ transformer_msa_scaling="d",
481
+ transformer_ff_dropout=0.0,
482
+ transformer_ff_inner_dropout=0.0,
483
+ transformer_ff_outer_dropout=0.0,
484
+ transformer_msa_dropout=0.1,
485
+ transformer_stochastic_depth=0.1,
486
+ transformer_checkpoint_ff=True,
487
+ head=SequencePoolClassificationHead,
488
+ batch_norm_logits=True,
489
+ logit_projection_layer=nn.Linear,
490
+ linear_module=nn.Linear,
491
+ ):
492
+
493
+ super().__init__()
494
+
495
+ if isinstance(cnn_activation, str):
496
+ cnn_activation = {
497
+ "ReLU": ReLU,
498
+ "SquaredReLU": SquaredReLU,
499
+ "GELU": GELU,
500
+ "SwiGLU": SwiGLU,
501
+ }[cnn_activation]
502
+
503
+ if isinstance(transformer_activation, str):
504
+ transformer_activation = {
505
+ "ReLU": ReLU,
506
+ "SquaredReLU": SquaredReLU,
507
+ "GELU": GELU,
508
+ "SwiGLU": SwiGLU,
509
+ }[transformer_activation]
510
+
511
+ self.encoder = ViTEncoder(
512
+ input_size=input_size,
513
+ initial_batch_norm=initial_batch_norm,
514
+ in_channels=in_channels,
515
+ cnn=cnn,
516
+ cnn_out_channels=cnn_out_channels,
517
+ cnn_kernel_size=cnn_kernel_size,
518
+ cnn_kernel_stride=cnn_kernel_stride,
519
+ cnn_padding=cnn_padding,
520
+ cnn_kernel_dilation=cnn_kernel_dilation,
521
+ cnn_kernel_groups=cnn_kernel_groups,
522
+ cnn_activation=cnn_activation,
523
+ cnn_activation_kwargs=cnn_activation_kwargs,
524
+ cnn_dropout=cnn_dropout,
525
+ pooling_type=pooling_type,
526
+ pooling_kernel_size=pooling_kernel_size,
527
+ pooling_kernel_stride=pooling_kernel_stride,
528
+ pooling_padding=pooling_padding,
529
+ transformer_feedforward_first=transformer_feedforward_first,
530
+ transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
531
+ transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
532
+ transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
533
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
534
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
535
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
536
+ transformer_pre_norm=transformer_pre_norm,
537
+ transformer_normformer=transformer_normformer,
538
+ transformer_post_norm=transformer_post_norm,
539
+ transformer_absolute_position_embedding=transformer_absolute_position_embedding,
540
+ transformer_relative_position_embedding=transformer_relative_position_embedding,
541
+ transformer_embedding_size=transformer_embedding_size,
542
+ transformer_layers=transformer_layers,
543
+ transformer_heads=transformer_heads,
544
+ transformer_mlp_ratio=transformer_mlp_ratio,
545
+ transformer_utility_tokens=transformer_utility_tokens,
546
+ transformer_return_utility_tokens=transformer_return_utility_tokens,
547
+ transformer_activation=transformer_activation,
548
+ transformer_activation_kwargs=transformer_activation_kwargs,
549
+ transformer_ff_linear_module_up=transformer_ff_linear_module_up,
550
+ transformer_ff_linear_module_down=transformer_ff_linear_module_down,
551
+ transformer_msa_scaling=transformer_msa_scaling,
552
+ transformer_ff_dropout=transformer_ff_dropout,
553
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
554
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
555
+ transformer_msa_dropout=transformer_msa_dropout,
556
+ transformer_stochastic_depth=transformer_stochastic_depth,
557
+ transformer_checkpoint_ff=transformer_checkpoint_ff,
558
+ linear_module=linear_module,
559
+ )
560
+
561
+ self.pool = head(
562
+ transformer_embedding_size,
563
+ image_classes,
564
+ logit_projection_layer=logit_projection_layer,
565
+ batch_norm_logits=batch_norm_logits,
566
+ )
567
+
568
+ self.reset_parameters()
569
+
570
+ @property
571
+ def sequence_length(self):
572
+ return self.encoder.sequence_length
573
+
574
+ def forward(self, x):
575
+ return self.pool(self.encoder(x))
576
+
577
+ def attention_logits(self, x):
578
+ return self.encoder.attention_logits(x)
579
+
580
+ def pool_attention(self, x):
581
+ if hasattr(self.pool.summarize, "attention"):
582
+ return self.pool.summarize.attention(self.encoder(x))
583
+ else:
584
+ raise NotImplementedError(
585
+ "`pool_attention` is currently only implemented where"
586
+ " head class is SequencePoolClassificationHead"
587
+ )
588
+
589
+ def head_to_utility_token_attention_logits(self, x):
590
+ all_attention = self.attention_logits(x)
591
+ batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
592
+ sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
593
+ n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
594
+ return sequence_averages[
595
+ :, :, :n_utility_tokens
596
+ ] # (layer, head, utility_tokens)
597
+
598
+ def reset_parameters(self):
599
+ self.encoder.reset_parameters()
600
+ self.pool.reset_parameters()
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 nicholasbailey87
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,43 @@
1
+ Metadata-Version: 2.3
2
+ Name: broccoli-ml
3
+ Version: 9.5.1
4
+ Summary: Some useful Pytorch models, circa 2025
5
+ License: MIT
6
+ Author: Nicholas Bailey
7
+ Requires-Python: >=3.8
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.8
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Dist: einops (>=0.8.1,<0.9.0)
17
+ Description-Content-Type: text/markdown
18
+
19
+ # broccoli
20
+
21
+ Some useful PyTorch models, circa 2025.
22
+
23
+ ![broccoli](broccoli.png "Image of a rockstar made of broccoli")
24
+
25
+ # Getting started
26
+
27
+ You can install broccoli with
28
+
29
+ ```
30
+ pip install broccoli-ml
31
+ ```
32
+
33
+ PyTorch is a peer dependency of `broccoli`, which means
34
+ * You will need to make sure you have PyTorch installed in order to use `broccoli`
35
+ * PyTorch will **not** be installed automatically when you install `broccoli`
36
+
37
+ We take this approach because PyTorch versioning is environment-specific and we don't know where you will want to use `broccoli`. If we automatically install PyTorch for you, there's a good chance we would get it wrong!
38
+
39
+ Therefore, please also make sure you install PyTorch.
40
+
41
+ # Usage examples
42
+
43
+ ...
@@ -0,0 +1,13 @@
1
+ broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
+ broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
3
+ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
+ broccoli/linear.py,sha256=i4U7ZC4ZWEH82YpDasx0Qs1pc3gkyL-3ajuyKCbsGTM,12649
5
+ broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
+ broccoli/transformer.py,sha256=ULk-QQX3hAI14-aCKhp9QSebzX4KUjlisEGup2Eycck,25565
8
+ broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
+ broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
10
+ broccoli_ml-9.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-9.5.1.dist-info/METADATA,sha256=HXRWnuc_-Gs_g37_RP3-POTLmi7sZamlzYv5SJEun1Y,1368
12
+ broccoli_ml-9.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-9.5.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 2.1.3
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any