TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/upsampler.py ADDED
@@ -0,0 +1,432 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Tuple
6
+
7
+
8
+
9
+
10
+ class UpsamplerUnCLIP(nn.Module):
11
+ """Diffusion-based upsampler for UnCLIP models.
12
+
13
+ A U-Net-like model that upsamples low-resolution images to high-resolution images,
14
+ conditioned on noisy high-resolution images and timesteps, using residual blocks,
15
+ downsampling, and upsampling layers.
16
+
17
+ Parameters
18
+ ----------
19
+ `forward_diffusion` : nn.Module
20
+ Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
21
+ `in_channels` : int, optional
22
+ Number of input channels (default: 3, for RGB images).
23
+ `out_channels` : int, optional
24
+ Number of output channels (default: 3, for RGB noise prediction).
25
+ `model_channels` : int, optional
26
+ Base number of channels in the model (default: 192).
27
+ `num_res_blocks` : int, optional
28
+ Number of residual blocks per resolution level (default: 2).
29
+ `channel_mult` : Tuple[int, ...], optional
30
+ Channel multiplier for each resolution level (default: (1, 2, 4, 8)).
31
+ `dropout` : float, optional
32
+ Dropout probability for regularization (default: 0.1).
33
+ `time_embed_dim` : int, optional
34
+ Dimensionality of time embeddings (default: 768).
35
+ `low_res_size` : int, optional
36
+ Spatial size of low-resolution input (default: 64).
37
+ `high_res_size` : int, optional
38
+ Spatial size of high-resolution output (default: 256).
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ forward_diffusion: nn.Module,
44
+ reverse_diffusion: nn.Module,
45
+ in_channels: int = 3,
46
+ out_channels: int = 3,
47
+ model_channels: int = 192,
48
+ num_res_blocks: int = 2,
49
+ channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
50
+ dropout_rate: float = 0.1,
51
+ time_embed_dim: int = 768,
52
+ low_res_size: int = 64,
53
+ high_res_size: int = 256,
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.forward_diffusion = forward_diffusion # this will be used on training time inside 'TrainUpsamplerUnCLIP'
58
+ self.reverse_diffusion = reverse_diffusion # this module will be used in inference time
59
+ self.in_channels = in_channels
60
+ self.out_channels = out_channels
61
+ self.model_channels = model_channels
62
+ self.num_res_blocks = num_res_blocks
63
+ self.low_res_size = low_res_size
64
+ self.high_res_size = high_res_size
65
+
66
+ # Time embedding
67
+ self.time_embed = nn.Sequential(
68
+ SinusoidalPositionalEmbedding(model_channels),
69
+ nn.Linear(model_channels, time_embed_dim),
70
+ nn.SiLU(),
71
+ nn.Linear(time_embed_dim, time_embed_dim),
72
+ )
73
+
74
+ # Input projection
75
+ # Concatenate noisy high-res and upsampled low-res
76
+ self.input_proj = nn.Conv2d(in_channels * 2, model_channels, 3, padding=1)
77
+
78
+ # Encoder (downsampling path)
79
+ self.encoder_blocks = nn.ModuleList()
80
+ self.downsample_blocks = nn.ModuleList()
81
+
82
+ ch = model_channels
83
+ for level, mult in enumerate(channel_mult):
84
+ for _ in range(num_res_blocks):
85
+ self.encoder_blocks.append(
86
+ ResBlock(ch, model_channels * mult, time_embed_dim, dropout_rate)
87
+ )
88
+ ch = model_channels * mult
89
+
90
+ if level != len(channel_mult) - 1:
91
+ self.downsample_blocks.append(DownsampleBlock(ch, ch))
92
+
93
+ # Middle blocks
94
+ self.middle_blocks = nn.ModuleList([
95
+ ResBlock(ch, ch, time_embed_dim, dropout_rate),
96
+ ResBlock(ch, ch, time_embed_dim, dropout_rate),
97
+ ])
98
+
99
+ # Decoder (upsampling path)
100
+ self.decoder_blocks = nn.ModuleList()
101
+ self.upsample_blocks = nn.ModuleList()
102
+
103
+ for level, mult in reversed(list(enumerate(channel_mult))):
104
+ for i in range(num_res_blocks + 1):
105
+ # Skip connections double the input channels
106
+ in_ch = ch + (model_channels * mult if i == 0 else 0)
107
+ out_ch = model_channels * mult
108
+
109
+ self.decoder_blocks.append(
110
+ ResBlock(in_ch, out_ch, time_embed_dim, dropout_rate)
111
+ )
112
+ ch = out_ch
113
+
114
+ if level != 0:
115
+ self.upsample_blocks.append(UpsampleBlock(ch, ch))
116
+
117
+ # Output projection
118
+ self.output_proj = nn.Sequential(
119
+ nn.GroupNorm(8, ch),
120
+ nn.SiLU(),
121
+ nn.Conv2d(ch, out_channels, 3, padding=1),
122
+ )
123
+
124
+ def forward(self, x_high: torch.Tensor, t: torch.Tensor, x_low: torch.Tensor) -> torch.Tensor:
125
+ """Predicts noise for the upsampling process.
126
+
127
+ Processes a noisy high-resolution image and a low-resolution conditioning image,
128
+ conditioned on timesteps, to predict the noise component for denoising.
129
+
130
+ Parameters
131
+ ----------
132
+ `x_high` : torch.Tensor
133
+ Noisy high-resolution image, shape (batch_size, in_channels, high_res_size, high_res_size).
134
+ `t` : torch.Tensor
135
+ Timestep indices, shape (batch_size,).
136
+ `x_low` : torch.Tensor
137
+ Low-resolution conditioning image, shape (batch_size, in_channels, low_res_size, low_res_size).
138
+
139
+ Returns
140
+ -------
141
+ out : torch.Tensor
142
+ Predicted noise, shape (batch_size, out_channels, high_res_size, high_res_size).
143
+ """
144
+ # Upsample low-resolution image to match high-resolution
145
+ x_low_upsampled = F.interpolate(
146
+ x_low,
147
+ size=(x_high.shape[-2], x_high.shape[-1]),
148
+ mode='bicubic',
149
+ align_corners=False
150
+ )
151
+ # print(f"After upsampling x_low: shape={x_low_upsampled.shape}, dtype={x_low_upsampled.dtype}")
152
+
153
+ # Concatenate noisy high-res and upsampled low-res
154
+ x = torch.cat([x_high, x_low_upsampled], dim=1)
155
+ # print(f"After concatenating x_high and x_low_upsampled: shape={x.shape}, dtype={x.dtype}")
156
+
157
+ # Time embedding
158
+ time_emb = self.time_embed(t.float()) # Ensure float for embedding
159
+ # print(f"After time embedding: shape={time_emb.shape}, dtype={time_emb.dtype}")
160
+
161
+ # Input projection
162
+ h = self.input_proj(x)
163
+ # print(f"After input projection: shape={h.shape}, dtype={h.dtype}")
164
+
165
+ # Store skip connections
166
+ skip_connections = []
167
+
168
+ # Encoder
169
+ for i, block in enumerate(self.encoder_blocks):
170
+ h = block(h, time_emb)
171
+ # print(f"After encoder block {i + 1}: shape={h.shape}, dtype={h.dtype}")
172
+ if (i + 1) % self.num_res_blocks == 0:
173
+ skip_connections.append(h)
174
+ # print(f"Saved skip connection {len(skip_connections)}: shape={h.shape}, dtype={h.dtype}")
175
+ downsample_idx = (i + 1) // self.num_res_blocks - 1
176
+ if downsample_idx < len(self.downsample_blocks):
177
+ h = self.downsample_blocks[downsample_idx](h)
178
+ # print(f"After downsample {downsample_idx + 1}: shape={h.shape}, dtype={h.dtype}")
179
+
180
+ # Middle
181
+ for i, block in enumerate(self.middle_blocks):
182
+ h = block(h, time_emb)
183
+ # print(f"After middle block {i + 1}: shape={h.shape}, dtype={h.dtype}")
184
+
185
+ # Decoder
186
+ upsample_idx = 0
187
+ for i, block in enumerate(self.decoder_blocks):
188
+ # Add skip connection
189
+ if i % (self.num_res_blocks + 1) == 0 and skip_connections:
190
+ skip = skip_connections.pop()
191
+ # print(f"Using skip connection {len(skip_connections) + 1}: shape={skip.shape}, dtype={skip.dtype}")
192
+ h = torch.cat([h, skip], dim=1)
193
+ # print(f"After concatenating skip connection: shape={h.shape}, dtype={h.dtype}")
194
+
195
+ h = block(h, time_emb)
196
+ # print(f"After decoder block {i + 1}: shape={h.shape}, dtype={h.dtype}")
197
+
198
+ # Upsample at the end of each resolution level
199
+ if ((i + 1) % (self.num_res_blocks + 1) == 0 and
200
+ upsample_idx < len(self.upsample_blocks)):
201
+ h = self.upsample_blocks[upsample_idx](h)
202
+ # print(f"After upsample {upsample_idx + 1}: shape={h.shape}, dtype={h.dtype}")
203
+ upsample_idx += 1
204
+
205
+ # Output projection
206
+ out = self.output_proj(h)
207
+ # print(f"After output projection: shape={out.shape}, dtype={out.dtype}")
208
+
209
+ return out
210
+
211
+
212
+
213
+ class SinusoidalPositionalEmbedding(nn.Module):
214
+ """Sinusoidal positional embedding for timesteps.
215
+
216
+ Generates sinusoidal embeddings for timesteps to condition the upsampler on the
217
+ diffusion process stage.
218
+
219
+ Parameters
220
+ ----------
221
+ `dim` : int
222
+ Dimensionality of the embedding.
223
+ """
224
+
225
+ def __init__(self, dim: int):
226
+ super().__init__()
227
+ self.dim = dim
228
+
229
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
230
+ """Generates sinusoidal embeddings for timesteps.
231
+
232
+ Parameters
233
+ ----------
234
+ `timesteps` : torch.Tensor
235
+ Timestep indices, shape (batch_size,).
236
+
237
+ Returns
238
+ -------
239
+ embeddings : torch.Tensor
240
+ Sinusoidal embeddings, shape (batch_size, dim).
241
+ """
242
+ device = timesteps.device
243
+ half_dim = self.dim // 2
244
+ embeddings = math.log(10000) / (half_dim - 1)
245
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
246
+ embeddings = timesteps[:, None] * embeddings[None, :]
247
+ embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
248
+ return embeddings
249
+
250
+
251
+ class ResBlock(nn.Module):
252
+ """Residual block with time embedding and conditioning.
253
+
254
+ A convolutional residual block with group normalization, time embedding conditioning,
255
+ and optional scale-shift normalization, used in the UnCLIP upsampler.
256
+
257
+ Parameters
258
+ ----------
259
+ `in_channels` : int
260
+ Number of input channels.
261
+ `out_channels` : int
262
+ Number of output channels.
263
+ `time_embed_dim` : int
264
+ Dimensionality of time embeddings.
265
+ `dropout` : float, optional
266
+ Dropout probability (default: 0.1).
267
+ `use_scale_shift_norm` : bool, optional
268
+ Whether to use scale-shift normalization for time embeddings (default: True).
269
+ """
270
+ def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int,
271
+ dropout: float = 0.1, use_scale_shift_norm: bool = True):
272
+ super().__init__()
273
+ self.use_scale_shift_norm = use_scale_shift_norm
274
+
275
+ self.in_layers = nn.Sequential(
276
+ nn.GroupNorm(8, in_channels),
277
+ nn.SiLU(),
278
+ nn.Conv2d(in_channels, out_channels, 3, padding=1)
279
+ )
280
+
281
+ self.time_emb_proj = nn.Sequential(
282
+ nn.SiLU(),
283
+ nn.Linear(time_embed_dim, out_channels * 2 if use_scale_shift_norm else out_channels)
284
+ )
285
+
286
+ # Changed: Separated the out_norm from the rest of out_layers to avoid slicing issues with nn.Sequential.
287
+ # Original would raise TypeError because nn.Sequential[1:] does not return a callable Sequential and cannot be directly invoked.
288
+ self.out_norm = nn.GroupNorm(8, out_channels)
289
+ self.out_rest = nn.Sequential(
290
+ nn.SiLU(),
291
+ nn.Dropout(dropout),
292
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
293
+ )
294
+
295
+ if in_channels != out_channels:
296
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)
297
+ else:
298
+ self.skip_connection = nn.Identity()
299
+
300
+ def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
301
+ """Processes input through the residual block with time conditioning.
302
+
303
+ Parameters
304
+ ----------
305
+ `x` : torch.Tensor
306
+ Input tensor, shape (batch_size, in_channels, height, width).
307
+ `time_emb` : torch.Tensor
308
+ Time embeddings, shape (batch_size, time_embed_dim).
309
+
310
+ Returns
311
+ -------
312
+ out : torch.Tensor
313
+ Output tensor, shape (batch_size, out_channels, height, width).
314
+ """
315
+ h = self.in_layers(x)
316
+
317
+ # Apply time embedding
318
+ emb_out = self.time_emb_proj(time_emb)[:, :, None, None]
319
+
320
+ if self.use_scale_shift_norm:
321
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
322
+ # Changed: Use self.out_norm instead of self.out_layers[0].
323
+ h = self.out_norm(h) * (1 + scale) + shift
324
+ # Changed: Use self.out_rest instead of self.out_layers[1:].
325
+ h = self.out_rest(h)
326
+ else:
327
+ h = h + emb_out
328
+ # Changed: Apply out_norm and out_rest consistently.
329
+ h = self.out_norm(h)
330
+ h = self.out_rest(h)
331
+
332
+ return h + self.skip_connection(x)
333
+
334
+
335
+ class UpsampleBlock(nn.Module):
336
+ """Upsampling block using transposed convolution.
337
+
338
+ Increases the spatial resolution of the input tensor using a transposed convolution.
339
+
340
+ Parameters
341
+ ----------
342
+ `in_channels` : int
343
+ Number of input channels.
344
+ `out_channels` : int
345
+ Number of output channels.
346
+ """
347
+
348
+ def __init__(self, in_channels: int, out_channels: int):
349
+ super().__init__()
350
+ self.conv = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
351
+
352
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
353
+ """Upsamples the input tensor.
354
+
355
+ Parameters
356
+ ----------
357
+ `x` : torch.Tensor
358
+ Input tensor, shape (batch_size, in_channels, height, width).
359
+
360
+ Returns
361
+ -------
362
+ out : torch.Tensor
363
+ Upsampled tensor, shape (batch_size, out_channels, height*2, width*2).
364
+ """
365
+ return self.conv(x)
366
+
367
+
368
+ class DownsampleBlock(nn.Module):
369
+ """Downsampling block using strided convolution.
370
+
371
+ Reduces the spatial resolution of the input tensor using a strided convolution.
372
+
373
+ Parameters
374
+ ----------
375
+ `in_channels` : int
376
+ Number of input channels.
377
+ `out_channels` : int
378
+ Number of output channels.
379
+ """
380
+
381
+ def __init__(self, in_channels: int, out_channels: int):
382
+ super().__init__()
383
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
384
+
385
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
386
+ """Downsamples the input tensor.
387
+
388
+ Parameters
389
+ ----------
390
+ `x` : torch.Tensor
391
+ Input tensor, shape (batch_size, in_channels, height, width).
392
+
393
+ Returns
394
+ -------
395
+ out : torch.Tensor
396
+ Downsampled tensor, shape (batch_size, out_channels, height//2, width//2).
397
+ """
398
+ return self.conv(x)
399
+
400
+
401
+ """
402
+ hyp = VarianceSchedulerUnCLIP(
403
+ num_steps=1000,
404
+ beta_start=1e-4,
405
+ beta_end=0.02,
406
+ trainable_beta=False,
407
+ beta_method="cosine"
408
+ )
409
+
410
+ forward = ForwardUnCLIP(hyp)
411
+
412
+ model = UpsamplerUnCLIP(
413
+ forward_diffusion=forward,
414
+ in_channels= 3,
415
+ out_channels= 3,
416
+ model_channels= 32,
417
+ num_res_blocks = 2,
418
+ channel_mult = (1, 2, 4, 8),
419
+ dropout = 0.1,
420
+ time_embed_dim = 756,
421
+ low_res_size = 256,
422
+ high_res_size = 1024
423
+ )
424
+ xl = torch.randn((2, 3, 256, 256))
425
+ xh = torch.randn((2, 3, 1024, 1024))
426
+ t = torch.tensor([3, 5])
427
+
428
+
429
+ result = model(xh, t, xl)
430
+ print(result.size())
431
+ print(result.dtype)
432
+ """