GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
GANDLF/cli/deploy.py
CHANGED
|
@@ -246,7 +246,7 @@ def get_metrics_mlcube_config(
|
|
|
246
246
|
mlcube_config = yaml.safe_load(f)
|
|
247
247
|
if entrypoint_script:
|
|
248
248
|
# modify the entrypoint to run a custom script
|
|
249
|
-
mlcube_config["tasks"]["evaluate"]["entrypoint"] = "python3.
|
|
249
|
+
mlcube_config["tasks"]["evaluate"]["entrypoint"] = "python3.11 /entrypoint.py"
|
|
250
250
|
mlcube_config["docker"]["build_strategy"] = "auto"
|
|
251
251
|
return mlcube_config
|
|
252
252
|
|
|
@@ -315,7 +315,7 @@ def get_model_mlcube_config(
|
|
|
315
315
|
device = "cuda" if requires_gpu else "cpu"
|
|
316
316
|
mlcube_config["tasks"]["infer"][
|
|
317
317
|
"entrypoint"
|
|
318
|
-
] = f"python3.
|
|
318
|
+
] = f"python3.11 /entrypoint.py --device {device}"
|
|
319
319
|
|
|
320
320
|
return mlcube_config
|
|
321
321
|
# Duplicate training task into one from reset (must be explicit) and one that resumes with new data
|
GANDLF/cli/generate_metrics.py
CHANGED
|
@@ -10,7 +10,11 @@ import SimpleITK as sitk
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
12
12
|
from GANDLF.config_manager import ConfigManager
|
|
13
|
-
from GANDLF.utils import
|
|
13
|
+
from GANDLF.utils import (
|
|
14
|
+
find_problem_type_from_parameters,
|
|
15
|
+
one_hot,
|
|
16
|
+
sanity_check_on_file_readers,
|
|
17
|
+
)
|
|
14
18
|
from GANDLF.metrics import (
|
|
15
19
|
overall_stats,
|
|
16
20
|
structural_similarity_index,
|
|
@@ -20,6 +24,7 @@ from GANDLF.metrics import (
|
|
|
20
24
|
mean_squared_log_error,
|
|
21
25
|
mean_absolute_error,
|
|
22
26
|
ncc_metrics,
|
|
27
|
+
generate_instance_segmentation,
|
|
23
28
|
)
|
|
24
29
|
from GANDLF.losses.segmentation import dice
|
|
25
30
|
from GANDLF.metrics.segmentation import (
|
|
@@ -175,6 +180,9 @@ def generate_metrics_dict(
|
|
|
175
180
|
for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
|
|
176
181
|
current_subject_id = row["SubjectID"]
|
|
177
182
|
overall_stats_dict[current_subject_id] = {}
|
|
183
|
+
sanity_check_on_file_readers(
|
|
184
|
+
row["Target"], row["Prediction"], current_subject_id
|
|
185
|
+
)
|
|
178
186
|
label_image = torchio.LabelMap(row["Target"])
|
|
179
187
|
pred_image = torchio.LabelMap(row["Prediction"])
|
|
180
188
|
label_tensor = label_image.data
|
|
@@ -259,6 +267,29 @@ def generate_metrics_dict(
|
|
|
259
267
|
"volumeSimilarity_" + str(class_index)
|
|
260
268
|
] = label_overlap_filter.GetVolumeSimilarity()
|
|
261
269
|
|
|
270
|
+
elif problem_type == "segmentation_brats":
|
|
271
|
+
for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
|
|
272
|
+
current_subject_id = row["SubjectID"]
|
|
273
|
+
overall_stats_dict[current_subject_id] = {}
|
|
274
|
+
sanity_check_on_file_readers(
|
|
275
|
+
row["Target"], row["Prediction"], current_subject_id
|
|
276
|
+
)
|
|
277
|
+
label_image = torchio.LabelMap(row["Target"])
|
|
278
|
+
pred_image = torchio.LabelMap(row["Prediction"])
|
|
279
|
+
label_tensor = label_image.data
|
|
280
|
+
pred_tensor = pred_image.data
|
|
281
|
+
spacing = label_image.spacing
|
|
282
|
+
# if label_tensor.data.shape[-1] == 1:
|
|
283
|
+
# spacing = spacing[0:2]
|
|
284
|
+
# remove dimension to ensure 3D tensors
|
|
285
|
+
if label_tensor.data.ndim == 4:
|
|
286
|
+
label_array = label_tensor.squeeze(0).numpy().astype(int)
|
|
287
|
+
pred_array = pred_tensor.squeeze(0).numpy().astype(int)
|
|
288
|
+
|
|
289
|
+
overall_stats_dict[current_subject_id] = generate_instance_segmentation(
|
|
290
|
+
prediction=pred_array, target=label_array, parameters=parameters
|
|
291
|
+
)
|
|
292
|
+
|
|
262
293
|
elif problem_type == "synthesis":
|
|
263
294
|
|
|
264
295
|
def __fix_2d_tensor(input_tensor):
|
|
@@ -319,6 +350,9 @@ def generate_metrics_dict(
|
|
|
319
350
|
for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
|
|
320
351
|
current_subject_id = row["SubjectID"]
|
|
321
352
|
overall_stats_dict[current_subject_id] = {}
|
|
353
|
+
sanity_check_on_file_readers(
|
|
354
|
+
row["Target"], row["Prediction"], current_subject_id
|
|
355
|
+
)
|
|
322
356
|
target_image = __fix_2d_tensor(torchio.ScalarImage(row["Target"]).data)
|
|
323
357
|
pred_image = __fix_2d_tensor(torchio.ScalarImage(row["Prediction"]).data)
|
|
324
358
|
# if "Mask" is not in the row, we assume that the whole image is the mask
|
GANDLF/cli/main_run.py
CHANGED
|
@@ -16,9 +16,9 @@ def main_run(
|
|
|
16
16
|
config_file: str,
|
|
17
17
|
model_dir: str,
|
|
18
18
|
train_mode: bool,
|
|
19
|
-
device: str,
|
|
20
19
|
resume: bool,
|
|
21
20
|
reset: bool,
|
|
21
|
+
profile: Optional[bool] = False,
|
|
22
22
|
output_dir: Optional[str] = None,
|
|
23
23
|
) -> None:
|
|
24
24
|
"""
|
|
@@ -29,9 +29,9 @@ def main_run(
|
|
|
29
29
|
config_file (str): The YAML file of the training configuration.
|
|
30
30
|
model_dir (str): The model directory; for training, model is written out here, and for inference, trained model is expected here.
|
|
31
31
|
train_mode (bool): Whether to train or infer.
|
|
32
|
-
device (str): The device type.
|
|
33
32
|
resume (bool): Whether the previous run will be resumed or not.
|
|
34
33
|
reset (bool): Whether the previous run will be reset or not.
|
|
34
|
+
profile (bool): Whether to profile the training or not. Defaults to False.
|
|
35
35
|
output_dir (str): The output directory for the inference session. Defaults to None.
|
|
36
36
|
|
|
37
37
|
Returns:
|
|
@@ -39,9 +39,7 @@ def main_run(
|
|
|
39
39
|
"""
|
|
40
40
|
file_data_full = data_csv
|
|
41
41
|
model_parameters = config_file
|
|
42
|
-
device = device
|
|
43
42
|
parameters = ConfigManager(model_parameters)
|
|
44
|
-
parameters["device_id"] = -1
|
|
45
43
|
|
|
46
44
|
if train_mode:
|
|
47
45
|
if resume:
|
|
@@ -59,9 +57,6 @@ def main_run(
|
|
|
59
57
|
parameters["output_dir"] = model_dir
|
|
60
58
|
Path(parameters["output_dir"]).mkdir(parents=True, exist_ok=True)
|
|
61
59
|
|
|
62
|
-
if "-1" in device:
|
|
63
|
-
device = "cpu"
|
|
64
|
-
|
|
65
60
|
# parse training CSV
|
|
66
61
|
if "," in file_data_full:
|
|
67
62
|
# training and validation pre-split
|
|
@@ -95,9 +90,9 @@ def main_run(
|
|
|
95
90
|
dataframe_testing=data_testing,
|
|
96
91
|
outputDir=parameters["output_dir"],
|
|
97
92
|
parameters=parameters,
|
|
98
|
-
device=device,
|
|
99
93
|
resume=resume,
|
|
100
94
|
reset=reset,
|
|
95
|
+
profile=profile,
|
|
101
96
|
)
|
|
102
97
|
else:
|
|
103
98
|
data_full, headers = parseTrainingCSV(file_data_full, train=train_mode)
|
|
@@ -107,9 +102,9 @@ def main_run(
|
|
|
107
102
|
dataframe=data_full,
|
|
108
103
|
outputDir=parameters["output_dir"],
|
|
109
104
|
parameters=parameters,
|
|
110
|
-
device=device,
|
|
111
105
|
resume=resume,
|
|
112
106
|
reset=reset,
|
|
107
|
+
profile=profile,
|
|
113
108
|
)
|
|
114
109
|
else:
|
|
115
110
|
_, data_full, headers = parseTestingCSV(
|
|
@@ -120,5 +115,4 @@ def main_run(
|
|
|
120
115
|
modelDir=model_dir,
|
|
121
116
|
outputDir=output_dir,
|
|
122
117
|
parameters=parameters,
|
|
123
|
-
device=device,
|
|
124
118
|
)
|
GANDLF/compute/__init__.py
CHANGED
GANDLF/compute/forward_pass.py
CHANGED
|
@@ -150,7 +150,6 @@ def validate_network(
|
|
|
150
150
|
tensor=subject[key]["data"].squeeze(0),
|
|
151
151
|
affine=subject[key]["affine"].squeeze(0),
|
|
152
152
|
)
|
|
153
|
-
|
|
154
153
|
# regression/classification problem AND label is present
|
|
155
154
|
if (params["problem_type"] != "segmentation") and label_present:
|
|
156
155
|
sampler = torchio.data.LabelSampler(params["patch_size"])
|
GANDLF/compute/generic.py
CHANGED
|
@@ -2,17 +2,122 @@ from typing import Optional, Tuple
|
|
|
2
2
|
from pandas.util import hash_pandas_object
|
|
3
3
|
import torch
|
|
4
4
|
from torch.utils.data import DataLoader
|
|
5
|
-
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
6
7
|
from GANDLF.models import get_model
|
|
7
8
|
from GANDLF.schedulers import get_scheduler
|
|
8
9
|
from GANDLF.optimizers import get_optimizer
|
|
9
|
-
from GANDLF.data import get_train_loader, get_validation_loader
|
|
10
|
+
from GANDLF.data import get_train_loader, get_validation_loader, ImagesFromDataFrame
|
|
10
11
|
from GANDLF.utils import (
|
|
11
12
|
populate_header_in_parameters,
|
|
13
|
+
populate_channel_keys_in_params,
|
|
12
14
|
parseTrainingCSV,
|
|
13
15
|
send_model_to_device,
|
|
14
16
|
get_class_imbalance_weights,
|
|
15
17
|
)
|
|
18
|
+
from GANDLF.utils.write_parse import get_dataframe
|
|
19
|
+
from torchio import SubjectsDataset, Queue
|
|
20
|
+
from typing import Union
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class AbstractSubsetDataParser(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Interface for subset data parsers, needed to separate the dataset creation
|
|
27
|
+
from construction of the dataloaders.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
subset_csv_path: str
|
|
31
|
+
parameters_dict: dict
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def create_subset_dataset(self) -> Union[SubjectsDataset, Queue]:
|
|
35
|
+
"""
|
|
36
|
+
Method to create the subset dataset based on the subset CSV file
|
|
37
|
+
and the parameters dict.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Union[SubjectsDataset, Queue]: The subset dataset.
|
|
41
|
+
"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def get_params_extended_with_subset_data(self) -> dict:
|
|
45
|
+
"""
|
|
46
|
+
Trick to get around the fact that parameters dict need to be modified
|
|
47
|
+
during this parsing procedure. This method should be called after
|
|
48
|
+
create_subset_dataset(), as this method will populate the parameters
|
|
49
|
+
dict with the headers from the subset data.
|
|
50
|
+
"""
|
|
51
|
+
return self.parameters_dict
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TrainingSubsetDataParser(AbstractSubsetDataParser):
|
|
55
|
+
def create_subset_dataset(self) -> Union[SubjectsDataset, Queue]:
|
|
56
|
+
(
|
|
57
|
+
self.parameters_dict["training_data"],
|
|
58
|
+
headers_to_populate_train,
|
|
59
|
+
) = parseTrainingCSV(self.subset_csv_path, train=True)
|
|
60
|
+
|
|
61
|
+
self.parameters_dict = populate_header_in_parameters(
|
|
62
|
+
self.parameters_dict, headers_to_populate_train
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
(
|
|
66
|
+
self.parameters_dict["penalty_weights"],
|
|
67
|
+
self.parameters_dict["sampling_weights"],
|
|
68
|
+
self.parameters_dict["class_weights"],
|
|
69
|
+
) = get_class_imbalance_weights(
|
|
70
|
+
self.parameters_dict["training_data"], self.parameters_dict
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
print("Penalty weights : ", self.parameters_dict["penalty_weights"])
|
|
74
|
+
print("Sampling weights: ", self.parameters_dict["sampling_weights"])
|
|
75
|
+
print("Class weights : ", self.parameters_dict["class_weights"])
|
|
76
|
+
|
|
77
|
+
return ImagesFromDataFrame(
|
|
78
|
+
get_dataframe(self.parameters_dict["training_data"]),
|
|
79
|
+
self.parameters_dict,
|
|
80
|
+
train=True,
|
|
81
|
+
loader_type="train",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class ValidationSubsetDataParser(AbstractSubsetDataParser):
|
|
86
|
+
def create_subset_dataset(self) -> Union[SubjectsDataset, Queue]:
|
|
87
|
+
(self.parameters_dict["validation_data"], _) = parseTrainingCSV(
|
|
88
|
+
self.subset_csv_path, train=False
|
|
89
|
+
)
|
|
90
|
+
validation_dataset = ImagesFromDataFrame(
|
|
91
|
+
get_dataframe(self.parameters_dict["validation_data"]),
|
|
92
|
+
self.parameters_dict,
|
|
93
|
+
train=False,
|
|
94
|
+
loader_type="validation",
|
|
95
|
+
)
|
|
96
|
+
self.parameters_dict = populate_channel_keys_in_params(
|
|
97
|
+
validation_dataset, self.parameters_dict
|
|
98
|
+
)
|
|
99
|
+
return validation_dataset
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TestSubsetDataParser(AbstractSubsetDataParser):
|
|
103
|
+
def create_subset_dataset(self) -> Union[SubjectsDataset, Queue]:
|
|
104
|
+
testing_dataset = ImagesFromDataFrame(
|
|
105
|
+
get_dataframe(self.subset_csv_path),
|
|
106
|
+
self.parameters_dict,
|
|
107
|
+
train=False,
|
|
108
|
+
loader_type="testing",
|
|
109
|
+
)
|
|
110
|
+
if not ("channel_keys" in self.parameters_dict):
|
|
111
|
+
self.parameters_dict = populate_channel_keys_in_params(
|
|
112
|
+
testing_dataset, self.parameters_dict
|
|
113
|
+
)
|
|
114
|
+
return testing_dataset
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class InferenceSubsetDataParserRadiology(TestSubsetDataParser):
|
|
118
|
+
"""Simple wrapper for name coherency, functionally this is the same as TestSubsetDataParser"""
|
|
119
|
+
|
|
120
|
+
pass
|
|
16
121
|
|
|
17
122
|
|
|
18
123
|
def create_pytorch_objects(
|
GANDLF/compute/inference_loop.py
CHANGED
|
@@ -18,8 +18,8 @@ from torch.cuda.amp import autocast
|
|
|
18
18
|
import openslide
|
|
19
19
|
from GANDLF.data import get_testing_loader
|
|
20
20
|
from GANDLF.utils import (
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
BEST_MODEL_PATH_END,
|
|
22
|
+
LATEST_MODEL_PATH_END,
|
|
23
23
|
load_ov_model,
|
|
24
24
|
print_model_summary,
|
|
25
25
|
applyCustomColorMap,
|
|
@@ -72,11 +72,11 @@ def inference_loop(
|
|
|
72
72
|
files_to_check = [
|
|
73
73
|
os.path.join(
|
|
74
74
|
modelDir,
|
|
75
|
-
str(parameters["model"]["architecture"]) +
|
|
75
|
+
str(parameters["model"]["architecture"]) + BEST_MODEL_PATH_END,
|
|
76
76
|
),
|
|
77
77
|
os.path.join(
|
|
78
78
|
modelDir,
|
|
79
|
-
str(parameters["model"]["architecture"]) +
|
|
79
|
+
str(parameters["model"]["architecture"]) + LATEST_MODEL_PATH_END,
|
|
80
80
|
),
|
|
81
81
|
]
|
|
82
82
|
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import sys
|
|
2
1
|
import warnings
|
|
3
2
|
from typing import Dict, Tuple, Union
|
|
4
3
|
from GANDLF.losses import global_losses_dict
|
|
@@ -134,7 +133,7 @@ def get_loss_and_metrics(
|
|
|
134
133
|
# Metrics should be a list
|
|
135
134
|
for metric in params["metrics"]:
|
|
136
135
|
metric_lower = metric.lower()
|
|
137
|
-
metric_output[metric] = 0
|
|
136
|
+
metric_output[metric] = 0.0
|
|
138
137
|
if metric_lower not in global_metrics_dict:
|
|
139
138
|
warnings.warn("WARNING: Could not find the requested metric '" + metric)
|
|
140
139
|
continue
|
GANDLF/compute/training_loop.py
CHANGED
|
@@ -13,9 +13,9 @@ from GANDLF.grad_clipping.grad_scaler import GradScaler, model_parameters_exclud
|
|
|
13
13
|
from GANDLF.grad_clipping.clip_gradients import dispatch_clip_grad_
|
|
14
14
|
from GANDLF.utils import (
|
|
15
15
|
get_date_time,
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
BEST_MODEL_PATH_END,
|
|
17
|
+
LATEST_MODEL_PATH_END,
|
|
18
|
+
INITIAL_MODEL_PATH_END,
|
|
19
19
|
save_model,
|
|
20
20
|
optimize_and_save_model,
|
|
21
21
|
load_model,
|
|
@@ -281,13 +281,13 @@ def training_loop(
|
|
|
281
281
|
first_model_saved = False
|
|
282
282
|
model_paths = {
|
|
283
283
|
"best": os.path.join(
|
|
284
|
-
output_dir, params["model"]["architecture"] +
|
|
284
|
+
output_dir, params["model"]["architecture"] + BEST_MODEL_PATH_END
|
|
285
285
|
),
|
|
286
286
|
"initial": os.path.join(
|
|
287
|
-
output_dir, params["model"]["architecture"] +
|
|
287
|
+
output_dir, params["model"]["architecture"] + INITIAL_MODEL_PATH_END
|
|
288
288
|
),
|
|
289
289
|
"latest": os.path.join(
|
|
290
|
-
output_dir, params["model"]["architecture"] +
|
|
290
|
+
output_dir, params["model"]["architecture"] + LATEST_MODEL_PATH_END
|
|
291
291
|
),
|
|
292
292
|
}
|
|
293
293
|
|
|
@@ -481,14 +481,14 @@ def training_loop(
|
|
|
481
481
|
+ str(mem[3])
|
|
482
482
|
)
|
|
483
483
|
if params["device"] == "cuda":
|
|
484
|
-
|
|
484
|
+
cuda_memory_stats = torch.cuda.memory_stats()
|
|
485
485
|
outputToWrite_mem += (
|
|
486
486
|
","
|
|
487
|
-
+ str(
|
|
487
|
+
+ str(cuda_memory_stats["active.all.peak"])
|
|
488
488
|
+ ","
|
|
489
|
-
+ str(
|
|
489
|
+
+ str(cuda_memory_stats["active.all.current"])
|
|
490
490
|
+ ","
|
|
491
|
-
+ str(
|
|
491
|
+
+ str(cuda_memory_stats["active.all.allocated"])
|
|
492
492
|
)
|
|
493
493
|
outputToWrite_mem += ",\n"
|
|
494
494
|
file_mem.write(outputToWrite_mem)
|