monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/scripts.py +29 -17
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/inferers/utils.py +1 -0
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +22 -0
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +946 -0
- monai/networks/utils.py +4 -4
- monai/transforms/__init__.py +13 -2
- monai/transforms/io/array.py +59 -3
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utility/dictionary.py +4 -0
- monai/transforms/utils.py +230 -1
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/enums.py +1 -0
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -157,6 +157,10 @@ class Encoder(nn.Module):
|
|
157
157
|
norm_eps: epsilon for the normalization.
|
158
158
|
attention_levels: indicate which level from num_channels contain an attention block.
|
159
159
|
with_nonlocal_attn: if True use non-local attention block.
|
160
|
+
include_fc: whether to include the final linear layer. Default to True.
|
161
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
162
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
163
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
160
164
|
"""
|
161
165
|
|
162
166
|
def __init__(
|
@@ -170,6 +174,9 @@ class Encoder(nn.Module):
|
|
170
174
|
norm_eps: float,
|
171
175
|
attention_levels: Sequence[bool],
|
172
176
|
with_nonlocal_attn: bool = True,
|
177
|
+
include_fc: bool = True,
|
178
|
+
use_combined_linear: bool = False,
|
179
|
+
use_flash_attention: bool = False,
|
173
180
|
) -> None:
|
174
181
|
super().__init__()
|
175
182
|
self.spatial_dims = spatial_dims
|
@@ -220,6 +227,9 @@ class Encoder(nn.Module):
|
|
220
227
|
num_channels=input_channel,
|
221
228
|
norm_num_groups=norm_num_groups,
|
222
229
|
norm_eps=norm_eps,
|
230
|
+
include_fc=include_fc,
|
231
|
+
use_combined_linear=use_combined_linear,
|
232
|
+
use_flash_attention=use_flash_attention,
|
223
233
|
)
|
224
234
|
)
|
225
235
|
|
@@ -243,6 +253,9 @@ class Encoder(nn.Module):
|
|
243
253
|
num_channels=channels[-1],
|
244
254
|
norm_num_groups=norm_num_groups,
|
245
255
|
norm_eps=norm_eps,
|
256
|
+
include_fc=include_fc,
|
257
|
+
use_combined_linear=use_combined_linear,
|
258
|
+
use_flash_attention=use_flash_attention,
|
246
259
|
)
|
247
260
|
)
|
248
261
|
blocks.append(
|
@@ -291,6 +304,10 @@ class Decoder(nn.Module):
|
|
291
304
|
attention_levels: indicate which level from num_channels contain an attention block.
|
292
305
|
with_nonlocal_attn: if True use non-local attention block.
|
293
306
|
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
|
307
|
+
include_fc: whether to include the final linear layer. Default to True.
|
308
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
309
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
310
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
294
311
|
"""
|
295
312
|
|
296
313
|
def __init__(
|
@@ -305,6 +322,9 @@ class Decoder(nn.Module):
|
|
305
322
|
attention_levels: Sequence[bool],
|
306
323
|
with_nonlocal_attn: bool = True,
|
307
324
|
use_convtranspose: bool = False,
|
325
|
+
include_fc: bool = True,
|
326
|
+
use_combined_linear: bool = False,
|
327
|
+
use_flash_attention: bool = False,
|
308
328
|
) -> None:
|
309
329
|
super().__init__()
|
310
330
|
self.spatial_dims = spatial_dims
|
@@ -350,6 +370,9 @@ class Decoder(nn.Module):
|
|
350
370
|
num_channels=reversed_block_out_channels[0],
|
351
371
|
norm_num_groups=norm_num_groups,
|
352
372
|
norm_eps=norm_eps,
|
373
|
+
include_fc=include_fc,
|
374
|
+
use_combined_linear=use_combined_linear,
|
375
|
+
use_flash_attention=use_flash_attention,
|
353
376
|
)
|
354
377
|
)
|
355
378
|
blocks.append(
|
@@ -389,6 +412,9 @@ class Decoder(nn.Module):
|
|
389
412
|
num_channels=block_in_ch,
|
390
413
|
norm_num_groups=norm_num_groups,
|
391
414
|
norm_eps=norm_eps,
|
415
|
+
include_fc=include_fc,
|
416
|
+
use_combined_linear=use_combined_linear,
|
417
|
+
use_flash_attention=use_flash_attention,
|
392
418
|
)
|
393
419
|
)
|
394
420
|
|
@@ -463,6 +489,10 @@ class AutoencoderKL(nn.Module):
|
|
463
489
|
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
|
464
490
|
use_checkpoint: if True, use activation checkpoint to save memory.
|
465
491
|
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
|
492
|
+
include_fc: whether to include the final linear layer in the attention block. Default to True.
|
493
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
494
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
495
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
466
496
|
"""
|
467
497
|
|
468
498
|
def __init__(
|
@@ -480,6 +510,9 @@ class AutoencoderKL(nn.Module):
|
|
480
510
|
with_decoder_nonlocal_attn: bool = True,
|
481
511
|
use_checkpoint: bool = False,
|
482
512
|
use_convtranspose: bool = False,
|
513
|
+
include_fc: bool = True,
|
514
|
+
use_combined_linear: bool = False,
|
515
|
+
use_flash_attention: bool = False,
|
483
516
|
) -> None:
|
484
517
|
super().__init__()
|
485
518
|
|
@@ -499,7 +532,7 @@ class AutoencoderKL(nn.Module):
|
|
499
532
|
"`num_channels`."
|
500
533
|
)
|
501
534
|
|
502
|
-
self.encoder = Encoder(
|
535
|
+
self.encoder: nn.Module = Encoder(
|
503
536
|
spatial_dims=spatial_dims,
|
504
537
|
in_channels=in_channels,
|
505
538
|
channels=channels,
|
@@ -509,8 +542,11 @@ class AutoencoderKL(nn.Module):
|
|
509
542
|
norm_eps=norm_eps,
|
510
543
|
attention_levels=attention_levels,
|
511
544
|
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
545
|
+
include_fc=include_fc,
|
546
|
+
use_combined_linear=use_combined_linear,
|
547
|
+
use_flash_attention=use_flash_attention,
|
512
548
|
)
|
513
|
-
self.decoder = Decoder(
|
549
|
+
self.decoder: nn.Module = Decoder(
|
514
550
|
spatial_dims=spatial_dims,
|
515
551
|
channels=channels,
|
516
552
|
in_channels=latent_channels,
|
@@ -521,6 +557,9 @@ class AutoencoderKL(nn.Module):
|
|
521
557
|
attention_levels=attention_levels,
|
522
558
|
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
523
559
|
use_convtranspose=use_convtranspose,
|
560
|
+
include_fc=include_fc,
|
561
|
+
use_combined_linear=use_combined_linear,
|
562
|
+
use_flash_attention=use_flash_attention,
|
524
563
|
)
|
525
564
|
self.quant_conv_mu = Convolution(
|
526
565
|
spatial_dims=spatial_dims,
|
@@ -665,27 +704,18 @@ class AutoencoderKL(nn.Module):
|
|
665
704
|
# copy over all matching keys
|
666
705
|
for k in new_state_dict:
|
667
706
|
if k in old_state_dict:
|
668
|
-
new_state_dict[k] = old_state_dict
|
707
|
+
new_state_dict[k] = old_state_dict.pop(k)
|
669
708
|
|
670
709
|
# fix the attention blocks
|
671
|
-
attention_blocks = [k.replace(".attn.
|
710
|
+
attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k]
|
672
711
|
for block in attention_blocks:
|
673
|
-
new_state_dict[f"{block}.attn.
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
)
|
681
|
-
new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat(
|
682
|
-
[
|
683
|
-
old_state_dict[f"{block}.to_q.bias"],
|
684
|
-
old_state_dict[f"{block}.to_k.bias"],
|
685
|
-
old_state_dict[f"{block}.to_v.bias"],
|
686
|
-
],
|
687
|
-
dim=0,
|
688
|
-
)
|
712
|
+
new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
|
713
|
+
new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
|
714
|
+
new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
|
715
|
+
new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
|
716
|
+
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
|
717
|
+
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
|
718
|
+
|
689
719
|
# old version did not have a projection so set these to the identity
|
690
720
|
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
|
691
721
|
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
|
@@ -698,5 +728,8 @@ class AutoencoderKL(nn.Module):
|
|
698
728
|
for k in new_state_dict:
|
699
729
|
if "postconv" in k:
|
700
730
|
old_name = k.replace("postconv", "conv")
|
701
|
-
new_state_dict[k] = old_state_dict
|
702
|
-
|
731
|
+
new_state_dict[k] = old_state_dict.pop(old_name)
|
732
|
+
if verbose:
|
733
|
+
# print all remaining keys in old_state_dict
|
734
|
+
print("remaining keys in old_state_dict:", old_state_dict.keys())
|
735
|
+
self.load_state_dict(new_state_dict, strict=True)
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from torch import nn
|
16
|
+
from torch.nn import functional as F
|
17
|
+
|
18
|
+
from monai.utils import optional_import
|
19
|
+
|
20
|
+
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
|
21
|
+
|
22
|
+
_all__ = ["CellSamWrapper"]
|
23
|
+
|
24
|
+
|
25
|
+
class CellSamWrapper(torch.nn.Module):
|
26
|
+
"""
|
27
|
+
CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything
|
28
|
+
with an image only decoder, that can be used for segmentation tasks.
|
29
|
+
|
30
|
+
|
31
|
+
Args:
|
32
|
+
auto_resize_inputs: whether to resize inputs before passing to the network.
|
33
|
+
(usually they need be resized, unless they are already at the expected size)
|
34
|
+
network_resize_roi: expected input size for the network.
|
35
|
+
(currently SAM expects 1024x1024)
|
36
|
+
checkpoint: checkpoint file to load the SAM weights from.
|
37
|
+
(this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
|
38
|
+
return_features: whether to return features from SAM encoder
|
39
|
+
(without using decoder/upsampling to the original input size)
|
40
|
+
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
auto_resize_inputs=True,
|
46
|
+
network_resize_roi=(1024, 1024),
|
47
|
+
checkpoint="sam_vit_b_01ec64.pth",
|
48
|
+
return_features=False,
|
49
|
+
*args,
|
50
|
+
**kwargs,
|
51
|
+
) -> None:
|
52
|
+
super().__init__(*args, **kwargs)
|
53
|
+
|
54
|
+
self.network_resize_roi = network_resize_roi
|
55
|
+
self.auto_resize_inputs = auto_resize_inputs
|
56
|
+
self.return_features = return_features
|
57
|
+
|
58
|
+
if not has_sam:
|
59
|
+
raise ValueError(
|
60
|
+
"SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git"
|
61
|
+
)
|
62
|
+
|
63
|
+
model = build_sam_vit_b(checkpoint=checkpoint)
|
64
|
+
|
65
|
+
model.prompt_encoder = None
|
66
|
+
model.mask_decoder = None
|
67
|
+
|
68
|
+
model.mask_decoder = nn.Sequential(
|
69
|
+
nn.BatchNorm2d(num_features=256),
|
70
|
+
nn.ReLU(inplace=True),
|
71
|
+
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
|
72
|
+
nn.BatchNorm2d(num_features=128),
|
73
|
+
nn.ReLU(inplace=True),
|
74
|
+
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
|
75
|
+
)
|
76
|
+
|
77
|
+
self.model = model
|
78
|
+
|
79
|
+
def forward(self, x):
|
80
|
+
sh = x.shape[2:]
|
81
|
+
|
82
|
+
if self.auto_resize_inputs:
|
83
|
+
x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear")
|
84
|
+
|
85
|
+
x = self.model.image_encoder(x)
|
86
|
+
|
87
|
+
if not self.return_features:
|
88
|
+
x = self.model.mask_decoder(x)
|
89
|
+
if self.auto_resize_inputs:
|
90
|
+
x = F.interpolate(x, size=sh, mode="bilinear")
|
91
|
+
|
92
|
+
return x
|
@@ -143,6 +143,10 @@ class ControlNet(nn.Module):
|
|
143
143
|
upcast_attention: if True, upcast attention operations to full precision.
|
144
144
|
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
|
145
145
|
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
|
146
|
+
include_fc: whether to include the final linear layer. Default to True.
|
147
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
|
148
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
149
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
146
150
|
"""
|
147
151
|
|
148
152
|
def __init__(
|
@@ -163,28 +167,29 @@ class ControlNet(nn.Module):
|
|
163
167
|
upcast_attention: bool = False,
|
164
168
|
conditioning_embedding_in_channels: int = 1,
|
165
169
|
conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
|
170
|
+
include_fc: bool = True,
|
171
|
+
use_combined_linear: bool = False,
|
172
|
+
use_flash_attention: bool = False,
|
166
173
|
) -> None:
|
167
174
|
super().__init__()
|
168
175
|
if with_conditioning is True and cross_attention_dim is None:
|
169
176
|
raise ValueError(
|
170
|
-
"
|
177
|
+
"ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
|
171
178
|
"to be specified when with_conditioning=True."
|
172
179
|
)
|
173
180
|
if cross_attention_dim is not None and with_conditioning is False:
|
174
|
-
raise ValueError(
|
175
|
-
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
|
176
|
-
)
|
181
|
+
raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.")
|
177
182
|
|
178
183
|
# All number of channels should be multiple of num_groups
|
179
184
|
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
|
180
185
|
raise ValueError(
|
181
|
-
f"
|
186
|
+
f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
|
182
187
|
f" channels={channels} and norm_num_groups={norm_num_groups}"
|
183
188
|
)
|
184
189
|
|
185
190
|
if len(channels) != len(attention_levels):
|
186
191
|
raise ValueError(
|
187
|
-
f"
|
192
|
+
f"ControlNet expects channels to have the same length as attention_levels, but got "
|
188
193
|
f"channels={channels} and attention_levels={attention_levels}"
|
189
194
|
)
|
190
195
|
|
@@ -282,6 +287,9 @@ class ControlNet(nn.Module):
|
|
282
287
|
transformer_num_layers=transformer_num_layers,
|
283
288
|
cross_attention_dim=cross_attention_dim,
|
284
289
|
upcast_attention=upcast_attention,
|
290
|
+
include_fc=include_fc,
|
291
|
+
use_combined_linear=use_combined_linear,
|
292
|
+
use_flash_attention=use_flash_attention,
|
285
293
|
)
|
286
294
|
|
287
295
|
self.down_blocks.append(down_block)
|
@@ -326,6 +334,9 @@ class ControlNet(nn.Module):
|
|
326
334
|
transformer_num_layers=transformer_num_layers,
|
327
335
|
cross_attention_dim=cross_attention_dim,
|
328
336
|
upcast_attention=upcast_attention,
|
337
|
+
include_fc=include_fc,
|
338
|
+
use_combined_linear=use_combined_linear,
|
339
|
+
use_flash_attention=use_flash_attention,
|
329
340
|
)
|
330
341
|
|
331
342
|
controlnet_block = Convolution(
|
@@ -441,25 +452,16 @@ class ControlNet(nn.Module):
|
|
441
452
|
# copy over all matching keys
|
442
453
|
for k in new_state_dict:
|
443
454
|
if k in old_state_dict:
|
444
|
-
new_state_dict[k] = old_state_dict
|
455
|
+
new_state_dict[k] = old_state_dict.pop(k)
|
445
456
|
|
446
457
|
# fix the attention blocks
|
447
|
-
attention_blocks = [k.replace(".
|
458
|
+
attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k]
|
448
459
|
for block in attention_blocks:
|
449
|
-
new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat(
|
450
|
-
[
|
451
|
-
old_state_dict[f"{block}.attn1.to_q.weight"],
|
452
|
-
old_state_dict[f"{block}.attn1.to_k.weight"],
|
453
|
-
old_state_dict[f"{block}.attn1.to_v.weight"],
|
454
|
-
],
|
455
|
-
dim=0,
|
456
|
-
)
|
457
|
-
|
458
460
|
# projection
|
459
|
-
new_state_dict[f"{block}.
|
460
|
-
new_state_dict[f"{block}.
|
461
|
-
|
462
|
-
new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
|
463
|
-
new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
|
461
|
+
new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
|
462
|
+
new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
|
464
463
|
|
464
|
+
if verbose:
|
465
|
+
# print all remaining keys in old_state_dict
|
466
|
+
print("remaining keys in old_state_dict:", old_state_dict.keys())
|
465
467
|
self.load_state_dict(new_state_dict)
|