konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/predictor.py CHANGED
@@ -1,328 +1,639 @@
1
- from abc import ABC, abstractmethod
2
1
  import builtins
2
+ import copy
3
3
  import importlib
4
+ import os
4
5
  import shutil
6
+ from abc import ABC, abstractmethod
7
+ from collections import defaultdict
8
+
9
+ import numpy as np
5
10
  import torch
6
11
  import tqdm
7
- import os
12
+ from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.tensorboard.writer import SummaryWriter
8
15
 
9
- from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
- from konfai.utils.config import config
11
- from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
12
- from konfai.utils.dataset import Dataset, Attribute
16
+ from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
13
17
  from konfai.data.data_manager import DataPrediction, DatasetIter
14
18
  from konfai.data.patching import Accumulator, PathCombine
15
- from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
16
19
  from konfai.data.transform import Transform, TransformLoader
20
+ from konfai.network.network import CPUModel, ModelLoader, NetState, Network
21
+ from konfai.utils.config import config
22
+ from konfai.utils.dataset import Attribute, Dataset
23
+ from konfai.utils.utils import DataLog, DistributedObject, NeedDevice, PredictorError, State, description, get_module
24
+
25
+
26
+ class Reduction(ABC):
27
+
28
+ @abstractmethod
29
+ def __init__(self):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
34
+ pass
35
+
36
+
37
+ class Mean(Reduction):
38
+
39
+ def __init__(self):
40
+ pass
41
+
42
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
43
+ return torch.mean(result.float(), dim=0)
17
44
 
18
- from torch.utils.tensorboard.writer import SummaryWriter
19
- from typing import Union
20
- import numpy as np
21
- import torch.distributed as dist
22
- from torch.nn.parallel import DistributedDataParallel as DDP
23
- from torch.utils.data import DataLoader
24
- import importlib
25
- import copy
26
- from collections import defaultdict
27
45
 
28
- class OutDataset(Dataset, NeedDevice, ABC):
46
+ class Median(Reduction):
47
+
48
+ def __init__(self):
49
+ pass
50
+
51
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
52
+ return torch.median(result.float(), dim=0).values
29
53
 
30
- def __init__(self, filename: str, group: str, before_reduction_transforms : dict[str, TransformLoader], after_reduction_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None], reduction: str) -> None:
31
- filename, format = filename.split(":")
32
- super().__init__(filename, format)
54
+
55
+ class OutputDataset(Dataset, NeedDevice, ABC):
56
+
57
+ def __init__(
58
+ self,
59
+ filename: str,
60
+ group: str,
61
+ before_reduction_transforms: dict[str, TransformLoader],
62
+ after_reduction_transforms: dict[str, TransformLoader],
63
+ final_transforms: dict[str, TransformLoader],
64
+ patch_combine: str | None,
65
+ reduction: str,
66
+ ) -> None:
67
+ filename, file_format = filename.split(":")
68
+ super().__init__(filename, file_format)
33
69
  self.group = group
34
70
  self._before_reduction_transforms = before_reduction_transforms
35
71
  self._after_reduction_transforms = after_reduction_transforms
36
72
  self._final_transforms = final_transforms
37
- self._patchCombine = patchCombine
73
+ self._patch_combine = patch_combine
38
74
  self.reduction_classpath = reduction
39
- self.reduction = None
40
-
41
- self.before_reduction_transforms : list[Transform] = []
42
- self.after_reduction_transforms : list[Transform] = []
43
- self.final_transforms : list[Transform] = []
44
- self.patchCombine: PathCombine = None
75
+ self.reduction: Reduction
76
+
77
+ self.before_reduction_transforms: list[Transform] = []
78
+ self.after_reduction_transforms: list[Transform] = []
79
+ self.final_transforms: list[Transform] = []
80
+ self.patch_combine: PathCombine | None = None
45
81
 
46
82
  self.output_layer_accumulator: dict[int, dict[int, Accumulator]] = {}
47
83
  self.attributes: dict[int, dict[int, dict[int, Attribute]]] = {}
48
84
  self.names: dict[int, str] = {}
49
85
  self.nb_data_augmentation = 0
50
86
 
87
+ @abstractmethod
51
88
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
52
- transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
53
- for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
54
-
89
+ transforms_type = [
90
+ "before_reduction_transforms",
91
+ "after_reduction_transforms",
92
+ "final_transforms",
93
+ ]
94
+ for name, _transform_type, transform_type in [
95
+ (k, getattr(self, f"_{k}"), getattr(self, k)) for k in transforms_type
96
+ ]:
97
+
55
98
  if _transform_type is not None:
56
99
  for classpath, transform in _transform_type.items():
57
- transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, name))
58
- transform.setDatasets(datasets)
100
+ transform = transform.get_transform(
101
+ classpath,
102
+ konfai_args=f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{name}",
103
+ )
104
+ transform.set_datasets(datasets)
59
105
  transform_type.append(transform)
60
106
 
61
- if self._patchCombine is not None:
62
- module, name = _getModule(self._patchCombine, "konfai.data.patching")
63
- self.patchCombine = config("{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))(getattr(importlib.import_module(module), name))(config = None)
107
+ if self._patch_combine is not None:
108
+ module, name = get_module(self._patch_combine, "konfai.data.patching")
109
+ self.patch_combine = config(f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset")(
110
+ getattr(importlib.import_module(module), name)
111
+ )(config=None)
64
112
 
65
- module, name = _getModule(self.reduction_classpath, "konfai.predictor")
113
+ module, name = get_module(self.reduction_classpath, "konfai.predictor")
66
114
  if module == "konfai.predictor":
67
115
  self.reduction = getattr(importlib.import_module(module), name)()
68
116
  else:
69
- self.reduction = config("{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, self.reduction_classpath))(getattr(importlib.import_module(module), name))(config = None)
70
-
71
-
72
- def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
73
- if patchSize is not None and overlap is not None:
74
- if self.patchCombine is not None:
75
- self.patchCombine.setPatchConfig(patchSize, overlap)
117
+ self.reduction = config(
118
+ f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{self.reduction_classpath}"
119
+ )(getattr(importlib.import_module(module), name))(config=None)
120
+
121
+ def set_patch_config(
122
+ self,
123
+ patch_size: list[int] | None,
124
+ overlap: int | None,
125
+ nb_data_augmentation: int,
126
+ ) -> None:
127
+ if patch_size is not None and overlap is not None:
128
+ if self.patch_combine is not None:
129
+ self.patch_combine.set_patch_config(patch_size, overlap)
76
130
  else:
77
- self.patchCombine = None
131
+ self.patch_combine = None
78
132
  self.nb_data_augmentation = nb_data_augmentation
79
-
80
- def setDevice(self, device: torch.device):
81
- super().setDevice(device)
82
- transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
133
+
134
+ def to(self, device: torch.device):
135
+ super().to(device)
136
+ transforms_type = [
137
+ "before_reduction_transforms",
138
+ "after_reduction_transforms",
139
+ "final_transforms",
140
+ ]
83
141
  for transform_type in [(getattr(self, k)) for k in transforms_type]:
84
142
  if transform_type is not None:
85
143
  for transform in transform_type:
86
- transform.setDevice(device)
144
+ transform.to(device)
87
145
 
88
146
  @abstractmethod
89
- def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
147
+ def add_layer(
148
+ self,
149
+ index_dataset: int,
150
+ index_augmentation: int,
151
+ index_patch: int,
152
+ layer: torch.Tensor,
153
+ dataset: DatasetIter,
154
+ ):
90
155
  pass
91
156
 
92
- def isDone(self, index: int) -> bool:
93
- return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all([acc.isFull() for acc in self.output_layer_accumulator[index].values()])
157
+ def is_done(self, index: int) -> bool:
158
+ return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all(
159
+ acc.is_full() for acc in self.output_layer_accumulator[index].values()
160
+ )
94
161
 
95
162
  @abstractmethod
96
- def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
163
+ def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
97
164
  pass
98
165
 
99
- def write(self, index: int, name: str, layer: torch.Tensor):
166
+ def write_prediction(self, index: int, name: str, layer: torch.Tensor) -> None:
100
167
  super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
101
168
  self.attributes.pop(index)
102
169
 
103
- class Reduction():
104
-
105
- def __init__(self):
106
- pass
107
-
108
- class Mean(Reduction):
109
-
110
- def __init__(self):
111
- pass
112
-
113
- def __call__(self, result: torch.Tensor) -> torch.Tensor:
114
- return torch.mean(result.float(), dim=0)
115
-
116
- class Median(Reduction):
117
-
118
- def __init__(self):
119
- pass
120
-
121
- def __call__(self, result: torch.Tensor) -> torch.Tensor:
122
- return torch.median(result.float(), dim=0).values
123
170
 
124
- class OutSameAsGroupDataset(OutDataset):
125
-
126
- @config("OutDataset")
127
- def __init__(self, sameAsGroup: str = "default", dataset_filename: str = "default:./Dataset:mha", group: str = "default", before_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, after_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
128
- super().__init__(dataset_filename, group, before_reduction_transforms, after_reduction_transforms, final_transforms, patchCombine, reduction)
129
- self.group_src, self.group_dest = sameAsGroup.split(":")
171
+ class OutSameAsGroupDataset(OutputDataset):
172
+
173
+ @config("OutputDataset")
174
+ def __init__(
175
+ self,
176
+ same_as_group: str = "default",
177
+ dataset_filename: str = "default:./Dataset:mha",
178
+ group: str = "default",
179
+ before_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
180
+ after_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
181
+ final_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
182
+ patch_combine: str | None = None,
183
+ reduction: str = "Mean",
184
+ inverse_transform: bool = True,
185
+ ) -> None:
186
+ super().__init__(
187
+ dataset_filename,
188
+ group,
189
+ before_reduction_transforms,
190
+ after_reduction_transforms,
191
+ final_transforms,
192
+ patch_combine,
193
+ reduction,
194
+ )
195
+ self.group_src, self.group_dest = same_as_group.split(":")
130
196
  self.inverse_transform = inverse_transform
131
197
 
132
- def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
133
- if index_dataset not in self.output_layer_accumulator or index_augmentation not in self.output_layer_accumulator[index_dataset]:
134
- input_dataset = dataset.getDatasetFromIndex(self.group_dest, index_dataset)
198
+ def add_layer(
199
+ self,
200
+ index_dataset: int,
201
+ index_augmentation: int,
202
+ index_patch: int,
203
+ layer: torch.Tensor,
204
+ dataset: DatasetIter,
205
+ ):
206
+ if (
207
+ index_dataset not in self.output_layer_accumulator
208
+ or index_augmentation not in self.output_layer_accumulator[index_dataset]
209
+ ):
210
+ input_dataset = dataset.get_dataset_from_index(self.group_dest, index_dataset)
135
211
  if index_dataset not in self.output_layer_accumulator:
136
212
  self.output_layer_accumulator[index_dataset] = {}
137
213
  self.attributes[index_dataset] = {}
138
214
  self.names[index_dataset] = input_dataset.name
139
215
  self.attributes[index_dataset][index_augmentation] = {}
140
216
 
141
- self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(input_dataset.patch.getPatch_slices(index_augmentation), input_dataset.patch.patch_size, self.patchCombine, batch=False)
217
+ self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(
218
+ input_dataset.patch.get_patch_slices(index_augmentation),
219
+ input_dataset.patch.patch_size,
220
+ self.patch_combine,
221
+ batch=False,
222
+ )
142
223
 
143
- for i in range(len(input_dataset.patch.getPatch_slices(index_augmentation))):
224
+ for i in range(len(input_dataset.patch.get_patch_slices(index_augmentation))):
144
225
  self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
145
226
 
146
227
  if self.inverse_transform:
147
- for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
148
- layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
149
- self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
150
-
228
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
229
+ layer = transform.inverse(
230
+ self.names[index_dataset],
231
+ layer,
232
+ self.attributes[index_dataset][index_augmentation][index_patch],
233
+ )
234
+ self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
151
235
 
152
236
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
153
237
  super().load(name_layer, datasets, groups)
154
-
238
+
155
239
  if self.group_src not in groups.keys():
156
- raise PredictorError(
157
- f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
158
- )
240
+ raise PredictorError(f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}.")
159
241
 
160
242
  if self.group_dest not in groups[self.group_src]:
161
243
  raise PredictorError(
162
244
  f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
163
245
  )
164
-
165
- def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
246
+
247
+ def _get_output(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
166
248
  layer = self.output_layer_accumulator[index][index_augmentation].assemble()
167
249
  if index_augmentation > 0:
168
-
250
+
169
251
  i = 0
170
- index_augmentation_tmp = index_augmentation-1
171
- for dataAugmentations in dataset.dataAugmentationsList:
172
- if index_augmentation_tmp >= i and index_augmentation_tmp < i+dataAugmentations.nb:
173
- for dataAugmentation in reversed(dataAugmentations.dataAugmentations):
174
- layer = dataAugmentation.inverse(index, index_augmentation_tmp-i, layer)
252
+ index_augmentation_tmp = index_augmentation - 1
253
+ for data_augmentations in dataset.data_augmentations_list:
254
+ if index_augmentation_tmp >= i and index_augmentation_tmp < i + data_augmentations.nb:
255
+ for data_augmentation in reversed(data_augmentations.data_augmentations):
256
+ layer = data_augmentation.inverse(index, index_augmentation_tmp - i, layer)
175
257
  break
176
- i += dataAugmentations.nb
258
+ i += data_augmentations.nb
177
259
 
178
260
  for transform in self.before_reduction_transforms:
179
261
  layer = transform(self.names[index], layer, self.attributes[index][index_augmentation][0])
180
-
262
+
181
263
  return layer
182
264
 
183
- def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
184
- result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
265
+ def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
266
+ result = torch.cat(
267
+ [
268
+ self._get_output(index, index_augmentation, dataset).unsqueeze(0)
269
+ for index_augmentation in self.output_layer_accumulator[index].keys()
270
+ ],
271
+ dim=0,
272
+ )
185
273
  self.output_layer_accumulator.pop(index)
186
274
  result = self.reduction(result.float()).to(result.dtype)
187
275
  for transform in self.after_reduction_transforms:
188
276
  result = transform(self.names[index], result, self.attributes[index][0][0])
189
277
 
190
278
  if self.inverse_transform:
191
- for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
279
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
192
280
  result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
193
-
281
+
194
282
  for transform in self.final_transforms:
195
- result = transform(self.names[index], result, self.attributes[index][0][0])
283
+ result = transform(self.names[index], result, self.attributes[index][0][0])
196
284
  return result
197
285
 
198
- class OutDatasetLoader():
199
286
 
200
- @config("OutDataset")
287
+ class OutputDatasetLoader:
288
+
289
+ @config("OutputDataset")
201
290
  def __init__(self, name_class: str = "OutSameAsGroupDataset") -> None:
202
291
  self.name_class = name_class
203
292
 
204
- def getOutDataset(self, layer_name: str) -> OutDataset:
205
- return getattr(importlib.import_module("konfai.predictor"), self.name_class)(config = None, DL_args = "Predictor.outsDataset.{}".format(layer_name))
206
-
207
- class _Predictor():
208
-
209
- def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
210
- self.world_size = world_size
293
+ def get_output_dataset(self, layer_name: str) -> OutputDataset:
294
+ return getattr(importlib.import_module("konfai.predictor"), self.name_class)(
295
+ config=None, konfai_args=f"Predictor.outputs_dataset.{layer_name}"
296
+ )
297
+
298
+
299
+ class _Predictor:
300
+ """
301
+ Internal class that runs distributed inference over a dataset using a composite model.
302
+
303
+ This class handles patch-wise prediction, output accumulation, logging to TensorBoard, and
304
+ writing final predictions to disk. It is designed to be used as a context manager and
305
+ supports model ensembles via `ModelComposite`.
306
+
307
+ Args:
308
+ world_size (int): Total number of processes or GPUs used.
309
+ global_rank (int): Rank of the current process across all nodes.
310
+ local_rank (int): Local GPU index within a single node.
311
+ autocast (bool): Whether to use automatic mixed precision (AMP).
312
+ predict_path (str): Output directory path where predictions and metrics are saved.
313
+ data_log (list[str] | None): List of logging targets in the format 'group/DataLogType/N'.
314
+ outputs_dataset (dict[str, OutputDataset]): Dictionary of output datasets to store predictions.
315
+ model_composite (DDP): Distributed model container that wraps the prediction model(s).
316
+ dataloader_prediction (DataLoader): DataLoader that provides prediction batches.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ world_size: int,
322
+ global_rank: int,
323
+ local_rank: int,
324
+ autocast: bool,
325
+ predict_path: str,
326
+ data_log: list[str] | None,
327
+ outputs_dataset: dict[str, OutputDataset],
328
+ model_composite: DDP,
329
+ dataloader_prediction: DataLoader,
330
+ ) -> None:
331
+ self.world_size = world_size
211
332
  self.global_rank = global_rank
212
333
  self.local_rank = local_rank
213
334
 
214
- self.modelComposite = modelComposite
335
+ self.model_composite = model_composite
215
336
  self.dataloader_prediction = dataloader_prediction
216
- self.outsDataset = outsDataset
337
+ self.outputs_dataset = outputs_dataset
217
338
  self.autocast = autocast
218
-
339
+
219
340
  self.it = 0
220
341
 
221
- self.device = self.modelComposite.device
342
+ self.device = self.model_composite.device
222
343
  self.dataset: DatasetIter = self.dataloader_prediction.dataset
223
- patch_size, overlap = self.dataset.getPatchConfig()
224
- for outDataset in self.outsDataset.values():
225
- outDataset.setPatchConfig([size for size in patch_size if size > 1], overlap, np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataset.dataAugmentationsList])+1), 1]))
226
- self.data_log : dict[str, tuple[DataLog, int]] = {}
344
+ patch_size, overlap = self.dataset.get_patch_config()
345
+ for output_dataset in self.outputs_dataset.values():
346
+ output_dataset.set_patch_config(
347
+ [size for size in patch_size if size > 1] if patch_size else None,
348
+ overlap,
349
+ np.max(
350
+ [
351
+ int(
352
+ np.sum([data_augmentation.nb for data_augmentation in self.dataset.data_augmentations_list])
353
+ + 1
354
+ ),
355
+ 1,
356
+ ]
357
+ ),
358
+ )
359
+ self.data_log: dict[str, tuple[DataLog, int]] = {}
227
360
  if data_log is not None:
228
361
  for data in data_log:
229
- self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
230
- self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.modelComposite.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
231
-
362
+ self.data_log[data.split("/")[0].replace(":", ".")] = (
363
+ DataLog[data.split("/")[1]],
364
+ int(data.split("/")[2]),
365
+ )
366
+ self.tb = (
367
+ SummaryWriter(log_dir=predict_path + "Metric/")
368
+ if len(
369
+ [
370
+ network
371
+ for network in self.model_composite.module.get_networks().values()
372
+ if network.measure is not None
373
+ ]
374
+ )
375
+ or len(self.data_log)
376
+ else None
377
+ )
378
+
232
379
  def __enter__(self):
380
+ """
381
+ Enters the prediction context and returns the predictor instance.
382
+ """
233
383
  return self
234
-
235
- def __exit__(self, type, value, traceback):
384
+
385
+ def __exit__(self, exc_type, value, traceback):
386
+ """
387
+ Closes the TensorBoard writer upon exit.
388
+ """
236
389
  if self.tb:
237
390
  self.tb.close()
238
391
 
239
- def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
240
- return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
241
-
392
+ def get_input(
393
+ self,
394
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
395
+ ) -> dict[tuple[str, bool], torch.Tensor]:
396
+ return {(k, v[5][0]): v[0] for k, v in data_dict.items()}
397
+
242
398
  @torch.no_grad()
243
399
  def run(self):
244
- self.modelComposite.eval()
245
- self.modelComposite.module.setState(NetState.PREDICTION)
246
- desc = lambda : "Prediction : {}".format(description(self.modelComposite))
400
+ """
401
+ Run the full prediction loop.
402
+
403
+ Iterates over the prediction DataLoader, performs inference using the composite model,
404
+ applies reduction (e.g., mean), and writes the final results using each `OutputDataset`.
405
+
406
+ Also logs intermediate data and metrics to TensorBoard if enabled.
407
+ """
408
+
409
+ self.model_composite.eval()
410
+ self.model_composite.module.set_state(NetState.PREDICTION)
247
411
  self.dataloader_prediction.dataset.load("Prediction")
248
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
249
- for it, data_dict in batch_iter:
250
- with torch.amp.autocast('cuda', enabled=self.autocast):
251
- input = self.getInput(data_dict)
252
- for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
412
+ with tqdm.tqdm(
413
+ iterable=enumerate(self.dataloader_prediction),
414
+ leave=True,
415
+ desc=f"Prediction : {description(self.model_composite)}",
416
+ total=len(self.dataloader_prediction),
417
+ ncols=0,
418
+ ) as batch_iter:
419
+ for _, data_dict in batch_iter:
420
+ with torch.amp.autocast("cuda", enabled=self.autocast):
421
+ input_tensor = self.get_input(data_dict)
422
+ for name, output in self.model_composite(input_tensor, list(self.outputs_dataset.keys())):
253
423
  self._predict_log(data_dict)
254
- outDataset = self.outsDataset[name]
255
- for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
256
- outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
257
- if outDataset.isDone(index):
258
- outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
259
-
260
- batch_iter.set_description(desc())
424
+ output_dataset = self.outputs_dataset[name]
425
+ for i, (index, patch_augmentation, patch_index) in enumerate(
426
+ [
427
+ (int(index), int(patch_augmentation), int(patch_index))
428
+ for index, patch_augmentation, patch_index in zip(
429
+ list(data_dict.values())[0][1],
430
+ list(data_dict.values())[0][2],
431
+ list(data_dict.values())[0][3],
432
+ )
433
+ ]
434
+ ):
435
+ output_dataset.add_layer(
436
+ index,
437
+ patch_augmentation,
438
+ patch_index,
439
+ output[i].cpu(),
440
+ self.dataset,
441
+ )
442
+ if output_dataset.is_done(index):
443
+ output_dataset.write_prediction(
444
+ index,
445
+ self.dataset.get_dataset_from_index(list(data_dict.keys())[0], index).name.split(
446
+ "/"
447
+ )[-1],
448
+ output_dataset.get_output(index, self.dataset),
449
+ )
450
+
451
+ batch_iter.set_description(f"Prediction : {description(self.model_composite)}")
261
452
  self.it += 1
262
-
263
- def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
264
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
265
-
266
- if self.global_rank == 0:
267
- images_log = []
268
- if len(self.data_log):
453
+
454
+ def _predict_log(
455
+ self,
456
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
457
+ ):
458
+ """
459
+ Log prediction results to TensorBoard, including images and metrics.
460
+
461
+ This method handles:
462
+ - Logging image-like data (e.g., inputs, outputs, masks) using `DataLog` instances,
463
+ based on the `data_log` configuration.
464
+ - Logging scalar loss and metric values (if present in the network) under the `Prediction/` namespace.
465
+ - Dynamically retrieving additional feature maps or intermediate layers if requested via `data_log`.
466
+
467
+ Logging is performed only on the global rank 0 process and only if `TensorBoard` is active.
468
+
469
+ Args:
470
+ data_dict (dict): Dictionary mapping group names to 6-tuples containing:
471
+ - input tensor,
472
+ - index,
473
+ - patch_augmentation,
474
+ - patch_index,
475
+ - metadata (list of strings),
476
+ - `requires_grad` flag (as a tensor).
477
+ """
478
+ measures = DistributedObject.get_measure(
479
+ self.world_size,
480
+ self.global_rank,
481
+ self.local_rank,
482
+ {"": self.model_composite.module},
483
+ 1,
484
+ )
485
+
486
+ if self.global_rank == 0 and self.tb is not None:
487
+ data_log = []
488
+ if len(self.data_log):
269
489
  for name, data_type in self.data_log.items():
270
490
  if name in data_dict:
271
- data_type[0](self.tb, "Prediction/{}".format(name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
491
+ data_type[0](
492
+ self.tb,
493
+ f"Prediction/{name}",
494
+ data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
495
+ self.it,
496
+ )
272
497
  else:
273
- images_log.append(name.replace(":", "."))
498
+ data_log.append(name.replace(":", "."))
274
499
 
275
- for name, network in self.modelComposite.module.getNetworks().items():
500
+ for name, network in self.model_composite.module.get_networks().items():
276
501
  if network.measure is not None:
277
- self.tb.add_scalars("Prediction/{}/Loss".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][0].items()}, self.it)
278
- self.tb.add_scalars("Prediction/{}/Metric".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][1].items()}, self.it)
279
- if len(images_log):
280
- for name, layer, _ in self.model.module.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
281
- self.data_log[name][0](self.tb, "Prediction/{}".format(name), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
502
+ self.tb.add_scalars(
503
+ f"Prediction/{name}/Loss",
504
+ {k: v[1] for k, v in measures[name][0].items()},
505
+ self.it,
506
+ )
507
+ self.tb.add_scalars(
508
+ f"Prediction/{name}/Metric",
509
+ {k: v[1] for k, v in measures[name][1].items()},
510
+ self.it,
511
+ )
512
+ if len(data_log):
513
+ for name, layer, _ in self.model_composite.module.get_layers(
514
+ [v.to(0) for k, v in self.get_input(data_dict).items() if k[1]], data_log
515
+ ):
516
+ self.data_log[name][0](
517
+ self.tb,
518
+ f"Prediction/{name}",
519
+ layer[: self.data_log[name][1]].detach().cpu().numpy(),
520
+ self.it,
521
+ )
522
+
282
523
 
283
524
  class ModelComposite(Network):
525
+ """
526
+ A composite model that replicates a given base network multiple times and combines their outputs.
527
+
528
+ This class is designed to handle model ensembles or repeated predictions from the same architecture.
529
+ It creates `nb_models` deep copies of the input `model`, each with its own name and output branch,
530
+ and aggregates their outputs using a provided `Reduction` strategy (e.g., mean, median).
531
+
532
+ Args:
533
+ model (Network): The base network to replicate.
534
+ nb_models (int): Number of copies of the model to create.
535
+ combine (Reduction): The reduction method used to combine outputs from all model replicas.
536
+
537
+ Attributes:
538
+ combine (Reduction): The reduction method used during forward inference.
539
+ """
284
540
 
285
541
  def __init__(self, model: Network, nb_models: int, combine: Reduction):
286
- super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
542
+ super().__init__(
543
+ model.in_channels,
544
+ model.optimizer,
545
+ model.lr_schedulers_loader,
546
+ model.outputs_criterions_loader,
547
+ model.patch,
548
+ model.nb_batch_per_step,
549
+ model.init_type,
550
+ model.init_gain,
551
+ model.dim,
552
+ )
287
553
  self.combine = combine
288
554
  for i in range(nb_models):
289
- self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
555
+ self.add_module(
556
+ f"Model_{i}",
557
+ copy.deepcopy(model),
558
+ in_branch=[0],
559
+ out_branch=[f"output_{i}"],
560
+ )
561
+
562
+ def load(self, state_dicts: list[dict[str, dict[str, torch.Tensor]]]):
563
+ """
564
+ Load weights for each sub-model in the composite from the corresponding state dictionaries.
290
565
 
291
- def load(self, state_dicts : list[dict[str, dict[str, torch.Tensor]]]):
566
+ Args:
567
+ state_dicts (list): A list of state dictionaries, one for each model replica.
568
+ """
292
569
  for i, state_dict in enumerate(state_dicts):
293
- self["Model_{}".format(i)].load(state_dict, init=False)
294
- self["Model_{}".format(i)].setName("{}_{}".format(self["Model_{}".format(i)].getName(), i))
295
-
296
- def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
570
+ self[f"Model_{i}"].load(state_dict, init=False)
571
+ self[f"Model_{i}"].set_name(f"{self[f'Model_{i}'].get_name()}_{i}")
572
+
573
+ def forward(
574
+ self,
575
+ data_dict: dict[tuple[str, bool], torch.Tensor],
576
+ output_layers: list[str] = [],
577
+ ) -> list[tuple[str, torch.Tensor]]:
578
+ """
579
+ Perform a forward pass on all model replicas and aggregate their outputs.
580
+
581
+ Args:
582
+ data_dict (dict): A dictionary mapping (group_name, requires_grad) to input tensors.
583
+ output_layers (list): List of output layer names to extract from each sub-model.
584
+
585
+ Returns:
586
+ list[tuple[str, torch.Tensor]]: Aggregated output for each layer, after applying the reduction.
587
+ """
297
588
  result = {}
298
589
  for name, module in self.items():
299
590
  result[name] = module(data_dict, output_layers)
300
-
591
+
301
592
  aggregated = defaultdict(list)
302
593
  for module_outputs in result.values():
303
594
  for key, tensor in module_outputs:
304
595
  aggregated[key].append(tensor)
305
-
596
+
306
597
  final_outputs = []
307
598
  for key, tensors in aggregated.items():
308
599
  final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
309
600
 
310
601
  return final_outputs
311
-
602
+
312
603
 
313
604
  class Predictor(DistributedObject):
605
+ """
606
+ KonfAI's main prediction controller.
607
+
608
+ This class orchestrates the prediction phase by:
609
+ - Loading model weights from checkpoint(s) or URL(s)
610
+ - Preparing datasets and output configurations
611
+ - Managing distributed inference with optional multi-GPU support
612
+ - Applying transformations and saving predictions
613
+ - Optionally logging results to TensorBoard
614
+
615
+ Attributes:
616
+ model (Network): The neural network model to use for prediction.
617
+ dataset (DataPrediction): Dataset manager for prediction data.
618
+ combine_classpath (str): Path to the reduction strategy (e.g., "Mean").
619
+ autocast (bool): Whether to enable AMP inference.
620
+ outputs_dataset (dict[str, OutputDataset]): Mapping from layer names to output writers.
621
+ data_log (list[str] | None): List of tensors to log during inference.
622
+ """
314
623
 
315
624
  @config("Predictor")
316
- def __init__(self,
317
- model: ModelLoader = ModelLoader(),
318
- dataset: DataPrediction = DataPrediction(),
319
- combine: str = "mean",
320
- train_name: str = "name",
321
- manual_seed : Union[int, None] = None,
322
- gpu_checkpoints: Union[list[str], None] = None,
323
- autocast : bool = False,
324
- outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
325
- images_log: list[str] = []) -> None:
625
+ def __init__(
626
+ self,
627
+ model: ModelLoader = ModelLoader(),
628
+ dataset: DataPrediction = DataPrediction(),
629
+ combine: str = "Mean",
630
+ train_name: str = "name",
631
+ manual_seed: int | None = None,
632
+ gpu_checkpoints: list[str] | None = None,
633
+ autocast: bool = False,
634
+ outputs_dataset: dict[str, OutputDatasetLoader] | None = {"default:Default": OutputDatasetLoader()},
635
+ data_log: list[str] | None = None,
636
+ ) -> None:
326
637
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
327
638
  exit(0)
328
639
  super().__init__(train_name)
@@ -331,99 +642,194 @@ class Predictor(DistributedObject):
331
642
  self.combine_classpath = combine
332
643
  self.autocast = autocast
333
644
 
334
- self.model = model.getModel(train=False)
645
+ self.model = model.get_model(train=False)
335
646
  self.it = 0
336
- self.outsDatasetLoader = outsDataset if outsDataset else {}
337
- self.outsDataset = {name.replace(":", ".") : value.getOutDataset(name) for name, value in self.outsDatasetLoader.items()}
647
+ self.outputs_dataset_loader = outputs_dataset if outputs_dataset else {}
648
+ self.outputs_dataset = {
649
+ name.replace(":", "."): value.get_output_dataset(name)
650
+ for name, value in self.outputs_dataset_loader.items()
651
+ }
338
652
 
339
653
  self.datasets_filename = []
340
- self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
341
- self.images_log = images_log
342
- for outDataset in self.outsDataset.values():
343
- self.datasets_filename.append(outDataset.filename)
344
- outDataset.filename = "{}{}".format(self.predict_path, outDataset.filename)
345
-
346
-
654
+ self.predict_path = predictions_directory() + self.name + "/"
655
+ for output_dataset in self.outputs_dataset.values():
656
+ self.datasets_filename.append(output_dataset.filename)
657
+ output_dataset.filename = f"{self.predict_path}{output_dataset.filename}"
658
+
659
+ self.data_log = data_log
660
+ modules = []
661
+ for i, _ in self.model.named_modules():
662
+ modules.append(i)
663
+ if self.data_log is not None:
664
+ for k in self.data_log:
665
+ tmp = k.split("/")[0].replace(":", ".")
666
+ if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
667
+ raise PredictorError(
668
+ f"Invalid key '{tmp}' in `data_log`.",
669
+ f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
670
+ f"nor a valid module name in the model ({modules}).",
671
+ "Please check your `data_log` configuration,"
672
+ " it should reference either a model output or a dataset group.",
673
+ )
674
+
347
675
  self.gpu_checkpoints = gpu_checkpoints
348
676
 
349
677
  def _load(self) -> list[dict[str, dict[str, torch.Tensor]]]:
350
- model_paths = MODEL().split(":")
678
+ """
679
+ Load pretrained model weights from configured paths or URLs.
680
+
681
+ This method handles both remote and local model sources:
682
+ - If the model path is a URL (starting with "https://"), it uses `torch.hub.load_state_dict_from_url`
683
+ to download and load the state dict.
684
+ - If the model path is local:
685
+ - It either loads the explicit file or resolves the latest model file in a default directory
686
+ based on the prediction name.
687
+ - All loaded state dicts are returned as a list of nested dictionaries mapping module names
688
+ to parameter tensors.
689
+
690
+ Returns:
691
+ list[dict[str, dict[str, torch.Tensor]]]: A list of state dictionaries, one per model.
692
+
693
+ Raises:
694
+ Exception: If a model path does not exist or cannot be loaded.
695
+ """
696
+ model_paths = path_to_models().split(":")
351
697
  state_dicts = []
352
698
  for model_path in model_paths:
353
699
  if model_path.startswith("https://"):
354
700
  try:
355
- state_dicts.append(torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True))
356
- except:
357
- raise Exception("Model : {} does not exist !".format(model_path))
701
+ state_dicts.append(
702
+ torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True)
703
+ )
704
+ except Exception:
705
+ raise Exception(f"Model : {model_path} does not exist !")
358
706
  else:
359
707
  if model_path != "":
360
708
  path = ""
361
709
  name = model_path
362
710
  else:
363
711
  if self.name.endswith(".pt"):
364
- path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
712
+ path = models_directory() + "/".join(self.name.split("/")[:-1]) + "/StateDict/"
365
713
  name = self.name.split("/")[-1]
366
714
  else:
367
- path = MODELS_DIRECTORY()+self.name+"/StateDict/"
715
+ path = models_directory() + self.name + "/StateDict/"
368
716
  name = sorted(os.listdir(path))[-1]
369
- if os.path.exists(path+name):
370
- state_dicts.append(torch.load(path+name, weights_only=False))
717
+ if os.path.exists(path + name):
718
+ state_dicts.append(torch.load(path + name, weights_only=True))
371
719
  else:
372
- raise Exception("Model : {} does not exist !".format(path+name))
720
+ raise Exception(f"Model : {path + name} does not exist !")
373
721
  return state_dicts
374
-
722
+
375
723
  def setup(self, world_size: int):
724
+ """
725
+ Set up the predictor for inference.
726
+
727
+ This method performs all necessary initialization steps before running predictions:
728
+ - Ensures output directories exist, and optionally prompts the user before overwriting existing predictions.
729
+ - Copies the current configuration file (Prediction.yml) into the output directory for reproducibility.
730
+ - Initializes the model in prediction mode, including output configuration and channel tracing.
731
+ - Validates that the configured output groups match existing modules in the model architecture.
732
+ - Dynamically loads pretrained weights from local files or remote URLs.
733
+ - Wraps the base model into a `ModelComposite` to support ensemble inference.
734
+ - Initializes the prediction dataloader, with proper distribution across available GPUs.
735
+ - Loads and prepares each configured `OutputDataset` object for storing predictions.
736
+
737
+ Args:
738
+ world_size (int): Total number of processes or GPUs used for distributed prediction.
739
+
740
+ Raises:
741
+ PredictorError: If an output group does not match any module in the model.
742
+ Exception: If a specified model file or URL is invalid or inaccessible.
743
+ """
376
744
  for dataset_filename in self.datasets_filename:
377
- path = self.predict_path +dataset_filename
745
+ path = self.predict_path + dataset_filename
378
746
  if os.path.exists(path):
379
747
  if os.environ["KONFAI_OVERWRITE"] != "True":
380
- accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
748
+ accept = builtins.input(
749
+ f"The prediction {path} already exists ! Do you want to overwrite it (yes,no) : "
750
+ )
381
751
  if accept != "yes":
382
752
  return
383
-
753
+
384
754
  if not os.path.exists(path):
385
755
  os.makedirs(path)
386
756
 
387
- shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
757
+ shutil.copyfile(config_file(), self.predict_path + "Prediction.yml")
388
758
 
389
-
390
- self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
391
- self.model.init_outputsGroup()
759
+ self.model.init(self.autocast, State.PREDICTION, self.dataset.get_groups_dest())
760
+ self.model.init_outputs_group()
392
761
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
393
-
762
+
394
763
  modules = []
395
- for i,_,_ in self.model.named_ModuleArgsDict():
764
+ for i, _, _ in self.model.named_module_args_dict():
396
765
  modules.append(i)
397
- for output_group in self.outsDataset.keys():
398
- if output_group not in modules:
399
- raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
400
- "Available modules: {}".format(modules),
401
- "Please check that the name matches exactly a submodule or output of your model architecture."
766
+ for output_group in self.outputs_dataset.keys():
767
+ if output_group.replace(";accu;", "") not in modules:
768
+ raise PredictorError(
769
+ f"The output group '{output_group}' defined in 'outputs_criterions' "
770
+ "does not correspond to any module in the model.",
771
+ f"Available modules: {modules}",
772
+ "Please check that the name matches exactly a submodule or" "output of your model architecture.",
402
773
  )
403
-
404
- module, name = _getModule(self.combine_classpath, "konfai.predictor")
774
+
775
+ module, name = get_module(self.combine_classpath, "konfai.predictor")
405
776
  if module == "konfai.predictor":
406
777
  combine = getattr(importlib.import_module(module), name)()
407
778
  else:
408
- combine = config("{}.{}".format(KONFAI_ROOT(), self.combine_classpath))(getattr(importlib.import_module(module), name))(config = None)
409
-
410
-
411
- self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), combine)
412
- self.modelComposite.load(self._load())
779
+ combine = config(f"{konfai_root()}.{self.combine_classpath}")(
780
+ getattr(importlib.import_module(module), name)
781
+ )(config=None)
782
+
783
+ self.model_composite = ModelComposite(self.model, len(path_to_models().split(":")), combine)
784
+ self.model_composite.load(self._load())
413
785
 
414
- if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
786
+ if (
787
+ len(list(self.outputs_dataset.keys())) == 0
788
+ and len(
789
+ [network for network in self.model_composite.get_networks().values() if network.measure is not None]
790
+ )
791
+ == 0
792
+ ):
415
793
  exit(0)
416
-
417
- self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
418
- self.dataloader = self.dataset.getData(world_size//self.size)
419
- for name, outDataset in self.outsDataset.items():
420
- outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
421
-
422
-
423
- def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
424
- modelComposite = Network.to(self.modelComposite, local_rank*self.size)
425
- modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
426
- with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
427
- p.run()
428
794
 
429
-
795
+ self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
796
+ self.dataloader, _, _ = self.dataset.get_data(world_size // self.size)
797
+ for name, output_dataset in self.outputs_dataset.items():
798
+ output_dataset.load(
799
+ name.replace(".", ":"),
800
+ list(self.dataset.datasets.values()),
801
+ {src: dest for src, inner in self.dataset.groups_src.items() for dest in inner},
802
+ )
803
+
804
+ def run_process(
805
+ self,
806
+ world_size: int,
807
+ global_rank: int,
808
+ local_rank: int,
809
+ dataloaders: list[DataLoader],
810
+ ):
811
+ """
812
+ Launch prediction on the given process rank.
813
+
814
+ Args:
815
+ world_size (int): Total number of processes.
816
+ global_rank (int): Rank of the current process.
817
+ local_rank (int): Local device rank.
818
+ dataloaders (list[DataLoader]): List of data loaders for prediction.
819
+ """
820
+ model_composite = Network.to(self.model_composite, local_rank * self.size)
821
+ model_composite = (
822
+ DDP(model_composite, static_graph=True) if torch.cuda.is_available() else CPUModel(model_composite)
823
+ )
824
+ with _Predictor(
825
+ world_size,
826
+ global_rank,
827
+ local_rank,
828
+ self.autocast,
829
+ self.predict_path,
830
+ self.data_log,
831
+ self.outputs_dataset,
832
+ model_composite,
833
+ *dataloaders,
834
+ ) as p:
835
+ p.run()