monai-weekly 1.5.dev2510__py3-none-any.whl → 1.5.dev2511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +4 -0
- monai/inferers/inferer.py +29 -9
- monai/networks/schedulers/__init__.py +1 -0
- monai/networks/schedulers/rectified_flow.py +322 -0
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/METADATA +1 -1
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/RECORD +15 -13
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/WHEEL +1 -1
- tests/inferers/test_controlnet_inferers.py +96 -32
- tests/inferers/test_diffusion_inferer.py +99 -1
- tests/inferers/test_latent_diffusion_inferer.py +217 -211
- tests/networks/schedulers/test_scheduler_rflow.py +105 -0
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-03-
|
11
|
+
"date": "2025-03-16T02:30:38+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "7876647f87c763d854f9546bbc60e12f13af84a6",
|
15
|
+
"version": "1.5.dev2511"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -232,6 +232,10 @@ class MaisiConvolution(nn.Module):
|
|
232
232
|
if self.print_info:
|
233
233
|
logger.info(f"Number of splits: {self.num_splits}")
|
234
234
|
|
235
|
+
if self.dim_split <= 1 and self.num_splits <= 1:
|
236
|
+
x = self.conv(x)
|
237
|
+
return x
|
238
|
+
|
235
239
|
# compute size of splits
|
236
240
|
l = x.size(self.dim_split + 2)
|
237
241
|
split_size = l // self.num_splits
|
monai/inferers/inferer.py
CHANGED
@@ -39,7 +39,7 @@ from monai.networks.nets import (
|
|
39
39
|
SPADEAutoencoderKL,
|
40
40
|
SPADEDiffusionModelUNet,
|
41
41
|
)
|
42
|
-
from monai.networks.schedulers import Scheduler
|
42
|
+
from monai.networks.schedulers import RFlowScheduler, Scheduler
|
43
43
|
from monai.transforms import CenterSpatialCrop, SpatialPad
|
44
44
|
from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
|
45
45
|
from monai.visualize import CAM, GradCAM, GradCAMpp
|
@@ -859,12 +859,18 @@ class DiffusionInferer(Inferer):
|
|
859
859
|
if not scheduler:
|
860
860
|
scheduler = self.scheduler
|
861
861
|
image = input_noise
|
862
|
+
|
863
|
+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
|
862
864
|
if verbose and has_tqdm:
|
863
|
-
progress_bar = tqdm(
|
865
|
+
progress_bar = tqdm(
|
866
|
+
zip(scheduler.timesteps, all_next_timesteps),
|
867
|
+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
|
868
|
+
)
|
864
869
|
else:
|
865
|
-
progress_bar = iter(scheduler.timesteps)
|
870
|
+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
|
866
871
|
intermediates = []
|
867
|
-
|
872
|
+
|
873
|
+
for t, next_t in progress_bar:
|
868
874
|
# 1. predict noise model_output
|
869
875
|
diffusion_model = (
|
870
876
|
partial(diffusion_model, seg=seg)
|
@@ -882,9 +888,13 @@ class DiffusionInferer(Inferer):
|
|
882
888
|
)
|
883
889
|
|
884
890
|
# 2. compute previous image: x_t -> x_t-1
|
885
|
-
|
891
|
+
if not isinstance(scheduler, RFlowScheduler):
|
892
|
+
image, _ = scheduler.step(model_output, t, image) # type: ignore
|
893
|
+
else:
|
894
|
+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
|
886
895
|
if save_intermediates and t % intermediate_steps == 0:
|
887
896
|
intermediates.append(image)
|
897
|
+
|
888
898
|
if save_intermediates:
|
889
899
|
return image, intermediates
|
890
900
|
else:
|
@@ -1392,12 +1402,18 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1392
1402
|
if not scheduler:
|
1393
1403
|
scheduler = self.scheduler
|
1394
1404
|
image = input_noise
|
1405
|
+
|
1406
|
+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
|
1395
1407
|
if verbose and has_tqdm:
|
1396
|
-
progress_bar = tqdm(
|
1408
|
+
progress_bar = tqdm(
|
1409
|
+
zip(scheduler.timesteps, all_next_timesteps),
|
1410
|
+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
|
1411
|
+
)
|
1397
1412
|
else:
|
1398
|
-
progress_bar = iter(scheduler.timesteps)
|
1413
|
+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
|
1399
1414
|
intermediates = []
|
1400
|
-
|
1415
|
+
|
1416
|
+
for t, next_t in progress_bar:
|
1401
1417
|
diffuse = diffusion_model
|
1402
1418
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1403
1419
|
diffuse = partial(diffusion_model, seg=seg)
|
@@ -1436,7 +1452,11 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1436
1452
|
)
|
1437
1453
|
|
1438
1454
|
# 3. compute previous image: x_t -> x_t-1
|
1439
|
-
|
1455
|
+
if not isinstance(scheduler, RFlowScheduler):
|
1456
|
+
image, _ = scheduler.step(model_output, t, image) # type: ignore
|
1457
|
+
else:
|
1458
|
+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
|
1459
|
+
|
1440
1460
|
if save_intermediates and t % intermediate_steps == 0:
|
1441
1461
|
intermediates.append(image)
|
1442
1462
|
if save_intermediates:
|
@@ -0,0 +1,322 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE
|
16
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
17
|
+
# you may not use this file except in compliance with the License.
|
18
|
+
# You may obtain a copy of the License at
|
19
|
+
#
|
20
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
21
|
+
#
|
22
|
+
# Unless required by applicable law or agreed to in writing, software
|
23
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
24
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
25
|
+
# See the License for the specific language governing permissions and
|
26
|
+
# limitations under the License.
|
27
|
+
# =========================================================================
|
28
|
+
|
29
|
+
from __future__ import annotations
|
30
|
+
|
31
|
+
from typing import Union
|
32
|
+
|
33
|
+
import numpy as np
|
34
|
+
import torch
|
35
|
+
from torch.distributions import LogisticNormal
|
36
|
+
|
37
|
+
from monai.utils import StrEnum
|
38
|
+
|
39
|
+
from .ddpm import DDPMPredictionType
|
40
|
+
from .scheduler import Scheduler
|
41
|
+
|
42
|
+
|
43
|
+
class RFlowPredictionType(StrEnum):
|
44
|
+
"""
|
45
|
+
Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.
|
46
|
+
|
47
|
+
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
|
48
|
+
"""
|
49
|
+
|
50
|
+
V_PREDICTION = DDPMPredictionType.V_PREDICTION
|
51
|
+
|
52
|
+
|
53
|
+
def timestep_transform(
|
54
|
+
t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
|
55
|
+
):
|
56
|
+
"""
|
57
|
+
Applies a transformation to the timestep based on image resolution scaling.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
t (torch.Tensor): The original timestep(s).
|
61
|
+
input_img_size_numel (torch.Tensor): The input image's size (H * W * D).
|
62
|
+
base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.
|
63
|
+
scale (float): Scaling factor for the transformation.
|
64
|
+
num_train_timesteps (int): Total number of training timesteps.
|
65
|
+
spatial_dim (int): Number of spatial dimensions in the image.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
torch.Tensor: Transformed timestep(s).
|
69
|
+
"""
|
70
|
+
t = t / num_train_timesteps
|
71
|
+
ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim)
|
72
|
+
|
73
|
+
ratio = ratio_space * scale
|
74
|
+
new_t = ratio * t / (1 + (ratio - 1) * t)
|
75
|
+
|
76
|
+
new_t = new_t * num_train_timesteps
|
77
|
+
return new_t
|
78
|
+
|
79
|
+
|
80
|
+
class RFlowScheduler(Scheduler):
|
81
|
+
"""
|
82
|
+
A rectified flow scheduler for guiding the diffusion process in a generative model.
|
83
|
+
|
84
|
+
Supports uniform and logit-normal sampling methods, timestep transformation for
|
85
|
+
different resolutions, and noise addition during diffusion.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
num_train_timesteps (int): Total number of training timesteps.
|
89
|
+
use_discrete_timesteps (bool): Whether to use discrete timesteps.
|
90
|
+
sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').
|
91
|
+
loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.
|
92
|
+
scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.
|
93
|
+
use_timestep_transform (bool): Whether to apply timestep transformation.
|
94
|
+
If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.
|
95
|
+
transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
|
96
|
+
steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
|
97
|
+
base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
|
98
|
+
spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.
|
99
|
+
|
100
|
+
Example:
|
101
|
+
|
102
|
+
.. code-block:: python
|
103
|
+
|
104
|
+
# define a scheduler
|
105
|
+
noise_scheduler = RFlowScheduler(
|
106
|
+
num_train_timesteps = 1000,
|
107
|
+
use_discrete_timesteps = True,
|
108
|
+
sample_method = 'logit-normal',
|
109
|
+
use_timestep_transform = True,
|
110
|
+
base_img_size_numel = 32 * 32 * 32,
|
111
|
+
spatial_dim = 3
|
112
|
+
)
|
113
|
+
|
114
|
+
# during training
|
115
|
+
inputs = torch.ones(2,4,64,64,32)
|
116
|
+
noise = torch.randn_like(inputs)
|
117
|
+
timesteps = noise_scheduler.sample_timesteps(inputs)
|
118
|
+
noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
119
|
+
predicted_velocity = diffusion_unet(
|
120
|
+
x=noisy_inputs,
|
121
|
+
timesteps=timesteps
|
122
|
+
)
|
123
|
+
loss = loss_l1(predicted_velocity, (inputs - noise))
|
124
|
+
|
125
|
+
# during inference
|
126
|
+
noisy_inputs = torch.randn(2,4,64,64,32)
|
127
|
+
input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
|
128
|
+
noise_scheduler.set_timesteps(
|
129
|
+
num_inference_steps=30, input_img_size_numel=input_img_size_numel)
|
130
|
+
)
|
131
|
+
all_next_timesteps = torch.cat(
|
132
|
+
(noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
|
133
|
+
)
|
134
|
+
for t, next_t in tqdm(
|
135
|
+
zip(noise_scheduler.timesteps, all_next_timesteps),
|
136
|
+
total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
|
137
|
+
):
|
138
|
+
predicted_velocity = diffusion_unet(
|
139
|
+
x=noisy_inputs,
|
140
|
+
timesteps=timesteps
|
141
|
+
)
|
142
|
+
noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)
|
143
|
+
final_output = noisy_inputs
|
144
|
+
"""
|
145
|
+
|
146
|
+
def __init__(
|
147
|
+
self,
|
148
|
+
num_train_timesteps: int = 1000,
|
149
|
+
use_discrete_timesteps: bool = True,
|
150
|
+
sample_method: str = "uniform",
|
151
|
+
loc: float = 0.0,
|
152
|
+
scale: float = 1.0,
|
153
|
+
use_timestep_transform: bool = False,
|
154
|
+
transform_scale: float = 1.0,
|
155
|
+
steps_offset: int = 0,
|
156
|
+
base_img_size_numel: int = 32 * 32 * 32,
|
157
|
+
spatial_dim: int = 3,
|
158
|
+
):
|
159
|
+
# rectified flow only accepts velocity prediction
|
160
|
+
self.prediction_type = RFlowPredictionType.V_PREDICTION
|
161
|
+
|
162
|
+
self.num_train_timesteps = num_train_timesteps
|
163
|
+
self.use_discrete_timesteps = use_discrete_timesteps
|
164
|
+
self.base_img_size_numel = base_img_size_numel
|
165
|
+
self.spatial_dim = spatial_dim
|
166
|
+
|
167
|
+
# sample method
|
168
|
+
if sample_method not in ["uniform", "logit-normal"]:
|
169
|
+
raise ValueError(
|
170
|
+
f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']."
|
171
|
+
)
|
172
|
+
self.sample_method = sample_method
|
173
|
+
if sample_method == "logit-normal":
|
174
|
+
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
|
175
|
+
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
|
176
|
+
|
177
|
+
# timestep transform
|
178
|
+
self.use_timestep_transform = use_timestep_transform
|
179
|
+
self.transform_scale = transform_scale
|
180
|
+
self.steps_offset = steps_offset
|
181
|
+
|
182
|
+
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
183
|
+
"""
|
184
|
+
Add noise to the original samples.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
original_samples: original samples
|
188
|
+
noise: noise to add to samples
|
189
|
+
timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
noisy_samples: sample with added noise
|
193
|
+
"""
|
194
|
+
timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps
|
195
|
+
timepoints = 1 - timepoints # [1,1/1000]
|
196
|
+
|
197
|
+
# expand timepoint to noise shape
|
198
|
+
if noise.ndim == 5:
|
199
|
+
timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:])
|
200
|
+
elif noise.ndim == 4:
|
201
|
+
timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:])
|
202
|
+
else:
|
203
|
+
raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}")
|
204
|
+
|
205
|
+
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
|
206
|
+
|
207
|
+
return noisy_samples
|
208
|
+
|
209
|
+
def set_timesteps(
|
210
|
+
self,
|
211
|
+
num_inference_steps: int,
|
212
|
+
device: str | torch.device | None = None,
|
213
|
+
input_img_size_numel: int | None = None,
|
214
|
+
) -> None:
|
215
|
+
"""
|
216
|
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
|
220
|
+
device: target device to put the data.
|
221
|
+
input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.
|
222
|
+
"""
|
223
|
+
if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1:
|
224
|
+
raise ValueError(
|
225
|
+
f"`num_inference_steps`: {num_inference_steps} should be at least 1, "
|
226
|
+
"and cannot be larger than `self.num_train_timesteps`:"
|
227
|
+
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
228
|
+
f" maximal {self.num_train_timesteps} timesteps."
|
229
|
+
)
|
230
|
+
|
231
|
+
self.num_inference_steps = num_inference_steps
|
232
|
+
# prepare timesteps
|
233
|
+
timesteps = [
|
234
|
+
(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
|
235
|
+
]
|
236
|
+
if self.use_discrete_timesteps:
|
237
|
+
timesteps = [int(round(t)) for t in timesteps]
|
238
|
+
if self.use_timestep_transform:
|
239
|
+
timesteps = [
|
240
|
+
timestep_transform(
|
241
|
+
t,
|
242
|
+
input_img_size_numel=input_img_size_numel,
|
243
|
+
base_img_size_numel=self.base_img_size_numel,
|
244
|
+
num_train_timesteps=self.num_train_timesteps,
|
245
|
+
spatial_dim=self.spatial_dim,
|
246
|
+
)
|
247
|
+
for t in timesteps
|
248
|
+
]
|
249
|
+
timesteps_np = np.array(timesteps).astype(np.float16)
|
250
|
+
if self.use_discrete_timesteps:
|
251
|
+
timesteps_np = timesteps_np.astype(np.int64)
|
252
|
+
self.timesteps = torch.from_numpy(timesteps_np).to(device)
|
253
|
+
self.timesteps += self.steps_offset
|
254
|
+
|
255
|
+
def sample_timesteps(self, x_start):
|
256
|
+
"""
|
257
|
+
Randomly samples training timesteps using the chosen sampling method.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
x_start (torch.Tensor): The input tensor for sampling.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
torch.Tensor: Sampled timesteps.
|
264
|
+
"""
|
265
|
+
if self.sample_method == "uniform":
|
266
|
+
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
|
267
|
+
elif self.sample_method == "logit-normal":
|
268
|
+
t = self.sample_t(x_start) * self.num_train_timesteps
|
269
|
+
|
270
|
+
if self.use_discrete_timesteps:
|
271
|
+
t = t.long()
|
272
|
+
|
273
|
+
if self.use_timestep_transform:
|
274
|
+
input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:]))
|
275
|
+
t = timestep_transform(
|
276
|
+
t,
|
277
|
+
input_img_size_numel=input_img_size_numel,
|
278
|
+
base_img_size_numel=self.base_img_size_numel,
|
279
|
+
num_train_timesteps=self.num_train_timesteps,
|
280
|
+
spatial_dim=len(x_start.shape) - 2,
|
281
|
+
)
|
282
|
+
|
283
|
+
return t
|
284
|
+
|
285
|
+
def step(
|
286
|
+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None
|
287
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
288
|
+
"""
|
289
|
+
Predicts the next sample in the diffusion process.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
model_output (torch.Tensor): Output from the trained diffusion model.
|
293
|
+
timestep (int): Current timestep in the diffusion chain.
|
294
|
+
sample (torch.Tensor): Current sample in the process.
|
295
|
+
next_timestep (Union[int, None]): Optional next timestep.
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info.
|
299
|
+
"""
|
300
|
+
# Ensure num_inference_steps exists and is a valid integer
|
301
|
+
if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int):
|
302
|
+
raise AttributeError(
|
303
|
+
"num_inference_steps is missing or not an integer in the class."
|
304
|
+
"Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it."
|
305
|
+
)
|
306
|
+
|
307
|
+
v_pred = model_output
|
308
|
+
|
309
|
+
if next_timestep is not None:
|
310
|
+
next_timestep = int(next_timestep)
|
311
|
+
dt: float = (
|
312
|
+
float(timestep - next_timestep) / self.num_train_timesteps
|
313
|
+
) # Now next_timestep is guaranteed to be int
|
314
|
+
else:
|
315
|
+
dt = (
|
316
|
+
1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0
|
317
|
+
) # Avoid division by zero
|
318
|
+
|
319
|
+
pred_post_sample = sample + v_pred * dt
|
320
|
+
pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps
|
321
|
+
|
322
|
+
return pred_post_sample, pred_original_sample
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=fnGV8I63_2ZDjUJguv4MIFw1U9IR2N6Y09ZWn4tRbZY,4095
|
2
|
+
monai/_version.py,sha256=vuxb7QhJeI5JKiKSbZNMJDItG7-knm3jNnxT04rJ6Xc,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -49,7 +49,7 @@ monai/apps/detection/utils/predict_utils.py,sha256=6j7U-7pLtbmgE6SXKR_MVImc67-M8
|
|
49
49
|
monai/apps/generation/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
50
50
|
monai/apps/generation/maisi/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
51
51
|
monai/apps/generation/maisi/networks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
52
|
-
monai/apps/generation/maisi/networks/autoencoderkl_maisi.py,sha256=
|
52
|
+
monai/apps/generation/maisi/networks/autoencoderkl_maisi.py,sha256=ClSQuCZAkQhXGLgZ2WEPLg7anSFHnT9v19JDKXyqYPo,36812
|
53
53
|
monai/apps/generation/maisi/networks/controlnet_maisi.py,sha256=0K2uyMfvc1-2cFCoNDngeMbzcPpvFR1JZ0fqF9pj8r4,7707
|
54
54
|
monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py,sha256=XFOiy6GngXC_OKM1dUiel_gp71yUFWgPErYdgrVLQAU,19072
|
55
55
|
monai/apps/mmars/__init__.py,sha256=BolpgEi9jNBgrOQd3Kwp-9QQLeWQwQtlN_MJkK1eu5s,726
|
@@ -195,7 +195,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
|
|
195
195
|
monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
|
196
196
|
monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
|
197
197
|
monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
|
198
|
-
monai/inferers/inferer.py,sha256=
|
198
|
+
monai/inferers/inferer.py,sha256=rgAI5qnLpszoXiSj3HCaqYiMAxymvqYO0Ltujq_lJUo,94617
|
199
199
|
monai/inferers/merger.py,sha256=JxSLdlXTKW1xug11UWQNi6dNtpqVRbGCLc-ifj06g8U,16613
|
200
200
|
monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
|
201
201
|
monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
|
@@ -341,10 +341,11 @@ monai/networks/nets/vitautoenc.py,sha256=vfQBWjTb0k7EY4uC76rmuOCIUUgeBvf_EIXBofC
|
|
341
341
|
monai/networks/nets/vnet.py,sha256=zaJi5kSiTLAuFHThSZfhJvHP6zKh3oBWsTWG-328O_g,10820
|
342
342
|
monai/networks/nets/voxelmorph.py,sha256=Q5VQFLLKSFqhsG0Z8_72ZGfK1nA4kdCfFnGbqI6Eofg,20665
|
343
343
|
monai/networks/nets/vqvae.py,sha256=Zf9fTL_rluhuJhH6gTNB6iiKRfwBxfuuyhCdU9TLmAk,18417
|
344
|
-
monai/networks/schedulers/__init__.py,sha256=
|
344
|
+
monai/networks/schedulers/__init__.py,sha256=Jic-Ln0liMjDVQ1KAv9Z1fsoxGZXuBKxqBeWJthgwHY,798
|
345
345
|
monai/networks/schedulers/ddim.py,sha256=MygHvgLB_NL9488DhHsE_g-EvV6DlDPtiBROpnCvDHc,14380
|
346
346
|
monai/networks/schedulers/ddpm.py,sha256=LPqmlNJex32QrqcVb5s7XCNKVlFPsd_05-IJHpUJZPI,11387
|
347
347
|
monai/networks/schedulers/pndm.py,sha256=9Qe8NOw_tvlpCBK7yvkmyriyGfIO5RRDV8ZKPh85cQY,14472
|
348
|
+
monai/networks/schedulers/rectified_flow.py,sha256=n0Pi03Z8GJBZVf9G5YUQ-uc9dZSGK4ra2SnMc4sI0GE,13757
|
348
349
|
monai/networks/schedulers/scheduler.py,sha256=X5eu5AmtNiads9cgaFy5r7BdlKYASSICyGSyF-fk6x8,9206
|
349
350
|
monai/optimizers/__init__.py,sha256=XUL7o9vSL7bZImpxVZqcc1c8MwUMrOZL4nJ-mjAA7yM,796
|
350
351
|
monai/optimizers/lr_finder.py,sha256=tbVi6qd-LLI6pENM9cDUv-Hh1HqziO3Wb9aI6JoaPng,21992
|
@@ -570,9 +571,9 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
|
|
570
571
|
tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
|
571
572
|
tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
572
573
|
tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
|
573
|
-
tests/inferers/test_controlnet_inferers.py,sha256=
|
574
|
-
tests/inferers/test_diffusion_inferer.py,sha256=
|
575
|
-
tests/inferers/test_latent_diffusion_inferer.py,sha256=
|
574
|
+
tests/inferers/test_controlnet_inferers.py,sha256=pGseHgnfMH-UOoAoUsKXdqka-IZc8X83ThauSanH--o,52825
|
575
|
+
tests/inferers/test_diffusion_inferer.py,sha256=U6zNPnem9_cY9bDxMh6L2hThsmla7sDq9ivWQEyqNAk,14613
|
576
|
+
tests/inferers/test_latent_diffusion_inferer.py,sha256=4cnS77I5YpFX1wKcTrlPfKVP3g6UHOkbuADgiXrScks,33544
|
576
577
|
tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
|
577
578
|
tests/inferers/test_saliency_inferer.py,sha256=7miHRbA4yb_WGcxql6za9uXXoZlql_7y23f7IzsyIps,1949
|
578
579
|
tests/inferers/test_slice_inferer.py,sha256=kzaJjjTnf2rAiR75l8A_J-Kie4NaLj2bogi-aJ5L5mk,1897
|
@@ -801,6 +802,7 @@ tests/networks/schedulers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr
|
|
801
802
|
tests/networks/schedulers/test_scheduler_ddim.py,sha256=0JnqgUAgA9W3H3QTNaRxAPvUmsrJKpZm6QX7tp71lxE,3540
|
802
803
|
tests/networks/schedulers/test_scheduler_ddpm.py,sha256=MizmzprAO5dBIeaFX8jlYrZmF0VK444lK9gq1X3Wxk4,4577
|
803
804
|
tests/networks/schedulers/test_scheduler_pndm.py,sha256=f_TDa2yUkFCWq9OAhYyXvQ-zUoZMJaqYxulOYnqQIAg,4612
|
805
|
+
tests/networks/schedulers/test_scheduler_rflow.py,sha256=IxkbVUHdBQ3p_RS1-83VFBg3b4hMdeDIDAMk28kX8lE,4339
|
804
806
|
tests/networks/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
805
807
|
tests/networks/utils/test_copy_model_state.py,sha256=SI0dlUkA0rdZVFvi5acr0nve002oIiftWIPWkqLQH2Q,6764
|
806
808
|
tests/networks/utils/test_eval_mode.py,sha256=HQqgC4COr5fAsBo2Z-DCjnOfx6WLxXlPywlGMnQY7_0,1086
|
@@ -1178,8 +1180,8 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
|
|
1178
1180
|
tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
1179
1181
|
tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
|
1180
1182
|
tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
|
1181
|
-
monai_weekly-1.5.
|
1182
|
-
monai_weekly-1.5.
|
1183
|
-
monai_weekly-1.5.
|
1184
|
-
monai_weekly-1.5.
|
1185
|
-
monai_weekly-1.5.
|
1183
|
+
monai_weekly-1.5.dev2511.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
1184
|
+
monai_weekly-1.5.dev2511.dist-info/METADATA,sha256=kCgsN9iKyvPUdX4ymb4dANgTBXyEtc9PxfAIQnU-MIA,11986
|
1185
|
+
monai_weekly-1.5.dev2511.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
1186
|
+
monai_weekly-1.5.dev2511.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
|
1187
|
+
monai_weekly-1.5.dev2511.dist-info/RECORD,,
|