broccoli-ml 0.1.33__tar.gz → 0.1.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.1.33
3
+ Version: 0.1.35
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -301,6 +301,8 @@ class ConcatPool(nn.Module):
301
301
  them channel-wise.
302
302
  """
303
303
 
304
+ # TODO: change this to use nn.Fold instead of view, which is equivlent but more readable
305
+
304
306
  def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
305
307
  super().__init__()
306
308
 
@@ -343,7 +343,7 @@ class TransformerBlock(nn.Module):
343
343
  norm_process_x, norm_process_x, norm_process_x
344
344
  )
345
345
  process_x = process_x + self.ff_process(process_x)
346
- x = torch.cat([process_x, identity_x])[unshuffle_indices, :, :].contiguous()
346
+ x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
347
347
 
348
348
  return x
349
349
 
@@ -66,23 +66,33 @@ class CCTEncoder(nn.Module):
66
66
  tranformer_d_model_scale=True,
67
67
  tranformer_log_length_scale=True,
68
68
  tranformer_quiet_attention=True,
69
- activation: nn.Module = nn.ReLU,
70
- activation_kwargs: Optional[dict] = None,
69
+ cnn_activation: nn.Module = nn.ReLU,
70
+ cnn_activation_kwargs: Optional[dict] = None,
71
+ transformer_activation: nn.Module = nn.GELU,
72
+ transformer_activation_kwargs: Optional[dict] = None,
71
73
  mlp_dropout=0.0,
72
74
  msa_dropout=0.1,
73
75
  stochastic_depth=0.1,
74
76
  linear_module=nn.Linear,
75
77
  image_channels=3,
78
+ batch_norm=False,
76
79
  ):
77
80
  if conv_pooling_type not in ["maxpool", "concat"]:
78
81
  raise NotImplementedError("Pooling type must be maxpool or concat")
79
82
 
80
83
  super().__init__()
81
84
 
82
- if activation_kwargs is not None:
83
- self.activation = activation(**activation_kwargs)
85
+ if cnn_activation_kwargs is not None:
86
+ self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
84
87
  else:
85
- self.activation = activation()
88
+ self.cnn_activation = cnn_activation()
89
+
90
+ if transformer_activation_kwargs is not None:
91
+ self.transformer_activation = transformer_activation(
92
+ **transformer_activation_kwargs
93
+ )
94
+ else:
95
+ self.transformer_activation = transformer_activation()
86
96
 
87
97
  self.image_size = image_size
88
98
 
@@ -96,43 +106,48 @@ class CCTEncoder(nn.Module):
96
106
  self.sequence_length = output_size**2
97
107
 
98
108
  if conv_pooling_type == "maxpool":
99
-
100
109
  conv_out_channels = transformer_embedding_size
101
- pooling_out_channels = transformer_embedding_size
102
- self.pooling_adapter = nn.Identity()
110
+ elif conv_pooling_type == "concat":
111
+ conv_out_channels = int(
112
+ round(transformer_embedding_size / (conv_pooling_kernel_size**2))
113
+ )
114
+
115
+ # This if block rhymes:
116
+ if cnn_activation.__name__.endswith("GLU"):
117
+ conv_out_channels *= 2
103
118
 
119
+ self.conv = ConvLayer(
120
+ image_channels,
121
+ conv_out_channels,
122
+ kernel_size=conv_kernel_size,
123
+ stride=1,
124
+ padding="same",
125
+ linear_module=linear_module,
126
+ )
127
+
128
+ if conv_pooling_type == "maxpool":
104
129
  self.pool = nn.Sequential(
105
130
  *[
106
131
  Rearrange( # rearrange in case we're using XGLU activation
107
132
  "N C H W -> N H W C"
108
133
  ),
109
- self.activation,
134
+ self.cnn_activation,
110
135
  Rearrange("N H W C -> N C H W"),
111
136
  nn.MaxPool2d(
112
137
  conv_pooling_kernel_size,
113
138
  stride=conv_pooling_kernel_stride,
114
139
  padding=conv_pooling_kernel_padding,
115
140
  ),
116
- Rearrange("N C H W -> N (H W) C"),
141
+ Rearrange("N C H W -> N (H W) C"), # for transformer
117
142
  ]
118
143
  )
119
144
 
120
145
  elif conv_pooling_type == "concat":
121
- # TODO: make all of this adapter business part of concatpool
122
- conv_out_channels = int(
123
- round(transformer_embedding_size / (conv_pooling_kernel_size**2))
124
- )
125
- pooling_out_channels = conv_pooling_kernel_size**2 * conv_out_channels
126
- pooling_adapter_out_channels = transformer_embedding_size
127
- if activation.__name__.endswith("GLU"):
128
- pooling_adapter_out_channels *= 2
129
- self.pooling_adapter = nn.Sequential(
130
- *[
131
- Rearrange("N C H W -> N (H W) C"),
132
- nn.Linear(pooling_out_channels, pooling_adapter_out_channels),
133
- self.activation,
134
- ]
146
+ concatpool_activation_output_size = (
147
+ conv_pooling_kernel_size**2 * conv_out_channels
135
148
  )
149
+ if cnn_activation.__name__.endswith("GLU"):
150
+ concatpool_activation_output_size /= 2
136
151
 
137
152
  self.pool = nn.Sequential(
138
153
  *[
@@ -141,7 +156,15 @@ class CCTEncoder(nn.Module):
141
156
  stride=conv_pooling_kernel_stride,
142
157
  padding=conv_pooling_kernel_padding,
143
158
  ),
144
- self.pooling_adapter,
159
+ Rearrange( # rearrange in case we're using XGLU activation
160
+ "N C H W -> N H W C"
161
+ ),
162
+ self.cnn_activation,
163
+ Rearrange("N H W C -> N (H W) C"),
164
+ nn.Linear(
165
+ concatpool_activation_output_size, transformer_embedding_size
166
+ ),
167
+ self.cnn_activation,
145
168
  ]
146
169
  )
147
170
 
@@ -154,8 +177,8 @@ class CCTEncoder(nn.Module):
154
177
  position_embedding_type=transformer_position_embedding,
155
178
  source_size=(output_size, output_size),
156
179
  mlp_ratio=transformer_mlp_ratio,
157
- activation=activation,
158
- activation_kwargs=activation_kwargs,
180
+ activation=transformer_activation,
181
+ activation_kwargs=transformer_activation_kwargs,
159
182
  mlp_dropout=mlp_dropout,
160
183
  msa_dropout=msa_dropout,
161
184
  stochastic_depth=stochastic_depth,
@@ -171,21 +194,9 @@ class CCTEncoder(nn.Module):
171
194
  else:
172
195
  self.transformer = nn.Identity()
173
196
 
174
- # This code block rhymes:
175
- if activation.__name__.endswith("GLU"):
176
- conv_out_channels *= 2
177
-
178
- self.conv = ConvLayer(
179
- image_channels,
180
- conv_out_channels,
181
- kernel_size=conv_kernel_size,
182
- stride=1,
183
- padding="same",
184
- linear_module=linear_module,
185
- )
186
-
187
197
  self.encoder = nn.Sequential(
188
198
  *[
199
+ nn.BatchNorm2d(image_channels) if batch_norm else nn.Identity(),
189
200
  self.conv,
190
201
  self.pool,
191
202
  self.transformer,
@@ -223,8 +234,10 @@ class CCT(nn.Module):
223
234
  tranformer_d_model_scale=True,
224
235
  tranformer_log_length_scale=True,
225
236
  tranformer_quiet_attention=True,
226
- activation: nn.Module = nn.ReLU,
227
- activation_kwargs: Optional[dict] = None,
237
+ cnn_activation: nn.Module = nn.ReLU,
238
+ cnn_activation_kwargs: Optional[dict] = None,
239
+ transformer_activation: nn.Module = nn.GELU,
240
+ transformer_activation_kwargs: Optional[dict] = None,
228
241
  mlp_dropout=0.0, # The original paper got best performance from mlp_dropout=0.
229
242
  msa_dropout=0.1, # "" msa_dropout=0.1
230
243
  stochastic_depth=0.1, # "" stochastic_depth=0.1
@@ -253,8 +266,10 @@ class CCT(nn.Module):
253
266
  tranformer_d_model_scale=tranformer_d_model_scale,
254
267
  tranformer_log_length_scale=tranformer_log_length_scale,
255
268
  tranformer_quiet_attention=tranformer_quiet_attention,
256
- activation=activation,
257
- activation_kwargs=activation_kwargs,
269
+ cnn_activation=cnn_activation,
270
+ cnn_activation_kwargs=cnn_activation_kwargs,
271
+ transformer_activation=transformer_activation,
272
+ transformer_activation_kwargs=transformer_activation_kwargs,
258
273
  mlp_dropout=mlp_dropout,
259
274
  msa_dropout=msa_dropout,
260
275
  stochastic_depth=stochastic_depth,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.1.33"
3
+ version = "0.1.35"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes