imgen-toolbox 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imgen_toolbox-0.1.0/LICENSE +21 -0
- imgen_toolbox-0.1.0/PKG-INFO +21 -0
- imgen_toolbox-0.1.0/README.md +2 -0
- imgen_toolbox-0.1.0/imgen_toolbox/__init__.py +15 -0
- imgen_toolbox-0.1.0/imgen_toolbox/data/__init__.py +0 -0
- imgen_toolbox-0.1.0/imgen_toolbox/data/datasets.py +99 -0
- imgen_toolbox-0.1.0/imgen_toolbox/losses/__init__.py +0 -0
- imgen_toolbox-0.1.0/imgen_toolbox/losses/gan_loss.py +64 -0
- imgen_toolbox-0.1.0/imgen_toolbox/models/__init__.py +0 -0
- imgen_toolbox-0.1.0/imgen_toolbox/models/patchGAN_discriminator.py +45 -0
- imgen_toolbox-0.1.0/imgen_toolbox/models/unet_generator.py +86 -0
- imgen_toolbox-0.1.0/imgen_toolbox/nn/__init__.py +0 -0
- imgen_toolbox-0.1.0/imgen_toolbox/nn/blocks.py +253 -0
- imgen_toolbox-0.1.0/imgen_toolbox/nn/embeddings.py +34 -0
- imgen_toolbox-0.1.0/imgen_toolbox.egg-info/PKG-INFO +21 -0
- imgen_toolbox-0.1.0/imgen_toolbox.egg-info/SOURCES.txt +22 -0
- imgen_toolbox-0.1.0/imgen_toolbox.egg-info/dependency_links.txt +1 -0
- imgen_toolbox-0.1.0/imgen_toolbox.egg-info/requires.txt +2 -0
- imgen_toolbox-0.1.0/imgen_toolbox.egg-info/top_level.txt +3 -0
- imgen_toolbox-0.1.0/pyproject.toml +35 -0
- imgen_toolbox-0.1.0/setup.cfg +4 -0
- imgen_toolbox-0.1.0/tests/test_blocks.py +0 -0
- imgen_toolbox-0.1.0/tests/test_init.py +0 -0
- imgen_toolbox-0.1.0/tests/test_lossess.py +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Simone Santoro
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: imgen_toolbox
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Repository containing building blocks and utilities for flexible Generative Neural Networks implementation
|
|
5
|
+
Author-email: Simone Santoro <simone_santoro21@outlook.it>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/SimoneSantoro21/imgen_toolbox
|
|
8
|
+
Classifier: Development Status :: 1 - Planning
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Requires-Python: >=3.11
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: numpy
|
|
17
|
+
Requires-Dist: nibabel
|
|
18
|
+
Dynamic: license-file
|
|
19
|
+
|
|
20
|
+
# imgen_toolbox
|
|
21
|
+
Repository containing building blocks and utilities for flexible Generative Neural Networks implementation
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .nn.embeddings import FiLM
|
|
2
|
+
from .nn.blocks import conv, Downsampling, Upsampling, UnetBlock
|
|
3
|
+
from .models.unet_generator import Generator_Unet3D
|
|
4
|
+
from .models.patchGAN_discriminator import PatchGAN_Discriminator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"FiLM",
|
|
9
|
+
"conv",
|
|
10
|
+
"Downsampling",
|
|
11
|
+
"Upsampling",
|
|
12
|
+
"UnetBlock",
|
|
13
|
+
"Generator_Unet3D",
|
|
14
|
+
"PatchGAN_Discriminator"
|
|
15
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import Dataset
|
|
3
|
+
import nibabel as nib
|
|
4
|
+
import numpy as np
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
class Dataset3D(Dataset):
|
|
9
|
+
"""
|
|
10
|
+
Unified dataset over ALL center indices in one folder.
|
|
11
|
+
|
|
12
|
+
Expects two subfolders:
|
|
13
|
+
- CENTERS: center_{patientid}_{center_idx}.nii.gz
|
|
14
|
+
- NEIGHBORS: neighbor_{patientid}_{center_idx}.nii.gz
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
(neighbor_tensor, center_tensor)
|
|
18
|
+
neighbor_tensor: [C, D, H, W] (C can be 1 or >1)
|
|
19
|
+
center_tensor: [1, D, H, W] (or [C, D, H, W] if your centers are multi-channel too)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, root_path: str, transform=None):
|
|
23
|
+
self.centers_dir = os.path.join(root_path, "CENTERS")
|
|
24
|
+
self.neighbors_dir = os.path.join(root_path, "NEIGHBORS")
|
|
25
|
+
self.transform = transform
|
|
26
|
+
|
|
27
|
+
pat_center = re.compile(r'^center_([^_]+)_(\d{1,2})\.nii\.gz$')
|
|
28
|
+
pat_neighbor = re.compile(r'^neighbor_([^_]+)_(\d{1,2})\.nii\.gz$')
|
|
29
|
+
|
|
30
|
+
center_map = {}
|
|
31
|
+
neighbor_map = {}
|
|
32
|
+
|
|
33
|
+
for fname in os.listdir(self.centers_dir):
|
|
34
|
+
m = pat_center.match(fname)
|
|
35
|
+
if not m:
|
|
36
|
+
continue
|
|
37
|
+
pid, ci = m.group(1), int(m.group(2))
|
|
38
|
+
center_map[(ci, pid)] = os.path.join(self.centers_dir, fname)
|
|
39
|
+
|
|
40
|
+
for fname in os.listdir(self.neighbors_dir):
|
|
41
|
+
m = pat_neighbor.match(fname)
|
|
42
|
+
if not m:
|
|
43
|
+
continue
|
|
44
|
+
pid, ci = m.group(1), int(m.group(2))
|
|
45
|
+
neighbor_map[(ci, pid)] = os.path.join(self.neighbors_dir, fname)
|
|
46
|
+
|
|
47
|
+
self.samples = []
|
|
48
|
+
for key, cpath in center_map.items():
|
|
49
|
+
if key in neighbor_map:
|
|
50
|
+
npath = neighbor_map[key]
|
|
51
|
+
center_idx, patient_id = key
|
|
52
|
+
self.samples.append((center_idx, patient_id, cpath, npath))
|
|
53
|
+
|
|
54
|
+
self.samples.sort(key=lambda x: (x[0], x[1]))
|
|
55
|
+
|
|
56
|
+
def __len__(self):
|
|
57
|
+
return len(self.samples)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _load_nii(path: str) -> np.ndarray:
|
|
61
|
+
arr = nib.load(path).get_fdata().astype(np.float32)
|
|
62
|
+
return arr
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _to_torch_volume(arr: np.ndarray) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
Accepts:
|
|
68
|
+
- 3D: [H, W, D] (nibabel typical)
|
|
69
|
+
- 4D: [H, W, D, C] (channels-last, e.g. axis=-1 stacking)
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
- 3D -> [1, D, H, W]
|
|
73
|
+
- 4D -> [C, D, H, W]
|
|
74
|
+
"""
|
|
75
|
+
t = torch.from_numpy(arr)
|
|
76
|
+
|
|
77
|
+
if t.ndim == 3:
|
|
78
|
+
# [H, W, D] -> [D, H, W] -> [1, D, H, W]
|
|
79
|
+
t = t.permute(2, 0, 1).contiguous()
|
|
80
|
+
t = t.unsqueeze(0)
|
|
81
|
+
return t
|
|
82
|
+
|
|
83
|
+
if t.ndim == 4:
|
|
84
|
+
# [H, W, D, C] -> [C, D, H, W]
|
|
85
|
+
t = t.permute(3, 2, 0, 1).contiguous()
|
|
86
|
+
return t
|
|
87
|
+
|
|
88
|
+
raise ValueError(f"Expected 3D or 4D volume, got shape {tuple(t.shape)}")
|
|
89
|
+
|
|
90
|
+
def __getitem__(self, idx):
|
|
91
|
+
center_idx, patient_id, cpath, npath = self.samples[idx]
|
|
92
|
+
|
|
93
|
+
center_arr = self._load_nii(cpath)
|
|
94
|
+
neighbor_arr = self._load_nii(npath)
|
|
95
|
+
|
|
96
|
+
center_tensor = self._to_torch_volume(center_arr)
|
|
97
|
+
neighbor_tensor = self._to_torch_volume(neighbor_arr)
|
|
98
|
+
|
|
99
|
+
return neighbor_tensor, center_tensor
|
|
File without changes
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class GANLoss(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Implementation of GAN Loss class.
|
|
8
|
+
|
|
9
|
+
It supports both vanilla GAN loss (using BCEWithLogitsLoss) and LSGAN (using MSELoss)
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
|
12
|
+
"""
|
|
13
|
+
Initializing GANLoss.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
gan_mode (str): Type of GAN objective ('vanilla' or 'lsgan').
|
|
17
|
+
target_real_label (float): Label for real images.
|
|
18
|
+
target_fake_label (float): Label for fake images.
|
|
19
|
+
"""
|
|
20
|
+
super(GANLoss, self).__init__()
|
|
21
|
+
# Register buffers so that these tensors are part of the state but not learnable.
|
|
22
|
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
|
23
|
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
|
24
|
+
self.gan_mode = gan_mode
|
|
25
|
+
if gan_mode == 'lsgan':
|
|
26
|
+
self.loss = nn.MSELoss()
|
|
27
|
+
elif gan_mode == 'vanilla':
|
|
28
|
+
self.loss = nn.BCEWithLogitsLoss()
|
|
29
|
+
else:
|
|
30
|
+
raise NotImplementedError('gan mode {} not implemented'.format(gan_mode))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_target_tensor(self, prediction, target_is_real):
|
|
34
|
+
"""
|
|
35
|
+
Create label tensors with the same size as the prediction.
|
|
36
|
+
|
|
37
|
+
Parameters:
|
|
38
|
+
prediction (torch.Tensor): The output from the discriminator.
|
|
39
|
+
target_is_real (bool): Whether the ground truth label is for real images.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
torch.Tensor: Target tensor filled with either real or fake label.
|
|
43
|
+
"""
|
|
44
|
+
if target_is_real:
|
|
45
|
+
target_tensor = self.real_label
|
|
46
|
+
else:
|
|
47
|
+
target_tensor = self.fake_label
|
|
48
|
+
return target_tensor.expand_as(prediction)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def forward(self, prediction, target_is_real):
|
|
52
|
+
"""
|
|
53
|
+
Calculate loss given discriminator's output and ground truth labels.
|
|
54
|
+
|
|
55
|
+
Parameters:
|
|
56
|
+
prediction (torch.Tensor): The discriminator output.
|
|
57
|
+
target_is_real (bool): True if ground truth label is for real images.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
torch.Tensor: The calculated loss.
|
|
61
|
+
"""
|
|
62
|
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
|
63
|
+
loss = self.loss(prediction, target_tensor)
|
|
64
|
+
return loss
|
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from ..nn.blocks import Downsampling
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PatchGAN_Discriminator(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
PatchGAN discriminator modified for 128x128 images.
|
|
8
|
+
|
|
9
|
+
The discriminator supports two modes:
|
|
10
|
+
- patch_size=70: A 70×70 PatchGAN variant.
|
|
11
|
+
- patch_size=16: A 16×16 PatchGAN variant.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, input_nc, ndf = 64, patch_size = 70, dimensionality = 3):
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
if patch_size == 70:
|
|
18
|
+
self.model = nn.Sequential(
|
|
19
|
+
Downsampling(input_nc, ndf, dimensionality=dimensionality, normalization=None,
|
|
20
|
+
activation=True, kernel_size = 4, stride = 2, padding = 1),
|
|
21
|
+
Downsampling(ndf, ndf*2, dimensionality=dimensionality, normalization='batch',
|
|
22
|
+
activation=True, kernel_size = 4, stride = 2, padding = 1),
|
|
23
|
+
Downsampling(ndf*2, ndf*4, dimensionality=dimensionality, normalization='batch',
|
|
24
|
+
activation=True, kernel_size = 4, stride = 2, padding = 1),
|
|
25
|
+
Downsampling(ndf*4, ndf*8, dimensionality=dimensionality, normalization='batch',
|
|
26
|
+
activation=True, kernel_size = 4, stride = 1, padding = 1),
|
|
27
|
+
Downsampling(ndf*8, 1, dimensionality=dimensionality, normalization=None,
|
|
28
|
+
activation=False, kernel_size = 4, stride = 1, padding = 1),
|
|
29
|
+
)
|
|
30
|
+
elif patch_size == 16:
|
|
31
|
+
self.model = nn.Sequential(
|
|
32
|
+
Downsampling(input_nc, ndf, dimensionality=dimensionality, normalization=None,
|
|
33
|
+
activation=True, kernel_size=4, stride=2, padding=1),
|
|
34
|
+
Downsampling(ndf, ndf*2, dimensionality=dimensionality, normalization='batch',
|
|
35
|
+
activation=True, kernel_size=4, stride=2, padding=1),
|
|
36
|
+
Downsampling(ndf*2, ndf*4, dimensionality=dimensionality, normalization='batch',
|
|
37
|
+
activation=True, kernel_size=4, stride=2, padding=1),
|
|
38
|
+
Downsampling(ndf*4, 1, dimensionality=dimensionality, normalization=None,
|
|
39
|
+
activation=False, kernel_size=4, stride=1, padding=1),
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError("Unsupported patch size. Please choose patch_size=70 or patch_size=16.")
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
return self.model(x)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from ..nn.embeddings import MetaMLP
|
|
4
|
+
from ..nn.blocks import UnetBlock
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Generator_Unet3D(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
3D U-Net generator (pix2pix-style) with FiLM metadata conditioning.
|
|
10
|
+
|
|
11
|
+
Input:
|
|
12
|
+
x: (B, in_ch, D, H, W)
|
|
13
|
+
meta: (B, 6) -> [center_dir(3), neighbor_dir(3)]
|
|
14
|
+
|
|
15
|
+
Output:
|
|
16
|
+
y: (B, out_ch, D, H, W)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self, in_ch, out_ch, ngf=64, dimensionality=3, use_dropout=True,
|
|
21
|
+
meta_dim=6, emb_dim=128, use_film=True
|
|
22
|
+
):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.use_film = use_film
|
|
25
|
+
|
|
26
|
+
# metadata embedding
|
|
27
|
+
self.meta_mlp = MetaMLP(meta_dim=meta_dim, emb_dim=emb_dim) if use_film else None
|
|
28
|
+
|
|
29
|
+
# anisotropic-ish params
|
|
30
|
+
params = [
|
|
31
|
+
dict(kernel_size=(4,4,4), stride=(2,2,2), padding=(1,1,1)), # 80->40, 128->64
|
|
32
|
+
dict(kernel_size=(4,4,4), stride=(2,2,2), padding=(1,1,1)), # 40->20, 64->32
|
|
33
|
+
dict(kernel_size=(4,4,4), stride=(2,2,2), padding=(1,1,1)), # 20->10, 32->16
|
|
34
|
+
dict(kernel_size=(4,4,4), stride=(2,2,2), padding=(1,1,1)), # 10->5, 16->8
|
|
35
|
+
dict(kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1)), # keep D=5, halve H,W: 8->4
|
|
36
|
+
dict(kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1)), # keep D=5, 4->2
|
|
37
|
+
dict(kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1)), # keep D=5, 2->1
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
# build from innermost outward
|
|
41
|
+
unet_block = UnetBlock(
|
|
42
|
+
ngf*8, ngf*8, input_nc=None, submodule=None,
|
|
43
|
+
is_innermost=True, dimensionality=dimensionality,
|
|
44
|
+
use_film=use_film, emb_dim=emb_dim, **params[-1]
|
|
45
|
+
)
|
|
46
|
+
unet_block = UnetBlock(
|
|
47
|
+
ngf*8, ngf*8, input_nc=None, submodule=unet_block,
|
|
48
|
+
use_dropout=use_dropout, dimensionality=dimensionality,
|
|
49
|
+
use_film=use_film, emb_dim=emb_dim, **params[-2]
|
|
50
|
+
)
|
|
51
|
+
unet_block = UnetBlock(
|
|
52
|
+
ngf*8, ngf*8, input_nc=None, submodule=unet_block,
|
|
53
|
+
use_dropout=use_dropout, dimensionality=dimensionality,
|
|
54
|
+
use_film=use_film, emb_dim=emb_dim, **params[-3]
|
|
55
|
+
)
|
|
56
|
+
unet_block = UnetBlock(
|
|
57
|
+
ngf*4, ngf*8, input_nc=None, submodule=unet_block,
|
|
58
|
+
dimensionality=dimensionality,
|
|
59
|
+
use_film=use_film, emb_dim=emb_dim, **params[-4]
|
|
60
|
+
)
|
|
61
|
+
unet_block = UnetBlock(
|
|
62
|
+
ngf*2, ngf*4, input_nc=None, submodule=unet_block,
|
|
63
|
+
dimensionality=dimensionality,
|
|
64
|
+
use_film=use_film, emb_dim=emb_dim, **params[-5]
|
|
65
|
+
)
|
|
66
|
+
unet_block = UnetBlock(
|
|
67
|
+
ngf, ngf*2, input_nc=None, submodule=unet_block,
|
|
68
|
+
dimensionality=dimensionality,
|
|
69
|
+
use_film=use_film, emb_dim=emb_dim, **params[-6]
|
|
70
|
+
)
|
|
71
|
+
unet_block = UnetBlock(
|
|
72
|
+
out_ch, ngf, input_nc=in_ch, submodule=unet_block,
|
|
73
|
+
is_outermost=True, dimensionality=dimensionality,
|
|
74
|
+
use_film=use_film, emb_dim=emb_dim, **params[-7]
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.model = unet_block
|
|
78
|
+
|
|
79
|
+
def forward(self, x, meta=None):
|
|
80
|
+
if self.use_film:
|
|
81
|
+
if meta is None:
|
|
82
|
+
raise ValueError("meta must be provided when use_film=True")
|
|
83
|
+
emb = self.meta_mlp(meta)
|
|
84
|
+
return self.model(x, emb)
|
|
85
|
+
else:
|
|
86
|
+
return self.model(x)
|
|
File without changes
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .embeddings import FiLM
|
|
4
|
+
|
|
5
|
+
#-------------------------------- BASIC OPERATIONS -----------------------------------------
|
|
6
|
+
|
|
7
|
+
def conv(dimensionality, in_channels, out_channels, **kwargs):
|
|
8
|
+
if dimensionality == 1:
|
|
9
|
+
return nn.Conv1d(in_channels, out_channels, **kwargs)
|
|
10
|
+
elif dimensionality == 2:
|
|
11
|
+
return nn.Conv2d(in_channels, out_channels, **kwargs)
|
|
12
|
+
elif dimensionality == 3:
|
|
13
|
+
return nn.Conv3d(in_channels, out_channels, **kwargs)
|
|
14
|
+
else:
|
|
15
|
+
raise ValueError("dimensionality must be 1, 2, or 3")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def transposed_conv(dimensionality, in_channels, out_channels, **kwargs):
|
|
20
|
+
if dimensionality == 1:
|
|
21
|
+
return nn.ConvTranspose1d(in_channels, out_channels, **kwargs)
|
|
22
|
+
elif dimensionality == 2:
|
|
23
|
+
return nn.ConvTranspose2d(in_channels, out_channels, **kwargs)
|
|
24
|
+
elif dimensionality == 3:
|
|
25
|
+
return nn.ConvTranspose3d(in_channels, out_channels, **kwargs)
|
|
26
|
+
else:
|
|
27
|
+
raise ValueError("dimensionality must be 1, 2, or 3")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
#-------------------- NORMALIZATION OPTIONS AND ACTIVATION FUNCTIONS ----------------------
|
|
32
|
+
|
|
33
|
+
class Normalization(nn.Module):
|
|
34
|
+
"""
|
|
35
|
+
Class for defining the normalization operation to apply in network layers
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, num_channels, normalization = "batch", dimensionality = 2, **kwargs):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.normalization = normalization
|
|
42
|
+
|
|
43
|
+
if normalization == "batch":
|
|
44
|
+
if dimensionality==1:
|
|
45
|
+
self.out = nn.BatchNorm1d(num_channels, **kwargs)
|
|
46
|
+
if dimensionality==2:
|
|
47
|
+
self.out = nn.BatchNorm2d(num_channels, **kwargs)
|
|
48
|
+
if dimensionality==3:
|
|
49
|
+
self.out = nn.BatchNorm3d(num_channels, **kwargs)
|
|
50
|
+
|
|
51
|
+
def forward(self, x):
|
|
52
|
+
return self.out(x)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Activation(nn.Module):
|
|
56
|
+
"""
|
|
57
|
+
Class for defining the Activation layer
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, activation, **kwargs):
|
|
61
|
+
super().__init__()
|
|
62
|
+
|
|
63
|
+
if activation == "relu":
|
|
64
|
+
self.out = nn.ReLU(*kwargs)
|
|
65
|
+
if activation == "leaky_relu":
|
|
66
|
+
self.out = nn.LeakyReLU(**kwargs)
|
|
67
|
+
|
|
68
|
+
def forward(self, x):
|
|
69
|
+
return self.out(x)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
#------------------------------- UPSAMPLING AND DOWNSAMPLING BLOCKS -------------------------
|
|
73
|
+
|
|
74
|
+
class Downsampling(nn.Module):
|
|
75
|
+
"""
|
|
76
|
+
Convolution based downsampling block. Works for both 2D and 3D images.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, in_ch, out_ch, dimensionality = 2, normalization = None,
|
|
80
|
+
activation = False, **kwargs):
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
layers = []
|
|
84
|
+
|
|
85
|
+
layers.append(conv(dimensionality, in_ch, out_ch, **kwargs))
|
|
86
|
+
if normalization:
|
|
87
|
+
layers.append(Normalization(out_ch, normalization, dimensionality))
|
|
88
|
+
if activation:
|
|
89
|
+
layers.append(nn.LeakyReLU(0.2, False))
|
|
90
|
+
|
|
91
|
+
self.downsampling = nn.Sequential(*layers)
|
|
92
|
+
|
|
93
|
+
def forward(self, x):
|
|
94
|
+
return self.downsampling(x)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Upsampling(nn.Module):
|
|
98
|
+
"""
|
|
99
|
+
Convolution based upsampling block. Works for both 2D and 3D images.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, in_ch, out_ch, dimensionality = 2, normalization = None,
|
|
103
|
+
activation = False, dropout = False, **kwargs):
|
|
104
|
+
super().__init__()
|
|
105
|
+
|
|
106
|
+
layers = []
|
|
107
|
+
|
|
108
|
+
if activation:
|
|
109
|
+
layers.append(nn.ReLU(inplace=False))
|
|
110
|
+
|
|
111
|
+
layers.append(transposed_conv(dimensionality, in_ch, out_ch, **kwargs))
|
|
112
|
+
|
|
113
|
+
if normalization:
|
|
114
|
+
layers.append(Normalization(out_ch, normalization, dimensionality))
|
|
115
|
+
|
|
116
|
+
if dropout:
|
|
117
|
+
layers.append(nn.Dropout(0.5))
|
|
118
|
+
|
|
119
|
+
self.upsampling = nn.Sequential(*layers)
|
|
120
|
+
|
|
121
|
+
def forward(self, x):
|
|
122
|
+
return self.upsampling(x)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
#--------------------------------- UNET SKIPCONNECTION BLOCK -------------------------
|
|
126
|
+
|
|
127
|
+
class UnetBlock(nn.Module):
|
|
128
|
+
"""
|
|
129
|
+
U-Net block with skip connection.
|
|
130
|
+
This version supports optional FiLM conditioning via an embedding vector.
|
|
131
|
+
|
|
132
|
+
Forward signatures:
|
|
133
|
+
- unconditional: y = block(x)
|
|
134
|
+
- conditional: y = block(x, emb)
|
|
135
|
+
|
|
136
|
+
Notes:
|
|
137
|
+
- For non-outermost blocks we return cat([x, up(...)], dim=1) as usual.
|
|
138
|
+
- FiLM is applied after the DOWN path and after the UP path (when enabled).
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self, outer_nc, inner_nc, input_nc=None, submodule=None,
|
|
143
|
+
is_innermost=False, is_outermost=False, norm_layer="batch",
|
|
144
|
+
use_dropout=False, dimensionality=2, kernel_size=4, stride=2, padding=1,
|
|
145
|
+
use_film=False, emb_dim=128
|
|
146
|
+
):
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.is_outermost = is_outermost
|
|
149
|
+
self.is_innermost = is_innermost
|
|
150
|
+
self.use_dropout = use_dropout
|
|
151
|
+
self.use_film = use_film
|
|
152
|
+
|
|
153
|
+
if input_nc is None:
|
|
154
|
+
input_nc = outer_nc
|
|
155
|
+
|
|
156
|
+
if dimensionality == 3:
|
|
157
|
+
if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size)
|
|
158
|
+
if isinstance(stride, int): stride = (stride, stride, stride)
|
|
159
|
+
if isinstance(padding, int): padding = (padding, padding, padding)
|
|
160
|
+
|
|
161
|
+
# --- layers ---
|
|
162
|
+
if is_outermost:
|
|
163
|
+
self.down = Downsampling(
|
|
164
|
+
input_nc, inner_nc,
|
|
165
|
+
dimensionality=dimensionality,
|
|
166
|
+
normalization=None,
|
|
167
|
+
activation=False,
|
|
168
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
169
|
+
)
|
|
170
|
+
self.submodule = submodule
|
|
171
|
+
self.up = Upsampling(
|
|
172
|
+
inner_nc * 2, outer_nc,
|
|
173
|
+
dimensionality=dimensionality,
|
|
174
|
+
normalization=None,
|
|
175
|
+
activation=True,
|
|
176
|
+
dropout=False,
|
|
177
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
178
|
+
)
|
|
179
|
+
self.out_act = nn.Sigmoid()
|
|
180
|
+
|
|
181
|
+
elif is_innermost:
|
|
182
|
+
self.down = Downsampling(
|
|
183
|
+
input_nc, inner_nc,
|
|
184
|
+
dimensionality=dimensionality,
|
|
185
|
+
normalization=None,
|
|
186
|
+
activation=True,
|
|
187
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
188
|
+
)
|
|
189
|
+
self.submodule = None
|
|
190
|
+
self.up = Upsampling(
|
|
191
|
+
inner_nc, outer_nc,
|
|
192
|
+
dimensionality=dimensionality,
|
|
193
|
+
normalization=norm_layer,
|
|
194
|
+
activation=True,
|
|
195
|
+
dropout=False,
|
|
196
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
197
|
+
)
|
|
198
|
+
self.out_act = None
|
|
199
|
+
|
|
200
|
+
else:
|
|
201
|
+
self.down = Downsampling(
|
|
202
|
+
input_nc, inner_nc,
|
|
203
|
+
dimensionality=dimensionality,
|
|
204
|
+
normalization=norm_layer,
|
|
205
|
+
activation=True,
|
|
206
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
207
|
+
)
|
|
208
|
+
self.submodule = submodule
|
|
209
|
+
self.up = Upsampling(
|
|
210
|
+
inner_nc * 2, outer_nc,
|
|
211
|
+
dimensionality=dimensionality,
|
|
212
|
+
normalization=norm_layer,
|
|
213
|
+
activation=True,
|
|
214
|
+
dropout=False,
|
|
215
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
216
|
+
)
|
|
217
|
+
self.dropout = nn.Dropout(0.5) if use_dropout else None
|
|
218
|
+
self.out_act = None
|
|
219
|
+
|
|
220
|
+
# --- FiLM modules (optional) ---
|
|
221
|
+
if use_film:
|
|
222
|
+
self.film_down = FiLM(emb_dim=emb_dim, n_ch=inner_nc)
|
|
223
|
+
self.film_up = FiLM(emb_dim=emb_dim, n_ch=outer_nc)
|
|
224
|
+
else:
|
|
225
|
+
self.film_down = None
|
|
226
|
+
self.film_up = None
|
|
227
|
+
|
|
228
|
+
def forward(self, x, emb=None):
|
|
229
|
+
# Down
|
|
230
|
+
h = self.down(x)
|
|
231
|
+
if self.film_down is not None and emb is not None:
|
|
232
|
+
h = self.film_down(h, emb)
|
|
233
|
+
|
|
234
|
+
# Submodule
|
|
235
|
+
if self.submodule is not None:
|
|
236
|
+
h = self.submodule(h, emb)
|
|
237
|
+
|
|
238
|
+
# Up
|
|
239
|
+
u = self.up(h)
|
|
240
|
+
if self.film_up is not None and emb is not None:
|
|
241
|
+
u = self.film_up(u, emb)
|
|
242
|
+
|
|
243
|
+
if self.out_act is not None:
|
|
244
|
+
u = self.out_act(u)
|
|
245
|
+
|
|
246
|
+
# Skip connections
|
|
247
|
+
if self.is_outermost:
|
|
248
|
+
return u
|
|
249
|
+
else:
|
|
250
|
+
out = torch.cat([x, u], dim=1)
|
|
251
|
+
if (not self.is_innermost) and (self.dropout is not None):
|
|
252
|
+
out = self.dropout(out)
|
|
253
|
+
return out
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
#--------------------------------------- FiLM embedding -------------------------------------
|
|
5
|
+
class MetaMLP(nn.Module):
|
|
6
|
+
"""Encodes metadata vector to an embedding used by FiLM."""
|
|
7
|
+
def __init__(self, meta_dim=6, emb_dim=128):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.net = nn.Sequential(
|
|
10
|
+
nn.Linear(meta_dim, emb_dim),
|
|
11
|
+
nn.ReLU(inplace=True),
|
|
12
|
+
nn.Linear(emb_dim, emb_dim),
|
|
13
|
+
nn.ReLU(inplace=True),
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
def forward(self, m): # (B,meta_dim)
|
|
17
|
+
return self.net(m) # (B,emb_dim)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FiLM(nn.Module):
|
|
21
|
+
"""Feature-wise Linear Modulation: produces per-channel gamma/beta."""
|
|
22
|
+
def __init__(self, emb_dim, n_ch):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.to_gb = nn.Linear(emb_dim, 2 * n_ch)
|
|
25
|
+
|
|
26
|
+
def forward(self, x, emb):
|
|
27
|
+
# x: (B,C,D,H,W) or (B,C,H,W), emb: (B,emb_dim)
|
|
28
|
+
gb = self.to_gb(emb) # (B,2C)
|
|
29
|
+
gamma, beta = gb.chunk(2, dim=1)
|
|
30
|
+
# broadcast
|
|
31
|
+
view_shape = (gamma.shape[0], gamma.shape[1]) + (1,) * (x.ndim - 2)
|
|
32
|
+
gamma = gamma.view(*view_shape)
|
|
33
|
+
beta = beta.view(*view_shape)
|
|
34
|
+
return x * (1.0 + gamma) + beta
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: imgen_toolbox
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Repository containing building blocks and utilities for flexible Generative Neural Networks implementation
|
|
5
|
+
Author-email: Simone Santoro <simone_santoro21@outlook.it>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/SimoneSantoro21/imgen_toolbox
|
|
8
|
+
Classifier: Development Status :: 1 - Planning
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Requires-Python: >=3.11
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: numpy
|
|
17
|
+
Requires-Dist: nibabel
|
|
18
|
+
Dynamic: license-file
|
|
19
|
+
|
|
20
|
+
# imgen_toolbox
|
|
21
|
+
Repository containing building blocks and utilities for flexible Generative Neural Networks implementation
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
imgen_toolbox/__init__.py
|
|
5
|
+
imgen_toolbox.egg-info/PKG-INFO
|
|
6
|
+
imgen_toolbox.egg-info/SOURCES.txt
|
|
7
|
+
imgen_toolbox.egg-info/dependency_links.txt
|
|
8
|
+
imgen_toolbox.egg-info/requires.txt
|
|
9
|
+
imgen_toolbox.egg-info/top_level.txt
|
|
10
|
+
imgen_toolbox/data/__init__.py
|
|
11
|
+
imgen_toolbox/data/datasets.py
|
|
12
|
+
imgen_toolbox/losses/__init__.py
|
|
13
|
+
imgen_toolbox/losses/gan_loss.py
|
|
14
|
+
imgen_toolbox/models/__init__.py
|
|
15
|
+
imgen_toolbox/models/patchGAN_discriminator.py
|
|
16
|
+
imgen_toolbox/models/unet_generator.py
|
|
17
|
+
imgen_toolbox/nn/__init__.py
|
|
18
|
+
imgen_toolbox/nn/blocks.py
|
|
19
|
+
imgen_toolbox/nn/embeddings.py
|
|
20
|
+
tests/test_blocks.py
|
|
21
|
+
tests/test_init.py
|
|
22
|
+
tests/test_lossess.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "imgen_toolbox"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Repository containing building blocks and utilities for flexible Generative Neural Networks implementation"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Simone Santoro", email = "simone_santoro21@outlook.it" }
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
license = { text = "MIT" }
|
|
17
|
+
|
|
18
|
+
dependencies = [
|
|
19
|
+
"numpy",
|
|
20
|
+
"nibabel"
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
classifiers = [
|
|
24
|
+
"Development Status :: 1 - Planning",
|
|
25
|
+
"Intended Audience :: Science/Research",
|
|
26
|
+
"License :: OSI Approved :: MIT License",
|
|
27
|
+
"Operating System :: OS Independent",
|
|
28
|
+
"Programming Language :: Python :: 3.11"
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.urls]
|
|
32
|
+
Homepage = "https://github.com/SimoneSantoro21/imgen_toolbox"
|
|
33
|
+
|
|
34
|
+
[tool.setuptools.packages.find]
|
|
35
|
+
where = ["."]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|