monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,472 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Sequence
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from monai.networks.blocks import Convolution
21
+ from monai.networks.layers import Act
22
+ from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer
23
+ from monai.utils import ensure_tuple_rep
24
+
25
+ __all__ = ["VQVAE"]
26
+
27
+
28
+ class VQVAEResidualUnit(nn.Module):
29
+ """
30
+ Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving
31
+ Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf).
32
+
33
+ The original implementation that can be found at
34
+ https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150.
35
+
36
+ Args:
37
+ spatial_dims: number of spatial spatial_dims of the input data.
38
+ in_channels: number of input channels.
39
+ num_res_channels: number of channels in the residual layers.
40
+ act: activation type and arguments. Defaults to RELU.
41
+ dropout: dropout ratio. Defaults to no dropout.
42
+ bias: whether to have a bias term. Defaults to True.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ spatial_dims: int,
48
+ in_channels: int,
49
+ num_res_channels: int,
50
+ act: tuple | str | None = Act.RELU,
51
+ dropout: float = 0.0,
52
+ bias: bool = True,
53
+ ) -> None:
54
+ super().__init__()
55
+
56
+ self.spatial_dims = spatial_dims
57
+ self.in_channels = in_channels
58
+ self.num_res_channels = num_res_channels
59
+ self.act = act
60
+ self.dropout = dropout
61
+ self.bias = bias
62
+
63
+ self.conv1 = Convolution(
64
+ spatial_dims=self.spatial_dims,
65
+ in_channels=self.in_channels,
66
+ out_channels=self.num_res_channels,
67
+ adn_ordering="DA",
68
+ act=self.act,
69
+ dropout=self.dropout,
70
+ bias=self.bias,
71
+ )
72
+
73
+ self.conv2 = Convolution(
74
+ spatial_dims=self.spatial_dims,
75
+ in_channels=self.num_res_channels,
76
+ out_channels=self.in_channels,
77
+ bias=self.bias,
78
+ conv_only=True,
79
+ )
80
+
81
+ def forward(self, x):
82
+ return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True)
83
+
84
+
85
+ class Encoder(nn.Module):
86
+ """
87
+ Encoder module for VQ-VAE.
88
+
89
+ Args:
90
+ spatial_dims: number of spatial spatial_dims.
91
+ in_channels: number of input channels.
92
+ out_channels: number of channels in the latent space (embedding_dim).
93
+ channels: sequence containing the number of channels at each level of the encoder.
94
+ num_res_layers: number of sequential residual layers at each level.
95
+ num_res_channels: number of channels in the residual layers at each level.
96
+ downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
97
+ following information stride (int), kernel_size (int), dilation (int) and padding (int).
98
+ dropout: dropout ratio.
99
+ act: activation type and arguments.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ spatial_dims: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ channels: Sequence[int],
108
+ num_res_layers: int,
109
+ num_res_channels: Sequence[int],
110
+ downsample_parameters: Sequence[Tuple[int, int, int, int]],
111
+ dropout: float,
112
+ act: tuple | str | None,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.spatial_dims = spatial_dims
116
+ self.in_channels = in_channels
117
+ self.out_channels = out_channels
118
+ self.channels = channels
119
+ self.num_res_layers = num_res_layers
120
+ self.num_res_channels = num_res_channels
121
+ self.downsample_parameters = downsample_parameters
122
+ self.dropout = dropout
123
+ self.act = act
124
+
125
+ blocks: list[nn.Module] = []
126
+
127
+ for i in range(len(self.channels)):
128
+ blocks.append(
129
+ Convolution(
130
+ spatial_dims=self.spatial_dims,
131
+ in_channels=self.in_channels if i == 0 else self.channels[i - 1],
132
+ out_channels=self.channels[i],
133
+ strides=self.downsample_parameters[i][0],
134
+ kernel_size=self.downsample_parameters[i][1],
135
+ adn_ordering="DA",
136
+ act=self.act,
137
+ dropout=None if i == 0 else self.dropout,
138
+ dropout_dim=1,
139
+ dilation=self.downsample_parameters[i][2],
140
+ padding=self.downsample_parameters[i][3],
141
+ )
142
+ )
143
+
144
+ for _ in range(self.num_res_layers):
145
+ blocks.append(
146
+ VQVAEResidualUnit(
147
+ spatial_dims=self.spatial_dims,
148
+ in_channels=self.channels[i],
149
+ num_res_channels=self.num_res_channels[i],
150
+ act=self.act,
151
+ dropout=self.dropout,
152
+ )
153
+ )
154
+
155
+ blocks.append(
156
+ Convolution(
157
+ spatial_dims=self.spatial_dims,
158
+ in_channels=self.channels[len(self.channels) - 1],
159
+ out_channels=self.out_channels,
160
+ strides=1,
161
+ kernel_size=3,
162
+ padding=1,
163
+ conv_only=True,
164
+ )
165
+ )
166
+
167
+ self.blocks = nn.ModuleList(blocks)
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ for block in self.blocks:
171
+ x = block(x)
172
+ return x
173
+
174
+
175
+ class Decoder(nn.Module):
176
+ """
177
+ Decoder module for VQ-VAE.
178
+
179
+ Args:
180
+ spatial_dims: number of spatial spatial_dims.
181
+ in_channels: number of channels in the latent space (embedding_dim).
182
+ out_channels: number of output channels.
183
+ channels: sequence containing the number of channels at each level of the decoder.
184
+ num_res_layers: number of sequential residual layers at each level.
185
+ num_res_channels: number of channels in the residual layers at each level.
186
+ upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
187
+ following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
188
+ dropout: dropout ratio.
189
+ act: activation type and arguments.
190
+ output_act: activation type and arguments for the output.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ spatial_dims: int,
196
+ in_channels: int,
197
+ out_channels: int,
198
+ channels: Sequence[int],
199
+ num_res_layers: int,
200
+ num_res_channels: Sequence[int],
201
+ upsample_parameters: Sequence[Tuple[int, int, int, int, int]],
202
+ dropout: float,
203
+ act: tuple | str | None,
204
+ output_act: tuple | str | None,
205
+ ) -> None:
206
+ super().__init__()
207
+ self.spatial_dims = spatial_dims
208
+ self.in_channels = in_channels
209
+ self.out_channels = out_channels
210
+ self.channels = channels
211
+ self.num_res_layers = num_res_layers
212
+ self.num_res_channels = num_res_channels
213
+ self.upsample_parameters = upsample_parameters
214
+ self.dropout = dropout
215
+ self.act = act
216
+ self.output_act = output_act
217
+
218
+ reversed_num_channels = list(reversed(self.channels))
219
+
220
+ blocks: list[nn.Module] = []
221
+ blocks.append(
222
+ Convolution(
223
+ spatial_dims=self.spatial_dims,
224
+ in_channels=self.in_channels,
225
+ out_channels=reversed_num_channels[0],
226
+ strides=1,
227
+ kernel_size=3,
228
+ padding=1,
229
+ conv_only=True,
230
+ )
231
+ )
232
+
233
+ reversed_num_res_channels = list(reversed(self.num_res_channels))
234
+ for i in range(len(self.channels)):
235
+ for _ in range(self.num_res_layers):
236
+ blocks.append(
237
+ VQVAEResidualUnit(
238
+ spatial_dims=self.spatial_dims,
239
+ in_channels=reversed_num_channels[i],
240
+ num_res_channels=reversed_num_res_channels[i],
241
+ act=self.act,
242
+ dropout=self.dropout,
243
+ )
244
+ )
245
+
246
+ blocks.append(
247
+ Convolution(
248
+ spatial_dims=self.spatial_dims,
249
+ in_channels=reversed_num_channels[i],
250
+ out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1],
251
+ strides=self.upsample_parameters[i][0],
252
+ kernel_size=self.upsample_parameters[i][1],
253
+ adn_ordering="DA",
254
+ act=self.act,
255
+ dropout=self.dropout if i != len(self.channels) - 1 else None,
256
+ norm=None,
257
+ dilation=self.upsample_parameters[i][2],
258
+ conv_only=i == len(self.channels) - 1,
259
+ is_transposed=True,
260
+ padding=self.upsample_parameters[i][3],
261
+ output_padding=self.upsample_parameters[i][4],
262
+ )
263
+ )
264
+
265
+ if self.output_act:
266
+ blocks.append(Act[self.output_act]())
267
+
268
+ self.blocks = nn.ModuleList(blocks)
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
271
+ for block in self.blocks:
272
+ x = block(x)
273
+ return x
274
+
275
+
276
+ class VQVAE(nn.Module):
277
+ """
278
+ Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative
279
+ Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf)
280
+
281
+ The original implementation can be found at
282
+ https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/
283
+
284
+ Args:
285
+ spatial_dims: number of spatial spatial_dims.
286
+ in_channels: number of input channels.
287
+ out_channels: number of output channels.
288
+ downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
289
+ following information stride (int), kernel_size (int), dilation (int) and padding (int).
290
+ upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
291
+ following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
292
+ num_res_layers: number of sequential residual layers at each level.
293
+ channels: number of channels at each level.
294
+ num_res_channels: number of channels in the residual layers at each level.
295
+ num_embeddings: VectorQuantization number of atomic elements in the codebook.
296
+ embedding_dim: VectorQuantization number of channels of the input and atomic elements.
297
+ commitment_cost: VectorQuantization commitment_cost.
298
+ decay: VectorQuantization decay.
299
+ epsilon: VectorQuantization epsilon.
300
+ act: activation type and arguments.
301
+ dropout: dropout ratio.
302
+ output_act: activation type and arguments for the output.
303
+ ddp_sync: whether to synchronize the codebook across processes.
304
+ use_checkpointing if True, use activation checkpointing to save memory.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ spatial_dims: int,
310
+ in_channels: int,
311
+ out_channels: int,
312
+ channels: Sequence[int] = (96, 96, 192),
313
+ num_res_layers: int = 3,
314
+ num_res_channels: Sequence[int] | int = (96, 96, 192),
315
+ downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = (
316
+ (2, 4, 1, 1),
317
+ (2, 4, 1, 1),
318
+ (2, 4, 1, 1),
319
+ ),
320
+ upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = (
321
+ (2, 4, 1, 1, 0),
322
+ (2, 4, 1, 1, 0),
323
+ (2, 4, 1, 1, 0),
324
+ ),
325
+ num_embeddings: int = 32,
326
+ embedding_dim: int = 64,
327
+ embedding_init: str = "normal",
328
+ commitment_cost: float = 0.25,
329
+ decay: float = 0.5,
330
+ epsilon: float = 1e-5,
331
+ dropout: float = 0.0,
332
+ act: tuple | str | None = Act.RELU,
333
+ output_act: tuple | str | None = None,
334
+ ddp_sync: bool = True,
335
+ use_checkpointing: bool = False,
336
+ ):
337
+ super().__init__()
338
+
339
+ self.in_channels = in_channels
340
+ self.out_channels = out_channels
341
+ self.spatial_dims = spatial_dims
342
+ self.channels = channels
343
+ self.num_embeddings = num_embeddings
344
+ self.embedding_dim = embedding_dim
345
+ self.use_checkpointing = use_checkpointing
346
+
347
+ if isinstance(num_res_channels, int):
348
+ num_res_channels = ensure_tuple_rep(num_res_channels, len(channels))
349
+
350
+ if len(num_res_channels) != len(channels):
351
+ raise ValueError(
352
+ "`num_res_channels` should be a single integer or a tuple of integers with the same length as "
353
+ "`num_channls`."
354
+ )
355
+ if all(isinstance(values, int) for values in upsample_parameters):
356
+ upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels)
357
+ else:
358
+ upsample_parameters_tuple = upsample_parameters
359
+
360
+ if all(isinstance(values, int) for values in downsample_parameters):
361
+ downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels)
362
+ else:
363
+ downsample_parameters_tuple = downsample_parameters
364
+
365
+ if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple):
366
+ raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.")
367
+
368
+ # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints
369
+ if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple):
370
+ raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.")
371
+
372
+ for parameter in downsample_parameters_tuple:
373
+ if len(parameter) != 4:
374
+ raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.")
375
+
376
+ for parameter in upsample_parameters_tuple:
377
+ if len(parameter) != 5:
378
+ raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.")
379
+
380
+ if len(downsample_parameters_tuple) != len(channels):
381
+ raise ValueError(
382
+ "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`."
383
+ )
384
+
385
+ if len(upsample_parameters_tuple) != len(channels):
386
+ raise ValueError(
387
+ "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`."
388
+ )
389
+
390
+ self.num_res_layers = num_res_layers
391
+ self.num_res_channels = num_res_channels
392
+
393
+ self.encoder = Encoder(
394
+ spatial_dims=spatial_dims,
395
+ in_channels=in_channels,
396
+ out_channels=embedding_dim,
397
+ channels=channels,
398
+ num_res_layers=num_res_layers,
399
+ num_res_channels=num_res_channels,
400
+ downsample_parameters=downsample_parameters_tuple,
401
+ dropout=dropout,
402
+ act=act,
403
+ )
404
+
405
+ self.decoder = Decoder(
406
+ spatial_dims=spatial_dims,
407
+ in_channels=embedding_dim,
408
+ out_channels=out_channels,
409
+ channels=channels,
410
+ num_res_layers=num_res_layers,
411
+ num_res_channels=num_res_channels,
412
+ upsample_parameters=upsample_parameters_tuple,
413
+ dropout=dropout,
414
+ act=act,
415
+ output_act=output_act,
416
+ )
417
+
418
+ self.quantizer = VectorQuantizer(
419
+ quantizer=EMAQuantizer(
420
+ spatial_dims=spatial_dims,
421
+ num_embeddings=num_embeddings,
422
+ embedding_dim=embedding_dim,
423
+ commitment_cost=commitment_cost,
424
+ decay=decay,
425
+ epsilon=epsilon,
426
+ embedding_init=embedding_init,
427
+ ddp_sync=ddp_sync,
428
+ )
429
+ )
430
+
431
+ def encode(self, images: torch.Tensor) -> torch.Tensor:
432
+ output: torch.Tensor
433
+ if self.use_checkpointing:
434
+ output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False)
435
+ else:
436
+ output = self.encoder(images)
437
+ return output
438
+
439
+ def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
440
+ x_loss, x = self.quantizer(encodings)
441
+ return x, x_loss
442
+
443
+ def decode(self, quantizations: torch.Tensor) -> torch.Tensor:
444
+ output: torch.Tensor
445
+
446
+ if self.use_checkpointing:
447
+ output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False)
448
+ else:
449
+ output = self.decoder(quantizations)
450
+ return output
451
+
452
+ def index_quantize(self, images: torch.Tensor) -> torch.Tensor:
453
+ return self.quantizer.quantize(self.encode(images=images))
454
+
455
+ def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor:
456
+ return self.decode(self.quantizer.embed(embedding_indices))
457
+
458
+ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
459
+ quantizations, quantization_losses = self.quantize(self.encode(images))
460
+ reconstruction = self.decode(quantizations)
461
+
462
+ return reconstruction, quantization_losses
463
+
464
+ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
465
+ z = self.encode(x)
466
+ e, _ = self.quantize(z)
467
+ return e
468
+
469
+ def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
470
+ e, _ = self.quantize(z)
471
+ image = self.decode(e)
472
+ return image
@@ -0,0 +1,17 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from .ddim import DDIMScheduler
15
+ from .ddpm import DDPMScheduler
16
+ from .pndm import PNDMScheduler
17
+ from .scheduler import NoiseSchedules, Scheduler