konfai 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +1 -1
- konfai/data/augmentation.py +41 -36
- konfai/data/data_manager.py +57 -34
- konfai/data/patching.py +37 -13
- konfai/data/transform.py +49 -21
- konfai/evaluator.py +24 -7
- konfai/main.py +5 -3
- konfai/models/segmentation/UNet.py +9 -10
- konfai/network/network.py +0 -1
- konfai/predictor.py +41 -21
- konfai/trainer.py +24 -10
- konfai/utils/dataset.py +27 -2
- konfai/utils/utils.py +49 -12
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/METADATA +1 -1
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/RECORD +19 -19
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/WHEEL +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/top_level.txt +0 -0
konfai/data/transform.py
CHANGED
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from typing import Any, Union
|
|
8
8
|
|
|
9
|
-
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
|
|
9
|
+
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
|
|
10
10
|
from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
|
|
11
11
|
from konfai.utils.config import config
|
|
12
12
|
|
|
@@ -52,7 +52,7 @@ class Clip(Transform):
|
|
|
52
52
|
input[torch.where(input < self.min_value)] = self.min_value
|
|
53
53
|
input[torch.where(input > self.max_value)] = self.max_value
|
|
54
54
|
if self.saveClip_min:
|
|
55
|
-
cache_attribute["Min"] = self
|
|
55
|
+
cache_attribute["Min"] = self.min_value
|
|
56
56
|
if self.saveClip_max:
|
|
57
57
|
cache_attribute["Max"] = self.max_value
|
|
58
58
|
return input
|
|
@@ -211,39 +211,67 @@ class Resample(Transform, ABC):
|
|
|
211
211
|
_ = cache_attribute.pop_np_array("Spacing")
|
|
212
212
|
return self._resample(input, [int(size) for size in size_1])
|
|
213
213
|
|
|
214
|
-
class
|
|
214
|
+
class ResampleToResolution(Resample):
|
|
215
|
+
|
|
216
|
+
def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
|
|
217
|
+
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
215
218
|
|
|
216
|
-
def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
|
|
217
|
-
self.spacing = torch.tensor(spacing, dtype=torch.float64)
|
|
218
|
-
|
|
219
219
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
220
|
+
if "Spacing" not in cache_attribute:
|
|
221
|
+
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
222
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
223
|
+
if len(shape) != len(self.spacing):
|
|
224
|
+
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
225
|
+
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
226
|
+
spacing = self.spacing
|
|
227
|
+
|
|
228
|
+
for i, s in enumerate(self.spacing):
|
|
229
|
+
if s == 0:
|
|
230
|
+
spacing[i] = image_spacing[i]
|
|
231
|
+
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
232
|
+
return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
|
|
223
233
|
|
|
224
234
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
235
|
+
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
236
|
+
spacing = self.spacing
|
|
237
|
+
for i, s in enumerate(self.spacing):
|
|
238
|
+
if s == 0:
|
|
239
|
+
spacing[i] = image_spacing[i]
|
|
240
|
+
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
241
|
+
cache_attribute["Spacing"] = spacing.flip(0)
|
|
228
242
|
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
229
243
|
size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
|
|
230
244
|
cache_attribute["Size"] = np.asarray(size)
|
|
231
245
|
return self._resample(input, size)
|
|
232
246
|
|
|
233
|
-
class
|
|
247
|
+
class ResampleToSize(Resample):
|
|
234
248
|
|
|
235
249
|
def __init__(self, size : list[int] = [100,512,512]) -> None:
|
|
236
250
|
self.size = size
|
|
237
251
|
|
|
238
252
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
239
|
-
|
|
253
|
+
if "Spacing" not in cache_attribute:
|
|
254
|
+
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
255
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
256
|
+
if len(shape) != len(self.size):
|
|
257
|
+
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
258
|
+
size = self.size
|
|
259
|
+
for i, s in enumerate(self.size):
|
|
260
|
+
if s == -1:
|
|
261
|
+
size[i] = shape[i]
|
|
262
|
+
return size
|
|
240
263
|
|
|
241
264
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
265
|
+
size = self.size
|
|
266
|
+
image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
267
|
+
for i, s in enumerate(self.size):
|
|
268
|
+
if s is None:
|
|
269
|
+
size[i] = image_size[i]
|
|
242
270
|
if "Spacing" in cache_attribute:
|
|
243
|
-
cache_attribute["Spacing"] = torch.flip(torch.tensor(
|
|
244
|
-
cache_attribute["Size"] =
|
|
245
|
-
cache_attribute["Size"] =
|
|
246
|
-
return self._resample(input,
|
|
271
|
+
cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
|
|
272
|
+
cache_attribute["Size"] = image_size
|
|
273
|
+
cache_attribute["Size"] = size
|
|
274
|
+
return self._resample(input, size)
|
|
247
275
|
|
|
248
276
|
class ResampleTransform(Transform):
|
|
249
277
|
|
|
@@ -412,8 +440,8 @@ class FlatLabel(Transform):
|
|
|
412
440
|
|
|
413
441
|
class Save(Transform):
|
|
414
442
|
|
|
415
|
-
def __init__(self,
|
|
416
|
-
self.
|
|
443
|
+
def __init__(self, dataset: str) -> None:
|
|
444
|
+
self.dataset = dataset
|
|
417
445
|
|
|
418
446
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
419
447
|
return input
|
|
@@ -528,7 +556,7 @@ class OneHot(Transform):
|
|
|
528
556
|
self.num_classes = num_classes
|
|
529
557
|
|
|
530
558
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
531
|
-
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(
|
|
559
|
+
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
|
|
532
560
|
return result
|
|
533
561
|
|
|
534
562
|
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
konfai/evaluator.py
CHANGED
|
@@ -9,7 +9,7 @@ import builtins
|
|
|
9
9
|
import importlib
|
|
10
10
|
from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
|
|
11
11
|
from konfai.utils.config import config
|
|
12
|
-
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
|
|
12
|
+
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
|
|
13
13
|
from konfai.data.data_manager import DataMetric
|
|
14
14
|
|
|
15
15
|
class CriterionsAttr():
|
|
@@ -54,7 +54,8 @@ class Statistics():
|
|
|
54
54
|
if name_dataset not in self.measures:
|
|
55
55
|
self.measures[name_dataset] = {}
|
|
56
56
|
self.measures[name_dataset][name] = value
|
|
57
|
-
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
58
59
|
def getStatistic(values: list[float]) -> dict[str, float]:
|
|
59
60
|
return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
|
|
60
61
|
|
|
@@ -91,7 +92,6 @@ class Evaluator(DistributedObject):
|
|
|
91
92
|
exit(0)
|
|
92
93
|
super().__init__(train_name)
|
|
93
94
|
self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
|
|
94
|
-
self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
|
|
95
95
|
self.metricsLoader = metrics
|
|
96
96
|
self.dataset = dataset
|
|
97
97
|
self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
|
|
@@ -102,10 +102,10 @@ class Evaluator(DistributedObject):
|
|
|
102
102
|
result = {}
|
|
103
103
|
for output_group in self.metrics:
|
|
104
104
|
for target_group in self.metrics[output_group]:
|
|
105
|
-
targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
105
|
+
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
106
106
|
name = data_dict[output_group][1][0]
|
|
107
107
|
for metric in self.metrics[output_group][target_group]:
|
|
108
|
-
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
|
|
108
|
+
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
|
|
109
109
|
statistics.add(result, name)
|
|
110
110
|
return result
|
|
111
111
|
|
|
@@ -126,9 +126,26 @@ class Evaluator(DistributedObject):
|
|
|
126
126
|
|
|
127
127
|
self.dataloader = self.dataset.getData(world_size)
|
|
128
128
|
|
|
129
|
+
groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
|
|
130
|
+
|
|
131
|
+
missing_outputs = set(self.metrics.keys()) - set(groupsDest)
|
|
132
|
+
if missing_outputs:
|
|
133
|
+
raise EvaluatorError(
|
|
134
|
+
f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
|
|
135
|
+
f"Available groups: {sorted(groupsDest)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
target_groups = {target for targets in self.metrics.values() for target in targets}
|
|
139
|
+
missing_targets = target_groups - set(groupsDest)
|
|
140
|
+
if missing_targets:
|
|
141
|
+
raise EvaluatorError(
|
|
142
|
+
f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
|
|
143
|
+
f"Available groups: {sorted(groupsDest)}"
|
|
144
|
+
)
|
|
145
|
+
|
|
129
146
|
def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
|
|
130
147
|
description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
131
|
-
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=
|
|
148
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
|
|
132
149
|
for _, data_dict in batch_iter:
|
|
133
150
|
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
|
|
134
151
|
outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
|
|
@@ -136,7 +153,7 @@ class Evaluator(DistributedObject):
|
|
|
136
153
|
self.statistics_train.write(outputs)
|
|
137
154
|
if len(dataloaders) == 2:
|
|
138
155
|
description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
139
|
-
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=
|
|
156
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
|
|
140
157
|
for _, data_dict in batch_iter:
|
|
141
158
|
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
|
|
142
159
|
outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
|
konfai/main.py
CHANGED
|
@@ -3,25 +3,27 @@ import os
|
|
|
3
3
|
from torch.cuda import device_count
|
|
4
4
|
import torch.multiprocessing as mp
|
|
5
5
|
from konfai.utils.utils import setup, TensorBoard, Log
|
|
6
|
+
from konfai import KONFAI_NB_CORES
|
|
6
7
|
|
|
7
8
|
import sys
|
|
8
9
|
sys.path.insert(0, os.getcwd())
|
|
9
10
|
|
|
10
11
|
def main():
|
|
12
|
+
import tracemalloc
|
|
13
|
+
|
|
11
14
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
12
15
|
try:
|
|
13
16
|
with setup(parser) as distributedObject:
|
|
14
|
-
with Log(distributedObject.name):
|
|
17
|
+
with Log(distributedObject.name, 0):
|
|
15
18
|
world_size = device_count()
|
|
16
19
|
if world_size == 0:
|
|
17
|
-
world_size =
|
|
20
|
+
world_size = int(KONFAI_NB_CORES())
|
|
18
21
|
distributedObject.setup(world_size)
|
|
19
22
|
with TensorBoard(distributedObject.name):
|
|
20
23
|
mp.spawn(distributedObject, nprocs=world_size)
|
|
21
24
|
except KeyboardInterrupt:
|
|
22
25
|
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
23
26
|
|
|
24
|
-
|
|
25
27
|
def cluster():
|
|
26
28
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
27
29
|
|
|
@@ -7,32 +7,33 @@ from konfai.data.patching import ModelPatch
|
|
|
7
7
|
|
|
8
8
|
class UNetHead(network.ModuleArgsDict):
|
|
9
9
|
|
|
10
|
-
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
10
|
+
def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
|
|
11
11
|
super().__init__()
|
|
12
12
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
13
13
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
14
14
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
15
|
+
|
|
15
16
|
|
|
16
17
|
class UNetBlock(network.ModuleArgsDict):
|
|
17
18
|
|
|
18
|
-
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0
|
|
19
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0) -> None:
|
|
19
20
|
super().__init__()
|
|
20
21
|
blockConfig_stride = blockConfig
|
|
21
22
|
if i > 0:
|
|
22
23
|
if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
|
|
23
24
|
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
24
25
|
else:
|
|
25
|
-
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size,
|
|
26
|
+
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
|
|
26
27
|
self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
|
|
27
28
|
if len(channels) > 2:
|
|
28
|
-
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1
|
|
29
|
+
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1))
|
|
29
30
|
self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
|
|
30
31
|
if nb_class > 0:
|
|
31
|
-
self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
|
|
32
|
+
self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
|
|
32
33
|
if i > 0:
|
|
33
34
|
if attention:
|
|
34
35
|
self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
|
|
35
|
-
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=
|
|
36
|
+
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=2, stride=2))
|
|
36
37
|
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
|
37
38
|
|
|
38
39
|
class UNet(network.Network):
|
|
@@ -51,8 +52,6 @@ class UNet(network.Network):
|
|
|
51
52
|
downSampleMode: str = "MAXPOOL",
|
|
52
53
|
upSampleMode: str = "CONV_TRANSPOSE",
|
|
53
54
|
attention : bool = False,
|
|
54
|
-
blockType: str = "Conv"
|
|
55
|
-
mri: bool = False) -> None:
|
|
55
|
+
blockType: str = "Conv") -> None:
|
|
56
56
|
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
57
|
-
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim
|
|
58
|
-
|
|
57
|
+
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim))
|
konfai/network/network.py
CHANGED
|
@@ -39,7 +39,6 @@ class OptimizerLoader():
|
|
|
39
39
|
self.name = name
|
|
40
40
|
|
|
41
41
|
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
|
-
torch.optim.AdamW
|
|
43
42
|
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
44
43
|
|
|
45
44
|
class SchedulerStep():
|
konfai/predictor.py
CHANGED
|
@@ -91,13 +91,30 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
91
91
|
super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
|
|
92
92
|
self.attributes.pop(index)
|
|
93
93
|
|
|
94
|
+
class Reduction():
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
class ReductionMean():
|
|
100
|
+
|
|
101
|
+
def __init__(self):
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
class ReductionMedian():
|
|
105
|
+
|
|
106
|
+
def __init__(self):
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
94
111
|
class OutSameAsGroupDataset(OutDataset):
|
|
95
112
|
|
|
96
113
|
@config("OutDataset")
|
|
97
|
-
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None,
|
|
114
|
+
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
|
|
98
115
|
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
99
116
|
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
100
|
-
self.
|
|
117
|
+
self.reduction = reduction
|
|
101
118
|
self.inverse_transform = inverse_transform
|
|
102
119
|
|
|
103
120
|
def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
@@ -151,9 +168,9 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
151
168
|
self.output_layer_accumulator.pop(index)
|
|
152
169
|
dtype = result.dtype
|
|
153
170
|
|
|
154
|
-
if self.
|
|
171
|
+
if self.reduction == "mean":
|
|
155
172
|
result = torch.mean(result.float(), dim=0).to(dtype)
|
|
156
|
-
elif self.
|
|
173
|
+
elif self.reduction == "median":
|
|
157
174
|
result, _ = torch.median(result.float(), dim=0)
|
|
158
175
|
else:
|
|
159
176
|
raise NameError("Reduction method does not exist (mean, median)")
|
|
@@ -201,7 +218,7 @@ class OutDatasetLoader():
|
|
|
201
218
|
|
|
202
219
|
class _Predictor():
|
|
203
220
|
|
|
204
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
221
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
205
222
|
self.world_size = world_size
|
|
206
223
|
self.global_rank = global_rank
|
|
207
224
|
self.local_rank = local_rank
|
|
@@ -209,7 +226,7 @@ class _Predictor():
|
|
|
209
226
|
self.modelComposite = modelComposite
|
|
210
227
|
self.dataloader_prediction = dataloader_prediction
|
|
211
228
|
self.outsDataset = outsDataset
|
|
212
|
-
|
|
229
|
+
self.autocast = autocast
|
|
213
230
|
|
|
214
231
|
self.it = 0
|
|
215
232
|
|
|
@@ -239,21 +256,22 @@ class _Predictor():
|
|
|
239
256
|
self.modelComposite.eval()
|
|
240
257
|
self.modelComposite.module.setState(NetState.PREDICTION)
|
|
241
258
|
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
242
|
-
self.dataloader_prediction.dataset.load()
|
|
243
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=
|
|
259
|
+
self.dataloader_prediction.dataset.load("Prediction")
|
|
260
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
|
|
244
261
|
dist.barrier()
|
|
245
262
|
for it, data_dict in batch_iter:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
self.
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
263
|
+
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
264
|
+
input = self.getInput(data_dict)
|
|
265
|
+
for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
|
|
266
|
+
self._predict_log(data_dict)
|
|
267
|
+
outDataset = self.outsDataset[name]
|
|
268
|
+
for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
|
|
269
|
+
outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
|
|
270
|
+
if outDataset.isDone(index):
|
|
271
|
+
outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
|
|
272
|
+
|
|
273
|
+
batch_iter.set_description(desc())
|
|
274
|
+
self.it += 1
|
|
257
275
|
|
|
258
276
|
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
259
277
|
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
|
|
@@ -320,6 +338,7 @@ class Predictor(DistributedObject):
|
|
|
320
338
|
train_name: str = "name",
|
|
321
339
|
manual_seed : Union[int, None] = None,
|
|
322
340
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
341
|
+
autocast : bool = False,
|
|
323
342
|
outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
|
|
324
343
|
images_log: list[str] = []) -> None:
|
|
325
344
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
@@ -328,6 +347,7 @@ class Predictor(DistributedObject):
|
|
|
328
347
|
self.manual_seed = manual_seed
|
|
329
348
|
self.dataset = dataset
|
|
330
349
|
self.combine = combine
|
|
350
|
+
self.autocast = autocast
|
|
331
351
|
|
|
332
352
|
self.model = model.getModel(train=False)
|
|
333
353
|
self.it = 0
|
|
@@ -384,7 +404,7 @@ class Predictor(DistributedObject):
|
|
|
384
404
|
|
|
385
405
|
shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
|
|
386
406
|
|
|
387
|
-
self.model.init(autocast
|
|
407
|
+
self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
|
|
388
408
|
self.model.init_outputsGroup()
|
|
389
409
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
390
410
|
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
|
|
@@ -402,7 +422,7 @@ class Predictor(DistributedObject):
|
|
|
402
422
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
403
423
|
modelComposite = Network.to(self.modelComposite, local_rank*self.size)
|
|
404
424
|
modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
|
|
405
|
-
with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
425
|
+
with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
406
426
|
p.run()
|
|
407
427
|
|
|
408
428
|
|
konfai/trainer.py
CHANGED
|
@@ -17,7 +17,6 @@ from konfai.utils.config import config
|
|
|
17
17
|
from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
|
|
18
18
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
class EarlyStoppingBase:
|
|
22
21
|
|
|
23
22
|
def __init__(self):
|
|
@@ -99,7 +98,7 @@ class _Trainer():
|
|
|
99
98
|
self.autocast = autocast
|
|
100
99
|
self.modelEMA = modelEMA
|
|
101
100
|
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
102
|
-
|
|
101
|
+
|
|
103
102
|
self.it_validation = it_validation
|
|
104
103
|
if self.it_validation is None:
|
|
105
104
|
self.it_validation = len(dataloader_training)
|
|
@@ -110,6 +109,8 @@ class _Trainer():
|
|
|
110
109
|
for data in data_log:
|
|
111
110
|
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
112
111
|
|
|
112
|
+
|
|
113
|
+
|
|
113
114
|
def __enter__(self):
|
|
114
115
|
return self
|
|
115
116
|
|
|
@@ -118,13 +119,14 @@ class _Trainer():
|
|
|
118
119
|
self.tb.close()
|
|
119
120
|
|
|
120
121
|
def run(self) -> None:
|
|
121
|
-
|
|
122
|
+
self.dataloader_training.dataset.load("Train")
|
|
123
|
+
self.dataloader_validation.dataset.load("Validation")
|
|
124
|
+
with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
|
|
122
125
|
for self.epoch in epoch_tqdm:
|
|
123
|
-
self.dataloader_training.dataset.load()
|
|
124
126
|
self.train()
|
|
125
127
|
if self.early_stopping.isStopped():
|
|
126
128
|
break
|
|
127
|
-
self.dataloader_training.dataset.resetAugmentation()
|
|
129
|
+
self.dataloader_training.dataset.resetAugmentation("Train")
|
|
128
130
|
|
|
129
131
|
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
130
132
|
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
@@ -137,7 +139,8 @@ class _Trainer():
|
|
|
137
139
|
self.modelEMA.module.setState(NetState.TRAIN)
|
|
138
140
|
|
|
139
141
|
desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
|
|
140
|
-
|
|
142
|
+
|
|
143
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
|
|
141
144
|
for _, data_dict in batch_iter:
|
|
142
145
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
143
146
|
input = self.getInput(data_dict)
|
|
@@ -171,8 +174,7 @@ class _Trainer():
|
|
|
171
174
|
|
|
172
175
|
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
173
176
|
data_dict = None
|
|
174
|
-
self.dataloader_validation
|
|
175
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
177
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
|
|
176
178
|
for _, data_dict in batch_iter:
|
|
177
179
|
input = self.getInput(data_dict)
|
|
178
180
|
self.model(input)
|
|
@@ -180,7 +182,7 @@ class _Trainer():
|
|
|
180
182
|
self.modelEMA.module(input)
|
|
181
183
|
|
|
182
184
|
batch_iter.set_description(desc())
|
|
183
|
-
self.dataloader_validation.dataset.resetAugmentation()
|
|
185
|
+
self.dataloader_validation.dataset.resetAugmentation("Validation")
|
|
184
186
|
dist.barrier()
|
|
185
187
|
self.model.train()
|
|
186
188
|
self.model.module.setState(NetState.TRAIN)
|
|
@@ -314,6 +316,18 @@ class Trainer(DistributedObject):
|
|
|
314
316
|
self.ema_decay = ema_decay
|
|
315
317
|
self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
|
|
316
318
|
self.data_log = data_log
|
|
319
|
+
|
|
320
|
+
modules = []
|
|
321
|
+
for i,_ in self.model.named_modules():
|
|
322
|
+
modules.append(i)
|
|
323
|
+
for k in self.data_log:
|
|
324
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
325
|
+
if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
|
|
326
|
+
raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
|
|
327
|
+
f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
|
|
328
|
+
f"nor a valid module name in the model ({modules}).",
|
|
329
|
+
"Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
|
|
330
|
+
|
|
317
331
|
self.gradient_checkpoints = gradient_checkpoints
|
|
318
332
|
self.gpu_checkpoints = gpu_checkpoints
|
|
319
333
|
self.save_checkpoint_mode = save_checkpoint_mode
|
|
@@ -360,7 +374,7 @@ class Trainer(DistributedObject):
|
|
|
360
374
|
os.makedirs(dir)
|
|
361
375
|
|
|
362
376
|
for name in sorted(os.listdir(path_checkpoint)):
|
|
363
|
-
checkpoint = torch.load(path_checkpoint+name, weights_only=False)
|
|
377
|
+
checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
|
|
364
378
|
self.model.load(checkpoint, init=False, ema=False)
|
|
365
379
|
|
|
366
380
|
torch.save(self.model, "{}Serialized/{}".format(path_model, name))
|
konfai/utils/dataset.py
CHANGED
|
@@ -348,6 +348,10 @@ class Dataset():
|
|
|
348
348
|
@abstractmethod
|
|
349
349
|
def getNames(self, group: str) -> list[str]:
|
|
350
350
|
pass
|
|
351
|
+
|
|
352
|
+
@abstractmethod
|
|
353
|
+
def getGroup(self) -> list[str]:
|
|
354
|
+
pass
|
|
351
355
|
|
|
352
356
|
@abstractmethod
|
|
353
357
|
def isExist(self, group: str, name: Union[str, None] = None) -> bool:
|
|
@@ -458,6 +462,9 @@ class Dataset():
|
|
|
458
462
|
names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
|
|
459
463
|
return names
|
|
460
464
|
|
|
465
|
+
def getGroup(self):
|
|
466
|
+
return self.h5.keys()
|
|
467
|
+
|
|
461
468
|
def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
|
|
462
469
|
if h5_group is None:
|
|
463
470
|
h5_group = self.h5
|
|
@@ -613,7 +620,10 @@ class Dataset():
|
|
|
613
620
|
|
|
614
621
|
def getNames(self, group: str) -> list[str]:
|
|
615
622
|
raise NotImplementedError()
|
|
616
|
-
|
|
623
|
+
|
|
624
|
+
def getGroup(self):
|
|
625
|
+
raise NotImplementedError()
|
|
626
|
+
|
|
617
627
|
def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
|
|
618
628
|
attributes = Attribute()
|
|
619
629
|
if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
|
|
@@ -736,7 +746,7 @@ class Dataset():
|
|
|
736
746
|
subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
|
|
737
747
|
return subDirectories
|
|
738
748
|
|
|
739
|
-
def getNames(self, groups: str, index: Union[list[int], None] = None
|
|
749
|
+
def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
|
|
740
750
|
names = []
|
|
741
751
|
if self.is_directory:
|
|
742
752
|
for subDirectory in self._getSubDirectories(groups):
|
|
@@ -752,6 +762,21 @@ class Dataset():
|
|
|
752
762
|
names = file.getNames(groups)
|
|
753
763
|
return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
|
|
754
764
|
|
|
765
|
+
def getGroup(self):
|
|
766
|
+
if self.is_directory:
|
|
767
|
+
groups = set()
|
|
768
|
+
for root, _, files in os.walk(self.filename):
|
|
769
|
+
for file in files:
|
|
770
|
+
path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
|
|
771
|
+
parts = path.split("/")
|
|
772
|
+
if len(parts) >= 2:
|
|
773
|
+
del parts[-2]
|
|
774
|
+
groups.add("/".join(parts))
|
|
775
|
+
else:
|
|
776
|
+
with Dataset.File(self.filename, True, self.format) as file:
|
|
777
|
+
groups = file.getGroup()
|
|
778
|
+
return list(groups)
|
|
779
|
+
|
|
755
780
|
def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
756
781
|
if self.is_directory:
|
|
757
782
|
for subDirectory in self._getSubDirectories(groups):
|