monai-weekly 1.5.dev2514__py3-none-any.whl → 1.5.dev2516__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/nnunet/__init__.py +8 -0
- monai/apps/nnunet/nnunet_bundle.py +594 -0
- monai/bundle/scripts.py +9 -18
- monai/networks/blocks/fft_utils_t.py +12 -20
- monai/networks/blocks/selfattention.py +5 -1
- monai/networks/nets/diffusion_model_unet.py +3 -3
- monai/transforms/inverse.py +39 -7
- {monai_weekly-1.5.dev2514.dist-info → monai_weekly-1.5.dev2516.dist-info}/METADATA +3 -1
- {monai_weekly-1.5.dev2514.dist-info → monai_weekly-1.5.dev2516.dist-info}/RECORD +23 -20
- {monai_weekly-1.5.dev2514.dist-info → monai_weekly-1.5.dev2516.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_download.py +5 -0
- tests/integration/test_integration_nnunet_bundle.py +150 -0
- tests/networks/blocks/test_selfattention.py +21 -0
- tests/networks/nets/test_transchex.py +3 -2
- tests/transforms/inverse/test_inverse_dict.py +105 -0
- tests/transforms/inverse/test_traceable_transform.py +2 -2
- {monai_weekly-1.5.dev2514.dist-info → monai_weekly-1.5.dev2516.dist-info}/licenses/LICENSE +0 -0
- {monai_weekly-1.5.dev2514.dist-info → monai_weekly-1.5.dev2516.dist-info}/top_level.txt +0 -0
- /tests/transforms/{test_inverse.py → inverse/test_inverse.py} +0 -0
- /tests/transforms/{test_invert.py → inverse/test_invert.py} +0 -0
- /tests/transforms/{test_invertd.py → inverse/test_invertd.py} +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-04-
|
11
|
+
"date": "2025-04-20T02:32:55+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "875aa3d53281b83a205442db83ac41ddf817ef9a",
|
15
|
+
"version": "1.5.dev2516"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/apps/nnunet/__init__.py
CHANGED
@@ -11,5 +11,13 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
from .nnunet_bundle import (
|
15
|
+
ModelnnUNetWrapper,
|
16
|
+
convert_monai_bundle_to_nnunet,
|
17
|
+
convert_nnunet_to_monai_bundle,
|
18
|
+
get_network_from_nnunet_plans,
|
19
|
+
get_nnunet_monai_predictor,
|
20
|
+
get_nnunet_trainer,
|
21
|
+
)
|
14
22
|
from .nnunetv2_runner import nnUNetV2Runner
|
15
23
|
from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json
|
@@ -0,0 +1,594 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import os
|
14
|
+
import shutil
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import Any, Optional, Union
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
20
|
+
from torch.backends import cudnn
|
21
|
+
|
22
|
+
from monai.data.meta_tensor import MetaTensor
|
23
|
+
from monai.utils import optional_import
|
24
|
+
|
25
|
+
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
|
26
|
+
load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
"get_nnunet_trainer",
|
30
|
+
"get_nnunet_monai_predictor",
|
31
|
+
"get_network_from_nnunet_plans",
|
32
|
+
"convert_nnunet_to_monai_bundle",
|
33
|
+
"convert_monai_bundle_to_nnunet",
|
34
|
+
"ModelnnUNetWrapper",
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
def get_nnunet_trainer(
|
39
|
+
dataset_name_or_id: Union[str, int],
|
40
|
+
configuration: str,
|
41
|
+
fold: Union[int, str],
|
42
|
+
trainer_class_name: str = "nnUNetTrainer",
|
43
|
+
plans_identifier: str = "nnUNetPlans",
|
44
|
+
use_compressed_data: bool = False,
|
45
|
+
continue_training: bool = False,
|
46
|
+
only_run_validation: bool = False,
|
47
|
+
disable_checkpointing: bool = False,
|
48
|
+
device: str = "cuda",
|
49
|
+
pretrained_model: Optional[str] = None,
|
50
|
+
) -> Any: # type: ignore
|
51
|
+
"""
|
52
|
+
Get the nnUNet trainer instance based on the provided configuration.
|
53
|
+
The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
|
54
|
+
optimizer, loss function, DataLoader, etc.
|
55
|
+
|
56
|
+
Example::
|
57
|
+
|
58
|
+
from monai.apps import SupervisedTrainer
|
59
|
+
from monai.bundle.nnunet import get_nnunet_trainer
|
60
|
+
|
61
|
+
dataset_name_or_id = 'Task009_Spleen'
|
62
|
+
fold = 0
|
63
|
+
configuration = '3d_fullres'
|
64
|
+
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
|
65
|
+
|
66
|
+
trainer = SupervisedTrainer(
|
67
|
+
device=nnunet_trainer.device,
|
68
|
+
max_epochs=nnunet_trainer.num_epochs,
|
69
|
+
train_data_loader=nnunet_trainer.dataloader_train,
|
70
|
+
network=nnunet_trainer.network,
|
71
|
+
optimizer=nnunet_trainer.optimizer,
|
72
|
+
loss_function=nnunet_trainer.loss_function,
|
73
|
+
epoch_length=nnunet_trainer.num_iterations_per_epoch,
|
74
|
+
)
|
75
|
+
|
76
|
+
Parameters
|
77
|
+
----------
|
78
|
+
dataset_name_or_id : Union[str, int]
|
79
|
+
The name or ID of the dataset to be used.
|
80
|
+
configuration : str
|
81
|
+
The configuration name for the training.
|
82
|
+
fold : Union[int, str]
|
83
|
+
The fold number or 'all' for cross-validation.
|
84
|
+
trainer_class_name : str, optional
|
85
|
+
The class name of the trainer to be used. Default is 'nnUNetTrainer'.
|
86
|
+
For a complete list of supported trainers, check:
|
87
|
+
https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants
|
88
|
+
plans_identifier : str, optional
|
89
|
+
Identifier for the plans to be used. Default is 'nnUNetPlans'.
|
90
|
+
use_compressed_data : bool, optional
|
91
|
+
Whether to use compressed data. Default is False.
|
92
|
+
continue_training : bool, optional
|
93
|
+
Whether to continue training from a checkpoint. Default is False.
|
94
|
+
only_run_validation : bool, optional
|
95
|
+
Whether to only run validation. Default is False.
|
96
|
+
disable_checkpointing : bool, optional
|
97
|
+
Whether to disable checkpointing. Default is False.
|
98
|
+
device : str, optional
|
99
|
+
The device to be used for training. Default is 'cuda'.
|
100
|
+
pretrained_model : Optional[str], optional
|
101
|
+
Path to the pretrained model file.
|
102
|
+
|
103
|
+
Returns
|
104
|
+
-------
|
105
|
+
nnunet_trainer : object
|
106
|
+
The nnUNet trainer instance.
|
107
|
+
"""
|
108
|
+
# From nnUNet/nnunetv2/run/run_training.py#run_training
|
109
|
+
if isinstance(fold, str):
|
110
|
+
if fold != "all":
|
111
|
+
try:
|
112
|
+
fold = int(fold)
|
113
|
+
except ValueError as e:
|
114
|
+
print(
|
115
|
+
f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!'
|
116
|
+
)
|
117
|
+
raise e
|
118
|
+
|
119
|
+
from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
|
120
|
+
|
121
|
+
nnunet_trainer = get_trainer_from_args(
|
122
|
+
str(dataset_name_or_id), configuration, fold, trainer_class_name, plans_identifier, device=torch.device(device)
|
123
|
+
)
|
124
|
+
if disable_checkpointing:
|
125
|
+
nnunet_trainer.disable_checkpointing = disable_checkpointing
|
126
|
+
|
127
|
+
assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy."
|
128
|
+
|
129
|
+
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation)
|
130
|
+
nnunet_trainer.on_train_start() # Added to Initialize Trainer
|
131
|
+
if torch.cuda.is_available():
|
132
|
+
cudnn.deterministic = False
|
133
|
+
cudnn.benchmark = True
|
134
|
+
|
135
|
+
if pretrained_model is not None:
|
136
|
+
state_dict = torch.load(pretrained_model)
|
137
|
+
if "network_weights" in state_dict:
|
138
|
+
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
|
139
|
+
return nnunet_trainer
|
140
|
+
|
141
|
+
|
142
|
+
class ModelnnUNetWrapper(torch.nn.Module):
|
143
|
+
"""
|
144
|
+
A wrapper class for nnUNet model integration with MONAI framework.
|
145
|
+
The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.
|
146
|
+
|
147
|
+
Parameters
|
148
|
+
----------
|
149
|
+
predictor : nnUNetPredictor
|
150
|
+
The nnUNet predictor object used for inference.
|
151
|
+
model_folder : Union[str, Path]
|
152
|
+
The folder path where the model and related files are stored.
|
153
|
+
model_name : str, optional
|
154
|
+
The name of the model file, by default "model.pt".
|
155
|
+
|
156
|
+
Attributes
|
157
|
+
----------
|
158
|
+
predictor : nnUNetPredictor
|
159
|
+
The nnUNet predictor object used for inference.
|
160
|
+
network_weights : torch.nn.Module
|
161
|
+
The network weights of the model.
|
162
|
+
|
163
|
+
Notes
|
164
|
+
-----
|
165
|
+
This class integrates nnUNet model with MONAI framework by loading necessary configurations,
|
166
|
+
restoring network architecture, and setting up the predictor for inference.
|
167
|
+
"""
|
168
|
+
|
169
|
+
def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
|
170
|
+
super().__init__()
|
171
|
+
self.predictor = predictor
|
172
|
+
|
173
|
+
model_training_output_dir = model_folder
|
174
|
+
|
175
|
+
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
|
176
|
+
|
177
|
+
# Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
|
178
|
+
dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
|
179
|
+
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
|
180
|
+
plans_manager = PlansManager(plans)
|
181
|
+
|
182
|
+
parameters = []
|
183
|
+
|
184
|
+
checkpoint = torch.load(
|
185
|
+
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
|
186
|
+
)
|
187
|
+
trainer_name = checkpoint["trainer_name"]
|
188
|
+
configuration_name = checkpoint["init_args"]["configuration"]
|
189
|
+
inference_allowed_mirroring_axes = (
|
190
|
+
checkpoint["inference_allowed_mirroring_axes"]
|
191
|
+
if "inference_allowed_mirroring_axes" in checkpoint.keys()
|
192
|
+
else None
|
193
|
+
)
|
194
|
+
if Path(model_training_output_dir).joinpath(model_name).is_file():
|
195
|
+
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
|
196
|
+
if "network_weights" in monai_checkpoint.keys():
|
197
|
+
parameters.append(monai_checkpoint["network_weights"])
|
198
|
+
else:
|
199
|
+
parameters.append(monai_checkpoint)
|
200
|
+
|
201
|
+
configuration_manager = plans_manager.get_configuration(configuration_name)
|
202
|
+
import nnunetv2
|
203
|
+
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
|
204
|
+
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
|
205
|
+
|
206
|
+
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
|
207
|
+
trainer_class = recursive_find_python_class(
|
208
|
+
join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, "nnunetv2.training.nnUNetTrainer"
|
209
|
+
)
|
210
|
+
if trainer_class is None:
|
211
|
+
raise RuntimeError(
|
212
|
+
f"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. "
|
213
|
+
f"Please place it there (in any .py file)!"
|
214
|
+
)
|
215
|
+
network = trainer_class.build_network_architecture(
|
216
|
+
configuration_manager.network_arch_class_name,
|
217
|
+
configuration_manager.network_arch_init_kwargs,
|
218
|
+
configuration_manager.network_arch_init_kwargs_req_import,
|
219
|
+
num_input_channels,
|
220
|
+
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
|
221
|
+
enable_deep_supervision=False,
|
222
|
+
)
|
223
|
+
|
224
|
+
predictor.plans_manager = plans_manager # type: ignore
|
225
|
+
predictor.configuration_manager = configuration_manager # type: ignore
|
226
|
+
predictor.list_of_parameters = parameters # type: ignore
|
227
|
+
predictor.network = network # type: ignore
|
228
|
+
predictor.dataset_json = dataset_json # type: ignore
|
229
|
+
predictor.trainer_name = trainer_name # type: ignore
|
230
|
+
predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore
|
231
|
+
predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore
|
232
|
+
|
233
|
+
self.network_weights = self.predictor.network # type: ignore
|
234
|
+
|
235
|
+
def forward(self, x: MetaTensor) -> MetaTensor:
|
236
|
+
"""
|
237
|
+
Forward pass for the nnUNet model.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
x (MetaTensor): Input tensor. If the input is a tuple,
|
241
|
+
it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
MetaTensor: The output tensor with the same metadata as the input.
|
245
|
+
|
246
|
+
Raises:
|
247
|
+
TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.
|
248
|
+
|
249
|
+
Notes:
|
250
|
+
- If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.
|
251
|
+
- If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.
|
252
|
+
- The filenames are used to generate predictions using the nnUNet predictor.
|
253
|
+
- The predictions are converted to torch tensors, with added batch and channel dimensions.
|
254
|
+
- The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
|
255
|
+
"""
|
256
|
+
if isinstance(x, MetaTensor):
|
257
|
+
if "pixdim" in x.meta:
|
258
|
+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
|
259
|
+
elif "affine" in x.meta:
|
260
|
+
spacing = [
|
261
|
+
abs(x.meta["affine"][0][0].item()),
|
262
|
+
abs(x.meta["affine"][1][1].item()),
|
263
|
+
abs(x.meta["affine"][2][2].item()),
|
264
|
+
]
|
265
|
+
properties_or_list_of_properties = {"spacing": spacing}
|
266
|
+
else:
|
267
|
+
properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
|
268
|
+
else:
|
269
|
+
raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
|
270
|
+
|
271
|
+
image_or_list_of_images = x.cpu().numpy()[0, :]
|
272
|
+
|
273
|
+
# input_files should be a list of file paths, one per modality
|
274
|
+
prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore
|
275
|
+
image_or_list_of_images,
|
276
|
+
None,
|
277
|
+
properties_or_list_of_properties,
|
278
|
+
truncated_ofname=None,
|
279
|
+
save_probabilities=False,
|
280
|
+
num_processes=2,
|
281
|
+
num_processes_segmentation_export=2,
|
282
|
+
)
|
283
|
+
# prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax
|
284
|
+
|
285
|
+
out_tensors = []
|
286
|
+
for out in prediction_output: # Add batch and channel dimensions
|
287
|
+
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
|
288
|
+
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
|
289
|
+
|
290
|
+
return MetaTensor(out_tensor, meta=x.meta)
|
291
|
+
|
292
|
+
|
293
|
+
def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper:
|
294
|
+
"""
|
295
|
+
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
|
296
|
+
The model folder should contain the following files, created during training:
|
297
|
+
|
298
|
+
- dataset.json: from the nnUNet results folder
|
299
|
+
- plans.json: from the nnUNet results folder
|
300
|
+
- nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
|
301
|
+
- model.pt: The checkpoint file containing the model weights.
|
302
|
+
|
303
|
+
The returned wrapper object can be used for inference with MONAI framework:
|
304
|
+
|
305
|
+
Example::
|
306
|
+
|
307
|
+
from monai.bundle.nnunet import get_nnunet_monai_predictor
|
308
|
+
|
309
|
+
model_folder = 'path/to/monai_bundle/model'
|
310
|
+
model_name = 'model.pt'
|
311
|
+
wrapper = get_nnunet_monai_predictor(model_folder, model_name)
|
312
|
+
|
313
|
+
# Perform inference
|
314
|
+
input_data = ...
|
315
|
+
output = wrapper(input_data)
|
316
|
+
|
317
|
+
|
318
|
+
Parameters
|
319
|
+
----------
|
320
|
+
model_folder : Union[str, Path]
|
321
|
+
The folder where the model is stored.
|
322
|
+
model_name : str, optional
|
323
|
+
The name of the model file, by default "model.pt".
|
324
|
+
|
325
|
+
Returns
|
326
|
+
-------
|
327
|
+
ModelnnUNetWrapper
|
328
|
+
A wrapper object that contains the nnUNetPredictor and the loaded model.
|
329
|
+
"""
|
330
|
+
|
331
|
+
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
|
332
|
+
|
333
|
+
predictor = nnUNetPredictor(
|
334
|
+
tile_step_size=0.5,
|
335
|
+
use_gaussian=True,
|
336
|
+
use_mirroring=False,
|
337
|
+
device=torch.device("cuda", 0),
|
338
|
+
verbose=False,
|
339
|
+
verbose_preprocessing=False,
|
340
|
+
allow_tqdm=True,
|
341
|
+
)
|
342
|
+
# initializes the network architecture, loads the checkpoint
|
343
|
+
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)
|
344
|
+
return wrapper
|
345
|
+
|
346
|
+
|
347
|
+
def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:
|
348
|
+
"""
|
349
|
+
Convert nnUNet model checkpoints and configuration to MONAI bundle format.
|
350
|
+
|
351
|
+
Parameters
|
352
|
+
----------
|
353
|
+
nnunet_config : dict
|
354
|
+
Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration',
|
355
|
+
'nnunet_trainer', and 'nnunet_plans'.
|
356
|
+
bundle_root_folder : str
|
357
|
+
Root folder where the MONAI bundle will be saved.
|
358
|
+
fold : int, optional
|
359
|
+
Fold number of the nnUNet model to be converted, by default 0.
|
360
|
+
|
361
|
+
Returns
|
362
|
+
-------
|
363
|
+
None
|
364
|
+
"""
|
365
|
+
|
366
|
+
nnunet_trainer = "nnUNetTrainer"
|
367
|
+
nnunet_plans = "nnUNetPlans"
|
368
|
+
nnunet_configuration = "3d_fullres"
|
369
|
+
|
370
|
+
if "nnunet_trainer" in nnunet_config:
|
371
|
+
nnunet_trainer = nnunet_config["nnunet_trainer"]
|
372
|
+
|
373
|
+
if "nnunet_plans" in nnunet_config:
|
374
|
+
nnunet_plans = nnunet_config["nnunet_plans"]
|
375
|
+
|
376
|
+
if "nnunet_configuration" in nnunet_config:
|
377
|
+
nnunet_configuration = nnunet_config["nnunet_configuration"]
|
378
|
+
|
379
|
+
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
|
380
|
+
|
381
|
+
dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"])
|
382
|
+
nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath(
|
383
|
+
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
|
384
|
+
)
|
385
|
+
|
386
|
+
nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
|
387
|
+
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
|
388
|
+
|
389
|
+
nnunet_checkpoint = {}
|
390
|
+
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
|
391
|
+
nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"]
|
392
|
+
nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"]
|
393
|
+
|
394
|
+
torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth"))
|
395
|
+
|
396
|
+
Path(bundle_root_folder).joinpath("models", f"fold_{fold}").mkdir(parents=True, exist_ok=True)
|
397
|
+
monai_last_checkpoint = {}
|
398
|
+
monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"]
|
399
|
+
torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "model.pt"))
|
400
|
+
|
401
|
+
monai_best_checkpoint = {}
|
402
|
+
monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"]
|
403
|
+
torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "best_model.pt"))
|
404
|
+
|
405
|
+
if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")):
|
406
|
+
shutil.copy(
|
407
|
+
Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json")
|
408
|
+
)
|
409
|
+
|
410
|
+
if not os.path.exists(os.path.join(bundle_root_folder, "models", "dataset.json")):
|
411
|
+
shutil.copy(
|
412
|
+
Path(nnunet_model_folder).joinpath("dataset.json"),
|
413
|
+
Path(bundle_root_folder).joinpath("models", "dataset.json"),
|
414
|
+
)
|
415
|
+
|
416
|
+
|
417
|
+
def get_network_from_nnunet_plans(
|
418
|
+
plans_file: str,
|
419
|
+
dataset_file: str,
|
420
|
+
configuration: str,
|
421
|
+
model_ckpt: Optional[str] = None,
|
422
|
+
model_key_in_ckpt: str = "model",
|
423
|
+
) -> Union[torch.nn.Module, Any]:
|
424
|
+
"""
|
425
|
+
Load and initialize a nnUNet network based on nnUNet plans and configuration.
|
426
|
+
|
427
|
+
Parameters
|
428
|
+
----------
|
429
|
+
plans_file : str
|
430
|
+
Path to the JSON file containing the nnUNet plans.
|
431
|
+
dataset_file : str
|
432
|
+
Path to the JSON file containing the dataset information.
|
433
|
+
configuration : str
|
434
|
+
The configuration name to be used from the plans.
|
435
|
+
model_ckpt : Optional[str], optional
|
436
|
+
Path to the model checkpoint file. If None, the network is returned without loading weights (default is None).
|
437
|
+
model_key_in_ckpt : str, optional
|
438
|
+
The key in the checkpoint file that contains the model state dictionary (default is "model").
|
439
|
+
|
440
|
+
Returns
|
441
|
+
-------
|
442
|
+
network : torch.nn.Module
|
443
|
+
The initialized neural network, with weights loaded if `model_ckpt` is provided.
|
444
|
+
"""
|
445
|
+
from batchgenerators.utilities.file_and_folder_operations import load_json
|
446
|
+
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
|
447
|
+
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
|
448
|
+
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
|
449
|
+
|
450
|
+
plans = load_json(plans_file)
|
451
|
+
dataset_json = load_json(dataset_file)
|
452
|
+
|
453
|
+
plans_manager = PlansManager(plans)
|
454
|
+
configuration_manager = plans_manager.get_configuration(configuration)
|
455
|
+
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
|
456
|
+
label_manager = plans_manager.get_label_manager(dataset_json)
|
457
|
+
|
458
|
+
enable_deep_supervision = True
|
459
|
+
|
460
|
+
network = get_network_from_plans(
|
461
|
+
configuration_manager.network_arch_class_name,
|
462
|
+
configuration_manager.network_arch_init_kwargs,
|
463
|
+
configuration_manager.network_arch_init_kwargs_req_import,
|
464
|
+
num_input_channels,
|
465
|
+
label_manager.num_segmentation_heads,
|
466
|
+
allow_init=True,
|
467
|
+
deep_supervision=enable_deep_supervision,
|
468
|
+
)
|
469
|
+
|
470
|
+
if model_ckpt is None:
|
471
|
+
return network
|
472
|
+
else:
|
473
|
+
state_dict = torch.load(model_ckpt)
|
474
|
+
network.load_state_dict(state_dict[model_key_in_ckpt])
|
475
|
+
return network
|
476
|
+
|
477
|
+
|
478
|
+
def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:
|
479
|
+
"""
|
480
|
+
Convert a MONAI bundle to nnU-Net format.
|
481
|
+
|
482
|
+
Parameters
|
483
|
+
----------
|
484
|
+
nnunet_config : dict
|
485
|
+
Configuration dictionary for nnU-Net. Expected keys are:
|
486
|
+
- "dataset_name_or_id": str, name or ID of the dataset.
|
487
|
+
- "nnunet_trainer": str, optional, name of the nnU-Net trainer (default is "nnUNetTrainer").
|
488
|
+
- "nnunet_plans": str, optional, name of the nnU-Net plans (default is "nnUNetPlans").
|
489
|
+
bundle_root_folder : str
|
490
|
+
Path to the root folder of the MONAI bundle.
|
491
|
+
fold : int, optional
|
492
|
+
Fold number for cross-validation (default is 0).
|
493
|
+
|
494
|
+
Returns
|
495
|
+
-------
|
496
|
+
None
|
497
|
+
"""
|
498
|
+
from odict import odict
|
499
|
+
|
500
|
+
nnunet_trainer: str = "nnUNetTrainer"
|
501
|
+
nnunet_plans: str = "nnUNetPlans"
|
502
|
+
|
503
|
+
if "nnunet_trainer" in nnunet_config:
|
504
|
+
nnunet_trainer = nnunet_config["nnunet_trainer"]
|
505
|
+
|
506
|
+
if "nnunet_plans" in nnunet_config:
|
507
|
+
nnunet_plans = nnunet_config["nnunet_plans"]
|
508
|
+
|
509
|
+
from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
|
510
|
+
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
|
511
|
+
|
512
|
+
def subfiles(
|
513
|
+
folder: Union[str, Path], prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True
|
514
|
+
) -> list[str]:
|
515
|
+
res = [
|
516
|
+
i.name
|
517
|
+
for i in Path(folder).iterdir()
|
518
|
+
if i.is_file()
|
519
|
+
and (prefix is None or i.name.startswith(prefix))
|
520
|
+
and (suffix is None or i.name.endswith(suffix))
|
521
|
+
]
|
522
|
+
if sort:
|
523
|
+
res.sort()
|
524
|
+
return res
|
525
|
+
|
526
|
+
nnunet_model_folder: Path = Path(os.environ["nnUNet_results"]).joinpath(
|
527
|
+
maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]),
|
528
|
+
f"{nnunet_trainer}__{nnunet_plans}__3d_fullres",
|
529
|
+
)
|
530
|
+
|
531
|
+
nnunet_preprocess_model_folder: Path = Path(os.environ["nnUNet_preprocessed"]).joinpath(
|
532
|
+
maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"])
|
533
|
+
)
|
534
|
+
|
535
|
+
Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
|
536
|
+
|
537
|
+
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
|
538
|
+
latest_checkpoints: list[str] = subfiles(
|
539
|
+
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
|
540
|
+
)
|
541
|
+
epochs: list[int] = []
|
542
|
+
for latest_checkpoint in latest_checkpoints:
|
543
|
+
epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")]))
|
544
|
+
|
545
|
+
epochs.sort()
|
546
|
+
final_epoch: int = epochs[-1]
|
547
|
+
monai_last_checkpoint: dict = torch.load(
|
548
|
+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
|
549
|
+
)
|
550
|
+
|
551
|
+
best_checkpoints: list[str] = subfiles(
|
552
|
+
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_key_metric", sort=True
|
553
|
+
)
|
554
|
+
key_metrics: list[str] = []
|
555
|
+
for best_checkpoint in best_checkpoints:
|
556
|
+
key_metrics.append(str(best_checkpoint[len("checkpoint_key_metric=") : -len(".pt")]))
|
557
|
+
|
558
|
+
key_metrics.sort()
|
559
|
+
best_key_metric: str = key_metrics[-1]
|
560
|
+
monai_best_checkpoint: dict = torch.load(
|
561
|
+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
|
562
|
+
)
|
563
|
+
|
564
|
+
nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
|
565
|
+
|
566
|
+
nnunet_checkpoint["network_weights"] = odict()
|
567
|
+
|
568
|
+
for key in monai_last_checkpoint["network_weights"]:
|
569
|
+
nnunet_checkpoint["network_weights"][key] = monai_last_checkpoint["network_weights"][key]
|
570
|
+
|
571
|
+
nnunet_checkpoint["current_epoch"] = final_epoch
|
572
|
+
nnunet_checkpoint["logging"] = nnUNetLogger().get_checkpoint()
|
573
|
+
nnunet_checkpoint["_best_ema"] = 0
|
574
|
+
nnunet_checkpoint["grad_scaler_state"] = None
|
575
|
+
|
576
|
+
torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
|
577
|
+
|
578
|
+
nnunet_checkpoint["network_weights"] = odict()
|
579
|
+
|
580
|
+
nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"]
|
581
|
+
|
582
|
+
for key in monai_best_checkpoint["network_weights"]:
|
583
|
+
nnunet_checkpoint["network_weights"][key] = monai_best_checkpoint["network_weights"][key]
|
584
|
+
|
585
|
+
torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
|
586
|
+
|
587
|
+
if not os.path.exists(os.path.join(nnunet_model_folder, "dataset.json")):
|
588
|
+
shutil.copy(f"{bundle_root_folder}/models/dataset.json", nnunet_model_folder)
|
589
|
+
if not os.path.exists(os.path.join(nnunet_model_folder, "plans.json")):
|
590
|
+
shutil.copy(f"{bundle_root_folder}/models/plans.json", nnunet_model_folder)
|
591
|
+
if not os.path.exists(os.path.join(nnunet_model_folder, "dataset_fingerprint.json")):
|
592
|
+
shutil.copy(f"{nnunet_preprocess_model_folder}/dataset_fingerprint.json", nnunet_model_folder)
|
593
|
+
if not os.path.exists(os.path.join(nnunet_model_folder, "nnunet_checkpoint.pth")):
|
594
|
+
shutil.copy(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", nnunet_model_folder)
|