TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/upsampler.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import math
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class UpsamplerUnCLIP(nn.Module):
|
|
11
|
+
"""Diffusion-based upsampler for UnCLIP models.
|
|
12
|
+
|
|
13
|
+
A U-Net-like model that upsamples low-resolution images to high-resolution images,
|
|
14
|
+
conditioned on noisy high-resolution images and timesteps, using residual blocks,
|
|
15
|
+
downsampling, and upsampling layers.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
`forward_diffusion` : nn.Module
|
|
20
|
+
Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
|
|
21
|
+
`in_channels` : int, optional
|
|
22
|
+
Number of input channels (default: 3, for RGB images).
|
|
23
|
+
`out_channels` : int, optional
|
|
24
|
+
Number of output channels (default: 3, for RGB noise prediction).
|
|
25
|
+
`model_channels` : int, optional
|
|
26
|
+
Base number of channels in the model (default: 192).
|
|
27
|
+
`num_res_blocks` : int, optional
|
|
28
|
+
Number of residual blocks per resolution level (default: 2).
|
|
29
|
+
`channel_mult` : Tuple[int, ...], optional
|
|
30
|
+
Channel multiplier for each resolution level (default: (1, 2, 4, 8)).
|
|
31
|
+
`dropout` : float, optional
|
|
32
|
+
Dropout probability for regularization (default: 0.1).
|
|
33
|
+
`time_embed_dim` : int, optional
|
|
34
|
+
Dimensionality of time embeddings (default: 768).
|
|
35
|
+
`low_res_size` : int, optional
|
|
36
|
+
Spatial size of low-resolution input (default: 64).
|
|
37
|
+
`high_res_size` : int, optional
|
|
38
|
+
Spatial size of high-resolution output (default: 256).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
forward_diffusion: nn.Module,
|
|
44
|
+
reverse_diffusion: nn.Module,
|
|
45
|
+
in_channels: int = 3,
|
|
46
|
+
out_channels: int = 3,
|
|
47
|
+
model_channels: int = 192,
|
|
48
|
+
num_res_blocks: int = 2,
|
|
49
|
+
channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
|
50
|
+
dropout_rate: float = 0.1,
|
|
51
|
+
time_embed_dim: int = 768,
|
|
52
|
+
low_res_size: int = 64,
|
|
53
|
+
high_res_size: int = 256,
|
|
54
|
+
) -> None:
|
|
55
|
+
super().__init__()
|
|
56
|
+
|
|
57
|
+
self.forward_diffusion = forward_diffusion # this will be used on training time inside 'TrainUpsamplerUnCLIP'
|
|
58
|
+
self.reverse_diffusion = reverse_diffusion # this module will be used in inference time
|
|
59
|
+
self.in_channels = in_channels
|
|
60
|
+
self.out_channels = out_channels
|
|
61
|
+
self.model_channels = model_channels
|
|
62
|
+
self.num_res_blocks = num_res_blocks
|
|
63
|
+
self.low_res_size = low_res_size
|
|
64
|
+
self.high_res_size = high_res_size
|
|
65
|
+
|
|
66
|
+
# Time embedding
|
|
67
|
+
self.time_embed = nn.Sequential(
|
|
68
|
+
SinusoidalPositionalEmbedding(model_channels),
|
|
69
|
+
nn.Linear(model_channels, time_embed_dim),
|
|
70
|
+
nn.SiLU(),
|
|
71
|
+
nn.Linear(time_embed_dim, time_embed_dim),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Input projection
|
|
75
|
+
# Concatenate noisy high-res and upsampled low-res
|
|
76
|
+
self.input_proj = nn.Conv2d(in_channels * 2, model_channels, 3, padding=1)
|
|
77
|
+
|
|
78
|
+
# Encoder (downsampling path)
|
|
79
|
+
self.encoder_blocks = nn.ModuleList()
|
|
80
|
+
self.downsample_blocks = nn.ModuleList()
|
|
81
|
+
|
|
82
|
+
ch = model_channels
|
|
83
|
+
for level, mult in enumerate(channel_mult):
|
|
84
|
+
for _ in range(num_res_blocks):
|
|
85
|
+
self.encoder_blocks.append(
|
|
86
|
+
ResBlock(ch, model_channels * mult, time_embed_dim, dropout_rate)
|
|
87
|
+
)
|
|
88
|
+
ch = model_channels * mult
|
|
89
|
+
|
|
90
|
+
if level != len(channel_mult) - 1:
|
|
91
|
+
self.downsample_blocks.append(DownsampleBlock(ch, ch))
|
|
92
|
+
|
|
93
|
+
# Middle blocks
|
|
94
|
+
self.middle_blocks = nn.ModuleList([
|
|
95
|
+
ResBlock(ch, ch, time_embed_dim, dropout_rate),
|
|
96
|
+
ResBlock(ch, ch, time_embed_dim, dropout_rate),
|
|
97
|
+
])
|
|
98
|
+
|
|
99
|
+
# Decoder (upsampling path)
|
|
100
|
+
self.decoder_blocks = nn.ModuleList()
|
|
101
|
+
self.upsample_blocks = nn.ModuleList()
|
|
102
|
+
|
|
103
|
+
for level, mult in reversed(list(enumerate(channel_mult))):
|
|
104
|
+
for i in range(num_res_blocks + 1):
|
|
105
|
+
# Skip connections double the input channels
|
|
106
|
+
in_ch = ch + (model_channels * mult if i == 0 else 0)
|
|
107
|
+
out_ch = model_channels * mult
|
|
108
|
+
|
|
109
|
+
self.decoder_blocks.append(
|
|
110
|
+
ResBlock(in_ch, out_ch, time_embed_dim, dropout_rate)
|
|
111
|
+
)
|
|
112
|
+
ch = out_ch
|
|
113
|
+
|
|
114
|
+
if level != 0:
|
|
115
|
+
self.upsample_blocks.append(UpsampleBlock(ch, ch))
|
|
116
|
+
|
|
117
|
+
# Output projection
|
|
118
|
+
self.output_proj = nn.Sequential(
|
|
119
|
+
nn.GroupNorm(8, ch),
|
|
120
|
+
nn.SiLU(),
|
|
121
|
+
nn.Conv2d(ch, out_channels, 3, padding=1),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def forward(self, x_high: torch.Tensor, t: torch.Tensor, x_low: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""Predicts noise for the upsampling process.
|
|
126
|
+
|
|
127
|
+
Processes a noisy high-resolution image and a low-resolution conditioning image,
|
|
128
|
+
conditioned on timesteps, to predict the noise component for denoising.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
`x_high` : torch.Tensor
|
|
133
|
+
Noisy high-resolution image, shape (batch_size, in_channels, high_res_size, high_res_size).
|
|
134
|
+
`t` : torch.Tensor
|
|
135
|
+
Timestep indices, shape (batch_size,).
|
|
136
|
+
`x_low` : torch.Tensor
|
|
137
|
+
Low-resolution conditioning image, shape (batch_size, in_channels, low_res_size, low_res_size).
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
out : torch.Tensor
|
|
142
|
+
Predicted noise, shape (batch_size, out_channels, high_res_size, high_res_size).
|
|
143
|
+
"""
|
|
144
|
+
# Upsample low-resolution image to match high-resolution
|
|
145
|
+
x_low_upsampled = F.interpolate(
|
|
146
|
+
x_low,
|
|
147
|
+
size=(x_high.shape[-2], x_high.shape[-1]),
|
|
148
|
+
mode='bicubic',
|
|
149
|
+
align_corners=False
|
|
150
|
+
)
|
|
151
|
+
# print(f"After upsampling x_low: shape={x_low_upsampled.shape}, dtype={x_low_upsampled.dtype}")
|
|
152
|
+
|
|
153
|
+
# Concatenate noisy high-res and upsampled low-res
|
|
154
|
+
x = torch.cat([x_high, x_low_upsampled], dim=1)
|
|
155
|
+
# print(f"After concatenating x_high and x_low_upsampled: shape={x.shape}, dtype={x.dtype}")
|
|
156
|
+
|
|
157
|
+
# Time embedding
|
|
158
|
+
time_emb = self.time_embed(t.float()) # Ensure float for embedding
|
|
159
|
+
# print(f"After time embedding: shape={time_emb.shape}, dtype={time_emb.dtype}")
|
|
160
|
+
|
|
161
|
+
# Input projection
|
|
162
|
+
h = self.input_proj(x)
|
|
163
|
+
# print(f"After input projection: shape={h.shape}, dtype={h.dtype}")
|
|
164
|
+
|
|
165
|
+
# Store skip connections
|
|
166
|
+
skip_connections = []
|
|
167
|
+
|
|
168
|
+
# Encoder
|
|
169
|
+
for i, block in enumerate(self.encoder_blocks):
|
|
170
|
+
h = block(h, time_emb)
|
|
171
|
+
# print(f"After encoder block {i + 1}: shape={h.shape}, dtype={h.dtype}")
|
|
172
|
+
if (i + 1) % self.num_res_blocks == 0:
|
|
173
|
+
skip_connections.append(h)
|
|
174
|
+
# print(f"Saved skip connection {len(skip_connections)}: shape={h.shape}, dtype={h.dtype}")
|
|
175
|
+
downsample_idx = (i + 1) // self.num_res_blocks - 1
|
|
176
|
+
if downsample_idx < len(self.downsample_blocks):
|
|
177
|
+
h = self.downsample_blocks[downsample_idx](h)
|
|
178
|
+
# print(f"After downsample {downsample_idx + 1}: shape={h.shape}, dtype={h.dtype}")
|
|
179
|
+
|
|
180
|
+
# Middle
|
|
181
|
+
for i, block in enumerate(self.middle_blocks):
|
|
182
|
+
h = block(h, time_emb)
|
|
183
|
+
# print(f"After middle block {i + 1}: shape={h.shape}, dtype={h.dtype}")
|
|
184
|
+
|
|
185
|
+
# Decoder
|
|
186
|
+
upsample_idx = 0
|
|
187
|
+
for i, block in enumerate(self.decoder_blocks):
|
|
188
|
+
# Add skip connection
|
|
189
|
+
if i % (self.num_res_blocks + 1) == 0 and skip_connections:
|
|
190
|
+
skip = skip_connections.pop()
|
|
191
|
+
# print(f"Using skip connection {len(skip_connections) + 1}: shape={skip.shape}, dtype={skip.dtype}")
|
|
192
|
+
h = torch.cat([h, skip], dim=1)
|
|
193
|
+
# print(f"After concatenating skip connection: shape={h.shape}, dtype={h.dtype}")
|
|
194
|
+
|
|
195
|
+
h = block(h, time_emb)
|
|
196
|
+
# print(f"After decoder block {i + 1}: shape={h.shape}, dtype={h.dtype}")
|
|
197
|
+
|
|
198
|
+
# Upsample at the end of each resolution level
|
|
199
|
+
if ((i + 1) % (self.num_res_blocks + 1) == 0 and
|
|
200
|
+
upsample_idx < len(self.upsample_blocks)):
|
|
201
|
+
h = self.upsample_blocks[upsample_idx](h)
|
|
202
|
+
# print(f"After upsample {upsample_idx + 1}: shape={h.shape}, dtype={h.dtype}")
|
|
203
|
+
upsample_idx += 1
|
|
204
|
+
|
|
205
|
+
# Output projection
|
|
206
|
+
out = self.output_proj(h)
|
|
207
|
+
# print(f"After output projection: shape={out.shape}, dtype={out.dtype}")
|
|
208
|
+
|
|
209
|
+
return out
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
214
|
+
"""Sinusoidal positional embedding for timesteps.
|
|
215
|
+
|
|
216
|
+
Generates sinusoidal embeddings for timesteps to condition the upsampler on the
|
|
217
|
+
diffusion process stage.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
`dim` : int
|
|
222
|
+
Dimensionality of the embedding.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
def __init__(self, dim: int):
|
|
226
|
+
super().__init__()
|
|
227
|
+
self.dim = dim
|
|
228
|
+
|
|
229
|
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
|
230
|
+
"""Generates sinusoidal embeddings for timesteps.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
`timesteps` : torch.Tensor
|
|
235
|
+
Timestep indices, shape (batch_size,).
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
embeddings : torch.Tensor
|
|
240
|
+
Sinusoidal embeddings, shape (batch_size, dim).
|
|
241
|
+
"""
|
|
242
|
+
device = timesteps.device
|
|
243
|
+
half_dim = self.dim // 2
|
|
244
|
+
embeddings = math.log(10000) / (half_dim - 1)
|
|
245
|
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
|
246
|
+
embeddings = timesteps[:, None] * embeddings[None, :]
|
|
247
|
+
embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
|
|
248
|
+
return embeddings
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class ResBlock(nn.Module):
|
|
252
|
+
"""Residual block with time embedding and conditioning.
|
|
253
|
+
|
|
254
|
+
A convolutional residual block with group normalization, time embedding conditioning,
|
|
255
|
+
and optional scale-shift normalization, used in the UnCLIP upsampler.
|
|
256
|
+
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
`in_channels` : int
|
|
260
|
+
Number of input channels.
|
|
261
|
+
`out_channels` : int
|
|
262
|
+
Number of output channels.
|
|
263
|
+
`time_embed_dim` : int
|
|
264
|
+
Dimensionality of time embeddings.
|
|
265
|
+
`dropout` : float, optional
|
|
266
|
+
Dropout probability (default: 0.1).
|
|
267
|
+
`use_scale_shift_norm` : bool, optional
|
|
268
|
+
Whether to use scale-shift normalization for time embeddings (default: True).
|
|
269
|
+
"""
|
|
270
|
+
def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int,
|
|
271
|
+
dropout: float = 0.1, use_scale_shift_norm: bool = True):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.use_scale_shift_norm = use_scale_shift_norm
|
|
274
|
+
|
|
275
|
+
self.in_layers = nn.Sequential(
|
|
276
|
+
nn.GroupNorm(8, in_channels),
|
|
277
|
+
nn.SiLU(),
|
|
278
|
+
nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
self.time_emb_proj = nn.Sequential(
|
|
282
|
+
nn.SiLU(),
|
|
283
|
+
nn.Linear(time_embed_dim, out_channels * 2 if use_scale_shift_norm else out_channels)
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Changed: Separated the out_norm from the rest of out_layers to avoid slicing issues with nn.Sequential.
|
|
287
|
+
# Original would raise TypeError because nn.Sequential[1:] does not return a callable Sequential and cannot be directly invoked.
|
|
288
|
+
self.out_norm = nn.GroupNorm(8, out_channels)
|
|
289
|
+
self.out_rest = nn.Sequential(
|
|
290
|
+
nn.SiLU(),
|
|
291
|
+
nn.Dropout(dropout),
|
|
292
|
+
nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if in_channels != out_channels:
|
|
296
|
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)
|
|
297
|
+
else:
|
|
298
|
+
self.skip_connection = nn.Identity()
|
|
299
|
+
|
|
300
|
+
def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
|
|
301
|
+
"""Processes input through the residual block with time conditioning.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
`x` : torch.Tensor
|
|
306
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
307
|
+
`time_emb` : torch.Tensor
|
|
308
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
out : torch.Tensor
|
|
313
|
+
Output tensor, shape (batch_size, out_channels, height, width).
|
|
314
|
+
"""
|
|
315
|
+
h = self.in_layers(x)
|
|
316
|
+
|
|
317
|
+
# Apply time embedding
|
|
318
|
+
emb_out = self.time_emb_proj(time_emb)[:, :, None, None]
|
|
319
|
+
|
|
320
|
+
if self.use_scale_shift_norm:
|
|
321
|
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
|
322
|
+
# Changed: Use self.out_norm instead of self.out_layers[0].
|
|
323
|
+
h = self.out_norm(h) * (1 + scale) + shift
|
|
324
|
+
# Changed: Use self.out_rest instead of self.out_layers[1:].
|
|
325
|
+
h = self.out_rest(h)
|
|
326
|
+
else:
|
|
327
|
+
h = h + emb_out
|
|
328
|
+
# Changed: Apply out_norm and out_rest consistently.
|
|
329
|
+
h = self.out_norm(h)
|
|
330
|
+
h = self.out_rest(h)
|
|
331
|
+
|
|
332
|
+
return h + self.skip_connection(x)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class UpsampleBlock(nn.Module):
|
|
336
|
+
"""Upsampling block using transposed convolution.
|
|
337
|
+
|
|
338
|
+
Increases the spatial resolution of the input tensor using a transposed convolution.
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
`in_channels` : int
|
|
343
|
+
Number of input channels.
|
|
344
|
+
`out_channels` : int
|
|
345
|
+
Number of output channels.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
349
|
+
super().__init__()
|
|
350
|
+
self.conv = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
|
|
351
|
+
|
|
352
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
353
|
+
"""Upsamples the input tensor.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
`x` : torch.Tensor
|
|
358
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
359
|
+
|
|
360
|
+
Returns
|
|
361
|
+
-------
|
|
362
|
+
out : torch.Tensor
|
|
363
|
+
Upsampled tensor, shape (batch_size, out_channels, height*2, width*2).
|
|
364
|
+
"""
|
|
365
|
+
return self.conv(x)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class DownsampleBlock(nn.Module):
|
|
369
|
+
"""Downsampling block using strided convolution.
|
|
370
|
+
|
|
371
|
+
Reduces the spatial resolution of the input tensor using a strided convolution.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
`in_channels` : int
|
|
376
|
+
Number of input channels.
|
|
377
|
+
`out_channels` : int
|
|
378
|
+
Number of output channels.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
382
|
+
super().__init__()
|
|
383
|
+
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
|
|
384
|
+
|
|
385
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
386
|
+
"""Downsamples the input tensor.
|
|
387
|
+
|
|
388
|
+
Parameters
|
|
389
|
+
----------
|
|
390
|
+
`x` : torch.Tensor
|
|
391
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
392
|
+
|
|
393
|
+
Returns
|
|
394
|
+
-------
|
|
395
|
+
out : torch.Tensor
|
|
396
|
+
Downsampled tensor, shape (batch_size, out_channels, height//2, width//2).
|
|
397
|
+
"""
|
|
398
|
+
return self.conv(x)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
"""
|
|
402
|
+
hyp = VarianceSchedulerUnCLIP(
|
|
403
|
+
num_steps=1000,
|
|
404
|
+
beta_start=1e-4,
|
|
405
|
+
beta_end=0.02,
|
|
406
|
+
trainable_beta=False,
|
|
407
|
+
beta_method="cosine"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
forward = ForwardUnCLIP(hyp)
|
|
411
|
+
|
|
412
|
+
model = UpsamplerUnCLIP(
|
|
413
|
+
forward_diffusion=forward,
|
|
414
|
+
in_channels= 3,
|
|
415
|
+
out_channels= 3,
|
|
416
|
+
model_channels= 32,
|
|
417
|
+
num_res_blocks = 2,
|
|
418
|
+
channel_mult = (1, 2, 4, 8),
|
|
419
|
+
dropout = 0.1,
|
|
420
|
+
time_embed_dim = 756,
|
|
421
|
+
low_res_size = 256,
|
|
422
|
+
high_res_size = 1024
|
|
423
|
+
)
|
|
424
|
+
xl = torch.randn((2, 3, 256, 256))
|
|
425
|
+
xh = torch.randn((2, 3, 1024, 1024))
|
|
426
|
+
t = torch.tensor([3, 5])
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
result = model(xh, t, xl)
|
|
430
|
+
print(result.size())
|
|
431
|
+
print(result.dtype)
|
|
432
|
+
"""
|