GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
@@ -48,13 +48,9 @@ def get_optimizer(params):
48
48
  optimizer (torch.optim.Optimizer): An instance of the specified optimizer.
49
49
 
50
50
  """
51
- # Retrieve the optimizer type from the input parameters
52
- optimizer_type = params["optimizer"]["type"]
53
51
 
52
+ chosen_optimizer = params["optimizer"]["type"]
54
53
  assert (
55
- optimizer_type in global_optimizer_dict
56
- ), f"Optimizer type {optimizer_type} not found"
57
-
58
- # Create the optimizer instance using the specified type and input parameters
59
- optimizer_function = global_optimizer_dict[optimizer_type]
60
- return optimizer_function(params)
54
+ chosen_optimizer in global_optimizer_dict
55
+ ), f"Could not find the requested optimizer '{params['optimizer']['type']}'"
56
+ return global_optimizer_dict[chosen_optimizer](params)
@@ -0,0 +1,243 @@
1
+ import torch
2
+ from opacus import PrivacyEngine
3
+
4
+ import collections.abc as abc
5
+ from functools import partial
6
+ from torch.utils.data._utils.collate import default_collate
7
+ from typing import Union, Callable
8
+ import copy
9
+ from opacus.optimizers import DPOptimizer
10
+ from opacus.utils.uniform_sampler import (
11
+ DistributedUniformWithReplacementSampler,
12
+ UniformWithReplacementSampler,
13
+ )
14
+ from torch.utils.data import BatchSampler, DataLoader, Sampler
15
+ import math
16
+ import numpy as np
17
+
18
+ from typing import List
19
+
20
+
21
+ class BatchSplittingSampler(Sampler[List[int]]):
22
+ """
23
+ Samples according to the underlying instance of ``Sampler``, but splits
24
+ the index sequences into smaller chunks.
25
+
26
+ Used to split large logical batches into physical batches of a smaller size,
27
+ while coordinating with DPOptimizer when the logical batch has ended.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ sampler: Sampler[List[int]],
34
+ max_batch_size: int,
35
+ optimizer: DPOptimizer,
36
+ ):
37
+ """
38
+
39
+ Args:
40
+ sampler: Wrapped Sampler instance
41
+ max_batch_size: Max size of emitted chunk of indices
42
+ optimizer: optimizer instance to notify when the logical batch is over
43
+ """
44
+ self.sampler = sampler
45
+ self.max_batch_size = max_batch_size
46
+ self.optimizer = optimizer
47
+
48
+ def __iter__(self):
49
+ for batch_idxs in self.sampler:
50
+ if len(batch_idxs) == 0:
51
+ self.optimizer.signal_skip_step(do_skip=False)
52
+ yield []
53
+ continue
54
+
55
+ split_idxs = np.array_split(
56
+ batch_idxs, math.ceil(len(batch_idxs) / self.max_batch_size)
57
+ )
58
+ split_idxs = [s.tolist() for s in split_idxs]
59
+ for x in split_idxs[:-1]:
60
+ self.optimizer.signal_skip_step(do_skip=True)
61
+ yield x
62
+ self.optimizer.signal_skip_step(do_skip=False)
63
+ yield split_idxs[-1]
64
+
65
+ def __len__(self):
66
+ if isinstance(self.sampler, BatchSampler):
67
+ return math.ceil(
68
+ len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
69
+ )
70
+ elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
71
+ self.sampler, DistributedUniformWithReplacementSampler
72
+ ):
73
+ expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
74
+ return math.ceil(
75
+ len(self.sampler) * (expected_batch_size / self.max_batch_size)
76
+ )
77
+
78
+ return len(self.sampler)
79
+
80
+
81
+ class OpacusAnonymizationManager:
82
+ def __init__(self, params):
83
+ self.params = params
84
+
85
+ def apply_privacy(
86
+ self,
87
+ model: torch.nn.Module,
88
+ optimizer: torch.optim.Optimizer,
89
+ train_dataloader: DataLoader,
90
+ ):
91
+ model, optimizer, train_dataloader, privacy_engine = self._apply_privacy(
92
+ model, optimizer, train_dataloader
93
+ )
94
+ train_dataloader.collate_fn = self._empty_collate(train_dataloader.dataset[0])
95
+ max_physical_batch_size = self.params["differential_privacy"].get(
96
+ "max_physical_batch_size", self.params["batch_size"]
97
+ )
98
+ if max_physical_batch_size != self.params["batch_size"]:
99
+ train_dataloader = self._wrap_data_loader(
100
+ data_loader=train_dataloader,
101
+ max_batch_size=max_physical_batch_size,
102
+ optimizer=optimizer,
103
+ )
104
+
105
+ return model, optimizer, train_dataloader, privacy_engine
106
+
107
+ def _apply_privacy(self, model, optimizer, train_dataloader):
108
+ privacy_engine = PrivacyEngine(
109
+ accountant=self.params["differential_privacy"]["accountant"],
110
+ secure_mode=self.params["differential_privacy"]["secure_mode"],
111
+ )
112
+ epsilon = self.params["differential_privacy"].get("epsilon")
113
+
114
+ if epsilon is not None:
115
+ (
116
+ model,
117
+ optimizer,
118
+ train_dataloader,
119
+ ) = privacy_engine.make_private_with_epsilon(
120
+ module=model,
121
+ optimizer=optimizer,
122
+ data_loader=train_dataloader,
123
+ max_grad_norm=self.params["differential_privacy"]["max_grad_norm"],
124
+ epochs=self.params["num_epochs"],
125
+ target_epsilon=self.params["differential_privacy"]["epsilon"],
126
+ target_delta=self.params["differential_privacy"]["delta"],
127
+ )
128
+ else:
129
+ model, optimizer, train_dataloader = privacy_engine.make_private(
130
+ module=model,
131
+ optimizer=optimizer,
132
+ data_loader=train_dataloader,
133
+ noise_multiplier=self.params["differential_privacy"][
134
+ "noise_multiplier"
135
+ ],
136
+ max_grad_norm=self.params["differential_privacy"]["max_grad_norm"],
137
+ )
138
+ return model, optimizer, train_dataloader, privacy_engine
139
+
140
+ def _empty_collate(
141
+ self,
142
+ item_example: Union[
143
+ torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str
144
+ ],
145
+ ) -> Callable:
146
+ """
147
+ Creates a new collate function that behave same as default pytorch one,
148
+ but can process the empty batches.
149
+
150
+ Args:
151
+ item_example (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): An example item from the dataset.
152
+
153
+ Returns:
154
+ Callable: function that should replace dataloader collate: `dataloader.collate_fn = empty_collate(...)`
155
+ """
156
+
157
+ def custom_collate(batch, _empty_batch_value):
158
+ if len(batch) > 0:
159
+ return default_collate(batch) # default behavior
160
+ else:
161
+ return copy.copy(_empty_batch_value)
162
+
163
+ empty_batch_value = self._build_empty_batch_value(item_example)
164
+
165
+ return partial(custom_collate, _empty_batch_value=empty_batch_value)
166
+
167
+ def _build_empty_batch_value(
168
+ self,
169
+ sample: Union[
170
+ torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str
171
+ ],
172
+ ):
173
+ """
174
+ Build an empty batch value from a sample. This function is used to create a placeholder for empty batches in an iteration. Inspired from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/collate.py#L108. The key difference is that pytorch `collate` has to traverse batch of objects AND unite its fields to lists, while this function traverse a single item AND creates an "empty" version of the batch.
175
+
176
+ Args:
177
+ sample (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): A sample from the dataset.
178
+
179
+ Raises:
180
+ TypeError: If the data type is not supported.
181
+
182
+ Returns:
183
+ Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]: An empty batch value.
184
+ """
185
+ if isinstance(sample, torch.Tensor):
186
+ # Create an empty tensor with the same shape except for the zeroed batch dimension.
187
+ return torch.empty((0,) + sample.shape)
188
+ elif isinstance(sample, np.ndarray):
189
+ # Create an empty tensor from a numpy array, also with the zeroed batch dimension.
190
+ return torch.empty(
191
+ (0,) + sample.shape, dtype=torch.from_numpy(sample).dtype
192
+ )
193
+ elif isinstance(sample, abc.Mapping):
194
+ # Recursively handle dictionary-like objects.
195
+ return {
196
+ key: self._build_empty_batch_value(value)
197
+ for key, value in sample.items()
198
+ }
199
+ elif isinstance(sample, tuple) and hasattr(sample, "_fields"): # namedtuple
200
+ return type(sample)(
201
+ *(self._build_empty_batch_value(item) for item in sample)
202
+ )
203
+ elif isinstance(sample, abc.Sequence) and not isinstance(sample, str):
204
+ # Handle lists and tuples, but exclude strings.
205
+ return [self._build_empty_batch_value(item) for item in sample]
206
+ elif isinstance(sample, (int, float, str)):
207
+ # Return an empty list for basic data types.
208
+ return []
209
+ else:
210
+ raise TypeError(f"Unsupported data type: {type(sample)}")
211
+
212
+ def _wrap_data_loader(
213
+ self, data_loader: DataLoader, max_batch_size: int, optimizer: DPOptimizer
214
+ ):
215
+ """
216
+ Replaces batch_sampler in the input data loader with ``BatchSplittingSampler``
217
+
218
+ Args:
219
+ data_loader: Wrapper DataLoader
220
+ max_batch_size: max physical batch size we want to emit
221
+ optimizer: DPOptimizer instance used for training
222
+
223
+ Returns:
224
+ New DataLoader instance with batch_sampler wrapped in ``BatchSplittingSampler``
225
+ """
226
+
227
+ return DataLoader(
228
+ dataset=data_loader.dataset,
229
+ batch_sampler=BatchSplittingSampler(
230
+ sampler=data_loader.batch_sampler,
231
+ max_batch_size=max_batch_size,
232
+ optimizer=optimizer,
233
+ ),
234
+ num_workers=data_loader.num_workers,
235
+ collate_fn=data_loader.collate_fn,
236
+ pin_memory=data_loader.pin_memory,
237
+ timeout=data_loader.timeout,
238
+ worker_init_fn=data_loader.worker_init_fn,
239
+ multiprocessing_context=data_loader.multiprocessing_context,
240
+ generator=data_loader.generator,
241
+ prefetch_factor=data_loader.prefetch_factor,
242
+ persistent_workers=data_loader.persistent_workers,
243
+ )
@@ -6,7 +6,8 @@ from .wrap_torch import (
6
6
  exp,
7
7
  step,
8
8
  reduce_on_plateau,
9
- cosineannealing,
9
+ cosineannealingwarmrestarts,
10
+ cosineannealingLR,
10
11
  )
11
12
 
12
13
  from .wrap_monai import warmupcosineschedule
@@ -24,7 +25,9 @@ global_schedulers_dict = {
24
25
  "reduce-on-plateau": reduce_on_plateau,
25
26
  "plateau": reduce_on_plateau,
26
27
  "reduceonplateau": reduce_on_plateau,
27
- "cosineannealing": cosineannealing,
28
+ "cosineannealing": cosineannealingwarmrestarts,
29
+ "cosineannealingwarmrestarts": cosineannealingwarmrestarts,
30
+ "cosineannealinglr": cosineannealingLR,
28
31
  "warmupcosineschedule": warmupcosineschedule,
29
32
  "wcs": warmupcosineschedule,
30
33
  }
@@ -38,6 +41,10 @@ def get_scheduler(params):
38
41
  params (dict): The parameters' dictionary.
39
42
 
40
43
  Returns:
41
- model (object): The scheduler definition.
44
+ scheduler (object): The scheduler definition.
42
45
  """
43
- return global_schedulers_dict[params["scheduler"]["type"]](params)
46
+ chosen_scheduler = params["scheduler"]["type"].lower()
47
+ assert (
48
+ chosen_scheduler in global_schedulers_dict
49
+ ), f"Could not find the requested scheduler '{params['scheduler']['type']}'"
50
+ return global_schedulers_dict[chosen_scheduler](params)
@@ -5,6 +5,7 @@ from torch.optim.lr_scheduler import (
5
5
  StepLR,
6
6
  ReduceLROnPlateau,
7
7
  CosineAnnealingWarmRestarts,
8
+ CosineAnnealingLR,
8
9
  )
9
10
  import math
10
11
 
@@ -169,14 +170,25 @@ def reduce_on_plateau(parameters):
169
170
  )
170
171
 
171
172
 
172
- def cosineannealing(parameters):
173
+ def cosineannealingwarmrestarts(parameters):
173
174
  parameters["scheduler"]["T_0"] = parameters["scheduler"].get("T_0", 5)
174
175
  parameters["scheduler"]["T_mult"] = parameters["scheduler"].get("T_mult", 1)
175
- parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr", 0.001)
176
+ parameters["scheduler"]["eta_min"] = parameters["scheduler"].get("eta_min", 0.001)
176
177
 
177
178
  return CosineAnnealingWarmRestarts(
178
179
  parameters["optimizer_object"],
179
180
  T_0=parameters["scheduler"]["T_0"],
180
181
  T_mult=parameters["scheduler"]["T_mult"],
181
- eta_min=parameters["scheduler"]["min_lr"],
182
+ eta_min=parameters["scheduler"]["eta_min"],
183
+ )
184
+
185
+
186
+ def cosineannealingLR(parameters):
187
+ parameters["scheduler"]["T_max"] = parameters["scheduler"].get("T_max", 50)
188
+ parameters["scheduler"]["eta_min"] = parameters["scheduler"].get("eta_min", 0.001)
189
+
190
+ return CosineAnnealingLR(
191
+ parameters["optimizer_object"],
192
+ T_max=parameters["scheduler"]["T_max"],
193
+ eta_min=parameters["scheduler"]["eta_min"],
182
194
  )
@@ -1,20 +1,31 @@
1
+ import os
2
+ import yaml
3
+
4
+ # codacy ignore python-use-of-pickle: Pickle usage is safe in this context (local data only).
5
+ import pickle
6
+ import shutil
1
7
  import pandas as pd
2
- import os, pickle, shutil
3
8
  from pathlib import Path
9
+ from warnings import warn
10
+
11
+ import lightning.pytorch as pl
12
+ from lightning.pytorch.profilers import PyTorchProfiler
13
+ from lightning.pytorch.tuner import Tuner as LightningTuner
4
14
 
5
- from GANDLF.compute import training_loop
6
15
  from GANDLF.utils import get_dataframe, split_data
16
+ from GANDLF.models.lightning_module import GandlfLightningModule
17
+ from GANDLF.data.lightning_datamodule import GandlfTrainingDatamodule
7
18
 
8
- import yaml
19
+ from typing import Optional
9
20
 
10
21
 
11
22
  def TrainingManager(
12
23
  dataframe: pd.DataFrame,
13
24
  outputDir: str,
14
25
  parameters: dict,
15
- device: str,
16
26
  resume: bool,
17
27
  reset: bool,
28
+ profile: Optional[bool] = False,
18
29
  ) -> None:
19
30
  """
20
31
  This is the training manager that ties all the training functionality together
@@ -23,10 +34,14 @@ def TrainingManager(
23
34
  dataframe (pandas.DataFrame): The full data from CSV.
24
35
  outputDir (str): The main output directory.
25
36
  parameters (dict): The parameters dictionary.
26
- device (str): The device to perform computations on.
27
37
  resume (bool): Whether the previous run will be resumed or not.
28
38
  reset (bool): Whether the previous run will be reset or not.
39
+ profile(bool): Whether we want the profile activity or not. Defaults to False.
40
+
29
41
  """
42
+
43
+ if "output_dir" not in parameters:
44
+ parameters["output_dir"] = outputDir
30
45
  if reset:
31
46
  shutil.rmtree(outputDir)
32
47
  Path(outputDir).mkdir(parents=True, exist_ok=True)
@@ -95,45 +110,79 @@ def TrainingManager(
95
110
  # read the data from the pickle if present
96
111
  data_dict[data_type] = get_dataframe(currentDataPickle)
97
112
 
98
- # parallel_compute_command is an empty string, thus no parallel computing requested
99
- if not parameters["parallel_compute_command"]:
100
- training_loop(
101
- training_data=data_dict["training"],
102
- validation_data=data_dict["validation"],
103
- output_dir=currentValidationOutputFolder,
104
- device=device,
105
- params=parameters,
106
- testing_data=data_dict["testing"],
107
- )
113
+ # Dataloader initialization - should be extracted somewhere else (preferably abstracted away)
114
+ datamodule = GandlfTrainingDatamodule(data_dict_files, parameters)
115
+ parameters = datamodule.updated_parameters_dict
116
+
117
+ # This entire section should be handled in config parser
118
+
119
+ accelerator = parameters.get("accelerator", "auto")
120
+ allowed_accelerators = ["cpu", "gpu", "auto"]
121
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
122
+ assert (
123
+ accelerator in allowed_accelerators
124
+ ), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
125
+ strategy = parameters.get("strategy", "auto")
126
+ allowed_strategies = ["auto", "ddp"]
127
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
128
+ assert (
129
+ strategy in allowed_strategies
130
+ ), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
131
+ precision = parameters.get("precision", "32")
132
+ allowed_precisions = [
133
+ "64",
134
+ "64-true",
135
+ "32",
136
+ "32-true",
137
+ "16",
138
+ "16-mixed",
139
+ "bf16",
140
+ "bf16-mixed",
141
+ ]
142
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
143
+ assert (
144
+ precision in allowed_precisions
145
+ ), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
146
+
147
+ warn(
148
+ f"Configured to use {accelerator} with {strategy} for training, but current development configuration will force single-device only training."
149
+ )
150
+ trainer = pl.Trainer(
151
+ accelerator=accelerator,
152
+ strategy=strategy,
153
+ fast_dev_run=False,
154
+ devices=parameters.get("devices", "auto"),
155
+ num_nodes=parameters.get("num_nodes", 1),
156
+ precision=precision,
157
+ gradient_clip_algorithm=parameters["clip_mode"],
158
+ gradient_clip_val=parameters["clip_grad"],
159
+ max_epochs=parameters["num_epochs"],
160
+ sync_batchnorm=False,
161
+ enable_checkpointing=False,
162
+ logger=False,
163
+ num_sanity_val_steps=0,
164
+ profiler=PyTorchProfiler(sort_by="cpu_time_total", row_limit=10)
165
+ if profile
166
+ else None,
167
+ )
168
+
169
+ lightning_module = GandlfLightningModule(
170
+ parameters, output_dir=currentValidationOutputFolder
171
+ )
108
172
 
109
- else:
110
- # call hpc command here
111
- parallel_compute_command_actual = parameters[
112
- "parallel_compute_command"
113
- ].replace("${outputDir}", currentValidationOutputFolder)
114
-
115
- assert (
116
- "python" in parallel_compute_command_actual
117
- ), "The 'parallel_compute_command_actual' needs to have the python from the virtual environment, which is usually '${GANDLF_dir}/venv/bin/python'"
118
-
119
- command = (
120
- parallel_compute_command_actual
121
- + " -m GANDLF.training_loop -train_loader_pickle "
122
- + data_dict_files["training"]
123
- + " -val_loader_pickle "
124
- + data_dict_files["validation"]
125
- + " -parameter_pickle "
126
- + currentModelConfigPickle
127
- + " -device "
128
- + str(device)
129
- + " -outputDir "
130
- + currentValidationOutputFolder
131
- + " -testing_loader_pickle "
132
- + data_dict_files["testing"]
173
+ if parameters.get("auto_batch_size_find", False):
174
+ LightningTuner(trainer).scale_batch_size(
175
+ lightning_module, datamodule=datamodule
133
176
  )
134
177
 
135
- print("Running command: ", command, flush=True)
136
- os.system(command, flush=True)
178
+ if parameters.get("auto_lr_find", False):
179
+ LightningTuner(trainer).lr_find(lightning_module, datamodule=datamodule)
180
+
181
+ trainer.fit(lightning_module, datamodule=datamodule)
182
+
183
+ testing_data = data_dict_files.get("testing", None)
184
+ if testing_data:
185
+ trainer.test(lightning_module, datamodule=datamodule)
137
186
 
138
187
 
139
188
  def TrainingManager_split(
@@ -142,9 +191,9 @@ def TrainingManager_split(
142
191
  dataframe_testing: pd.DataFrame,
143
192
  outputDir: str,
144
193
  parameters: dict,
145
- device: str,
146
194
  resume: bool,
147
195
  reset: bool,
196
+ profile: Optional[bool] = False,
148
197
  ):
149
198
  """
150
199
  This is the training manager that ties all the training functionality together
@@ -155,9 +204,10 @@ def TrainingManager_split(
155
204
  dataframe_testing (pd.DataFrame): The testing data from CSV.
156
205
  outputDir (str): The main output directory.
157
206
  parameters (dict): The parameters dictionary.
158
- device (str): The device to perform computations on.
159
207
  resume (bool): Whether the previous run will be resumed or not.
160
208
  reset (bool): Whether the previous run will be reset or not.
209
+ profile(bool): Whether the we want the profile activity or not. Defaults to False.
210
+
161
211
  """
162
212
  currentModelConfigPickle = os.path.join(outputDir, "parameters.pkl")
163
213
  currentModelConfigYaml = os.path.join(outputDir, "config.yaml")
@@ -178,11 +228,71 @@ def TrainingManager_split(
178
228
  with open(currentModelConfigYaml, "w") as handle:
179
229
  yaml.dump(parameters, handle, default_flow_style=False)
180
230
 
181
- training_loop(
182
- training_data=dataframe_train,
183
- validation_data=dataframe_validation,
184
- output_dir=outputDir,
185
- device=device,
186
- params=parameters,
187
- testing_data=dataframe_testing,
231
+ data_dict_files = {
232
+ "training": dataframe_train,
233
+ "validation": dataframe_validation,
234
+ "testing": dataframe_testing,
235
+ }
236
+
237
+ datamodule = GandlfTrainingDatamodule(data_dict_files, parameters)
238
+ parameters = datamodule.updated_parameters_dict
239
+
240
+ # This entire section should be handled in config parser
241
+
242
+ accelerator = parameters.get("accelerator", "auto")
243
+ allowed_accelerators = ["cpu", "gpu", "auto"]
244
+ assert (
245
+ accelerator in allowed_accelerators
246
+ ), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
247
+ strategy = parameters.get("strategy", "auto")
248
+ allowed_strategies = ["auto", "ddp"]
249
+ assert (
250
+ strategy in allowed_strategies
251
+ ), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
252
+ precision = parameters.get("precision", "32")
253
+ allowed_precisions = [
254
+ "64",
255
+ "64-true",
256
+ "32",
257
+ "32-true",
258
+ "16",
259
+ "16-mixed",
260
+ "bf16",
261
+ "bf16-mixed",
262
+ ]
263
+ assert (
264
+ precision in allowed_precisions
265
+ ), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
266
+
267
+ trainer = pl.Trainer(
268
+ accelerator=accelerator,
269
+ strategy=strategy,
270
+ fast_dev_run=False,
271
+ devices=parameters.get("devices", "auto"),
272
+ num_nodes=parameters.get("num_nodes", 1),
273
+ precision=precision,
274
+ gradient_clip_algorithm=parameters["clip_mode"],
275
+ gradient_clip_val=parameters["clip_grad"],
276
+ max_epochs=parameters["num_epochs"],
277
+ sync_batchnorm=False,
278
+ enable_checkpointing=False,
279
+ logger=False,
280
+ num_sanity_val_steps=0,
281
+ profiler=PyTorchProfiler(sort_by="cpu_time_total", row_limit=10)
282
+ if profile
283
+ else None,
188
284
  )
285
+ lightning_module = GandlfLightningModule(parameters, output_dir=outputDir)
286
+
287
+ if parameters.get("auto_batch_size_find", False):
288
+ LightningTuner(trainer).scale_batch_size(
289
+ lightning_module, datamodule=datamodule
290
+ )
291
+
292
+ if parameters.get("auto_lr_find", False):
293
+ LightningTuner(trainer).lr_find(lightning_module, datamodule=datamodule)
294
+
295
+ trainer.fit(lightning_module, datamodule=datamodule)
296
+
297
+ if dataframe_testing is not None:
298
+ trainer.test(lightning_module, datamodule=datamodule)
GANDLF/utils/__init__.py CHANGED
@@ -7,9 +7,11 @@ from .imaging import (
7
7
  resize_image,
8
8
  resample_image,
9
9
  perform_sanity_check_on_subject,
10
+ sanity_check_on_file_readers,
10
11
  write_training_patches,
11
12
  get_correct_padding_size,
12
13
  applyCustomColorMap,
14
+ MapSaver,
13
15
  )
14
16
 
15
17
  from .tensor import (
@@ -58,9 +60,9 @@ from .generic import (
58
60
  )
59
61
 
60
62
  from .modelio import (
61
- best_model_path_end,
62
- latest_model_path_end,
63
- initial_model_path_end,
63
+ BEST_MODEL_PATH_END,
64
+ LATEST_MODEL_PATH_END,
65
+ INITIAL_MODEL_PATH_END,
64
66
  load_model,
65
67
  load_ov_model,
66
68
  save_model,