konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/predictor.py CHANGED
@@ -1,304 +1,594 @@
1
- from abc import ABC, abstractmethod
2
1
  import builtins
2
+ import copy
3
3
  import importlib
4
+ import os
4
5
  import shutil
6
+ from abc import ABC, abstractmethod
7
+ from collections import defaultdict
8
+
9
+ import numpy as np
5
10
  import torch
6
11
  import tqdm
7
- import os
12
+ from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.tensorboard.writer import SummaryWriter
8
15
 
9
- from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
- from konfai.utils.config import config
11
- from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
12
- from konfai.utils.dataset import Dataset, Attribute
16
+ from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
13
17
  from konfai.data.data_manager import DataPrediction, DatasetIter
14
18
  from konfai.data.patching import Accumulator, PathCombine
15
- from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
16
19
  from konfai.data.transform import Transform, TransformLoader
20
+ from konfai.network.network import CPUModel, ModelLoader, NetState, Network
21
+ from konfai.utils.config import config
22
+ from konfai.utils.dataset import Attribute, Dataset
23
+ from konfai.utils.utils import DataLog, DistributedObject, NeedDevice, PredictorError, State, description, get_module
24
+
25
+
26
+ class Reduction(ABC):
27
+
28
+ @abstractmethod
29
+ def __init__(self):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
34
+ pass
35
+
36
+
37
+ class Mean(Reduction):
38
+
39
+ def __init__(self):
40
+ pass
41
+
42
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
43
+ return torch.mean(result.float(), dim=0)
17
44
 
18
- from torch.utils.tensorboard.writer import SummaryWriter
19
- from typing import Union
20
- import numpy as np
21
- import torch.distributed as dist
22
- from torch.nn.parallel import DistributedDataParallel as DDP
23
- from torch.utils.data import DataLoader
24
- import importlib
25
- import copy
26
- from collections import defaultdict
27
45
 
28
- class OutDataset(Dataset, NeedDevice, ABC):
46
+ class Median(Reduction):
47
+
48
+ def __init__(self):
49
+ pass
50
+
51
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
52
+ return torch.median(result.float(), dim=0).values
29
53
 
30
- def __init__(self, filename: str, group: str, before_reduction_transforms : dict[str, TransformLoader], after_reduction_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None], reduction: str) -> None:
31
- filename, format = filename.split(":")
32
- super().__init__(filename, format)
54
+
55
+ class OutputDataset(Dataset, NeedDevice, ABC):
56
+
57
+ def __init__(
58
+ self,
59
+ filename: str,
60
+ group: str,
61
+ before_reduction_transforms: dict[str, TransformLoader],
62
+ after_reduction_transforms: dict[str, TransformLoader],
63
+ final_transforms: dict[str, TransformLoader],
64
+ patch_combine: str | None,
65
+ reduction: str,
66
+ ) -> None:
67
+ filename, file_format = filename.split(":")
68
+ super().__init__(filename, file_format)
33
69
  self.group = group
34
70
  self._before_reduction_transforms = before_reduction_transforms
35
71
  self._after_reduction_transforms = after_reduction_transforms
36
72
  self._final_transforms = final_transforms
37
- self._patchCombine = patchCombine
73
+ self._patch_combine = patch_combine
38
74
  self.reduction_classpath = reduction
39
- self.reduction = None
40
-
41
- self.before_reduction_transforms : list[Transform] = []
42
- self.after_reduction_transforms : list[Transform] = []
43
- self.final_transforms : list[Transform] = []
44
- self.patchCombine: PathCombine = None
75
+ self.reduction: Reduction
76
+
77
+ self.before_reduction_transforms: list[Transform] = []
78
+ self.after_reduction_transforms: list[Transform] = []
79
+ self.final_transforms: list[Transform] = []
80
+ self.patch_combine: PathCombine | None = None
45
81
 
46
82
  self.output_layer_accumulator: dict[int, dict[int, Accumulator]] = {}
47
83
  self.attributes: dict[int, dict[int, dict[int, Attribute]]] = {}
48
84
  self.names: dict[int, str] = {}
49
85
  self.nb_data_augmentation = 0
50
86
 
87
+ @abstractmethod
51
88
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
52
- transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
53
- for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
54
-
89
+ transforms_type = [
90
+ "before_reduction_transforms",
91
+ "after_reduction_transforms",
92
+ "final_transforms",
93
+ ]
94
+ for name, _transform_type, transform_type in [
95
+ (k, getattr(self, f"_{k}"), getattr(self, k)) for k in transforms_type
96
+ ]:
97
+
55
98
  if _transform_type is not None:
56
99
  for classpath, transform in _transform_type.items():
57
- transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, name))
58
- transform.setDatasets(datasets)
100
+ transform = transform.get_transform(
101
+ classpath,
102
+ konfai_args=f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{name}",
103
+ )
104
+ transform.set_datasets(datasets)
59
105
  transform_type.append(transform)
60
106
 
61
- if self._patchCombine is not None:
62
- module, name = _getModule(self._patchCombine, "konfai.data.patching")
63
- self.patchCombine = config("{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))(getattr(importlib.import_module(module), name))(config = None)
107
+ if self._patch_combine is not None:
108
+ module, name = get_module(self._patch_combine, "konfai.data.patching")
109
+ self.patch_combine = config(f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset")(
110
+ getattr(importlib.import_module(module), name)
111
+ )(config=None)
64
112
 
65
- module, name = _getModule(self.reduction_classpath, "konfai.predictor")
113
+ module, name = get_module(self.reduction_classpath, "konfai.predictor")
66
114
  if module == "konfai.predictor":
67
- self.reduction = getattr(importlib.import_module(module), name)
115
+ self.reduction = getattr(importlib.import_module(module), name)()
68
116
  else:
69
- self.reduction = config("{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, self.reduction_classpath))(getattr(importlib.import_module(module), name))(config = None)
70
-
71
-
72
- def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
73
- if patchSize is not None and overlap is not None:
74
- if self.patchCombine is not None:
75
- self.patchCombine.setPatchConfig(patchSize, overlap)
117
+ self.reduction = config(
118
+ f"{konfai_root()}.outputs_dataset.{name_layer}.OutputDataset.{self.reduction_classpath}"
119
+ )(getattr(importlib.import_module(module), name))(config=None)
120
+
121
+ def set_patch_config(
122
+ self,
123
+ patch_size: list[int] | None,
124
+ overlap: int | None,
125
+ nb_data_augmentation: int,
126
+ ) -> None:
127
+ if patch_size is not None and overlap is not None:
128
+ if self.patch_combine is not None:
129
+ self.patch_combine.set_patch_config(patch_size, overlap)
76
130
  else:
77
- self.patchCombine = None
131
+ self.patch_combine = None
78
132
  self.nb_data_augmentation = nb_data_augmentation
79
-
80
- def setDevice(self, device: torch.device):
81
- super().setDevice(device)
82
- transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
133
+
134
+ def to(self, device: torch.device):
135
+ super().to(device)
136
+ transforms_type = [
137
+ "before_reduction_transforms",
138
+ "after_reduction_transforms",
139
+ "final_transforms",
140
+ ]
83
141
  for transform_type in [(getattr(self, k)) for k in transforms_type]:
84
142
  if transform_type is not None:
85
143
  for transform in transform_type:
86
- transform.setDevice(device)
144
+ transform.to(device)
87
145
 
88
146
  @abstractmethod
89
- def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
147
+ def add_layer(
148
+ self,
149
+ index_dataset: int,
150
+ index_augmentation: int,
151
+ index_patch: int,
152
+ layer: torch.Tensor,
153
+ dataset: DatasetIter,
154
+ ):
90
155
  pass
91
156
 
92
- def isDone(self, index: int) -> bool:
93
- return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all([acc.isFull() for acc in self.output_layer_accumulator[index].values()])
157
+ def is_done(self, index: int) -> bool:
158
+ return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all(
159
+ acc.is_full() for acc in self.output_layer_accumulator[index].values()
160
+ )
94
161
 
95
162
  @abstractmethod
96
- def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
163
+ def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
97
164
  pass
98
165
 
99
- def write(self, index: int, name: str, layer: torch.Tensor):
166
+ def write_prediction(self, index: int, name: str, layer: torch.Tensor) -> None:
100
167
  super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
101
168
  self.attributes.pop(index)
102
169
 
103
- class Reduction():
104
-
105
- def __init__(self):
106
- pass
107
-
108
- class Mean(Reduction):
109
170
 
110
- def __init__(self):
111
- pass
112
-
113
- def __call__(self, result: torch.Tensor) -> torch.Tensor:
114
- return torch.mean(result.float(), dim=0)
115
-
116
- class Median(Reduction):
117
-
118
- def __init__(self):
119
- pass
120
-
121
- def __call__(self, result: torch.Tensor) -> torch.Tensor:
122
- return torch.median(result.float(), dim=0).values
123
-
124
- class OutSameAsGroupDataset(OutDataset):
125
-
126
- @config("OutDataset")
127
- def __init__(self, sameAsGroup: str = "default", dataset_filename: str = "default:./Dataset:mha", group: str = "default", before_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, after_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
128
- super().__init__(dataset_filename, group, before_reduction_transforms, after_reduction_transforms, final_transforms, patchCombine, reduction)
129
- self.group_src, self.group_dest = sameAsGroup.split(":")
171
+ class OutSameAsGroupDataset(OutputDataset):
172
+
173
+ @config("OutputDataset")
174
+ def __init__(
175
+ self,
176
+ same_as_group: str = "default",
177
+ dataset_filename: str = "default:./Dataset:mha",
178
+ group: str = "default",
179
+ before_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
180
+ after_reduction_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
181
+ final_transforms: dict[str, TransformLoader] = {"default:Normalize": TransformLoader()},
182
+ patch_combine: str | None = None,
183
+ reduction: str = "Mean",
184
+ inverse_transform: bool = True,
185
+ ) -> None:
186
+ super().__init__(
187
+ dataset_filename,
188
+ group,
189
+ before_reduction_transforms,
190
+ after_reduction_transforms,
191
+ final_transforms,
192
+ patch_combine,
193
+ reduction,
194
+ )
195
+ self.group_src, self.group_dest = same_as_group.split(":")
130
196
  self.inverse_transform = inverse_transform
131
197
 
132
- def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
133
- if index_dataset not in self.output_layer_accumulator or index_augmentation not in self.output_layer_accumulator[index_dataset]:
134
- input_dataset = dataset.getDatasetFromIndex(self.group_dest, index_dataset)
198
+ def add_layer(
199
+ self,
200
+ index_dataset: int,
201
+ index_augmentation: int,
202
+ index_patch: int,
203
+ layer: torch.Tensor,
204
+ dataset: DatasetIter,
205
+ ):
206
+ if (
207
+ index_dataset not in self.output_layer_accumulator
208
+ or index_augmentation not in self.output_layer_accumulator[index_dataset]
209
+ ):
210
+ input_dataset = dataset.get_dataset_from_index(self.group_dest, index_dataset)
135
211
  if index_dataset not in self.output_layer_accumulator:
136
212
  self.output_layer_accumulator[index_dataset] = {}
137
213
  self.attributes[index_dataset] = {}
138
214
  self.names[index_dataset] = input_dataset.name
139
215
  self.attributes[index_dataset][index_augmentation] = {}
140
216
 
141
- self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(input_dataset.patch.getPatch_slices(index_augmentation), input_dataset.patch.patch_size, self.patchCombine, batch=False)
217
+ self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(
218
+ input_dataset.patch.get_patch_slices(index_augmentation),
219
+ input_dataset.patch.patch_size,
220
+ self.patch_combine,
221
+ batch=False,
222
+ )
142
223
 
143
- for i in range(len(input_dataset.patch.getPatch_slices(index_augmentation))):
224
+ for i in range(len(input_dataset.patch.get_patch_slices(index_augmentation))):
144
225
  self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
145
226
 
146
227
  if self.inverse_transform:
147
- for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
148
- layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
149
- self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
150
-
228
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
229
+ layer = transform.inverse(
230
+ self.names[index_dataset],
231
+ layer,
232
+ self.attributes[index_dataset][index_augmentation][index_patch],
233
+ )
234
+ self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
151
235
 
152
236
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
153
237
  super().load(name_layer, datasets, groups)
154
-
238
+
155
239
  if self.group_src not in groups.keys():
156
- raise PredictorError(
157
- f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
158
- )
240
+ raise PredictorError(f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}.")
159
241
 
160
242
  if self.group_dest not in groups[self.group_src]:
161
243
  raise PredictorError(
162
244
  f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
163
245
  )
164
-
165
- def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
246
+
247
+ def _get_output(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
166
248
  layer = self.output_layer_accumulator[index][index_augmentation].assemble()
167
249
  if index_augmentation > 0:
168
-
250
+
169
251
  i = 0
170
- index_augmentation_tmp = index_augmentation-1
171
- for dataAugmentations in dataset.dataAugmentationsList:
172
- if index_augmentation_tmp >= i and index_augmentation_tmp < i+dataAugmentations.nb:
173
- for dataAugmentation in reversed(dataAugmentations.dataAugmentations):
174
- layer = dataAugmentation.inverse(index, index_augmentation_tmp-i, layer)
252
+ index_augmentation_tmp = index_augmentation - 1
253
+ for data_augmentations in dataset.data_augmentations_list:
254
+ if index_augmentation_tmp >= i and index_augmentation_tmp < i + data_augmentations.nb:
255
+ for data_augmentation in reversed(data_augmentations.data_augmentations):
256
+ layer = data_augmentation.inverse(index, index_augmentation_tmp - i, layer)
175
257
  break
176
- i += dataAugmentations.nb
258
+ i += data_augmentations.nb
177
259
 
178
260
  for transform in self.before_reduction_transforms:
179
261
  layer = transform(self.names[index], layer, self.attributes[index][index_augmentation][0])
180
-
262
+
181
263
  return layer
182
264
 
183
- def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
184
- result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
265
+ def get_output(self, index: int, dataset: DatasetIter) -> torch.Tensor:
266
+ result = torch.cat(
267
+ [
268
+ self._get_output(index, index_augmentation, dataset).unsqueeze(0)
269
+ for index_augmentation in self.output_layer_accumulator[index].keys()
270
+ ],
271
+ dim=0,
272
+ )
185
273
  self.output_layer_accumulator.pop(index)
186
274
  result = self.reduction(result.float()).to(result.dtype)
187
-
188
275
  for transform in self.after_reduction_transforms:
189
276
  result = transform(self.names[index], result, self.attributes[index][0][0])
190
277
 
191
278
  if self.inverse_transform:
192
- for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
279
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
193
280
  result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
194
-
281
+
195
282
  for transform in self.final_transforms:
196
- result = transform(self.names[index], result, self.attributes[index][0][0])
283
+ result = transform(self.names[index], result, self.attributes[index][0][0])
197
284
  return result
198
285
 
199
- class OutDatasetLoader():
200
286
 
201
- @config("OutDataset")
287
+ class OutputDatasetLoader:
288
+
289
+ @config("OutputDataset")
202
290
  def __init__(self, name_class: str = "OutSameAsGroupDataset") -> None:
203
291
  self.name_class = name_class
204
292
 
205
- def getOutDataset(self, layer_name: str) -> OutDataset:
206
- return getattr(importlib.import_module("konfai.predictor"), self.name_class)(config = None, DL_args = "Predictor.outsDataset.{}".format(layer_name))
207
-
208
- class _Predictor():
209
-
210
- def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
211
- self.world_size = world_size
293
+ def get_output_dataset(self, layer_name: str) -> OutputDataset:
294
+ return getattr(importlib.import_module("konfai.predictor"), self.name_class)(
295
+ config=None, konfai_args=f"Predictor.outputs_dataset.{layer_name}"
296
+ )
297
+
298
+
299
+ class _Predictor:
300
+ """
301
+ Internal class that runs distributed inference over a dataset using a composite model.
302
+
303
+ This class handles patch-wise prediction, output accumulation, logging to TensorBoard, and
304
+ writing final predictions to disk. It is designed to be used as a context manager and
305
+ supports model ensembles via `ModelComposite`.
306
+
307
+ Args:
308
+ world_size (int): Total number of processes or GPUs used.
309
+ global_rank (int): Rank of the current process across all nodes.
310
+ local_rank (int): Local GPU index within a single node.
311
+ autocast (bool): Whether to use automatic mixed precision (AMP).
312
+ predict_path (str): Output directory path where predictions and metrics are saved.
313
+ data_log (list[str] | None): List of logging targets in the format 'group/DataLogType/N'.
314
+ outputs_dataset (dict[str, OutputDataset]): Dictionary of output datasets to store predictions.
315
+ model_composite (DDP): Distributed model container that wraps the prediction model(s).
316
+ dataloader_prediction (DataLoader): DataLoader that provides prediction batches.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ world_size: int,
322
+ global_rank: int,
323
+ local_rank: int,
324
+ autocast: bool,
325
+ predict_path: str,
326
+ data_log: list[str] | None,
327
+ outputs_dataset: dict[str, OutputDataset],
328
+ model_composite: DDP,
329
+ dataloader_prediction: DataLoader,
330
+ ) -> None:
331
+ self.world_size = world_size
212
332
  self.global_rank = global_rank
213
333
  self.local_rank = local_rank
214
334
 
215
- self.modelComposite = modelComposite
335
+ self.model_composite = model_composite
216
336
  self.dataloader_prediction = dataloader_prediction
217
- self.outsDataset = outsDataset
337
+ self.outputs_dataset = outputs_dataset
218
338
  self.autocast = autocast
219
-
339
+
220
340
  self.it = 0
221
341
 
222
- self.device = self.modelComposite.device
342
+ self.device = self.model_composite.device
223
343
  self.dataset: DatasetIter = self.dataloader_prediction.dataset
224
- patch_size, overlap = self.dataset.getPatchConfig()
225
- for outDataset in self.outsDataset.values():
226
- outDataset.setPatchConfig([size for size in patch_size if size > 1], overlap, np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataset.dataAugmentationsList])+1), 1]))
227
- self.data_log : dict[str, tuple[DataLog, int]] = {}
344
+ patch_size, overlap = self.dataset.get_patch_config()
345
+ for output_dataset in self.outputs_dataset.values():
346
+ output_dataset.set_patch_config(
347
+ [size for size in patch_size if size > 1] if patch_size else None,
348
+ overlap,
349
+ np.max(
350
+ [
351
+ int(
352
+ np.sum([data_augmentation.nb for data_augmentation in self.dataset.data_augmentations_list])
353
+ + 1
354
+ ),
355
+ 1,
356
+ ]
357
+ ),
358
+ )
359
+ self.data_log: dict[str, tuple[DataLog, int]] = {}
228
360
  if data_log is not None:
229
361
  for data in data_log:
230
- self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
231
- self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.modelComposite.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
232
-
362
+ self.data_log[data.split("/")[0].replace(":", ".")] = (
363
+ DataLog[data.split("/")[1]],
364
+ int(data.split("/")[2]),
365
+ )
366
+ self.tb = (
367
+ SummaryWriter(log_dir=predict_path + "Metric/")
368
+ if len(
369
+ [
370
+ network
371
+ for network in self.model_composite.module.get_networks().values()
372
+ if network.measure is not None
373
+ ]
374
+ )
375
+ or len(self.data_log)
376
+ else None
377
+ )
378
+
233
379
  def __enter__(self):
380
+ """
381
+ Enters the prediction context and returns the predictor instance.
382
+ """
234
383
  return self
235
-
236
- def __exit__(self, type, value, traceback):
384
+
385
+ def __exit__(self, exc_type, value, traceback):
386
+ """
387
+ Closes the TensorBoard writer upon exit.
388
+ """
237
389
  if self.tb:
238
390
  self.tb.close()
239
391
 
240
- def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
241
- return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
242
-
392
+ def get_input(
393
+ self,
394
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
395
+ ) -> dict[tuple[str, bool], torch.Tensor]:
396
+ return {(k, v[5][0]): v[0] for k, v in data_dict.items()}
397
+
243
398
  @torch.no_grad()
244
399
  def run(self):
245
- self.modelComposite.eval()
246
- self.modelComposite.module.setState(NetState.PREDICTION)
247
- desc = lambda : "Prediction : {}".format(description(self.modelComposite))
400
+ """
401
+ Run the full prediction loop.
402
+
403
+ Iterates over the prediction DataLoader, performs inference using the composite model,
404
+ applies reduction (e.g., mean), and writes the final results using each `OutputDataset`.
405
+
406
+ Also logs intermediate data and metrics to TensorBoard if enabled.
407
+ """
408
+
409
+ self.model_composite.eval()
410
+ self.model_composite.module.set_state(NetState.PREDICTION)
248
411
  self.dataloader_prediction.dataset.load("Prediction")
249
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
250
- for it, data_dict in batch_iter:
251
- with torch.amp.autocast('cuda', enabled=self.autocast):
252
- input = self.getInput(data_dict)
253
- for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
412
+ with tqdm.tqdm(
413
+ iterable=enumerate(self.dataloader_prediction),
414
+ leave=True,
415
+ desc=f"Prediction : {description(self.model_composite)}",
416
+ total=len(self.dataloader_prediction),
417
+ ncols=0,
418
+ ) as batch_iter:
419
+ for _, data_dict in batch_iter:
420
+ with torch.amp.autocast("cuda", enabled=self.autocast):
421
+ input_tensor = self.get_input(data_dict)
422
+ for name, output in self.model_composite(input_tensor, list(self.outputs_dataset.keys())):
254
423
  self._predict_log(data_dict)
255
- outDataset = self.outsDataset[name]
256
- for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
257
- outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
258
- if outDataset.isDone(index):
259
- outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
260
-
261
- batch_iter.set_description(desc())
424
+ output_dataset = self.outputs_dataset[name]
425
+ for i, (index, patch_augmentation, patch_index) in enumerate(
426
+ [
427
+ (int(index), int(patch_augmentation), int(patch_index))
428
+ for index, patch_augmentation, patch_index in zip(
429
+ list(data_dict.values())[0][1],
430
+ list(data_dict.values())[0][2],
431
+ list(data_dict.values())[0][3],
432
+ )
433
+ ]
434
+ ):
435
+ output_dataset.add_layer(
436
+ index,
437
+ patch_augmentation,
438
+ patch_index,
439
+ output[i].cpu(),
440
+ self.dataset,
441
+ )
442
+ if output_dataset.is_done(index):
443
+ output_dataset.write_prediction(
444
+ index,
445
+ self.dataset.get_dataset_from_index(list(data_dict.keys())[0], index).name.split(
446
+ "/"
447
+ )[-1],
448
+ output_dataset.get_output(index, self.dataset),
449
+ )
450
+
451
+ batch_iter.set_description(f"Prediction : {description(self.model_composite)}")
262
452
  self.it += 1
263
-
264
- def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
265
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
266
-
267
- if self.global_rank == 0:
268
- images_log = []
269
- if len(self.data_log):
453
+
454
+ def _predict_log(
455
+ self,
456
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
457
+ ):
458
+ """
459
+ Log prediction results to TensorBoard, including images and metrics.
460
+
461
+ This method handles:
462
+ - Logging image-like data (e.g., inputs, outputs, masks) using `DataLog` instances,
463
+ based on the `data_log` configuration.
464
+ - Logging scalar loss and metric values (if present in the network) under the `Prediction/` namespace.
465
+ - Dynamically retrieving additional feature maps or intermediate layers if requested via `data_log`.
466
+
467
+ Logging is performed only on the global rank 0 process and only if `TensorBoard` is active.
468
+
469
+ Args:
470
+ data_dict (dict): Dictionary mapping group names to 6-tuples containing:
471
+ - input tensor,
472
+ - index,
473
+ - patch_augmentation,
474
+ - patch_index,
475
+ - metadata (list of strings),
476
+ - `requires_grad` flag (as a tensor).
477
+ """
478
+ measures = DistributedObject.get_measure(
479
+ self.world_size,
480
+ self.global_rank,
481
+ self.local_rank,
482
+ {"": self.model_composite.module},
483
+ 1,
484
+ )
485
+
486
+ if self.global_rank == 0 and self.tb is not None:
487
+ data_log = []
488
+ if len(self.data_log):
270
489
  for name, data_type in self.data_log.items():
271
490
  if name in data_dict:
272
- data_type[0](self.tb, "Prediction/{}".format(name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
491
+ data_type[0](
492
+ self.tb,
493
+ f"Prediction/{name}",
494
+ data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
495
+ self.it,
496
+ )
273
497
  else:
274
- images_log.append(name.replace(":", "."))
498
+ data_log.append(name.replace(":", "."))
275
499
 
276
- for name, network in self.modelComposite.module.getNetworks().items():
500
+ for name, network in self.model_composite.module.get_networks().items():
277
501
  if network.measure is not None:
278
- self.tb.add_scalars("Prediction/{}/Loss".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][0].items()}, self.it)
279
- self.tb.add_scalars("Prediction/{}/Metric".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][1].items()}, self.it)
280
- if len(images_log):
281
- for name, layer, _ in self.model.module.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
282
- self.data_log[name][0](self.tb, "Prediction/{}".format(name), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
502
+ self.tb.add_scalars(
503
+ f"Prediction/{name}/Loss",
504
+ {k: v[1] for k, v in measures[name][0].items()},
505
+ self.it,
506
+ )
507
+ self.tb.add_scalars(
508
+ f"Prediction/{name}/Metric",
509
+ {k: v[1] for k, v in measures[name][1].items()},
510
+ self.it,
511
+ )
512
+ if len(data_log):
513
+ for name, layer, _ in self.model_composite.module.get_layers(
514
+ [v.to(0) for k, v in self.get_input(data_dict).items() if k[1]], data_log
515
+ ):
516
+ self.data_log[name][0](
517
+ self.tb,
518
+ f"Prediction/{name}",
519
+ layer[: self.data_log[name][1]].detach().cpu().numpy(),
520
+ self.it,
521
+ )
522
+
283
523
 
284
524
  class ModelComposite(Network):
525
+ """
526
+ A composite model that replicates a given base network multiple times and combines their outputs.
527
+
528
+ This class is designed to handle model ensembles or repeated predictions from the same architecture.
529
+ It creates `nb_models` deep copies of the input `model`, each with its own name and output branch,
530
+ and aggregates their outputs using a provided `Reduction` strategy (e.g., mean, median).
531
+
532
+ Args:
533
+ model (Network): The base network to replicate.
534
+ nb_models (int): Number of copies of the model to create.
535
+ combine (Reduction): The reduction method used to combine outputs from all model replicas.
536
+
537
+ Attributes:
538
+ combine (Reduction): The reduction method used during forward inference.
539
+ """
285
540
 
286
541
  def __init__(self, model: Network, nb_models: int, combine: Reduction):
287
- super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
542
+ super().__init__(
543
+ model.in_channels,
544
+ model.optimizer,
545
+ model.lr_schedulers_loader,
546
+ model.outputs_criterions_loader,
547
+ model.patch,
548
+ model.nb_batch_per_step,
549
+ model.init_type,
550
+ model.init_gain,
551
+ model.dim,
552
+ )
288
553
  self.combine = combine
289
554
  for i in range(nb_models):
290
- self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
555
+ self.add_module(
556
+ f"Model_{i}",
557
+ copy.deepcopy(model),
558
+ in_branch=[0],
559
+ out_branch=[f"output_{i}"],
560
+ )
291
561
 
292
- def load(self, state_dicts : list[dict[str, dict[str, torch.Tensor]]]):
562
+ def load(self, state_dicts: list[dict[str, dict[str, torch.Tensor]]]):
563
+ """
564
+ Load weights for each sub-model in the composite from the corresponding state dictionaries.
565
+
566
+ Args:
567
+ state_dicts (list): A list of state dictionaries, one for each model replica.
568
+ """
293
569
  for i, state_dict in enumerate(state_dicts):
294
- self["Model_{}".format(i)].load(state_dict, init=False)
295
- self["Model_{}".format(i)].setName("{}_{}".format(self["Model_{}".format(i)].getName(), i))
296
-
297
- def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
570
+ self[f"Model_{i}"].load(state_dict, init=False)
571
+ self[f"Model_{i}"].set_name(f"{self[f'Model_{i}'].get_name()}_{i}")
572
+
573
+ def forward(
574
+ self,
575
+ data_dict: dict[tuple[str, bool], torch.Tensor],
576
+ output_layers: list[str] = [],
577
+ ) -> list[tuple[str, torch.Tensor]]:
578
+ """
579
+ Perform a forward pass on all model replicas and aggregate their outputs.
580
+
581
+ Args:
582
+ data_dict (dict): A dictionary mapping (group_name, requires_grad) to input tensors.
583
+ output_layers (list): List of output layer names to extract from each sub-model.
584
+
585
+ Returns:
586
+ list[tuple[str, torch.Tensor]]: Aggregated output for each layer, after applying the reduction.
587
+ """
298
588
  result = {}
299
589
  for name, module in self.items():
300
590
  result[name] = module(data_dict, output_layers)
301
-
591
+
302
592
  aggregated = defaultdict(list)
303
593
  for module_outputs in result.values():
304
594
  for key, tensor in module_outputs:
@@ -309,21 +599,41 @@ class ModelComposite(Network):
309
599
  final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
310
600
 
311
601
  return final_outputs
312
-
602
+
313
603
 
314
604
  class Predictor(DistributedObject):
605
+ """
606
+ KonfAI's main prediction controller.
607
+
608
+ This class orchestrates the prediction phase by:
609
+ - Loading model weights from checkpoint(s) or URL(s)
610
+ - Preparing datasets and output configurations
611
+ - Managing distributed inference with optional multi-GPU support
612
+ - Applying transformations and saving predictions
613
+ - Optionally logging results to TensorBoard
614
+
615
+ Attributes:
616
+ model (Network): The neural network model to use for prediction.
617
+ dataset (DataPrediction): Dataset manager for prediction data.
618
+ combine_classpath (str): Path to the reduction strategy (e.g., "Mean").
619
+ autocast (bool): Whether to enable AMP inference.
620
+ outputs_dataset (dict[str, OutputDataset]): Mapping from layer names to output writers.
621
+ data_log (list[str] | None): List of tensors to log during inference.
622
+ """
315
623
 
316
624
  @config("Predictor")
317
- def __init__(self,
318
- model: ModelLoader = ModelLoader(),
319
- dataset: DataPrediction = DataPrediction(),
320
- combine: str = "mean",
321
- train_name: str = "name",
322
- manual_seed : Union[int, None] = None,
323
- gpu_checkpoints: Union[list[str], None] = None,
324
- autocast : bool = False,
325
- outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
326
- images_log: list[str] = []) -> None:
625
+ def __init__(
626
+ self,
627
+ model: ModelLoader = ModelLoader(),
628
+ dataset: DataPrediction = DataPrediction(),
629
+ combine: str = "Mean",
630
+ train_name: str = "name",
631
+ manual_seed: int | None = None,
632
+ gpu_checkpoints: list[str] | None = None,
633
+ autocast: bool = False,
634
+ outputs_dataset: dict[str, OutputDatasetLoader] | None = {"default:Default": OutputDatasetLoader()},
635
+ data_log: list[str] | None = None,
636
+ ) -> None:
327
637
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
328
638
  exit(0)
329
639
  super().__init__(train_name)
@@ -332,99 +642,194 @@ class Predictor(DistributedObject):
332
642
  self.combine_classpath = combine
333
643
  self.autocast = autocast
334
644
 
335
- self.model = model.getModel(train=False)
645
+ self.model = model.get_model(train=False)
336
646
  self.it = 0
337
- self.outsDatasetLoader = outsDataset if outsDataset else {}
338
- self.outsDataset = {name.replace(":", ".") : value.getOutDataset(name) for name, value in self.outsDatasetLoader.items()}
647
+ self.outputs_dataset_loader = outputs_dataset if outputs_dataset else {}
648
+ self.outputs_dataset = {
649
+ name.replace(":", "."): value.get_output_dataset(name)
650
+ for name, value in self.outputs_dataset_loader.items()
651
+ }
339
652
 
340
653
  self.datasets_filename = []
341
- self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
342
- self.images_log = images_log
343
- for outDataset in self.outsDataset.values():
344
- self.datasets_filename.append(outDataset.filename)
345
- outDataset.filename = "{}{}".format(self.predict_path, outDataset.filename)
346
-
347
-
654
+ self.predict_path = predictions_directory() + self.name + "/"
655
+ for output_dataset in self.outputs_dataset.values():
656
+ self.datasets_filename.append(output_dataset.filename)
657
+ output_dataset.filename = f"{self.predict_path}{output_dataset.filename}"
658
+
659
+ self.data_log = data_log
660
+ modules = []
661
+ for i, _ in self.model.named_modules():
662
+ modules.append(i)
663
+ if self.data_log is not None:
664
+ for k in self.data_log:
665
+ tmp = k.split("/")[0].replace(":", ".")
666
+ if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
667
+ raise PredictorError(
668
+ f"Invalid key '{tmp}' in `data_log`.",
669
+ f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
670
+ f"nor a valid module name in the model ({modules}).",
671
+ "Please check your `data_log` configuration,"
672
+ " it should reference either a model output or a dataset group.",
673
+ )
674
+
348
675
  self.gpu_checkpoints = gpu_checkpoints
349
676
 
350
677
  def _load(self) -> list[dict[str, dict[str, torch.Tensor]]]:
351
- model_paths = MODEL().split(":")
678
+ """
679
+ Load pretrained model weights from configured paths or URLs.
680
+
681
+ This method handles both remote and local model sources:
682
+ - If the model path is a URL (starting with "https://"), it uses `torch.hub.load_state_dict_from_url`
683
+ to download and load the state dict.
684
+ - If the model path is local:
685
+ - It either loads the explicit file or resolves the latest model file in a default directory
686
+ based on the prediction name.
687
+ - All loaded state dicts are returned as a list of nested dictionaries mapping module names
688
+ to parameter tensors.
689
+
690
+ Returns:
691
+ list[dict[str, dict[str, torch.Tensor]]]: A list of state dictionaries, one per model.
692
+
693
+ Raises:
694
+ Exception: If a model path does not exist or cannot be loaded.
695
+ """
696
+ model_paths = path_to_models().split(":")
352
697
  state_dicts = []
353
698
  for model_path in model_paths:
354
699
  if model_path.startswith("https://"):
355
700
  try:
356
- state_dicts.append(torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True))
357
- except:
358
- raise Exception("Model : {} does not exist !".format(model_path))
701
+ state_dicts.append(
702
+ torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True)
703
+ )
704
+ except Exception:
705
+ raise Exception(f"Model : {model_path} does not exist !")
359
706
  else:
360
707
  if model_path != "":
361
708
  path = ""
362
709
  name = model_path
363
710
  else:
364
711
  if self.name.endswith(".pt"):
365
- path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
712
+ path = models_directory() + "/".join(self.name.split("/")[:-1]) + "/StateDict/"
366
713
  name = self.name.split("/")[-1]
367
714
  else:
368
- path = MODELS_DIRECTORY()+self.name+"/StateDict/"
715
+ path = models_directory() + self.name + "/StateDict/"
369
716
  name = sorted(os.listdir(path))[-1]
370
- if os.path.exists(path+name):
371
- state_dicts.append(torch.load(path+name, weights_only=False))
717
+ if os.path.exists(path + name):
718
+ state_dicts.append(torch.load(path + name, weights_only=True))
372
719
  else:
373
- raise Exception("Model : {} does not exist !".format(path+name))
720
+ raise Exception(f"Model : {path + name} does not exist !")
374
721
  return state_dicts
375
-
722
+
376
723
  def setup(self, world_size: int):
724
+ """
725
+ Set up the predictor for inference.
726
+
727
+ This method performs all necessary initialization steps before running predictions:
728
+ - Ensures output directories exist, and optionally prompts the user before overwriting existing predictions.
729
+ - Copies the current configuration file (Prediction.yml) into the output directory for reproducibility.
730
+ - Initializes the model in prediction mode, including output configuration and channel tracing.
731
+ - Validates that the configured output groups match existing modules in the model architecture.
732
+ - Dynamically loads pretrained weights from local files or remote URLs.
733
+ - Wraps the base model into a `ModelComposite` to support ensemble inference.
734
+ - Initializes the prediction dataloader, with proper distribution across available GPUs.
735
+ - Loads and prepares each configured `OutputDataset` object for storing predictions.
736
+
737
+ Args:
738
+ world_size (int): Total number of processes or GPUs used for distributed prediction.
739
+
740
+ Raises:
741
+ PredictorError: If an output group does not match any module in the model.
742
+ Exception: If a specified model file or URL is invalid or inaccessible.
743
+ """
377
744
  for dataset_filename in self.datasets_filename:
378
- path = self.predict_path +dataset_filename
745
+ path = self.predict_path + dataset_filename
379
746
  if os.path.exists(path):
380
747
  if os.environ["KONFAI_OVERWRITE"] != "True":
381
- accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
748
+ accept = builtins.input(
749
+ f"The prediction {path} already exists ! Do you want to overwrite it (yes,no) : "
750
+ )
382
751
  if accept != "yes":
383
752
  return
384
-
753
+
385
754
  if not os.path.exists(path):
386
755
  os.makedirs(path)
387
756
 
388
- shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
757
+ shutil.copyfile(config_file(), self.predict_path + "Prediction.yml")
389
758
 
390
-
391
- self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
392
- self.model.init_outputsGroup()
759
+ self.model.init(self.autocast, State.PREDICTION, self.dataset.get_groups_dest())
760
+ self.model.init_outputs_group()
393
761
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
394
-
762
+
395
763
  modules = []
396
- for i,_,_ in self.model.named_ModuleArgsDict():
764
+ for i, _, _ in self.model.named_module_args_dict():
397
765
  modules.append(i)
398
- for output_group in self.outsDataset.keys():
399
- if output_group not in modules:
400
- raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
401
- "Available modules: {}".format(modules),
402
- "Please check that the name matches exactly a submodule or output of your model architecture."
766
+ for output_group in self.outputs_dataset.keys():
767
+ if output_group.replace(";accu;", "") not in modules:
768
+ raise PredictorError(
769
+ f"The output group '{output_group}' defined in 'outputs_criterions' "
770
+ "does not correspond to any module in the model.",
771
+ f"Available modules: {modules}",
772
+ "Please check that the name matches exactly a submodule or" "output of your model architecture.",
403
773
  )
404
-
405
- module, name = _getModule(self.combine_classpath, "konfai.predictor")
774
+
775
+ module, name = get_module(self.combine_classpath, "konfai.predictor")
406
776
  if module == "konfai.predictor":
407
- combine = getattr(importlib.import_module(module), name)
777
+ combine = getattr(importlib.import_module(module), name)()
408
778
  else:
409
- combine = config("{}.{}".format(KONFAI_ROOT(), self.combine_classpath))(getattr(importlib.import_module(module), name))(config = None)
410
-
411
-
412
- self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), combine)
413
- self.modelComposite.load(self._load())
779
+ combine = config(f"{konfai_root()}.{self.combine_classpath}")(
780
+ getattr(importlib.import_module(module), name)
781
+ )(config=None)
782
+
783
+ self.model_composite = ModelComposite(self.model, len(path_to_models().split(":")), combine)
784
+ self.model_composite.load(self._load())
414
785
 
415
- if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
786
+ if (
787
+ len(list(self.outputs_dataset.keys())) == 0
788
+ and len(
789
+ [network for network in self.model_composite.get_networks().values() if network.measure is not None]
790
+ )
791
+ == 0
792
+ ):
416
793
  exit(0)
417
-
418
- self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
419
- self.dataloader = self.dataset.getData(world_size//self.size)
420
- for name, outDataset in self.outsDataset.items():
421
- outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
422
-
423
-
424
- def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
425
- modelComposite = Network.to(self.modelComposite, local_rank*self.size)
426
- modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
427
- with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
428
- p.run()
429
794
 
430
-
795
+ self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
796
+ self.dataloader = self.dataset.get_data(world_size // self.size)
797
+ for name, output_dataset in self.outputs_dataset.items():
798
+ output_dataset.load(
799
+ name.replace(".", ":"),
800
+ list(self.dataset.datasets.values()),
801
+ {src: dest for src, inner in self.dataset.groups_src.items() for dest in inner},
802
+ )
803
+
804
+ def run_process(
805
+ self,
806
+ world_size: int,
807
+ global_rank: int,
808
+ local_rank: int,
809
+ dataloaders: list[DataLoader],
810
+ ):
811
+ """
812
+ Launch prediction on the given process rank.
813
+
814
+ Args:
815
+ world_size (int): Total number of processes.
816
+ global_rank (int): Rank of the current process.
817
+ local_rank (int): Local device rank.
818
+ dataloaders (list[DataLoader]): List of data loaders for prediction.
819
+ """
820
+ model_composite = Network.to(self.model_composite, local_rank * self.size)
821
+ model_composite = (
822
+ DDP(model_composite, static_graph=True) if torch.cuda.is_available() else CPUModel(model_composite)
823
+ )
824
+ with _Predictor(
825
+ world_size,
826
+ global_rank,
827
+ local_rank,
828
+ self.autocast,
829
+ self.predict_path,
830
+ self.data_log,
831
+ self.outputs_dataset,
832
+ model_composite,
833
+ *dataloaders,
834
+ ) as p:
835
+ p.run()