monai-weekly 1.4.dev2425__py3-none-any.whl → 1.4.dev2427__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/transforms.py +1 -1
- monai/apps/deepgrow/transforms.py +1 -1
- monai/apps/generation/__init__.py +10 -0
- monai/apps/generation/maisi/__init__.py +10 -0
- monai/apps/generation/maisi/networks/__init__.py +10 -0
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +975 -0
- monai/apps/generation/maisi/networks/controlnet_maisi.py +178 -0
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +410 -0
- monai/apps/generation/maisi/utils/__init__.py +10 -0
- monai/apps/generation/maisi/utils/morphological_ops.py +170 -0
- monai/apps/nuclick/transforms.py +1 -1
- monai/apps/pathology/transforms/post/array.py +1 -1
- monai/apps/pathology/utils.py +2 -2
- monai/data/torchscript_utils.py +1 -1
- monai/data/ultrasound_confidence_map.py +41 -10
- monai/losses/dice.py +10 -3
- monai/metrics/utils.py +3 -3
- monai/optimizers/lr_finder.py +1 -1
- monai/transforms/intensity/array.py +25 -2
- monai/transforms/signal/array.py +1 -1
- monai/utils/misc.py +20 -2
- monai/utils/module.py +6 -3
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/METADATA +6 -3
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/RECORD +29 -21
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,975 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import gc
|
15
|
+
import logging
|
16
|
+
from typing import TYPE_CHECKING, Sequence, cast
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from monai.networks.blocks import Convolution
|
23
|
+
from monai.utils import optional_import
|
24
|
+
from monai.utils.type_conversion import convert_to_tensor
|
25
|
+
|
26
|
+
AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
|
27
|
+
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
|
28
|
+
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
|
29
|
+
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
|
33
|
+
else:
|
34
|
+
AutoencoderKLType = cast(type, AutoencoderKL)
|
35
|
+
|
36
|
+
|
37
|
+
# Set up logging configuration
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
def _empty_cuda_cache(save_mem: bool) -> None:
|
42
|
+
if torch.cuda.is_available() and save_mem:
|
43
|
+
torch.cuda.empty_cache()
|
44
|
+
return
|
45
|
+
|
46
|
+
|
47
|
+
class MaisiGroupNorm3D(nn.GroupNorm):
|
48
|
+
"""
|
49
|
+
Custom 3D Group Normalization with optional print_info output.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
num_groups: Number of groups for the group norm.
|
53
|
+
num_channels: Number of channels for the group norm.
|
54
|
+
eps: Epsilon value for numerical stability.
|
55
|
+
affine: Whether to use learnable affine parameters, default to `True`.
|
56
|
+
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
|
57
|
+
print_info: Whether to print information, default to `False`.
|
58
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
59
|
+
"""
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
num_groups: int,
|
64
|
+
num_channels: int,
|
65
|
+
eps: float = 1e-5,
|
66
|
+
affine: bool = True,
|
67
|
+
norm_float16: bool = False,
|
68
|
+
print_info: bool = False,
|
69
|
+
save_mem: bool = True,
|
70
|
+
):
|
71
|
+
super().__init__(num_groups, num_channels, eps, affine)
|
72
|
+
self.norm_float16 = norm_float16
|
73
|
+
self.print_info = print_info
|
74
|
+
self.save_mem = save_mem
|
75
|
+
|
76
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
77
|
+
if self.print_info:
|
78
|
+
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")
|
79
|
+
|
80
|
+
if len(input.shape) != 5:
|
81
|
+
raise ValueError("Expected a 5D tensor")
|
82
|
+
|
83
|
+
param_n, param_c, param_d, param_h, param_w = input.shape
|
84
|
+
input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)
|
85
|
+
|
86
|
+
inputs = []
|
87
|
+
for i in range(input.size(1)):
|
88
|
+
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
|
89
|
+
mean = array.mean([2, 3, 4, 5], keepdim=True)
|
90
|
+
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
|
91
|
+
if self.norm_float16:
|
92
|
+
inputs.append(((array - mean) / std).to(dtype=torch.float16))
|
93
|
+
else:
|
94
|
+
inputs.append((array - mean) / std)
|
95
|
+
|
96
|
+
del input
|
97
|
+
_empty_cuda_cache(self.save_mem)
|
98
|
+
|
99
|
+
input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs)
|
100
|
+
|
101
|
+
input = input.view(param_n, param_c, param_d, param_h, param_w)
|
102
|
+
if self.affine:
|
103
|
+
input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1))
|
104
|
+
|
105
|
+
if self.print_info:
|
106
|
+
logger.info(f"MaisiGroupNorm3D with output size: {input.size()}")
|
107
|
+
|
108
|
+
return input
|
109
|
+
|
110
|
+
def _cat_inputs(self, inputs):
|
111
|
+
input_type = inputs[0].device.type
|
112
|
+
input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone()
|
113
|
+
inputs[0] = 0
|
114
|
+
_empty_cuda_cache(self.save_mem)
|
115
|
+
|
116
|
+
for k in range(len(inputs) - 1):
|
117
|
+
input = torch.cat((input, inputs[k + 1].cpu()), dim=1)
|
118
|
+
inputs[k + 1] = 0
|
119
|
+
_empty_cuda_cache(self.save_mem)
|
120
|
+
gc.collect()
|
121
|
+
|
122
|
+
if self.print_info:
|
123
|
+
logger.info(f"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.")
|
124
|
+
|
125
|
+
return input.to("cuda", non_blocking=True) if input_type == "cuda" else input
|
126
|
+
|
127
|
+
|
128
|
+
class MaisiConvolution(nn.Module):
|
129
|
+
"""
|
130
|
+
Convolutional layer with optional print_info output and custom splitting mechanism.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
134
|
+
in_channels: Number of input channels.
|
135
|
+
out_channels: Number of output channels.
|
136
|
+
num_splits: Number of splits for the input tensor.
|
137
|
+
dim_split: Dimension of splitting for the input tensor.
|
138
|
+
print_info: Whether to print information.
|
139
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
140
|
+
Additional arguments for the convolution operation.
|
141
|
+
https://docs.monai.io/en/stable/networks.html#convolution
|
142
|
+
"""
|
143
|
+
|
144
|
+
def __init__(
|
145
|
+
self,
|
146
|
+
spatial_dims: int,
|
147
|
+
in_channels: int,
|
148
|
+
out_channels: int,
|
149
|
+
num_splits: int,
|
150
|
+
dim_split: int,
|
151
|
+
print_info: bool,
|
152
|
+
save_mem: bool = True,
|
153
|
+
strides: Sequence[int] | int = 1,
|
154
|
+
kernel_size: Sequence[int] | int = 3,
|
155
|
+
adn_ordering: str = "NDA",
|
156
|
+
act: tuple | str | None = "PRELU",
|
157
|
+
norm: tuple | str | None = "INSTANCE",
|
158
|
+
dropout: tuple | str | float | None = None,
|
159
|
+
dropout_dim: int = 1,
|
160
|
+
dilation: Sequence[int] | int = 1,
|
161
|
+
groups: int = 1,
|
162
|
+
bias: bool = True,
|
163
|
+
conv_only: bool = False,
|
164
|
+
is_transposed: bool = False,
|
165
|
+
padding: Sequence[int] | int | None = None,
|
166
|
+
output_padding: Sequence[int] | int | None = None,
|
167
|
+
) -> None:
|
168
|
+
super().__init__()
|
169
|
+
self.conv = Convolution(
|
170
|
+
spatial_dims=spatial_dims,
|
171
|
+
in_channels=in_channels,
|
172
|
+
out_channels=out_channels,
|
173
|
+
strides=strides,
|
174
|
+
kernel_size=kernel_size,
|
175
|
+
adn_ordering=adn_ordering,
|
176
|
+
act=act,
|
177
|
+
norm=norm,
|
178
|
+
dropout=dropout,
|
179
|
+
dropout_dim=dropout_dim,
|
180
|
+
dilation=dilation,
|
181
|
+
groups=groups,
|
182
|
+
bias=bias,
|
183
|
+
conv_only=conv_only,
|
184
|
+
is_transposed=is_transposed,
|
185
|
+
padding=padding,
|
186
|
+
output_padding=output_padding,
|
187
|
+
)
|
188
|
+
|
189
|
+
self.dim_split = dim_split
|
190
|
+
self.stride = strides[self.dim_split] if isinstance(strides, list) else strides
|
191
|
+
self.num_splits = num_splits
|
192
|
+
self.print_info = print_info
|
193
|
+
self.save_mem = save_mem
|
194
|
+
|
195
|
+
def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]:
|
196
|
+
overlaps = [0] + [padding] * (self.num_splits - 1)
|
197
|
+
last_padding = x.size(self.dim_split + 2) % split_size
|
198
|
+
|
199
|
+
slices = [slice(None)] * 5
|
200
|
+
splits: list[torch.Tensor] = []
|
201
|
+
for i in range(self.num_splits):
|
202
|
+
slices[self.dim_split + 2] = slice(
|
203
|
+
i * split_size - overlaps[i],
|
204
|
+
(i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding),
|
205
|
+
)
|
206
|
+
splits.append(x[tuple(slices)])
|
207
|
+
|
208
|
+
if self.print_info:
|
209
|
+
for j in range(len(splits)):
|
210
|
+
logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}")
|
211
|
+
|
212
|
+
return splits
|
213
|
+
|
214
|
+
def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor:
|
215
|
+
slices = [slice(None)] * 5
|
216
|
+
for i in range(self.num_splits):
|
217
|
+
slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size)
|
218
|
+
outputs[i] = outputs[i][tuple(slices)]
|
219
|
+
|
220
|
+
if self.print_info:
|
221
|
+
for i in range(self.num_splits):
|
222
|
+
logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}")
|
223
|
+
|
224
|
+
if max(outputs[0].size()) < 500:
|
225
|
+
x = torch.cat(outputs, dim=self.dim_split + 2)
|
226
|
+
else:
|
227
|
+
x = outputs[0].clone().to("cpu", non_blocking=True)
|
228
|
+
outputs[0] = torch.Tensor(0)
|
229
|
+
_empty_cuda_cache(self.save_mem)
|
230
|
+
for k in range(len(outputs) - 1):
|
231
|
+
x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2)
|
232
|
+
outputs[k + 1] = torch.Tensor(0)
|
233
|
+
_empty_cuda_cache(self.save_mem)
|
234
|
+
gc.collect()
|
235
|
+
if self.print_info:
|
236
|
+
logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.")
|
237
|
+
|
238
|
+
x = x.to("cuda", non_blocking=True)
|
239
|
+
return x
|
240
|
+
|
241
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
242
|
+
if self.print_info:
|
243
|
+
logger.info(f"Number of splits: {self.num_splits}")
|
244
|
+
|
245
|
+
# compute size of splits
|
246
|
+
l = x.size(self.dim_split + 2)
|
247
|
+
split_size = l // self.num_splits
|
248
|
+
|
249
|
+
# update padding length if necessary
|
250
|
+
padding = 3
|
251
|
+
if padding % self.stride > 0:
|
252
|
+
padding = (padding // self.stride + 1) * self.stride
|
253
|
+
if self.print_info:
|
254
|
+
logger.info(f"Padding size: {padding}")
|
255
|
+
|
256
|
+
# split tensor into a list of tensors
|
257
|
+
splits = self._split_tensor(x, split_size, padding)
|
258
|
+
|
259
|
+
del x
|
260
|
+
_empty_cuda_cache(self.save_mem)
|
261
|
+
|
262
|
+
# convolution
|
263
|
+
outputs = [self.conv(split) for split in splits]
|
264
|
+
if self.print_info:
|
265
|
+
for j in range(len(outputs)):
|
266
|
+
logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}")
|
267
|
+
|
268
|
+
# update size of splits and padding length for output
|
269
|
+
split_size_out = split_size
|
270
|
+
padding_s = padding
|
271
|
+
non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0
|
272
|
+
if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2:
|
273
|
+
split_size_out *= 2
|
274
|
+
padding_s *= 2
|
275
|
+
elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2:
|
276
|
+
split_size_out //= 2
|
277
|
+
padding_s //= 2
|
278
|
+
|
279
|
+
# concatenate list of tensors
|
280
|
+
x = self._concatenate_tensors(outputs, split_size_out, padding_s)
|
281
|
+
|
282
|
+
del outputs
|
283
|
+
_empty_cuda_cache(self.save_mem)
|
284
|
+
|
285
|
+
return x
|
286
|
+
|
287
|
+
|
288
|
+
class MaisiUpsample(nn.Module):
|
289
|
+
"""
|
290
|
+
Convolution-based upsampling layer.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
294
|
+
in_channels: Number of input channels to the layer.
|
295
|
+
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
296
|
+
num_splits: Number of splits for the input tensor.
|
297
|
+
dim_split: Dimension of splitting for the input tensor.
|
298
|
+
print_info: Whether to print information.
|
299
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
300
|
+
"""
|
301
|
+
|
302
|
+
def __init__(
|
303
|
+
self,
|
304
|
+
spatial_dims: int,
|
305
|
+
in_channels: int,
|
306
|
+
use_convtranspose: bool,
|
307
|
+
num_splits: int,
|
308
|
+
dim_split: int,
|
309
|
+
print_info: bool,
|
310
|
+
save_mem: bool = True,
|
311
|
+
) -> None:
|
312
|
+
super().__init__()
|
313
|
+
self.conv = MaisiConvolution(
|
314
|
+
spatial_dims=spatial_dims,
|
315
|
+
in_channels=in_channels,
|
316
|
+
out_channels=in_channels,
|
317
|
+
strides=2 if use_convtranspose else 1,
|
318
|
+
kernel_size=3,
|
319
|
+
padding=1,
|
320
|
+
conv_only=True,
|
321
|
+
is_transposed=use_convtranspose,
|
322
|
+
num_splits=num_splits,
|
323
|
+
dim_split=dim_split,
|
324
|
+
print_info=print_info,
|
325
|
+
save_mem=save_mem,
|
326
|
+
)
|
327
|
+
self.use_convtranspose = use_convtranspose
|
328
|
+
self.save_mem = save_mem
|
329
|
+
|
330
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
331
|
+
if self.use_convtranspose:
|
332
|
+
x = self.conv(x)
|
333
|
+
x_tensor: torch.Tensor = convert_to_tensor(x)
|
334
|
+
return x_tensor
|
335
|
+
|
336
|
+
x = F.interpolate(x, scale_factor=2.0, mode="trilinear")
|
337
|
+
_empty_cuda_cache(self.save_mem)
|
338
|
+
x = self.conv(x)
|
339
|
+
_empty_cuda_cache(self.save_mem)
|
340
|
+
|
341
|
+
out_tensor: torch.Tensor = convert_to_tensor(x)
|
342
|
+
return out_tensor
|
343
|
+
|
344
|
+
|
345
|
+
class MaisiDownsample(nn.Module):
|
346
|
+
"""
|
347
|
+
Convolution-based downsampling layer.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
351
|
+
in_channels: Number of input channels.
|
352
|
+
num_splits: Number of splits for the input tensor.
|
353
|
+
dim_split: Dimension of splitting for the input tensor.
|
354
|
+
print_info: Whether to print information.
|
355
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
356
|
+
"""
|
357
|
+
|
358
|
+
def __init__(
|
359
|
+
self,
|
360
|
+
spatial_dims: int,
|
361
|
+
in_channels: int,
|
362
|
+
num_splits: int,
|
363
|
+
dim_split: int,
|
364
|
+
print_info: bool,
|
365
|
+
save_mem: bool = True,
|
366
|
+
) -> None:
|
367
|
+
super().__init__()
|
368
|
+
self.pad = (0, 1) * spatial_dims
|
369
|
+
self.conv = MaisiConvolution(
|
370
|
+
spatial_dims=spatial_dims,
|
371
|
+
in_channels=in_channels,
|
372
|
+
out_channels=in_channels,
|
373
|
+
strides=2,
|
374
|
+
kernel_size=3,
|
375
|
+
padding=0,
|
376
|
+
conv_only=True,
|
377
|
+
num_splits=num_splits,
|
378
|
+
dim_split=dim_split,
|
379
|
+
print_info=print_info,
|
380
|
+
save_mem=save_mem,
|
381
|
+
)
|
382
|
+
|
383
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
384
|
+
x = F.pad(x, self.pad, mode="constant", value=0.0)
|
385
|
+
x = self.conv(x)
|
386
|
+
return x
|
387
|
+
|
388
|
+
|
389
|
+
class MaisiResBlock(nn.Module):
|
390
|
+
"""
|
391
|
+
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
|
392
|
+
residual connection between input and output.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
396
|
+
in_channels: Input channels to the layer.
|
397
|
+
norm_num_groups: Number of groups for the group norm layer.
|
398
|
+
norm_eps: Epsilon for the normalization.
|
399
|
+
out_channels: Number of output channels.
|
400
|
+
num_splits: Number of splits for the input tensor.
|
401
|
+
dim_split: Dimension of splitting for the input tensor.
|
402
|
+
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
|
403
|
+
print_info: Whether to print information, default to `False`.
|
404
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
405
|
+
"""
|
406
|
+
|
407
|
+
def __init__(
|
408
|
+
self,
|
409
|
+
spatial_dims: int,
|
410
|
+
in_channels: int,
|
411
|
+
norm_num_groups: int,
|
412
|
+
norm_eps: float,
|
413
|
+
out_channels: int,
|
414
|
+
num_splits: int,
|
415
|
+
dim_split: int,
|
416
|
+
norm_float16: bool = False,
|
417
|
+
print_info: bool = False,
|
418
|
+
save_mem: bool = True,
|
419
|
+
) -> None:
|
420
|
+
super().__init__()
|
421
|
+
self.in_channels = in_channels
|
422
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
423
|
+
self.save_mem = save_mem
|
424
|
+
|
425
|
+
self.norm1 = MaisiGroupNorm3D(
|
426
|
+
num_groups=norm_num_groups,
|
427
|
+
num_channels=in_channels,
|
428
|
+
eps=norm_eps,
|
429
|
+
affine=True,
|
430
|
+
norm_float16=norm_float16,
|
431
|
+
print_info=print_info,
|
432
|
+
save_mem=save_mem,
|
433
|
+
)
|
434
|
+
self.conv1 = MaisiConvolution(
|
435
|
+
spatial_dims=spatial_dims,
|
436
|
+
in_channels=self.in_channels,
|
437
|
+
out_channels=self.out_channels,
|
438
|
+
strides=1,
|
439
|
+
kernel_size=3,
|
440
|
+
padding=1,
|
441
|
+
conv_only=True,
|
442
|
+
num_splits=num_splits,
|
443
|
+
dim_split=dim_split,
|
444
|
+
print_info=print_info,
|
445
|
+
save_mem=save_mem,
|
446
|
+
)
|
447
|
+
self.norm2 = MaisiGroupNorm3D(
|
448
|
+
num_groups=norm_num_groups,
|
449
|
+
num_channels=out_channels,
|
450
|
+
eps=norm_eps,
|
451
|
+
affine=True,
|
452
|
+
norm_float16=norm_float16,
|
453
|
+
print_info=print_info,
|
454
|
+
save_mem=save_mem,
|
455
|
+
)
|
456
|
+
self.conv2 = MaisiConvolution(
|
457
|
+
spatial_dims=spatial_dims,
|
458
|
+
in_channels=self.out_channels,
|
459
|
+
out_channels=self.out_channels,
|
460
|
+
strides=1,
|
461
|
+
kernel_size=3,
|
462
|
+
padding=1,
|
463
|
+
conv_only=True,
|
464
|
+
num_splits=num_splits,
|
465
|
+
dim_split=dim_split,
|
466
|
+
print_info=print_info,
|
467
|
+
save_mem=save_mem,
|
468
|
+
)
|
469
|
+
|
470
|
+
self.nin_shortcut = (
|
471
|
+
MaisiConvolution(
|
472
|
+
spatial_dims=spatial_dims,
|
473
|
+
in_channels=self.in_channels,
|
474
|
+
out_channels=self.out_channels,
|
475
|
+
strides=1,
|
476
|
+
kernel_size=1,
|
477
|
+
padding=0,
|
478
|
+
conv_only=True,
|
479
|
+
num_splits=num_splits,
|
480
|
+
dim_split=dim_split,
|
481
|
+
print_info=print_info,
|
482
|
+
save_mem=save_mem,
|
483
|
+
)
|
484
|
+
if self.in_channels != self.out_channels
|
485
|
+
else nn.Identity()
|
486
|
+
)
|
487
|
+
|
488
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
489
|
+
h = self.norm1(x)
|
490
|
+
_empty_cuda_cache(self.save_mem)
|
491
|
+
|
492
|
+
h = F.silu(h)
|
493
|
+
_empty_cuda_cache(self.save_mem)
|
494
|
+
h = self.conv1(h)
|
495
|
+
_empty_cuda_cache(self.save_mem)
|
496
|
+
|
497
|
+
h = self.norm2(h)
|
498
|
+
_empty_cuda_cache(self.save_mem)
|
499
|
+
|
500
|
+
h = F.silu(h)
|
501
|
+
_empty_cuda_cache(self.save_mem)
|
502
|
+
h = self.conv2(h)
|
503
|
+
_empty_cuda_cache(self.save_mem)
|
504
|
+
|
505
|
+
if self.in_channels != self.out_channels:
|
506
|
+
x = self.nin_shortcut(x)
|
507
|
+
_empty_cuda_cache(self.save_mem)
|
508
|
+
|
509
|
+
out = x + h
|
510
|
+
out_tensor: torch.Tensor = convert_to_tensor(out)
|
511
|
+
return out_tensor
|
512
|
+
|
513
|
+
|
514
|
+
class MaisiEncoder(nn.Module):
|
515
|
+
"""
|
516
|
+
Convolutional cascade that downsamples the image into a spatial latent space.
|
517
|
+
|
518
|
+
Args:
|
519
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
520
|
+
in_channels: Number of input channels.
|
521
|
+
num_channels: Sequence of block output channels.
|
522
|
+
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
523
|
+
num_res_blocks: Number of residual blocks (see ResBlock) per level.
|
524
|
+
norm_num_groups: Number of groups for the group norm layers.
|
525
|
+
norm_eps: Epsilon for the normalization.
|
526
|
+
attention_levels: Indicate which level from num_channels contain an attention block.
|
527
|
+
with_nonlocal_attn: If True, use non-local attention block.
|
528
|
+
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
529
|
+
num_splits: Number of splits for the input tensor.
|
530
|
+
dim_split: Dimension of splitting for the input tensor.
|
531
|
+
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
|
532
|
+
print_info: Whether to print information, default to `False`.
|
533
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
534
|
+
"""
|
535
|
+
|
536
|
+
def __init__(
|
537
|
+
self,
|
538
|
+
spatial_dims: int,
|
539
|
+
in_channels: int,
|
540
|
+
num_channels: Sequence[int],
|
541
|
+
out_channels: int,
|
542
|
+
num_res_blocks: Sequence[int],
|
543
|
+
norm_num_groups: int,
|
544
|
+
norm_eps: float,
|
545
|
+
attention_levels: Sequence[bool],
|
546
|
+
num_splits: int,
|
547
|
+
dim_split: int,
|
548
|
+
norm_float16: bool = False,
|
549
|
+
print_info: bool = False,
|
550
|
+
save_mem: bool = True,
|
551
|
+
with_nonlocal_attn: bool = True,
|
552
|
+
use_flash_attention: bool = False,
|
553
|
+
) -> None:
|
554
|
+
super().__init__()
|
555
|
+
|
556
|
+
# Check if attention_levels and num_channels have the same size
|
557
|
+
if len(attention_levels) != len(num_channels):
|
558
|
+
raise ValueError("attention_levels and num_channels must have the same size")
|
559
|
+
|
560
|
+
# Check if num_res_blocks and num_channels have the same size
|
561
|
+
if len(num_res_blocks) != len(num_channels):
|
562
|
+
raise ValueError("num_res_blocks and num_channels must have the same size")
|
563
|
+
|
564
|
+
self.save_mem = save_mem
|
565
|
+
|
566
|
+
blocks: list[nn.Module] = []
|
567
|
+
|
568
|
+
blocks.append(
|
569
|
+
MaisiConvolution(
|
570
|
+
spatial_dims=spatial_dims,
|
571
|
+
in_channels=in_channels,
|
572
|
+
out_channels=num_channels[0],
|
573
|
+
strides=1,
|
574
|
+
kernel_size=3,
|
575
|
+
padding=1,
|
576
|
+
conv_only=True,
|
577
|
+
num_splits=num_splits,
|
578
|
+
dim_split=dim_split,
|
579
|
+
print_info=print_info,
|
580
|
+
save_mem=save_mem,
|
581
|
+
)
|
582
|
+
)
|
583
|
+
|
584
|
+
output_channel = num_channels[0]
|
585
|
+
for i in range(len(num_channels)):
|
586
|
+
input_channel = output_channel
|
587
|
+
output_channel = num_channels[i]
|
588
|
+
is_final_block = i == len(num_channels) - 1
|
589
|
+
|
590
|
+
for _ in range(num_res_blocks[i]):
|
591
|
+
blocks.append(
|
592
|
+
MaisiResBlock(
|
593
|
+
spatial_dims=spatial_dims,
|
594
|
+
in_channels=input_channel,
|
595
|
+
norm_num_groups=norm_num_groups,
|
596
|
+
norm_eps=norm_eps,
|
597
|
+
out_channels=output_channel,
|
598
|
+
num_splits=num_splits,
|
599
|
+
dim_split=dim_split,
|
600
|
+
norm_float16=norm_float16,
|
601
|
+
print_info=print_info,
|
602
|
+
save_mem=save_mem,
|
603
|
+
)
|
604
|
+
)
|
605
|
+
input_channel = output_channel
|
606
|
+
if attention_levels[i]:
|
607
|
+
blocks.append(
|
608
|
+
AttentionBlock(
|
609
|
+
spatial_dims=spatial_dims,
|
610
|
+
num_channels=input_channel,
|
611
|
+
norm_num_groups=norm_num_groups,
|
612
|
+
norm_eps=norm_eps,
|
613
|
+
use_flash_attention=use_flash_attention,
|
614
|
+
)
|
615
|
+
)
|
616
|
+
|
617
|
+
if not is_final_block:
|
618
|
+
blocks.append(
|
619
|
+
MaisiDownsample(
|
620
|
+
spatial_dims=spatial_dims,
|
621
|
+
in_channels=input_channel,
|
622
|
+
num_splits=num_splits,
|
623
|
+
dim_split=dim_split,
|
624
|
+
print_info=print_info,
|
625
|
+
save_mem=save_mem,
|
626
|
+
)
|
627
|
+
)
|
628
|
+
|
629
|
+
if with_nonlocal_attn:
|
630
|
+
blocks.append(
|
631
|
+
ResBlock(
|
632
|
+
spatial_dims=spatial_dims,
|
633
|
+
in_channels=num_channels[-1],
|
634
|
+
norm_num_groups=norm_num_groups,
|
635
|
+
norm_eps=norm_eps,
|
636
|
+
out_channels=num_channels[-1],
|
637
|
+
)
|
638
|
+
)
|
639
|
+
|
640
|
+
blocks.append(
|
641
|
+
AttentionBlock(
|
642
|
+
spatial_dims=spatial_dims,
|
643
|
+
num_channels=num_channels[-1],
|
644
|
+
norm_num_groups=norm_num_groups,
|
645
|
+
norm_eps=norm_eps,
|
646
|
+
use_flash_attention=use_flash_attention,
|
647
|
+
)
|
648
|
+
)
|
649
|
+
blocks.append(
|
650
|
+
ResBlock(
|
651
|
+
spatial_dims=spatial_dims,
|
652
|
+
in_channels=num_channels[-1],
|
653
|
+
norm_num_groups=norm_num_groups,
|
654
|
+
norm_eps=norm_eps,
|
655
|
+
out_channels=num_channels[-1],
|
656
|
+
)
|
657
|
+
)
|
658
|
+
|
659
|
+
blocks.append(
|
660
|
+
MaisiGroupNorm3D(
|
661
|
+
num_groups=norm_num_groups,
|
662
|
+
num_channels=num_channels[-1],
|
663
|
+
eps=norm_eps,
|
664
|
+
affine=True,
|
665
|
+
norm_float16=norm_float16,
|
666
|
+
print_info=print_info,
|
667
|
+
save_mem=save_mem,
|
668
|
+
)
|
669
|
+
)
|
670
|
+
blocks.append(
|
671
|
+
MaisiConvolution(
|
672
|
+
spatial_dims=spatial_dims,
|
673
|
+
in_channels=num_channels[-1],
|
674
|
+
out_channels=out_channels,
|
675
|
+
strides=1,
|
676
|
+
kernel_size=3,
|
677
|
+
padding=1,
|
678
|
+
conv_only=True,
|
679
|
+
num_splits=num_splits,
|
680
|
+
dim_split=dim_split,
|
681
|
+
print_info=print_info,
|
682
|
+
save_mem=save_mem,
|
683
|
+
)
|
684
|
+
)
|
685
|
+
|
686
|
+
self.blocks = nn.ModuleList(blocks)
|
687
|
+
|
688
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
689
|
+
for block in self.blocks:
|
690
|
+
x = block(x)
|
691
|
+
_empty_cuda_cache(self.save_mem)
|
692
|
+
return x
|
693
|
+
|
694
|
+
|
695
|
+
class MaisiDecoder(nn.Module):
|
696
|
+
"""
|
697
|
+
Convolutional cascade upsampling from a spatial latent space into an image space.
|
698
|
+
|
699
|
+
Args:
|
700
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
701
|
+
num_channels: Sequence of block output channels.
|
702
|
+
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
703
|
+
out_channels: Number of output channels.
|
704
|
+
num_res_blocks: Number of residual blocks (see ResBlock) per level.
|
705
|
+
norm_num_groups: Number of groups for the group norm layers.
|
706
|
+
norm_eps: Epsilon for the normalization.
|
707
|
+
attention_levels: Indicate which level from num_channels contain an attention block.
|
708
|
+
with_nonlocal_attn: If True, use non-local attention block.
|
709
|
+
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
710
|
+
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
711
|
+
num_splits: Number of splits for the input tensor.
|
712
|
+
dim_split: Dimension of splitting for the input tensor.
|
713
|
+
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
|
714
|
+
print_info: Whether to print information, default to `False`.
|
715
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
716
|
+
"""
|
717
|
+
|
718
|
+
def __init__(
|
719
|
+
self,
|
720
|
+
spatial_dims: int,
|
721
|
+
num_channels: Sequence[int],
|
722
|
+
in_channels: int,
|
723
|
+
out_channels: int,
|
724
|
+
num_res_blocks: Sequence[int],
|
725
|
+
norm_num_groups: int,
|
726
|
+
norm_eps: float,
|
727
|
+
attention_levels: Sequence[bool],
|
728
|
+
num_splits: int,
|
729
|
+
dim_split: int,
|
730
|
+
norm_float16: bool = False,
|
731
|
+
print_info: bool = False,
|
732
|
+
save_mem: bool = True,
|
733
|
+
with_nonlocal_attn: bool = True,
|
734
|
+
use_flash_attention: bool = False,
|
735
|
+
use_convtranspose: bool = False,
|
736
|
+
) -> None:
|
737
|
+
super().__init__()
|
738
|
+
self.print_info = print_info
|
739
|
+
self.save_mem = save_mem
|
740
|
+
|
741
|
+
reversed_block_out_channels = list(reversed(num_channels))
|
742
|
+
|
743
|
+
blocks: list[nn.Module] = []
|
744
|
+
|
745
|
+
blocks.append(
|
746
|
+
MaisiConvolution(
|
747
|
+
spatial_dims=spatial_dims,
|
748
|
+
in_channels=in_channels,
|
749
|
+
out_channels=reversed_block_out_channels[0],
|
750
|
+
strides=1,
|
751
|
+
kernel_size=3,
|
752
|
+
padding=1,
|
753
|
+
conv_only=True,
|
754
|
+
num_splits=num_splits,
|
755
|
+
dim_split=dim_split,
|
756
|
+
print_info=print_info,
|
757
|
+
save_mem=save_mem,
|
758
|
+
)
|
759
|
+
)
|
760
|
+
|
761
|
+
if with_nonlocal_attn:
|
762
|
+
blocks.append(
|
763
|
+
ResBlock(
|
764
|
+
spatial_dims=spatial_dims,
|
765
|
+
in_channels=reversed_block_out_channels[0],
|
766
|
+
norm_num_groups=norm_num_groups,
|
767
|
+
norm_eps=norm_eps,
|
768
|
+
out_channels=reversed_block_out_channels[0],
|
769
|
+
)
|
770
|
+
)
|
771
|
+
blocks.append(
|
772
|
+
AttentionBlock(
|
773
|
+
spatial_dims=spatial_dims,
|
774
|
+
num_channels=reversed_block_out_channels[0],
|
775
|
+
norm_num_groups=norm_num_groups,
|
776
|
+
norm_eps=norm_eps,
|
777
|
+
use_flash_attention=use_flash_attention,
|
778
|
+
)
|
779
|
+
)
|
780
|
+
blocks.append(
|
781
|
+
ResBlock(
|
782
|
+
spatial_dims=spatial_dims,
|
783
|
+
in_channels=reversed_block_out_channels[0],
|
784
|
+
norm_num_groups=norm_num_groups,
|
785
|
+
norm_eps=norm_eps,
|
786
|
+
out_channels=reversed_block_out_channels[0],
|
787
|
+
)
|
788
|
+
)
|
789
|
+
|
790
|
+
reversed_attention_levels = list(reversed(attention_levels))
|
791
|
+
reversed_num_res_blocks = list(reversed(num_res_blocks))
|
792
|
+
block_out_ch = reversed_block_out_channels[0]
|
793
|
+
for i in range(len(reversed_block_out_channels)):
|
794
|
+
block_in_ch = block_out_ch
|
795
|
+
block_out_ch = reversed_block_out_channels[i]
|
796
|
+
is_final_block = i == len(num_channels) - 1
|
797
|
+
|
798
|
+
for _ in range(reversed_num_res_blocks[i]):
|
799
|
+
blocks.append(
|
800
|
+
MaisiResBlock(
|
801
|
+
spatial_dims=spatial_dims,
|
802
|
+
in_channels=block_in_ch,
|
803
|
+
norm_num_groups=norm_num_groups,
|
804
|
+
norm_eps=norm_eps,
|
805
|
+
out_channels=block_out_ch,
|
806
|
+
num_splits=num_splits,
|
807
|
+
dim_split=dim_split,
|
808
|
+
norm_float16=norm_float16,
|
809
|
+
print_info=print_info,
|
810
|
+
save_mem=save_mem,
|
811
|
+
)
|
812
|
+
)
|
813
|
+
block_in_ch = block_out_ch
|
814
|
+
|
815
|
+
if reversed_attention_levels[i]:
|
816
|
+
blocks.append(
|
817
|
+
AttentionBlock(
|
818
|
+
spatial_dims=spatial_dims,
|
819
|
+
num_channels=block_in_ch,
|
820
|
+
norm_num_groups=norm_num_groups,
|
821
|
+
norm_eps=norm_eps,
|
822
|
+
use_flash_attention=use_flash_attention,
|
823
|
+
)
|
824
|
+
)
|
825
|
+
|
826
|
+
if not is_final_block:
|
827
|
+
blocks.append(
|
828
|
+
MaisiUpsample(
|
829
|
+
spatial_dims=spatial_dims,
|
830
|
+
in_channels=block_in_ch,
|
831
|
+
use_convtranspose=use_convtranspose,
|
832
|
+
num_splits=num_splits,
|
833
|
+
dim_split=dim_split,
|
834
|
+
print_info=print_info,
|
835
|
+
save_mem=save_mem,
|
836
|
+
)
|
837
|
+
)
|
838
|
+
|
839
|
+
blocks.append(
|
840
|
+
MaisiGroupNorm3D(
|
841
|
+
num_groups=norm_num_groups,
|
842
|
+
num_channels=block_in_ch,
|
843
|
+
eps=norm_eps,
|
844
|
+
affine=True,
|
845
|
+
norm_float16=norm_float16,
|
846
|
+
print_info=print_info,
|
847
|
+
save_mem=save_mem,
|
848
|
+
)
|
849
|
+
)
|
850
|
+
blocks.append(
|
851
|
+
MaisiConvolution(
|
852
|
+
spatial_dims=spatial_dims,
|
853
|
+
in_channels=block_in_ch,
|
854
|
+
out_channels=out_channels,
|
855
|
+
strides=1,
|
856
|
+
kernel_size=3,
|
857
|
+
padding=1,
|
858
|
+
conv_only=True,
|
859
|
+
num_splits=num_splits,
|
860
|
+
dim_split=dim_split,
|
861
|
+
print_info=print_info,
|
862
|
+
save_mem=save_mem,
|
863
|
+
)
|
864
|
+
)
|
865
|
+
|
866
|
+
self.blocks = nn.ModuleList(blocks)
|
867
|
+
|
868
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
869
|
+
for block in self.blocks:
|
870
|
+
x = block(x)
|
871
|
+
_empty_cuda_cache(self.save_mem)
|
872
|
+
return x
|
873
|
+
|
874
|
+
|
875
|
+
class AutoencoderKlMaisi(AutoencoderKLType):
|
876
|
+
"""
|
877
|
+
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
|
878
|
+
|
879
|
+
Args:
|
880
|
+
spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
|
881
|
+
in_channels: Number of input channels.
|
882
|
+
out_channels: Number of output channels.
|
883
|
+
num_res_blocks: Number of residual blocks per level.
|
884
|
+
num_channels: Sequence of block output channels.
|
885
|
+
attention_levels: Indicate which level from num_channels contain an attention block.
|
886
|
+
latent_channels: Number of channels in the latent space.
|
887
|
+
norm_num_groups: Number of groups for the group norm layers.
|
888
|
+
norm_eps: Epsilon for the normalization.
|
889
|
+
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
|
890
|
+
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
|
891
|
+
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
892
|
+
use_checkpointing: If True, use activation checkpointing.
|
893
|
+
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
894
|
+
num_splits: Number of splits for the input tensor.
|
895
|
+
dim_split: Dimension of splitting for the input tensor.
|
896
|
+
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
|
897
|
+
print_info: Whether to print information, default to `False`.
|
898
|
+
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
|
899
|
+
"""
|
900
|
+
|
901
|
+
def __init__(
|
902
|
+
self,
|
903
|
+
spatial_dims: int,
|
904
|
+
in_channels: int,
|
905
|
+
out_channels: int,
|
906
|
+
num_res_blocks: Sequence[int],
|
907
|
+
num_channels: Sequence[int],
|
908
|
+
attention_levels: Sequence[bool],
|
909
|
+
latent_channels: int = 3,
|
910
|
+
norm_num_groups: int = 32,
|
911
|
+
norm_eps: float = 1e-6,
|
912
|
+
with_encoder_nonlocal_attn: bool = False,
|
913
|
+
with_decoder_nonlocal_attn: bool = False,
|
914
|
+
use_flash_attention: bool = False,
|
915
|
+
use_checkpointing: bool = False,
|
916
|
+
use_convtranspose: bool = False,
|
917
|
+
num_splits: int = 16,
|
918
|
+
dim_split: int = 0,
|
919
|
+
norm_float16: bool = False,
|
920
|
+
print_info: bool = False,
|
921
|
+
save_mem: bool = True,
|
922
|
+
) -> None:
|
923
|
+
super().__init__(
|
924
|
+
spatial_dims,
|
925
|
+
in_channels,
|
926
|
+
out_channels,
|
927
|
+
num_res_blocks,
|
928
|
+
num_channels,
|
929
|
+
attention_levels,
|
930
|
+
latent_channels,
|
931
|
+
norm_num_groups,
|
932
|
+
norm_eps,
|
933
|
+
with_encoder_nonlocal_attn,
|
934
|
+
with_decoder_nonlocal_attn,
|
935
|
+
use_flash_attention,
|
936
|
+
use_checkpointing,
|
937
|
+
use_convtranspose,
|
938
|
+
)
|
939
|
+
|
940
|
+
self.encoder = MaisiEncoder(
|
941
|
+
spatial_dims=spatial_dims,
|
942
|
+
in_channels=in_channels,
|
943
|
+
num_channels=num_channels,
|
944
|
+
out_channels=latent_channels,
|
945
|
+
num_res_blocks=num_res_blocks,
|
946
|
+
norm_num_groups=norm_num_groups,
|
947
|
+
norm_eps=norm_eps,
|
948
|
+
attention_levels=attention_levels,
|
949
|
+
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
950
|
+
use_flash_attention=use_flash_attention,
|
951
|
+
num_splits=num_splits,
|
952
|
+
dim_split=dim_split,
|
953
|
+
norm_float16=norm_float16,
|
954
|
+
print_info=print_info,
|
955
|
+
save_mem=save_mem,
|
956
|
+
)
|
957
|
+
|
958
|
+
self.decoder = MaisiDecoder(
|
959
|
+
spatial_dims=spatial_dims,
|
960
|
+
num_channels=num_channels,
|
961
|
+
in_channels=latent_channels,
|
962
|
+
out_channels=out_channels,
|
963
|
+
num_res_blocks=num_res_blocks,
|
964
|
+
norm_num_groups=norm_num_groups,
|
965
|
+
norm_eps=norm_eps,
|
966
|
+
attention_levels=attention_levels,
|
967
|
+
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
968
|
+
use_flash_attention=use_flash_attention,
|
969
|
+
use_convtranspose=use_convtranspose,
|
970
|
+
num_splits=num_splits,
|
971
|
+
dim_split=dim_split,
|
972
|
+
norm_float16=norm_float16,
|
973
|
+
print_info=print_info,
|
974
|
+
save_mem=save_mem,
|
975
|
+
)
|