GANDLF 0.1.3.dev20250318__py3-none-any.whl → 0.1.4.dev20250502__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (55) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +21 -0
  3. GANDLF/cli/main_run.py +4 -12
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +26 -716
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +90 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +29 -35
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +50 -0
  37. GANDLF/metrics/segmentation_panoptica.py +35 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +6 -2
  43. GANDLF/training_manager.py +159 -69
  44. GANDLF/utils/__init__.py +4 -3
  45. GANDLF/utils/imaging.py +121 -2
  46. GANDLF/utils/modelio.py +9 -7
  47. GANDLF/utils/pred_target_processors.py +71 -0
  48. GANDLF/utils/write_parse.py +1 -1
  49. GANDLF/version.py +1 -1
  50. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/METADATA +14 -8
  51. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/RECORD +55 -32
  52. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/WHEEL +1 -1
  53. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/entry_points.txt +0 -0
  54. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info/licenses}/LICENSE +0 -0
  55. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/top_level.txt +0 -0
@@ -48,13 +48,9 @@ def get_optimizer(params):
48
48
  optimizer (torch.optim.Optimizer): An instance of the specified optimizer.
49
49
 
50
50
  """
51
- # Retrieve the optimizer type from the input parameters
52
- optimizer_type = params["optimizer"]["type"]
53
51
 
52
+ chosen_optimizer = params["optimizer"]["type"]
54
53
  assert (
55
- optimizer_type in global_optimizer_dict
56
- ), f"Optimizer type {optimizer_type} not found"
57
-
58
- # Create the optimizer instance using the specified type and input parameters
59
- optimizer_function = global_optimizer_dict[optimizer_type]
60
- return optimizer_function(params)
54
+ chosen_optimizer in global_optimizer_dict
55
+ ), f"Could not find the requested optimizer '{params['optimizer']['type']}'"
56
+ return global_optimizer_dict[chosen_optimizer](params)
@@ -0,0 +1,243 @@
1
+ import torch
2
+ from opacus import PrivacyEngine
3
+
4
+ import collections.abc as abc
5
+ from functools import partial
6
+ from torch.utils.data._utils.collate import default_collate
7
+ from typing import Union, Callable
8
+ import copy
9
+ from opacus.optimizers import DPOptimizer
10
+ from opacus.utils.uniform_sampler import (
11
+ DistributedUniformWithReplacementSampler,
12
+ UniformWithReplacementSampler,
13
+ )
14
+ from torch.utils.data import BatchSampler, DataLoader, Sampler
15
+ import math
16
+ import numpy as np
17
+
18
+ from typing import List
19
+
20
+
21
+ class BatchSplittingSampler(Sampler[List[int]]):
22
+ """
23
+ Samples according to the underlying instance of ``Sampler``, but splits
24
+ the index sequences into smaller chunks.
25
+
26
+ Used to split large logical batches into physical batches of a smaller size,
27
+ while coordinating with DPOptimizer when the logical batch has ended.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ sampler: Sampler[List[int]],
34
+ max_batch_size: int,
35
+ optimizer: DPOptimizer,
36
+ ):
37
+ """
38
+
39
+ Args:
40
+ sampler: Wrapped Sampler instance
41
+ max_batch_size: Max size of emitted chunk of indices
42
+ optimizer: optimizer instance to notify when the logical batch is over
43
+ """
44
+ self.sampler = sampler
45
+ self.max_batch_size = max_batch_size
46
+ self.optimizer = optimizer
47
+
48
+ def __iter__(self):
49
+ for batch_idxs in self.sampler:
50
+ if len(batch_idxs) == 0:
51
+ self.optimizer.signal_skip_step(do_skip=False)
52
+ yield []
53
+ continue
54
+
55
+ split_idxs = np.array_split(
56
+ batch_idxs, math.ceil(len(batch_idxs) / self.max_batch_size)
57
+ )
58
+ split_idxs = [s.tolist() for s in split_idxs]
59
+ for x in split_idxs[:-1]:
60
+ self.optimizer.signal_skip_step(do_skip=True)
61
+ yield x
62
+ self.optimizer.signal_skip_step(do_skip=False)
63
+ yield split_idxs[-1]
64
+
65
+ def __len__(self):
66
+ if isinstance(self.sampler, BatchSampler):
67
+ return math.ceil(
68
+ len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
69
+ )
70
+ elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
71
+ self.sampler, DistributedUniformWithReplacementSampler
72
+ ):
73
+ expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
74
+ return math.ceil(
75
+ len(self.sampler) * (expected_batch_size / self.max_batch_size)
76
+ )
77
+
78
+ return len(self.sampler)
79
+
80
+
81
+ class OpacusAnonymizationManager:
82
+ def __init__(self, params):
83
+ self.params = params
84
+
85
+ def apply_privacy(
86
+ self,
87
+ model: torch.nn.Module,
88
+ optimizer: torch.optim.Optimizer,
89
+ train_dataloader: DataLoader,
90
+ ):
91
+ model, optimizer, train_dataloader, privacy_engine = self._apply_privacy(
92
+ model, optimizer, train_dataloader
93
+ )
94
+ train_dataloader.collate_fn = self._empty_collate(train_dataloader.dataset[0])
95
+ max_physical_batch_size = self.params["differential_privacy"].get(
96
+ "max_physical_batch_size", self.params["batch_size"]
97
+ )
98
+ if max_physical_batch_size != self.params["batch_size"]:
99
+ train_dataloader = self._wrap_data_loader(
100
+ data_loader=train_dataloader,
101
+ max_batch_size=max_physical_batch_size,
102
+ optimizer=optimizer,
103
+ )
104
+
105
+ return model, optimizer, train_dataloader, privacy_engine
106
+
107
+ def _apply_privacy(self, model, optimizer, train_dataloader):
108
+ privacy_engine = PrivacyEngine(
109
+ accountant=self.params["differential_privacy"]["accountant"],
110
+ secure_mode=self.params["differential_privacy"]["secure_mode"],
111
+ )
112
+ epsilon = self.params["differential_privacy"].get("epsilon")
113
+
114
+ if epsilon is not None:
115
+ (
116
+ model,
117
+ optimizer,
118
+ train_dataloader,
119
+ ) = privacy_engine.make_private_with_epsilon(
120
+ module=model,
121
+ optimizer=optimizer,
122
+ data_loader=train_dataloader,
123
+ max_grad_norm=self.params["differential_privacy"]["max_grad_norm"],
124
+ epochs=self.params["num_epochs"],
125
+ target_epsilon=self.params["differential_privacy"]["epsilon"],
126
+ target_delta=self.params["differential_privacy"]["delta"],
127
+ )
128
+ else:
129
+ model, optimizer, train_dataloader = privacy_engine.make_private(
130
+ module=model,
131
+ optimizer=optimizer,
132
+ data_loader=train_dataloader,
133
+ noise_multiplier=self.params["differential_privacy"][
134
+ "noise_multiplier"
135
+ ],
136
+ max_grad_norm=self.params["differential_privacy"]["max_grad_norm"],
137
+ )
138
+ return model, optimizer, train_dataloader, privacy_engine
139
+
140
+ def _empty_collate(
141
+ self,
142
+ item_example: Union[
143
+ torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str
144
+ ],
145
+ ) -> Callable:
146
+ """
147
+ Creates a new collate function that behave same as default pytorch one,
148
+ but can process the empty batches.
149
+
150
+ Args:
151
+ item_example (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): An example item from the dataset.
152
+
153
+ Returns:
154
+ Callable: function that should replace dataloader collate: `dataloader.collate_fn = empty_collate(...)`
155
+ """
156
+
157
+ def custom_collate(batch, _empty_batch_value):
158
+ if len(batch) > 0:
159
+ return default_collate(batch) # default behavior
160
+ else:
161
+ return copy.copy(_empty_batch_value)
162
+
163
+ empty_batch_value = self._build_empty_batch_value(item_example)
164
+
165
+ return partial(custom_collate, _empty_batch_value=empty_batch_value)
166
+
167
+ def _build_empty_batch_value(
168
+ self,
169
+ sample: Union[
170
+ torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str
171
+ ],
172
+ ):
173
+ """
174
+ Build an empty batch value from a sample. This function is used to create a placeholder for empty batches in an iteration. Inspired from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/collate.py#L108. The key difference is that pytorch `collate` has to traverse batch of objects AND unite its fields to lists, while this function traverse a single item AND creates an "empty" version of the batch.
175
+
176
+ Args:
177
+ sample (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): A sample from the dataset.
178
+
179
+ Raises:
180
+ TypeError: If the data type is not supported.
181
+
182
+ Returns:
183
+ Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]: An empty batch value.
184
+ """
185
+ if isinstance(sample, torch.Tensor):
186
+ # Create an empty tensor with the same shape except for the zeroed batch dimension.
187
+ return torch.empty((0,) + sample.shape)
188
+ elif isinstance(sample, np.ndarray):
189
+ # Create an empty tensor from a numpy array, also with the zeroed batch dimension.
190
+ return torch.empty(
191
+ (0,) + sample.shape, dtype=torch.from_numpy(sample).dtype
192
+ )
193
+ elif isinstance(sample, abc.Mapping):
194
+ # Recursively handle dictionary-like objects.
195
+ return {
196
+ key: self._build_empty_batch_value(value)
197
+ for key, value in sample.items()
198
+ }
199
+ elif isinstance(sample, tuple) and hasattr(sample, "_fields"): # namedtuple
200
+ return type(sample)(
201
+ *(self._build_empty_batch_value(item) for item in sample)
202
+ )
203
+ elif isinstance(sample, abc.Sequence) and not isinstance(sample, str):
204
+ # Handle lists and tuples, but exclude strings.
205
+ return [self._build_empty_batch_value(item) for item in sample]
206
+ elif isinstance(sample, (int, float, str)):
207
+ # Return an empty list for basic data types.
208
+ return []
209
+ else:
210
+ raise TypeError(f"Unsupported data type: {type(sample)}")
211
+
212
+ def _wrap_data_loader(
213
+ self, data_loader: DataLoader, max_batch_size: int, optimizer: DPOptimizer
214
+ ):
215
+ """
216
+ Replaces batch_sampler in the input data loader with ``BatchSplittingSampler``
217
+
218
+ Args:
219
+ data_loader: Wrapper DataLoader
220
+ max_batch_size: max physical batch size we want to emit
221
+ optimizer: DPOptimizer instance used for training
222
+
223
+ Returns:
224
+ New DataLoader instance with batch_sampler wrapped in ``BatchSplittingSampler``
225
+ """
226
+
227
+ return DataLoader(
228
+ dataset=data_loader.dataset,
229
+ batch_sampler=BatchSplittingSampler(
230
+ sampler=data_loader.batch_sampler,
231
+ max_batch_size=max_batch_size,
232
+ optimizer=optimizer,
233
+ ),
234
+ num_workers=data_loader.num_workers,
235
+ collate_fn=data_loader.collate_fn,
236
+ pin_memory=data_loader.pin_memory,
237
+ timeout=data_loader.timeout,
238
+ worker_init_fn=data_loader.worker_init_fn,
239
+ multiprocessing_context=data_loader.multiprocessing_context,
240
+ generator=data_loader.generator,
241
+ prefetch_factor=data_loader.prefetch_factor,
242
+ persistent_workers=data_loader.persistent_workers,
243
+ )
@@ -38,6 +38,10 @@ def get_scheduler(params):
38
38
  params (dict): The parameters' dictionary.
39
39
 
40
40
  Returns:
41
- model (object): The scheduler definition.
41
+ scheduler (object): The scheduler definition.
42
42
  """
43
- return global_schedulers_dict[params["scheduler"]["type"]](params)
43
+ chosen_scheduler = params["scheduler"]["type"].lower()
44
+ assert (
45
+ chosen_scheduler in global_schedulers_dict
46
+ ), f"Could not find the requested scheduler '{params['scheduler']['type']}'"
47
+ return global_schedulers_dict[chosen_scheduler](params)
@@ -1,21 +1,31 @@
1
+ import os
2
+ import yaml
3
+
4
+ # codacy ignore python-use-of-pickle: Pickle usage is safe in this context (local data only).
5
+ import pickle
6
+ import shutil
1
7
  import pandas as pd
2
- import os, pickle, shutil
3
8
  from pathlib import Path
4
- from torch.profiler import profile, ProfilerActivity
9
+ from warnings import warn
10
+
11
+ import lightning.pytorch as pl
12
+ from lightning.pytorch.profilers import PyTorchProfiler
13
+ from lightning.pytorch.tuner import Tuner as LightningTuner
5
14
 
6
- from GANDLF.compute import training_loop
7
15
  from GANDLF.utils import get_dataframe, split_data
16
+ from GANDLF.models.lightning_module import GandlfLightningModule
17
+ from GANDLF.data.lightning_datamodule import GandlfTrainingDatamodule
8
18
 
9
- import yaml
19
+ from typing import Optional
10
20
 
11
21
 
12
22
  def TrainingManager(
13
23
  dataframe: pd.DataFrame,
14
24
  outputDir: str,
15
25
  parameters: dict,
16
- device: str,
17
26
  resume: bool,
18
27
  reset: bool,
28
+ profile: Optional[bool] = False,
19
29
  ) -> None:
20
30
  """
21
31
  This is the training manager that ties all the training functionality together
@@ -24,10 +34,14 @@ def TrainingManager(
24
34
  dataframe (pandas.DataFrame): The full data from CSV.
25
35
  outputDir (str): The main output directory.
26
36
  parameters (dict): The parameters dictionary.
27
- device (str): The device to perform computations on.
28
37
  resume (bool): Whether the previous run will be resumed or not.
29
38
  reset (bool): Whether the previous run will be reset or not.
39
+ profile(bool): Whether we want the profile activity or not. Defaults to False.
40
+
30
41
  """
42
+
43
+ if "output_dir" not in parameters:
44
+ parameters["output_dir"] = outputDir
31
45
  if reset:
32
46
  shutil.rmtree(outputDir)
33
47
  Path(outputDir).mkdir(parents=True, exist_ok=True)
@@ -96,45 +110,79 @@ def TrainingManager(
96
110
  # read the data from the pickle if present
97
111
  data_dict[data_type] = get_dataframe(currentDataPickle)
98
112
 
99
- # parallel_compute_command is an empty string, thus no parallel computing requested
100
- if not parameters["parallel_compute_command"]:
101
- training_loop(
102
- training_data=data_dict["training"],
103
- validation_data=data_dict["validation"],
104
- output_dir=currentValidationOutputFolder,
105
- device=device,
106
- params=parameters,
107
- testing_data=data_dict["testing"],
108
- )
113
+ # Dataloader initialization - should be extracted somewhere else (preferably abstracted away)
114
+ datamodule = GandlfTrainingDatamodule(data_dict_files, parameters)
115
+ parameters = datamodule.updated_parameters_dict
116
+
117
+ # This entire section should be handled in config parser
118
+
119
+ accelerator = parameters.get("accelerator", "auto")
120
+ allowed_accelerators = ["cpu", "gpu", "auto"]
121
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
122
+ assert (
123
+ accelerator in allowed_accelerators
124
+ ), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
125
+ strategy = parameters.get("strategy", "auto")
126
+ allowed_strategies = ["auto", "ddp"]
127
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
128
+ assert (
129
+ strategy in allowed_strategies
130
+ ), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
131
+ precision = parameters.get("precision", "32")
132
+ allowed_precisions = [
133
+ "64",
134
+ "64-true",
135
+ "32",
136
+ "32-true",
137
+ "16",
138
+ "16-mixed",
139
+ "bf16",
140
+ "bf16-mixed",
141
+ ]
142
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
143
+ assert (
144
+ precision in allowed_precisions
145
+ ), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
146
+
147
+ warn(
148
+ f"Configured to use {accelerator} with {strategy} for training, but current development configuration will force single-device only training."
149
+ )
150
+ trainer = pl.Trainer(
151
+ accelerator=accelerator,
152
+ strategy=strategy,
153
+ fast_dev_run=False,
154
+ devices=parameters.get("devices", "auto"),
155
+ num_nodes=parameters.get("num_nodes", 1),
156
+ precision=precision,
157
+ gradient_clip_algorithm=parameters["clip_mode"],
158
+ gradient_clip_val=parameters["clip_grad"],
159
+ max_epochs=parameters["num_epochs"],
160
+ sync_batchnorm=False,
161
+ enable_checkpointing=False,
162
+ logger=False,
163
+ num_sanity_val_steps=0,
164
+ profiler=PyTorchProfiler(sort_by="cpu_time_total", row_limit=10)
165
+ if profile
166
+ else None,
167
+ )
168
+
169
+ lightning_module = GandlfLightningModule(
170
+ parameters, output_dir=currentValidationOutputFolder
171
+ )
109
172
 
110
- else:
111
- # call hpc command here
112
- parallel_compute_command_actual = parameters[
113
- "parallel_compute_command"
114
- ].replace("${outputDir}", currentValidationOutputFolder)
115
-
116
- assert (
117
- "python" in parallel_compute_command_actual
118
- ), "The 'parallel_compute_command_actual' needs to have the python from the virtual environment, which is usually '${GANDLF_dir}/venv/bin/python'"
119
-
120
- command = (
121
- parallel_compute_command_actual
122
- + " -m GANDLF.training_loop -train_loader_pickle "
123
- + data_dict_files["training"]
124
- + " -val_loader_pickle "
125
- + data_dict_files["validation"]
126
- + " -parameter_pickle "
127
- + currentModelConfigPickle
128
- + " -device "
129
- + str(device)
130
- + " -outputDir "
131
- + currentValidationOutputFolder
132
- + " -testing_loader_pickle "
133
- + data_dict_files["testing"]
173
+ if parameters.get("auto_batch_size_find", False):
174
+ LightningTuner(trainer).scale_batch_size(
175
+ lightning_module, datamodule=datamodule
134
176
  )
135
177
 
136
- print("Running command: ", command, flush=True)
137
- os.system(command, flush=True)
178
+ if parameters.get("auto_lr_find", False):
179
+ LightningTuner(trainer).lr_find(lightning_module, datamodule=datamodule)
180
+
181
+ trainer.fit(lightning_module, datamodule=datamodule)
182
+
183
+ testing_data = data_dict_files.get("testing", None)
184
+ if testing_data:
185
+ trainer.test(lightning_module, datamodule=datamodule)
138
186
 
139
187
 
140
188
  def TrainingManager_split(
@@ -143,10 +191,9 @@ def TrainingManager_split(
143
191
  dataframe_testing: pd.DataFrame,
144
192
  outputDir: str,
145
193
  parameters: dict,
146
- device: str,
147
194
  resume: bool,
148
195
  reset: bool,
149
- _profile: bool,
196
+ profile: Optional[bool] = False,
150
197
  ):
151
198
  """
152
199
  This is the training manager that ties all the training functionality together
@@ -157,10 +204,9 @@ def TrainingManager_split(
157
204
  dataframe_testing (pd.DataFrame): The testing data from CSV.
158
205
  outputDir (str): The main output directory.
159
206
  parameters (dict): The parameters dictionary.
160
- device (str): The device to perform computations on.
161
207
  resume (bool): Whether the previous run will be resumed or not.
162
208
  reset (bool): Whether the previous run will be reset or not.
163
- _profile(bool):Whether the we want the profile activity or not.
209
+ profile(bool): Whether the we want the profile activity or not. Defaults to False.
164
210
 
165
211
  """
166
212
  currentModelConfigPickle = os.path.join(outputDir, "parameters.pkl")
@@ -182,27 +228,71 @@ def TrainingManager_split(
182
228
  with open(currentModelConfigYaml, "w") as handle:
183
229
  yaml.dump(parameters, handle, default_flow_style=False)
184
230
 
185
- if _profile:
186
- with profile(
187
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
188
- profile_memory=True,
189
- record_shapes=True,
190
- ) as prof:
191
- training_loop(
192
- training_data=dataframe_train,
193
- validation_data=dataframe_validation,
194
- output_dir=outputDir,
195
- device=device,
196
- params=parameters,
197
- testing_data=dataframe_testing,
198
- )
199
- print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
200
- else:
201
- training_loop(
202
- training_data=dataframe_train,
203
- validation_data=dataframe_validation,
204
- output_dir=outputDir,
205
- device=device,
206
- params=parameters,
207
- testing_data=dataframe_testing,
231
+ data_dict_files = {
232
+ "training": dataframe_train,
233
+ "validation": dataframe_validation,
234
+ "testing": dataframe_testing,
235
+ }
236
+
237
+ datamodule = GandlfTrainingDatamodule(data_dict_files, parameters)
238
+ parameters = datamodule.updated_parameters_dict
239
+
240
+ # This entire section should be handled in config parser
241
+
242
+ accelerator = parameters.get("accelerator", "auto")
243
+ allowed_accelerators = ["cpu", "gpu", "auto"]
244
+ assert (
245
+ accelerator in allowed_accelerators
246
+ ), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
247
+ strategy = parameters.get("strategy", "auto")
248
+ allowed_strategies = ["auto", "ddp"]
249
+ assert (
250
+ strategy in allowed_strategies
251
+ ), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
252
+ precision = parameters.get("precision", "32")
253
+ allowed_precisions = [
254
+ "64",
255
+ "64-true",
256
+ "32",
257
+ "32-true",
258
+ "16",
259
+ "16-mixed",
260
+ "bf16",
261
+ "bf16-mixed",
262
+ ]
263
+ assert (
264
+ precision in allowed_precisions
265
+ ), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
266
+
267
+ trainer = pl.Trainer(
268
+ accelerator=accelerator,
269
+ strategy=strategy,
270
+ fast_dev_run=False,
271
+ devices=parameters.get("devices", "auto"),
272
+ num_nodes=parameters.get("num_nodes", 1),
273
+ precision=precision,
274
+ gradient_clip_algorithm=parameters["clip_mode"],
275
+ gradient_clip_val=parameters["clip_grad"],
276
+ max_epochs=parameters["num_epochs"],
277
+ sync_batchnorm=False,
278
+ enable_checkpointing=False,
279
+ logger=False,
280
+ num_sanity_val_steps=0,
281
+ profiler=PyTorchProfiler(sort_by="cpu_time_total", row_limit=10)
282
+ if profile
283
+ else None,
284
+ )
285
+ lightning_module = GandlfLightningModule(parameters, output_dir=outputDir)
286
+
287
+ if parameters.get("auto_batch_size_find", False):
288
+ LightningTuner(trainer).scale_batch_size(
289
+ lightning_module, datamodule=datamodule
208
290
  )
291
+
292
+ if parameters.get("auto_lr_find", False):
293
+ LightningTuner(trainer).lr_find(lightning_module, datamodule=datamodule)
294
+
295
+ trainer.fit(lightning_module, datamodule=datamodule)
296
+
297
+ if dataframe_testing is not None:
298
+ trainer.test(lightning_module, datamodule=datamodule)
GANDLF/utils/__init__.py CHANGED
@@ -10,6 +10,7 @@ from .imaging import (
10
10
  write_training_patches,
11
11
  get_correct_padding_size,
12
12
  applyCustomColorMap,
13
+ MapSaver,
13
14
  )
14
15
 
15
16
  from .tensor import (
@@ -58,9 +59,9 @@ from .generic import (
58
59
  )
59
60
 
60
61
  from .modelio import (
61
- best_model_path_end,
62
- latest_model_path_end,
63
- initial_model_path_end,
62
+ BEST_MODEL_PATH_END,
63
+ LATEST_MODEL_PATH_END,
64
+ INITIAL_MODEL_PATH_END,
64
65
  load_model,
65
66
  load_ov_model,
66
67
  save_model,