flownets 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flownets-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Tommaso Giacometti
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.4
2
+ Name: flownets
3
+ Version: 0.0.1
4
+ Summary: A new beginning for my models :)
5
+ Home-page: https://github.com/TommyGiak/FlowNets
6
+ Author: Tommaso Giacometti
7
+ Author-email: tommaso.giak@gmail.com
8
+ License: MIT
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy
16
+ Requires-Dist: torch
17
+ Dynamic: author
18
+ Dynamic: author-email
19
+ Dynamic: classifier
20
+ Dynamic: description
21
+ Dynamic: description-content-type
22
+ Dynamic: home-page
23
+ Dynamic: license
24
+ Dynamic: license-file
25
+ Dynamic: requires-dist
26
+ Dynamic: requires-python
27
+ Dynamic: summary
28
+
29
+ # Here I will insert a better description of the package...
30
+ But this is the null version so you must wait a bit :(
@@ -0,0 +1,2 @@
1
+ # I just want my architectures to be used in my projects
2
+ A complete documentation will follow soon (I hope)
@@ -0,0 +1,10 @@
1
+ from .model2D import HybridUNetCFM2D
2
+ from .model3D import HybridUNetCFM3D
3
+
4
+ from .version import __version__
5
+
6
+ __all__ = [
7
+ 'HybridUNetCFM2D',
8
+ 'HybridUNetCFM3D',
9
+ '__version__'
10
+ ]
@@ -0,0 +1,300 @@
1
+ # hybrid_unet_cfm_2d.py
2
+ # 2D Hybrid U-Net + Transformer (bottleneck) for Conditional Flow Matching (CFM)
3
+ # Works in latent space (2D volumes). Includes spatial cross-attention for masks
4
+ # and global conditioning for clinical embeddings. Training loop skeleton included.
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ # ---------------------------- Utilities ---------------------------------
14
+ # Sinusoidal time embedding
15
+ class SinusoidalPosEmb(nn.Module):
16
+ def __init__(self, dim):
17
+ super().__init__()
18
+ self.dim = dim
19
+ self.half = self.dim // 2
20
+
21
+ # clearer check: dimension must be even
22
+ assert self.dim % 2 == 0, f'The dimension of the Positional embedding must be a multiple of 2, not {dim}!'
23
+
24
+ def forward(self, t):
25
+ # t: (B,) floats in [0,1]
26
+ device = t.device
27
+ # use 'half' in the denominator so we don't divide by zero when half == 1
28
+ freqs = torch.exp(torch.arange(self.half, device=device, dtype=torch.float32) * -(math.log(100.0) / self.half))
29
+ args = t.view(-1,1) * freqs[None, :]
30
+ emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
31
+ return emb
32
+
33
+
34
+ class TimeSequential(nn.Sequential):
35
+ def __init__(self, *args, **kwargs) -> None:
36
+ super().__init__(*args, **kwargs)
37
+ pass
38
+
39
+ def forward(self, input, t_emb=None):
40
+ for module in self:
41
+ input = module(input, t_emb)
42
+ return input
43
+
44
+
45
+ # --------------------------- 2D building blocks -------------------------
46
+ class Conv2dBlock(nn.Module):
47
+ def __init__(self, in_ch, out_ch, kernel=3, padding=1, use_norm=True, num_groups=8):
48
+ super().__init__()
49
+ self.seq = nn.Sequential()
50
+ if use_norm:
51
+ if num_groups>out_ch:
52
+ num_groups=out_ch
53
+ assert out_ch%num_groups==0, f'In the Conv2dBlock, the number of output channels ({out_ch}) must be a multiple of the number of groups ({num_groups})!'
54
+ self.seq.append(nn.Conv2d(in_ch, out_ch, kernel, padding=padding, bias=False))
55
+ self.seq.append(nn.GroupNorm(num_groups=num_groups, num_channels=out_ch))
56
+ else:
57
+ self.seq.append(nn.Conv2d(in_ch, out_ch, kernel, padding=padding))
58
+ self.seq.append(nn.SiLU())
59
+
60
+ def forward(self, x):
61
+ return self.seq(x)
62
+
63
+
64
+ class Residual2dBlock(nn.Module):
65
+ def __init__(self, ch, time_emb_dim=None):
66
+ super().__init__()
67
+ self.conv1 = Conv2dBlock(ch, ch)
68
+ self.time_proj = nn.Linear(time_emb_dim, ch) if time_emb_dim is not None else None
69
+ self.conv2 = Conv2dBlock(ch, ch)
70
+
71
+ def forward(self, x, t_emb=None):
72
+ h = self.conv1(x)
73
+ if self.time_proj is not None and t_emb is not None:
74
+ # broadcast time embed -> (B, C, 1, 1)
75
+ proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1)
76
+ h = h + proj
77
+ elif self.time_proj is not None and t_emb is None:
78
+ raise RuntimeError('In the Module is defined a time projection embetting layer but no time embedding was passed in the forward!')
79
+ h = self.conv2(h)
80
+ return x + h
81
+
82
+
83
+ # Simple Downsample/Upsample
84
+ class Downsample2d(nn.Module):
85
+ def __init__(self, in_ch, out_ch=None, conv_layer=True):
86
+ super().__init__()
87
+ if conv_layer:
88
+ out_ch = out_ch or in_ch*2
89
+ self.op = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
90
+ else:
91
+ if out_ch is not None:
92
+ raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
93
+ self.op = nn.MaxPool2d(kernel_size=2, stride=2)
94
+
95
+ def forward(self, x):
96
+ return self.op(x)
97
+
98
+
99
+ class Upsample2d(nn.Module):
100
+ def __init__(self, in_ch, out_ch=None, conv_layer=True, upsample_mode='nearest'):
101
+ super().__init__()
102
+ if conv_layer:
103
+ out_ch = out_ch or in_ch//2
104
+ self.op = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
105
+ else:
106
+ if out_ch is not None:
107
+ raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
108
+ self.op = nn.Upsample(scale_factor=2, mode=upsample_mode)
109
+
110
+ def forward(self, x):
111
+ return self.op(x)
112
+
113
+
114
+ # ----------------------- Cross-Attention modules -----------------------
115
+ # A simple tokenized cross-attention where spatial features are flattened
116
+ # to tokens and attend to condition tokens (mask tokens or clinical tokens).
117
+
118
+ class CrossAttention(nn.Module):
119
+ def __init__(self, dim, cond_dim, heads=8, dropout=0.0, bias=False):
120
+ super().__init__()
121
+ if (dim % heads != 0) or (cond_dim % heads != 0):
122
+ raise ValueError(f'The dimensions of the embeddings ({dim} and {cond_dim}) in CrossAttention must be multiples of the number of heads ({heads})!')
123
+ self.mha = nn.MultiheadAttention(num_heads=heads, embed_dim=dim, kdim=cond_dim, vdim=cond_dim,
124
+ dropout=dropout, bias=bias, add_zero_attn=False, batch_first=True)
125
+
126
+ def forward(self, x, cond):
127
+ # return only the attention output; residual connection is applied in the transformer block
128
+ attn_out, _ = self.mha(x, cond, cond, need_weights=False)
129
+ return attn_out
130
+
131
+
132
+ # ------------------------- Transformer (bottleneck) -------------------
133
+ class SimpleTransformerBlock(nn.Module):
134
+ def __init__(self, dim, heads=8, mlp_ratio=4.0, dropout=0.0):
135
+ super().__init__()
136
+ self.norm1 = nn.LayerNorm(dim)
137
+ self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
138
+ self.norm2 = nn.LayerNorm(dim)
139
+ hidden = int(dim * mlp_ratio)
140
+ self.mlp = nn.Sequential(
141
+ nn.Linear(dim, hidden),
142
+ nn.SiLU(),
143
+ nn.Dropout(dropout),
144
+ nn.Linear(hidden, dim),
145
+ nn.Dropout(dropout),
146
+ )
147
+ warnings.warn('In SimpleTransformerBlock LayerNorm are used, not GroupNorm, but why? IDK maybe we are carzy')
148
+ warnings.warn('The attention mechanism is implemented with torch, you may use flash attention... Flash attention is wanderful!')
149
+ warnings.warn('The attention transformer block do not use ADAPTIVE layer/group norm... We do not like mutants')
150
+
151
+ def forward(self, x):
152
+ # x: (B, N, dim)
153
+ h = self.norm1(x)
154
+ attn_out, _ = self.attn(h, h, h, need_weights=False)
155
+ x = x + attn_out
156
+ x = x + self.mlp(self.norm2(x))
157
+ return x
158
+
159
+ class CrossTransformerBlock(nn.Module):
160
+ def __init__(self, dim, cond_dim, heads=8, mlp_ratio=4.0, dropout=0.0):
161
+ super().__init__()
162
+ self.norm1 = nn.LayerNorm(dim)
163
+ self.cross_attn = CrossAttention(dim, cond_dim, heads=heads, dropout=dropout, bias=False)
164
+ self.norm2 = nn.LayerNorm(dim)
165
+ hidden = int(dim * mlp_ratio)
166
+ self.mlp = nn.Sequential(
167
+ nn.Linear(dim, hidden),
168
+ nn.SiLU(),
169
+ nn.Dropout(dropout),
170
+ nn.Linear(hidden, dim),
171
+ nn.Dropout(dropout),
172
+ )
173
+
174
+ def forward(self, x, cond_tokens):
175
+ # cross-attention with residual
176
+ h = self.norm1(x)
177
+ attn_out = self.cross_attn(h, cond_tokens)
178
+ x += attn_out
179
+ # feed-forward with its own normalization and residual
180
+ x += self.mlp(self.norm2(x))
181
+ return x
182
+
183
+
184
+ # --------------------------- Hybrid U-Net CFM --------------------------
185
+ class HybridUNetCFM2D(nn.Module):
186
+ def __init__(self,
187
+ img_size : tuple, # 2d tuple (Axial slices, Coronal slices, Saggital slices)
188
+ in_channels=1,
189
+ channels_per_down=[8,16,32,64],
190
+ n_residuals_blocks=1,
191
+ time_emb_dim=256,
192
+ cond_global_dim=0,
193
+ attn_patch_size=2,
194
+ num_bottleneck_blocks=2):
195
+
196
+ super().__init__()
197
+ self.img_size = img_size #! not really needed!!!
198
+ self.patch_size = attn_patch_size
199
+ self.in_channels = in_channels
200
+ self.chs = channels_per_down
201
+ self.n_residuals_blocks = n_residuals_blocks
202
+
203
+ self.time_mlp = nn.Sequential(SinusoidalPosEmb(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU())
204
+
205
+ # Encoder
206
+ self.res_blocks_encoder = nn.ModuleList(
207
+ TimeSequential(*[Residual2dBlock(ch=c, time_emb_dim=time_emb_dim)
208
+ for _ in range(n_residuals_blocks)])
209
+ for c in [in_channels] + self.chs[:-1]
210
+ )
211
+
212
+ self.downs = nn.ModuleList([Downsample2d(in_ch=ic, out_ch=oc) for ic, oc in zip([in_channels]+self.chs[:-1], self.chs)])
213
+
214
+ assert len(self.res_blocks_encoder) == len(self.downs)
215
+
216
+ # Bottleneck
217
+ bottleneck_ch = self.chs[-1]
218
+ self.bottleneck_res = Residual2dBlock(bottleneck_ch, time_emb_dim)
219
+ self.transformer_blocks = nn.ModuleList([
220
+ CrossTransformerBlock(bottleneck_ch*(self.patch_size**2), cond_global_dim) if cond_global_dim > 0 else SimpleTransformerBlock(bottleneck_ch*(self.patch_size**2))
221
+ for _ in range(num_bottleneck_blocks)
222
+ ])
223
+
224
+ # Up path
225
+ self.ups = nn.ModuleList([Upsample2d(in_ch=ic, out_ch=oc) for ic, oc in zip(reversed(self.chs), reversed([in_channels]+self.chs[:-1]))])
226
+ self.conv_skip_connection_decored = nn.ModuleList([Conv2dBlock(in_ch=oc*2, out_ch=oc) for oc in reversed([in_channels]+self.chs[:-1])])
227
+ self.res_blocks_decoder = nn.ModuleList(
228
+ TimeSequential(*[Residual2dBlock(ch=c, time_emb_dim=time_emb_dim)
229
+ for _ in range(n_residuals_blocks)])
230
+ for c in reversed([in_channels] + self.chs[:-1])
231
+ ) # per ora va ben così, ma poi magari va aggiunto un pos encoding per accettare anche diverse image sizes...!!! Molta fatica sta maledetta attenzione
232
+
233
+ assert len(self.res_blocks_decoder) == len(self.ups) == len(self.conv_skip_connection_decored)
234
+
235
+ pass
236
+
237
+ def patching_and_tokenizer(self, layer, channels, current_img_size):
238
+ B, C, H, W = layer.shape
239
+ assert (C,H,W) == (channels, *current_img_size) # TO REMOVE
240
+ p = self.patch_size
241
+
242
+ layer = layer.unfold(2,p,p).unfold(3,p,p)
243
+ layer = layer.reshape(B, C, -1, p, p).transpose(1,2).flatten(start_dim=2).contiguous() # Ci ho perso tipo 45 minuti per arrivare a questo cazzo di unfolding quindi abbiatene cura (portrebbe non funzionare comunque maledizione)
244
+ return layer
245
+
246
+ def tokens_to_patches(self, tokens, channels, current_img_size):
247
+ B, N, D = tokens.shape
248
+ H, W = current_img_size
249
+ assert N*D == math.prod((channels, *current_img_size)) # TO REMOVE
250
+ p = self.patch_size
251
+
252
+ tokens = tokens.unflatten(-1, (channels, p, p)).transpose(1, 2) # (B,C,64,2,2)
253
+ tokens = tokens.view(B, channels, H//p, W//p, p, p).permute(0, 1, 2, 4, 3, 5).reshape(B, channels, *current_img_size)
254
+ assert tokens.shape == (B, channels, H, W) # TO REMOVE
255
+ return tokens
256
+
257
+ def get_parameters_number(self):
258
+ return sum(p.numel() for p in self.parameters())
259
+
260
+ def forward(self, t, z_t, cond_global: Optional[torch.Tensor] = None):
261
+ # z_t: (B, C, H, W)
262
+ t_emb = self.time_mlp(t)
263
+
264
+ # Encoder
265
+ skips = []
266
+ x = z_t
267
+ for i, (res, down) in enumerate(zip(self.res_blocks_encoder, self.downs)):
268
+ x = res(x, t_emb)
269
+ skips.append(x)
270
+ x = down(x)
271
+
272
+ # Bottleneck
273
+ x = self.bottleneck_res(x, t_emb)
274
+
275
+ # tokens
276
+ B, C, H, W = x.shape
277
+ x = self.patching_and_tokenizer(layer=x, channels=C, current_img_size=(H, W)) # (B, N, D), N = (H//p)*(W//p), D = C*(p**2)
278
+
279
+ # prepare global cond tokens if provided
280
+ cond_tokens = None
281
+ if cond_global is not None:
282
+ # cond_global: (B, cond_dim) -> expand to tokens M=1
283
+ cond_tokens = cond_global.unsqueeze(1) # (B, 1, cond_dim)
284
+ for tb in self.transformer_blocks:
285
+ if isinstance(tb, CrossTransformerBlock) and cond_tokens is not None:
286
+ x = tb(x, cond_tokens)
287
+ else:
288
+ x = tb(x)
289
+ x = self.tokens_to_patches(x, channels=C, current_img_size=(H, W))
290
+
291
+ # Decoder
292
+ for up, conv_skip, res in zip(self.ups, self.conv_skip_connection_decored, self.res_blocks_decoder):
293
+ x = up(x)
294
+ skip = skips.pop()
295
+ x = torch.cat([x, skip], dim=1) # channel concat
296
+ x = conv_skip(x)
297
+ x = res(x, t_emb)
298
+
299
+ return x
300
+
@@ -0,0 +1,299 @@
1
+ # hybrid_unet_cfm_3d.py
2
+ # 3D Hybrid U-Net + Transformer (bottleneck) for Conditional Flow Matching (CFM)
3
+ # Works in latent space (3D volumes). Includes spatial cross-attention for masks
4
+ # and global conditioning for clinical embeddings. Training loop skeleton included.
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ # ---------------------------- Utilities ---------------------------------
14
+ # Sinusoidal time embedding
15
+ class SinusoidalPosEmb(nn.Module):
16
+ def __init__(self, dim):
17
+ super().__init__()
18
+ self.dim = dim
19
+ self.half = self.dim // 2
20
+ # clearer check: dimension must be even
21
+ assert self.dim % 2 == 0, f'The dimension of the Positional embedding must be a multiple of 2, not {dim}!'
22
+
23
+ def forward(self, t):
24
+ # t: (B,) floats in [0,1]
25
+ device = t.device
26
+ # use 'half' in the denominator so we don't divide by zero when half == 1
27
+ freqs = torch.exp(torch.arange(self.half, device=device, dtype=torch.float32) * -(math.log(100.0) / self.half))
28
+ args = t.view(-1,1) * freqs[None, :]
29
+ emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
30
+ return emb
31
+
32
+
33
+ class TimeSequential(nn.Sequential):
34
+ def __init__(self, *args, **kwargs) -> None:
35
+ super().__init__(*args, **kwargs)
36
+ pass
37
+
38
+ def forward(self, input, t_emb=None):
39
+ for module in self:
40
+ input = module(input, t_emb)
41
+ return input
42
+
43
+
44
+ # --------------------------- 3D building blocks -------------------------
45
+ class Conv3dBlock(nn.Module):
46
+ def __init__(self, in_ch, out_ch, kernel=3, padding=1, use_norm=True, num_groups=8):
47
+ super().__init__()
48
+ self.seq = nn.Sequential()
49
+ if use_norm:
50
+ if num_groups>out_ch:
51
+ num_groups=out_ch
52
+ assert out_ch%num_groups==0, f'In the Conv3dBlock, the number of output channels ({out_ch}) must be a multiple of the number of groups ({num_groups})!'
53
+ self.seq.append(nn.Conv3d(in_ch, out_ch, kernel, padding=padding, bias=False))
54
+ self.seq.append(nn.GroupNorm(num_groups=num_groups, num_channels=out_ch))
55
+ else:
56
+ self.seq.append(nn.Conv3d(in_ch, out_ch, kernel, padding=padding))
57
+ self.seq.append(nn.SiLU())
58
+
59
+ def forward(self, x):
60
+ return self.seq(x)
61
+
62
+
63
+ class Residual3dBlock(nn.Module):
64
+ def __init__(self, ch, time_emb_dim=None):
65
+ super().__init__()
66
+ self.conv1 = Conv3dBlock(ch, ch)
67
+ self.time_proj = nn.Linear(time_emb_dim, ch) if time_emb_dim is not None else None
68
+ self.conv2 = Conv3dBlock(ch, ch)
69
+
70
+ def forward(self, x, t_emb=None):
71
+ h = self.conv1(x)
72
+ if self.time_proj is not None and t_emb is not None:
73
+ # broadcast time embed -> (B, C, 1,1,1)
74
+ proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
75
+ h = h + proj
76
+ elif self.time_proj is not None and t_emb is None:
77
+ raise RuntimeError('In the Module is defined a time projection embetting layer but no time embedding was passed in the forward!')
78
+ h = self.conv2(h)
79
+ return x + h
80
+
81
+
82
+ # Simple Downsample/Upsample
83
+ class Downsample3d(nn.Module):
84
+ def __init__(self, in_ch, out_ch=None, conv_layer=True):
85
+ super().__init__()
86
+ if conv_layer:
87
+ out_ch = out_ch or in_ch*2
88
+ self.op = nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
89
+ else:
90
+ if out_ch is not None:
91
+ raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
92
+ self.op = nn.MaxPool3d(kernel_size=2, stride=2)
93
+
94
+ def forward(self, x):
95
+ return self.op(x)
96
+
97
+
98
+ class Upsample3d(nn.Module):
99
+ def __init__(self, in_ch, out_ch=None, conv_layer=True, upsample_mode='nearest'):
100
+ super().__init__()
101
+ if conv_layer:
102
+ out_ch = out_ch or in_ch//2
103
+ self.op = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
104
+ else:
105
+ if out_ch is not None:
106
+ raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
107
+ self.op = nn.Upsample(scale_factor=2, mode=upsample_mode)
108
+
109
+ def forward(self, x):
110
+ return self.op(x)
111
+
112
+
113
+ # ----------------------- Cross-Attention modules -----------------------
114
+ # A simple tokenized cross-attention where spatial features are flattened
115
+ # to tokens and attend to condition tokens (mask tokens or clinical tokens).
116
+
117
+ class CrossAttention(nn.Module):
118
+ def __init__(self, dim, cond_dim, heads=8, dropout=0.0, bias=False):
119
+ super().__init__()
120
+ if (dim % heads != 0) or (cond_dim % heads != 0):
121
+ raise ValueError(f'The dimensions of the embeddings ({dim} and {cond_dim}) in CrossAttention must be multiples of the number of heads ({heads})!')
122
+ self.mha = nn.MultiheadAttention(num_heads=heads, embed_dim=dim, kdim=cond_dim, vdim=cond_dim,
123
+ dropout=dropout, bias=bias, add_zero_attn=False, batch_first=True)
124
+
125
+ def forward(self, x, cond):
126
+ # return only the attention output; residual connection is applied in the transformer block
127
+ attn_out, _ = self.mha(x, cond, cond, need_weights=False)
128
+ return attn_out
129
+
130
+
131
+ # ------------------------- Transformer (bottleneck) -------------------
132
+ class SimpleTransformerBlock(nn.Module):
133
+ def __init__(self, dim, heads=8, mlp_ratio=4.0, dropout=0.0):
134
+ super().__init__()
135
+ self.norm1 = nn.LayerNorm(dim)
136
+ self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
137
+ self.norm2 = nn.LayerNorm(dim)
138
+ hidden = int(dim * mlp_ratio)
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(dim, hidden),
141
+ nn.SiLU(),
142
+ nn.Dropout(dropout),
143
+ nn.Linear(hidden, dim),
144
+ nn.Dropout(dropout),
145
+ )
146
+ warnings.warn('In SimpleTransformerBlock LayerNorm are used, not GroupNorm, but why? IDK maybe we are carzy')
147
+ warnings.warn('The attention mechanism is implemented with torch, you may use flash attention... Flash attention is wanderful!')
148
+ warnings.warn('The attention transformer block do not use ADAPTIVE layer/group norm... We do not like mutants')
149
+
150
+ def forward(self, x):
151
+ # x: (B, N, dim)
152
+ h = self.norm1(x)
153
+ attn_out, _ = self.attn(h, h, h, need_weights=False)
154
+ x = x + attn_out
155
+ x = x + self.mlp(self.norm2(x))
156
+ return x
157
+
158
+ class CrossTransformerBlock(nn.Module):
159
+ def __init__(self, dim, cond_dim, heads=8, mlp_ratio=4.0, dropout=0.0):
160
+ super().__init__()
161
+ self.norm1 = nn.LayerNorm(dim)
162
+ self.cross_attn = CrossAttention(dim, cond_dim, heads=heads, dropout=dropout, bias=False)
163
+ self.norm2 = nn.LayerNorm(dim)
164
+ hidden = int(dim * mlp_ratio)
165
+ self.mlp = nn.Sequential(
166
+ nn.Linear(dim, hidden),
167
+ nn.SiLU(),
168
+ nn.Dropout(dropout),
169
+ nn.Linear(hidden, dim),
170
+ nn.Dropout(dropout),
171
+ )
172
+
173
+ def forward(self, x, cond_tokens):
174
+ # cross-attention with residual
175
+ h = self.norm1(x)
176
+ attn_out = self.cross_attn(h, cond_tokens)
177
+ x += attn_out
178
+ # feed-forward with its own normalization and residual
179
+ x += self.mlp(self.norm2(x))
180
+ return x
181
+
182
+
183
+ # --------------------------- Hybrid U-Net CFM --------------------------
184
+ class HybridUNetCFM3D(nn.Module):
185
+ def __init__(self,
186
+ img_size : tuple, # 3d tuple (Axial slices, Coronal slices, Saggital slices)
187
+ in_channels=1,
188
+ channels_per_down=[8,16,32,64],
189
+ n_residuals_blocks=1,
190
+ time_emb_dim=256,
191
+ cond_global_dim=0,
192
+ attn_patch_size=2,
193
+ num_bottleneck_blocks=2):
194
+
195
+ super().__init__()
196
+ self.img_size = img_size #! not really needed!!!
197
+ self.patch_size = attn_patch_size
198
+ self.in_channels = in_channels
199
+ self.chs = channels_per_down
200
+ self.n_residuals_blocks = n_residuals_blocks
201
+
202
+ self.time_mlp = nn.Sequential(SinusoidalPosEmb(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU())
203
+
204
+ # Encoder
205
+ self.res_blocks_encoder = nn.ModuleList(
206
+ TimeSequential(*[Residual3dBlock(ch=c, time_emb_dim=time_emb_dim)
207
+ for _ in range(n_residuals_blocks)])
208
+ for c in [in_channels] + self.chs[:-1]
209
+ )
210
+
211
+ self.downs = nn.ModuleList([Downsample3d(in_ch=ic, out_ch=oc) for ic, oc in zip([in_channels]+self.chs[:-1], self.chs)])
212
+
213
+ assert len(self.res_blocks_encoder) == len(self.downs)
214
+
215
+ # Bottleneck
216
+ bottleneck_ch = self.chs[-1]
217
+ self.bottleneck_res = Residual3dBlock(bottleneck_ch, time_emb_dim)
218
+ self.transformer_blocks = nn.ModuleList([
219
+ CrossTransformerBlock(bottleneck_ch*(self.patch_size**3), cond_global_dim) if cond_global_dim > 0 else SimpleTransformerBlock(bottleneck_ch*(self.patch_size**3))
220
+ for _ in range(num_bottleneck_blocks)
221
+ ])
222
+
223
+ # Up path
224
+ self.ups = nn.ModuleList([Upsample3d(in_ch=ic, out_ch=oc) for ic, oc in zip(reversed(self.chs), reversed([in_channels]+self.chs[:-1]))])
225
+ self.conv_skip_connection_decored = nn.ModuleList([Conv3dBlock(in_ch=oc*2, out_ch=oc) for oc in reversed([in_channels]+self.chs[:-1])])
226
+ self.res_blocks_decoder = nn.ModuleList(
227
+ TimeSequential(*[Residual3dBlock(ch=c, time_emb_dim=time_emb_dim)
228
+ for _ in range(n_residuals_blocks)])
229
+ for c in reversed([in_channels] + self.chs[:-1])
230
+ ) # per ora va ben così, ma poi magari va aggiunto un pos encoding per accettare anche diverse image sizes...!!! Molta fatica sta maledetta attenzione
231
+
232
+ assert len(self.res_blocks_decoder) == len(self.ups) == len(self.conv_skip_connection_decored)
233
+
234
+ pass
235
+
236
+ def patching_and_tokenizer(self, layer, channels, current_img_size):
237
+ B, C, Z, H, W = layer.shape
238
+ assert (C,Z,H,W) == (channels, *current_img_size) # TO REMOVE
239
+ p = self.patch_size
240
+
241
+ layer = layer.unfold(2,p,p).unfold(3,p,p).unfold(4,p,p)
242
+ layer = layer.reshape(B, C, -1, p, p, p).transpose(1,2).flatten(start_dim=2).contiguous() # Ci ho perso tipo 45 minuti per arrivare a questo cazzo di unfolding quindi abbiatene cura (portrebbe non funzionare comunque maledizione)
243
+ return layer
244
+
245
+ def tokens_to_patches(self, tokens, channels, current_img_size):
246
+ B, N, D = tokens.shape
247
+ Z, H, W = current_img_size
248
+ assert N*D == math.prod((channels, *current_img_size)) # TO REMOVE
249
+ p = self.patch_size
250
+
251
+ tokens = tokens.unflatten(-1, (channels, p, p, p)).transpose(1, 2) # (B,3,64,2,2,2)
252
+ tokens = tokens.view(B, channels, Z//p, H//p, W//p, p, p, p).permute(0, 1, 2, 5, 3, 6, 4, 7).reshape(B, channels, *current_img_size)
253
+ assert tokens.shape == (B, channels, Z, H, W) # TO REMOVE
254
+ return tokens
255
+
256
+ def get_parameters_number(self):
257
+ return sum(p.numel() for p in self.parameters())
258
+
259
+ def forward(self, t, z_t, cond_global: Optional[torch.Tensor] = None):
260
+ # z_t: (B, C, Z, H, W)
261
+ t_emb = self.time_mlp(t)
262
+
263
+ # Encoder
264
+ skips = []
265
+ x = z_t
266
+ for i, (res, down) in enumerate(zip(self.res_blocks_encoder, self.downs)):
267
+ x = res(x, t_emb)
268
+ skips.append(x)
269
+ x = down(x)
270
+
271
+ # Bottleneck
272
+ x = self.bottleneck_res(x, t_emb)
273
+
274
+ # tokens
275
+ B, C, Z, H, W = x.shape
276
+ x = self.patching_and_tokenizer(layer=x, channels=C, current_img_size=(Z, H, W)) # (B, N, D), N = (Z//p)*(H//p)*(W//p), D = C*(p**3)
277
+
278
+ # prepare global cond tokens if provided
279
+ cond_tokens = None
280
+ if cond_global is not None:
281
+ # cond_global: (B, cond_dim) -> expand to tokens M=1
282
+ cond_tokens = cond_global.unsqueeze(1) # (B, 1, cond_dim)
283
+ for tb in self.transformer_blocks:
284
+ if isinstance(tb, CrossTransformerBlock) and cond_tokens is not None:
285
+ x = tb(x, cond_tokens)
286
+ else:
287
+ x = tb(x)
288
+ x = self.tokens_to_patches(x, channels=C, current_img_size=(Z, H, W))
289
+
290
+ # Decoder
291
+ for up, conv_skip, res in zip(self.ups, self.conv_skip_connection_decored, self.res_blocks_decoder):
292
+ x = up(x)
293
+ skip = skips.pop()
294
+ x = torch.cat([x, skip], dim=1) # channel concat
295
+ x = conv_skip(x)
296
+ x = res(x, t_emb)
297
+
298
+ return x
299
+
@@ -0,0 +1 @@
1
+ __version__ = '0.0.1'
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.4
2
+ Name: flownets
3
+ Version: 0.0.1
4
+ Summary: A new beginning for my models :)
5
+ Home-page: https://github.com/TommyGiak/FlowNets
6
+ Author: Tommaso Giacometti
7
+ Author-email: tommaso.giak@gmail.com
8
+ License: MIT
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy
16
+ Requires-Dist: torch
17
+ Dynamic: author
18
+ Dynamic: author-email
19
+ Dynamic: classifier
20
+ Dynamic: description
21
+ Dynamic: description-content-type
22
+ Dynamic: home-page
23
+ Dynamic: license
24
+ Dynamic: license-file
25
+ Dynamic: requires-dist
26
+ Dynamic: requires-python
27
+ Dynamic: summary
28
+
29
+ # Here I will insert a better description of the package...
30
+ But this is the null version so you must wait a bit :(
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ flownets/__init__.py
5
+ flownets/model2D.py
6
+ flownets/model3D.py
7
+ flownets/version.py
8
+ flownets.egg-info/PKG-INFO
9
+ flownets.egg-info/SOURCES.txt
10
+ flownets.egg-info/dependency_links.txt
11
+ flownets.egg-info/requires.txt
12
+ flownets.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ numpy
2
+ torch
@@ -0,0 +1 @@
1
+ flownets
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,30 @@
1
+ import os
2
+ from setuptools import setup, find_packages
3
+
4
+ package_name = "flownets"
5
+ HERE = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ version_py = os.path.join(os.path.dirname(__file__), package_name, "version.py")
8
+ version = version = open(version_py).read().split(' ')[-1][1:-1]
9
+ requirements = open(os.path.join(HERE, "requirements.txt")).read().split("\n")
10
+
11
+ setup(
12
+ name=package_name,
13
+ version=version,
14
+ description="A new beginning for my models :)",
15
+ long_description = open(os.path.join(HERE, "flownets", "README.md"), encoding="utf-8").read(),
16
+ long_description_content_type="text/markdown",
17
+ url="https://github.com/TommyGiak/FlowNets",
18
+ author="Tommaso Giacometti",
19
+ author_email="tommaso.giak@gmail.com",
20
+ license="MIT",
21
+ packages=find_packages(exclude=("tests", "docs")),
22
+ include_package_data=True,
23
+ install_requires=requirements,
24
+ python_requires=">=3.8",
25
+ classifiers=[
26
+ "Programming Language :: Python :: 3",
27
+ "License :: OSI Approved :: MIT License",
28
+ "Operating System :: OS Independent",
29
+ ],
30
+ )