broccoli-ml 0.29.1__py3-none-any.whl → 10.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/utils.py CHANGED
@@ -1,9 +1,15 @@
1
- import importlib.resources
2
- import torch
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
3
 
4
4
 
5
- def get_weights(name: str) -> torch.Tensor:
6
- resource_path = importlib.resources.files("broccoli.assets") / name
7
- with importlib.resources.as_file(resource_path) as path_to_weights:
8
- weights = torch.load(path_to_weights)
9
- return weights
5
+ class PadTensor(nn.Module):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__()
8
+ self.args = args
9
+ self.kwargs = kwargs
10
+
11
+ def forward(self, x):
12
+ if sum(self.args[0]) == 0:
13
+ return x
14
+ else:
15
+ return F.pad(x, *self.args, **self.kwargs)
broccoli/vit.py CHANGED
@@ -4,24 +4,13 @@ from typing import Optional
4
4
  from .transformer import TransformerEncoder, FeedforwardBlock
5
5
  from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
- from .linear import SpectralNormLinear
7
+ from .utils import PadTensor
8
+
8
9
  from einops import einsum
9
10
  from einops.layers.torch import Rearrange
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
11
 
13
-
14
- class PadTensor(nn.Module):
15
- def __init__(self, *args, **kwargs):
16
- super().__init__()
17
- self.args = args
18
- self.kwargs = kwargs
19
-
20
- def forward(self, x):
21
- if sum(self.args[0]) == 0:
22
- return x
23
- else:
24
- return F.pad(x, *self.args, **self.kwargs)
12
+ import torch
13
+ import torch.nn as nn
25
14
 
26
15
 
27
16
  class GetCLSToken(nn.Module):
@@ -43,22 +32,45 @@ class SequencePool(nn.Module):
43
32
  ]
44
33
  )
45
34
 
35
+ self.reset_parameters()
36
+
46
37
  def forward(self, x):
47
38
  weights = self.attention(x)
48
39
  return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
49
40
 
41
+ def attention_scores(self, x):
42
+ return self.attention(x)
43
+
44
+ def reset_parameters(self):
45
+ # Iterate over modules in the sequential block
46
+ for module in self.attention:
47
+ if hasattr(module, "reset_parameters"):
48
+ module.reset_parameters()
49
+
50
50
 
51
51
  class ClassificationHead(nn.Module):
52
52
  """
53
53
  A general classification head for a ViT
54
54
  """
55
55
 
56
- def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
56
+ def __init__(
57
+ self,
58
+ d_model,
59
+ n_classes,
60
+ logit_projection_layer=nn.Linear,
61
+ batch_norm_logits=True,
62
+ ):
57
63
  super().__init__()
58
64
  self.d_model = d_model
59
65
  self.summarize = GetCLSToken()
60
- self.projection = nn.Linear(d_model, n_classes)
61
- if batch_norm:
66
+
67
+ if d_model == n_classes:
68
+ # No need to project
69
+ self.projection = nn.Identity()
70
+ else:
71
+ self.projection = logit_projection_layer(d_model, n_classes)
72
+
73
+ if batch_norm_logits:
62
74
  self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
63
75
  else:
64
76
  self.batch_norm = nn.Identity()
@@ -71,9 +83,16 @@ class ClassificationHead(nn.Module):
71
83
  ]
72
84
  )
73
85
 
86
+ self.reset_parameters()
87
+
74
88
  def forward(self, x):
75
89
  return self.classification_process(x)
76
90
 
91
+ def reset_parameters(self):
92
+ for module in self.classification_process:
93
+ if hasattr(module, "reset_parameters"):
94
+ module.reset_parameters()
95
+
77
96
 
78
97
  class SequencePoolClassificationHead(ClassificationHead):
79
98
  """
@@ -82,9 +101,31 @@ class SequencePoolClassificationHead(ClassificationHead):
82
101
  as a generalisation of average pooling.
83
102
  """
84
103
 
85
- def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
86
- super().__init__(d_model, linear_module, out_dim, batch_norm=batch_norm)
87
- self.summarize = SequencePool(d_model, linear_module)
104
+ def __init__(
105
+ self,
106
+ d_model,
107
+ n_classes,
108
+ logit_projection_layer=nn.Linear,
109
+ batch_norm_logits=True,
110
+ ):
111
+ super().__init__(
112
+ d_model,
113
+ n_classes,
114
+ logit_projection_layer=logit_projection_layer,
115
+ batch_norm_logits=batch_norm_logits,
116
+ )
117
+
118
+ self.summarize = SequencePool(d_model, logit_projection_layer)
119
+ # Rebuild the classification process with the correct summary module:
120
+ self.classification_process = nn.Sequential(
121
+ *[
122
+ self.summarize,
123
+ self.projection,
124
+ self.batch_norm,
125
+ ]
126
+ )
127
+
128
+ self.reset_parameters()
88
129
 
89
130
 
90
131
  class ViTEncoder(nn.Module):
@@ -117,20 +158,36 @@ class ViTEncoder(nn.Module):
117
158
  pooling_kernel_stride=2,
118
159
  pooling_padding=1,
119
160
  transformer_feedforward_first=True,
161
+ transformer_initial_ff_residual_path=True,
162
+ transformer_initial_ff_linear_module_up=None,
163
+ transformer_initial_ff_linear_module_down=None,
164
+ transformer_initial_ff_dropout=None,
165
+ transformer_initial_ff_inner_dropout=None,
166
+ transformer_initial_ff_outer_dropout=None,
120
167
  transformer_pre_norm=True,
121
168
  transformer_normformer=False,
122
- transformer_position_embedding="relative", # absolute or relative
169
+ transformer_post_norm=False,
170
+ transformer_absolute_position_embedding=False,
171
+ transformer_relative_position_embedding=True,
123
172
  transformer_embedding_size=256,
124
173
  transformer_layers=7,
125
174
  transformer_heads=4,
126
175
  transformer_mlp_ratio=2,
127
- transformer_bos_tokens=0,
128
- transformer_return_bos_tokens=False,
176
+ transformer_utility_tokens=0,
177
+ transformer_talking_heads=False,
178
+ transformer_return_utility_tokens=False,
129
179
  transformer_activation: nn.Module = SquaredReLU,
130
180
  transformer_activation_kwargs: Optional[dict] = None,
131
- transformer_mlp_dropout=0.0,
181
+ transformer_ff_linear_module_up=None,
182
+ transformer_ff_linear_module_down=None,
183
+ transformer_msa_scaling="d",
184
+ transformer_ff_dropout=0.0,
185
+ transformer_ff_inner_dropout=0.0,
186
+ transformer_ff_outer_dropout=0.0,
132
187
  transformer_msa_dropout=0.1,
133
188
  transformer_stochastic_depth=0.1,
189
+ transformer_checkpoint_ff=True,
190
+ transformer_layerscale=True,
134
191
  linear_module=nn.Linear,
135
192
  ):
136
193
  super().__init__()
@@ -232,13 +289,7 @@ class ViTEncoder(nn.Module):
232
289
 
233
290
  if pooling_type is None:
234
291
  pooling_out_channels = cnn_activation_out_channels
235
- self.pool = nn.Sequential(
236
- *[
237
- Rearrange(
238
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
239
- ), # for transformer
240
- ]
241
- )
292
+ self.pool = nn.Identity()
242
293
 
243
294
  elif pooling_type == "max":
244
295
  pooling_out_channels = cnn_activation_out_channels
@@ -279,20 +330,30 @@ class ViTEncoder(nn.Module):
279
330
  transformer_embedding_size,
280
331
  transformer_layers,
281
332
  transformer_heads,
282
- position_embedding_type=transformer_position_embedding,
333
+ absolute_position_embedding=transformer_absolute_position_embedding,
334
+ relative_position_embedding=transformer_relative_position_embedding,
283
335
  source_size=pooling_output_size,
284
336
  mlp_ratio=transformer_mlp_ratio,
285
337
  activation=transformer_activation,
286
338
  activation_kwargs=transformer_activation_kwargs,
287
- mlp_dropout=transformer_mlp_dropout,
339
+ ff_linear_module_up=transformer_ff_linear_module_up,
340
+ ff_linear_module_down=transformer_ff_linear_module_down,
341
+ msa_scaling=transformer_msa_scaling,
342
+ ff_dropout=transformer_ff_dropout,
343
+ ff_inner_dropout=transformer_ff_inner_dropout,
344
+ ff_outer_dropout=transformer_ff_outer_dropout,
288
345
  msa_dropout=transformer_msa_dropout,
289
346
  stochastic_depth=transformer_stochastic_depth,
290
347
  causal=False,
291
348
  linear_module=linear_module,
292
- bos_tokens=transformer_bos_tokens,
293
- return_bos_tokens=transformer_return_bos_tokens,
349
+ utility_tokens=transformer_utility_tokens,
350
+ talking_heads=transformer_talking_heads,
351
+ return_utility_tokens=transformer_return_utility_tokens,
294
352
  pre_norm=transformer_pre_norm,
295
353
  normformer=transformer_normformer,
354
+ post_norm=transformer_post_norm,
355
+ checkpoint_ff=transformer_checkpoint_ff,
356
+ layerscale=transformer_layerscale,
296
357
  )
297
358
  else:
298
359
  self.transformer = nn.Identity()
@@ -304,11 +365,41 @@ class ViTEncoder(nn.Module):
304
365
  transformer_embedding_size,
305
366
  activation=transformer_activation,
306
367
  activation_kwargs=transformer_activation_kwargs,
307
- dropout=transformer_mlp_dropout,
308
- linear_module=linear_module,
368
+ dropout=(
369
+ # First truthy assigned value
370
+ transformer_initial_ff_dropout
371
+ if transformer_initial_ff_dropout is not None
372
+ else transformer_ff_dropout
373
+ ),
374
+ inner_dropout=(
375
+ # First truthy assigned value
376
+ transformer_initial_ff_inner_dropout
377
+ if transformer_initial_ff_inner_dropout is not None
378
+ else transformer_ff_inner_dropout
379
+ ),
380
+ outer_dropout=(
381
+ # First truthy assigned value
382
+ transformer_initial_ff_outer_dropout
383
+ if transformer_initial_ff_outer_dropout is not None
384
+ else transformer_ff_outer_dropout
385
+ ),
386
+ linear_module_up=(
387
+ # First truthy assigned value
388
+ transformer_initial_ff_linear_module_up
389
+ or transformer_ff_linear_module_up
390
+ or linear_module
391
+ ),
392
+ linear_module_down=(
393
+ # First truthy assigned value
394
+ transformer_initial_ff_linear_module_down
395
+ or transformer_ff_linear_module_down
396
+ or linear_module
397
+ ),
309
398
  pre_norm=transformer_pre_norm,
310
399
  normformer=transformer_normformer,
311
- # raw_input=not cnn,
400
+ post_norm=transformer_post_norm,
401
+ residual_path=transformer_initial_ff_residual_path,
402
+ checkpoint=transformer_checkpoint_ff,
312
403
  )
313
404
  else:
314
405
  self.initial_ff = nn.Identity()
@@ -328,17 +419,24 @@ class ViTEncoder(nn.Module):
328
419
  ]
329
420
  )
330
421
 
422
+ self.reset_parameters()
423
+
331
424
  def forward(self, x):
332
425
  return self.encoder(x)
333
426
 
427
+ def attention_logits(self, x):
428
+ x = self.encoder[:-1](x)
429
+ return self.encoder[-1].attention_logits(x)
430
+
431
+ def reset_parameters(self):
432
+ for module in self.encoder:
433
+ if hasattr(module, "reset_parameters"):
434
+ module.reset_parameters()
435
+
334
436
 
335
437
  class ViT(nn.Module):
336
438
  """
337
- Denoising convolutional transformer
338
- Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
339
- *''Escaping the Big Data Paradigm with Compact Transformers''*](
340
- https://arxiv.org/abs/2104.05704). It's a convolutional neural network
341
- leading into a transformer encoder, followed by a sequence pooling layer.
439
+ ...
342
440
  """
343
441
 
344
442
  def __init__(
@@ -362,23 +460,40 @@ class ViT(nn.Module):
362
460
  pooling_kernel_stride=2,
363
461
  pooling_padding=1,
364
462
  transformer_feedforward_first=True,
463
+ transformer_initial_ff_residual_path=True,
464
+ transformer_initial_ff_linear_module_up=None,
465
+ transformer_initial_ff_linear_module_down=None,
466
+ transformer_initial_ff_dropout=None,
467
+ transformer_initial_ff_inner_dropout=None,
468
+ transformer_initial_ff_outer_dropout=None,
365
469
  transformer_pre_norm=True,
366
470
  transformer_normformer=False,
367
- transformer_position_embedding="relative", # absolute or relative
471
+ transformer_post_norm=False,
472
+ transformer_absolute_position_embedding=False,
473
+ transformer_relative_position_embedding=True,
368
474
  transformer_embedding_size=256,
369
475
  transformer_layers=7,
370
476
  transformer_heads=4,
371
477
  transformer_mlp_ratio=2,
372
- transformer_bos_tokens=0,
373
- transformer_return_bos_tokens=False,
478
+ transformer_utility_tokens=0,
479
+ transformer_talking_heads=False,
480
+ transformer_return_utility_tokens=False,
374
481
  transformer_activation: nn.Module = SquaredReLU,
375
482
  transformer_activation_kwargs: Optional[dict] = None,
376
- transformer_mlp_dropout=0.0,
483
+ transformer_ff_linear_module_up=None,
484
+ transformer_ff_linear_module_down=None,
485
+ transformer_msa_scaling="d",
486
+ transformer_ff_dropout=0.0,
487
+ transformer_ff_inner_dropout=0.0,
488
+ transformer_ff_outer_dropout=0.0,
377
489
  transformer_msa_dropout=0.1,
378
490
  transformer_stochastic_depth=0.1,
379
- batch_norm_outputs=True,
380
- linear_module=SpectralNormLinear,
491
+ transformer_checkpoint_ff=True,
492
+ transformer_layerscale=True,
381
493
  head=SequencePoolClassificationHead,
494
+ batch_norm_logits=True,
495
+ logit_projection_layer=nn.Linear,
496
+ linear_module=nn.Linear,
382
497
  ):
383
498
 
384
499
  super().__init__()
@@ -418,33 +533,76 @@ class ViT(nn.Module):
418
533
  pooling_kernel_stride=pooling_kernel_stride,
419
534
  pooling_padding=pooling_padding,
420
535
  transformer_feedforward_first=transformer_feedforward_first,
536
+ transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
537
+ transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
538
+ transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
539
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
540
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
541
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
421
542
  transformer_pre_norm=transformer_pre_norm,
422
543
  transformer_normformer=transformer_normformer,
423
- transformer_position_embedding=transformer_position_embedding,
544
+ transformer_post_norm=transformer_post_norm,
545
+ transformer_absolute_position_embedding=transformer_absolute_position_embedding,
546
+ transformer_relative_position_embedding=transformer_relative_position_embedding,
424
547
  transformer_embedding_size=transformer_embedding_size,
425
548
  transformer_layers=transformer_layers,
426
549
  transformer_heads=transformer_heads,
427
550
  transformer_mlp_ratio=transformer_mlp_ratio,
428
- transformer_bos_tokens=transformer_bos_tokens,
429
- transformer_return_bos_tokens=transformer_return_bos_tokens,
551
+ transformer_utility_tokens=transformer_utility_tokens,
552
+ transformer_talking_heads=transformer_talking_heads,
553
+ transformer_return_utility_tokens=transformer_return_utility_tokens,
430
554
  transformer_activation=transformer_activation,
431
555
  transformer_activation_kwargs=transformer_activation_kwargs,
432
- transformer_mlp_dropout=transformer_mlp_dropout,
556
+ transformer_ff_linear_module_up=transformer_ff_linear_module_up,
557
+ transformer_ff_linear_module_down=transformer_ff_linear_module_down,
558
+ transformer_msa_scaling=transformer_msa_scaling,
559
+ transformer_ff_dropout=transformer_ff_dropout,
560
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
561
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
433
562
  transformer_msa_dropout=transformer_msa_dropout,
434
563
  transformer_stochastic_depth=transformer_stochastic_depth,
564
+ transformer_checkpoint_ff=transformer_checkpoint_ff,
565
+ transformer_layerscale=transformer_layerscale,
435
566
  linear_module=linear_module,
436
567
  )
437
568
 
438
569
  self.pool = head(
439
570
  transformer_embedding_size,
440
- linear_module,
441
571
  image_classes,
442
- batch_norm=batch_norm_outputs,
572
+ logit_projection_layer=logit_projection_layer,
573
+ batch_norm_logits=batch_norm_logits,
443
574
  )
444
575
 
576
+ self.reset_parameters()
577
+
445
578
  @property
446
579
  def sequence_length(self):
447
580
  return self.encoder.sequence_length
448
581
 
449
582
  def forward(self, x):
450
583
  return self.pool(self.encoder(x))
584
+
585
+ def attention_logits(self, x):
586
+ return self.encoder.attention_logits(x)
587
+
588
+ def pool_attention(self, x):
589
+ if hasattr(self.pool.summarize, "attention"):
590
+ return self.pool.summarize.attention(self.encoder(x))
591
+ else:
592
+ raise NotImplementedError(
593
+ "`pool_attention` is currently only implemented where"
594
+ " head class is SequencePoolClassificationHead"
595
+ )
596
+
597
+ def head_to_utility_token_attention_logits(self, x):
598
+ all_attention = self.attention_logits(x)
599
+ batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
600
+ sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
601
+ n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
602
+ return sequence_averages[
603
+ :, :, :n_utility_tokens
604
+ ] # (layer, head, utility_tokens)
605
+
606
+ def reset_parameters(self):
607
+ self.encoder.reset_parameters()
608
+ self.pool.reset_parameters()
@@ -1,17 +1,19 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.29.1
3
+ Version: 10.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
7
- Requires-Python: >=3.11
7
+ Requires-Python: >=3.8
8
8
  Classifier: License :: OSI Approved :: MIT License
9
9
  Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.8
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
10
13
  Classifier: Programming Language :: Python :: 3.11
11
14
  Classifier: Programming Language :: Python :: 3.12
12
15
  Classifier: Programming Language :: Python :: 3.13
13
16
  Requires-Dist: einops (>=0.8.1,<0.9.0)
14
- Requires-Dist: numpy (>=2.0.2,<2.1.0)
15
17
  Description-Content-Type: text/markdown
16
18
 
17
19
  # broccoli
@@ -0,0 +1,13 @@
1
+ broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
+ broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
3
+ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
+ broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
+ broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
+ broccoli/transformer.py,sha256=lnfiv7UIYbABiClIluy6CefGxaiYMrvBcj2Ul0uU6xE,27693
8
+ broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
+ broccoli/vit.py,sha256=EGbQb-atuzG3JAx7kdTaJEbWvQR-4XgyYvwjKkN5C38,22612
10
+ broccoli_ml-10.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-10.0.1.dist-info/METADATA,sha256=65GKe2Jor5jgUZ8zxROntJ_t0XwAlaukrvpT7nxS0lQ,1369
12
+ broccoli_ml-10.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-10.0.1.dist-info/RECORD,,
broccoli/eigenpatches.py DELETED
@@ -1,49 +0,0 @@
1
- """
2
- Jordan (2024) was able to train a CNN to 94% accuracy on CIFAR-10 in 3.29 seconds
3
- on a single A100 GPU by using carefully-tuned hyperparameters and a number of
4
- techniques to increase learning efficiency. The author notes that applying fixed
5
- weights to the first layer of the network that approximate a whitening
6
- transformation on image patches, following tsyam-code, (2023), was "the single
7
- most impactful feature... [and] more than doubles training speed".
8
-
9
- The usefulness of a fixed layer that whitens image patches can be justified
10
- according to the work of Chowers & Weiss (2022), who find that the first layer
11
- weights of a convolutional neural network will asymptotically approach a whitening
12
- transformation regardless of the details of the rest of the network architecture
13
- or the training data. This effectively functions as a bandpass filter layer,
14
- reminiscent of the way neurons in the human primary visual cortex work (Kristensen
15
- & Sandberg, 2021).
16
-
17
- The `eigenvectors` function here is adapted from
18
- https://github.com/KellerJordan/cifar10-airbench/blob/master/airbench96_faster.py
19
- using https://datascienceplus.com/understanding-the-covariance-matrix/
20
- """
21
-
22
- import torch
23
- import torch.nn as nn
24
- from einops import rearrange
25
-
26
-
27
- def eigenvectors(images: torch.Tensor, patch_size: int, eps=5e-4) -> torch.Tensor:
28
- """
29
- Adapted from
30
- github.com/KellerJordan/cifar10-airbench/blob/master/airbench96_faster.py
31
- using https://datascienceplus.com/understanding-the-covariance-matrix/
32
-
33
- Args:
34
- images: a batch of training images (the bigger and more representative the better!)
35
- patch_size: the size of the eigenvectors we want to create (i.e. the patch/kernel
36
- size of the model we will initialise with the eigenvectors)
37
- eps: a small number to avoid division by zero
38
- """
39
- with torch.no_grad():
40
- unfolder = nn.Unfold(kernel_size=patch_size, stride=1)
41
- patches = unfolder(images) # (N, patch_elements, patches_per_image)
42
- patches = rearrange(patches, "N elements patches -> (N patches) elements")
43
- n = patches.size(0)
44
- centred = patches - patches.mean(dim=1, keepdim=True)
45
- covariance_matrix = (
46
- centred.T @ centred
47
- ) / n # https://datascienceplus.com/understanding-the-covariance-matrix/
48
- _, eigenvectors = torch.linalg.eigh(covariance_matrix)
49
- return eigenvectors
@@ -1,17 +0,0 @@
1
- broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
- broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
3
- broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
4
- broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
5
- broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
- broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
- broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
- broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
9
- broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
- broccoli/tensor.py,sha256=zhSOo9W24FEgN7U35wy3ZIJHnw3u4cepJO5heCw6vwU,4590
11
- broccoli/transformer.py,sha256=jQGpj_e5WAEU_zEPjCU0OyD_08O3HwsMBg3pbrCzp4E,16924
12
- broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=m4Wa8B8L25xSODh91ViVyLmwLOBZayp5S7S9f8pIvZo,16109
14
- broccoli_ml-0.29.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.29.1.dist-info/METADATA,sha256=AojdLjmBwqW9of7D6RSjCdWbDRJO2GM3NbXd1TYeOJY,1257
16
- broccoli_ml-0.29.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.29.1.dist-info/RECORD,,