konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/evaluator.py CHANGED
@@ -1,69 +1,182 @@
1
+ import builtins
2
+ import importlib
3
+ import json
1
4
  import os
2
- from torch.utils.data import DataLoader
5
+ import shutil
6
+ from typing import Any
7
+
8
+ import numpy as np
3
9
  import torch
4
10
  import tqdm
5
- import numpy as np
6
- import json
7
- import shutil
8
- import builtins
9
- import importlib
10
- from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
11
- from konfai.utils.config import config
12
- from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
11
+ from torch.utils.data import DataLoader
12
+
13
+ from konfai import config_file, evaluations_directory, konfai_root
13
14
  from konfai.data.data_manager import DataMetric
15
+ from konfai.utils.config import config
16
+ from konfai.utils.utils import DistributedObject, EvaluatorError, get_module, synchronize_data
17
+
14
18
 
15
- class CriterionsAttr():
19
+ class CriterionsAttr:
20
+ """
21
+ Container for additional metadata or configuration attributes related to a loss criterion.
22
+
23
+ This class is currently empty but acts as a placeholder for future extension.
24
+ It is passed along with each loss function to allow parameterization or inspection of its behavior.
25
+
26
+ Use cases may include:
27
+ - Weighting of individual loss terms
28
+ - Conditional activation
29
+ - Logging preferences
30
+ """
16
31
 
17
32
  @config()
18
33
  def __init__(self) -> None:
19
- pass
34
+ pass
35
+
36
+
37
+ class CriterionsLoader:
38
+ """
39
+ Loader for multiple criterion modules to be applied between a model output and one or more targets.
20
40
 
21
- class CriterionsLoader():
41
+ Each loss module (e.g., Dice, CrossEntropy, NCC) is dynamically loaded using its fully-qualified
42
+ classpath and is associated with a `CriterionsAttr` configuration object.
43
+
44
+ Args:
45
+ criterions_loader (dict): A mapping from module classpaths (as strings) to `CriterionsAttr` instances.
46
+ The module path is parsed and instantiated via `get_module`.
47
+
48
+ """
22
49
 
23
50
  @config()
24
- def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
25
- self.criterionsLoader = criterionsLoader
51
+ def __init__(
52
+ self,
53
+ criterions_loader: dict[str, CriterionsAttr] = {"default:torch:nn:CrossEntropyLoss:Dice:NCC": CriterionsAttr()},
54
+ ) -> None:
55
+ self.criterions_loader = criterions_loader
26
56
 
27
- def getCriterions(self, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
57
+ def get_criterions(self, output_group: str, target_group: str) -> dict[torch.nn.Module, CriterionsAttr]:
28
58
  criterions = {}
29
- for module_classpath, criterionsAttr in self.criterionsLoader.items():
30
- module, name = _getModule(module_classpath, "konfai.metric.measure")
31
- criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
59
+ for module_classpath, criterions_attr in self.criterions_loader.items():
60
+ module, name = get_module(module_classpath, "konfai.metric.measure")
61
+ criterions[
62
+ config(
63
+ f"{konfai_root()}.metrics.{output_group}.targets_criterions.{target_group}"
64
+ f".criterions_loader.{module_classpath}"
65
+ )(getattr(importlib.import_module(module), name))(config=None)
66
+ ] = criterions_attr
32
67
  return criterions
33
68
 
34
- class TargetCriterionsLoader():
69
+
70
+ class TargetCriterionsLoader:
71
+ """
72
+ Loader class for handling multiple target groups with associated criterion configurations.
73
+
74
+ This class allows defining a set of criterion loaders (e.g., Dice, BCE, MSE) for each
75
+ target group to be used during evaluation or training. Each target group corresponds
76
+ to one or more loss functions, all linked to a specific model output.
77
+
78
+ Args:
79
+ targets_criterions (dict[str, CriterionsLoader]): Dictionary mapping each target group name
80
+ to a `CriterionsLoader` instance that defines its associated loss functions.
81
+ """
35
82
 
36
83
  @config()
37
- def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
38
- self.targetsCriterions = targetsCriterions
39
-
40
- def getTargetsCriterions(self, output_group : str) -> dict[str, dict[torch.nn.Module, float]]:
41
- targetsCriterions = {}
42
- for target_group, criterionsLoader in self.targetsCriterions.items():
43
- targetsCriterions[target_group] = criterionsLoader.getCriterions(output_group, target_group)
44
- return targetsCriterions
84
+ def __init__(
85
+ self,
86
+ targets_criterions: dict[str, CriterionsLoader] = {"default": CriterionsLoader()},
87
+ ) -> None:
88
+ self.targets_criterions = targets_criterions
89
+
90
+ def get_targets_criterions(self, output_group: str) -> dict[str, dict[torch.nn.Module, CriterionsAttr]]:
91
+ """
92
+ Retrieve the criterion modules and their attributes for a specific output group.
93
+
94
+ This function prepares the loss functions to be applied for a given model output,
95
+ grouped by their target group.
96
+
97
+ Args:
98
+ output_group (str): Name of the model output group (e.g., "output_segmentation").
99
+
100
+ Returns:
101
+ dict[str, dict[nn.Module, CriterionsAttr]]: A nested dictionary where the first key is the
102
+ target group name, and the value is a dictionary mapping each loss module to its attributes.
103
+ """
104
+ targets_criterions = {}
105
+ for target_group, criterions_loader in self.targets_criterions.items():
106
+ targets_criterions[target_group] = criterions_loader.get_criterions(output_group, target_group)
107
+ return targets_criterions
108
+
109
+
110
+ class Statistics:
111
+ """
112
+ Utility class to accumulate, structure, and write evaluation metric results.
113
+
114
+ This class is used to:
115
+ - Collect metrics for each dataset sample.
116
+ - Compute aggregate statistics (mean, std, percentiles, etc.).
117
+ - Export all results in a structured JSON format, including both per-case and aggregate values.
45
118
 
46
- class Statistics():
119
+ Args:
120
+ filename (str): Path to the output JSON file that will store the final results.
121
+ """
47
122
 
48
123
  def __init__(self, filename: str) -> None:
49
124
  self.measures: dict[str, dict[str, float]] = {}
50
125
  self.filename = filename
51
126
 
52
127
  def add(self, values: dict[str, float], name_dataset: str) -> None:
128
+ """
129
+ Add a set of metric values for a given dataset case.
130
+
131
+ Args:
132
+ values (dict): Dictionary of metric names and their values.
133
+ name_dataset (str): Identifier (e.g., case name) for the sample.
134
+ """
53
135
  for name, value in values.items():
54
136
  if name_dataset not in self.measures:
55
137
  self.measures[name_dataset] = {}
56
138
  self.measures[name_dataset][name] = value
57
-
58
- @staticmethod
59
- def getStatistic(values: list[float]) -> dict[str, float]:
60
- return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
61
-
62
- def write(self, outputs: list[dict[str, any]]) -> None:
139
+
140
+ @staticmethod
141
+ def get_statistic(values: list[float]) -> dict[str, float]:
142
+ """
143
+ Compute statistical aggregates for a list of metric values.
144
+
145
+ Args:
146
+ values (list of float): Values to summarize.
147
+
148
+ Returns:
149
+ dict[str, float]: A dictionary containing:
150
+ - max, min, std
151
+ - 25th, 50th, and 75th percentiles
152
+ - mean and count
153
+ """
154
+ return {
155
+ "max": np.nanmax(values) if np.any(~np.isnan(values)) else np.nan,
156
+ "min": np.nanmin(values) if np.any(~np.isnan(values)) else np.nan,
157
+ "std": np.nanstd(values) if np.any(~np.isnan(values)) else np.nan,
158
+ "25pc": np.nanpercentile(values, 25) if np.any(~np.isnan(values)) else np.nan,
159
+ "50pc": np.nanpercentile(values, 50) if np.any(~np.isnan(values)) else np.nan,
160
+ "75pc": np.nanpercentile(values, 75) if np.any(~np.isnan(values)) else np.nan,
161
+ "mean": np.nanmean(values) if np.any(~np.isnan(values)) else np.nan,
162
+ "count": np.count_nonzero(~np.isnan(values)) if np.any(~np.isnan(values)) else np.nan,
163
+ }
164
+
165
+ def write(self, outputs: list[dict[str, dict[str, Any]]]) -> None:
166
+ """
167
+ Write the collected and aggregated statistics to the configured output file.
168
+
169
+ The output JSON structure contains:
170
+ - `case`: All individual metrics per sample.
171
+ - `aggregates`: Global statistics computed over all cases.
172
+
173
+ Args:
174
+ outputs (list): List of metric dictionaries to merge and serialize.
175
+ """
63
176
  measures = {}
64
177
  for output in outputs:
65
178
  measures.update(output)
66
- result = {}
179
+ result: dict[str, dict[str, dict[str, Any]]] = {}
67
180
  result["case"] = {}
68
181
  for name, v in measures.items():
69
182
  for metric_name, value in v.items():
@@ -73,92 +186,236 @@ class Statistics():
73
186
 
74
187
  result["aggregates"] = {}
75
188
  tmp: dict[str, list[float]] = {}
76
- for name, v in measures.items():
77
- for metric_name, value in v.items():
189
+ for _, v in measures.items():
190
+ for metric_name, _ in v.items():
78
191
  if metric_name not in tmp:
79
192
  tmp[metric_name] = []
80
193
  tmp[metric_name].append(v[metric_name])
81
194
  for metric_name, values in tmp.items():
82
- result["aggregates"][metric_name] = Statistics.getStatistic(values)
195
+ result["aggregates"][metric_name] = Statistics.get_statistic(values)
83
196
 
84
197
  with open(self.filename, "w") as f:
85
198
  f.write(json.dumps(result, indent=4))
86
-
199
+
200
+
87
201
  class Evaluator(DistributedObject):
202
+ """
203
+ Distributed evaluation engine for computing metrics on model predictions.
204
+
205
+ This class handles the evaluation of predicted outputs using predefined metric loaders.
206
+ It supports multi-output and multi-target configurations, computes aggregated statistics
207
+ across training and validation datasets, and synchronizes results across processes.
208
+
209
+ Evaluation results are stored in JSON format and optionally displayed during iteration.
210
+
211
+ Args:
212
+ train_name (str): Unique name of the evaluation run, used for logging and output folders.
213
+ metrics (dict[str, TargetCriterionsLoader]): Dictionary mapping output groups to loaders of target metrics.
214
+ dataset (DataMetric): Dataset provider configured for evaluation mode.
215
+
216
+ Attributes:
217
+ statistics_train (Statistics): Object used to store training evaluation metrics.
218
+ statistics_validation (Statistics): Object used to store validation evaluation metrics.
219
+ dataloader (list[DataLoader]): DataLoaders for training and validation sets.
220
+ metric_path (str): Path to the evaluation output directory.
221
+ metrics (dict): Instantiated metrics organized by output and target groups.
222
+ """
88
223
 
89
224
  @config("Evaluator")
90
- def __init__(self, train_name: str = "default:TRAIN_01", metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()}, dataset : DataMetric = DataMetric(),) -> None:
225
+ def __init__(
226
+ self,
227
+ train_name: str = "default:TRAIN_01",
228
+ metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()},
229
+ dataset: DataMetric = DataMetric(),
230
+ ) -> None:
91
231
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
92
232
  exit(0)
93
233
  super().__init__(train_name)
94
- self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
234
+ self.metric_path = evaluations_directory() + self.name + "/"
95
235
  self.metricsLoader = metrics
96
236
  self.dataset = dataset
97
- self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
98
- self.statistics_train = Statistics(self.metric_path+"Metric_TRAIN.json")
99
- self.statistics_validation = Statistics(self.metric_path+"Metric_VALIDATION.json")
237
+ self.metrics = {k: v.get_targets_criterions(k) for k, v in self.metricsLoader.items()}
238
+ self.statistics_train = Statistics(self.metric_path + "Metric_TRAIN.json")
239
+ self.statistics_validation = Statistics(self.metric_path + "Metric_VALIDATION.json")
240
+
241
+ def update(self, data_dict: dict[str, tuple[torch.Tensor, str]], statistics: Statistics) -> dict[str, float]:
242
+ """
243
+ Compute metrics for a batch and update running statistics.
100
244
 
101
- def update(self, data_dict: dict[str, tuple[torch.Tensor, str]], statistics : Statistics) -> dict[str, float]:
245
+ Args:
246
+ data_dict (dict): Dictionary where keys are output/target group names and values are
247
+ tuples of (tensor, sample name).
248
+ statistics (Statistics): The statistics object to update (train or validation).
249
+
250
+ Returns:
251
+ dict[str, float]: Dictionary of computed metric values with keys in the format
252
+ 'output_group:target_group:MetricName'.
253
+ """
102
254
  result = {}
103
255
  for output_group in self.metrics:
104
256
  for target_group in self.metrics[output_group]:
105
- targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split(";") if group in data_dict]
257
+ targets = [
258
+ (data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0])
259
+ for group in target_group.split(";")
260
+ if group in data_dict
261
+ ]
106
262
  name = data_dict[output_group][1][0]
107
263
  for metric in self.metrics[output_group][target_group]:
108
- result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
264
+ loss = metric(
265
+ (data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0]),
266
+ *targets,
267
+ )
268
+ if isinstance(loss, tuple):
269
+ true_loss = loss[1]
270
+ else:
271
+ true_loss = loss.item()
272
+
273
+ if isinstance(true_loss, dict):
274
+ loss = 0
275
+ c = 0
276
+ for k, v in true_loss.items():
277
+ result[f"{output_group}:{target_group}:{metric.__class__.__name__}:{k}"] = v
278
+ if not np.isnan(v):
279
+ loss += v
280
+ c += 1
281
+ result[f"{output_group}:{target_group}:{metric.__class__.__name__}"] = loss / c
282
+ else:
283
+ result[f"{output_group}:{target_group}:{metric.__class__.__name__}"] = true_loss
109
284
  statistics.add(result, name)
110
285
  return result
111
-
286
+
112
287
  def setup(self, world_size: int):
288
+ """
289
+ Prepare the evaluator for distributed metric computation.
290
+
291
+ This method performs the following steps:
292
+ - Checks whether previous evaluation results exist and optionally overwrites them.
293
+ - Creates the output directory and copies the current configuration file for reproducibility.
294
+ - Loads the evaluation dataset according to the world size.
295
+ - Validates that all specified output and target groups used in metric definitions
296
+ are present in the dataset group configuration.
297
+
298
+ Args:
299
+ world_size (int): Number of processes in the distributed evaluation setup.
300
+
301
+ Raises:
302
+ EvaluatorError: If any metric output or target group is missing in the dataset's group mapping.
303
+ """
113
304
  if os.path.exists(self.metric_path):
114
305
  if os.environ["KONFAI_OVERWRITE"] != "True":
115
- accept = builtins.input("The metric {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
306
+ accept = builtins.input(
307
+ f"The metric {self.name} already exists ! Do you want to overwrite it (yes,no) : "
308
+ )
116
309
  if accept != "yes":
117
310
  return
118
-
311
+
119
312
  if os.path.exists(self.metric_path):
120
- shutil.rmtree(self.metric_path)
313
+ shutil.rmtree(self.metric_path)
121
314
 
122
315
  if not os.path.exists(self.metric_path):
123
316
  os.makedirs(self.metric_path)
124
- metric_namefile_src = CONFIG_FILE().replace(".yml", "")
125
- shutil.copyfile(metric_namefile_src+".yml", "{}{}.yml".format(self.metric_path, metric_namefile_src))
317
+ metric_namefile_src = config_file().replace(".yml", "")
318
+ shutil.copyfile(
319
+ metric_namefile_src + ".yml",
320
+ f"{self.metric_path}{metric_namefile_src}.yml",
321
+ )
126
322
 
127
- self.dataloader = self.dataset.getData(world_size)
323
+ self.dataloader, _, _ = self.dataset.get_data(world_size)
128
324
 
129
- groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
325
+ groups_dest = [group for groups in self.dataset.groups_src.values() for group in groups]
130
326
 
131
- missing_outputs = set(self.metrics.keys()) - set(groupsDest)
327
+ missing_outputs = set(self.metrics.keys()) - set(groups_dest)
132
328
  if missing_outputs:
133
329
  raise EvaluatorError(
134
- f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
135
- f"Available groups: {sorted(groupsDest)}"
330
+ f"The following metric output groups are missing from 'groups_dest': {sorted(missing_outputs)}. ",
331
+ f"Available groups: {sorted(groups_dest)}",
136
332
  )
137
333
 
138
334
  target_groups = []
139
335
  for i in {target for targets in self.metrics.values() for target in targets}:
140
336
  for u in i.split(";"):
141
337
  target_groups.append(u)
142
- missing_targets = set(target_groups) - set(groupsDest)
338
+ missing_targets = set(target_groups) - set(groups_dest)
143
339
  if missing_targets:
144
340
  raise EvaluatorError(
145
- f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
146
- f"Available groups: {sorted(groupsDest)}"
341
+ f"The following metric target groups are missing from 'groups_dest': {sorted(missing_targets)}. ",
342
+ f"Available groups: {sorted(groups_dest)}",
147
343
  )
148
344
 
149
345
  def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
150
- description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
151
- with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
346
+ """
347
+ Execute the distributed evaluation loop over the training and validation datasets.
348
+
349
+ This method iterates through the provided DataLoaders (train and optionally validation),
350
+ updates the metric statistics using the configured `metrics` dictionary, and synchronizes
351
+ the results across all processes. On the global rank 0, the metrics are saved as JSON files.
352
+
353
+ Metrics are displayed in real-time using `tqdm` progress bars, showing a summary of the
354
+ current batch's computed values.
355
+
356
+ Args:
357
+ world_size (int): Total number of distributed processes.
358
+ global_rank (int): Global rank of the current process (used for writing results).
359
+ gpu (int): Local GPU ID used for synchronization.
360
+ dataloaders (list[DataLoader]): A list containing one or two DataLoaders:
361
+ - `dataloaders[0]` is used for training evaluation.
362
+ - `dataloaders[1]` (optional) is used for validation evaluation.
363
+
364
+ Notes:
365
+ - Only the main process (`global_rank == 0`) writes final results to disk.
366
+ """
367
+
368
+ def description(measure):
369
+ return (
370
+ f"Metric TRAIN : {' | '.join(f'{k}: {v:.4f}' for k, v in measure.items())}"
371
+ if measure is not None
372
+ else "Metric TRAIN : "
373
+ )
374
+
375
+ with tqdm.tqdm(
376
+ iterable=enumerate(dataloaders[0]),
377
+ leave=True,
378
+ desc=description(None),
379
+ total=len(dataloaders[0]),
380
+ ncols=0,
381
+ ) as batch_iter:
152
382
  for _, data_dict in batch_iter:
153
- batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
383
+ batch_iter.set_description(
384
+ description(
385
+ self.update(
386
+ {k: (v[0], v[4]) for k, v in data_dict.items()},
387
+ self.statistics_train,
388
+ )
389
+ )
390
+ )
154
391
  outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
155
392
  if global_rank == 0:
156
393
  self.statistics_train.write(outputs)
157
394
  if len(dataloaders) == 2:
158
- description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
159
- with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
395
+
396
+ def description(measure):
397
+ return (
398
+ f"Metric VALIDATION : {' | '.join(f'{k}: {v:.2f}' for k, v in measure.items())}"
399
+ if measure is not None
400
+ else "Metric VALIDATION : "
401
+ )
402
+
403
+ with tqdm.tqdm(
404
+ iterable=enumerate(dataloaders[1]),
405
+ leave=True,
406
+ desc=description(None),
407
+ total=len(dataloaders[1]),
408
+ ncols=0,
409
+ ) as batch_iter:
160
410
  for _, data_dict in batch_iter:
161
- batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
411
+ batch_iter.set_description(
412
+ description(
413
+ self.update(
414
+ {k: (v[0], v[4]) for k, v in data_dict.items()},
415
+ self.statistics_validation,
416
+ )
417
+ )
418
+ )
162
419
  outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
163
420
  if global_rank == 0:
164
- self.statistics_validation.write(outputs)
421
+ self.statistics_validation.write(outputs)
konfai/main.py CHANGED
@@ -1,53 +1,102 @@
1
1
  import argparse
2
2
  import os
3
- from torch.cuda import device_count
3
+ import sys
4
+
4
5
  import torch.multiprocessing as mp
5
- from konfai.utils.utils import setup, TensorBoard, Log
6
- from konfai import KONFAI_NB_CORES
6
+ from torch.cuda import device_count
7
+
8
+ from konfai import konfai_nb_cores
9
+ from konfai.utils.utils import Log, TensorBoard, setup
7
10
 
8
- import sys
9
11
  sys.path.insert(0, os.getcwd())
10
12
 
13
+
11
14
  def main():
15
+ """
16
+ Entry point for launching KonfAI training locally.
17
+
18
+ - Parses arguments (if any) via a setup parser.
19
+ - Initializes distributed environment based on available CUDA devices or CPU cores.
20
+ - Launches training via `mp.spawn`.
21
+ - Manages logging and TensorBoard context.
22
+
23
+ KeyboardInterrupt is caught to allow clean manual termination.
24
+ """
12
25
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
13
26
  try:
14
- with setup(parser) as distributedObject:
15
- with Log(distributedObject.name, 0):
27
+ with setup(parser) as distributed_object:
28
+ with Log(distributed_object.name, 0):
16
29
  world_size = device_count()
17
30
  if world_size == 0:
18
- world_size = int(KONFAI_NB_CORES())
19
- distributedObject.setup(world_size)
20
- with TensorBoard(distributedObject.name):
21
- mp.spawn(distributedObject, nprocs=world_size)
31
+ world_size = int(konfai_nb_cores())
32
+ distributed_object.setup(world_size)
33
+ with TensorBoard(distributed_object.name):
34
+ mp.spawn(distributed_object, nprocs=world_size)
22
35
  except KeyboardInterrupt:
23
36
  print("\n[KonfAI] Manual interruption (Ctrl+C)")
24
37
 
38
+
25
39
  def cluster():
40
+ """
41
+ Entry point for launching KonfAI on a cluster using Submitit.
42
+
43
+ - Parses cluster-specific arguments: job name, nodes, memory, time limit, etc.
44
+ - Sets up distributed environment based on number of nodes and GPUs.
45
+ - Configures Submitit executor with job specs.
46
+ - Submits the job to SLURM (or another Submitit-compatible backend).
47
+
48
+ Environment variables:
49
+ KONFAI_OVERWRITE: Set to force overwrite of previous training runs.
50
+ KONFAI_CLUSTER: Mark this as a cluster job (used downstream).
51
+
52
+ Raises:
53
+ KeyboardInterrupt: On manual interruption.
54
+ Exception: Any submission-related error is printed and causes exit.
55
+ """
26
56
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
57
 
28
58
  # Cluster manager arguments
29
- cluster_args = parser.add_argument_group('Cluster manager arguments')
30
- cluster_args.add_argument('--name', type=str, help='Task name', required=True)
31
- cluster_args.add_argument('--num-nodes', '--num_nodes', default=1, type=int, help='Number of nodes')
32
- cluster_args.add_argument('--memory', type=int, default=16, help='Amount of memory per node')
33
- cluster_args.add_argument('--time-limit', '--time_limit', type=int, default=1440, help='Job time limit in minute')
34
- cluster_args.add_argument('--resubmit', action='store_true', help='Automatically resubmit job just before timout')
59
+ cluster_args = parser.add_argument_group("Cluster manager arguments")
60
+ cluster_args.add_argument("--name", type=str, help="Task name", required=True)
61
+ cluster_args.add_argument("--num-nodes", "--num_nodes", default=1, type=int, help="Number of nodes")
62
+ cluster_args.add_argument("--memory", type=int, default=16, help="Amount of memory per node")
63
+ cluster_args.add_argument(
64
+ "--time-limit",
65
+ "--time_limit",
66
+ type=int,
67
+ default=1440,
68
+ help="Job time limit in minute",
69
+ )
70
+ cluster_args.add_argument(
71
+ "--resubmit",
72
+ action="store_true",
73
+ help="Automatically resubmit job just before timout",
74
+ )
35
75
  try:
36
- with setup(parser) as distributedObject:
76
+ with setup(parser) as distributed_object:
37
77
  args = parser.parse_args()
38
78
  config = vars(args)
39
79
  os.environ["KONFAI_OVERWRITE"] = "True"
40
80
  os.environ["KONFAI_CLUSTER"] = "True"
41
81
 
42
82
  n_gpu = len(config["gpu"].split(","))
43
- distributedObject.setup(n_gpu*int(config["num_nodes"]))
83
+ distributed_object.setup(n_gpu * int(config["num_nodes"]))
44
84
  import submitit
85
+
45
86
  executor = submitit.AutoExecutor(folder="./Cluster/")
46
- executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
47
- with TensorBoard(distributedObject.name):
48
- executor.submit(distributedObject)
87
+ executor.update_parameters(
88
+ name=config["name"],
89
+ mem_gb=config["memory"],
90
+ gpus_per_node=n_gpu,
91
+ tasks_per_node=n_gpu // distributed_object.size,
92
+ cpus_per_task=config["num_workers"],
93
+ nodes=config["num_nodes"],
94
+ timeout_min=config["time_limit"],
95
+ )
96
+ with TensorBoard(distributed_object.name):
97
+ executor.submit(distributed_object)
49
98
  except KeyboardInterrupt:
50
99
  print("\n[KonfAI] Manual interruption (Ctrl+C)")
51
100
  except Exception as e:
52
101
  print(e)
53
- exit(1)
102
+ exit(1)