konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/predictor.py
CHANGED
|
@@ -1,328 +1,639 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
1
|
import builtins
|
|
2
|
+
import copy
|
|
3
3
|
import importlib
|
|
4
|
+
import os
|
|
4
5
|
import shutil
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
5
10
|
import torch
|
|
6
11
|
import tqdm
|
|
7
|
-
import
|
|
12
|
+
from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
8
15
|
|
|
9
|
-
from konfai import
|
|
10
|
-
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
|
|
12
|
-
from konfai.utils.dataset import Dataset, Attribute
|
|
16
|
+
from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
|
|
13
17
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
14
18
|
from konfai.data.patching import Accumulator, PathCombine
|
|
15
|
-
from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
|
|
16
19
|
from konfai.data.transform import Transform, TransformLoader
|
|
20
|
+
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
21
|
+
from konfai.utils.config import config
|
|
22
|
+
from konfai.utils.dataset import Attribute, Dataset
|
|
23
|
+
from konfai.utils.utils import DataLog, DistributedObject, NeedDevice, PredictorError, State, description, get_module
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Reduction(ABC):
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def __init__(self):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Mean(Reduction):
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
return torch.mean(result.float(), dim=0)
|
|
17
44
|
|
|
18
|
-
from torch.utils.tensorboard.writer import SummaryWriter
|
|
19
|
-
from typing import Union
|
|
20
|
-
import numpy as np
|
|
21
|
-
import torch.distributed as dist
|
|
22
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
23
|
-
from torch.utils.data import DataLoader
|
|
24
|
-
import importlib
|
|
25
|
-
import copy
|
|
26
|
-
from collections import defaultdict
|
|
27
45
|
|
|
28
|
-
class
|
|
46
|
+
class Median(Reduction):
|
|
47
|
+
|
|
48
|
+
def __init__(self):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
52
|
+
return torch.median(result.float(), dim=0).values
|
|
29
53
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
54
|
+
|
|
55
|
+
class OutputDataset(Dataset, NeedDevice, ABC):
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
filename: str,
|
|
60
|
+
group: str,
|
|
61
|
+
before_reduction_transforms: dict[str, TransformLoader],
|
|
62
|
+
after_reduction_transforms: dict[str, TransformLoader],
|
|
63
|
+
final_transforms: dict[str, TransformLoader],
|
|
64
|
+
patch_combine: str | None,
|
|
65
|
+
reduction: str,
|
|
66
|
+
) -> None:
|
|
67
|
+
filename, file_format = filename.split(":")
|
|
68
|
+
super().__init__(filename, file_format)
|
|
33
69
|
self.group = group
|
|
34
70
|
self._before_reduction_transforms = before_reduction_transforms
|
|
35
71
|
self._after_reduction_transforms = after_reduction_transforms
|
|
36
72
|
self._final_transforms = final_transforms
|
|
37
|
-
self.
|
|
73
|
+
self._patch_combine = patch_combine
|
|
38
74
|
self.reduction_classpath = reduction
|
|
39
|
-
self.reduction
|
|
40
|
-
|
|
41
|
-
self.before_reduction_transforms
|
|
42
|
-
self.after_reduction_transforms
|
|
43
|
-
self.final_transforms
|
|
44
|
-
self.
|
|
75
|
+
self.reduction: Reduction
|
|
76
|
+
|
|
77
|
+
self.before_reduction_transforms: list[Transform] = []
|
|
78
|
+
self.after_reduction_transforms: list[Transform] = []
|
|
79
|
+
self.final_transforms: list[Transform] = []
|
|
80
|
+
self.patch_combine: PathCombine | None = None
|
|
45
81
|
|
|
46
82
|
self.output_layer_accumulator: dict[int, dict[int, Accumulator]] = {}
|
|
47
83
|
self.attributes: dict[int, dict[int, dict[int, Attribute]]] = {}
|
|
48
84
|
self.names: dict[int, str] = {}
|
|
49
85
|
self.nb_data_augmentation = 0
|
|
50
86
|
|
|
87
|
+
@abstractmethod
|
|
51
88
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
52
|
-
transforms_type = [
|
|
53
|
-
|
|
54
|
-
|
|
89
|
+
transforms_type = [
|
|
90
|
+
"before_reduction_transforms",
|
|
91
|
+
"after_reduction_transforms",
|
|
92
|
+
"final_transforms",
|
|
93
|
+
]
|
|
94
|
+
for name, _transform_type, transform_type in [
|
|
95
|
+
(k, getattr(self, f"_{k}"), getattr(self, k)) for k in transforms_type
|
|
96
|
+
]:
|
|
97
|
+
|
|
55
98
|
if _transform_type is not None:
|
|
56
99
|
for classpath, transform in _transform_type.items():
|
|
57
|
-
transform = transform.
|
|
58
|
-
|
|
100
|
+
transform = transform.get_transform(
|
|
101
|
+
classpath,
|
|
102
|
+
konfai_args=f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{name}",
|
|
103
|
+
)
|
|
104
|
+
transform.set_datasets(datasets)
|
|
59
105
|
transform_type.append(transform)
|
|
60
106
|
|
|
61
|
-
if self.
|
|
62
|
-
module, name =
|
|
63
|
-
self.
|
|
107
|
+
if self._patch_combine is not None:
|
|
108
|
+
module, name = get_module(self._patch_combine, "konfai.data.patching")
|
|
109
|
+
self.patch_combine = config(f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset")(
|
|
110
|
+
getattr(importlib.import_module(module), name)
|
|
111
|
+
)(config=None)
|
|
64
112
|
|
|
65
|
-
module, name =
|
|
113
|
+
module, name = get_module(self.reduction_classpath, "konfai.predictor")
|
|
66
114
|
if module == "konfai.predictor":
|
|
67
115
|
self.reduction = getattr(importlib.import_module(module), name)()
|
|
68
116
|
else:
|
|
69
|
-
self.reduction = config(
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
117
|
+
self.reduction = config(
|
|
118
|
+
f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{self.reduction_classpath}"
|
|
119
|
+
)(getattr(importlib.import_module(module), name))(config=None)
|
|
120
|
+
|
|
121
|
+
def set_patch_config(
|
|
122
|
+
self,
|
|
123
|
+
patch_size: list[int] | None,
|
|
124
|
+
overlap: int | None,
|
|
125
|
+
nb_data_augmentation: int,
|
|
126
|
+
) -> None:
|
|
127
|
+
if patch_size is not None and overlap is not None:
|
|
128
|
+
if self.patch_combine is not None:
|
|
129
|
+
self.patch_combine.set_patch_config(patch_size, overlap)
|
|
76
130
|
else:
|
|
77
|
-
self.
|
|
131
|
+
self.patch_combine = None
|
|
78
132
|
self.nb_data_augmentation = nb_data_augmentation
|
|
79
|
-
|
|
80
|
-
def
|
|
81
|
-
super().
|
|
82
|
-
transforms_type = [
|
|
133
|
+
|
|
134
|
+
def to(self, device: torch.device):
|
|
135
|
+
super().to(device)
|
|
136
|
+
transforms_type = [
|
|
137
|
+
"before_reduction_transforms",
|
|
138
|
+
"after_reduction_transforms",
|
|
139
|
+
"final_transforms",
|
|
140
|
+
]
|
|
83
141
|
for transform_type in [(getattr(self, k)) for k in transforms_type]:
|
|
84
142
|
if transform_type is not None:
|
|
85
143
|
for transform in transform_type:
|
|
86
|
-
transform.
|
|
144
|
+
transform.to(device)
|
|
87
145
|
|
|
88
146
|
@abstractmethod
|
|
89
|
-
def
|
|
147
|
+
def add_layer(
|
|
148
|
+
self,
|
|
149
|
+
index_dataset: int,
|
|
150
|
+
index_augmentation: int,
|
|
151
|
+
index_patch: int,
|
|
152
|
+
layer: torch.Tensor,
|
|
153
|
+
dataset: DatasetIter,
|
|
154
|
+
):
|
|
90
155
|
pass
|
|
91
156
|
|
|
92
|
-
def
|
|
93
|
-
return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all(
|
|
157
|
+
def is_done(self, index: int) -> bool:
|
|
158
|
+
return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all(
|
|
159
|
+
acc.is_full() for acc in self.output_layer_accumulator[index].values()
|
|
160
|
+
)
|
|
94
161
|
|
|
95
162
|
@abstractmethod
|
|
96
|
-
def
|
|
163
|
+
def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
97
164
|
pass
|
|
98
165
|
|
|
99
|
-
def
|
|
166
|
+
def write_prediction(self, index: int, name: str, layer: torch.Tensor) -> None:
|
|
100
167
|
super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
|
|
101
168
|
self.attributes.pop(index)
|
|
102
169
|
|
|
103
|
-
class Reduction():
|
|
104
|
-
|
|
105
|
-
def __init__(self):
|
|
106
|
-
pass
|
|
107
|
-
|
|
108
|
-
class Mean(Reduction):
|
|
109
|
-
|
|
110
|
-
def __init__(self):
|
|
111
|
-
pass
|
|
112
|
-
|
|
113
|
-
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
114
|
-
return torch.mean(result.float(), dim=0)
|
|
115
|
-
|
|
116
|
-
class Median(Reduction):
|
|
117
|
-
|
|
118
|
-
def __init__(self):
|
|
119
|
-
pass
|
|
120
|
-
|
|
121
|
-
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
122
|
-
return torch.median(result.float(), dim=0).values
|
|
123
170
|
|
|
124
|
-
class OutSameAsGroupDataset(
|
|
125
|
-
|
|
126
|
-
@config("
|
|
127
|
-
def __init__(
|
|
128
|
-
|
|
129
|
-
|
|
171
|
+
class OutSameAsGroupDataset(OutputDataset):
|
|
172
|
+
|
|
173
|
+
@config("OutputDataset")
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
same_as_group: str = "default",
|
|
177
|
+
dataset_filename: str = "default:./Dataset:mha",
|
|
178
|
+
group: str = "default",
|
|
179
|
+
before_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
|
|
180
|
+
after_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
|
|
181
|
+
final_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
|
|
182
|
+
patch_combine: str | None = None,
|
|
183
|
+
reduction: str = "Mean",
|
|
184
|
+
inverse_transform: bool = True,
|
|
185
|
+
) -> None:
|
|
186
|
+
super().__init__(
|
|
187
|
+
dataset_filename,
|
|
188
|
+
group,
|
|
189
|
+
before_reduction_transforms,
|
|
190
|
+
after_reduction_transforms,
|
|
191
|
+
final_transforms,
|
|
192
|
+
patch_combine,
|
|
193
|
+
reduction,
|
|
194
|
+
)
|
|
195
|
+
self.group_src, self.group_dest = same_as_group.split(":")
|
|
130
196
|
self.inverse_transform = inverse_transform
|
|
131
197
|
|
|
132
|
-
def
|
|
133
|
-
|
|
134
|
-
|
|
198
|
+
def add_layer(
|
|
199
|
+
self,
|
|
200
|
+
index_dataset: int,
|
|
201
|
+
index_augmentation: int,
|
|
202
|
+
index_patch: int,
|
|
203
|
+
layer: torch.Tensor,
|
|
204
|
+
dataset: DatasetIter,
|
|
205
|
+
):
|
|
206
|
+
if (
|
|
207
|
+
index_dataset not in self.output_layer_accumulator
|
|
208
|
+
or index_augmentation not in self.output_layer_accumulator[index_dataset]
|
|
209
|
+
):
|
|
210
|
+
input_dataset = dataset.get_dataset_from_index(self.group_dest, index_dataset)
|
|
135
211
|
if index_dataset not in self.output_layer_accumulator:
|
|
136
212
|
self.output_layer_accumulator[index_dataset] = {}
|
|
137
213
|
self.attributes[index_dataset] = {}
|
|
138
214
|
self.names[index_dataset] = input_dataset.name
|
|
139
215
|
self.attributes[index_dataset][index_augmentation] = {}
|
|
140
216
|
|
|
141
|
-
self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(
|
|
217
|
+
self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(
|
|
218
|
+
input_dataset.patch.get_patch_slices(index_augmentation),
|
|
219
|
+
input_dataset.patch.patch_size,
|
|
220
|
+
self.patch_combine,
|
|
221
|
+
batch=False,
|
|
222
|
+
)
|
|
142
223
|
|
|
143
|
-
for i in range(len(input_dataset.patch.
|
|
224
|
+
for i in range(len(input_dataset.patch.get_patch_slices(index_augmentation))):
|
|
144
225
|
self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
|
|
145
226
|
|
|
146
227
|
if self.inverse_transform:
|
|
147
|
-
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].
|
|
148
|
-
layer = transform.inverse(
|
|
149
|
-
|
|
150
|
-
|
|
228
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
|
|
229
|
+
layer = transform.inverse(
|
|
230
|
+
self.names[index_dataset],
|
|
231
|
+
layer,
|
|
232
|
+
self.attributes[index_dataset][index_augmentation][index_patch],
|
|
233
|
+
)
|
|
234
|
+
self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
|
|
151
235
|
|
|
152
236
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
153
237
|
super().load(name_layer, datasets, groups)
|
|
154
|
-
|
|
238
|
+
|
|
155
239
|
if self.group_src not in groups.keys():
|
|
156
|
-
raise PredictorError(
|
|
157
|
-
f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
|
|
158
|
-
)
|
|
240
|
+
raise PredictorError(f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}.")
|
|
159
241
|
|
|
160
242
|
if self.group_dest not in groups[self.group_src]:
|
|
161
243
|
raise PredictorError(
|
|
162
244
|
f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
|
|
163
245
|
)
|
|
164
|
-
|
|
165
|
-
def
|
|
246
|
+
|
|
247
|
+
def _get_output(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
|
|
166
248
|
layer = self.output_layer_accumulator[index][index_augmentation].assemble()
|
|
167
249
|
if index_augmentation > 0:
|
|
168
|
-
|
|
250
|
+
|
|
169
251
|
i = 0
|
|
170
|
-
index_augmentation_tmp = index_augmentation-1
|
|
171
|
-
for
|
|
172
|
-
if index_augmentation_tmp >= i and index_augmentation_tmp < i+
|
|
173
|
-
for
|
|
174
|
-
layer =
|
|
252
|
+
index_augmentation_tmp = index_augmentation - 1
|
|
253
|
+
for data_augmentations in dataset.data_augmentations_list:
|
|
254
|
+
if index_augmentation_tmp >= i and index_augmentation_tmp < i + data_augmentations.nb:
|
|
255
|
+
for data_augmentation in reversed(data_augmentations.data_augmentations):
|
|
256
|
+
layer = data_augmentation.inverse(index, index_augmentation_tmp - i, layer)
|
|
175
257
|
break
|
|
176
|
-
i +=
|
|
258
|
+
i += data_augmentations.nb
|
|
177
259
|
|
|
178
260
|
for transform in self.before_reduction_transforms:
|
|
179
261
|
layer = transform(self.names[index], layer, self.attributes[index][index_augmentation][0])
|
|
180
|
-
|
|
262
|
+
|
|
181
263
|
return layer
|
|
182
264
|
|
|
183
|
-
def
|
|
184
|
-
result = torch.cat(
|
|
265
|
+
def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
266
|
+
result = torch.cat(
|
|
267
|
+
[
|
|
268
|
+
self._get_output(index, index_augmentation, dataset).unsqueeze(0)
|
|
269
|
+
for index_augmentation in self.output_layer_accumulator[index].keys()
|
|
270
|
+
],
|
|
271
|
+
dim=0,
|
|
272
|
+
)
|
|
185
273
|
self.output_layer_accumulator.pop(index)
|
|
186
274
|
result = self.reduction(result.float()).to(result.dtype)
|
|
187
275
|
for transform in self.after_reduction_transforms:
|
|
188
276
|
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
189
277
|
|
|
190
278
|
if self.inverse_transform:
|
|
191
|
-
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].
|
|
279
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
|
|
192
280
|
result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
|
|
193
|
-
|
|
281
|
+
|
|
194
282
|
for transform in self.final_transforms:
|
|
195
|
-
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
283
|
+
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
196
284
|
return result
|
|
197
285
|
|
|
198
|
-
class OutDatasetLoader():
|
|
199
286
|
|
|
200
|
-
|
|
287
|
+
class OutputDatasetLoader:
|
|
288
|
+
|
|
289
|
+
@config("OutputDataset")
|
|
201
290
|
def __init__(self, name_class: str = "OutSameAsGroupDataset") -> None:
|
|
202
291
|
self.name_class = name_class
|
|
203
292
|
|
|
204
|
-
def
|
|
205
|
-
return getattr(importlib.import_module("konfai.predictor"), self.name_class)(
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
293
|
+
def get_output_dataset(self, layer_name: str) -> OutputDataset:
|
|
294
|
+
return getattr(importlib.import_module("konfai.predictor"), self.name_class)(
|
|
295
|
+
config=None, konfai_args=f"Predictor.outputs_dataset.{layer_name}"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class _Predictor:
|
|
300
|
+
"""
|
|
301
|
+
Internal class that runs distributed inference over a dataset using a composite model.
|
|
302
|
+
|
|
303
|
+
This class handles patch-wise prediction, output accumulation, logging to TensorBoard, and
|
|
304
|
+
writing final predictions to disk. It is designed to be used as a context manager and
|
|
305
|
+
supports model ensembles via `ModelComposite`.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
world_size (int): Total number of processes or GPUs used.
|
|
309
|
+
global_rank (int): Rank of the current process across all nodes.
|
|
310
|
+
local_rank (int): Local GPU index within a single node.
|
|
311
|
+
autocast (bool): Whether to use automatic mixed precision (AMP).
|
|
312
|
+
predict_path (str): Output directory path where predictions and metrics are saved.
|
|
313
|
+
data_log (list[str] | None): List of logging targets in the format 'group/DataLogType/N'.
|
|
314
|
+
outputs_dataset (dict[str, OutputDataset]): Dictionary of output datasets to store predictions.
|
|
315
|
+
model_composite (DDP): Distributed model container that wraps the prediction model(s).
|
|
316
|
+
dataloader_prediction (DataLoader): DataLoader that provides prediction batches.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def __init__(
|
|
320
|
+
self,
|
|
321
|
+
world_size: int,
|
|
322
|
+
global_rank: int,
|
|
323
|
+
local_rank: int,
|
|
324
|
+
autocast: bool,
|
|
325
|
+
predict_path: str,
|
|
326
|
+
data_log: list[str] | None,
|
|
327
|
+
outputs_dataset: dict[str, OutputDataset],
|
|
328
|
+
model_composite: DDP,
|
|
329
|
+
dataloader_prediction: DataLoader,
|
|
330
|
+
) -> None:
|
|
331
|
+
self.world_size = world_size
|
|
211
332
|
self.global_rank = global_rank
|
|
212
333
|
self.local_rank = local_rank
|
|
213
334
|
|
|
214
|
-
self.
|
|
335
|
+
self.model_composite = model_composite
|
|
215
336
|
self.dataloader_prediction = dataloader_prediction
|
|
216
|
-
self.
|
|
337
|
+
self.outputs_dataset = outputs_dataset
|
|
217
338
|
self.autocast = autocast
|
|
218
|
-
|
|
339
|
+
|
|
219
340
|
self.it = 0
|
|
220
341
|
|
|
221
|
-
self.device = self.
|
|
342
|
+
self.device = self.model_composite.device
|
|
222
343
|
self.dataset: DatasetIter = self.dataloader_prediction.dataset
|
|
223
|
-
patch_size, overlap = self.dataset.
|
|
224
|
-
for
|
|
225
|
-
|
|
226
|
-
|
|
344
|
+
patch_size, overlap = self.dataset.get_patch_config()
|
|
345
|
+
for output_dataset in self.outputs_dataset.values():
|
|
346
|
+
output_dataset.set_patch_config(
|
|
347
|
+
[size for size in patch_size if size > 1] if patch_size else None,
|
|
348
|
+
overlap,
|
|
349
|
+
np.max(
|
|
350
|
+
[
|
|
351
|
+
int(
|
|
352
|
+
np.sum([data_augmentation.nb for data_augmentation in self.dataset.data_augmentations_list])
|
|
353
|
+
+ 1
|
|
354
|
+
),
|
|
355
|
+
1,
|
|
356
|
+
]
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
self.data_log: dict[str, tuple[DataLog, int]] = {}
|
|
227
360
|
if data_log is not None:
|
|
228
361
|
for data in data_log:
|
|
229
|
-
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
230
|
-
|
|
231
|
-
|
|
362
|
+
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
363
|
+
DataLog[data.split("/")[1]],
|
|
364
|
+
int(data.split("/")[2]),
|
|
365
|
+
)
|
|
366
|
+
self.tb = (
|
|
367
|
+
SummaryWriter(log_dir=predict_path + "Metric/")
|
|
368
|
+
if len(
|
|
369
|
+
[
|
|
370
|
+
network
|
|
371
|
+
for network in self.model_composite.module.get_networks().values()
|
|
372
|
+
if network.measure is not None
|
|
373
|
+
]
|
|
374
|
+
)
|
|
375
|
+
or len(self.data_log)
|
|
376
|
+
else None
|
|
377
|
+
)
|
|
378
|
+
|
|
232
379
|
def __enter__(self):
|
|
380
|
+
"""
|
|
381
|
+
Enters the prediction context and returns the predictor instance.
|
|
382
|
+
"""
|
|
233
383
|
return self
|
|
234
|
-
|
|
235
|
-
def __exit__(self,
|
|
384
|
+
|
|
385
|
+
def __exit__(self, exc_type, value, traceback):
|
|
386
|
+
"""
|
|
387
|
+
Closes the TensorBoard writer upon exit.
|
|
388
|
+
"""
|
|
236
389
|
if self.tb:
|
|
237
390
|
self.tb.close()
|
|
238
391
|
|
|
239
|
-
def
|
|
240
|
-
|
|
241
|
-
|
|
392
|
+
def get_input(
|
|
393
|
+
self,
|
|
394
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
395
|
+
) -> dict[tuple[str, bool], torch.Tensor]:
|
|
396
|
+
return {(k, v[5][0]): v[0] for k, v in data_dict.items()}
|
|
397
|
+
|
|
242
398
|
@torch.no_grad()
|
|
243
399
|
def run(self):
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
400
|
+
"""
|
|
401
|
+
Run the full prediction loop.
|
|
402
|
+
|
|
403
|
+
Iterates over the prediction DataLoader, performs inference using the composite model,
|
|
404
|
+
applies reduction (e.g., mean), and writes the final results using each `OutputDataset`.
|
|
405
|
+
|
|
406
|
+
Also logs intermediate data and metrics to TensorBoard if enabled.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
self.model_composite.eval()
|
|
410
|
+
self.model_composite.module.set_state(NetState.PREDICTION)
|
|
247
411
|
self.dataloader_prediction.dataset.load("Prediction")
|
|
248
|
-
with tqdm.tqdm(
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
412
|
+
with tqdm.tqdm(
|
|
413
|
+
iterable=enumerate(self.dataloader_prediction),
|
|
414
|
+
leave=True,
|
|
415
|
+
desc=f"Prediction : {description(self.model_composite)}",
|
|
416
|
+
total=len(self.dataloader_prediction),
|
|
417
|
+
ncols=0,
|
|
418
|
+
) as batch_iter:
|
|
419
|
+
for _, data_dict in batch_iter:
|
|
420
|
+
with torch.amp.autocast("cuda", enabled=self.autocast):
|
|
421
|
+
input_tensor = self.get_input(data_dict)
|
|
422
|
+
for name, output in self.model_composite(input_tensor, list(self.outputs_dataset.keys())):
|
|
253
423
|
self._predict_log(data_dict)
|
|
254
|
-
|
|
255
|
-
for i, (index, patch_augmentation, patch_index) in enumerate(
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
424
|
+
output_dataset = self.outputs_dataset[name]
|
|
425
|
+
for i, (index, patch_augmentation, patch_index) in enumerate(
|
|
426
|
+
[
|
|
427
|
+
(int(index), int(patch_augmentation), int(patch_index))
|
|
428
|
+
for index, patch_augmentation, patch_index in zip(
|
|
429
|
+
list(data_dict.values())[0][1],
|
|
430
|
+
list(data_dict.values())[0][2],
|
|
431
|
+
list(data_dict.values())[0][3],
|
|
432
|
+
)
|
|
433
|
+
]
|
|
434
|
+
):
|
|
435
|
+
output_dataset.add_layer(
|
|
436
|
+
index,
|
|
437
|
+
patch_augmentation,
|
|
438
|
+
patch_index,
|
|
439
|
+
output[i].cpu(),
|
|
440
|
+
self.dataset,
|
|
441
|
+
)
|
|
442
|
+
if output_dataset.is_done(index):
|
|
443
|
+
output_dataset.write_prediction(
|
|
444
|
+
index,
|
|
445
|
+
self.dataset.get_dataset_from_index(list(data_dict.keys())[0], index).name.split(
|
|
446
|
+
"/"
|
|
447
|
+
)[-1],
|
|
448
|
+
output_dataset.get_output(index, self.dataset),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
batch_iter.set_description(f"Prediction : {description(self.model_composite)}")
|
|
261
452
|
self.it += 1
|
|
262
|
-
|
|
263
|
-
def _predict_log(
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
453
|
+
|
|
454
|
+
def _predict_log(
|
|
455
|
+
self,
|
|
456
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
457
|
+
):
|
|
458
|
+
"""
|
|
459
|
+
Log prediction results to TensorBoard, including images and metrics.
|
|
460
|
+
|
|
461
|
+
This method handles:
|
|
462
|
+
- Logging image-like data (e.g., inputs, outputs, masks) using `DataLog` instances,
|
|
463
|
+
based on the `data_log` configuration.
|
|
464
|
+
- Logging scalar loss and metric values (if present in the network) under the `Prediction/` namespace.
|
|
465
|
+
- Dynamically retrieving additional feature maps or intermediate layers if requested via `data_log`.
|
|
466
|
+
|
|
467
|
+
Logging is performed only on the global rank 0 process and only if `TensorBoard` is active.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
data_dict (dict): Dictionary mapping group names to 6-tuples containing:
|
|
471
|
+
- input tensor,
|
|
472
|
+
- index,
|
|
473
|
+
- patch_augmentation,
|
|
474
|
+
- patch_index,
|
|
475
|
+
- metadata (list of strings),
|
|
476
|
+
- `requires_grad` flag (as a tensor).
|
|
477
|
+
"""
|
|
478
|
+
measures = DistributedObject.get_measure(
|
|
479
|
+
self.world_size,
|
|
480
|
+
self.global_rank,
|
|
481
|
+
self.local_rank,
|
|
482
|
+
{"": self.model_composite.module},
|
|
483
|
+
1,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
if self.global_rank == 0 and self.tb is not None:
|
|
487
|
+
data_log = []
|
|
488
|
+
if len(self.data_log):
|
|
269
489
|
for name, data_type in self.data_log.items():
|
|
270
490
|
if name in data_dict:
|
|
271
|
-
data_type[0](
|
|
491
|
+
data_type[0](
|
|
492
|
+
self.tb,
|
|
493
|
+
f"Prediction/{name}",
|
|
494
|
+
data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
|
|
495
|
+
self.it,
|
|
496
|
+
)
|
|
272
497
|
else:
|
|
273
|
-
|
|
498
|
+
data_log.append(name.replace(":", "."))
|
|
274
499
|
|
|
275
|
-
for name, network in self.
|
|
500
|
+
for name, network in self.model_composite.module.get_networks().items():
|
|
276
501
|
if network.measure is not None:
|
|
277
|
-
self.tb.add_scalars(
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
502
|
+
self.tb.add_scalars(
|
|
503
|
+
f"Prediction/{name}/Loss",
|
|
504
|
+
{k: v[1] for k, v in measures[name][0].items()},
|
|
505
|
+
self.it,
|
|
506
|
+
)
|
|
507
|
+
self.tb.add_scalars(
|
|
508
|
+
f"Prediction/{name}/Metric",
|
|
509
|
+
{k: v[1] for k, v in measures[name][1].items()},
|
|
510
|
+
self.it,
|
|
511
|
+
)
|
|
512
|
+
if len(data_log):
|
|
513
|
+
for name, layer, _ in self.model_composite.module.get_layers(
|
|
514
|
+
[v.to(0) for k, v in self.get_input(data_dict).items() if k[1]], data_log
|
|
515
|
+
):
|
|
516
|
+
self.data_log[name][0](
|
|
517
|
+
self.tb,
|
|
518
|
+
f"Prediction/{name}",
|
|
519
|
+
layer[: self.data_log[name][1]].detach().cpu().numpy(),
|
|
520
|
+
self.it,
|
|
521
|
+
)
|
|
522
|
+
|
|
282
523
|
|
|
283
524
|
class ModelComposite(Network):
|
|
525
|
+
"""
|
|
526
|
+
A composite model that replicates a given base network multiple times and combines their outputs.
|
|
527
|
+
|
|
528
|
+
This class is designed to handle model ensembles or repeated predictions from the same architecture.
|
|
529
|
+
It creates `nb_models` deep copies of the input `model`, each with its own name and output branch,
|
|
530
|
+
and aggregates their outputs using a provided `Reduction` strategy (e.g., mean, median).
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
model (Network): The base network to replicate.
|
|
534
|
+
nb_models (int): Number of copies of the model to create.
|
|
535
|
+
combine (Reduction): The reduction method used to combine outputs from all model replicas.
|
|
536
|
+
|
|
537
|
+
Attributes:
|
|
538
|
+
combine (Reduction): The reduction method used during forward inference.
|
|
539
|
+
"""
|
|
284
540
|
|
|
285
541
|
def __init__(self, model: Network, nb_models: int, combine: Reduction):
|
|
286
|
-
super().__init__(
|
|
542
|
+
super().__init__(
|
|
543
|
+
model.in_channels,
|
|
544
|
+
model.optimizer,
|
|
545
|
+
model.lr_schedulers_loader,
|
|
546
|
+
model.outputs_criterions_loader,
|
|
547
|
+
model.patch,
|
|
548
|
+
model.nb_batch_per_step,
|
|
549
|
+
model.init_type,
|
|
550
|
+
model.init_gain,
|
|
551
|
+
model.dim,
|
|
552
|
+
)
|
|
287
553
|
self.combine = combine
|
|
288
554
|
for i in range(nb_models):
|
|
289
|
-
self.add_module(
|
|
555
|
+
self.add_module(
|
|
556
|
+
f"Model_{i}",
|
|
557
|
+
copy.deepcopy(model),
|
|
558
|
+
in_branch=[0],
|
|
559
|
+
out_branch=[f"output_{i}"],
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
def load(self, state_dicts: list[dict[str, dict[str, torch.Tensor]]]):
|
|
563
|
+
"""
|
|
564
|
+
Load weights for each sub-model in the composite from the corresponding state dictionaries.
|
|
290
565
|
|
|
291
|
-
|
|
566
|
+
Args:
|
|
567
|
+
state_dicts (list): A list of state dictionaries, one for each model replica.
|
|
568
|
+
"""
|
|
292
569
|
for i, state_dict in enumerate(state_dicts):
|
|
293
|
-
self["Model_{}"
|
|
294
|
-
self["Model_{}"
|
|
295
|
-
|
|
296
|
-
def forward(
|
|
570
|
+
self[f"Model_{i}"].load(state_dict, init=False)
|
|
571
|
+
self[f"Model_{i}"].set_name(f"{self[f'Model_{i}'].get_name()}_{i}")
|
|
572
|
+
|
|
573
|
+
def forward(
|
|
574
|
+
self,
|
|
575
|
+
data_dict: dict[tuple[str, bool], torch.Tensor],
|
|
576
|
+
output_layers: list[str] = [],
|
|
577
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
578
|
+
"""
|
|
579
|
+
Perform a forward pass on all model replicas and aggregate their outputs.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
data_dict (dict): A dictionary mapping (group_name, requires_grad) to input tensors.
|
|
583
|
+
output_layers (list): List of output layer names to extract from each sub-model.
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
list[tuple[str, torch.Tensor]]: Aggregated output for each layer, after applying the reduction.
|
|
587
|
+
"""
|
|
297
588
|
result = {}
|
|
298
589
|
for name, module in self.items():
|
|
299
590
|
result[name] = module(data_dict, output_layers)
|
|
300
|
-
|
|
591
|
+
|
|
301
592
|
aggregated = defaultdict(list)
|
|
302
593
|
for module_outputs in result.values():
|
|
303
594
|
for key, tensor in module_outputs:
|
|
304
595
|
aggregated[key].append(tensor)
|
|
305
|
-
|
|
596
|
+
|
|
306
597
|
final_outputs = []
|
|
307
598
|
for key, tensors in aggregated.items():
|
|
308
599
|
final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
|
|
309
600
|
|
|
310
601
|
return final_outputs
|
|
311
|
-
|
|
602
|
+
|
|
312
603
|
|
|
313
604
|
class Predictor(DistributedObject):
|
|
605
|
+
"""
|
|
606
|
+
KonfAI's main prediction controller.
|
|
607
|
+
|
|
608
|
+
This class orchestrates the prediction phase by:
|
|
609
|
+
- Loading model weights from checkpoint(s) or URL(s)
|
|
610
|
+
- Preparing datasets and output configurations
|
|
611
|
+
- Managing distributed inference with optional multi-GPU support
|
|
612
|
+
- Applying transformations and saving predictions
|
|
613
|
+
- Optionally logging results to TensorBoard
|
|
614
|
+
|
|
615
|
+
Attributes:
|
|
616
|
+
model (Network): The neural network model to use for prediction.
|
|
617
|
+
dataset (DataPrediction): Dataset manager for prediction data.
|
|
618
|
+
combine_classpath (str): Path to the reduction strategy (e.g., "Mean").
|
|
619
|
+
autocast (bool): Whether to enable AMP inference.
|
|
620
|
+
outputs_dataset (dict[str, OutputDataset]): Mapping from layer names to output writers.
|
|
621
|
+
data_log (list[str] | None): List of tensors to log during inference.
|
|
622
|
+
"""
|
|
314
623
|
|
|
315
624
|
@config("Predictor")
|
|
316
|
-
def __init__(
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
625
|
+
def __init__(
|
|
626
|
+
self,
|
|
627
|
+
model: ModelLoader = ModelLoader(),
|
|
628
|
+
dataset: DataPrediction = DataPrediction(),
|
|
629
|
+
combine: str = "Mean",
|
|
630
|
+
train_name: str = "name",
|
|
631
|
+
manual_seed: int | None = None,
|
|
632
|
+
gpu_checkpoints: list[str] | None = None,
|
|
633
|
+
autocast: bool = False,
|
|
634
|
+
outputs_dataset: dict[str, OutputDatasetLoader] | None = {"default:Default": OutputDatasetLoader()},
|
|
635
|
+
data_log: list[str] | None = None,
|
|
636
|
+
) -> None:
|
|
326
637
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
327
638
|
exit(0)
|
|
328
639
|
super().__init__(train_name)
|
|
@@ -331,99 +642,194 @@ class Predictor(DistributedObject):
|
|
|
331
642
|
self.combine_classpath = combine
|
|
332
643
|
self.autocast = autocast
|
|
333
644
|
|
|
334
|
-
self.model = model.
|
|
645
|
+
self.model = model.get_model(train=False)
|
|
335
646
|
self.it = 0
|
|
336
|
-
self.
|
|
337
|
-
self.
|
|
647
|
+
self.outputs_dataset_loader = outputs_dataset if outputs_dataset else {}
|
|
648
|
+
self.outputs_dataset = {
|
|
649
|
+
name.replace(":", "."): value.get_output_dataset(name)
|
|
650
|
+
for name, value in self.outputs_dataset_loader.items()
|
|
651
|
+
}
|
|
338
652
|
|
|
339
653
|
self.datasets_filename = []
|
|
340
|
-
self.predict_path =
|
|
341
|
-
self.
|
|
342
|
-
|
|
343
|
-
self.
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
654
|
+
self.predict_path = predictions_directory() + self.name + "/"
|
|
655
|
+
for output_dataset in self.outputs_dataset.values():
|
|
656
|
+
self.datasets_filename.append(output_dataset.filename)
|
|
657
|
+
output_dataset.filename = f"{self.predict_path}{output_dataset.filename}"
|
|
658
|
+
|
|
659
|
+
self.data_log = data_log
|
|
660
|
+
modules = []
|
|
661
|
+
for i, _ in self.model.named_modules():
|
|
662
|
+
modules.append(i)
|
|
663
|
+
if self.data_log is not None:
|
|
664
|
+
for k in self.data_log:
|
|
665
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
666
|
+
if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
|
|
667
|
+
raise PredictorError(
|
|
668
|
+
f"Invalid key '{tmp}' in `data_log`.",
|
|
669
|
+
f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
|
|
670
|
+
f"nor a valid module name in the model ({modules}).",
|
|
671
|
+
"Please check your `data_log` configuration,"
|
|
672
|
+
" it should reference either a model output or a dataset group.",
|
|
673
|
+
)
|
|
674
|
+
|
|
347
675
|
self.gpu_checkpoints = gpu_checkpoints
|
|
348
676
|
|
|
349
677
|
def _load(self) -> list[dict[str, dict[str, torch.Tensor]]]:
|
|
350
|
-
|
|
678
|
+
"""
|
|
679
|
+
Load pretrained model weights from configured paths or URLs.
|
|
680
|
+
|
|
681
|
+
This method handles both remote and local model sources:
|
|
682
|
+
- If the model path is a URL (starting with "https://"), it uses `torch.hub.load_state_dict_from_url`
|
|
683
|
+
to download and load the state dict.
|
|
684
|
+
- If the model path is local:
|
|
685
|
+
- It either loads the explicit file or resolves the latest model file in a default directory
|
|
686
|
+
based on the prediction name.
|
|
687
|
+
- All loaded state dicts are returned as a list of nested dictionaries mapping module names
|
|
688
|
+
to parameter tensors.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
list[dict[str, dict[str, torch.Tensor]]]: A list of state dictionaries, one per model.
|
|
692
|
+
|
|
693
|
+
Raises:
|
|
694
|
+
Exception: If a model path does not exist or cannot be loaded.
|
|
695
|
+
"""
|
|
696
|
+
model_paths = path_to_models().split(":")
|
|
351
697
|
state_dicts = []
|
|
352
698
|
for model_path in model_paths:
|
|
353
699
|
if model_path.startswith("https://"):
|
|
354
700
|
try:
|
|
355
|
-
state_dicts.append(
|
|
356
|
-
|
|
357
|
-
|
|
701
|
+
state_dicts.append(
|
|
702
|
+
torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True)
|
|
703
|
+
)
|
|
704
|
+
except Exception:
|
|
705
|
+
raise Exception(f"Model : {model_path} does not exist !")
|
|
358
706
|
else:
|
|
359
707
|
if model_path != "":
|
|
360
708
|
path = ""
|
|
361
709
|
name = model_path
|
|
362
710
|
else:
|
|
363
711
|
if self.name.endswith(".pt"):
|
|
364
|
-
path =
|
|
712
|
+
path = models_directory() + "/".join(self.name.split("/")[:-1]) + "/StateDict/"
|
|
365
713
|
name = self.name.split("/")[-1]
|
|
366
714
|
else:
|
|
367
|
-
path =
|
|
715
|
+
path = models_directory() + self.name + "/StateDict/"
|
|
368
716
|
name = sorted(os.listdir(path))[-1]
|
|
369
|
-
if os.path.exists(path+name):
|
|
370
|
-
state_dicts.append(torch.load(path+name, weights_only=
|
|
717
|
+
if os.path.exists(path + name):
|
|
718
|
+
state_dicts.append(torch.load(path + name, weights_only=True))
|
|
371
719
|
else:
|
|
372
|
-
raise Exception("Model : {} does not exist !"
|
|
720
|
+
raise Exception(f"Model : {path + name} does not exist !")
|
|
373
721
|
return state_dicts
|
|
374
|
-
|
|
722
|
+
|
|
375
723
|
def setup(self, world_size: int):
|
|
724
|
+
"""
|
|
725
|
+
Set up the predictor for inference.
|
|
726
|
+
|
|
727
|
+
This method performs all necessary initialization steps before running predictions:
|
|
728
|
+
- Ensures output directories exist, and optionally prompts the user before overwriting existing predictions.
|
|
729
|
+
- Copies the current configuration file (Prediction.yml) into the output directory for reproducibility.
|
|
730
|
+
- Initializes the model in prediction mode, including output configuration and channel tracing.
|
|
731
|
+
- Validates that the configured output groups match existing modules in the model architecture.
|
|
732
|
+
- Dynamically loads pretrained weights from local files or remote URLs.
|
|
733
|
+
- Wraps the base model into a `ModelComposite` to support ensemble inference.
|
|
734
|
+
- Initializes the prediction dataloader, with proper distribution across available GPUs.
|
|
735
|
+
- Loads and prepares each configured `OutputDataset` object for storing predictions.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
world_size (int): Total number of processes or GPUs used for distributed prediction.
|
|
739
|
+
|
|
740
|
+
Raises:
|
|
741
|
+
PredictorError: If an output group does not match any module in the model.
|
|
742
|
+
Exception: If a specified model file or URL is invalid or inaccessible.
|
|
743
|
+
"""
|
|
376
744
|
for dataset_filename in self.datasets_filename:
|
|
377
|
-
path = self.predict_path +dataset_filename
|
|
745
|
+
path = self.predict_path + dataset_filename
|
|
378
746
|
if os.path.exists(path):
|
|
379
747
|
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
380
|
-
accept = builtins.input(
|
|
748
|
+
accept = builtins.input(
|
|
749
|
+
f"The prediction {path} already exists ! Do you want to overwrite it (yes,no) : "
|
|
750
|
+
)
|
|
381
751
|
if accept != "yes":
|
|
382
752
|
return
|
|
383
|
-
|
|
753
|
+
|
|
384
754
|
if not os.path.exists(path):
|
|
385
755
|
os.makedirs(path)
|
|
386
756
|
|
|
387
|
-
shutil.copyfile(
|
|
757
|
+
shutil.copyfile(config_file(), self.predict_path + "Prediction.yml")
|
|
388
758
|
|
|
389
|
-
|
|
390
|
-
self.model.
|
|
391
|
-
self.model.init_outputsGroup()
|
|
759
|
+
self.model.init(self.autocast, State.PREDICTION, self.dataset.get_groups_dest())
|
|
760
|
+
self.model.init_outputs_group()
|
|
392
761
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
393
|
-
|
|
762
|
+
|
|
394
763
|
modules = []
|
|
395
|
-
for i,_,_ in self.model.
|
|
764
|
+
for i, _, _ in self.model.named_module_args_dict():
|
|
396
765
|
modules.append(i)
|
|
397
|
-
for output_group in self.
|
|
398
|
-
if output_group not in modules:
|
|
399
|
-
raise PredictorError(
|
|
400
|
-
"
|
|
401
|
-
"
|
|
766
|
+
for output_group in self.outputs_dataset.keys():
|
|
767
|
+
if output_group.replace(";accu;", "") not in modules:
|
|
768
|
+
raise PredictorError(
|
|
769
|
+
f"The output group '{output_group}' defined in 'outputs_criterions' "
|
|
770
|
+
"does not correspond to any module in the model.",
|
|
771
|
+
f"Available modules: {modules}",
|
|
772
|
+
"Please check that the name matches exactly a submodule or" "output of your model architecture.",
|
|
402
773
|
)
|
|
403
|
-
|
|
404
|
-
module, name =
|
|
774
|
+
|
|
775
|
+
module, name = get_module(self.combine_classpath, "konfai.predictor")
|
|
405
776
|
if module == "konfai.predictor":
|
|
406
777
|
combine = getattr(importlib.import_module(module), name)()
|
|
407
778
|
else:
|
|
408
|
-
combine = config("{}.{
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
self.
|
|
779
|
+
combine = config(f"{konfai_root()}.{self.combine_classpath}")(
|
|
780
|
+
getattr(importlib.import_module(module), name)
|
|
781
|
+
)(config=None)
|
|
782
|
+
|
|
783
|
+
self.model_composite = ModelComposite(self.model, len(path_to_models().split(":")), combine)
|
|
784
|
+
self.model_composite.load(self._load())
|
|
413
785
|
|
|
414
|
-
if
|
|
786
|
+
if (
|
|
787
|
+
len(list(self.outputs_dataset.keys())) == 0
|
|
788
|
+
and len(
|
|
789
|
+
[network for network in self.model_composite.get_networks().values() if network.measure is not None]
|
|
790
|
+
)
|
|
791
|
+
== 0
|
|
792
|
+
):
|
|
415
793
|
exit(0)
|
|
416
|
-
|
|
417
|
-
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
418
|
-
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
419
|
-
for name, outDataset in self.outsDataset.items():
|
|
420
|
-
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
424
|
-
modelComposite = Network.to(self.modelComposite, local_rank*self.size)
|
|
425
|
-
modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
|
|
426
|
-
with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
427
|
-
p.run()
|
|
428
794
|
|
|
429
|
-
|
|
795
|
+
self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
|
|
796
|
+
self.dataloader, _, _ = self.dataset.get_data(world_size // self.size)
|
|
797
|
+
for name, output_dataset in self.outputs_dataset.items():
|
|
798
|
+
output_dataset.load(
|
|
799
|
+
name.replace(".", ":"),
|
|
800
|
+
list(self.dataset.datasets.values()),
|
|
801
|
+
{src: dest for src, inner in self.dataset.groups_src.items() for dest in inner},
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
def run_process(
|
|
805
|
+
self,
|
|
806
|
+
world_size: int,
|
|
807
|
+
global_rank: int,
|
|
808
|
+
local_rank: int,
|
|
809
|
+
dataloaders: list[DataLoader],
|
|
810
|
+
):
|
|
811
|
+
"""
|
|
812
|
+
Launch prediction on the given process rank.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
world_size (int): Total number of processes.
|
|
816
|
+
global_rank (int): Rank of the current process.
|
|
817
|
+
local_rank (int): Local device rank.
|
|
818
|
+
dataloaders (list[DataLoader]): List of data loaders for prediction.
|
|
819
|
+
"""
|
|
820
|
+
model_composite = Network.to(self.model_composite, local_rank * self.size)
|
|
821
|
+
model_composite = (
|
|
822
|
+
DDP(model_composite, static_graph=True) if torch.cuda.is_available() else CPUModel(model_composite)
|
|
823
|
+
)
|
|
824
|
+
with _Predictor(
|
|
825
|
+
world_size,
|
|
826
|
+
global_rank,
|
|
827
|
+
local_rank,
|
|
828
|
+
self.autocast,
|
|
829
|
+
self.predict_path,
|
|
830
|
+
self.data_log,
|
|
831
|
+
self.outputs_dataset,
|
|
832
|
+
model_composite,
|
|
833
|
+
*dataloaders,
|
|
834
|
+
) as p:
|
|
835
|
+
p.run()
|