monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/apps/vista3d/inferer.py +177 -0
  7. monai/apps/vista3d/sampler.py +179 -0
  8. monai/apps/vista3d/transforms.py +224 -0
  9. monai/bundle/scripts.py +29 -17
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/inferers/utils.py +1 -0
  13. monai/losses/__init__.py +1 -0
  14. monai/losses/dice.py +10 -1
  15. monai/losses/nacl_loss.py +139 -0
  16. monai/networks/blocks/crossattention.py +48 -26
  17. monai/networks/blocks/mlp.py +16 -4
  18. monai/networks/blocks/selfattention.py +75 -23
  19. monai/networks/blocks/spatialattention.py +16 -1
  20. monai/networks/blocks/transformerblock.py +17 -2
  21. monai/networks/layers/filtering.py +6 -2
  22. monai/networks/nets/__init__.py +2 -1
  23. monai/networks/nets/autoencoderkl.py +55 -22
  24. monai/networks/nets/cell_sam_wrapper.py +92 -0
  25. monai/networks/nets/controlnet.py +24 -22
  26. monai/networks/nets/diffusion_model_unet.py +159 -19
  27. monai/networks/nets/segresnet_ds.py +127 -1
  28. monai/networks/nets/spade_autoencoderkl.py +22 -0
  29. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  30. monai/networks/nets/transformer.py +17 -17
  31. monai/networks/nets/vista3d.py +946 -0
  32. monai/networks/utils.py +4 -4
  33. monai/transforms/__init__.py +13 -2
  34. monai/transforms/io/array.py +59 -3
  35. monai/transforms/io/dictionary.py +29 -2
  36. monai/transforms/spatial/functional.py +1 -1
  37. monai/transforms/transform.py +2 -2
  38. monai/transforms/utility/dictionary.py +4 -0
  39. monai/transforms/utils.py +230 -1
  40. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  41. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  42. monai/utils/enums.py +1 -0
  43. monai/utils/module.py +7 -6
  44. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
  45. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
  46. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
  47. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  48. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
  49. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import random
16
+ from collections.abc import Callable, Sequence
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+ import torch
21
+ from torch import Tensor
22
+
23
+ ENABLE_SPECIAL = True
24
+ SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
25
+ MERGE_LIST = {
26
+ 1: [25, 26], # hepatic tumor and vessel merge into liver
27
+ 4: [24], # pancreatic tumor merge into pancreas
28
+ 132: [57], # overlap with trachea merge into airway
29
+ }
30
+
31
+ __all__ = ["sample_prompt_pairs"]
32
+
33
+
34
+ def _get_point_label(id: int) -> tuple[int, int]:
35
+ if id in SPECIAL_INDEX and ENABLE_SPECIAL:
36
+ return 2, 3
37
+ else:
38
+ return 0, 1
39
+
40
+
41
+ def sample_prompt_pairs(
42
+ labels: Tensor,
43
+ label_set: Sequence[int],
44
+ max_prompt: int | None = None,
45
+ max_foreprompt: int | None = None,
46
+ max_backprompt: int = 1,
47
+ max_point: int = 20,
48
+ include_background: bool = False,
49
+ drop_label_prob: float = 0.2,
50
+ drop_point_prob: float = 0.2,
51
+ point_sampler: Callable | None = None,
52
+ **point_sampler_kwargs: Any,
53
+ ) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
54
+ """
55
+ Sample training pairs for VISTA3D training.
56
+
57
+ Args:
58
+ labels: [1, 1, H, W, D], ground truth labels.
59
+ label_set: the label list for the specific dataset. Note if 0 is included in label_set,
60
+ it will be added into automatic branch training. Recommend removing 0 from label_set
61
+ for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.
62
+ The reason is region with 0 in one partially labeled dataset may contain foregrounds in
63
+ another dataset.
64
+ max_prompt: int, max number of total prompt, including foreground and background.
65
+ max_foreprompt: int, max number of prompt from foreground.
66
+ max_backprompt: int, max number of prompt from background.
67
+ max_point: maximum number of points for each object.
68
+ include_background: if include 0 into training prompt. If included, background 0 is treated
69
+ the same as foreground and points will be sampled. Can be true only if user want to segment
70
+ background 0 with point clicks, otherwise always be false.
71
+ drop_label_prob: probability to drop label prompt.
72
+ drop_point_prob: probability to drop point prompt.
73
+ point_sampler: sampler to augment masks with supervoxel.
74
+ point_sampler_kwargs: arguments for point_sampler.
75
+
76
+ Returns:
77
+ tuple:
78
+ - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
79
+ training automatic segmentation.
80
+ - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
81
+ for each class. Note that background label prompts require matching points as well
82
+ (e.g., [0, 0, 0] is used).
83
+ - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
84
+ labels for each point (negative or positive). -1 is used for padding the background
85
+ label prompt and will be ignored.
86
+ - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
87
+ for label indexing during training. If label_prompt is None, prompt_class is used to
88
+ identify point classes.
89
+
90
+ """
91
+
92
+ # class label number
93
+ if not labels.shape[0] == 1:
94
+ raise ValueError("only support batch size 1")
95
+ labels = labels[0, 0]
96
+ device = labels.device
97
+ unique_labels = labels.unique().cpu().numpy().tolist()
98
+ if include_background:
99
+ unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))
100
+ else:
101
+ unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})
102
+ background_labels = list(set(label_set) - set(unique_labels))
103
+ # during training, balance background and foreground prompts
104
+ if max_backprompt is not None:
105
+ if len(background_labels) > max_backprompt:
106
+ random.shuffle(background_labels)
107
+ background_labels = background_labels[:max_backprompt]
108
+
109
+ if max_foreprompt is not None:
110
+ if len(unique_labels) > max_foreprompt:
111
+ random.shuffle(unique_labels)
112
+ unique_labels = unique_labels[:max_foreprompt]
113
+
114
+ if max_prompt is not None:
115
+ if len(unique_labels) + len(background_labels) > max_prompt:
116
+ if len(unique_labels) > max_prompt:
117
+ unique_labels = random.sample(unique_labels, max_prompt)
118
+ background_labels = []
119
+ else:
120
+ background_labels = random.sample(background_labels, max_prompt - len(unique_labels))
121
+ _point = []
122
+ _point_label = []
123
+ # if use regular sampling
124
+ if point_sampler is None:
125
+ num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)
126
+ num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))
127
+ for id in unique_labels:
128
+ neg_id, pos_id = _get_point_label(id)
129
+ plabels = labels == int(id)
130
+ nlabels = ~plabels
131
+ plabelpoints = torch.nonzero(plabels)
132
+ nlabelpoints = torch.nonzero(nlabels)
133
+ # final sampled positive points
134
+ num_pa = min(len(plabelpoints), num_p)
135
+ # final sampled negative points
136
+ num_na = min(len(nlabelpoints), num_n)
137
+ _point.append(
138
+ torch.stack(
139
+ random.choices(plabelpoints, k=num_pa)
140
+ + random.choices(nlabelpoints, k=num_na)
141
+ + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)
142
+ )
143
+ )
144
+ _point_label.append(
145
+ torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(
146
+ device
147
+ )
148
+ )
149
+ for _ in background_labels:
150
+ # pad the background labels
151
+ _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0
152
+ _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point
153
+ else:
154
+ _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)
155
+ for _ in background_labels:
156
+ # pad the background labels
157
+ _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0
158
+ _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point
159
+ if len(unique_labels) == 0 and len(background_labels) == 0:
160
+ # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must
161
+ # be skipped. Handle this in trainer.
162
+ label_prompt, point, point_label, prompt_class = None, None, None, None
163
+ else:
164
+ label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()
165
+ point = torch.stack(_point)
166
+ point_label = torch.stack(_point_label)
167
+ prompt_class = copy.deepcopy(label_prompt)
168
+ if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:
169
+ label_prompt = None
170
+ # If label prompt is dropped, there is no need to pad with points with label -1.
171
+ pad = len(background_labels)
172
+ point = point[: len(point) - pad] # type: ignore
173
+ point_label = point_label[: len(point_label) - pad]
174
+ prompt_class = prompt_class[: len(prompt_class) - pad]
175
+ else:
176
+ if random.uniform(0, 1) < drop_point_prob:
177
+ point = None
178
+ point_label = None
179
+ return label_prompt, point, point_label, prompt_class
@@ -0,0 +1,224 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+ from typing import Sequence
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from monai.config import DtypeLike, KeysCollection
21
+ from monai.transforms import MapLabelValue
22
+ from monai.transforms.transform import MapTransform
23
+ from monai.transforms.utils import keep_components_with_positive_points
24
+ from monai.utils import look_up_option
25
+
26
+ __all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"]
27
+
28
+
29
+ def _get_name_to_index_mapping(labels_dict: dict | None) -> dict:
30
+ """get the label name to index mapping"""
31
+ name_to_index_mapping = {}
32
+ if labels_dict is not None:
33
+ name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()}
34
+ return name_to_index_mapping
35
+
36
+
37
+ def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None:
38
+ """convert the label name to index"""
39
+ if label_prompt is not None and isinstance(label_prompt, list):
40
+ converted_label_prompt = []
41
+ # for new class, add to the mapping
42
+ for l in label_prompt:
43
+ if isinstance(l, str) and not l.isdigit():
44
+ if l.lower() not in name_to_index_mapping:
45
+ name_to_index_mapping[l.lower()] = len(name_to_index_mapping)
46
+ for l in label_prompt:
47
+ if isinstance(l, (int, str)):
48
+ converted_label_prompt.append(
49
+ name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l)
50
+ )
51
+ else:
52
+ converted_label_prompt.append(l)
53
+ return converted_label_prompt
54
+ return label_prompt
55
+
56
+
57
+ class VistaPreTransformd(MapTransform):
58
+ def __init__(
59
+ self,
60
+ keys: KeysCollection,
61
+ allow_missing_keys: bool = False,
62
+ special_index: Sequence[int] = (25, 26, 27, 28, 29, 117),
63
+ labels_dict: dict | None = None,
64
+ subclass: dict | None = None,
65
+ ) -> None:
66
+ """
67
+ Pre-transform for Vista3d.
68
+
69
+ It performs two functionalities:
70
+
71
+ 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels),
72
+ convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive).
73
+
74
+ 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key].
75
+ e.g. "lung" label is converted to ["left lung", "right lung"].
76
+
77
+ The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B,
78
+ where each element is an int value of length [B, N].
79
+
80
+ Args:
81
+ keys: keys of the corresponding items to be transformed.
82
+ special_index: the index that defines the special class.
83
+ subclass: a dictionary that maps a label prompt to its subclasses.
84
+ allow_missing_keys: don't raise exception if key is missing.
85
+ """
86
+ super().__init__(keys, allow_missing_keys)
87
+ self.special_index = special_index
88
+ self.subclass = subclass
89
+ self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict)
90
+
91
+ def __call__(self, data):
92
+ label_prompt = data.get("label_prompt", None)
93
+ point_labels = data.get("point_labels", None)
94
+ # convert the label name to index if needed
95
+ label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt)
96
+ try:
97
+ # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator.
98
+ if self.subclass is not None and label_prompt is not None:
99
+ _label_prompt = []
100
+ subclass_keys = list(map(int, self.subclass.keys()))
101
+ for i in range(len(label_prompt)):
102
+ if label_prompt[i] in subclass_keys:
103
+ _label_prompt.extend(self.subclass[str(label_prompt[i])])
104
+ else:
105
+ _label_prompt.append(label_prompt[i])
106
+ data["label_prompt"] = _label_prompt
107
+ if label_prompt is not None and point_labels is not None:
108
+ if label_prompt[0] in self.special_index:
109
+ point_labels = np.array(point_labels)
110
+ point_labels[point_labels == 0] = 2
111
+ point_labels[point_labels == 1] = 3
112
+ point_labels = point_labels.tolist()
113
+ data["point_labels"] = point_labels
114
+ except Exception:
115
+ # There is specific requirements for `label_prompt` and `point_labels`.
116
+ # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None.
117
+ # Those formatting errors should be captured later.
118
+ warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.")
119
+
120
+ return data
121
+
122
+
123
+ class VistaPostTransformd(MapTransform):
124
+ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
125
+ """
126
+ Post-transform for Vista3d. It converts the model output logits into final segmentation masks.
127
+ If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...],
128
+ else the indexes will be [0, label_prompt[0], label_prompt[1], ...].
129
+ If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove
130
+ regions that does not contain positive points.
131
+
132
+ Args:
133
+ keys: keys of the corresponding items to be transformed.
134
+ dataset_transforms: a dictionary specifies the transform for corresponding dataset:
135
+ key: dataset name, value: list of data transforms.
136
+ dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
137
+ allow_missing_keys: don't raise exception if key is missing.
138
+
139
+ """
140
+ super().__init__(keys, allow_missing_keys)
141
+
142
+ def __call__(self, data):
143
+ """data["label_prompt"] should not contain 0"""
144
+ for keys in self.keys:
145
+ if keys in data:
146
+ pred = data[keys]
147
+ object_num = pred.shape[0]
148
+ device = pred.device
149
+ if data.get("label_prompt", None) is None and data.get("points", None) is not None:
150
+ pred = keep_components_with_positive_points(
151
+ pred.unsqueeze(0),
152
+ point_coords=data.get("points").to(device),
153
+ point_labels=data.get("point_labels").to(device),
154
+ )[0]
155
+ pred[pred < 0] = 0.0
156
+ # if it's multichannel, perform argmax
157
+ if object_num > 1:
158
+ # concate background channel. Make sure user did not provide 0 as prompt.
159
+ is_bk = torch.all(pred <= 0, dim=0, keepdim=True)
160
+ pred = pred.argmax(0).unsqueeze(0).float() + 1.0
161
+ pred[is_bk] = 0.0
162
+ else:
163
+ # AsDiscrete will remove NaN
164
+ # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred)
165
+ pred[pred > 0] = 1.0
166
+ if "label_prompt" in data and data["label_prompt"] is not None:
167
+ pred += 0.5 # inplace mapping to avoid cloning pred
168
+ label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device
169
+ for i in range(1, object_num + 1):
170
+ frac = i + 0.5
171
+ pred[pred == frac] = label_prompt[i - 1].to(pred.dtype)
172
+ pred[pred == 0.5] = 0.0
173
+ data[keys] = pred
174
+ return data
175
+
176
+
177
+ class Relabeld(MapTransform):
178
+ def __init__(
179
+ self,
180
+ keys: KeysCollection,
181
+ label_mappings: dict[str, list[tuple[int, int]]],
182
+ dtype: DtypeLike = np.int16,
183
+ dataset_key: str = "dataset_name",
184
+ allow_missing_keys: bool = False,
185
+ ) -> None:
186
+ """
187
+ Remap the voxel labels in the input data dictionary based on the specified mapping.
188
+
189
+ This list of local -> global label mappings will be applied to each input `data[keys]`.
190
+ if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.
191
+ if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.
192
+
193
+ Args:
194
+ keys: keys of the corresponding items to be transformed.
195
+ label_mappings: a dictionary specifies how local dataset class indices are mapped to the
196
+ global class indices. The dictionary keys are dataset names and the values are lists of
197
+ list of (local label, global label) pairs. This list of local -> global label mappings
198
+ will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,
199
+ label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,
200
+ no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform.
201
+ dtype: convert the output data to dtype, default to float32.
202
+ dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
203
+ allow_missing_keys: don't raise exception if key is missing.
204
+
205
+ """
206
+ super().__init__(keys, allow_missing_keys)
207
+ self.mappers = {}
208
+ self.dataset_key = dataset_key
209
+ for name, mapping in label_mappings.items():
210
+ self.mappers[name] = MapLabelValue(
211
+ orig_labels=[int(pair[0]) for pair in mapping],
212
+ target_labels=[int(pair[1]) for pair in mapping],
213
+ dtype=dtype,
214
+ )
215
+
216
+ def __call__(self, data):
217
+ d = dict(data)
218
+ dataset_name = d.get(self.dataset_key, "default")
219
+ _m = look_up_option(dataset_name, self.mappers, default=None)
220
+ if _m is None:
221
+ return d
222
+ for key in self.key_iterator(d):
223
+ d[key] = _m(d[key])
224
+ return d
monai/bundle/scripts.py CHANGED
@@ -18,6 +18,7 @@ import re
18
18
  import warnings
19
19
  import zipfile
20
20
  from collections.abc import Mapping, Sequence
21
+ from functools import partial
21
22
  from pathlib import Path
22
23
  from pydoc import locate
23
24
  from shutil import copyfile
@@ -1254,6 +1255,7 @@ def verify_net_in_out(
1254
1255
 
1255
1256
  def _export(
1256
1257
  converter: Callable,
1258
+ saver: Callable,
1257
1259
  parser: ConfigParser,
1258
1260
  net_id: str,
1259
1261
  filepath: str,
@@ -1268,6 +1270,8 @@ def _export(
1268
1270
  Args:
1269
1271
  converter: a callable object that takes a torch.nn.module and kwargs as input and
1270
1272
  converts the module to another type.
1273
+ saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
1274
+ (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
1271
1275
  parser: a ConfigParser of the bundle to be converted.
1272
1276
  net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
1273
1277
  filepath: filepath to export, if filename has no extension, it becomes `.ts`.
@@ -1307,14 +1311,9 @@ def _export(
1307
1311
  # add .json extension to all extra files which are always encoded as JSON
1308
1312
  extra_files = {k + ".json": v for k, v in extra_files.items()}
1309
1313
 
1310
- save_net_with_metadata(
1311
- jit_obj=net,
1312
- filename_prefix_or_stream=filepath,
1313
- include_config_vals=False,
1314
- append_timestamp=False,
1315
- meta_values=parser.get().pop("_meta_", None),
1316
- more_extra_files=extra_files,
1317
- )
1314
+ meta_values = parser.get().pop("_meta_", None)
1315
+ saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)
1316
+
1318
1317
  logger.info(f"exported to file: {filepath}.")
1319
1318
 
1320
1319
 
@@ -1413,17 +1412,23 @@ def onnx_export(
1413
1412
  input_shape_ = _get_fake_input_shape(parser=parser)
1414
1413
 
1415
1414
  inputs_ = [torch.rand(input_shape_)]
1416
- net = parser.get_parsed_content(net_id_)
1417
- if has_ignite:
1418
- # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
1419
- Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
1420
- else:
1421
- ckpt = torch.load(ckpt_file_)
1422
- copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])
1423
1415
 
1424
1416
  converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
1425
- onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
1426
- onnx.save(onnx_model, filepath_)
1417
+
1418
+ def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
1419
+ onnx.save(onnx_obj, filename_prefix_or_stream)
1420
+
1421
+ _export(
1422
+ convert_to_onnx,
1423
+ save_onnx,
1424
+ parser,
1425
+ net_id=net_id_,
1426
+ filepath=filepath_,
1427
+ ckpt_file=ckpt_file_,
1428
+ config_file=config_file_,
1429
+ key_in_ckpt=key_in_ckpt_,
1430
+ **converter_kwargs_,
1431
+ )
1427
1432
 
1428
1433
 
1429
1434
  def ckpt_export(
@@ -1544,8 +1549,12 @@ def ckpt_export(
1544
1549
 
1545
1550
  converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
1546
1551
  # Use the given converter to convert a model and save with metadata, config content
1552
+
1553
+ save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1554
+
1547
1555
  _export(
1548
1556
  convert_to_torchscript,
1557
+ save_ts,
1549
1558
  parser,
1550
1559
  net_id=net_id_,
1551
1560
  filepath=filepath_,
@@ -1715,8 +1724,11 @@ def trt_export(
1715
1724
  }
1716
1725
  converter_kwargs_.update(trt_api_parameters)
1717
1726
 
1727
+ save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1728
+
1718
1729
  _export(
1719
1730
  convert_to_trt,
1731
+ save_ts,
1720
1732
  parser,
1721
1733
  net_id=net_id_,
1722
1734
  filepath=filepath_,
monai/data/utils.py CHANGED
@@ -927,7 +927,7 @@ def compute_shape_offset(
927
927
  corners = in_affine_ @ corners
928
928
  all_dist = corners_out[:-1].copy()
929
929
  corners_out = corners_out[:-1] / corners_out[-1]
930
- out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
930
+ out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0)
931
931
  offset = None
932
932
  for i in range(corners.shape[1]):
933
933
  min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
@@ -23,7 +23,7 @@ from monai.data.meta_tensor import MetaTensor
23
23
  from monai.data.utils import iter_patch_position
24
24
  from monai.data.wsi_reader import BaseWSIReader, WSIReader
25
25
  from monai.transforms import ForegroundMask, Randomizable, apply_transform
26
- from monai.utils import convert_to_dst_type, ensure_tuple_rep
26
+ from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep
27
27
  from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys
28
28
 
29
29
  __all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
@@ -123,9 +123,9 @@ class PatchWSIDataset(Dataset):
123
123
  def _get_location(self, sample: dict):
124
124
  if self.center_location:
125
125
  size = self._get_size(sample)
126
- return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))]
126
+ return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size)))
127
127
  else:
128
- return sample[WSIPatchKeys.LOCATION]
128
+ return ensure_tuple(sample[WSIPatchKeys.LOCATION])
129
129
 
130
130
  def _get_level(self, sample: dict):
131
131
  if self.patch_level is None:
monai/inferers/utils.py CHANGED
@@ -300,6 +300,7 @@ def sliding_window_inference(
300
300
 
301
301
  # remove padding if image_size smaller than roi_size
302
302
  if any(pad_size):
303
+ kwargs.update({"pad_size": pad_size})
303
304
  for ss, output_i in enumerate(output_image_list):
304
305
  zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
305
306
  final_slicing: list[slice] = []
monai/losses/__init__.py CHANGED
@@ -37,6 +37,7 @@ from .giou_loss import BoxGIoULoss, giou
37
37
  from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
38
38
  from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
39
39
  from .multi_scale import MultiScaleLoss
40
+ from .nacl_loss import NACLLoss
40
41
  from .perceptual import PerceptualLoss
41
42
  from .spatial_mask import MaskedLoss
42
43
  from .spectral_loss import JukeboxLoss
monai/losses/dice.py CHANGED
@@ -666,6 +666,7 @@ class DiceCELoss(_Loss):
666
666
  weight: torch.Tensor | None = None,
667
667
  lambda_dice: float = 1.0,
668
668
  lambda_ce: float = 1.0,
669
+ label_smoothing: float = 0.0,
669
670
  ) -> None:
670
671
  """
671
672
  Args:
@@ -704,6 +705,9 @@ class DiceCELoss(_Loss):
704
705
  Defaults to 1.0.
705
706
  lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
706
707
  Defaults to 1.0.
708
+ label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
709
+ by the given factor to reduce overfitting.
710
+ Defaults to 0.0.
707
711
 
708
712
  """
709
713
  super().__init__()
@@ -728,7 +732,12 @@ class DiceCELoss(_Loss):
728
732
  batch=batch,
729
733
  weight=dice_weight,
730
734
  )
731
- self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
735
+ if pytorch_after(1, 10):
736
+ self.cross_entropy = nn.CrossEntropyLoss(
737
+ weight=weight, reduction=reduction, label_smoothing=label_smoothing
738
+ )
739
+ else:
740
+ self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
732
741
  self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
733
742
  if lambda_dice < 0.0:
734
743
  raise ValueError("lambda_dice should be no less than 0.0.")