konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,99 +1,168 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
import torch.nn.functional as F # noqa: N812
|
|
2
3
|
from torch.nn.parameter import Parameter
|
|
3
|
-
import torch.nn.functional as F
|
|
4
4
|
|
|
5
|
-
from konfai.network import network, blocks
|
|
6
|
-
from konfai.utils.config import config
|
|
7
5
|
from konfai.models.segmentation import UNet
|
|
6
|
+
from konfai.network import blocks, network
|
|
7
|
+
|
|
8
8
|
|
|
9
9
|
class VoxelMorph(network.Network):
|
|
10
|
-
|
|
11
|
-
def __init__(
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
14
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
15
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
16
|
+
},
|
|
17
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
18
|
+
dim: int = 3,
|
|
19
|
+
channels: list[int] = [4, 16, 32, 32, 32],
|
|
20
|
+
block_config: blocks.BlockConfig = blocks.BlockConfig(),
|
|
21
|
+
nb_conv_per_stage: int = 2,
|
|
22
|
+
downsample_mode: str = "MAXPOOL",
|
|
23
|
+
upsample_mode: str = "CONV_TRANSPOSE",
|
|
24
|
+
attention: bool = False,
|
|
25
|
+
shape: list[int] = [192, 192, 192],
|
|
26
|
+
int_steps: int = 7,
|
|
27
|
+
int_downsize: int = 2,
|
|
28
|
+
nb_batch_per_step: int = 1,
|
|
29
|
+
rigid: bool = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(
|
|
32
|
+
in_channels=channels[0],
|
|
33
|
+
optimizer=optimizer,
|
|
34
|
+
schedulers=schedulers,
|
|
35
|
+
outputs_criterions=outputs_criterions,
|
|
36
|
+
dim=dim,
|
|
37
|
+
nb_batch_per_step=nb_batch_per_step,
|
|
38
|
+
)
|
|
39
|
+
self.add_module("Concat", blocks.Concat(), in_branch=[0, 1], out_branch=["input_concat"])
|
|
40
|
+
self.add_module(
|
|
41
|
+
"UNetBlock_0",
|
|
42
|
+
UNet.UNetBlock(
|
|
43
|
+
channels,
|
|
44
|
+
nb_conv_per_stage,
|
|
45
|
+
block_config,
|
|
46
|
+
downsample_mode=blocks.DownsampleMode[downsample_mode],
|
|
47
|
+
upsample_mode=blocks.UpsampleMode[upsample_mode],
|
|
48
|
+
attention=attention,
|
|
49
|
+
block=blocks.ConvBlock,
|
|
50
|
+
nb_class=0,
|
|
51
|
+
dim=dim,
|
|
52
|
+
),
|
|
53
|
+
in_branch=["input_concat"],
|
|
54
|
+
out_branch=["unet"],
|
|
55
|
+
)
|
|
30
56
|
|
|
31
57
|
if rigid:
|
|
32
|
-
self.add_module(
|
|
58
|
+
self.add_module(
|
|
59
|
+
"Flow",
|
|
60
|
+
Rigid(channels[1], dim),
|
|
61
|
+
in_branch=["unet"],
|
|
62
|
+
out_branch=["pos_flow"],
|
|
63
|
+
)
|
|
33
64
|
else:
|
|
34
|
-
self.add_module(
|
|
35
|
-
|
|
65
|
+
self.add_module(
|
|
66
|
+
"Flow",
|
|
67
|
+
Flow(channels[1], int_steps, int_downsize, shape, dim),
|
|
68
|
+
in_branch=["unet"],
|
|
69
|
+
out_branch=["pos_flow"],
|
|
70
|
+
)
|
|
71
|
+
self.add_module(
|
|
72
|
+
"MovingImageResample",
|
|
73
|
+
SpatialTransformer(shape, rigid=rigid),
|
|
74
|
+
in_branch=[1, "pos_flow"],
|
|
75
|
+
out_branch=["moving_image_resample"],
|
|
76
|
+
)
|
|
77
|
+
|
|
36
78
|
|
|
37
79
|
class Flow(network.ModuleArgsDict):
|
|
38
80
|
|
|
39
|
-
def __init__(
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
in_channels: int,
|
|
84
|
+
int_steps: int,
|
|
85
|
+
int_downsize: int,
|
|
86
|
+
shape: list[int],
|
|
87
|
+
dim: int,
|
|
88
|
+
) -> None:
|
|
40
89
|
super().__init__()
|
|
41
|
-
self.add_module(
|
|
90
|
+
self.add_module(
|
|
91
|
+
"Head",
|
|
92
|
+
blocks.get_torch_module("Conv", dim)(
|
|
93
|
+
in_channels=in_channels,
|
|
94
|
+
out_channels=dim,
|
|
95
|
+
kernel_size=3,
|
|
96
|
+
stride=1,
|
|
97
|
+
padding=1,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
42
100
|
self["Head"].weight = Parameter(torch.distributions.Normal(0, 1e-5).sample(self["Head"].weight.shape))
|
|
43
101
|
self["Head"].bias = Parameter(torch.zeros(self["Head"].bias.shape))
|
|
44
|
-
|
|
102
|
+
|
|
45
103
|
if int_steps > 0 and int_downsize > 1:
|
|
46
104
|
self.add_module("DownSample", ResizeTransform(int_downsize))
|
|
47
105
|
|
|
48
106
|
if int_steps > 0:
|
|
49
|
-
self.add_module(
|
|
50
|
-
|
|
107
|
+
self.add_module(
|
|
108
|
+
"Integrate_pos_flow",
|
|
109
|
+
VecInt([int(dim / int_downsize) for dim in shape], int_steps),
|
|
110
|
+
)
|
|
111
|
+
|
|
51
112
|
if int_steps > 0 and int_downsize > 1:
|
|
52
113
|
self.add_module("Upsample_pos_flow", ResizeTransform(1 / int_downsize))
|
|
53
114
|
|
|
115
|
+
|
|
54
116
|
class Rigid(network.ModuleArgsDict):
|
|
55
117
|
|
|
56
118
|
def __init__(self, in_channels: int, dim: int) -> None:
|
|
57
119
|
super().__init__()
|
|
58
120
|
self.add_module("ToFeatures", torch.nn.Flatten(1))
|
|
59
|
-
self.add_module("Head", torch.nn.Linear(in_channels*512*512, 2))
|
|
121
|
+
self.add_module("Head", torch.nn.Linear(in_channels * 512 * 512, 2))
|
|
60
122
|
|
|
61
123
|
def init(self, init_type: str, init_gain: float):
|
|
62
124
|
self["Head"].weight.data.fill_(0)
|
|
63
125
|
self["Head"].bias.data.copy_(torch.tensor([0, 0], dtype=torch.float))
|
|
64
|
-
|
|
126
|
+
|
|
127
|
+
|
|
65
128
|
class MaskFlow(torch.nn.Module):
|
|
66
129
|
|
|
67
130
|
def __init__(self):
|
|
68
131
|
super().__init__()
|
|
69
|
-
|
|
132
|
+
|
|
70
133
|
def forward(self, mask: torch.Tensor, *flows: torch.Tensor):
|
|
71
134
|
result = torch.zeros_like(flows[0])
|
|
72
135
|
for i, flow in enumerate(flows):
|
|
73
|
-
result = result+torch.where(mask == i+1, flow, torch.tensor(0))
|
|
136
|
+
result = result + torch.where(mask == i + 1, flow, torch.tensor(0))
|
|
74
137
|
return result
|
|
75
138
|
|
|
139
|
+
|
|
76
140
|
class SpatialTransformer(torch.nn.Module):
|
|
77
|
-
|
|
78
|
-
def __init__(self, size
|
|
141
|
+
|
|
142
|
+
def __init__(self, size: list[int], rigid: bool = False):
|
|
79
143
|
super().__init__()
|
|
80
144
|
self.rigid = rigid
|
|
81
145
|
if not rigid:
|
|
82
146
|
vectors = [torch.arange(0, s) for s in size]
|
|
83
|
-
grids = torch.meshgrid(vectors, indexing=
|
|
147
|
+
grids = torch.meshgrid(vectors, indexing="ij")
|
|
84
148
|
grid = torch.stack(grids)
|
|
85
149
|
grid = torch.unsqueeze(grid, 0)
|
|
86
150
|
grid = grid.type(torch.float)
|
|
87
|
-
self.register_buffer(
|
|
151
|
+
self.register_buffer("grid", grid)
|
|
88
152
|
|
|
89
153
|
def forward(self, src: torch.Tensor, flow: torch.Tensor):
|
|
90
154
|
if self.rigid:
|
|
91
155
|
new_locs = torch.zeros((flow.shape[0], 2, 3)).to(flow.device)
|
|
92
|
-
new_locs[:, 0,0] = 1
|
|
93
|
-
new_locs[:, 1,1] = 1
|
|
94
|
-
new_locs[:, 0,2] = flow[:, 0]
|
|
95
|
-
new_locs[:, 1,2] = flow[:, 1]
|
|
96
|
-
return F.grid_sample(
|
|
156
|
+
new_locs[:, 0, 0] = 1
|
|
157
|
+
new_locs[:, 1, 1] = 1
|
|
158
|
+
new_locs[:, 0, 2] = flow[:, 0]
|
|
159
|
+
new_locs[:, 1, 2] = flow[:, 1]
|
|
160
|
+
return F.grid_sample(
|
|
161
|
+
src,
|
|
162
|
+
F.affine_grid(new_locs, src.size()),
|
|
163
|
+
align_corners=True,
|
|
164
|
+
mode="bilinear",
|
|
165
|
+
)
|
|
97
166
|
else:
|
|
98
167
|
new_locs = self.grid + flow
|
|
99
168
|
shape = flow.shape[2:]
|
|
@@ -102,13 +171,15 @@ class SpatialTransformer(torch.nn.Module):
|
|
|
102
171
|
new_locs = new_locs.permute(0, 2, 3, 1)
|
|
103
172
|
return F.grid_sample(src, new_locs[..., [1, 0]], align_corners=True, mode="bilinear")
|
|
104
173
|
|
|
174
|
+
|
|
105
175
|
class VecInt(torch.nn.Module):
|
|
106
176
|
|
|
107
|
-
def __init__(self, inshape, nsteps):
|
|
177
|
+
def __init__(self, inshape: list[int], nsteps: int):
|
|
108
178
|
super().__init__()
|
|
109
|
-
|
|
179
|
+
if nsteps < 0:
|
|
180
|
+
raise NameError(f"nsteps should be >= 0, found: {nsteps}")
|
|
110
181
|
self.nsteps = nsteps
|
|
111
|
-
self.scale = 1.0 / (2
|
|
182
|
+
self.scale = 1.0 / (2**self.nsteps)
|
|
112
183
|
self.transformer = SpatialTransformer(inshape)
|
|
113
184
|
|
|
114
185
|
def forward(self, vec: torch.Tensor):
|
|
@@ -119,16 +190,28 @@ class VecInt(torch.nn.Module):
|
|
|
119
190
|
|
|
120
191
|
|
|
121
192
|
class ResizeTransform(torch.nn.Module):
|
|
122
|
-
|
|
123
|
-
def __init__(self, size):
|
|
193
|
+
|
|
194
|
+
def __init__(self, size: float):
|
|
124
195
|
super().__init__()
|
|
125
196
|
self.factor = 1.0 / size
|
|
126
197
|
|
|
127
198
|
def forward(self, x: torch.Tensor):
|
|
128
199
|
if self.factor < 1:
|
|
129
|
-
x = F.interpolate(
|
|
200
|
+
x = F.interpolate(
|
|
201
|
+
x,
|
|
202
|
+
align_corners=True,
|
|
203
|
+
scale_factor=self.factor,
|
|
204
|
+
mode="bilinear",
|
|
205
|
+
recompute_scale_factor=True,
|
|
206
|
+
)
|
|
130
207
|
x = self.factor * x
|
|
131
208
|
elif self.factor > 1:
|
|
132
209
|
x = self.factor * x
|
|
133
|
-
x = F.interpolate(
|
|
134
|
-
|
|
210
|
+
x = F.interpolate(
|
|
211
|
+
x,
|
|
212
|
+
align_corners=True,
|
|
213
|
+
scale_factor=self.factor,
|
|
214
|
+
mode="bilinear",
|
|
215
|
+
recompute_scale_factor=True,
|
|
216
|
+
)
|
|
217
|
+
return x
|
|
@@ -1,21 +1,35 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from konfai.network import
|
|
4
|
-
|
|
3
|
+
from konfai.network import blocks, network
|
|
4
|
+
|
|
5
5
|
|
|
6
6
|
class ConvBlock(torch.nn.Module):
|
|
7
|
-
|
|
8
|
-
def __init__(self, in_channels
|
|
7
|
+
|
|
8
|
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
|
9
9
|
super().__init__()
|
|
10
|
-
self.Conv_0 = torch.nn.Conv3d(
|
|
10
|
+
self.Conv_0 = torch.nn.Conv3d(
|
|
11
|
+
in_channels=in_channels,
|
|
12
|
+
out_channels=out_channels,
|
|
13
|
+
kernel_size=3,
|
|
14
|
+
stride=1,
|
|
15
|
+
padding=1,
|
|
16
|
+
bias=True,
|
|
17
|
+
)
|
|
11
18
|
self.Norm_0 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
|
|
12
19
|
self.Activation_0 = torch.nn.LeakyReLU(negative_slope=0.01)
|
|
13
|
-
self.Conv_1 = torch.nn.Conv3d(
|
|
20
|
+
self.Conv_1 = torch.nn.Conv3d(
|
|
21
|
+
in_channels=out_channels,
|
|
22
|
+
out_channels=out_channels,
|
|
23
|
+
kernel_size=3,
|
|
24
|
+
stride=1,
|
|
25
|
+
padding=1,
|
|
26
|
+
bias=True,
|
|
27
|
+
)
|
|
14
28
|
self.Norm_1 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
|
|
15
29
|
self.Activation_1 = torch.nn.LeakyReLU(negative_slope=0.01)
|
|
16
|
-
|
|
17
|
-
def forward(self,
|
|
18
|
-
output = self.Conv_0(
|
|
30
|
+
|
|
31
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
output = self.Conv_0(tensor)
|
|
19
33
|
output = self.Norm_0(output)
|
|
20
34
|
output = self.Activation_0(output)
|
|
21
35
|
output = self.Conv_1(output)
|
|
@@ -23,34 +37,54 @@ class ConvBlock(torch.nn.Module):
|
|
|
23
37
|
output = self.Activation_1(output)
|
|
24
38
|
return output
|
|
25
39
|
|
|
26
|
-
|
|
40
|
+
|
|
41
|
+
class UnetCPP1Layers(torch.nn.Module):
|
|
27
42
|
|
|
28
43
|
def __init__(self) -> None:
|
|
29
44
|
super().__init__()
|
|
30
45
|
self.DownConvBlock_0 = ConvBlock(in_channels=1, out_channels=32)
|
|
31
|
-
|
|
32
|
-
def forward(self,
|
|
33
|
-
return self.DownConvBlock_0(
|
|
46
|
+
|
|
47
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return self.DownConvBlock_0(tensor)
|
|
49
|
+
|
|
34
50
|
|
|
35
51
|
class Adaptation(torch.nn.Module):
|
|
36
52
|
|
|
37
53
|
def __init__(self) -> None:
|
|
38
54
|
super().__init__()
|
|
39
|
-
self.Encoder_1 =
|
|
55
|
+
self.Encoder_1 = UnetCPP1Layers()
|
|
40
56
|
self.ToFeatures = blocks.ToFeatures(3)
|
|
41
57
|
self.FCT_1 = torch.nn.Linear(32, 32, bias=True)
|
|
42
58
|
|
|
43
|
-
def forward(
|
|
59
|
+
def forward(
|
|
60
|
+
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
|
|
61
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
44
62
|
self.Encoder_1.requires_grad_(False)
|
|
45
63
|
self.FCT_1.requires_grad_(True)
|
|
46
|
-
return
|
|
64
|
+
return (
|
|
65
|
+
self.FCT_1(self.ToFeatures(self.Encoder_1(a))),
|
|
66
|
+
self.FCT_1(self.ToFeatures(self.Encoder_1(b))),
|
|
67
|
+
self.FCT_1(self.ToFeatures(self.Encoder_1(c))),
|
|
68
|
+
)
|
|
69
|
+
|
|
47
70
|
|
|
48
71
|
class Representation(network.Network):
|
|
49
72
|
|
|
50
|
-
def __init__(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
76
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
77
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
78
|
+
},
|
|
79
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
80
|
+
dim: int = 3,
|
|
81
|
+
):
|
|
82
|
+
super().__init__(
|
|
83
|
+
in_channels=1,
|
|
84
|
+
optimizer=optimizer,
|
|
85
|
+
schedulers=schedulers,
|
|
86
|
+
outputs_criterions=outputs_criterions,
|
|
87
|
+
dim=dim,
|
|
88
|
+
init_type="kaiming",
|
|
89
|
+
)
|
|
90
|
+
self.add_module("Model", Adaptation(), in_branch=[0, 1, 2])
|