GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
GANDLF/optimizers/__init__.py
CHANGED
|
@@ -48,13 +48,9 @@ def get_optimizer(params):
|
|
|
48
48
|
optimizer (torch.optim.Optimizer): An instance of the specified optimizer.
|
|
49
49
|
|
|
50
50
|
"""
|
|
51
|
-
# Retrieve the optimizer type from the input parameters
|
|
52
|
-
optimizer_type = params["optimizer"]["type"]
|
|
53
51
|
|
|
52
|
+
chosen_optimizer = params["optimizer"]["type"]
|
|
54
53
|
assert (
|
|
55
|
-
|
|
56
|
-
), f"
|
|
57
|
-
|
|
58
|
-
# Create the optimizer instance using the specified type and input parameters
|
|
59
|
-
optimizer_function = global_optimizer_dict[optimizer_type]
|
|
60
|
-
return optimizer_function(params)
|
|
54
|
+
chosen_optimizer in global_optimizer_dict
|
|
55
|
+
), f"Could not find the requested optimizer '{params['optimizer']['type']}'"
|
|
56
|
+
return global_optimizer_dict[chosen_optimizer](params)
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from opacus import PrivacyEngine
|
|
3
|
+
|
|
4
|
+
import collections.abc as abc
|
|
5
|
+
from functools import partial
|
|
6
|
+
from torch.utils.data._utils.collate import default_collate
|
|
7
|
+
from typing import Union, Callable
|
|
8
|
+
import copy
|
|
9
|
+
from opacus.optimizers import DPOptimizer
|
|
10
|
+
from opacus.utils.uniform_sampler import (
|
|
11
|
+
DistributedUniformWithReplacementSampler,
|
|
12
|
+
UniformWithReplacementSampler,
|
|
13
|
+
)
|
|
14
|
+
from torch.utils.data import BatchSampler, DataLoader, Sampler
|
|
15
|
+
import math
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from typing import List
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BatchSplittingSampler(Sampler[List[int]]):
|
|
22
|
+
"""
|
|
23
|
+
Samples according to the underlying instance of ``Sampler``, but splits
|
|
24
|
+
the index sequences into smaller chunks.
|
|
25
|
+
|
|
26
|
+
Used to split large logical batches into physical batches of a smaller size,
|
|
27
|
+
while coordinating with DPOptimizer when the logical batch has ended.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
sampler: Sampler[List[int]],
|
|
34
|
+
max_batch_size: int,
|
|
35
|
+
optimizer: DPOptimizer,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
sampler: Wrapped Sampler instance
|
|
41
|
+
max_batch_size: Max size of emitted chunk of indices
|
|
42
|
+
optimizer: optimizer instance to notify when the logical batch is over
|
|
43
|
+
"""
|
|
44
|
+
self.sampler = sampler
|
|
45
|
+
self.max_batch_size = max_batch_size
|
|
46
|
+
self.optimizer = optimizer
|
|
47
|
+
|
|
48
|
+
def __iter__(self):
|
|
49
|
+
for batch_idxs in self.sampler:
|
|
50
|
+
if len(batch_idxs) == 0:
|
|
51
|
+
self.optimizer.signal_skip_step(do_skip=False)
|
|
52
|
+
yield []
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
split_idxs = np.array_split(
|
|
56
|
+
batch_idxs, math.ceil(len(batch_idxs) / self.max_batch_size)
|
|
57
|
+
)
|
|
58
|
+
split_idxs = [s.tolist() for s in split_idxs]
|
|
59
|
+
for x in split_idxs[:-1]:
|
|
60
|
+
self.optimizer.signal_skip_step(do_skip=True)
|
|
61
|
+
yield x
|
|
62
|
+
self.optimizer.signal_skip_step(do_skip=False)
|
|
63
|
+
yield split_idxs[-1]
|
|
64
|
+
|
|
65
|
+
def __len__(self):
|
|
66
|
+
if isinstance(self.sampler, BatchSampler):
|
|
67
|
+
return math.ceil(
|
|
68
|
+
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
|
|
69
|
+
)
|
|
70
|
+
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
|
|
71
|
+
self.sampler, DistributedUniformWithReplacementSampler
|
|
72
|
+
):
|
|
73
|
+
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
|
|
74
|
+
return math.ceil(
|
|
75
|
+
len(self.sampler) * (expected_batch_size / self.max_batch_size)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return len(self.sampler)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class OpacusAnonymizationManager:
|
|
82
|
+
def __init__(self, params):
|
|
83
|
+
self.params = params
|
|
84
|
+
|
|
85
|
+
def apply_privacy(
|
|
86
|
+
self,
|
|
87
|
+
model: torch.nn.Module,
|
|
88
|
+
optimizer: torch.optim.Optimizer,
|
|
89
|
+
train_dataloader: DataLoader,
|
|
90
|
+
):
|
|
91
|
+
model, optimizer, train_dataloader, privacy_engine = self._apply_privacy(
|
|
92
|
+
model, optimizer, train_dataloader
|
|
93
|
+
)
|
|
94
|
+
train_dataloader.collate_fn = self._empty_collate(train_dataloader.dataset[0])
|
|
95
|
+
max_physical_batch_size = self.params["differential_privacy"].get(
|
|
96
|
+
"max_physical_batch_size", self.params["batch_size"]
|
|
97
|
+
)
|
|
98
|
+
if max_physical_batch_size != self.params["batch_size"]:
|
|
99
|
+
train_dataloader = self._wrap_data_loader(
|
|
100
|
+
data_loader=train_dataloader,
|
|
101
|
+
max_batch_size=max_physical_batch_size,
|
|
102
|
+
optimizer=optimizer,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return model, optimizer, train_dataloader, privacy_engine
|
|
106
|
+
|
|
107
|
+
def _apply_privacy(self, model, optimizer, train_dataloader):
|
|
108
|
+
privacy_engine = PrivacyEngine(
|
|
109
|
+
accountant=self.params["differential_privacy"]["accountant"],
|
|
110
|
+
secure_mode=self.params["differential_privacy"]["secure_mode"],
|
|
111
|
+
)
|
|
112
|
+
epsilon = self.params["differential_privacy"].get("epsilon")
|
|
113
|
+
|
|
114
|
+
if epsilon is not None:
|
|
115
|
+
(
|
|
116
|
+
model,
|
|
117
|
+
optimizer,
|
|
118
|
+
train_dataloader,
|
|
119
|
+
) = privacy_engine.make_private_with_epsilon(
|
|
120
|
+
module=model,
|
|
121
|
+
optimizer=optimizer,
|
|
122
|
+
data_loader=train_dataloader,
|
|
123
|
+
max_grad_norm=self.params["differential_privacy"]["max_grad_norm"],
|
|
124
|
+
epochs=self.params["num_epochs"],
|
|
125
|
+
target_epsilon=self.params["differential_privacy"]["epsilon"],
|
|
126
|
+
target_delta=self.params["differential_privacy"]["delta"],
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
model, optimizer, train_dataloader = privacy_engine.make_private(
|
|
130
|
+
module=model,
|
|
131
|
+
optimizer=optimizer,
|
|
132
|
+
data_loader=train_dataloader,
|
|
133
|
+
noise_multiplier=self.params["differential_privacy"][
|
|
134
|
+
"noise_multiplier"
|
|
135
|
+
],
|
|
136
|
+
max_grad_norm=self.params["differential_privacy"]["max_grad_norm"],
|
|
137
|
+
)
|
|
138
|
+
return model, optimizer, train_dataloader, privacy_engine
|
|
139
|
+
|
|
140
|
+
def _empty_collate(
|
|
141
|
+
self,
|
|
142
|
+
item_example: Union[
|
|
143
|
+
torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str
|
|
144
|
+
],
|
|
145
|
+
) -> Callable:
|
|
146
|
+
"""
|
|
147
|
+
Creates a new collate function that behave same as default pytorch one,
|
|
148
|
+
but can process the empty batches.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
item_example (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): An example item from the dataset.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Callable: function that should replace dataloader collate: `dataloader.collate_fn = empty_collate(...)`
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def custom_collate(batch, _empty_batch_value):
|
|
158
|
+
if len(batch) > 0:
|
|
159
|
+
return default_collate(batch) # default behavior
|
|
160
|
+
else:
|
|
161
|
+
return copy.copy(_empty_batch_value)
|
|
162
|
+
|
|
163
|
+
empty_batch_value = self._build_empty_batch_value(item_example)
|
|
164
|
+
|
|
165
|
+
return partial(custom_collate, _empty_batch_value=empty_batch_value)
|
|
166
|
+
|
|
167
|
+
def _build_empty_batch_value(
|
|
168
|
+
self,
|
|
169
|
+
sample: Union[
|
|
170
|
+
torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str
|
|
171
|
+
],
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
Build an empty batch value from a sample. This function is used to create a placeholder for empty batches in an iteration. Inspired from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/collate.py#L108. The key difference is that pytorch `collate` has to traverse batch of objects AND unite its fields to lists, while this function traverse a single item AND creates an "empty" version of the batch.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
sample (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): A sample from the dataset.
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
TypeError: If the data type is not supported.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]: An empty batch value.
|
|
184
|
+
"""
|
|
185
|
+
if isinstance(sample, torch.Tensor):
|
|
186
|
+
# Create an empty tensor with the same shape except for the zeroed batch dimension.
|
|
187
|
+
return torch.empty((0,) + sample.shape)
|
|
188
|
+
elif isinstance(sample, np.ndarray):
|
|
189
|
+
# Create an empty tensor from a numpy array, also with the zeroed batch dimension.
|
|
190
|
+
return torch.empty(
|
|
191
|
+
(0,) + sample.shape, dtype=torch.from_numpy(sample).dtype
|
|
192
|
+
)
|
|
193
|
+
elif isinstance(sample, abc.Mapping):
|
|
194
|
+
# Recursively handle dictionary-like objects.
|
|
195
|
+
return {
|
|
196
|
+
key: self._build_empty_batch_value(value)
|
|
197
|
+
for key, value in sample.items()
|
|
198
|
+
}
|
|
199
|
+
elif isinstance(sample, tuple) and hasattr(sample, "_fields"): # namedtuple
|
|
200
|
+
return type(sample)(
|
|
201
|
+
*(self._build_empty_batch_value(item) for item in sample)
|
|
202
|
+
)
|
|
203
|
+
elif isinstance(sample, abc.Sequence) and not isinstance(sample, str):
|
|
204
|
+
# Handle lists and tuples, but exclude strings.
|
|
205
|
+
return [self._build_empty_batch_value(item) for item in sample]
|
|
206
|
+
elif isinstance(sample, (int, float, str)):
|
|
207
|
+
# Return an empty list for basic data types.
|
|
208
|
+
return []
|
|
209
|
+
else:
|
|
210
|
+
raise TypeError(f"Unsupported data type: {type(sample)}")
|
|
211
|
+
|
|
212
|
+
def _wrap_data_loader(
|
|
213
|
+
self, data_loader: DataLoader, max_batch_size: int, optimizer: DPOptimizer
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
Replaces batch_sampler in the input data loader with ``BatchSplittingSampler``
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
data_loader: Wrapper DataLoader
|
|
220
|
+
max_batch_size: max physical batch size we want to emit
|
|
221
|
+
optimizer: DPOptimizer instance used for training
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
New DataLoader instance with batch_sampler wrapped in ``BatchSplittingSampler``
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
return DataLoader(
|
|
228
|
+
dataset=data_loader.dataset,
|
|
229
|
+
batch_sampler=BatchSplittingSampler(
|
|
230
|
+
sampler=data_loader.batch_sampler,
|
|
231
|
+
max_batch_size=max_batch_size,
|
|
232
|
+
optimizer=optimizer,
|
|
233
|
+
),
|
|
234
|
+
num_workers=data_loader.num_workers,
|
|
235
|
+
collate_fn=data_loader.collate_fn,
|
|
236
|
+
pin_memory=data_loader.pin_memory,
|
|
237
|
+
timeout=data_loader.timeout,
|
|
238
|
+
worker_init_fn=data_loader.worker_init_fn,
|
|
239
|
+
multiprocessing_context=data_loader.multiprocessing_context,
|
|
240
|
+
generator=data_loader.generator,
|
|
241
|
+
prefetch_factor=data_loader.prefetch_factor,
|
|
242
|
+
persistent_workers=data_loader.persistent_workers,
|
|
243
|
+
)
|
GANDLF/schedulers/__init__.py
CHANGED
|
@@ -6,7 +6,8 @@ from .wrap_torch import (
|
|
|
6
6
|
exp,
|
|
7
7
|
step,
|
|
8
8
|
reduce_on_plateau,
|
|
9
|
-
|
|
9
|
+
cosineannealingwarmrestarts,
|
|
10
|
+
cosineannealingLR,
|
|
10
11
|
)
|
|
11
12
|
|
|
12
13
|
from .wrap_monai import warmupcosineschedule
|
|
@@ -24,7 +25,9 @@ global_schedulers_dict = {
|
|
|
24
25
|
"reduce-on-plateau": reduce_on_plateau,
|
|
25
26
|
"plateau": reduce_on_plateau,
|
|
26
27
|
"reduceonplateau": reduce_on_plateau,
|
|
27
|
-
"cosineannealing":
|
|
28
|
+
"cosineannealing": cosineannealingwarmrestarts,
|
|
29
|
+
"cosineannealingwarmrestarts": cosineannealingwarmrestarts,
|
|
30
|
+
"cosineannealinglr": cosineannealingLR,
|
|
28
31
|
"warmupcosineschedule": warmupcosineschedule,
|
|
29
32
|
"wcs": warmupcosineschedule,
|
|
30
33
|
}
|
|
@@ -38,6 +41,10 @@ def get_scheduler(params):
|
|
|
38
41
|
params (dict): The parameters' dictionary.
|
|
39
42
|
|
|
40
43
|
Returns:
|
|
41
|
-
|
|
44
|
+
scheduler (object): The scheduler definition.
|
|
42
45
|
"""
|
|
43
|
-
|
|
46
|
+
chosen_scheduler = params["scheduler"]["type"].lower()
|
|
47
|
+
assert (
|
|
48
|
+
chosen_scheduler in global_schedulers_dict
|
|
49
|
+
), f"Could not find the requested scheduler '{params['scheduler']['type']}'"
|
|
50
|
+
return global_schedulers_dict[chosen_scheduler](params)
|
GANDLF/schedulers/wrap_torch.py
CHANGED
|
@@ -5,6 +5,7 @@ from torch.optim.lr_scheduler import (
|
|
|
5
5
|
StepLR,
|
|
6
6
|
ReduceLROnPlateau,
|
|
7
7
|
CosineAnnealingWarmRestarts,
|
|
8
|
+
CosineAnnealingLR,
|
|
8
9
|
)
|
|
9
10
|
import math
|
|
10
11
|
|
|
@@ -169,14 +170,25 @@ def reduce_on_plateau(parameters):
|
|
|
169
170
|
)
|
|
170
171
|
|
|
171
172
|
|
|
172
|
-
def
|
|
173
|
+
def cosineannealingwarmrestarts(parameters):
|
|
173
174
|
parameters["scheduler"]["T_0"] = parameters["scheduler"].get("T_0", 5)
|
|
174
175
|
parameters["scheduler"]["T_mult"] = parameters["scheduler"].get("T_mult", 1)
|
|
175
|
-
parameters["scheduler"]["
|
|
176
|
+
parameters["scheduler"]["eta_min"] = parameters["scheduler"].get("eta_min", 0.001)
|
|
176
177
|
|
|
177
178
|
return CosineAnnealingWarmRestarts(
|
|
178
179
|
parameters["optimizer_object"],
|
|
179
180
|
T_0=parameters["scheduler"]["T_0"],
|
|
180
181
|
T_mult=parameters["scheduler"]["T_mult"],
|
|
181
|
-
eta_min=parameters["scheduler"]["
|
|
182
|
+
eta_min=parameters["scheduler"]["eta_min"],
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def cosineannealingLR(parameters):
|
|
187
|
+
parameters["scheduler"]["T_max"] = parameters["scheduler"].get("T_max", 50)
|
|
188
|
+
parameters["scheduler"]["eta_min"] = parameters["scheduler"].get("eta_min", 0.001)
|
|
189
|
+
|
|
190
|
+
return CosineAnnealingLR(
|
|
191
|
+
parameters["optimizer_object"],
|
|
192
|
+
T_max=parameters["scheduler"]["T_max"],
|
|
193
|
+
eta_min=parameters["scheduler"]["eta_min"],
|
|
182
194
|
)
|
GANDLF/training_manager.py
CHANGED
|
@@ -1,20 +1,31 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
|
|
4
|
+
# codacy ignore python-use-of-pickle: Pickle usage is safe in this context (local data only).
|
|
5
|
+
import pickle
|
|
6
|
+
import shutil
|
|
1
7
|
import pandas as pd
|
|
2
|
-
import os, pickle, shutil
|
|
3
8
|
from pathlib import Path
|
|
9
|
+
from warnings import warn
|
|
10
|
+
|
|
11
|
+
import lightning.pytorch as pl
|
|
12
|
+
from lightning.pytorch.profilers import PyTorchProfiler
|
|
13
|
+
from lightning.pytorch.tuner import Tuner as LightningTuner
|
|
4
14
|
|
|
5
|
-
from GANDLF.compute import training_loop
|
|
6
15
|
from GANDLF.utils import get_dataframe, split_data
|
|
16
|
+
from GANDLF.models.lightning_module import GandlfLightningModule
|
|
17
|
+
from GANDLF.data.lightning_datamodule import GandlfTrainingDatamodule
|
|
7
18
|
|
|
8
|
-
import
|
|
19
|
+
from typing import Optional
|
|
9
20
|
|
|
10
21
|
|
|
11
22
|
def TrainingManager(
|
|
12
23
|
dataframe: pd.DataFrame,
|
|
13
24
|
outputDir: str,
|
|
14
25
|
parameters: dict,
|
|
15
|
-
device: str,
|
|
16
26
|
resume: bool,
|
|
17
27
|
reset: bool,
|
|
28
|
+
profile: Optional[bool] = False,
|
|
18
29
|
) -> None:
|
|
19
30
|
"""
|
|
20
31
|
This is the training manager that ties all the training functionality together
|
|
@@ -23,10 +34,14 @@ def TrainingManager(
|
|
|
23
34
|
dataframe (pandas.DataFrame): The full data from CSV.
|
|
24
35
|
outputDir (str): The main output directory.
|
|
25
36
|
parameters (dict): The parameters dictionary.
|
|
26
|
-
device (str): The device to perform computations on.
|
|
27
37
|
resume (bool): Whether the previous run will be resumed or not.
|
|
28
38
|
reset (bool): Whether the previous run will be reset or not.
|
|
39
|
+
profile(bool): Whether we want the profile activity or not. Defaults to False.
|
|
40
|
+
|
|
29
41
|
"""
|
|
42
|
+
|
|
43
|
+
if "output_dir" not in parameters:
|
|
44
|
+
parameters["output_dir"] = outputDir
|
|
30
45
|
if reset:
|
|
31
46
|
shutil.rmtree(outputDir)
|
|
32
47
|
Path(outputDir).mkdir(parents=True, exist_ok=True)
|
|
@@ -95,45 +110,79 @@ def TrainingManager(
|
|
|
95
110
|
# read the data from the pickle if present
|
|
96
111
|
data_dict[data_type] = get_dataframe(currentDataPickle)
|
|
97
112
|
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
113
|
+
# Dataloader initialization - should be extracted somewhere else (preferably abstracted away)
|
|
114
|
+
datamodule = GandlfTrainingDatamodule(data_dict_files, parameters)
|
|
115
|
+
parameters = datamodule.updated_parameters_dict
|
|
116
|
+
|
|
117
|
+
# This entire section should be handled in config parser
|
|
118
|
+
|
|
119
|
+
accelerator = parameters.get("accelerator", "auto")
|
|
120
|
+
allowed_accelerators = ["cpu", "gpu", "auto"]
|
|
121
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
122
|
+
assert (
|
|
123
|
+
accelerator in allowed_accelerators
|
|
124
|
+
), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
|
|
125
|
+
strategy = parameters.get("strategy", "auto")
|
|
126
|
+
allowed_strategies = ["auto", "ddp"]
|
|
127
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
128
|
+
assert (
|
|
129
|
+
strategy in allowed_strategies
|
|
130
|
+
), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
|
|
131
|
+
precision = parameters.get("precision", "32")
|
|
132
|
+
allowed_precisions = [
|
|
133
|
+
"64",
|
|
134
|
+
"64-true",
|
|
135
|
+
"32",
|
|
136
|
+
"32-true",
|
|
137
|
+
"16",
|
|
138
|
+
"16-mixed",
|
|
139
|
+
"bf16",
|
|
140
|
+
"bf16-mixed",
|
|
141
|
+
]
|
|
142
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
143
|
+
assert (
|
|
144
|
+
precision in allowed_precisions
|
|
145
|
+
), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
|
|
146
|
+
|
|
147
|
+
warn(
|
|
148
|
+
f"Configured to use {accelerator} with {strategy} for training, but current development configuration will force single-device only training."
|
|
149
|
+
)
|
|
150
|
+
trainer = pl.Trainer(
|
|
151
|
+
accelerator=accelerator,
|
|
152
|
+
strategy=strategy,
|
|
153
|
+
fast_dev_run=False,
|
|
154
|
+
devices=parameters.get("devices", "auto"),
|
|
155
|
+
num_nodes=parameters.get("num_nodes", 1),
|
|
156
|
+
precision=precision,
|
|
157
|
+
gradient_clip_algorithm=parameters["clip_mode"],
|
|
158
|
+
gradient_clip_val=parameters["clip_grad"],
|
|
159
|
+
max_epochs=parameters["num_epochs"],
|
|
160
|
+
sync_batchnorm=False,
|
|
161
|
+
enable_checkpointing=False,
|
|
162
|
+
logger=False,
|
|
163
|
+
num_sanity_val_steps=0,
|
|
164
|
+
profiler=PyTorchProfiler(sort_by="cpu_time_total", row_limit=10)
|
|
165
|
+
if profile
|
|
166
|
+
else None,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
lightning_module = GandlfLightningModule(
|
|
170
|
+
parameters, output_dir=currentValidationOutputFolder
|
|
171
|
+
)
|
|
108
172
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
"parallel_compute_command"
|
|
113
|
-
].replace("${outputDir}", currentValidationOutputFolder)
|
|
114
|
-
|
|
115
|
-
assert (
|
|
116
|
-
"python" in parallel_compute_command_actual
|
|
117
|
-
), "The 'parallel_compute_command_actual' needs to have the python from the virtual environment, which is usually '${GANDLF_dir}/venv/bin/python'"
|
|
118
|
-
|
|
119
|
-
command = (
|
|
120
|
-
parallel_compute_command_actual
|
|
121
|
-
+ " -m GANDLF.training_loop -train_loader_pickle "
|
|
122
|
-
+ data_dict_files["training"]
|
|
123
|
-
+ " -val_loader_pickle "
|
|
124
|
-
+ data_dict_files["validation"]
|
|
125
|
-
+ " -parameter_pickle "
|
|
126
|
-
+ currentModelConfigPickle
|
|
127
|
-
+ " -device "
|
|
128
|
-
+ str(device)
|
|
129
|
-
+ " -outputDir "
|
|
130
|
-
+ currentValidationOutputFolder
|
|
131
|
-
+ " -testing_loader_pickle "
|
|
132
|
-
+ data_dict_files["testing"]
|
|
173
|
+
if parameters.get("auto_batch_size_find", False):
|
|
174
|
+
LightningTuner(trainer).scale_batch_size(
|
|
175
|
+
lightning_module, datamodule=datamodule
|
|
133
176
|
)
|
|
134
177
|
|
|
135
|
-
|
|
136
|
-
|
|
178
|
+
if parameters.get("auto_lr_find", False):
|
|
179
|
+
LightningTuner(trainer).lr_find(lightning_module, datamodule=datamodule)
|
|
180
|
+
|
|
181
|
+
trainer.fit(lightning_module, datamodule=datamodule)
|
|
182
|
+
|
|
183
|
+
testing_data = data_dict_files.get("testing", None)
|
|
184
|
+
if testing_data:
|
|
185
|
+
trainer.test(lightning_module, datamodule=datamodule)
|
|
137
186
|
|
|
138
187
|
|
|
139
188
|
def TrainingManager_split(
|
|
@@ -142,9 +191,9 @@ def TrainingManager_split(
|
|
|
142
191
|
dataframe_testing: pd.DataFrame,
|
|
143
192
|
outputDir: str,
|
|
144
193
|
parameters: dict,
|
|
145
|
-
device: str,
|
|
146
194
|
resume: bool,
|
|
147
195
|
reset: bool,
|
|
196
|
+
profile: Optional[bool] = False,
|
|
148
197
|
):
|
|
149
198
|
"""
|
|
150
199
|
This is the training manager that ties all the training functionality together
|
|
@@ -155,9 +204,10 @@ def TrainingManager_split(
|
|
|
155
204
|
dataframe_testing (pd.DataFrame): The testing data from CSV.
|
|
156
205
|
outputDir (str): The main output directory.
|
|
157
206
|
parameters (dict): The parameters dictionary.
|
|
158
|
-
device (str): The device to perform computations on.
|
|
159
207
|
resume (bool): Whether the previous run will be resumed or not.
|
|
160
208
|
reset (bool): Whether the previous run will be reset or not.
|
|
209
|
+
profile(bool): Whether the we want the profile activity or not. Defaults to False.
|
|
210
|
+
|
|
161
211
|
"""
|
|
162
212
|
currentModelConfigPickle = os.path.join(outputDir, "parameters.pkl")
|
|
163
213
|
currentModelConfigYaml = os.path.join(outputDir, "config.yaml")
|
|
@@ -178,11 +228,71 @@ def TrainingManager_split(
|
|
|
178
228
|
with open(currentModelConfigYaml, "w") as handle:
|
|
179
229
|
yaml.dump(parameters, handle, default_flow_style=False)
|
|
180
230
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
231
|
+
data_dict_files = {
|
|
232
|
+
"training": dataframe_train,
|
|
233
|
+
"validation": dataframe_validation,
|
|
234
|
+
"testing": dataframe_testing,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
datamodule = GandlfTrainingDatamodule(data_dict_files, parameters)
|
|
238
|
+
parameters = datamodule.updated_parameters_dict
|
|
239
|
+
|
|
240
|
+
# This entire section should be handled in config parser
|
|
241
|
+
|
|
242
|
+
accelerator = parameters.get("accelerator", "auto")
|
|
243
|
+
allowed_accelerators = ["cpu", "gpu", "auto"]
|
|
244
|
+
assert (
|
|
245
|
+
accelerator in allowed_accelerators
|
|
246
|
+
), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
|
|
247
|
+
strategy = parameters.get("strategy", "auto")
|
|
248
|
+
allowed_strategies = ["auto", "ddp"]
|
|
249
|
+
assert (
|
|
250
|
+
strategy in allowed_strategies
|
|
251
|
+
), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
|
|
252
|
+
precision = parameters.get("precision", "32")
|
|
253
|
+
allowed_precisions = [
|
|
254
|
+
"64",
|
|
255
|
+
"64-true",
|
|
256
|
+
"32",
|
|
257
|
+
"32-true",
|
|
258
|
+
"16",
|
|
259
|
+
"16-mixed",
|
|
260
|
+
"bf16",
|
|
261
|
+
"bf16-mixed",
|
|
262
|
+
]
|
|
263
|
+
assert (
|
|
264
|
+
precision in allowed_precisions
|
|
265
|
+
), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
|
|
266
|
+
|
|
267
|
+
trainer = pl.Trainer(
|
|
268
|
+
accelerator=accelerator,
|
|
269
|
+
strategy=strategy,
|
|
270
|
+
fast_dev_run=False,
|
|
271
|
+
devices=parameters.get("devices", "auto"),
|
|
272
|
+
num_nodes=parameters.get("num_nodes", 1),
|
|
273
|
+
precision=precision,
|
|
274
|
+
gradient_clip_algorithm=parameters["clip_mode"],
|
|
275
|
+
gradient_clip_val=parameters["clip_grad"],
|
|
276
|
+
max_epochs=parameters["num_epochs"],
|
|
277
|
+
sync_batchnorm=False,
|
|
278
|
+
enable_checkpointing=False,
|
|
279
|
+
logger=False,
|
|
280
|
+
num_sanity_val_steps=0,
|
|
281
|
+
profiler=PyTorchProfiler(sort_by="cpu_time_total", row_limit=10)
|
|
282
|
+
if profile
|
|
283
|
+
else None,
|
|
188
284
|
)
|
|
285
|
+
lightning_module = GandlfLightningModule(parameters, output_dir=outputDir)
|
|
286
|
+
|
|
287
|
+
if parameters.get("auto_batch_size_find", False):
|
|
288
|
+
LightningTuner(trainer).scale_batch_size(
|
|
289
|
+
lightning_module, datamodule=datamodule
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
if parameters.get("auto_lr_find", False):
|
|
293
|
+
LightningTuner(trainer).lr_find(lightning_module, datamodule=datamodule)
|
|
294
|
+
|
|
295
|
+
trainer.fit(lightning_module, datamodule=datamodule)
|
|
296
|
+
|
|
297
|
+
if dataframe_testing is not None:
|
|
298
|
+
trainer.test(lightning_module, datamodule=datamodule)
|
GANDLF/utils/__init__.py
CHANGED
|
@@ -7,9 +7,11 @@ from .imaging import (
|
|
|
7
7
|
resize_image,
|
|
8
8
|
resample_image,
|
|
9
9
|
perform_sanity_check_on_subject,
|
|
10
|
+
sanity_check_on_file_readers,
|
|
10
11
|
write_training_patches,
|
|
11
12
|
get_correct_padding_size,
|
|
12
13
|
applyCustomColorMap,
|
|
14
|
+
MapSaver,
|
|
13
15
|
)
|
|
14
16
|
|
|
15
17
|
from .tensor import (
|
|
@@ -58,9 +60,9 @@ from .generic import (
|
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
from .modelio import (
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
63
|
+
BEST_MODEL_PATH_END,
|
|
64
|
+
LATEST_MODEL_PATH_END,
|
|
65
|
+
INITIAL_MODEL_PATH_END,
|
|
64
66
|
load_model,
|
|
65
67
|
load_ov_model,
|
|
66
68
|
save_model,
|