konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/evaluator.py
CHANGED
|
@@ -1,69 +1,182 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
import importlib
|
|
3
|
+
import json
|
|
1
4
|
import os
|
|
2
|
-
|
|
5
|
+
import shutil
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
3
9
|
import torch
|
|
4
10
|
import tqdm
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
import
|
|
8
|
-
import builtins
|
|
9
|
-
import importlib
|
|
10
|
-
from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
|
|
11
|
-
from konfai.utils.config import config
|
|
12
|
-
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
|
|
13
|
+
from konfai import config_file, evaluations_directory, konfai_root
|
|
13
14
|
from konfai.data.data_manager import DataMetric
|
|
15
|
+
from konfai.utils.config import config
|
|
16
|
+
from konfai.utils.utils import DistributedObject, EvaluatorError, get_module, synchronize_data
|
|
17
|
+
|
|
14
18
|
|
|
15
|
-
class CriterionsAttr
|
|
19
|
+
class CriterionsAttr:
|
|
20
|
+
"""
|
|
21
|
+
Container for additional metadata or configuration attributes related to a loss criterion.
|
|
22
|
+
|
|
23
|
+
This class is currently empty but acts as a placeholder for future extension.
|
|
24
|
+
It is passed along with each loss function to allow parameterization or inspection of its behavior.
|
|
25
|
+
|
|
26
|
+
Use cases may include:
|
|
27
|
+
- Weighting of individual loss terms
|
|
28
|
+
- Conditional activation
|
|
29
|
+
- Logging preferences
|
|
30
|
+
"""
|
|
16
31
|
|
|
17
32
|
@config()
|
|
18
33
|
def __init__(self) -> None:
|
|
19
|
-
pass
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CriterionsLoader:
|
|
38
|
+
"""
|
|
39
|
+
Loader for multiple criterion modules to be applied between a model output and one or more targets.
|
|
20
40
|
|
|
21
|
-
|
|
41
|
+
Each loss module (e.g., Dice, CrossEntropy, NCC) is dynamically loaded using its fully-qualified
|
|
42
|
+
classpath and is associated with a `CriterionsAttr` configuration object.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
criterions_loader (dict): A mapping from module classpaths (as strings) to `CriterionsAttr` instances.
|
|
46
|
+
The module path is parsed and instantiated via `get_module`.
|
|
47
|
+
|
|
48
|
+
"""
|
|
22
49
|
|
|
23
50
|
@config()
|
|
24
|
-
def __init__(
|
|
25
|
-
self
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
criterions_loader: dict[str, CriterionsAttr] = {"default:torch:nn:CrossEntropyLoss:Dice:NCC": CriterionsAttr()},
|
|
54
|
+
) -> None:
|
|
55
|
+
self.criterions_loader = criterions_loader
|
|
26
56
|
|
|
27
|
-
def
|
|
57
|
+
def get_criterions(self, output_group: str, target_group: str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
28
58
|
criterions = {}
|
|
29
|
-
for module_classpath,
|
|
30
|
-
module, name =
|
|
31
|
-
criterions[
|
|
59
|
+
for module_classpath, criterions_attr in self.criterions_loader.items():
|
|
60
|
+
module, name = get_module(module_classpath, "konfai.metric.measure")
|
|
61
|
+
criterions[
|
|
62
|
+
config(
|
|
63
|
+
f"{konfai_root()}.metrics.{output_group}.targets_criterions.{target_group}"
|
|
64
|
+
f".criterions_loader.{module_classpath}"
|
|
65
|
+
)(getattr(importlib.import_module(module), name))(config=None)
|
|
66
|
+
] = criterions_attr
|
|
32
67
|
return criterions
|
|
33
68
|
|
|
34
|
-
|
|
69
|
+
|
|
70
|
+
class TargetCriterionsLoader:
|
|
71
|
+
"""
|
|
72
|
+
Loader class for handling multiple target groups with associated criterion configurations.
|
|
73
|
+
|
|
74
|
+
This class allows defining a set of criterion loaders (e.g., Dice, BCE, MSE) for each
|
|
75
|
+
target group to be used during evaluation or training. Each target group corresponds
|
|
76
|
+
to one or more loss functions, all linked to a specific model output.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
targets_criterions (dict[str, CriterionsLoader]): Dictionary mapping each target group name
|
|
80
|
+
to a `CriterionsLoader` instance that defines its associated loss functions.
|
|
81
|
+
"""
|
|
35
82
|
|
|
36
83
|
@config()
|
|
37
|
-
def __init__(
|
|
38
|
-
self
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
targets_criterions: dict[str, CriterionsLoader] = {"default": CriterionsLoader()},
|
|
87
|
+
) -> None:
|
|
88
|
+
self.targets_criterions = targets_criterions
|
|
89
|
+
|
|
90
|
+
def get_targets_criterions(self, output_group: str) -> dict[str, dict[torch.nn.Module, CriterionsAttr]]:
|
|
91
|
+
"""
|
|
92
|
+
Retrieve the criterion modules and their attributes for a specific output group.
|
|
93
|
+
|
|
94
|
+
This function prepares the loss functions to be applied for a given model output,
|
|
95
|
+
grouped by their target group.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
output_group (str): Name of the model output group (e.g., "output_segmentation").
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
dict[str, dict[nn.Module, CriterionsAttr]]: A nested dictionary where the first key is the
|
|
102
|
+
target group name, and the value is a dictionary mapping each loss module to its attributes.
|
|
103
|
+
"""
|
|
104
|
+
targets_criterions = {}
|
|
105
|
+
for target_group, criterions_loader in self.targets_criterions.items():
|
|
106
|
+
targets_criterions[target_group] = criterions_loader.get_criterions(output_group, target_group)
|
|
107
|
+
return targets_criterions
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Statistics:
|
|
111
|
+
"""
|
|
112
|
+
Utility class to accumulate, structure, and write evaluation metric results.
|
|
113
|
+
|
|
114
|
+
This class is used to:
|
|
115
|
+
- Collect metrics for each dataset sample.
|
|
116
|
+
- Compute aggregate statistics (mean, std, percentiles, etc.).
|
|
117
|
+
- Export all results in a structured JSON format, including both per-case and aggregate values.
|
|
45
118
|
|
|
46
|
-
|
|
119
|
+
Args:
|
|
120
|
+
filename (str): Path to the output JSON file that will store the final results.
|
|
121
|
+
"""
|
|
47
122
|
|
|
48
123
|
def __init__(self, filename: str) -> None:
|
|
49
124
|
self.measures: dict[str, dict[str, float]] = {}
|
|
50
125
|
self.filename = filename
|
|
51
126
|
|
|
52
127
|
def add(self, values: dict[str, float], name_dataset: str) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Add a set of metric values for a given dataset case.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
values (dict): Dictionary of metric names and their values.
|
|
133
|
+
name_dataset (str): Identifier (e.g., case name) for the sample.
|
|
134
|
+
"""
|
|
53
135
|
for name, value in values.items():
|
|
54
136
|
if name_dataset not in self.measures:
|
|
55
137
|
self.measures[name_dataset] = {}
|
|
56
138
|
self.measures[name_dataset][name] = value
|
|
57
|
-
|
|
58
|
-
@staticmethod
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def get_statistic(values: list[float]) -> dict[str, float]:
|
|
142
|
+
"""
|
|
143
|
+
Compute statistical aggregates for a list of metric values.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
values (list of float): Values to summarize.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
dict[str, float]: A dictionary containing:
|
|
150
|
+
- max, min, std
|
|
151
|
+
- 25th, 50th, and 75th percentiles
|
|
152
|
+
- mean and count
|
|
153
|
+
"""
|
|
154
|
+
return {
|
|
155
|
+
"max": np.max(values),
|
|
156
|
+
"min": np.min(values),
|
|
157
|
+
"std": np.std(values),
|
|
158
|
+
"25pc": np.percentile(values, 25),
|
|
159
|
+
"50pc": np.percentile(values, 50),
|
|
160
|
+
"75pc": np.percentile(values, 75),
|
|
161
|
+
"mean": np.mean(values),
|
|
162
|
+
"count": len(values),
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
def write(self, outputs: list[dict[str, dict[str, Any]]]) -> None:
|
|
166
|
+
"""
|
|
167
|
+
Write the collected and aggregated statistics to the configured output file.
|
|
168
|
+
|
|
169
|
+
The output JSON structure contains:
|
|
170
|
+
- `case`: All individual metrics per sample.
|
|
171
|
+
- `aggregates`: Global statistics computed over all cases.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
outputs (list): List of metric dictionaries to merge and serialize.
|
|
175
|
+
"""
|
|
63
176
|
measures = {}
|
|
64
177
|
for output in outputs:
|
|
65
178
|
measures.update(output)
|
|
66
|
-
result = {}
|
|
179
|
+
result: dict[str, dict[str, dict[str, Any]]] = {}
|
|
67
180
|
result["case"] = {}
|
|
68
181
|
for name, v in measures.items():
|
|
69
182
|
for metric_name, value in v.items():
|
|
@@ -73,92 +186,220 @@ class Statistics():
|
|
|
73
186
|
|
|
74
187
|
result["aggregates"] = {}
|
|
75
188
|
tmp: dict[str, list[float]] = {}
|
|
76
|
-
for
|
|
77
|
-
for metric_name,
|
|
189
|
+
for _, v in measures.items():
|
|
190
|
+
for metric_name, _ in v.items():
|
|
78
191
|
if metric_name not in tmp:
|
|
79
192
|
tmp[metric_name] = []
|
|
80
193
|
tmp[metric_name].append(v[metric_name])
|
|
81
194
|
for metric_name, values in tmp.items():
|
|
82
|
-
result["aggregates"][metric_name] = Statistics.
|
|
195
|
+
result["aggregates"][metric_name] = Statistics.get_statistic(values)
|
|
83
196
|
|
|
84
197
|
with open(self.filename, "w") as f:
|
|
85
198
|
f.write(json.dumps(result, indent=4))
|
|
86
|
-
|
|
199
|
+
|
|
200
|
+
|
|
87
201
|
class Evaluator(DistributedObject):
|
|
202
|
+
"""
|
|
203
|
+
Distributed evaluation engine for computing metrics on model predictions.
|
|
204
|
+
|
|
205
|
+
This class handles the evaluation of predicted outputs using predefined metric loaders.
|
|
206
|
+
It supports multi-output and multi-target configurations, computes aggregated statistics
|
|
207
|
+
across training and validation datasets, and synchronizes results across processes.
|
|
208
|
+
|
|
209
|
+
Evaluation results are stored in JSON format and optionally displayed during iteration.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
train_name (str): Unique name of the evaluation run, used for logging and output folders.
|
|
213
|
+
metrics (dict[str, TargetCriterionsLoader]): Dictionary mapping output groups to loaders of target metrics.
|
|
214
|
+
dataset (DataMetric): Dataset provider configured for evaluation mode.
|
|
215
|
+
|
|
216
|
+
Attributes:
|
|
217
|
+
statistics_train (Statistics): Object used to store training evaluation metrics.
|
|
218
|
+
statistics_validation (Statistics): Object used to store validation evaluation metrics.
|
|
219
|
+
dataloader (list[DataLoader]): DataLoaders for training and validation sets.
|
|
220
|
+
metric_path (str): Path to the evaluation output directory.
|
|
221
|
+
metrics (dict): Instantiated metrics organized by output and target groups.
|
|
222
|
+
"""
|
|
88
223
|
|
|
89
224
|
@config("Evaluator")
|
|
90
|
-
def __init__(
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
train_name: str = "default:TRAIN_01",
|
|
228
|
+
metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()},
|
|
229
|
+
dataset: DataMetric = DataMetric(),
|
|
230
|
+
) -> None:
|
|
91
231
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
92
232
|
exit(0)
|
|
93
233
|
super().__init__(train_name)
|
|
94
|
-
self.metric_path =
|
|
234
|
+
self.metric_path = evaluations_directory() + self.name + "/"
|
|
95
235
|
self.metricsLoader = metrics
|
|
96
236
|
self.dataset = dataset
|
|
97
|
-
self.metrics = {k: v.
|
|
98
|
-
self.statistics_train = Statistics(self.metric_path+"Metric_TRAIN.json")
|
|
99
|
-
self.statistics_validation = Statistics(self.metric_path+"Metric_VALIDATION.json")
|
|
237
|
+
self.metrics = {k: v.get_targets_criterions(k) for k, v in self.metricsLoader.items()}
|
|
238
|
+
self.statistics_train = Statistics(self.metric_path + "Metric_TRAIN.json")
|
|
239
|
+
self.statistics_validation = Statistics(self.metric_path + "Metric_VALIDATION.json")
|
|
100
240
|
|
|
101
|
-
def update(self, data_dict: dict[str, tuple[torch.Tensor, str]], statistics
|
|
241
|
+
def update(self, data_dict: dict[str, tuple[torch.Tensor, str]], statistics: Statistics) -> dict[str, float]:
|
|
242
|
+
"""
|
|
243
|
+
Compute metrics for a batch and update running statistics.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
data_dict (dict): Dictionary where keys are output/target group names and values are
|
|
247
|
+
tuples of (tensor, sample name).
|
|
248
|
+
statistics (Statistics): The statistics object to update (train or validation).
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
dict[str, float]: Dictionary of computed metric values with keys in the format
|
|
252
|
+
'output_group:target_group:MetricName'.
|
|
253
|
+
"""
|
|
102
254
|
result = {}
|
|
103
255
|
for output_group in self.metrics:
|
|
104
256
|
for target_group in self.metrics[output_group]:
|
|
105
|
-
targets = [
|
|
257
|
+
targets = [
|
|
258
|
+
(data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0])
|
|
259
|
+
for group in target_group.split(";")
|
|
260
|
+
if group in data_dict
|
|
261
|
+
]
|
|
106
262
|
name = data_dict[output_group][1][0]
|
|
107
263
|
for metric in self.metrics[output_group][target_group]:
|
|
108
|
-
result["{}:{}:{
|
|
264
|
+
result[f"{output_group}:{target_group}:{metric.__class__.__name__}"] = metric(
|
|
265
|
+
(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0]),
|
|
266
|
+
*targets,
|
|
267
|
+
).item()
|
|
109
268
|
statistics.add(result, name)
|
|
110
269
|
return result
|
|
111
|
-
|
|
270
|
+
|
|
112
271
|
def setup(self, world_size: int):
|
|
272
|
+
"""
|
|
273
|
+
Prepare the evaluator for distributed metric computation.
|
|
274
|
+
|
|
275
|
+
This method performs the following steps:
|
|
276
|
+
- Checks whether previous evaluation results exist and optionally overwrites them.
|
|
277
|
+
- Creates the output directory and copies the current configuration file for reproducibility.
|
|
278
|
+
- Loads the evaluation dataset according to the world size.
|
|
279
|
+
- Validates that all specified output and target groups used in metric definitions
|
|
280
|
+
are present in the dataset group configuration.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
world_size (int): Number of processes in the distributed evaluation setup.
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
EvaluatorError: If any metric output or target group is missing in the dataset's group mapping.
|
|
287
|
+
"""
|
|
113
288
|
if os.path.exists(self.metric_path):
|
|
114
289
|
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
115
|
-
accept = builtins.input(
|
|
290
|
+
accept = builtins.input(
|
|
291
|
+
f"The metric {self.name} already exists ! Do you want to overwrite it (yes,no) : "
|
|
292
|
+
)
|
|
116
293
|
if accept != "yes":
|
|
117
294
|
return
|
|
118
|
-
|
|
295
|
+
|
|
119
296
|
if os.path.exists(self.metric_path):
|
|
120
|
-
shutil.rmtree(self.metric_path)
|
|
297
|
+
shutil.rmtree(self.metric_path)
|
|
121
298
|
|
|
122
299
|
if not os.path.exists(self.metric_path):
|
|
123
300
|
os.makedirs(self.metric_path)
|
|
124
|
-
metric_namefile_src =
|
|
125
|
-
shutil.copyfile(
|
|
301
|
+
metric_namefile_src = config_file().replace(".yml", "")
|
|
302
|
+
shutil.copyfile(
|
|
303
|
+
metric_namefile_src + ".yml",
|
|
304
|
+
f"{self.metric_path}{metric_namefile_src}.yml",
|
|
305
|
+
)
|
|
126
306
|
|
|
127
|
-
self.dataloader = self.dataset.
|
|
307
|
+
self.dataloader = self.dataset.get_data(world_size)
|
|
128
308
|
|
|
129
|
-
|
|
309
|
+
groups_dest = [group for groups in self.dataset.groups_src.values() for group in groups]
|
|
130
310
|
|
|
131
|
-
missing_outputs = set(self.metrics.keys()) - set(
|
|
311
|
+
missing_outputs = set(self.metrics.keys()) - set(groups_dest)
|
|
132
312
|
if missing_outputs:
|
|
133
313
|
raise EvaluatorError(
|
|
134
|
-
f"The following metric output groups are missing from '
|
|
135
|
-
f"Available groups: {sorted(
|
|
314
|
+
f"The following metric output groups are missing from 'groups_dest': {sorted(missing_outputs)}. ",
|
|
315
|
+
f"Available groups: {sorted(groups_dest)}",
|
|
136
316
|
)
|
|
137
317
|
|
|
138
318
|
target_groups = []
|
|
139
319
|
for i in {target for targets in self.metrics.values() for target in targets}:
|
|
140
320
|
for u in i.split(";"):
|
|
141
321
|
target_groups.append(u)
|
|
142
|
-
missing_targets = set(target_groups) - set(
|
|
322
|
+
missing_targets = set(target_groups) - set(groups_dest)
|
|
143
323
|
if missing_targets:
|
|
144
324
|
raise EvaluatorError(
|
|
145
|
-
f"The following metric target groups are missing from '
|
|
146
|
-
f"Available groups: {sorted(
|
|
325
|
+
f"The following metric target groups are missing from 'groups_dest': {sorted(missing_targets)}. ",
|
|
326
|
+
f"Available groups: {sorted(groups_dest)}",
|
|
147
327
|
)
|
|
148
328
|
|
|
149
329
|
def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
|
|
150
|
-
|
|
151
|
-
|
|
330
|
+
"""
|
|
331
|
+
Execute the distributed evaluation loop over the training and validation datasets.
|
|
332
|
+
|
|
333
|
+
This method iterates through the provided DataLoaders (train and optionally validation),
|
|
334
|
+
updates the metric statistics using the configured `metrics` dictionary, and synchronizes
|
|
335
|
+
the results across all processes. On the global rank 0, the metrics are saved as JSON files.
|
|
336
|
+
|
|
337
|
+
Metrics are displayed in real-time using `tqdm` progress bars, showing a summary of the
|
|
338
|
+
current batch's computed values.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
world_size (int): Total number of distributed processes.
|
|
342
|
+
global_rank (int): Global rank of the current process (used for writing results).
|
|
343
|
+
gpu (int): Local GPU ID used for synchronization.
|
|
344
|
+
dataloaders (list[DataLoader]): A list containing one or two DataLoaders:
|
|
345
|
+
- `dataloaders[0]` is used for training evaluation.
|
|
346
|
+
- `dataloaders[1]` (optional) is used for validation evaluation.
|
|
347
|
+
|
|
348
|
+
Notes:
|
|
349
|
+
- Only the main process (`global_rank == 0`) writes final results to disk.
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
def description(measure):
|
|
353
|
+
return (
|
|
354
|
+
f"Metric TRAIN : {' | '.join(f'{k}: {v:.4f}' for k, v in measure.items())}"
|
|
355
|
+
if measure is not None
|
|
356
|
+
else "Metric TRAIN : "
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
with tqdm.tqdm(
|
|
360
|
+
iterable=enumerate(dataloaders[0]),
|
|
361
|
+
leave=True,
|
|
362
|
+
desc=description(None),
|
|
363
|
+
total=len(dataloaders[0]),
|
|
364
|
+
ncols=0,
|
|
365
|
+
) as batch_iter:
|
|
152
366
|
for _, data_dict in batch_iter:
|
|
153
|
-
batch_iter.set_description(
|
|
367
|
+
batch_iter.set_description(
|
|
368
|
+
description(
|
|
369
|
+
self.update(
|
|
370
|
+
{k: (v[0], v[4]) for k, v in data_dict.items()},
|
|
371
|
+
self.statistics_train,
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
)
|
|
154
375
|
outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
|
|
155
376
|
if global_rank == 0:
|
|
156
377
|
self.statistics_train.write(outputs)
|
|
157
378
|
if len(dataloaders) == 2:
|
|
158
|
-
|
|
159
|
-
|
|
379
|
+
|
|
380
|
+
def description(measure):
|
|
381
|
+
return (
|
|
382
|
+
f"Metric VALIDATION : {' | '.join(f'{k}: {v:.2f}' for k, v in measure.items())}"
|
|
383
|
+
if measure is not None
|
|
384
|
+
else "Metric VALIDATION : "
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
with tqdm.tqdm(
|
|
388
|
+
iterable=enumerate(dataloaders[1]),
|
|
389
|
+
leave=True,
|
|
390
|
+
desc=description(None),
|
|
391
|
+
total=len(dataloaders[1]),
|
|
392
|
+
ncols=0,
|
|
393
|
+
) as batch_iter:
|
|
160
394
|
for _, data_dict in batch_iter:
|
|
161
|
-
batch_iter.set_description(
|
|
395
|
+
batch_iter.set_description(
|
|
396
|
+
description(
|
|
397
|
+
self.update(
|
|
398
|
+
{k: (v[0], v[4]) for k, v in data_dict.items()},
|
|
399
|
+
self.statistics_validation,
|
|
400
|
+
)
|
|
401
|
+
)
|
|
402
|
+
)
|
|
162
403
|
outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
|
|
163
404
|
if global_rank == 0:
|
|
164
|
-
self.statistics_validation.write(outputs)
|
|
405
|
+
self.statistics_validation.write(outputs)
|
konfai/main.py
CHANGED
|
@@ -1,53 +1,102 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import os
|
|
3
|
-
|
|
3
|
+
import sys
|
|
4
|
+
|
|
4
5
|
import torch.multiprocessing as mp
|
|
5
|
-
from
|
|
6
|
-
|
|
6
|
+
from torch.cuda import device_count
|
|
7
|
+
|
|
8
|
+
from konfai import konfai_nb_cores
|
|
9
|
+
from konfai.utils.utils import Log, TensorBoard, setup
|
|
7
10
|
|
|
8
|
-
import sys
|
|
9
11
|
sys.path.insert(0, os.getcwd())
|
|
10
12
|
|
|
13
|
+
|
|
11
14
|
def main():
|
|
15
|
+
"""
|
|
16
|
+
Entry point for launching KonfAI training locally.
|
|
17
|
+
|
|
18
|
+
- Parses arguments (if any) via a setup parser.
|
|
19
|
+
- Initializes distributed environment based on available CUDA devices or CPU cores.
|
|
20
|
+
- Launches training via `mp.spawn`.
|
|
21
|
+
- Manages logging and TensorBoard context.
|
|
22
|
+
|
|
23
|
+
KeyboardInterrupt is caught to allow clean manual termination.
|
|
24
|
+
"""
|
|
12
25
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
13
26
|
try:
|
|
14
|
-
with setup(parser) as
|
|
15
|
-
with Log(
|
|
27
|
+
with setup(parser) as distributed_object:
|
|
28
|
+
with Log(distributed_object.name, 0):
|
|
16
29
|
world_size = device_count()
|
|
17
30
|
if world_size == 0:
|
|
18
|
-
world_size = int(
|
|
19
|
-
|
|
20
|
-
with TensorBoard(
|
|
21
|
-
mp.spawn(
|
|
31
|
+
world_size = int(konfai_nb_cores())
|
|
32
|
+
distributed_object.setup(world_size)
|
|
33
|
+
with TensorBoard(distributed_object.name):
|
|
34
|
+
mp.spawn(distributed_object, nprocs=world_size)
|
|
22
35
|
except KeyboardInterrupt:
|
|
23
36
|
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
24
37
|
|
|
38
|
+
|
|
25
39
|
def cluster():
|
|
40
|
+
"""
|
|
41
|
+
Entry point for launching KonfAI on a cluster using Submitit.
|
|
42
|
+
|
|
43
|
+
- Parses cluster-specific arguments: job name, nodes, memory, time limit, etc.
|
|
44
|
+
- Sets up distributed environment based on number of nodes and GPUs.
|
|
45
|
+
- Configures Submitit executor with job specs.
|
|
46
|
+
- Submits the job to SLURM (or another Submitit-compatible backend).
|
|
47
|
+
|
|
48
|
+
Environment variables:
|
|
49
|
+
KONFAI_OVERWRITE: Set to force overwrite of previous training runs.
|
|
50
|
+
KONFAI_CLUSTER: Mark this as a cluster job (used downstream).
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
KeyboardInterrupt: On manual interruption.
|
|
54
|
+
Exception: Any submission-related error is printed and causes exit.
|
|
55
|
+
"""
|
|
26
56
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
27
57
|
|
|
28
58
|
# Cluster manager arguments
|
|
29
|
-
cluster_args = parser.add_argument_group(
|
|
30
|
-
cluster_args.add_argument(
|
|
31
|
-
cluster_args.add_argument(
|
|
32
|
-
cluster_args.add_argument(
|
|
33
|
-
cluster_args.add_argument(
|
|
34
|
-
|
|
59
|
+
cluster_args = parser.add_argument_group("Cluster manager arguments")
|
|
60
|
+
cluster_args.add_argument("--name", type=str, help="Task name", required=True)
|
|
61
|
+
cluster_args.add_argument("--num-nodes", "--num_nodes", default=1, type=int, help="Number of nodes")
|
|
62
|
+
cluster_args.add_argument("--memory", type=int, default=16, help="Amount of memory per node")
|
|
63
|
+
cluster_args.add_argument(
|
|
64
|
+
"--time-limit",
|
|
65
|
+
"--time_limit",
|
|
66
|
+
type=int,
|
|
67
|
+
default=1440,
|
|
68
|
+
help="Job time limit in minute",
|
|
69
|
+
)
|
|
70
|
+
cluster_args.add_argument(
|
|
71
|
+
"--resubmit",
|
|
72
|
+
action="store_true",
|
|
73
|
+
help="Automatically resubmit job just before timout",
|
|
74
|
+
)
|
|
35
75
|
try:
|
|
36
|
-
with setup(parser) as
|
|
76
|
+
with setup(parser) as distributed_object:
|
|
37
77
|
args = parser.parse_args()
|
|
38
78
|
config = vars(args)
|
|
39
79
|
os.environ["KONFAI_OVERWRITE"] = "True"
|
|
40
80
|
os.environ["KONFAI_CLUSTER"] = "True"
|
|
41
81
|
|
|
42
82
|
n_gpu = len(config["gpu"].split(","))
|
|
43
|
-
|
|
83
|
+
distributed_object.setup(n_gpu * int(config["num_nodes"]))
|
|
44
84
|
import submitit
|
|
85
|
+
|
|
45
86
|
executor = submitit.AutoExecutor(folder="./Cluster/")
|
|
46
|
-
executor.update_parameters(
|
|
47
|
-
|
|
48
|
-
|
|
87
|
+
executor.update_parameters(
|
|
88
|
+
name=config["name"],
|
|
89
|
+
mem_gb=config["memory"],
|
|
90
|
+
gpus_per_node=n_gpu,
|
|
91
|
+
tasks_per_node=n_gpu // distributed_object.size,
|
|
92
|
+
cpus_per_task=config["num_workers"],
|
|
93
|
+
nodes=config["num_nodes"],
|
|
94
|
+
timeout_min=config["time_limit"],
|
|
95
|
+
)
|
|
96
|
+
with TensorBoard(distributed_object.name):
|
|
97
|
+
executor.submit(distributed_object)
|
|
49
98
|
except KeyboardInterrupt:
|
|
50
99
|
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
51
100
|
except Exception as e:
|
|
52
101
|
print(e)
|
|
53
|
-
exit(1)
|
|
102
|
+
exit(1)
|