monai-weekly 1.5.dev2514__py3-none-any.whl → 1.5.dev2515__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "4986d7ffd2d351c9d66de0e0329884b1a26d5500"
139
+ __commit_id__ = "f27517b81ded6f3de730861d95d10d72fb0c4a51"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-04-06T02:31:51+0000",
11
+ "date": "2025-04-13T03:01:18+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "a3ea49fc4e600d131daadad61ea340df25fcfdaa",
15
- "version": "1.5.dev2514"
14
+ "full-revisionid": "2f0c8e65507306bf5b92e1ac85642ca808d1c5e2",
15
+ "version": "1.5.dev2515"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -11,5 +11,13 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from .nnunet_bundle import (
15
+ ModelnnUNetWrapper,
16
+ convert_monai_bundle_to_nnunet,
17
+ convert_nnunet_to_monai_bundle,
18
+ get_network_from_nnunet_plans,
19
+ get_nnunet_monai_predictor,
20
+ get_nnunet_trainer,
21
+ )
14
22
  from .nnunetv2_runner import nnUNetV2Runner
15
23
  from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json
@@ -0,0 +1,594 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import shutil
15
+ from pathlib import Path
16
+ from typing import Any, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch.backends import cudnn
21
+
22
+ from monai.data.meta_tensor import MetaTensor
23
+ from monai.utils import optional_import
24
+
25
+ join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
26
+ load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
27
+
28
+ __all__ = [
29
+ "get_nnunet_trainer",
30
+ "get_nnunet_monai_predictor",
31
+ "get_network_from_nnunet_plans",
32
+ "convert_nnunet_to_monai_bundle",
33
+ "convert_monai_bundle_to_nnunet",
34
+ "ModelnnUNetWrapper",
35
+ ]
36
+
37
+
38
+ def get_nnunet_trainer(
39
+ dataset_name_or_id: Union[str, int],
40
+ configuration: str,
41
+ fold: Union[int, str],
42
+ trainer_class_name: str = "nnUNetTrainer",
43
+ plans_identifier: str = "nnUNetPlans",
44
+ use_compressed_data: bool = False,
45
+ continue_training: bool = False,
46
+ only_run_validation: bool = False,
47
+ disable_checkpointing: bool = False,
48
+ device: str = "cuda",
49
+ pretrained_model: Optional[str] = None,
50
+ ) -> Any: # type: ignore
51
+ """
52
+ Get the nnUNet trainer instance based on the provided configuration.
53
+ The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
54
+ optimizer, loss function, DataLoader, etc.
55
+
56
+ Example::
57
+
58
+ from monai.apps import SupervisedTrainer
59
+ from monai.bundle.nnunet import get_nnunet_trainer
60
+
61
+ dataset_name_or_id = 'Task009_Spleen'
62
+ fold = 0
63
+ configuration = '3d_fullres'
64
+ nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
65
+
66
+ trainer = SupervisedTrainer(
67
+ device=nnunet_trainer.device,
68
+ max_epochs=nnunet_trainer.num_epochs,
69
+ train_data_loader=nnunet_trainer.dataloader_train,
70
+ network=nnunet_trainer.network,
71
+ optimizer=nnunet_trainer.optimizer,
72
+ loss_function=nnunet_trainer.loss_function,
73
+ epoch_length=nnunet_trainer.num_iterations_per_epoch,
74
+ )
75
+
76
+ Parameters
77
+ ----------
78
+ dataset_name_or_id : Union[str, int]
79
+ The name or ID of the dataset to be used.
80
+ configuration : str
81
+ The configuration name for the training.
82
+ fold : Union[int, str]
83
+ The fold number or 'all' for cross-validation.
84
+ trainer_class_name : str, optional
85
+ The class name of the trainer to be used. Default is 'nnUNetTrainer'.
86
+ For a complete list of supported trainers, check:
87
+ https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants
88
+ plans_identifier : str, optional
89
+ Identifier for the plans to be used. Default is 'nnUNetPlans'.
90
+ use_compressed_data : bool, optional
91
+ Whether to use compressed data. Default is False.
92
+ continue_training : bool, optional
93
+ Whether to continue training from a checkpoint. Default is False.
94
+ only_run_validation : bool, optional
95
+ Whether to only run validation. Default is False.
96
+ disable_checkpointing : bool, optional
97
+ Whether to disable checkpointing. Default is False.
98
+ device : str, optional
99
+ The device to be used for training. Default is 'cuda'.
100
+ pretrained_model : Optional[str], optional
101
+ Path to the pretrained model file.
102
+
103
+ Returns
104
+ -------
105
+ nnunet_trainer : object
106
+ The nnUNet trainer instance.
107
+ """
108
+ # From nnUNet/nnunetv2/run/run_training.py#run_training
109
+ if isinstance(fold, str):
110
+ if fold != "all":
111
+ try:
112
+ fold = int(fold)
113
+ except ValueError as e:
114
+ print(
115
+ f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!'
116
+ )
117
+ raise e
118
+
119
+ from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
120
+
121
+ nnunet_trainer = get_trainer_from_args(
122
+ str(dataset_name_or_id), configuration, fold, trainer_class_name, plans_identifier, device=torch.device(device)
123
+ )
124
+ if disable_checkpointing:
125
+ nnunet_trainer.disable_checkpointing = disable_checkpointing
126
+
127
+ assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy."
128
+
129
+ maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation)
130
+ nnunet_trainer.on_train_start() # Added to Initialize Trainer
131
+ if torch.cuda.is_available():
132
+ cudnn.deterministic = False
133
+ cudnn.benchmark = True
134
+
135
+ if pretrained_model is not None:
136
+ state_dict = torch.load(pretrained_model)
137
+ if "network_weights" in state_dict:
138
+ nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
139
+ return nnunet_trainer
140
+
141
+
142
+ class ModelnnUNetWrapper(torch.nn.Module):
143
+ """
144
+ A wrapper class for nnUNet model integration with MONAI framework.
145
+ The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.
146
+
147
+ Parameters
148
+ ----------
149
+ predictor : nnUNetPredictor
150
+ The nnUNet predictor object used for inference.
151
+ model_folder : Union[str, Path]
152
+ The folder path where the model and related files are stored.
153
+ model_name : str, optional
154
+ The name of the model file, by default "model.pt".
155
+
156
+ Attributes
157
+ ----------
158
+ predictor : nnUNetPredictor
159
+ The nnUNet predictor object used for inference.
160
+ network_weights : torch.nn.Module
161
+ The network weights of the model.
162
+
163
+ Notes
164
+ -----
165
+ This class integrates nnUNet model with MONAI framework by loading necessary configurations,
166
+ restoring network architecture, and setting up the predictor for inference.
167
+ """
168
+
169
+ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
170
+ super().__init__()
171
+ self.predictor = predictor
172
+
173
+ model_training_output_dir = model_folder
174
+
175
+ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
176
+
177
+ # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
178
+ dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
179
+ plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
180
+ plans_manager = PlansManager(plans)
181
+
182
+ parameters = []
183
+
184
+ checkpoint = torch.load(
185
+ join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
186
+ )
187
+ trainer_name = checkpoint["trainer_name"]
188
+ configuration_name = checkpoint["init_args"]["configuration"]
189
+ inference_allowed_mirroring_axes = (
190
+ checkpoint["inference_allowed_mirroring_axes"]
191
+ if "inference_allowed_mirroring_axes" in checkpoint.keys()
192
+ else None
193
+ )
194
+ if Path(model_training_output_dir).joinpath(model_name).is_file():
195
+ monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
196
+ if "network_weights" in monai_checkpoint.keys():
197
+ parameters.append(monai_checkpoint["network_weights"])
198
+ else:
199
+ parameters.append(monai_checkpoint)
200
+
201
+ configuration_manager = plans_manager.get_configuration(configuration_name)
202
+ import nnunetv2
203
+ from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
204
+ from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
205
+
206
+ num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
207
+ trainer_class = recursive_find_python_class(
208
+ join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, "nnunetv2.training.nnUNetTrainer"
209
+ )
210
+ if trainer_class is None:
211
+ raise RuntimeError(
212
+ f"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. "
213
+ f"Please place it there (in any .py file)!"
214
+ )
215
+ network = trainer_class.build_network_architecture(
216
+ configuration_manager.network_arch_class_name,
217
+ configuration_manager.network_arch_init_kwargs,
218
+ configuration_manager.network_arch_init_kwargs_req_import,
219
+ num_input_channels,
220
+ plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
221
+ enable_deep_supervision=False,
222
+ )
223
+
224
+ predictor.plans_manager = plans_manager # type: ignore
225
+ predictor.configuration_manager = configuration_manager # type: ignore
226
+ predictor.list_of_parameters = parameters # type: ignore
227
+ predictor.network = network # type: ignore
228
+ predictor.dataset_json = dataset_json # type: ignore
229
+ predictor.trainer_name = trainer_name # type: ignore
230
+ predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore
231
+ predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore
232
+
233
+ self.network_weights = self.predictor.network # type: ignore
234
+
235
+ def forward(self, x: MetaTensor) -> MetaTensor:
236
+ """
237
+ Forward pass for the nnUNet model.
238
+
239
+ Args:
240
+ x (MetaTensor): Input tensor. If the input is a tuple,
241
+ it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
242
+
243
+ Returns:
244
+ MetaTensor: The output tensor with the same metadata as the input.
245
+
246
+ Raises:
247
+ TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.
248
+
249
+ Notes:
250
+ - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.
251
+ - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.
252
+ - The filenames are used to generate predictions using the nnUNet predictor.
253
+ - The predictions are converted to torch tensors, with added batch and channel dimensions.
254
+ - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
255
+ """
256
+ if isinstance(x, MetaTensor):
257
+ if "pixdim" in x.meta:
258
+ properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
259
+ elif "affine" in x.meta:
260
+ spacing = [
261
+ abs(x.meta["affine"][0][0].item()),
262
+ abs(x.meta["affine"][1][1].item()),
263
+ abs(x.meta["affine"][2][2].item()),
264
+ ]
265
+ properties_or_list_of_properties = {"spacing": spacing}
266
+ else:
267
+ properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
268
+ else:
269
+ raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
270
+
271
+ image_or_list_of_images = x.cpu().numpy()[0, :]
272
+
273
+ # input_files should be a list of file paths, one per modality
274
+ prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore
275
+ image_or_list_of_images,
276
+ None,
277
+ properties_or_list_of_properties,
278
+ truncated_ofname=None,
279
+ save_probabilities=False,
280
+ num_processes=2,
281
+ num_processes_segmentation_export=2,
282
+ )
283
+ # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax
284
+
285
+ out_tensors = []
286
+ for out in prediction_output: # Add batch and channel dimensions
287
+ out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
288
+ out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
289
+
290
+ return MetaTensor(out_tensor, meta=x.meta)
291
+
292
+
293
+ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper:
294
+ """
295
+ Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
296
+ The model folder should contain the following files, created during training:
297
+
298
+ - dataset.json: from the nnUNet results folder
299
+ - plans.json: from the nnUNet results folder
300
+ - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
301
+ - model.pt: The checkpoint file containing the model weights.
302
+
303
+ The returned wrapper object can be used for inference with MONAI framework:
304
+
305
+ Example::
306
+
307
+ from monai.bundle.nnunet import get_nnunet_monai_predictor
308
+
309
+ model_folder = 'path/to/monai_bundle/model'
310
+ model_name = 'model.pt'
311
+ wrapper = get_nnunet_monai_predictor(model_folder, model_name)
312
+
313
+ # Perform inference
314
+ input_data = ...
315
+ output = wrapper(input_data)
316
+
317
+
318
+ Parameters
319
+ ----------
320
+ model_folder : Union[str, Path]
321
+ The folder where the model is stored.
322
+ model_name : str, optional
323
+ The name of the model file, by default "model.pt".
324
+
325
+ Returns
326
+ -------
327
+ ModelnnUNetWrapper
328
+ A wrapper object that contains the nnUNetPredictor and the loaded model.
329
+ """
330
+
331
+ from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
332
+
333
+ predictor = nnUNetPredictor(
334
+ tile_step_size=0.5,
335
+ use_gaussian=True,
336
+ use_mirroring=False,
337
+ device=torch.device("cuda", 0),
338
+ verbose=False,
339
+ verbose_preprocessing=False,
340
+ allow_tqdm=True,
341
+ )
342
+ # initializes the network architecture, loads the checkpoint
343
+ wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)
344
+ return wrapper
345
+
346
+
347
+ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:
348
+ """
349
+ Convert nnUNet model checkpoints and configuration to MONAI bundle format.
350
+
351
+ Parameters
352
+ ----------
353
+ nnunet_config : dict
354
+ Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration',
355
+ 'nnunet_trainer', and 'nnunet_plans'.
356
+ bundle_root_folder : str
357
+ Root folder where the MONAI bundle will be saved.
358
+ fold : int, optional
359
+ Fold number of the nnUNet model to be converted, by default 0.
360
+
361
+ Returns
362
+ -------
363
+ None
364
+ """
365
+
366
+ nnunet_trainer = "nnUNetTrainer"
367
+ nnunet_plans = "nnUNetPlans"
368
+ nnunet_configuration = "3d_fullres"
369
+
370
+ if "nnunet_trainer" in nnunet_config:
371
+ nnunet_trainer = nnunet_config["nnunet_trainer"]
372
+
373
+ if "nnunet_plans" in nnunet_config:
374
+ nnunet_plans = nnunet_config["nnunet_plans"]
375
+
376
+ if "nnunet_configuration" in nnunet_config:
377
+ nnunet_configuration = nnunet_config["nnunet_configuration"]
378
+
379
+ from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
380
+
381
+ dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"])
382
+ nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath(
383
+ dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
384
+ )
385
+
386
+ nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
387
+ nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
388
+
389
+ nnunet_checkpoint = {}
390
+ nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
391
+ nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"]
392
+ nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"]
393
+
394
+ torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth"))
395
+
396
+ Path(bundle_root_folder).joinpath("models", f"fold_{fold}").mkdir(parents=True, exist_ok=True)
397
+ monai_last_checkpoint = {}
398
+ monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"]
399
+ torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "model.pt"))
400
+
401
+ monai_best_checkpoint = {}
402
+ monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"]
403
+ torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "best_model.pt"))
404
+
405
+ if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")):
406
+ shutil.copy(
407
+ Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json")
408
+ )
409
+
410
+ if not os.path.exists(os.path.join(bundle_root_folder, "models", "dataset.json")):
411
+ shutil.copy(
412
+ Path(nnunet_model_folder).joinpath("dataset.json"),
413
+ Path(bundle_root_folder).joinpath("models", "dataset.json"),
414
+ )
415
+
416
+
417
+ def get_network_from_nnunet_plans(
418
+ plans_file: str,
419
+ dataset_file: str,
420
+ configuration: str,
421
+ model_ckpt: Optional[str] = None,
422
+ model_key_in_ckpt: str = "model",
423
+ ) -> Union[torch.nn.Module, Any]:
424
+ """
425
+ Load and initialize a nnUNet network based on nnUNet plans and configuration.
426
+
427
+ Parameters
428
+ ----------
429
+ plans_file : str
430
+ Path to the JSON file containing the nnUNet plans.
431
+ dataset_file : str
432
+ Path to the JSON file containing the dataset information.
433
+ configuration : str
434
+ The configuration name to be used from the plans.
435
+ model_ckpt : Optional[str], optional
436
+ Path to the model checkpoint file. If None, the network is returned without loading weights (default is None).
437
+ model_key_in_ckpt : str, optional
438
+ The key in the checkpoint file that contains the model state dictionary (default is "model").
439
+
440
+ Returns
441
+ -------
442
+ network : torch.nn.Module
443
+ The initialized neural network, with weights loaded if `model_ckpt` is provided.
444
+ """
445
+ from batchgenerators.utilities.file_and_folder_operations import load_json
446
+ from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
447
+ from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
448
+ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
449
+
450
+ plans = load_json(plans_file)
451
+ dataset_json = load_json(dataset_file)
452
+
453
+ plans_manager = PlansManager(plans)
454
+ configuration_manager = plans_manager.get_configuration(configuration)
455
+ num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
456
+ label_manager = plans_manager.get_label_manager(dataset_json)
457
+
458
+ enable_deep_supervision = True
459
+
460
+ network = get_network_from_plans(
461
+ configuration_manager.network_arch_class_name,
462
+ configuration_manager.network_arch_init_kwargs,
463
+ configuration_manager.network_arch_init_kwargs_req_import,
464
+ num_input_channels,
465
+ label_manager.num_segmentation_heads,
466
+ allow_init=True,
467
+ deep_supervision=enable_deep_supervision,
468
+ )
469
+
470
+ if model_ckpt is None:
471
+ return network
472
+ else:
473
+ state_dict = torch.load(model_ckpt)
474
+ network.load_state_dict(state_dict[model_key_in_ckpt])
475
+ return network
476
+
477
+
478
+ def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:
479
+ """
480
+ Convert a MONAI bundle to nnU-Net format.
481
+
482
+ Parameters
483
+ ----------
484
+ nnunet_config : dict
485
+ Configuration dictionary for nnU-Net. Expected keys are:
486
+ - "dataset_name_or_id": str, name or ID of the dataset.
487
+ - "nnunet_trainer": str, optional, name of the nnU-Net trainer (default is "nnUNetTrainer").
488
+ - "nnunet_plans": str, optional, name of the nnU-Net plans (default is "nnUNetPlans").
489
+ bundle_root_folder : str
490
+ Path to the root folder of the MONAI bundle.
491
+ fold : int, optional
492
+ Fold number for cross-validation (default is 0).
493
+
494
+ Returns
495
+ -------
496
+ None
497
+ """
498
+ from odict import odict
499
+
500
+ nnunet_trainer: str = "nnUNetTrainer"
501
+ nnunet_plans: str = "nnUNetPlans"
502
+
503
+ if "nnunet_trainer" in nnunet_config:
504
+ nnunet_trainer = nnunet_config["nnunet_trainer"]
505
+
506
+ if "nnunet_plans" in nnunet_config:
507
+ nnunet_plans = nnunet_config["nnunet_plans"]
508
+
509
+ from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
510
+ from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
511
+
512
+ def subfiles(
513
+ folder: Union[str, Path], prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True
514
+ ) -> list[str]:
515
+ res = [
516
+ i.name
517
+ for i in Path(folder).iterdir()
518
+ if i.is_file()
519
+ and (prefix is None or i.name.startswith(prefix))
520
+ and (suffix is None or i.name.endswith(suffix))
521
+ ]
522
+ if sort:
523
+ res.sort()
524
+ return res
525
+
526
+ nnunet_model_folder: Path = Path(os.environ["nnUNet_results"]).joinpath(
527
+ maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]),
528
+ f"{nnunet_trainer}__{nnunet_plans}__3d_fullres",
529
+ )
530
+
531
+ nnunet_preprocess_model_folder: Path = Path(os.environ["nnUNet_preprocessed"]).joinpath(
532
+ maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"])
533
+ )
534
+
535
+ Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
536
+
537
+ nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
538
+ latest_checkpoints: list[str] = subfiles(
539
+ Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
540
+ )
541
+ epochs: list[int] = []
542
+ for latest_checkpoint in latest_checkpoints:
543
+ epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")]))
544
+
545
+ epochs.sort()
546
+ final_epoch: int = epochs[-1]
547
+ monai_last_checkpoint: dict = torch.load(
548
+ f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
549
+ )
550
+
551
+ best_checkpoints: list[str] = subfiles(
552
+ Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_key_metric", sort=True
553
+ )
554
+ key_metrics: list[str] = []
555
+ for best_checkpoint in best_checkpoints:
556
+ key_metrics.append(str(best_checkpoint[len("checkpoint_key_metric=") : -len(".pt")]))
557
+
558
+ key_metrics.sort()
559
+ best_key_metric: str = key_metrics[-1]
560
+ monai_best_checkpoint: dict = torch.load(
561
+ f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
562
+ )
563
+
564
+ nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
565
+
566
+ nnunet_checkpoint["network_weights"] = odict()
567
+
568
+ for key in monai_last_checkpoint["network_weights"]:
569
+ nnunet_checkpoint["network_weights"][key] = monai_last_checkpoint["network_weights"][key]
570
+
571
+ nnunet_checkpoint["current_epoch"] = final_epoch
572
+ nnunet_checkpoint["logging"] = nnUNetLogger().get_checkpoint()
573
+ nnunet_checkpoint["_best_ema"] = 0
574
+ nnunet_checkpoint["grad_scaler_state"] = None
575
+
576
+ torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
577
+
578
+ nnunet_checkpoint["network_weights"] = odict()
579
+
580
+ nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"]
581
+
582
+ for key in monai_best_checkpoint["network_weights"]:
583
+ nnunet_checkpoint["network_weights"][key] = monai_best_checkpoint["network_weights"][key]
584
+
585
+ torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
586
+
587
+ if not os.path.exists(os.path.join(nnunet_model_folder, "dataset.json")):
588
+ shutil.copy(f"{bundle_root_folder}/models/dataset.json", nnunet_model_folder)
589
+ if not os.path.exists(os.path.join(nnunet_model_folder, "plans.json")):
590
+ shutil.copy(f"{bundle_root_folder}/models/plans.json", nnunet_model_folder)
591
+ if not os.path.exists(os.path.join(nnunet_model_folder, "dataset_fingerprint.json")):
592
+ shutil.copy(f"{nnunet_preprocess_model_folder}/dataset_fingerprint.json", nnunet_model_folder)
593
+ if not os.path.exists(os.path.join(nnunet_model_folder, "nnunet_checkpoint.pth")):
594
+ shutil.copy(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", nnunet_model_folder)