TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ldm/noise_predictor.py ADDED
@@ -0,0 +1,1074 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class NoisePredictor(nn.Module):
7
+ """U-Net-like architecture for noise prediction in Latent Diffusion Models.
8
+
9
+ Predicts noise in the latent space for diffusion models (DDPM, DDIM, SDE), incorporating
10
+ time embeddings and optional text conditioning. Used as the `noise_predictor` in
11
+ `TrainLDM` and `SampleLDM` from the `ldm` module.
12
+
13
+ Parameters
14
+ ----------
15
+ in_channels : int
16
+ Number of input channels (matches latent channels from `AutoencoderLDM`).
17
+ down_channels : list of int
18
+ List of output channels for downsampling blocks.
19
+ mid_channels : list of int
20
+ List of channels for middle blocks.
21
+ up_channels : list of int
22
+ List of output channels for upsampling blocks.
23
+ down_sampling : list of bool
24
+ List indicating whether to downsample in each down block.
25
+ time_embed_dim : int
26
+ Dimensionality of time embeddings.
27
+ y_embed_dim : int
28
+ Dimensionality of text embeddings for conditioning.
29
+ num_down_blocks : int
30
+ Number of convolutional layer pairs per down block.
31
+ num_mid_blocks : int
32
+ Number of convolutional layer pairs per middle block.
33
+ num_up_blocks : int
34
+ Number of convolutional layer pairs per up block.
35
+ dropout_rate : float, optional
36
+ Dropout rate for convolutional and attention layers (default: 0.1).
37
+ down_sampling_factor : int, optional
38
+ Factor for spatial downsampling/upsampling (default: 2).
39
+ where_y : bool, optional
40
+ If True, text embeddings are used in attention; if False, concatenated to input
41
+ (default: True).
42
+ y_to_all : bool, optional
43
+ If True, apply text-conditioned attention to all layers; if False, only first layer
44
+ (default: False).
45
+
46
+ Attributes
47
+ ----------
48
+ in_channels : int
49
+ Number of input channels.
50
+ down_channels : list of int
51
+ Channels for downsampling blocks.
52
+ mid_channels : list of int
53
+ Channels for middle blocks.
54
+ up_channels : list of int
55
+ Channels for upsampling blocks.
56
+ down_sampling : list of bool
57
+ Downsampling flags.
58
+ time_embed_dim : int
59
+ Time embedding dimension.
60
+ y_embed_dim : int
61
+ Text embedding dimension.
62
+ num_down_blocks : int
63
+ Number of layer pairs per down block.
64
+ num_mid_blocks : int
65
+ Number of layer pairs per middle block.
66
+ num_up_blocks : int
67
+ Number of layer pairs per up block.
68
+ dropout_rate : float
69
+ Dropout rate.
70
+ where_y : bool
71
+ Flag for text embedding usage.
72
+ up_sampling : list of bool
73
+ Reversed `down_sampling` for upsampling blocks.
74
+ conv1 : torch.nn.Conv2d
75
+ Initial 3x3 convolutional layer.
76
+ time_projection : torch.nn.Sequential
77
+ Projection for time embeddings.
78
+ down_blocks : torch.nn.ModuleList
79
+ List of DownBlock modules for downsampling.
80
+ mid_blocks : torch.nn.ModuleList
81
+ List of MiddleBlock modules for bottleneck processing.
82
+ up_blocks : torch.nn.ModuleList
83
+ List of UpBlock modules for upsampling.
84
+ conv2 : torch.nn.Sequential
85
+ Final convolutional layer with group normalization and dropout.
86
+
87
+ Notes
88
+ -----
89
+ - The architecture follows a U-Net structure with downsampling, bottleneck, and
90
+ upsampling blocks, incorporating time embeddings and optional text conditioning via
91
+ attention or concatenation.
92
+ - Skip connections link down and up blocks, with channel adjustments for concatenation.
93
+ - Weights are initialized with Kaiming normal (Leaky ReLU nonlinearity) for stability.
94
+ - Input and output tensors have the same shape, matching the latent space of
95
+ `AutoencoderLDM`.
96
+ """
97
+ def __init__(
98
+ self,
99
+ in_channels,
100
+ down_channels,
101
+ mid_channels,
102
+ up_channels,
103
+ down_sampling,
104
+ time_embed_dim,
105
+ y_embed_dim, # output embedding dimension in text conditional net
106
+ num_down_blocks,
107
+ num_mid_blocks,
108
+ num_up_blocks,
109
+ dropout_rate=0.1,
110
+ down_sampling_factor=2,
111
+ where_y=True,
112
+ y_to_all=False
113
+ ):
114
+ super().__init__()
115
+ self.in_channels = in_channels
116
+ self.down_channels = down_channels
117
+ self.mid_channels = mid_channels
118
+ self.up_channels = up_channels
119
+ self.down_sampling = down_sampling
120
+ self.time_embed_dim = time_embed_dim
121
+ self.y_embed_dim = y_embed_dim
122
+ self.num_down_blocks = num_down_blocks
123
+ self.num_mid_blocks = num_mid_blocks
124
+ self.num_up_blocks = num_up_blocks
125
+ self.dropout_rate = dropout_rate
126
+ self.where_y = where_y
127
+ self.up_sampling = list(reversed(self.down_sampling))
128
+ self.conv1 = nn.Conv2d(
129
+ in_channels=self.in_channels,
130
+ out_channels=self.down_channels[0],
131
+ kernel_size=3,
132
+ padding=1
133
+ )
134
+ # initial time embedding projection
135
+ self.time_projection = nn.Sequential(
136
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
137
+ nn.SiLU(),
138
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
139
+ )
140
+ # down blocks
141
+ self.down_blocks = nn.ModuleList([
142
+ DownBlock(
143
+ in_channels=self.down_channels[i],
144
+ out_channels=self.down_channels[i+1],
145
+ time_embed_dim=self.time_embed_dim,
146
+ y_embed_dim=y_embed_dim,
147
+ num_layers=self.num_down_blocks,
148
+ down_sampling_factor=down_sampling_factor,
149
+ down_sample=self.down_sampling[i],
150
+ dropout_rate=self.dropout_rate,
151
+ y_to_all=y_to_all
152
+ ) for i in range(len(self.down_channels)-1)
153
+ ])
154
+ # middle blocks
155
+ self.mid_blocks = nn.ModuleList([
156
+ MiddleBlock(
157
+ in_channels=self.mid_channels[i],
158
+ out_channels=self.mid_channels[i + 1],
159
+ time_embed_dim=self.time_embed_dim,
160
+ y_embed_dim=y_embed_dim,
161
+ num_layers=self.num_mid_blocks,
162
+ dropout_rate=self.dropout_rate,
163
+ y_to_all=y_to_all
164
+ ) for i in range(len(self.mid_channels) - 1)
165
+ ])
166
+ # up blocks
167
+ skip_channels = list(reversed(self.down_channels))
168
+ self.up_blocks = nn.ModuleList([
169
+ UpBlock(
170
+ in_channels=self.up_channels[i],
171
+ out_channels=self.up_channels[i+1],
172
+ skip_channels=skip_channels[i],
173
+ time_embed_dim=self.time_embed_dim,
174
+ y_embed_dim=y_embed_dim,
175
+ num_layers=self.num_up_blocks,
176
+ up_sampling_factor=down_sampling_factor,
177
+ up_sampling=self.up_sampling[i],
178
+ dropout_rate=self.dropout_rate,
179
+ y_to_all=y_to_all
180
+ ) for i in range(len(self.up_channels)-1)
181
+ ])
182
+ # final convolution layer
183
+ self.conv2 = nn.Sequential(
184
+ nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
185
+ nn.Dropout(p=self.dropout_rate),
186
+ nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
187
+ )
188
+
189
+ def initialize_weights(self):
190
+ """Initializes model weights for training stability.
191
+
192
+ Applies Kaiming normal initialization to convolutional and linear layers with
193
+ Leaky ReLU nonlinearity (a=0.2), and zeros biases.
194
+ """
195
+ for module in self.modules():
196
+ if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
197
+ nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
198
+ if module.bias is not None:
199
+ nn.init.zeros_(module.bias)
200
+
201
+ def forward(self, x, t, y=None):
202
+ """Predicts noise given latent input, time step, and optional text conditioning.
203
+
204
+ Parameters
205
+ ----------
206
+ x : torch.Tensor
207
+ Input latent tensor, shape (batch_size, in_channels, height, width).
208
+ t : torch.Tensor
209
+ Time steps, shape (batch_size,).
210
+ y : torch.Tensor, optional
211
+ Text embeddings for conditioning, shape (batch_size, seq_len, y_embed_dim)
212
+ or (batch_size, y_embed_dim) (default: None).
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Predicted noise, same shape as input `x`.
218
+ """
219
+ if not self.where_y and y is not None:
220
+ x = torch.cat(tensors=[x, y], dim=1)
221
+ output = self.conv1(x)
222
+ time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
223
+ time_embed = self.time_projection(time_embed)
224
+ skip_connections = []
225
+
226
+ for i, down in enumerate(self.down_blocks):
227
+ skip_connections.append(output)
228
+ output = down(x=output, embed_time=time_embed, y=y)
229
+ for i, mid in enumerate(self.mid_blocks):
230
+ output = mid(x=output, embed_time=time_embed, y=y)
231
+ for i, up in enumerate(self.up_blocks):
232
+ skip_connection = skip_connections.pop()
233
+ output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
234
+
235
+ output = self.conv2(output)
236
+ return output
237
+ #-----------------------------------------------------------------------------
238
+ class DownBlock(nn.Module):
239
+ """Downsampling block for NoisePredictor’s encoder.
240
+
241
+ Applies convolutional layers with residual connections, time embeddings, and optional
242
+ text-conditioned attention, followed by downsampling if enabled.
243
+
244
+ Parameters
245
+ ----------
246
+ in_channels : int
247
+ Number of input channels.
248
+ out_channels : int
249
+ Number of output channels.
250
+ time_embed_dim : int
251
+ Dimensionality of time embeddings.
252
+ y_embed_dim : int
253
+ Dimensionality of text embeddings.
254
+ num_layers : int
255
+ Number of convolutional layer pairs (Conv3).
256
+ down_sampling_factor : int
257
+ Factor for spatial downsampling.
258
+ down_sample : bool
259
+ If True, apply downsampling; if False, use identity (no downsampling).
260
+ dropout_rate : float
261
+ Dropout rate for Conv3 and attention layers.
262
+ y_to_all : bool
263
+ If True, apply text-conditioned attention to all layers; if False, only first layer.
264
+
265
+ Attributes
266
+ ----------
267
+ num_layers : int
268
+ Number of convolutional layer pairs.
269
+ y_to_all : bool
270
+ Flag for text-conditioned attention scope.
271
+ conv1 : torch.nn.ModuleList
272
+ List of Conv3 layers for first convolution in each pair.
273
+ conv2 : torch.nn.ModuleList
274
+ List of Conv3 layers for second convolution in each pair.
275
+ time_embedding : torch.nn.ModuleList
276
+ List of TimeEmbedding modules for time conditioning.
277
+ attention : torch.nn.ModuleList
278
+ List of Attention modules for text conditioning or self-attention.
279
+ down_sampling : DownSampling or torch.nn.Identity
280
+ Downsampling module or identity if `down_sample=False`.
281
+ resnet : torch.nn.ModuleList
282
+ List of 1x1 convolutional layers for residual connections.
283
+ """
284
+ def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
285
+ super().__init__()
286
+ self.num_layers = num_layers
287
+ self.y_to_all = y_to_all
288
+ self.conv1 = nn.ModuleList([
289
+ Conv3(
290
+ in_channels=in_channels if i==0 else out_channels,
291
+ out_channels=out_channels,
292
+ num_groups=8,
293
+ kernel_size=3,
294
+ norm=True,
295
+ activation=True,
296
+ dropout_rate=dropout_rate
297
+ ) for i in range(self.num_layers)
298
+ ])
299
+ self.conv2 = nn.ModuleList([
300
+ Conv3(
301
+ in_channels=out_channels,
302
+ out_channels=out_channels,
303
+ num_groups=8,
304
+ kernel_size=3,
305
+ norm=True,
306
+ activation=True,
307
+ dropout_rate=dropout_rate
308
+ ) for _ in range(self.num_layers)
309
+ ])
310
+ self.time_embedding = nn.ModuleList([
311
+ TimeEmbedding(
312
+ output_dim=out_channels,
313
+ embed_dim=time_embed_dim
314
+ ) for _ in range(self.num_layers)
315
+ ])
316
+ self.attention = nn.ModuleList([
317
+ Attention(
318
+ in_channels=out_channels,
319
+ y_embed_dim= y_embed_dim,
320
+ num_groups=8,
321
+ num_heads=4,
322
+ dropout_rate=dropout_rate
323
+ ) for _ in range(self.num_layers)
324
+ ])
325
+ self.down_sampling = DownSampling(
326
+ in_channels=out_channels,
327
+ out_channels=out_channels,
328
+ down_sampling_factor=down_sampling_factor,
329
+ conv_block=True,
330
+ max_pool=True
331
+ ) if down_sample else nn.Identity()
332
+ self.resnet = nn.ModuleList([
333
+ nn.Conv2d(
334
+ in_channels=in_channels if i == 0 else out_channels,
335
+ out_channels=out_channels,
336
+ kernel_size=1
337
+ ) for i in range(num_layers)
338
+
339
+ ])
340
+
341
+ def forward(self, x, embed_time, y):
342
+ """Processes input through convolutions, time embeddings, attention, and downsampling.
343
+
344
+ Parameters
345
+ ----------
346
+ x : torch.Tensor
347
+ Input tensor, shape (batch_size, in_channels, height, width).
348
+ embed_time : torch.Tensor
349
+ Time embeddings, shape (batch_size, time_embed_dim).
350
+ y : torch.Tensor, optional
351
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
352
+ (batch_size, y_embed_dim) (default: None).
353
+
354
+ Returns
355
+ -------
356
+ torch.Tensor
357
+ Output tensor, shape (batch_size, out_channels,
358
+ height/down_sampling_factor, width/down_sampling_factor) if downsampling;
359
+ otherwise, same height/width as input.
360
+ """
361
+ output = x
362
+ for i in range(self.num_layers):
363
+ resnet_input = output
364
+ output = self.conv1[i](output)
365
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
366
+ output = self.conv2[i](output)
367
+ output = output + self.resnet[i](resnet_input)
368
+ if y is not None and not self.y_to_all and i == 0:
369
+ out_attn = self.attention[i](output, y)
370
+ output = output + out_attn
371
+ elif y is not None and self.y_to_all:
372
+ out_attn = self.attention[i](output, y)
373
+ output = output + out_attn
374
+ elif y is None and self.y_to_all:
375
+ out_attn = self.attention[i](output)
376
+ output = output + out_attn
377
+ elif y is None and not self.y_to_all and i == 0:
378
+ out_attn = self.attention[i](output)
379
+ output = output + out_attn
380
+
381
+ output = self.down_sampling(output)
382
+ return output
383
+ #------------------------------------------------------------------------------
384
+ class MiddleBlock(nn.Module):
385
+ """Bottleneck block for NoisePredictor’s middle layers.
386
+
387
+ Applies convolutional layers with residual connections, time embeddings, and optional
388
+ text-conditioned attention, preserving spatial dimensions.
389
+
390
+ Parameters
391
+ ----------
392
+ in_channels : int
393
+ Number of input channels.
394
+ out_channels : int
395
+ Number of output channels.
396
+ time_embed_dim : int
397
+ Dimensionality of time embeddings.
398
+ y_embed_dim : int
399
+ Dimensionality of text embeddings.
400
+ num_layers : int
401
+ Number of convolutional layer pairs (Conv3).
402
+ dropout_rate : float
403
+ Dropout rate for Conv3 and attention layers.
404
+ y_to_all : bool, optional
405
+ If True, apply text-conditioned attention to all layers; if False, only first layer
406
+ (default: False).
407
+
408
+ Attributes
409
+ ----------
410
+ num_layers : int
411
+ Number of convolutional layer pairs.
412
+ y_to_all : bool
413
+ Flag for text-conditioned attention scope.
414
+ conv1 : torch.nn.ModuleList
415
+ List of Conv3 layers for first convolution in each pair.
416
+ conv2 : torch.nn.ModuleList
417
+ List of Conv3 layers for second convolution in each pair.
418
+ time_embedding : torch.nn.ModuleList
419
+ List of TimeEmbedding modules for time conditioning.
420
+ attention : torch.nn.ModuleList
421
+ List of Attention modules for text conditioning or self-attention.
422
+ resnet : torch.nn.ModuleList
423
+ List of 1x1 convolutional layers for residual connections.
424
+ """
425
+ def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
426
+ super().__init__()
427
+ self.num_layers = num_layers
428
+ self.y_to_all = y_to_all
429
+ self.conv1 = nn.ModuleList([
430
+ Conv3(
431
+ in_channels=in_channels if i == 0 else out_channels,
432
+ out_channels=out_channels,
433
+ num_groups=8,
434
+ kernel_size=3,
435
+ norm=True,
436
+ activation=True,
437
+ dropout_rate=dropout_rate
438
+ ) for i in range(self.num_layers+1)
439
+ ])
440
+ self.conv2 = nn.ModuleList([
441
+ Conv3(
442
+ in_channels=out_channels,
443
+ out_channels=out_channels,
444
+ num_groups=8,
445
+ kernel_size=3,
446
+ norm=True,
447
+ activation=True,
448
+ dropout_rate=dropout_rate
449
+ ) for _ in range(self.num_layers+1)
450
+ ])
451
+ self.time_embedding = nn.ModuleList([
452
+ TimeEmbedding(
453
+ output_dim=out_channels,
454
+ embed_dim=time_embed_dim
455
+ ) for _ in range(self.num_layers+1)
456
+ ])
457
+ self.attention = nn.ModuleList([
458
+ Attention(
459
+ in_channels=out_channels,
460
+ y_embed_dim=y_embed_dim,
461
+ num_groups=8,
462
+ num_heads=4,
463
+ dropout_rate=dropout_rate
464
+ ) for _ in range(self.num_layers + 1)
465
+ ])
466
+ self.resnet = nn.ModuleList([
467
+ nn.Conv2d(
468
+ in_channels=in_channels if i == 0 else out_channels,
469
+ out_channels=out_channels,
470
+ kernel_size=1
471
+ ) for i in range(num_layers+1)
472
+ ])
473
+
474
+ def forward(self, x, embed_time, y=None):
475
+ """Processes input through convolutions, time embeddings, and attention.
476
+
477
+ Parameters
478
+ ----------
479
+ x : torch.Tensor
480
+ Input tensor, shape (batch_size, in_channels, height, width).
481
+ embed_time : torch.Tensor
482
+ Time embeddings, shape (batch_size, time_embed_dim).
483
+ y : torch.Tensor, optional
484
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
485
+ (batch_size, y_embed_dim) (default: None).
486
+
487
+ Returns
488
+ -------
489
+ torch.Tensor
490
+ Output tensor, shape (batch_size, out_channels, height, width).
491
+ """
492
+ output = x
493
+ resnet_input = output
494
+ output = self.conv1[0](output)
495
+ output = output + self.time_embedding[0](embed_time)[:, :, None, None]
496
+ output = self.conv2[0](output)
497
+ output = output + self.resnet[0](resnet_input)
498
+ for i in range(self.num_layers):
499
+ if y is not None and not self.y_to_all and i == 0:
500
+ out_attn = self.attention[i](output, y)
501
+ output = output + out_attn
502
+ elif y is not None and self.y_to_all:
503
+ out_attn = self.attention[i](output, y)
504
+ output = output + out_attn
505
+ elif y is None and self.y_to_all:
506
+ out_attn = self.attention[i](output)
507
+ output = output + out_attn
508
+ elif y is None and not self.y_to_all and i == 0:
509
+ out_attn = self.attention[i](output)
510
+ output = output + out_attn
511
+ resnet_input = output
512
+ output = self.conv1[i + 1](output)
513
+ output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
514
+ output = self.conv2[i + 1](output)
515
+ output = output + self.resnet[i+1](resnet_input)
516
+ return output
517
+ #------------------------------------------------------------------------------
518
+ class UpBlock(nn.Module):
519
+ """Upsampling block for NoisePredictor’s decoder.
520
+
521
+ Applies upsampling (if enabled), concatenates skip connections, and processes through
522
+ convolutional layers with residual connections, time embeddings, and optional
523
+ text-conditioned attention.
524
+
525
+ Parameters
526
+ ----------
527
+ in_channels : int
528
+ Number of input channels (before upsampling).
529
+ out_channels : int
530
+ Number of output channels.
531
+ skip_channels : int
532
+ Number of channels from skip connection.
533
+ time_embed_dim : int
534
+ Dimensionality of time embeddings.
535
+ y_embed_dim : int
536
+ Dimensionality of text embeddings.
537
+ num_layers : int
538
+ Number of convolutional layer pairs (Conv3).
539
+ up_sampling_factor : int
540
+ Factor for spatial upsampling.
541
+ up_sampling : bool
542
+ If True, apply upsampling; if False, use identity (no upsampling).
543
+ dropout_rate : float
544
+ Dropout rate for Conv3 and attention layers.
545
+ y_to_all : bool, optional
546
+ If True, apply text-conditioned attention to all layers; if False, only first layer
547
+ (default: False).
548
+
549
+ Attributes
550
+ ----------
551
+ num_layers : int
552
+ Number of convolutional layer pairs.
553
+ y_to_all : bool
554
+ Flag for text-conditioned attention scope.
555
+ conv1 : torch.nn.ModuleList
556
+ List of Conv3 layers for first convolution in each pair.
557
+ conv2 : torch.nn.ModuleList
558
+ List of Conv3 layers for second convolution in each pair.
559
+ time_embedding : torch.nn.ModuleList
560
+ List of TimeEmbedding modules for time conditioning.
561
+ attention : torch.nn.ModuleList
562
+ List of Attention modules for text conditioning or self-attention.
563
+ up_sampling : UpSampling or torch.nn.Identity
564
+ Upsampling module or identity if `up_sampling=False`.
565
+ resnet : torch.nn.ModuleList
566
+ List of 1x1 convolutional layers for residual connections.
567
+ """
568
+ def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
569
+ super().__init__()
570
+ self.num_layers = num_layers
571
+ self.y_to_all = y_to_all
572
+ effective_in_channels = in_channels//2 + skip_channels
573
+ self.conv1 = nn.ModuleList([
574
+ Conv3(
575
+ in_channels=effective_in_channels if i == 0 else out_channels,
576
+ out_channels=out_channels,
577
+ num_groups=8,
578
+ kernel_size=3,
579
+ norm=True,
580
+ activation=True,
581
+ dropout_rate=dropout_rate
582
+ ) for i in range(self.num_layers)
583
+ ])
584
+ self.conv2 = nn.ModuleList([
585
+ Conv3(
586
+ in_channels=out_channels,
587
+ out_channels=out_channels,
588
+ num_groups=8,
589
+ kernel_size=3,
590
+ norm=True,
591
+ activation=True,
592
+ dropout_rate=dropout_rate
593
+ ) for _ in range(self.num_layers)
594
+ ])
595
+ self.time_embedding = nn.ModuleList([
596
+ TimeEmbedding(
597
+ output_dim=out_channels,
598
+ embed_dim=time_embed_dim
599
+ ) for _ in range(self.num_layers)
600
+ ])
601
+ self.attention = nn.ModuleList([
602
+ Attention(
603
+ in_channels=out_channels,
604
+ y_embed_dim=y_embed_dim,
605
+ num_groups=8,
606
+ num_heads=4,
607
+ dropout_rate=dropout_rate
608
+ ) for _ in range(self.num_layers)
609
+ ])
610
+ self.up_sampling = UpSampling(
611
+ in_channels=in_channels,
612
+ out_channels=in_channels,
613
+ up_sampling_factor=up_sampling_factor,
614
+ conv_block=True,
615
+ up_sampling=True
616
+ ) if up_sampling else nn.Identity()
617
+ self.resnet = nn.ModuleList([
618
+ nn.Conv2d(
619
+ in_channels=effective_in_channels if i == 0 else out_channels,
620
+ out_channels=out_channels,
621
+ kernel_size=1
622
+ ) for i in range(num_layers)
623
+
624
+ ])
625
+
626
+ def forward(self, x, skip_connection, embed_time, y=None):
627
+ """Processes input through upsampling, skip connection, convolutions, time embeddings, and attention.
628
+
629
+ Parameters
630
+ ----------
631
+ x : torch.Tensor
632
+ Input tensor, shape (batch_size, in_channels, height, width).
633
+ skip_connection : torch.Tensor
634
+ Skip connection tensor, shape (batch_size, skip_channels,
635
+ height*up_sampling_factor, width*up_sampling_factor).
636
+ embed_time : torch.Tensor
637
+ Time embeddings, shape (batch_size, time_embed_dim).
638
+ y : torch.Tensor, optional
639
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
640
+ (batch_size, y_embed_dim) (default: None).
641
+
642
+ Returns
643
+ -------
644
+ torch.Tensor
645
+ Output tensor, shape (batch_size, out_channels,
646
+ height*up_sampling_factor, width*up_sampling_factor) if upsampling;
647
+ otherwise, same height/width as input (after skip connection).
648
+ """
649
+ x = self.up_sampling(x)
650
+ x = torch.cat(tensors=[x, skip_connection], dim=1)
651
+ output = x
652
+ for i in range(self.num_layers):
653
+ resnet_input = output
654
+ output = self.conv1[i](output)
655
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
656
+ output = self.conv2[i](output)
657
+ output = output + self.resnet[i](resnet_input)
658
+ if y is not None and not self.y_to_all and i == 0:
659
+ out_attn = self.attention[i](output, y)
660
+ output = output + out_attn
661
+ elif y is not None and self.y_to_all:
662
+ out_attn = self.attention[i](output, y)
663
+ output = output + out_attn
664
+ elif y is None and self.y_to_all:
665
+ out_attn = self.attention[i](output)
666
+ output = output + out_attn
667
+ elif y is None and not self.y_to_all and i == 0:
668
+ out_attn = self.attention[i](output)
669
+ output = output + out_attn
670
+ #print("up-block output shape:", output.size())
671
+ return output
672
+ #------------------------------------------------------------------------
673
+ class Conv3(nn.Module):
674
+ """Convolutional layer with optional group normalization, SiLU activation, and dropout.
675
+
676
+ Used in DownBlock, MiddleBlock, and UpBlock for feature extraction in NoisePredictor.
677
+
678
+ Parameters
679
+ ----------
680
+ in_channels : int
681
+ Number of input channels.
682
+ out_channels : int
683
+ Number of output channels.
684
+ num_groups : int, optional
685
+ Number of groups for group normalization (default: 8).
686
+ kernel_size : int, optional
687
+ Convolutional kernel size (default: 3).
688
+ norm : bool, optional
689
+ If True, apply group normalization (default: True).
690
+ activation : bool, optional
691
+ If True, apply SiLU activation (default: True).
692
+ dropout_rate : float, optional
693
+ Dropout rate (default: 0.2).
694
+
695
+ Attributes
696
+ ----------
697
+ conv : torch.nn.Conv2d
698
+ Convolutional layer with specified kernel size and padding.
699
+ group_norm : torch.nn.GroupNorm or torch.nn.Identity
700
+ Group normalization or identity if `norm=False`.
701
+ activation : torch.nn.SiLU or torch.nn.Identity
702
+ SiLU activation or identity if `activation=False`.
703
+ dropout : torch.nn.Dropout
704
+ Dropout layer.
705
+ """
706
+ def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
707
+ super().__init__()
708
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
709
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
710
+ self.activation = nn.SiLU() if activation else nn.Identity()
711
+ self.dropout = nn.Dropout(p=dropout_rate)
712
+
713
+ def forward(self, batch):
714
+ """Processes input through convolution, normalization, activation, and dropout.
715
+
716
+ Parameters
717
+ ----------
718
+ batch : torch.Tensor
719
+ Input tensor, shape (batch_size, in_channels, height, width).
720
+
721
+ Returns
722
+ -------
723
+ torch.Tensor
724
+ Output tensor, shape (batch_size, out_channels, height, width).
725
+ """
726
+ batch = self.conv(batch)
727
+ batch = self.group_norm(batch)
728
+ batch = self.activation(batch)
729
+ batch = self.dropout(batch)
730
+ return batch
731
+ #----------------------------------------------------------------
732
+ class TimeEmbedding(nn.Module):
733
+ """Time embedding projection for conditioning NoisePredictor layers.
734
+
735
+ Projects time embeddings to match the channel dimension of convolutional outputs.
736
+
737
+ Parameters
738
+ ----------
739
+ output_dim : int
740
+ Output channel dimension (matches convolutional channels).
741
+ embed_dim : int
742
+ Input time embedding dimension.
743
+
744
+ Attributes
745
+ ----------
746
+ embedding : torch.nn.Sequential
747
+ Sequential layer with SiLU activation and linear projection.
748
+ """
749
+ def __init__(self, output_dim, embed_dim):
750
+ super().__init__()
751
+ self.embedding = nn.Sequential(
752
+ nn.SiLU(),
753
+ nn.Linear(in_features=embed_dim, out_features=output_dim)
754
+ )
755
+ def forward(self, batch):
756
+ """Projects time embeddings to output dimension.
757
+
758
+ Parameters
759
+ ----------
760
+ batch : torch.Tensor
761
+ Time embeddings, shape (batch_size, embed_dim).
762
+
763
+ Returns
764
+ -------
765
+ torch.Tensor
766
+ Projected embeddings, shape (batch_size, output_dim).
767
+ """
768
+ return self.embedding(batch)
769
+ #----------------------------------------------------------------
770
+ class GetEmbeddedTime(nn.Module):
771
+ """Generates sinusoidal time embeddings for NoisePredictor.
772
+
773
+ Creates positional encodings for time steps using sine and cosine functions, following
774
+ the transformer embedding approach.
775
+
776
+ Parameters
777
+ ----------
778
+ embed_dim : int
779
+ Dimensionality of the time embeddings (must be even).
780
+
781
+ Attributes
782
+ ----------
783
+ embed_dim : int
784
+ Time embedding dimension.
785
+
786
+ Raises
787
+ ------
788
+ AssertionError
789
+ If `embed_dim` is not divisible by 2.
790
+ """
791
+ def __init__(self, embed_dim):
792
+ super().__init__()
793
+ assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
794
+ self.embed_dim = embed_dim
795
+
796
+ def forward(self, time_steps):
797
+ """Generates sinusoidal embeddings for time steps.
798
+
799
+ Parameters
800
+ ----------
801
+ time_steps : torch.Tensor
802
+ Time steps, shape (batch_size,).
803
+
804
+ Returns
805
+ -------
806
+ torch.Tensor
807
+ Sinusoidal embeddings, shape (batch_size, embed_dim).
808
+ """
809
+ i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
810
+ factor = 10000 ** (2 * i / self.embed_dim)
811
+ embed_time = time_steps[:, None] / factor
812
+ embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
813
+ return embed_time
814
+ #----------------------------------------------------------------
815
+ class Attention(nn.Module):
816
+ """Attention module for NoisePredictor, supporting text conditioning or self-attention.
817
+
818
+ Applies multi-head attention to enhance features, with optional text embeddings for
819
+ conditional generation.
820
+
821
+ Parameters
822
+ ----------
823
+ in_channels : int
824
+ Number of input channels (embedding dimension for attention).
825
+ y_embed_dim : int, optional
826
+ Dimensionality of text embeddings (default: 768).
827
+ num_heads : int, optional
828
+ Number of attention heads (default: 4).
829
+ num_groups : int, optional
830
+ Number of groups for group normalization (default: 8).
831
+ dropout_rate : float, optional
832
+ Dropout rate for attention and output (default: 0.1).
833
+
834
+ Attributes
835
+ ----------
836
+ in_channels : int
837
+ Input channel dimension.
838
+ y_embed_dim : int
839
+ Text embedding dimension.
840
+ num_heads : int
841
+ Number of attention heads.
842
+ dropout_rate : float
843
+ Dropout rate.
844
+ attention : torch.nn.MultiheadAttention
845
+ Multi-head attention with `batch_first=True`.
846
+ norm : torch.nn.GroupNorm
847
+ Group normalization before attention.
848
+ dropout : torch.nn.Dropout
849
+ Dropout layer for output.
850
+ y_projection : torch.nn.Linear
851
+ Projection for text embeddings to match `in_channels`.
852
+
853
+ Raises
854
+ ------
855
+ AssertionError
856
+ If input channels do not match `in_channels`.
857
+ ValueError
858
+ If text embeddings (`y`) have incorrect dimensions after projection.
859
+ """
860
+ def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
861
+ super().__init__()
862
+ self.in_channels = in_channels
863
+ self.y_embed_dim = y_embed_dim
864
+ self.num_heads = num_heads
865
+ self.dropout_rate = dropout_rate
866
+ self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
867
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
868
+ self.dropout = nn.Dropout(dropout_rate)
869
+ self.y_projection = nn.Linear(y_embed_dim, in_channels)
870
+
871
+ def forward(self, x, y=None):
872
+ """Applies attention to input features with optional text conditioning.
873
+
874
+ Parameters
875
+ ----------
876
+ x : torch.Tensor
877
+ Input tensor, shape (batch_size, in_channels, height, width).
878
+ y : torch.Tensor, optional
879
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
880
+ (batch_size, y_embed_dim) (default: None).
881
+
882
+ Returns
883
+ -------
884
+ torch.Tensor
885
+ Output tensor, same shape as input `x`.
886
+ """
887
+ batch_size, channels, h, w = x.shape
888
+ assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
889
+ x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
890
+ if y is not None:
891
+ y = self.y_projection(y)
892
+ if y.dim() != 3:
893
+ if y.dim() == 2:
894
+ y = y.unsqueeze(1)
895
+ else:
896
+ raise ValueError(
897
+ f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
898
+ )
899
+ if y.shape[-1] != self.in_channels:
900
+ raise ValueError(
901
+ f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
902
+ )
903
+ out, _ = self.attention(x_reshaped, y, y)
904
+ else:
905
+ out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
906
+ out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
907
+ out = self.norm(out)
908
+ out = self.dropout(out)
909
+ return out
910
+ #-----------------------------------------------------------------
911
+ class DownSampling(nn.Module):
912
+ """Downsampling module for NoisePredictor’s DownBlock.
913
+
914
+ Combines convolutional downsampling and max pooling (if enabled), concatenating
915
+ outputs to preserve feature information.
916
+
917
+ Parameters
918
+ ----------
919
+ in_channels : int
920
+ Number of input channels.
921
+ out_channels : int
922
+ Number of output channels.
923
+ down_sampling_factor : int
924
+ Factor for spatial downsampling.
925
+ conv_block : bool, optional
926
+ If True, include convolutional path (default: True).
927
+ max_pool : bool, optional
928
+ If True, include max pooling path (default: True).
929
+
930
+ Attributes
931
+ ----------
932
+ conv_block : bool
933
+ Flag for convolutional path.
934
+ max_pool : bool
935
+ Flag for max pooling path.
936
+ down_sampling_factor : int
937
+ Downsampling factor.
938
+ conv : torch.nn.Sequential or torch.nn.Identity
939
+ Convolutional path or identity if `conv_block=False`.
940
+ pool : torch.nn.Sequential or torch.nn.Identity
941
+ Max pooling path or identity if `max_pool=False`.
942
+ """
943
+ def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
944
+ super().__init__()
945
+ self.conv_block = conv_block
946
+ self.max_pool = max_pool
947
+ self.down_sampling_factor = down_sampling_factor
948
+ self.conv = nn.Sequential(
949
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
950
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
951
+ kernel_size=3, stride=down_sampling_factor, padding=1)
952
+ ) if conv_block else nn.Identity()
953
+ self.pool = nn.Sequential(
954
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
955
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
956
+ kernel_size=1, stride=1, padding=0)
957
+ ) if max_pool else nn.Identity()
958
+
959
+ def forward(self, batch):
960
+ """Downsamples input using convolutional and/or pooling paths.
961
+
962
+ Parameters
963
+ ----------
964
+ batch : torch.Tensor
965
+ Input tensor, shape (batch_size, in_channels, height, width).
966
+
967
+ Returns
968
+ -------
969
+ torch.Tensor
970
+ Downsampled tensor, shape (batch_size, out_channels,
971
+ height/down_sampling_factor, width/down_sampling_factor).
972
+ """
973
+ if not self.conv_block:
974
+ return self.pool(batch)
975
+ if not self.max_pool:
976
+ return self.conv(batch)
977
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
978
+ #--------------------------------------------------------------------------
979
+ class UpSampling(nn.Module):
980
+ """Upsampling module for NoisePredictor’s UpBlock.
981
+
982
+ Combines transposed convolution and nearest-neighbor upsampling (if enabled),
983
+ concatenating outputs to preserve feature information, with interpolation to align
984
+ spatial dimensions if needed.
985
+
986
+ Parameters
987
+ ----------
988
+ in_channels : int
989
+ Number of input channels.
990
+ out_channels : int
991
+ Number of output channels.
992
+ up_sampling_factor : int
993
+ Factor for spatial upsampling.
994
+ conv_block : bool, optional
995
+ If True, include transposed convolutional path (default: True).
996
+ up_sampling : bool, optional
997
+ If True, include nearest-neighbor upsampling path (default: True).
998
+
999
+ Attributes
1000
+ ----------
1001
+ conv_block : bool
1002
+ Flag for convolutional path.
1003
+ up_sampling : bool
1004
+ Flag for upsampling path.
1005
+ up_sampling_factor : int
1006
+ Upsampling factor.
1007
+ conv : torch.nn.Sequential or torch.nn.Identity
1008
+ Transposed convolutional path or identity if `conv_block=False`.
1009
+ up_sample : torch.nn.Sequential or torch.nn.Identity
1010
+ Nearest-neighbor upsampling path or identity if `up_sampling=False`.
1011
+ """
1012
+ def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
1013
+ super().__init__()
1014
+ self.conv_block = conv_block
1015
+ self.up_sampling = up_sampling
1016
+ self.up_sampling_factor = up_sampling_factor
1017
+ half_out_channels = out_channels // 2
1018
+ self.conv = nn.Sequential(
1019
+ nn.ConvTranspose2d(
1020
+ in_channels=in_channels,
1021
+ out_channels=half_out_channels if up_sampling else out_channels,
1022
+ kernel_size=3,
1023
+ stride=up_sampling_factor,
1024
+ padding=1,
1025
+ output_padding=up_sampling_factor - 1
1026
+ ),
1027
+ nn.Conv2d(
1028
+ in_channels=half_out_channels if up_sampling else out_channels,
1029
+ out_channels=half_out_channels if up_sampling else out_channels,
1030
+ kernel_size=1,
1031
+ stride=1,
1032
+ padding=0
1033
+ )
1034
+ ) if conv_block else nn.Identity()
1035
+
1036
+ self.up_sample = nn.Sequential(
1037
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
1038
+ nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
1039
+ kernel_size=1, stride=1, padding=0)
1040
+ ) if up_sampling else nn.Identity()
1041
+
1042
+ def forward(self, batch):
1043
+ """Upsamples input using convolutional and/or upsampling paths.
1044
+
1045
+ Parameters
1046
+ ----------
1047
+ batch : torch.Tensor
1048
+ Input tensor, shape (batch_size, in_channels, height, width).
1049
+
1050
+ Returns
1051
+ -------
1052
+ torch.Tensor
1053
+ Upsampled tensor, shape (batch_size, out_channels,
1054
+ height*up_sampling_factor, width*up_sampling_factor).
1055
+
1056
+ Notes
1057
+ -----
1058
+ - Interpolation is applied if the spatial dimensions of the convolutional and
1059
+ upsampling paths differ, using nearest-neighbor mode.
1060
+ """
1061
+ if not self.conv_block:
1062
+ return self.up_sample(batch)
1063
+ if not self.up_sampling:
1064
+ return self.conv(batch)
1065
+ conv_output = self.conv(batch)
1066
+ up_sample_output = self.up_sample(batch)
1067
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
1068
+ _, _, h, w = conv_output.shape
1069
+ up_sample_output = torch.nn.functional.interpolate(
1070
+ up_sample_output,
1071
+ size=(h, w),
1072
+ mode='nearest'
1073
+ )
1074
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)