flownets 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flownets-0.0.1/LICENSE +21 -0
- flownets-0.0.1/PKG-INFO +30 -0
- flownets-0.0.1/README.md +2 -0
- flownets-0.0.1/flownets/__init__.py +10 -0
- flownets-0.0.1/flownets/model2D.py +300 -0
- flownets-0.0.1/flownets/model3D.py +299 -0
- flownets-0.0.1/flownets/version.py +1 -0
- flownets-0.0.1/flownets.egg-info/PKG-INFO +30 -0
- flownets-0.0.1/flownets.egg-info/SOURCES.txt +12 -0
- flownets-0.0.1/flownets.egg-info/dependency_links.txt +1 -0
- flownets-0.0.1/flownets.egg-info/requires.txt +2 -0
- flownets-0.0.1/flownets.egg-info/top_level.txt +1 -0
- flownets-0.0.1/setup.cfg +4 -0
- flownets-0.0.1/setup.py +30 -0
flownets-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Tommaso Giacometti
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
flownets-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flownets
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A new beginning for my models :)
|
|
5
|
+
Home-page: https://github.com/TommyGiak/FlowNets
|
|
6
|
+
Author: Tommaso Giacometti
|
|
7
|
+
Author-email: tommaso.giak@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: numpy
|
|
16
|
+
Requires-Dist: torch
|
|
17
|
+
Dynamic: author
|
|
18
|
+
Dynamic: author-email
|
|
19
|
+
Dynamic: classifier
|
|
20
|
+
Dynamic: description
|
|
21
|
+
Dynamic: description-content-type
|
|
22
|
+
Dynamic: home-page
|
|
23
|
+
Dynamic: license
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
Dynamic: requires-dist
|
|
26
|
+
Dynamic: requires-python
|
|
27
|
+
Dynamic: summary
|
|
28
|
+
|
|
29
|
+
# Here I will insert a better description of the package...
|
|
30
|
+
But this is the null version so you must wait a bit :(
|
flownets-0.0.1/README.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
# hybrid_unet_cfm_2d.py
|
|
2
|
+
# 2D Hybrid U-Net + Transformer (bottleneck) for Conditional Flow Matching (CFM)
|
|
3
|
+
# Works in latent space (2D volumes). Includes spatial cross-attention for masks
|
|
4
|
+
# and global conditioning for clinical embeddings. Training loop skeleton included.
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
# ---------------------------- Utilities ---------------------------------
|
|
14
|
+
# Sinusoidal time embedding
|
|
15
|
+
class SinusoidalPosEmb(nn.Module):
|
|
16
|
+
def __init__(self, dim):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.dim = dim
|
|
19
|
+
self.half = self.dim // 2
|
|
20
|
+
|
|
21
|
+
# clearer check: dimension must be even
|
|
22
|
+
assert self.dim % 2 == 0, f'The dimension of the Positional embedding must be a multiple of 2, not {dim}!'
|
|
23
|
+
|
|
24
|
+
def forward(self, t):
|
|
25
|
+
# t: (B,) floats in [0,1]
|
|
26
|
+
device = t.device
|
|
27
|
+
# use 'half' in the denominator so we don't divide by zero when half == 1
|
|
28
|
+
freqs = torch.exp(torch.arange(self.half, device=device, dtype=torch.float32) * -(math.log(100.0) / self.half))
|
|
29
|
+
args = t.view(-1,1) * freqs[None, :]
|
|
30
|
+
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
|
31
|
+
return emb
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TimeSequential(nn.Sequential):
|
|
35
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
36
|
+
super().__init__(*args, **kwargs)
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def forward(self, input, t_emb=None):
|
|
40
|
+
for module in self:
|
|
41
|
+
input = module(input, t_emb)
|
|
42
|
+
return input
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# --------------------------- 2D building blocks -------------------------
|
|
46
|
+
class Conv2dBlock(nn.Module):
|
|
47
|
+
def __init__(self, in_ch, out_ch, kernel=3, padding=1, use_norm=True, num_groups=8):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.seq = nn.Sequential()
|
|
50
|
+
if use_norm:
|
|
51
|
+
if num_groups>out_ch:
|
|
52
|
+
num_groups=out_ch
|
|
53
|
+
assert out_ch%num_groups==0, f'In the Conv2dBlock, the number of output channels ({out_ch}) must be a multiple of the number of groups ({num_groups})!'
|
|
54
|
+
self.seq.append(nn.Conv2d(in_ch, out_ch, kernel, padding=padding, bias=False))
|
|
55
|
+
self.seq.append(nn.GroupNorm(num_groups=num_groups, num_channels=out_ch))
|
|
56
|
+
else:
|
|
57
|
+
self.seq.append(nn.Conv2d(in_ch, out_ch, kernel, padding=padding))
|
|
58
|
+
self.seq.append(nn.SiLU())
|
|
59
|
+
|
|
60
|
+
def forward(self, x):
|
|
61
|
+
return self.seq(x)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Residual2dBlock(nn.Module):
|
|
65
|
+
def __init__(self, ch, time_emb_dim=None):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.conv1 = Conv2dBlock(ch, ch)
|
|
68
|
+
self.time_proj = nn.Linear(time_emb_dim, ch) if time_emb_dim is not None else None
|
|
69
|
+
self.conv2 = Conv2dBlock(ch, ch)
|
|
70
|
+
|
|
71
|
+
def forward(self, x, t_emb=None):
|
|
72
|
+
h = self.conv1(x)
|
|
73
|
+
if self.time_proj is not None and t_emb is not None:
|
|
74
|
+
# broadcast time embed -> (B, C, 1, 1)
|
|
75
|
+
proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1)
|
|
76
|
+
h = h + proj
|
|
77
|
+
elif self.time_proj is not None and t_emb is None:
|
|
78
|
+
raise RuntimeError('In the Module is defined a time projection embetting layer but no time embedding was passed in the forward!')
|
|
79
|
+
h = self.conv2(h)
|
|
80
|
+
return x + h
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Simple Downsample/Upsample
|
|
84
|
+
class Downsample2d(nn.Module):
|
|
85
|
+
def __init__(self, in_ch, out_ch=None, conv_layer=True):
|
|
86
|
+
super().__init__()
|
|
87
|
+
if conv_layer:
|
|
88
|
+
out_ch = out_ch or in_ch*2
|
|
89
|
+
self.op = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
|
|
90
|
+
else:
|
|
91
|
+
if out_ch is not None:
|
|
92
|
+
raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
|
|
93
|
+
self.op = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
94
|
+
|
|
95
|
+
def forward(self, x):
|
|
96
|
+
return self.op(x)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Upsample2d(nn.Module):
|
|
100
|
+
def __init__(self, in_ch, out_ch=None, conv_layer=True, upsample_mode='nearest'):
|
|
101
|
+
super().__init__()
|
|
102
|
+
if conv_layer:
|
|
103
|
+
out_ch = out_ch or in_ch//2
|
|
104
|
+
self.op = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
|
|
105
|
+
else:
|
|
106
|
+
if out_ch is not None:
|
|
107
|
+
raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
|
|
108
|
+
self.op = nn.Upsample(scale_factor=2, mode=upsample_mode)
|
|
109
|
+
|
|
110
|
+
def forward(self, x):
|
|
111
|
+
return self.op(x)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# ----------------------- Cross-Attention modules -----------------------
|
|
115
|
+
# A simple tokenized cross-attention where spatial features are flattened
|
|
116
|
+
# to tokens and attend to condition tokens (mask tokens or clinical tokens).
|
|
117
|
+
|
|
118
|
+
class CrossAttention(nn.Module):
|
|
119
|
+
def __init__(self, dim, cond_dim, heads=8, dropout=0.0, bias=False):
|
|
120
|
+
super().__init__()
|
|
121
|
+
if (dim % heads != 0) or (cond_dim % heads != 0):
|
|
122
|
+
raise ValueError(f'The dimensions of the embeddings ({dim} and {cond_dim}) in CrossAttention must be multiples of the number of heads ({heads})!')
|
|
123
|
+
self.mha = nn.MultiheadAttention(num_heads=heads, embed_dim=dim, kdim=cond_dim, vdim=cond_dim,
|
|
124
|
+
dropout=dropout, bias=bias, add_zero_attn=False, batch_first=True)
|
|
125
|
+
|
|
126
|
+
def forward(self, x, cond):
|
|
127
|
+
# return only the attention output; residual connection is applied in the transformer block
|
|
128
|
+
attn_out, _ = self.mha(x, cond, cond, need_weights=False)
|
|
129
|
+
return attn_out
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# ------------------------- Transformer (bottleneck) -------------------
|
|
133
|
+
class SimpleTransformerBlock(nn.Module):
|
|
134
|
+
def __init__(self, dim, heads=8, mlp_ratio=4.0, dropout=0.0):
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
137
|
+
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
|
|
138
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
139
|
+
hidden = int(dim * mlp_ratio)
|
|
140
|
+
self.mlp = nn.Sequential(
|
|
141
|
+
nn.Linear(dim, hidden),
|
|
142
|
+
nn.SiLU(),
|
|
143
|
+
nn.Dropout(dropout),
|
|
144
|
+
nn.Linear(hidden, dim),
|
|
145
|
+
nn.Dropout(dropout),
|
|
146
|
+
)
|
|
147
|
+
warnings.warn('In SimpleTransformerBlock LayerNorm are used, not GroupNorm, but why? IDK maybe we are carzy')
|
|
148
|
+
warnings.warn('The attention mechanism is implemented with torch, you may use flash attention... Flash attention is wanderful!')
|
|
149
|
+
warnings.warn('The attention transformer block do not use ADAPTIVE layer/group norm... We do not like mutants')
|
|
150
|
+
|
|
151
|
+
def forward(self, x):
|
|
152
|
+
# x: (B, N, dim)
|
|
153
|
+
h = self.norm1(x)
|
|
154
|
+
attn_out, _ = self.attn(h, h, h, need_weights=False)
|
|
155
|
+
x = x + attn_out
|
|
156
|
+
x = x + self.mlp(self.norm2(x))
|
|
157
|
+
return x
|
|
158
|
+
|
|
159
|
+
class CrossTransformerBlock(nn.Module):
|
|
160
|
+
def __init__(self, dim, cond_dim, heads=8, mlp_ratio=4.0, dropout=0.0):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
163
|
+
self.cross_attn = CrossAttention(dim, cond_dim, heads=heads, dropout=dropout, bias=False)
|
|
164
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
165
|
+
hidden = int(dim * mlp_ratio)
|
|
166
|
+
self.mlp = nn.Sequential(
|
|
167
|
+
nn.Linear(dim, hidden),
|
|
168
|
+
nn.SiLU(),
|
|
169
|
+
nn.Dropout(dropout),
|
|
170
|
+
nn.Linear(hidden, dim),
|
|
171
|
+
nn.Dropout(dropout),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def forward(self, x, cond_tokens):
|
|
175
|
+
# cross-attention with residual
|
|
176
|
+
h = self.norm1(x)
|
|
177
|
+
attn_out = self.cross_attn(h, cond_tokens)
|
|
178
|
+
x += attn_out
|
|
179
|
+
# feed-forward with its own normalization and residual
|
|
180
|
+
x += self.mlp(self.norm2(x))
|
|
181
|
+
return x
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# --------------------------- Hybrid U-Net CFM --------------------------
|
|
185
|
+
class HybridUNetCFM2D(nn.Module):
|
|
186
|
+
def __init__(self,
|
|
187
|
+
img_size : tuple, # 2d tuple (Axial slices, Coronal slices, Saggital slices)
|
|
188
|
+
in_channels=1,
|
|
189
|
+
channels_per_down=[8,16,32,64],
|
|
190
|
+
n_residuals_blocks=1,
|
|
191
|
+
time_emb_dim=256,
|
|
192
|
+
cond_global_dim=0,
|
|
193
|
+
attn_patch_size=2,
|
|
194
|
+
num_bottleneck_blocks=2):
|
|
195
|
+
|
|
196
|
+
super().__init__()
|
|
197
|
+
self.img_size = img_size #! not really needed!!!
|
|
198
|
+
self.patch_size = attn_patch_size
|
|
199
|
+
self.in_channels = in_channels
|
|
200
|
+
self.chs = channels_per_down
|
|
201
|
+
self.n_residuals_blocks = n_residuals_blocks
|
|
202
|
+
|
|
203
|
+
self.time_mlp = nn.Sequential(SinusoidalPosEmb(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU())
|
|
204
|
+
|
|
205
|
+
# Encoder
|
|
206
|
+
self.res_blocks_encoder = nn.ModuleList(
|
|
207
|
+
TimeSequential(*[Residual2dBlock(ch=c, time_emb_dim=time_emb_dim)
|
|
208
|
+
for _ in range(n_residuals_blocks)])
|
|
209
|
+
for c in [in_channels] + self.chs[:-1]
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
self.downs = nn.ModuleList([Downsample2d(in_ch=ic, out_ch=oc) for ic, oc in zip([in_channels]+self.chs[:-1], self.chs)])
|
|
213
|
+
|
|
214
|
+
assert len(self.res_blocks_encoder) == len(self.downs)
|
|
215
|
+
|
|
216
|
+
# Bottleneck
|
|
217
|
+
bottleneck_ch = self.chs[-1]
|
|
218
|
+
self.bottleneck_res = Residual2dBlock(bottleneck_ch, time_emb_dim)
|
|
219
|
+
self.transformer_blocks = nn.ModuleList([
|
|
220
|
+
CrossTransformerBlock(bottleneck_ch*(self.patch_size**2), cond_global_dim) if cond_global_dim > 0 else SimpleTransformerBlock(bottleneck_ch*(self.patch_size**2))
|
|
221
|
+
for _ in range(num_bottleneck_blocks)
|
|
222
|
+
])
|
|
223
|
+
|
|
224
|
+
# Up path
|
|
225
|
+
self.ups = nn.ModuleList([Upsample2d(in_ch=ic, out_ch=oc) for ic, oc in zip(reversed(self.chs), reversed([in_channels]+self.chs[:-1]))])
|
|
226
|
+
self.conv_skip_connection_decored = nn.ModuleList([Conv2dBlock(in_ch=oc*2, out_ch=oc) for oc in reversed([in_channels]+self.chs[:-1])])
|
|
227
|
+
self.res_blocks_decoder = nn.ModuleList(
|
|
228
|
+
TimeSequential(*[Residual2dBlock(ch=c, time_emb_dim=time_emb_dim)
|
|
229
|
+
for _ in range(n_residuals_blocks)])
|
|
230
|
+
for c in reversed([in_channels] + self.chs[:-1])
|
|
231
|
+
) # per ora va ben così, ma poi magari va aggiunto un pos encoding per accettare anche diverse image sizes...!!! Molta fatica sta maledetta attenzione
|
|
232
|
+
|
|
233
|
+
assert len(self.res_blocks_decoder) == len(self.ups) == len(self.conv_skip_connection_decored)
|
|
234
|
+
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
def patching_and_tokenizer(self, layer, channels, current_img_size):
|
|
238
|
+
B, C, H, W = layer.shape
|
|
239
|
+
assert (C,H,W) == (channels, *current_img_size) # TO REMOVE
|
|
240
|
+
p = self.patch_size
|
|
241
|
+
|
|
242
|
+
layer = layer.unfold(2,p,p).unfold(3,p,p)
|
|
243
|
+
layer = layer.reshape(B, C, -1, p, p).transpose(1,2).flatten(start_dim=2).contiguous() # Ci ho perso tipo 45 minuti per arrivare a questo cazzo di unfolding quindi abbiatene cura (portrebbe non funzionare comunque maledizione)
|
|
244
|
+
return layer
|
|
245
|
+
|
|
246
|
+
def tokens_to_patches(self, tokens, channels, current_img_size):
|
|
247
|
+
B, N, D = tokens.shape
|
|
248
|
+
H, W = current_img_size
|
|
249
|
+
assert N*D == math.prod((channels, *current_img_size)) # TO REMOVE
|
|
250
|
+
p = self.patch_size
|
|
251
|
+
|
|
252
|
+
tokens = tokens.unflatten(-1, (channels, p, p)).transpose(1, 2) # (B,C,64,2,2)
|
|
253
|
+
tokens = tokens.view(B, channels, H//p, W//p, p, p).permute(0, 1, 2, 4, 3, 5).reshape(B, channels, *current_img_size)
|
|
254
|
+
assert tokens.shape == (B, channels, H, W) # TO REMOVE
|
|
255
|
+
return tokens
|
|
256
|
+
|
|
257
|
+
def get_parameters_number(self):
|
|
258
|
+
return sum(p.numel() for p in self.parameters())
|
|
259
|
+
|
|
260
|
+
def forward(self, t, z_t, cond_global: Optional[torch.Tensor] = None):
|
|
261
|
+
# z_t: (B, C, H, W)
|
|
262
|
+
t_emb = self.time_mlp(t)
|
|
263
|
+
|
|
264
|
+
# Encoder
|
|
265
|
+
skips = []
|
|
266
|
+
x = z_t
|
|
267
|
+
for i, (res, down) in enumerate(zip(self.res_blocks_encoder, self.downs)):
|
|
268
|
+
x = res(x, t_emb)
|
|
269
|
+
skips.append(x)
|
|
270
|
+
x = down(x)
|
|
271
|
+
|
|
272
|
+
# Bottleneck
|
|
273
|
+
x = self.bottleneck_res(x, t_emb)
|
|
274
|
+
|
|
275
|
+
# tokens
|
|
276
|
+
B, C, H, W = x.shape
|
|
277
|
+
x = self.patching_and_tokenizer(layer=x, channels=C, current_img_size=(H, W)) # (B, N, D), N = (H//p)*(W//p), D = C*(p**2)
|
|
278
|
+
|
|
279
|
+
# prepare global cond tokens if provided
|
|
280
|
+
cond_tokens = None
|
|
281
|
+
if cond_global is not None:
|
|
282
|
+
# cond_global: (B, cond_dim) -> expand to tokens M=1
|
|
283
|
+
cond_tokens = cond_global.unsqueeze(1) # (B, 1, cond_dim)
|
|
284
|
+
for tb in self.transformer_blocks:
|
|
285
|
+
if isinstance(tb, CrossTransformerBlock) and cond_tokens is not None:
|
|
286
|
+
x = tb(x, cond_tokens)
|
|
287
|
+
else:
|
|
288
|
+
x = tb(x)
|
|
289
|
+
x = self.tokens_to_patches(x, channels=C, current_img_size=(H, W))
|
|
290
|
+
|
|
291
|
+
# Decoder
|
|
292
|
+
for up, conv_skip, res in zip(self.ups, self.conv_skip_connection_decored, self.res_blocks_decoder):
|
|
293
|
+
x = up(x)
|
|
294
|
+
skip = skips.pop()
|
|
295
|
+
x = torch.cat([x, skip], dim=1) # channel concat
|
|
296
|
+
x = conv_skip(x)
|
|
297
|
+
x = res(x, t_emb)
|
|
298
|
+
|
|
299
|
+
return x
|
|
300
|
+
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# hybrid_unet_cfm_3d.py
|
|
2
|
+
# 3D Hybrid U-Net + Transformer (bottleneck) for Conditional Flow Matching (CFM)
|
|
3
|
+
# Works in latent space (3D volumes). Includes spatial cross-attention for masks
|
|
4
|
+
# and global conditioning for clinical embeddings. Training loop skeleton included.
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
# ---------------------------- Utilities ---------------------------------
|
|
14
|
+
# Sinusoidal time embedding
|
|
15
|
+
class SinusoidalPosEmb(nn.Module):
|
|
16
|
+
def __init__(self, dim):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.dim = dim
|
|
19
|
+
self.half = self.dim // 2
|
|
20
|
+
# clearer check: dimension must be even
|
|
21
|
+
assert self.dim % 2 == 0, f'The dimension of the Positional embedding must be a multiple of 2, not {dim}!'
|
|
22
|
+
|
|
23
|
+
def forward(self, t):
|
|
24
|
+
# t: (B,) floats in [0,1]
|
|
25
|
+
device = t.device
|
|
26
|
+
# use 'half' in the denominator so we don't divide by zero when half == 1
|
|
27
|
+
freqs = torch.exp(torch.arange(self.half, device=device, dtype=torch.float32) * -(math.log(100.0) / self.half))
|
|
28
|
+
args = t.view(-1,1) * freqs[None, :]
|
|
29
|
+
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
|
30
|
+
return emb
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TimeSequential(nn.Sequential):
|
|
34
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
35
|
+
super().__init__(*args, **kwargs)
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def forward(self, input, t_emb=None):
|
|
39
|
+
for module in self:
|
|
40
|
+
input = module(input, t_emb)
|
|
41
|
+
return input
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# --------------------------- 3D building blocks -------------------------
|
|
45
|
+
class Conv3dBlock(nn.Module):
|
|
46
|
+
def __init__(self, in_ch, out_ch, kernel=3, padding=1, use_norm=True, num_groups=8):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.seq = nn.Sequential()
|
|
49
|
+
if use_norm:
|
|
50
|
+
if num_groups>out_ch:
|
|
51
|
+
num_groups=out_ch
|
|
52
|
+
assert out_ch%num_groups==0, f'In the Conv3dBlock, the number of output channels ({out_ch}) must be a multiple of the number of groups ({num_groups})!'
|
|
53
|
+
self.seq.append(nn.Conv3d(in_ch, out_ch, kernel, padding=padding, bias=False))
|
|
54
|
+
self.seq.append(nn.GroupNorm(num_groups=num_groups, num_channels=out_ch))
|
|
55
|
+
else:
|
|
56
|
+
self.seq.append(nn.Conv3d(in_ch, out_ch, kernel, padding=padding))
|
|
57
|
+
self.seq.append(nn.SiLU())
|
|
58
|
+
|
|
59
|
+
def forward(self, x):
|
|
60
|
+
return self.seq(x)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Residual3dBlock(nn.Module):
|
|
64
|
+
def __init__(self, ch, time_emb_dim=None):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.conv1 = Conv3dBlock(ch, ch)
|
|
67
|
+
self.time_proj = nn.Linear(time_emb_dim, ch) if time_emb_dim is not None else None
|
|
68
|
+
self.conv2 = Conv3dBlock(ch, ch)
|
|
69
|
+
|
|
70
|
+
def forward(self, x, t_emb=None):
|
|
71
|
+
h = self.conv1(x)
|
|
72
|
+
if self.time_proj is not None and t_emb is not None:
|
|
73
|
+
# broadcast time embed -> (B, C, 1,1,1)
|
|
74
|
+
proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
|
75
|
+
h = h + proj
|
|
76
|
+
elif self.time_proj is not None and t_emb is None:
|
|
77
|
+
raise RuntimeError('In the Module is defined a time projection embetting layer but no time embedding was passed in the forward!')
|
|
78
|
+
h = self.conv2(h)
|
|
79
|
+
return x + h
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# Simple Downsample/Upsample
|
|
83
|
+
class Downsample3d(nn.Module):
|
|
84
|
+
def __init__(self, in_ch, out_ch=None, conv_layer=True):
|
|
85
|
+
super().__init__()
|
|
86
|
+
if conv_layer:
|
|
87
|
+
out_ch = out_ch or in_ch*2
|
|
88
|
+
self.op = nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
|
|
89
|
+
else:
|
|
90
|
+
if out_ch is not None:
|
|
91
|
+
raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
|
|
92
|
+
self.op = nn.MaxPool3d(kernel_size=2, stride=2)
|
|
93
|
+
|
|
94
|
+
def forward(self, x):
|
|
95
|
+
return self.op(x)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Upsample3d(nn.Module):
|
|
99
|
+
def __init__(self, in_ch, out_ch=None, conv_layer=True, upsample_mode='nearest'):
|
|
100
|
+
super().__init__()
|
|
101
|
+
if conv_layer:
|
|
102
|
+
out_ch = out_ch or in_ch//2
|
|
103
|
+
self.op = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
|
|
104
|
+
else:
|
|
105
|
+
if out_ch is not None:
|
|
106
|
+
raise ValueError(f"If you don't use a conv layer the output channels must be the same of the input channels, not {out_ch}")
|
|
107
|
+
self.op = nn.Upsample(scale_factor=2, mode=upsample_mode)
|
|
108
|
+
|
|
109
|
+
def forward(self, x):
|
|
110
|
+
return self.op(x)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ----------------------- Cross-Attention modules -----------------------
|
|
114
|
+
# A simple tokenized cross-attention where spatial features are flattened
|
|
115
|
+
# to tokens and attend to condition tokens (mask tokens or clinical tokens).
|
|
116
|
+
|
|
117
|
+
class CrossAttention(nn.Module):
|
|
118
|
+
def __init__(self, dim, cond_dim, heads=8, dropout=0.0, bias=False):
|
|
119
|
+
super().__init__()
|
|
120
|
+
if (dim % heads != 0) or (cond_dim % heads != 0):
|
|
121
|
+
raise ValueError(f'The dimensions of the embeddings ({dim} and {cond_dim}) in CrossAttention must be multiples of the number of heads ({heads})!')
|
|
122
|
+
self.mha = nn.MultiheadAttention(num_heads=heads, embed_dim=dim, kdim=cond_dim, vdim=cond_dim,
|
|
123
|
+
dropout=dropout, bias=bias, add_zero_attn=False, batch_first=True)
|
|
124
|
+
|
|
125
|
+
def forward(self, x, cond):
|
|
126
|
+
# return only the attention output; residual connection is applied in the transformer block
|
|
127
|
+
attn_out, _ = self.mha(x, cond, cond, need_weights=False)
|
|
128
|
+
return attn_out
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ------------------------- Transformer (bottleneck) -------------------
|
|
132
|
+
class SimpleTransformerBlock(nn.Module):
|
|
133
|
+
def __init__(self, dim, heads=8, mlp_ratio=4.0, dropout=0.0):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
136
|
+
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
|
|
137
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
138
|
+
hidden = int(dim * mlp_ratio)
|
|
139
|
+
self.mlp = nn.Sequential(
|
|
140
|
+
nn.Linear(dim, hidden),
|
|
141
|
+
nn.SiLU(),
|
|
142
|
+
nn.Dropout(dropout),
|
|
143
|
+
nn.Linear(hidden, dim),
|
|
144
|
+
nn.Dropout(dropout),
|
|
145
|
+
)
|
|
146
|
+
warnings.warn('In SimpleTransformerBlock LayerNorm are used, not GroupNorm, but why? IDK maybe we are carzy')
|
|
147
|
+
warnings.warn('The attention mechanism is implemented with torch, you may use flash attention... Flash attention is wanderful!')
|
|
148
|
+
warnings.warn('The attention transformer block do not use ADAPTIVE layer/group norm... We do not like mutants')
|
|
149
|
+
|
|
150
|
+
def forward(self, x):
|
|
151
|
+
# x: (B, N, dim)
|
|
152
|
+
h = self.norm1(x)
|
|
153
|
+
attn_out, _ = self.attn(h, h, h, need_weights=False)
|
|
154
|
+
x = x + attn_out
|
|
155
|
+
x = x + self.mlp(self.norm2(x))
|
|
156
|
+
return x
|
|
157
|
+
|
|
158
|
+
class CrossTransformerBlock(nn.Module):
|
|
159
|
+
def __init__(self, dim, cond_dim, heads=8, mlp_ratio=4.0, dropout=0.0):
|
|
160
|
+
super().__init__()
|
|
161
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
162
|
+
self.cross_attn = CrossAttention(dim, cond_dim, heads=heads, dropout=dropout, bias=False)
|
|
163
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
164
|
+
hidden = int(dim * mlp_ratio)
|
|
165
|
+
self.mlp = nn.Sequential(
|
|
166
|
+
nn.Linear(dim, hidden),
|
|
167
|
+
nn.SiLU(),
|
|
168
|
+
nn.Dropout(dropout),
|
|
169
|
+
nn.Linear(hidden, dim),
|
|
170
|
+
nn.Dropout(dropout),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def forward(self, x, cond_tokens):
|
|
174
|
+
# cross-attention with residual
|
|
175
|
+
h = self.norm1(x)
|
|
176
|
+
attn_out = self.cross_attn(h, cond_tokens)
|
|
177
|
+
x += attn_out
|
|
178
|
+
# feed-forward with its own normalization and residual
|
|
179
|
+
x += self.mlp(self.norm2(x))
|
|
180
|
+
return x
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# --------------------------- Hybrid U-Net CFM --------------------------
|
|
184
|
+
class HybridUNetCFM3D(nn.Module):
|
|
185
|
+
def __init__(self,
|
|
186
|
+
img_size : tuple, # 3d tuple (Axial slices, Coronal slices, Saggital slices)
|
|
187
|
+
in_channels=1,
|
|
188
|
+
channels_per_down=[8,16,32,64],
|
|
189
|
+
n_residuals_blocks=1,
|
|
190
|
+
time_emb_dim=256,
|
|
191
|
+
cond_global_dim=0,
|
|
192
|
+
attn_patch_size=2,
|
|
193
|
+
num_bottleneck_blocks=2):
|
|
194
|
+
|
|
195
|
+
super().__init__()
|
|
196
|
+
self.img_size = img_size #! not really needed!!!
|
|
197
|
+
self.patch_size = attn_patch_size
|
|
198
|
+
self.in_channels = in_channels
|
|
199
|
+
self.chs = channels_per_down
|
|
200
|
+
self.n_residuals_blocks = n_residuals_blocks
|
|
201
|
+
|
|
202
|
+
self.time_mlp = nn.Sequential(SinusoidalPosEmb(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU())
|
|
203
|
+
|
|
204
|
+
# Encoder
|
|
205
|
+
self.res_blocks_encoder = nn.ModuleList(
|
|
206
|
+
TimeSequential(*[Residual3dBlock(ch=c, time_emb_dim=time_emb_dim)
|
|
207
|
+
for _ in range(n_residuals_blocks)])
|
|
208
|
+
for c in [in_channels] + self.chs[:-1]
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self.downs = nn.ModuleList([Downsample3d(in_ch=ic, out_ch=oc) for ic, oc in zip([in_channels]+self.chs[:-1], self.chs)])
|
|
212
|
+
|
|
213
|
+
assert len(self.res_blocks_encoder) == len(self.downs)
|
|
214
|
+
|
|
215
|
+
# Bottleneck
|
|
216
|
+
bottleneck_ch = self.chs[-1]
|
|
217
|
+
self.bottleneck_res = Residual3dBlock(bottleneck_ch, time_emb_dim)
|
|
218
|
+
self.transformer_blocks = nn.ModuleList([
|
|
219
|
+
CrossTransformerBlock(bottleneck_ch*(self.patch_size**3), cond_global_dim) if cond_global_dim > 0 else SimpleTransformerBlock(bottleneck_ch*(self.patch_size**3))
|
|
220
|
+
for _ in range(num_bottleneck_blocks)
|
|
221
|
+
])
|
|
222
|
+
|
|
223
|
+
# Up path
|
|
224
|
+
self.ups = nn.ModuleList([Upsample3d(in_ch=ic, out_ch=oc) for ic, oc in zip(reversed(self.chs), reversed([in_channels]+self.chs[:-1]))])
|
|
225
|
+
self.conv_skip_connection_decored = nn.ModuleList([Conv3dBlock(in_ch=oc*2, out_ch=oc) for oc in reversed([in_channels]+self.chs[:-1])])
|
|
226
|
+
self.res_blocks_decoder = nn.ModuleList(
|
|
227
|
+
TimeSequential(*[Residual3dBlock(ch=c, time_emb_dim=time_emb_dim)
|
|
228
|
+
for _ in range(n_residuals_blocks)])
|
|
229
|
+
for c in reversed([in_channels] + self.chs[:-1])
|
|
230
|
+
) # per ora va ben così, ma poi magari va aggiunto un pos encoding per accettare anche diverse image sizes...!!! Molta fatica sta maledetta attenzione
|
|
231
|
+
|
|
232
|
+
assert len(self.res_blocks_decoder) == len(self.ups) == len(self.conv_skip_connection_decored)
|
|
233
|
+
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
def patching_and_tokenizer(self, layer, channels, current_img_size):
|
|
237
|
+
B, C, Z, H, W = layer.shape
|
|
238
|
+
assert (C,Z,H,W) == (channels, *current_img_size) # TO REMOVE
|
|
239
|
+
p = self.patch_size
|
|
240
|
+
|
|
241
|
+
layer = layer.unfold(2,p,p).unfold(3,p,p).unfold(4,p,p)
|
|
242
|
+
layer = layer.reshape(B, C, -1, p, p, p).transpose(1,2).flatten(start_dim=2).contiguous() # Ci ho perso tipo 45 minuti per arrivare a questo cazzo di unfolding quindi abbiatene cura (portrebbe non funzionare comunque maledizione)
|
|
243
|
+
return layer
|
|
244
|
+
|
|
245
|
+
def tokens_to_patches(self, tokens, channels, current_img_size):
|
|
246
|
+
B, N, D = tokens.shape
|
|
247
|
+
Z, H, W = current_img_size
|
|
248
|
+
assert N*D == math.prod((channels, *current_img_size)) # TO REMOVE
|
|
249
|
+
p = self.patch_size
|
|
250
|
+
|
|
251
|
+
tokens = tokens.unflatten(-1, (channels, p, p, p)).transpose(1, 2) # (B,3,64,2,2,2)
|
|
252
|
+
tokens = tokens.view(B, channels, Z//p, H//p, W//p, p, p, p).permute(0, 1, 2, 5, 3, 6, 4, 7).reshape(B, channels, *current_img_size)
|
|
253
|
+
assert tokens.shape == (B, channels, Z, H, W) # TO REMOVE
|
|
254
|
+
return tokens
|
|
255
|
+
|
|
256
|
+
def get_parameters_number(self):
|
|
257
|
+
return sum(p.numel() for p in self.parameters())
|
|
258
|
+
|
|
259
|
+
def forward(self, t, z_t, cond_global: Optional[torch.Tensor] = None):
|
|
260
|
+
# z_t: (B, C, Z, H, W)
|
|
261
|
+
t_emb = self.time_mlp(t)
|
|
262
|
+
|
|
263
|
+
# Encoder
|
|
264
|
+
skips = []
|
|
265
|
+
x = z_t
|
|
266
|
+
for i, (res, down) in enumerate(zip(self.res_blocks_encoder, self.downs)):
|
|
267
|
+
x = res(x, t_emb)
|
|
268
|
+
skips.append(x)
|
|
269
|
+
x = down(x)
|
|
270
|
+
|
|
271
|
+
# Bottleneck
|
|
272
|
+
x = self.bottleneck_res(x, t_emb)
|
|
273
|
+
|
|
274
|
+
# tokens
|
|
275
|
+
B, C, Z, H, W = x.shape
|
|
276
|
+
x = self.patching_and_tokenizer(layer=x, channels=C, current_img_size=(Z, H, W)) # (B, N, D), N = (Z//p)*(H//p)*(W//p), D = C*(p**3)
|
|
277
|
+
|
|
278
|
+
# prepare global cond tokens if provided
|
|
279
|
+
cond_tokens = None
|
|
280
|
+
if cond_global is not None:
|
|
281
|
+
# cond_global: (B, cond_dim) -> expand to tokens M=1
|
|
282
|
+
cond_tokens = cond_global.unsqueeze(1) # (B, 1, cond_dim)
|
|
283
|
+
for tb in self.transformer_blocks:
|
|
284
|
+
if isinstance(tb, CrossTransformerBlock) and cond_tokens is not None:
|
|
285
|
+
x = tb(x, cond_tokens)
|
|
286
|
+
else:
|
|
287
|
+
x = tb(x)
|
|
288
|
+
x = self.tokens_to_patches(x, channels=C, current_img_size=(Z, H, W))
|
|
289
|
+
|
|
290
|
+
# Decoder
|
|
291
|
+
for up, conv_skip, res in zip(self.ups, self.conv_skip_connection_decored, self.res_blocks_decoder):
|
|
292
|
+
x = up(x)
|
|
293
|
+
skip = skips.pop()
|
|
294
|
+
x = torch.cat([x, skip], dim=1) # channel concat
|
|
295
|
+
x = conv_skip(x)
|
|
296
|
+
x = res(x, t_emb)
|
|
297
|
+
|
|
298
|
+
return x
|
|
299
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.0.1'
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flownets
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A new beginning for my models :)
|
|
5
|
+
Home-page: https://github.com/TommyGiak/FlowNets
|
|
6
|
+
Author: Tommaso Giacometti
|
|
7
|
+
Author-email: tommaso.giak@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: numpy
|
|
16
|
+
Requires-Dist: torch
|
|
17
|
+
Dynamic: author
|
|
18
|
+
Dynamic: author-email
|
|
19
|
+
Dynamic: classifier
|
|
20
|
+
Dynamic: description
|
|
21
|
+
Dynamic: description-content-type
|
|
22
|
+
Dynamic: home-page
|
|
23
|
+
Dynamic: license
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
Dynamic: requires-dist
|
|
26
|
+
Dynamic: requires-python
|
|
27
|
+
Dynamic: summary
|
|
28
|
+
|
|
29
|
+
# Here I will insert a better description of the package...
|
|
30
|
+
But this is the null version so you must wait a bit :(
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
setup.py
|
|
4
|
+
flownets/__init__.py
|
|
5
|
+
flownets/model2D.py
|
|
6
|
+
flownets/model3D.py
|
|
7
|
+
flownets/version.py
|
|
8
|
+
flownets.egg-info/PKG-INFO
|
|
9
|
+
flownets.egg-info/SOURCES.txt
|
|
10
|
+
flownets.egg-info/dependency_links.txt
|
|
11
|
+
flownets.egg-info/requires.txt
|
|
12
|
+
flownets.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
flownets
|
flownets-0.0.1/setup.cfg
ADDED
flownets-0.0.1/setup.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from setuptools import setup, find_packages
|
|
3
|
+
|
|
4
|
+
package_name = "flownets"
|
|
5
|
+
HERE = os.path.dirname(os.path.abspath(__file__))
|
|
6
|
+
|
|
7
|
+
version_py = os.path.join(os.path.dirname(__file__), package_name, "version.py")
|
|
8
|
+
version = version = open(version_py).read().split(' ')[-1][1:-1]
|
|
9
|
+
requirements = open(os.path.join(HERE, "requirements.txt")).read().split("\n")
|
|
10
|
+
|
|
11
|
+
setup(
|
|
12
|
+
name=package_name,
|
|
13
|
+
version=version,
|
|
14
|
+
description="A new beginning for my models :)",
|
|
15
|
+
long_description = open(os.path.join(HERE, "flownets", "README.md"), encoding="utf-8").read(),
|
|
16
|
+
long_description_content_type="text/markdown",
|
|
17
|
+
url="https://github.com/TommyGiak/FlowNets",
|
|
18
|
+
author="Tommaso Giacometti",
|
|
19
|
+
author_email="tommaso.giak@gmail.com",
|
|
20
|
+
license="MIT",
|
|
21
|
+
packages=find_packages(exclude=("tests", "docs")),
|
|
22
|
+
include_package_data=True,
|
|
23
|
+
install_requires=requirements,
|
|
24
|
+
python_requires=">=3.8",
|
|
25
|
+
classifiers=[
|
|
26
|
+
"Programming Language :: Python :: 3",
|
|
27
|
+
"License :: OSI Approved :: MIT License",
|
|
28
|
+
"Operating System :: OS Independent",
|
|
29
|
+
],
|
|
30
|
+
)
|