konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/predictor.py
CHANGED
|
@@ -1,304 +1,594 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
1
|
import builtins
|
|
2
|
+
import copy
|
|
3
3
|
import importlib
|
|
4
|
+
import os
|
|
4
5
|
import shutil
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
5
10
|
import torch
|
|
6
11
|
import tqdm
|
|
7
|
-
import
|
|
12
|
+
from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
8
15
|
|
|
9
|
-
from konfai import
|
|
10
|
-
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
|
|
12
|
-
from konfai.utils.dataset import Dataset, Attribute
|
|
16
|
+
from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
|
|
13
17
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
14
18
|
from konfai.data.patching import Accumulator, PathCombine
|
|
15
|
-
from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
|
|
16
19
|
from konfai.data.transform import Transform, TransformLoader
|
|
20
|
+
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
21
|
+
from konfai.utils.config import config
|
|
22
|
+
from konfai.utils.dataset import Attribute, Dataset
|
|
23
|
+
from konfai.utils.utils import DataLog, DistributedObject, NeedDevice, PredictorError, State, description, get_module
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Reduction(ABC):
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def __init__(self):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Mean(Reduction):
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
return torch.mean(result.float(), dim=0)
|
|
17
44
|
|
|
18
|
-
from torch.utils.tensorboard.writer import SummaryWriter
|
|
19
|
-
from typing import Union
|
|
20
|
-
import numpy as np
|
|
21
|
-
import torch.distributed as dist
|
|
22
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
23
|
-
from torch.utils.data import DataLoader
|
|
24
|
-
import importlib
|
|
25
|
-
import copy
|
|
26
|
-
from collections import defaultdict
|
|
27
45
|
|
|
28
|
-
class
|
|
46
|
+
class Median(Reduction):
|
|
47
|
+
|
|
48
|
+
def __init__(self):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
52
|
+
return torch.median(result.float(), dim=0).values
|
|
29
53
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
54
|
+
|
|
55
|
+
class OutputDataset(Dataset, NeedDevice, ABC):
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
filename: str,
|
|
60
|
+
group: str,
|
|
61
|
+
before_reduction_transforms: dict[str, TransformLoader],
|
|
62
|
+
after_reduction_transforms: dict[str, TransformLoader],
|
|
63
|
+
final_transforms: dict[str, TransformLoader],
|
|
64
|
+
patch_combine: str | None,
|
|
65
|
+
reduction: str,
|
|
66
|
+
) -> None:
|
|
67
|
+
filename, file_format = filename.split(":")
|
|
68
|
+
super().__init__(filename, file_format)
|
|
33
69
|
self.group = group
|
|
34
70
|
self._before_reduction_transforms = before_reduction_transforms
|
|
35
71
|
self._after_reduction_transforms = after_reduction_transforms
|
|
36
72
|
self._final_transforms = final_transforms
|
|
37
|
-
self.
|
|
73
|
+
self._patch_combine = patch_combine
|
|
38
74
|
self.reduction_classpath = reduction
|
|
39
|
-
self.reduction
|
|
40
|
-
|
|
41
|
-
self.before_reduction_transforms
|
|
42
|
-
self.after_reduction_transforms
|
|
43
|
-
self.final_transforms
|
|
44
|
-
self.
|
|
75
|
+
self.reduction: Reduction
|
|
76
|
+
|
|
77
|
+
self.before_reduction_transforms: list[Transform] = []
|
|
78
|
+
self.after_reduction_transforms: list[Transform] = []
|
|
79
|
+
self.final_transforms: list[Transform] = []
|
|
80
|
+
self.patch_combine: PathCombine | None = None
|
|
45
81
|
|
|
46
82
|
self.output_layer_accumulator: dict[int, dict[int, Accumulator]] = {}
|
|
47
83
|
self.attributes: dict[int, dict[int, dict[int, Attribute]]] = {}
|
|
48
84
|
self.names: dict[int, str] = {}
|
|
49
85
|
self.nb_data_augmentation = 0
|
|
50
86
|
|
|
87
|
+
@abstractmethod
|
|
51
88
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
52
|
-
transforms_type = [
|
|
53
|
-
|
|
54
|
-
|
|
89
|
+
transforms_type = [
|
|
90
|
+
"before_reduction_transforms",
|
|
91
|
+
"after_reduction_transforms",
|
|
92
|
+
"final_transforms",
|
|
93
|
+
]
|
|
94
|
+
for name, _transform_type, transform_type in [
|
|
95
|
+
(k, getattr(self, f"_{k}"), getattr(self, k)) for k in transforms_type
|
|
96
|
+
]:
|
|
97
|
+
|
|
55
98
|
if _transform_type is not None:
|
|
56
99
|
for classpath, transform in _transform_type.items():
|
|
57
|
-
transform = transform.
|
|
58
|
-
|
|
100
|
+
transform = transform.get_transform(
|
|
101
|
+
classpath,
|
|
102
|
+
konfai_args=f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{name}",
|
|
103
|
+
)
|
|
104
|
+
transform.set_datasets(datasets)
|
|
59
105
|
transform_type.append(transform)
|
|
60
106
|
|
|
61
|
-
if self.
|
|
62
|
-
module, name =
|
|
63
|
-
self.
|
|
107
|
+
if self._patch_combine is not None:
|
|
108
|
+
module, name = get_module(self._patch_combine, "konfai.data.patching")
|
|
109
|
+
self.patch_combine = config(f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset")(
|
|
110
|
+
getattr(importlib.import_module(module), name)
|
|
111
|
+
)(config=None)
|
|
64
112
|
|
|
65
|
-
module, name =
|
|
113
|
+
module, name = get_module(self.reduction_classpath, "konfai.predictor")
|
|
66
114
|
if module == "konfai.predictor":
|
|
67
|
-
self.reduction = getattr(importlib.import_module(module), name)
|
|
115
|
+
self.reduction = getattr(importlib.import_module(module), name)()
|
|
68
116
|
else:
|
|
69
|
-
self.reduction = config(
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
117
|
+
self.reduction = config(
|
|
118
|
+
f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{self.reduction_classpath}"
|
|
119
|
+
)(getattr(importlib.import_module(module), name))(config=None)
|
|
120
|
+
|
|
121
|
+
def set_patch_config(
|
|
122
|
+
self,
|
|
123
|
+
patch_size: list[int] | None,
|
|
124
|
+
overlap: int | None,
|
|
125
|
+
nb_data_augmentation: int,
|
|
126
|
+
) -> None:
|
|
127
|
+
if patch_size is not None and overlap is not None:
|
|
128
|
+
if self.patch_combine is not None:
|
|
129
|
+
self.patch_combine.set_patch_config(patch_size, overlap)
|
|
76
130
|
else:
|
|
77
|
-
self.
|
|
131
|
+
self.patch_combine = None
|
|
78
132
|
self.nb_data_augmentation = nb_data_augmentation
|
|
79
|
-
|
|
80
|
-
def
|
|
81
|
-
super().
|
|
82
|
-
transforms_type = [
|
|
133
|
+
|
|
134
|
+
def to(self, device: torch.device):
|
|
135
|
+
super().to(device)
|
|
136
|
+
transforms_type = [
|
|
137
|
+
"before_reduction_transforms",
|
|
138
|
+
"after_reduction_transforms",
|
|
139
|
+
"final_transforms",
|
|
140
|
+
]
|
|
83
141
|
for transform_type in [(getattr(self, k)) for k in transforms_type]:
|
|
84
142
|
if transform_type is not None:
|
|
85
143
|
for transform in transform_type:
|
|
86
|
-
transform.
|
|
144
|
+
transform.to(device)
|
|
87
145
|
|
|
88
146
|
@abstractmethod
|
|
89
|
-
def
|
|
147
|
+
def add_layer(
|
|
148
|
+
self,
|
|
149
|
+
index_dataset: int,
|
|
150
|
+
index_augmentation: int,
|
|
151
|
+
index_patch: int,
|
|
152
|
+
layer: torch.Tensor,
|
|
153
|
+
dataset: DatasetIter,
|
|
154
|
+
):
|
|
90
155
|
pass
|
|
91
156
|
|
|
92
|
-
def
|
|
93
|
-
return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all(
|
|
157
|
+
def is_done(self, index: int) -> bool:
|
|
158
|
+
return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all(
|
|
159
|
+
acc.is_full() for acc in self.output_layer_accumulator[index].values()
|
|
160
|
+
)
|
|
94
161
|
|
|
95
162
|
@abstractmethod
|
|
96
|
-
def
|
|
163
|
+
def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
97
164
|
pass
|
|
98
165
|
|
|
99
|
-
def
|
|
166
|
+
def write_prediction(self, index: int, name: str, layer: torch.Tensor) -> None:
|
|
100
167
|
super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
|
|
101
168
|
self.attributes.pop(index)
|
|
102
169
|
|
|
103
|
-
class Reduction():
|
|
104
|
-
|
|
105
|
-
def __init__(self):
|
|
106
|
-
pass
|
|
107
|
-
|
|
108
|
-
class Mean(Reduction):
|
|
109
170
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
171
|
+
class OutSameAsGroupDataset(OutputDataset):
|
|
172
|
+
|
|
173
|
+
@config("OutputDataset")
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
same_as_group: str = "default",
|
|
177
|
+
dataset_filename: str = "default:./Dataset:mha",
|
|
178
|
+
group: str = "default",
|
|
179
|
+
before_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
|
|
180
|
+
after_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
|
|
181
|
+
final_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
|
|
182
|
+
patch_combine: str | None = None,
|
|
183
|
+
reduction: str = "Mean",
|
|
184
|
+
inverse_transform: bool = True,
|
|
185
|
+
) -> None:
|
|
186
|
+
super().__init__(
|
|
187
|
+
dataset_filename,
|
|
188
|
+
group,
|
|
189
|
+
before_reduction_transforms,
|
|
190
|
+
after_reduction_transforms,
|
|
191
|
+
final_transforms,
|
|
192
|
+
patch_combine,
|
|
193
|
+
reduction,
|
|
194
|
+
)
|
|
195
|
+
self.group_src, self.group_dest = same_as_group.split(":")
|
|
130
196
|
self.inverse_transform = inverse_transform
|
|
131
197
|
|
|
132
|
-
def
|
|
133
|
-
|
|
134
|
-
|
|
198
|
+
def add_layer(
|
|
199
|
+
self,
|
|
200
|
+
index_dataset: int,
|
|
201
|
+
index_augmentation: int,
|
|
202
|
+
index_patch: int,
|
|
203
|
+
layer: torch.Tensor,
|
|
204
|
+
dataset: DatasetIter,
|
|
205
|
+
):
|
|
206
|
+
if (
|
|
207
|
+
index_dataset not in self.output_layer_accumulator
|
|
208
|
+
or index_augmentation not in self.output_layer_accumulator[index_dataset]
|
|
209
|
+
):
|
|
210
|
+
input_dataset = dataset.get_dataset_from_index(self.group_dest, index_dataset)
|
|
135
211
|
if index_dataset not in self.output_layer_accumulator:
|
|
136
212
|
self.output_layer_accumulator[index_dataset] = {}
|
|
137
213
|
self.attributes[index_dataset] = {}
|
|
138
214
|
self.names[index_dataset] = input_dataset.name
|
|
139
215
|
self.attributes[index_dataset][index_augmentation] = {}
|
|
140
216
|
|
|
141
|
-
self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(
|
|
217
|
+
self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(
|
|
218
|
+
input_dataset.patch.get_patch_slices(index_augmentation),
|
|
219
|
+
input_dataset.patch.patch_size,
|
|
220
|
+
self.patch_combine,
|
|
221
|
+
batch=False,
|
|
222
|
+
)
|
|
142
223
|
|
|
143
|
-
for i in range(len(input_dataset.patch.
|
|
224
|
+
for i in range(len(input_dataset.patch.get_patch_slices(index_augmentation))):
|
|
144
225
|
self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
|
|
145
226
|
|
|
146
227
|
if self.inverse_transform:
|
|
147
|
-
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].
|
|
148
|
-
layer = transform.inverse(
|
|
149
|
-
|
|
150
|
-
|
|
228
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
|
|
229
|
+
layer = transform.inverse(
|
|
230
|
+
self.names[index_dataset],
|
|
231
|
+
layer,
|
|
232
|
+
self.attributes[index_dataset][index_augmentation][index_patch],
|
|
233
|
+
)
|
|
234
|
+
self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
|
|
151
235
|
|
|
152
236
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
153
237
|
super().load(name_layer, datasets, groups)
|
|
154
|
-
|
|
238
|
+
|
|
155
239
|
if self.group_src not in groups.keys():
|
|
156
|
-
raise PredictorError(
|
|
157
|
-
f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
|
|
158
|
-
)
|
|
240
|
+
raise PredictorError(f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}.")
|
|
159
241
|
|
|
160
242
|
if self.group_dest not in groups[self.group_src]:
|
|
161
243
|
raise PredictorError(
|
|
162
244
|
f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
|
|
163
245
|
)
|
|
164
|
-
|
|
165
|
-
def
|
|
246
|
+
|
|
247
|
+
def _get_output(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
|
|
166
248
|
layer = self.output_layer_accumulator[index][index_augmentation].assemble()
|
|
167
249
|
if index_augmentation > 0:
|
|
168
|
-
|
|
250
|
+
|
|
169
251
|
i = 0
|
|
170
|
-
index_augmentation_tmp = index_augmentation-1
|
|
171
|
-
for
|
|
172
|
-
if index_augmentation_tmp >= i and index_augmentation_tmp < i+
|
|
173
|
-
for
|
|
174
|
-
layer =
|
|
252
|
+
index_augmentation_tmp = index_augmentation - 1
|
|
253
|
+
for data_augmentations in dataset.data_augmentations_list:
|
|
254
|
+
if index_augmentation_tmp >= i and index_augmentation_tmp < i + data_augmentations.nb:
|
|
255
|
+
for data_augmentation in reversed(data_augmentations.data_augmentations):
|
|
256
|
+
layer = data_augmentation.inverse(index, index_augmentation_tmp - i, layer)
|
|
175
257
|
break
|
|
176
|
-
i +=
|
|
258
|
+
i += data_augmentations.nb
|
|
177
259
|
|
|
178
260
|
for transform in self.before_reduction_transforms:
|
|
179
261
|
layer = transform(self.names[index], layer, self.attributes[index][index_augmentation][0])
|
|
180
|
-
|
|
262
|
+
|
|
181
263
|
return layer
|
|
182
264
|
|
|
183
|
-
def
|
|
184
|
-
result = torch.cat(
|
|
265
|
+
def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
266
|
+
result = torch.cat(
|
|
267
|
+
[
|
|
268
|
+
self._get_output(index, index_augmentation, dataset).unsqueeze(0)
|
|
269
|
+
for index_augmentation in self.output_layer_accumulator[index].keys()
|
|
270
|
+
],
|
|
271
|
+
dim=0,
|
|
272
|
+
)
|
|
185
273
|
self.output_layer_accumulator.pop(index)
|
|
186
274
|
result = self.reduction(result.float()).to(result.dtype)
|
|
187
|
-
|
|
188
275
|
for transform in self.after_reduction_transforms:
|
|
189
276
|
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
190
277
|
|
|
191
278
|
if self.inverse_transform:
|
|
192
|
-
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].
|
|
279
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
|
|
193
280
|
result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
|
|
194
|
-
|
|
281
|
+
|
|
195
282
|
for transform in self.final_transforms:
|
|
196
|
-
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
283
|
+
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
197
284
|
return result
|
|
198
285
|
|
|
199
|
-
class OutDatasetLoader():
|
|
200
286
|
|
|
201
|
-
|
|
287
|
+
class OutputDatasetLoader:
|
|
288
|
+
|
|
289
|
+
@config("OutputDataset")
|
|
202
290
|
def __init__(self, name_class: str = "OutSameAsGroupDataset") -> None:
|
|
203
291
|
self.name_class = name_class
|
|
204
292
|
|
|
205
|
-
def
|
|
206
|
-
return getattr(importlib.import_module("konfai.predictor"), self.name_class)(
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
293
|
+
def get_output_dataset(self, layer_name: str) -> OutputDataset:
|
|
294
|
+
return getattr(importlib.import_module("konfai.predictor"), self.name_class)(
|
|
295
|
+
config=None, konfai_args=f"Predictor.outputs_dataset.{layer_name}"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class _Predictor:
|
|
300
|
+
"""
|
|
301
|
+
Internal class that runs distributed inference over a dataset using a composite model.
|
|
302
|
+
|
|
303
|
+
This class handles patch-wise prediction, output accumulation, logging to TensorBoard, and
|
|
304
|
+
writing final predictions to disk. It is designed to be used as a context manager and
|
|
305
|
+
supports model ensembles via `ModelComposite`.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
world_size (int): Total number of processes or GPUs used.
|
|
309
|
+
global_rank (int): Rank of the current process across all nodes.
|
|
310
|
+
local_rank (int): Local GPU index within a single node.
|
|
311
|
+
autocast (bool): Whether to use automatic mixed precision (AMP).
|
|
312
|
+
predict_path (str): Output directory path where predictions and metrics are saved.
|
|
313
|
+
data_log (list[str] | None): List of logging targets in the format 'group/DataLogType/N'.
|
|
314
|
+
outputs_dataset (dict[str, OutputDataset]): Dictionary of output datasets to store predictions.
|
|
315
|
+
model_composite (DDP): Distributed model container that wraps the prediction model(s).
|
|
316
|
+
dataloader_prediction (DataLoader): DataLoader that provides prediction batches.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def __init__(
|
|
320
|
+
self,
|
|
321
|
+
world_size: int,
|
|
322
|
+
global_rank: int,
|
|
323
|
+
local_rank: int,
|
|
324
|
+
autocast: bool,
|
|
325
|
+
predict_path: str,
|
|
326
|
+
data_log: list[str] | None,
|
|
327
|
+
outputs_dataset: dict[str, OutputDataset],
|
|
328
|
+
model_composite: DDP,
|
|
329
|
+
dataloader_prediction: DataLoader,
|
|
330
|
+
) -> None:
|
|
331
|
+
self.world_size = world_size
|
|
212
332
|
self.global_rank = global_rank
|
|
213
333
|
self.local_rank = local_rank
|
|
214
334
|
|
|
215
|
-
self.
|
|
335
|
+
self.model_composite = model_composite
|
|
216
336
|
self.dataloader_prediction = dataloader_prediction
|
|
217
|
-
self.
|
|
337
|
+
self.outputs_dataset = outputs_dataset
|
|
218
338
|
self.autocast = autocast
|
|
219
|
-
|
|
339
|
+
|
|
220
340
|
self.it = 0
|
|
221
341
|
|
|
222
|
-
self.device = self.
|
|
342
|
+
self.device = self.model_composite.device
|
|
223
343
|
self.dataset: DatasetIter = self.dataloader_prediction.dataset
|
|
224
|
-
patch_size, overlap = self.dataset.
|
|
225
|
-
for
|
|
226
|
-
|
|
227
|
-
|
|
344
|
+
patch_size, overlap = self.dataset.get_patch_config()
|
|
345
|
+
for output_dataset in self.outputs_dataset.values():
|
|
346
|
+
output_dataset.set_patch_config(
|
|
347
|
+
[size for size in patch_size if size > 1] if patch_size else None,
|
|
348
|
+
overlap,
|
|
349
|
+
np.max(
|
|
350
|
+
[
|
|
351
|
+
int(
|
|
352
|
+
np.sum([data_augmentation.nb for data_augmentation in self.dataset.data_augmentations_list])
|
|
353
|
+
+ 1
|
|
354
|
+
),
|
|
355
|
+
1,
|
|
356
|
+
]
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
self.data_log: dict[str, tuple[DataLog, int]] = {}
|
|
228
360
|
if data_log is not None:
|
|
229
361
|
for data in data_log:
|
|
230
|
-
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
231
|
-
|
|
232
|
-
|
|
362
|
+
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
363
|
+
DataLog[data.split("/")[1]],
|
|
364
|
+
int(data.split("/")[2]),
|
|
365
|
+
)
|
|
366
|
+
self.tb = (
|
|
367
|
+
SummaryWriter(log_dir=predict_path + "Metric/")
|
|
368
|
+
if len(
|
|
369
|
+
[
|
|
370
|
+
network
|
|
371
|
+
for network in self.model_composite.module.get_networks().values()
|
|
372
|
+
if network.measure is not None
|
|
373
|
+
]
|
|
374
|
+
)
|
|
375
|
+
or len(self.data_log)
|
|
376
|
+
else None
|
|
377
|
+
)
|
|
378
|
+
|
|
233
379
|
def __enter__(self):
|
|
380
|
+
"""
|
|
381
|
+
Enters the prediction context and returns the predictor instance.
|
|
382
|
+
"""
|
|
234
383
|
return self
|
|
235
|
-
|
|
236
|
-
def __exit__(self,
|
|
384
|
+
|
|
385
|
+
def __exit__(self, exc_type, value, traceback):
|
|
386
|
+
"""
|
|
387
|
+
Closes the TensorBoard writer upon exit.
|
|
388
|
+
"""
|
|
237
389
|
if self.tb:
|
|
238
390
|
self.tb.close()
|
|
239
391
|
|
|
240
|
-
def
|
|
241
|
-
|
|
242
|
-
|
|
392
|
+
def get_input(
|
|
393
|
+
self,
|
|
394
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
395
|
+
) -> dict[tuple[str, bool], torch.Tensor]:
|
|
396
|
+
return {(k, v[5][0]): v[0] for k, v in data_dict.items()}
|
|
397
|
+
|
|
243
398
|
@torch.no_grad()
|
|
244
399
|
def run(self):
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
400
|
+
"""
|
|
401
|
+
Run the full prediction loop.
|
|
402
|
+
|
|
403
|
+
Iterates over the prediction DataLoader, performs inference using the composite model,
|
|
404
|
+
applies reduction (e.g., mean), and writes the final results using each `OutputDataset`.
|
|
405
|
+
|
|
406
|
+
Also logs intermediate data and metrics to TensorBoard if enabled.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
self.model_composite.eval()
|
|
410
|
+
self.model_composite.module.set_state(NetState.PREDICTION)
|
|
248
411
|
self.dataloader_prediction.dataset.load("Prediction")
|
|
249
|
-
with tqdm.tqdm(
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
412
|
+
with tqdm.tqdm(
|
|
413
|
+
iterable=enumerate(self.dataloader_prediction),
|
|
414
|
+
leave=True,
|
|
415
|
+
desc=f"Prediction : {description(self.model_composite)}",
|
|
416
|
+
total=len(self.dataloader_prediction),
|
|
417
|
+
ncols=0,
|
|
418
|
+
) as batch_iter:
|
|
419
|
+
for _, data_dict in batch_iter:
|
|
420
|
+
with torch.amp.autocast("cuda", enabled=self.autocast):
|
|
421
|
+
input_tensor = self.get_input(data_dict)
|
|
422
|
+
for name, output in self.model_composite(input_tensor, list(self.outputs_dataset.keys())):
|
|
254
423
|
self._predict_log(data_dict)
|
|
255
|
-
|
|
256
|
-
for i, (index, patch_augmentation, patch_index) in enumerate(
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
424
|
+
output_dataset = self.outputs_dataset[name]
|
|
425
|
+
for i, (index, patch_augmentation, patch_index) in enumerate(
|
|
426
|
+
[
|
|
427
|
+
(int(index), int(patch_augmentation), int(patch_index))
|
|
428
|
+
for index, patch_augmentation, patch_index in zip(
|
|
429
|
+
list(data_dict.values())[0][1],
|
|
430
|
+
list(data_dict.values())[0][2],
|
|
431
|
+
list(data_dict.values())[0][3],
|
|
432
|
+
)
|
|
433
|
+
]
|
|
434
|
+
):
|
|
435
|
+
output_dataset.add_layer(
|
|
436
|
+
index,
|
|
437
|
+
patch_augmentation,
|
|
438
|
+
patch_index,
|
|
439
|
+
output[i].cpu(),
|
|
440
|
+
self.dataset,
|
|
441
|
+
)
|
|
442
|
+
if output_dataset.is_done(index):
|
|
443
|
+
output_dataset.write_prediction(
|
|
444
|
+
index,
|
|
445
|
+
self.dataset.get_dataset_from_index(list(data_dict.keys())[0], index).name.split(
|
|
446
|
+
"/"
|
|
447
|
+
)[-1],
|
|
448
|
+
output_dataset.get_output(index, self.dataset),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
batch_iter.set_description(f"Prediction : {description(self.model_composite)}")
|
|
262
452
|
self.it += 1
|
|
263
|
-
|
|
264
|
-
def _predict_log(
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
453
|
+
|
|
454
|
+
def _predict_log(
|
|
455
|
+
self,
|
|
456
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
457
|
+
):
|
|
458
|
+
"""
|
|
459
|
+
Log prediction results to TensorBoard, including images and metrics.
|
|
460
|
+
|
|
461
|
+
This method handles:
|
|
462
|
+
- Logging image-like data (e.g., inputs, outputs, masks) using `DataLog` instances,
|
|
463
|
+
based on the `data_log` configuration.
|
|
464
|
+
- Logging scalar loss and metric values (if present in the network) under the `Prediction/` namespace.
|
|
465
|
+
- Dynamically retrieving additional feature maps or intermediate layers if requested via `data_log`.
|
|
466
|
+
|
|
467
|
+
Logging is performed only on the global rank 0 process and only if `TensorBoard` is active.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
data_dict (dict): Dictionary mapping group names to 6-tuples containing:
|
|
471
|
+
- input tensor,
|
|
472
|
+
- index,
|
|
473
|
+
- patch_augmentation,
|
|
474
|
+
- patch_index,
|
|
475
|
+
- metadata (list of strings),
|
|
476
|
+
- `requires_grad` flag (as a tensor).
|
|
477
|
+
"""
|
|
478
|
+
measures = DistributedObject.get_measure(
|
|
479
|
+
self.world_size,
|
|
480
|
+
self.global_rank,
|
|
481
|
+
self.local_rank,
|
|
482
|
+
{"": self.model_composite.module},
|
|
483
|
+
1,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
if self.global_rank == 0 and self.tb is not None:
|
|
487
|
+
data_log = []
|
|
488
|
+
if len(self.data_log):
|
|
270
489
|
for name, data_type in self.data_log.items():
|
|
271
490
|
if name in data_dict:
|
|
272
|
-
data_type[0](
|
|
491
|
+
data_type[0](
|
|
492
|
+
self.tb,
|
|
493
|
+
f"Prediction/{name}",
|
|
494
|
+
data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
|
|
495
|
+
self.it,
|
|
496
|
+
)
|
|
273
497
|
else:
|
|
274
|
-
|
|
498
|
+
data_log.append(name.replace(":", "."))
|
|
275
499
|
|
|
276
|
-
for name, network in self.
|
|
500
|
+
for name, network in self.model_composite.module.get_networks().items():
|
|
277
501
|
if network.measure is not None:
|
|
278
|
-
self.tb.add_scalars(
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
502
|
+
self.tb.add_scalars(
|
|
503
|
+
f"Prediction/{name}/Loss",
|
|
504
|
+
{k: v[1] for k, v in measures[name][0].items()},
|
|
505
|
+
self.it,
|
|
506
|
+
)
|
|
507
|
+
self.tb.add_scalars(
|
|
508
|
+
f"Prediction/{name}/Metric",
|
|
509
|
+
{k: v[1] for k, v in measures[name][1].items()},
|
|
510
|
+
self.it,
|
|
511
|
+
)
|
|
512
|
+
if len(data_log):
|
|
513
|
+
for name, layer, _ in self.model_composite.module.get_layers(
|
|
514
|
+
[v.to(0) for k, v in self.get_input(data_dict).items() if k[1]], data_log
|
|
515
|
+
):
|
|
516
|
+
self.data_log[name][0](
|
|
517
|
+
self.tb,
|
|
518
|
+
f"Prediction/{name}",
|
|
519
|
+
layer[: self.data_log[name][1]].detach().cpu().numpy(),
|
|
520
|
+
self.it,
|
|
521
|
+
)
|
|
522
|
+
|
|
283
523
|
|
|
284
524
|
class ModelComposite(Network):
|
|
525
|
+
"""
|
|
526
|
+
A composite model that replicates a given base network multiple times and combines their outputs.
|
|
527
|
+
|
|
528
|
+
This class is designed to handle model ensembles or repeated predictions from the same architecture.
|
|
529
|
+
It creates `nb_models` deep copies of the input `model`, each with its own name and output branch,
|
|
530
|
+
and aggregates their outputs using a provided `Reduction` strategy (e.g., mean, median).
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
model (Network): The base network to replicate.
|
|
534
|
+
nb_models (int): Number of copies of the model to create.
|
|
535
|
+
combine (Reduction): The reduction method used to combine outputs from all model replicas.
|
|
536
|
+
|
|
537
|
+
Attributes:
|
|
538
|
+
combine (Reduction): The reduction method used during forward inference.
|
|
539
|
+
"""
|
|
285
540
|
|
|
286
541
|
def __init__(self, model: Network, nb_models: int, combine: Reduction):
|
|
287
|
-
super().__init__(
|
|
542
|
+
super().__init__(
|
|
543
|
+
model.in_channels,
|
|
544
|
+
model.optimizer,
|
|
545
|
+
model.lr_schedulers_loader,
|
|
546
|
+
model.outputs_criterions_loader,
|
|
547
|
+
model.patch,
|
|
548
|
+
model.nb_batch_per_step,
|
|
549
|
+
model.init_type,
|
|
550
|
+
model.init_gain,
|
|
551
|
+
model.dim,
|
|
552
|
+
)
|
|
288
553
|
self.combine = combine
|
|
289
554
|
for i in range(nb_models):
|
|
290
|
-
self.add_module(
|
|
555
|
+
self.add_module(
|
|
556
|
+
f"Model_{i}",
|
|
557
|
+
copy.deepcopy(model),
|
|
558
|
+
in_branch=[0],
|
|
559
|
+
out_branch=[f"output_{i}"],
|
|
560
|
+
)
|
|
291
561
|
|
|
292
|
-
def load(self, state_dicts
|
|
562
|
+
def load(self, state_dicts: list[dict[str, dict[str, torch.Tensor]]]):
|
|
563
|
+
"""
|
|
564
|
+
Load weights for each sub-model in the composite from the corresponding state dictionaries.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
state_dicts (list): A list of state dictionaries, one for each model replica.
|
|
568
|
+
"""
|
|
293
569
|
for i, state_dict in enumerate(state_dicts):
|
|
294
|
-
self["Model_{}"
|
|
295
|
-
self["Model_{}"
|
|
296
|
-
|
|
297
|
-
def forward(
|
|
570
|
+
self[f"Model_{i}"].load(state_dict, init=False)
|
|
571
|
+
self[f"Model_{i}"].set_name(f"{self[f'Model_{i}'].get_name()}_{i}")
|
|
572
|
+
|
|
573
|
+
def forward(
|
|
574
|
+
self,
|
|
575
|
+
data_dict: dict[tuple[str, bool], torch.Tensor],
|
|
576
|
+
output_layers: list[str] = [],
|
|
577
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
578
|
+
"""
|
|
579
|
+
Perform a forward pass on all model replicas and aggregate their outputs.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
data_dict (dict): A dictionary mapping (group_name, requires_grad) to input tensors.
|
|
583
|
+
output_layers (list): List of output layer names to extract from each sub-model.
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
list[tuple[str, torch.Tensor]]: Aggregated output for each layer, after applying the reduction.
|
|
587
|
+
"""
|
|
298
588
|
result = {}
|
|
299
589
|
for name, module in self.items():
|
|
300
590
|
result[name] = module(data_dict, output_layers)
|
|
301
|
-
|
|
591
|
+
|
|
302
592
|
aggregated = defaultdict(list)
|
|
303
593
|
for module_outputs in result.values():
|
|
304
594
|
for key, tensor in module_outputs:
|
|
@@ -309,21 +599,41 @@ class ModelComposite(Network):
|
|
|
309
599
|
final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
|
|
310
600
|
|
|
311
601
|
return final_outputs
|
|
312
|
-
|
|
602
|
+
|
|
313
603
|
|
|
314
604
|
class Predictor(DistributedObject):
|
|
605
|
+
"""
|
|
606
|
+
KonfAI's main prediction controller.
|
|
607
|
+
|
|
608
|
+
This class orchestrates the prediction phase by:
|
|
609
|
+
- Loading model weights from checkpoint(s) or URL(s)
|
|
610
|
+
- Preparing datasets and output configurations
|
|
611
|
+
- Managing distributed inference with optional multi-GPU support
|
|
612
|
+
- Applying transformations and saving predictions
|
|
613
|
+
- Optionally logging results to TensorBoard
|
|
614
|
+
|
|
615
|
+
Attributes:
|
|
616
|
+
model (Network): The neural network model to use for prediction.
|
|
617
|
+
dataset (DataPrediction): Dataset manager for prediction data.
|
|
618
|
+
combine_classpath (str): Path to the reduction strategy (e.g., "Mean").
|
|
619
|
+
autocast (bool): Whether to enable AMP inference.
|
|
620
|
+
outputs_dataset (dict[str, OutputDataset]): Mapping from layer names to output writers.
|
|
621
|
+
data_log (list[str] | None): List of tensors to log during inference.
|
|
622
|
+
"""
|
|
315
623
|
|
|
316
624
|
@config("Predictor")
|
|
317
|
-
def __init__(
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
625
|
+
def __init__(
|
|
626
|
+
self,
|
|
627
|
+
model: ModelLoader = ModelLoader(),
|
|
628
|
+
dataset: DataPrediction = DataPrediction(),
|
|
629
|
+
combine: str = "Mean",
|
|
630
|
+
train_name: str = "name",
|
|
631
|
+
manual_seed: int | None = None,
|
|
632
|
+
gpu_checkpoints: list[str] | None = None,
|
|
633
|
+
autocast: bool = False,
|
|
634
|
+
outputs_dataset: dict[str, OutputDatasetLoader] | None = {"default:Default": OutputDatasetLoader()},
|
|
635
|
+
data_log: list[str] | None = None,
|
|
636
|
+
) -> None:
|
|
327
637
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
328
638
|
exit(0)
|
|
329
639
|
super().__init__(train_name)
|
|
@@ -332,99 +642,194 @@ class Predictor(DistributedObject):
|
|
|
332
642
|
self.combine_classpath = combine
|
|
333
643
|
self.autocast = autocast
|
|
334
644
|
|
|
335
|
-
self.model = model.
|
|
645
|
+
self.model = model.get_model(train=False)
|
|
336
646
|
self.it = 0
|
|
337
|
-
self.
|
|
338
|
-
self.
|
|
647
|
+
self.outputs_dataset_loader = outputs_dataset if outputs_dataset else {}
|
|
648
|
+
self.outputs_dataset = {
|
|
649
|
+
name.replace(":", "."): value.get_output_dataset(name)
|
|
650
|
+
for name, value in self.outputs_dataset_loader.items()
|
|
651
|
+
}
|
|
339
652
|
|
|
340
653
|
self.datasets_filename = []
|
|
341
|
-
self.predict_path =
|
|
342
|
-
self.
|
|
343
|
-
|
|
344
|
-
self.
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
654
|
+
self.predict_path = predictions_directory() + self.name + "/"
|
|
655
|
+
for output_dataset in self.outputs_dataset.values():
|
|
656
|
+
self.datasets_filename.append(output_dataset.filename)
|
|
657
|
+
output_dataset.filename = f"{self.predict_path}{output_dataset.filename}"
|
|
658
|
+
|
|
659
|
+
self.data_log = data_log
|
|
660
|
+
modules = []
|
|
661
|
+
for i, _ in self.model.named_modules():
|
|
662
|
+
modules.append(i)
|
|
663
|
+
if self.data_log is not None:
|
|
664
|
+
for k in self.data_log:
|
|
665
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
666
|
+
if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
|
|
667
|
+
raise PredictorError(
|
|
668
|
+
f"Invalid key '{tmp}' in `data_log`.",
|
|
669
|
+
f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
|
|
670
|
+
f"nor a valid module name in the model ({modules}).",
|
|
671
|
+
"Please check your `data_log` configuration,"
|
|
672
|
+
" it should reference either a model output or a dataset group.",
|
|
673
|
+
)
|
|
674
|
+
|
|
348
675
|
self.gpu_checkpoints = gpu_checkpoints
|
|
349
676
|
|
|
350
677
|
def _load(self) -> list[dict[str, dict[str, torch.Tensor]]]:
|
|
351
|
-
|
|
678
|
+
"""
|
|
679
|
+
Load pretrained model weights from configured paths or URLs.
|
|
680
|
+
|
|
681
|
+
This method handles both remote and local model sources:
|
|
682
|
+
- If the model path is a URL (starting with "https://"), it uses `torch.hub.load_state_dict_from_url`
|
|
683
|
+
to download and load the state dict.
|
|
684
|
+
- If the model path is local:
|
|
685
|
+
- It either loads the explicit file or resolves the latest model file in a default directory
|
|
686
|
+
based on the prediction name.
|
|
687
|
+
- All loaded state dicts are returned as a list of nested dictionaries mapping module names
|
|
688
|
+
to parameter tensors.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
list[dict[str, dict[str, torch.Tensor]]]: A list of state dictionaries, one per model.
|
|
692
|
+
|
|
693
|
+
Raises:
|
|
694
|
+
Exception: If a model path does not exist or cannot be loaded.
|
|
695
|
+
"""
|
|
696
|
+
model_paths = path_to_models().split(":")
|
|
352
697
|
state_dicts = []
|
|
353
698
|
for model_path in model_paths:
|
|
354
699
|
if model_path.startswith("https://"):
|
|
355
700
|
try:
|
|
356
|
-
state_dicts.append(
|
|
357
|
-
|
|
358
|
-
|
|
701
|
+
state_dicts.append(
|
|
702
|
+
torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True)
|
|
703
|
+
)
|
|
704
|
+
except Exception:
|
|
705
|
+
raise Exception(f"Model : {model_path} does not exist !")
|
|
359
706
|
else:
|
|
360
707
|
if model_path != "":
|
|
361
708
|
path = ""
|
|
362
709
|
name = model_path
|
|
363
710
|
else:
|
|
364
711
|
if self.name.endswith(".pt"):
|
|
365
|
-
path =
|
|
712
|
+
path = models_directory() + "/".join(self.name.split("/")[:-1]) + "/StateDict/"
|
|
366
713
|
name = self.name.split("/")[-1]
|
|
367
714
|
else:
|
|
368
|
-
path =
|
|
715
|
+
path = models_directory() + self.name + "/StateDict/"
|
|
369
716
|
name = sorted(os.listdir(path))[-1]
|
|
370
|
-
if os.path.exists(path+name):
|
|
371
|
-
state_dicts.append(torch.load(path+name, weights_only=
|
|
717
|
+
if os.path.exists(path + name):
|
|
718
|
+
state_dicts.append(torch.load(path + name, weights_only=True))
|
|
372
719
|
else:
|
|
373
|
-
raise Exception("Model : {} does not exist !"
|
|
720
|
+
raise Exception(f"Model : {path + name} does not exist !")
|
|
374
721
|
return state_dicts
|
|
375
|
-
|
|
722
|
+
|
|
376
723
|
def setup(self, world_size: int):
|
|
724
|
+
"""
|
|
725
|
+
Set up the predictor for inference.
|
|
726
|
+
|
|
727
|
+
This method performs all necessary initialization steps before running predictions:
|
|
728
|
+
- Ensures output directories exist, and optionally prompts the user before overwriting existing predictions.
|
|
729
|
+
- Copies the current configuration file (Prediction.yml) into the output directory for reproducibility.
|
|
730
|
+
- Initializes the model in prediction mode, including output configuration and channel tracing.
|
|
731
|
+
- Validates that the configured output groups match existing modules in the model architecture.
|
|
732
|
+
- Dynamically loads pretrained weights from local files or remote URLs.
|
|
733
|
+
- Wraps the base model into a `ModelComposite` to support ensemble inference.
|
|
734
|
+
- Initializes the prediction dataloader, with proper distribution across available GPUs.
|
|
735
|
+
- Loads and prepares each configured `OutputDataset` object for storing predictions.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
world_size (int): Total number of processes or GPUs used for distributed prediction.
|
|
739
|
+
|
|
740
|
+
Raises:
|
|
741
|
+
PredictorError: If an output group does not match any module in the model.
|
|
742
|
+
Exception: If a specified model file or URL is invalid or inaccessible.
|
|
743
|
+
"""
|
|
377
744
|
for dataset_filename in self.datasets_filename:
|
|
378
|
-
path = self.predict_path +dataset_filename
|
|
745
|
+
path = self.predict_path + dataset_filename
|
|
379
746
|
if os.path.exists(path):
|
|
380
747
|
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
381
|
-
accept = builtins.input(
|
|
748
|
+
accept = builtins.input(
|
|
749
|
+
f"The prediction {path} already exists ! Do you want to overwrite it (yes,no) : "
|
|
750
|
+
)
|
|
382
751
|
if accept != "yes":
|
|
383
752
|
return
|
|
384
|
-
|
|
753
|
+
|
|
385
754
|
if not os.path.exists(path):
|
|
386
755
|
os.makedirs(path)
|
|
387
756
|
|
|
388
|
-
shutil.copyfile(
|
|
757
|
+
shutil.copyfile(config_file(), self.predict_path + "Prediction.yml")
|
|
389
758
|
|
|
390
|
-
|
|
391
|
-
self.model.
|
|
392
|
-
self.model.init_outputsGroup()
|
|
759
|
+
self.model.init(self.autocast, State.PREDICTION, self.dataset.get_groups_dest())
|
|
760
|
+
self.model.init_outputs_group()
|
|
393
761
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
394
|
-
|
|
762
|
+
|
|
395
763
|
modules = []
|
|
396
|
-
for i,_,_ in self.model.
|
|
764
|
+
for i, _, _ in self.model.named_module_args_dict():
|
|
397
765
|
modules.append(i)
|
|
398
|
-
for output_group in self.
|
|
399
|
-
if output_group not in modules:
|
|
400
|
-
raise PredictorError(
|
|
401
|
-
"
|
|
402
|
-
"
|
|
766
|
+
for output_group in self.outputs_dataset.keys():
|
|
767
|
+
if output_group.replace(";accu;", "") not in modules:
|
|
768
|
+
raise PredictorError(
|
|
769
|
+
f"The output group '{output_group}' defined in 'outputs_criterions' "
|
|
770
|
+
"does not correspond to any module in the model.",
|
|
771
|
+
f"Available modules: {modules}",
|
|
772
|
+
"Please check that the name matches exactly a submodule or" "output of your model architecture.",
|
|
403
773
|
)
|
|
404
|
-
|
|
405
|
-
module, name =
|
|
774
|
+
|
|
775
|
+
module, name = get_module(self.combine_classpath, "konfai.predictor")
|
|
406
776
|
if module == "konfai.predictor":
|
|
407
|
-
combine = getattr(importlib.import_module(module), name)
|
|
777
|
+
combine = getattr(importlib.import_module(module), name)()
|
|
408
778
|
else:
|
|
409
|
-
combine = config("{}.{
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
self.
|
|
779
|
+
combine = config(f"{konfai_root()}.{self.combine_classpath}")(
|
|
780
|
+
getattr(importlib.import_module(module), name)
|
|
781
|
+
)(config=None)
|
|
782
|
+
|
|
783
|
+
self.model_composite = ModelComposite(self.model, len(path_to_models().split(":")), combine)
|
|
784
|
+
self.model_composite.load(self._load())
|
|
414
785
|
|
|
415
|
-
if
|
|
786
|
+
if (
|
|
787
|
+
len(list(self.outputs_dataset.keys())) == 0
|
|
788
|
+
and len(
|
|
789
|
+
[network for network in self.model_composite.get_networks().values() if network.measure is not None]
|
|
790
|
+
)
|
|
791
|
+
== 0
|
|
792
|
+
):
|
|
416
793
|
exit(0)
|
|
417
|
-
|
|
418
|
-
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
419
|
-
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
420
|
-
for name, outDataset in self.outsDataset.items():
|
|
421
|
-
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
425
|
-
modelComposite = Network.to(self.modelComposite, local_rank*self.size)
|
|
426
|
-
modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
|
|
427
|
-
with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
428
|
-
p.run()
|
|
429
794
|
|
|
430
|
-
|
|
795
|
+
self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
|
|
796
|
+
self.dataloader = self.dataset.get_data(world_size // self.size)
|
|
797
|
+
for name, output_dataset in self.outputs_dataset.items():
|
|
798
|
+
output_dataset.load(
|
|
799
|
+
name.replace(".", ":"),
|
|
800
|
+
list(self.dataset.datasets.values()),
|
|
801
|
+
{src: dest for src, inner in self.dataset.groups_src.items() for dest in inner},
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
def run_process(
|
|
805
|
+
self,
|
|
806
|
+
world_size: int,
|
|
807
|
+
global_rank: int,
|
|
808
|
+
local_rank: int,
|
|
809
|
+
dataloaders: list[DataLoader],
|
|
810
|
+
):
|
|
811
|
+
"""
|
|
812
|
+
Launch prediction on the given process rank.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
world_size (int): Total number of processes.
|
|
816
|
+
global_rank (int): Rank of the current process.
|
|
817
|
+
local_rank (int): Local device rank.
|
|
818
|
+
dataloaders (list[DataLoader]): List of data loaders for prediction.
|
|
819
|
+
"""
|
|
820
|
+
model_composite = Network.to(self.model_composite, local_rank * self.size)
|
|
821
|
+
model_composite = (
|
|
822
|
+
DDP(model_composite, static_graph=True) if torch.cuda.is_available() else CPUModel(model_composite)
|
|
823
|
+
)
|
|
824
|
+
with _Predictor(
|
|
825
|
+
world_size,
|
|
826
|
+
global_rank,
|
|
827
|
+
local_rank,
|
|
828
|
+
self.autocast,
|
|
829
|
+
self.predict_path,
|
|
830
|
+
self.data_log,
|
|
831
|
+
self.outputs_dataset,
|
|
832
|
+
model_composite,
|
|
833
|
+
*dataloaders,
|
|
834
|
+
) as p:
|
|
835
|
+
p.run()
|