TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ldm/reverse_ldm.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ReverseSDE(nn.Module):
|
|
8
|
+
def __init__(self, hyper_params, method):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.hyper_params = hyper_params
|
|
11
|
+
self.method = method
|
|
12
|
+
|
|
13
|
+
def forward(self, xt, noise, predicted_noise, time_steps):
|
|
14
|
+
|
|
15
|
+
dt = self.hyper_params.dt
|
|
16
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
17
|
+
cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
|
|
18
|
+
if self.method == "ve":
|
|
19
|
+
sigma_t = self.hyper_params.sigmas[time_steps]
|
|
20
|
+
sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
21
|
+
sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
|
|
22
|
+
drift = -(sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
|
|
23
|
+
diffusion = sigma_diff.view(-1, 1, 1, 1) * noise if noise is not None else 0
|
|
24
|
+
xt = xt + drift + diffusion
|
|
25
|
+
xt = torch.clamp(xt, -1e5, 1e5)
|
|
26
|
+
|
|
27
|
+
elif self.method == "vp":
|
|
28
|
+
drift = -0.5 * betas * xt * dt - betas * predicted_noise * dt
|
|
29
|
+
diffusion = torch.sqrt(betas * dt) * noise if noise is not None else 0
|
|
30
|
+
xt = xt + drift + diffusion
|
|
31
|
+
|
|
32
|
+
elif self.method == "sub-vp":
|
|
33
|
+
drift = -0.5 * betas * xt * dt - betas * (1 - torch.exp(-2 * cum_betas)) * predicted_noise * dt
|
|
34
|
+
diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise if noise is not None else 0
|
|
35
|
+
xt = xt + drift + diffusion
|
|
36
|
+
|
|
37
|
+
elif self.method == "ode":
|
|
38
|
+
if self.method == "ve":
|
|
39
|
+
sigma_t = self.hyper_params.sigmas[time_steps]
|
|
40
|
+
sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
41
|
+
drift = -0.5 * (sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
|
|
42
|
+
else:
|
|
43
|
+
drift = -0.5 * betas * xt * dt - 0.5 * betas * predicted_noise * dt
|
|
44
|
+
xt = xt + drift
|
|
45
|
+
xt = torch.clamp(xt, -1e5, 1e5)
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unknown method: {self.method}")
|
|
48
|
+
return xt
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ReverseDDIM(nn.Module):
|
|
54
|
+
def __init__(self, hyper_params):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.hyper_params = hyper_params
|
|
57
|
+
|
|
58
|
+
def forward(self, xt, predicted_noise, time_steps, prev_time_steps):
|
|
59
|
+
|
|
60
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.tau_num_steps)):
|
|
61
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
|
|
62
|
+
if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.hyper_params.tau_num_steps)):
|
|
63
|
+
raise ValueError(f"prev_time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
|
|
64
|
+
|
|
65
|
+
_, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.hyper_params.get_tau_schedule()
|
|
66
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
67
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
68
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
69
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
70
|
+
|
|
71
|
+
eta = self.hyper_params.eta
|
|
72
|
+
x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
|
|
73
|
+
noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
|
|
74
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
|
|
75
|
+
direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
|
|
76
|
+
xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
|
|
77
|
+
|
|
78
|
+
return xt_prev, x0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ReverseDDPM(nn.Module):
|
|
83
|
+
"""reverse diffusion process of DDPM."""
|
|
84
|
+
def __init__(self, hyper_params):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.hyper_params = hyper_params # hyperparameters class
|
|
87
|
+
|
|
88
|
+
def forward(self, xt, predicted_noise, time_steps):
|
|
89
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
|
|
90
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
|
|
91
|
+
|
|
92
|
+
if self.hyper_params.trainable_beta:
|
|
93
|
+
betas_t, alphas_t, alpha_bars_t, _, _ = self.hyper_params.compute_schedule(self.hyper_params.betas)
|
|
94
|
+
betas_t = betas_t[time_steps].to(xt.device)
|
|
95
|
+
alphas_t = alphas_t[time_steps].to(xt.device)
|
|
96
|
+
alpha_bars_t = alpha_bars_t[time_steps].to(xt.device)
|
|
97
|
+
alpha_bars_t_minus_1 = alpha_bars_t[time_steps - 1].to(xt.device) if time_steps.any() else None
|
|
98
|
+
else:
|
|
99
|
+
betas_t = self.hyper_params.betas[time_steps].to(xt.device)
|
|
100
|
+
alphas_t = self.hyper_params.alphas[time_steps].to(xt.device)
|
|
101
|
+
alpha_bars_t = self.hyper_params.alpha_bars[time_steps].to(xt.device)
|
|
102
|
+
alpha_bars_t_minus_1 = self.hyper_params.alpha_bars[time_steps - 1].to(xt.device) if time_steps.any() else None
|
|
103
|
+
|
|
104
|
+
sqrt_alphas_t = torch.sqrt(alphas_t).view(-1, 1, 1, 1)
|
|
105
|
+
sqrt_one_minus_alpha_bars_t = torch.sqrt(1 - alpha_bars_t).view(-1, 1, 1, 1)
|
|
106
|
+
betas_t = betas_t.view(-1, 1, 1, 1)
|
|
107
|
+
|
|
108
|
+
mu = (xt - (betas_t / sqrt_one_minus_alpha_bars_t) * predicted_noise) / sqrt_alphas_t
|
|
109
|
+
|
|
110
|
+
mask = (time_steps == 0)
|
|
111
|
+
if mask.all():
|
|
112
|
+
return mu
|
|
113
|
+
|
|
114
|
+
variance = (1 - alpha_bars_t_minus_1) / (1 - alpha_bars_t) * betas_t.squeeze()
|
|
115
|
+
std = torch.sqrt(variance).view(-1, 1, 1, 1)
|
|
116
|
+
|
|
117
|
+
z = torch.randn_like(xt).to(xt.device)
|
|
118
|
+
xt_minus_1 = mu + (~mask).float().view(-1, 1, 1, 1) * std * z
|
|
119
|
+
return xt_minus_1
|
ldm/sample_ldm.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from transformers import BertTokenizer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SampleLDM(nn.Module):
|
|
7
|
+
"""Sampler for generating images using Latent Diffusion Models (LDM).
|
|
8
|
+
|
|
9
|
+
Generates images by iteratively denoising random noise in the latent space using a
|
|
10
|
+
reverse diffusion process, decoding the result back to the image space with a
|
|
11
|
+
pre-trained compressor, as described in Rombach et al. (2022). Supports DDPM, DDIM,
|
|
12
|
+
and SDE diffusion models, as well as conditional generation with text prompts.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
model : str
|
|
17
|
+
Diffusion model type. Supported: "ddpm", "ddim", "sde".
|
|
18
|
+
reverse_diffusion : nn.Module
|
|
19
|
+
Reverse diffusion module (e.g., ReverseDDPM, ReverseDDIM, ReverseSDE).
|
|
20
|
+
noise_predictor : nn.Module
|
|
21
|
+
Model to predict noise added during the forward diffusion process.
|
|
22
|
+
compressor_model : nn.Module
|
|
23
|
+
Pre-trained model to encode/decode between image and latent spaces (e.g., autoencoder).
|
|
24
|
+
image_shape : tuple
|
|
25
|
+
Shape of generated images as (height, width).
|
|
26
|
+
conditional_model : nn.Module, optional
|
|
27
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
28
|
+
tokenizer : str or BertTokenizer, optional
|
|
29
|
+
Tokenizer for processing text prompts, default "bert-base-uncased".
|
|
30
|
+
batch_size : int, optional
|
|
31
|
+
Number of images to generate per batch (default: 1).
|
|
32
|
+
in_channels : int, optional
|
|
33
|
+
Number of input channels for latent representations (default: 3).
|
|
34
|
+
device : torch.device, optional
|
|
35
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
36
|
+
max_length : int, optional
|
|
37
|
+
Maximum length for tokenized prompts (default: 77).
|
|
38
|
+
output_range : tuple, optional
|
|
39
|
+
Range for clamping generated images (min, max), default (-1, 1).
|
|
40
|
+
|
|
41
|
+
Attributes
|
|
42
|
+
----------
|
|
43
|
+
device : torch.device
|
|
44
|
+
Device used for computation.
|
|
45
|
+
model : str
|
|
46
|
+
Diffusion model type ("ddpm", "ddim", "sde").
|
|
47
|
+
noise_predictor : nn.Module
|
|
48
|
+
Noise prediction model.
|
|
49
|
+
reverse : nn.Module
|
|
50
|
+
Reverse diffusion module.
|
|
51
|
+
compressor : nn.Module
|
|
52
|
+
Compressor model for latent space encoding/decoding.
|
|
53
|
+
conditional_model : nn.Module or None
|
|
54
|
+
Conditional model for text-based generation, if provided.
|
|
55
|
+
tokenizer : BertTokenizer
|
|
56
|
+
Tokenizer for text prompts.
|
|
57
|
+
in_channels : int
|
|
58
|
+
Number of input channels for latent representations.
|
|
59
|
+
image_shape : tuple
|
|
60
|
+
Shape of generated images (height, width).
|
|
61
|
+
batch_size : int
|
|
62
|
+
Batch size for generation.
|
|
63
|
+
max_length : int
|
|
64
|
+
Maximum length for tokenized prompts.
|
|
65
|
+
output_range : tuple
|
|
66
|
+
Range for clamping generated images.
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
ValueError
|
|
71
|
+
If `image_shape` is not a tuple of two positive integers, `batch_size` is not
|
|
72
|
+
positive, `in_channels` is not positive, or `output_range` is not a tuple
|
|
73
|
+
(min, max) with min < max.
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self, model, reverse_diffusion, noise_predictor, compressor_model, image_shape, conditional_model=None,
|
|
76
|
+
tokenizer="bert-base-uncased", batch_size=1, in_channels=3, device=None, max_length=77, output_range=(-1, 1)):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
79
|
+
self.model = model
|
|
80
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
81
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
82
|
+
self.compressor = compressor_model.to(self.device)
|
|
83
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
84
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
85
|
+
self.in_channels = in_channels
|
|
86
|
+
self.image_shape = image_shape
|
|
87
|
+
self.batch_size = batch_size
|
|
88
|
+
self.max_length = max_length
|
|
89
|
+
self.output_range = output_range
|
|
90
|
+
|
|
91
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
|
|
92
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
93
|
+
if batch_size <= 0:
|
|
94
|
+
raise ValueError("batch_size must be positive")
|
|
95
|
+
if in_channels <= 0:
|
|
96
|
+
raise ValueError("in_channels must be positive")
|
|
97
|
+
if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
|
|
98
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
99
|
+
|
|
100
|
+
def tokenize(self, prompts):
|
|
101
|
+
"""Tokenizes text prompts for conditional generation.
|
|
102
|
+
|
|
103
|
+
Converts input prompts into tokenized tensors using the specified tokenizer.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
prompts : str or list
|
|
108
|
+
Text prompt(s) for conditional generation. Can be a single string or a list
|
|
109
|
+
of strings.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
tuple
|
|
114
|
+
A tuple containing:
|
|
115
|
+
- input_ids: Tokenized input IDs (torch.Tensor, shape (batch_size, max_length)).
|
|
116
|
+
- attention_mask: Attention mask for tokenized inputs (torch.Tensor, same shape).
|
|
117
|
+
|
|
118
|
+
Raises
|
|
119
|
+
------
|
|
120
|
+
TypeError
|
|
121
|
+
If `prompts` is not a string or a list of strings.
|
|
122
|
+
"""
|
|
123
|
+
if isinstance(prompts, str):
|
|
124
|
+
prompts = [prompts]
|
|
125
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
126
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
127
|
+
|
|
128
|
+
encoded = self.tokenizer(
|
|
129
|
+
prompts,
|
|
130
|
+
padding="max_length",
|
|
131
|
+
truncation=True,
|
|
132
|
+
max_length=self.max_length,
|
|
133
|
+
return_tensors="pt"
|
|
134
|
+
)
|
|
135
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
136
|
+
|
|
137
|
+
def forward(self, conditions=None, normalize_output=True):
|
|
138
|
+
"""Generates images using the reverse diffusion process in the latent space.
|
|
139
|
+
|
|
140
|
+
Iteratively denoises random noise in the latent space using the specified reverse
|
|
141
|
+
diffusion model (DDPM, DDIM, or SDE), then decodes the result to the image space
|
|
142
|
+
with the compressor model. Supports conditional generation with text prompts.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
conditions : str or list, optional
|
|
147
|
+
Text prompt(s) for conditional generation, default None.
|
|
148
|
+
normalize_output : bool, optional
|
|
149
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
torch.Tensor
|
|
154
|
+
Generated images, shape (batch_size, channels, height, width).
|
|
155
|
+
If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
|
|
156
|
+
they are clamped to `output_range`.
|
|
157
|
+
|
|
158
|
+
Raises
|
|
159
|
+
------
|
|
160
|
+
ValueError
|
|
161
|
+
If `conditions` is provided but no conditional model is specified, if a
|
|
162
|
+
conditional model is specified but `conditions` is None, or if `model` is not
|
|
163
|
+
one of "ddpm", "ddim", "sde".
|
|
164
|
+
|
|
165
|
+
Notes
|
|
166
|
+
-----
|
|
167
|
+
- Sampling is performed with `torch.no_grad()` for efficiency.
|
|
168
|
+
- The noise predictor, reverse diffusion, compressor, and conditional model
|
|
169
|
+
(if applicable) are set to evaluation mode during sampling.
|
|
170
|
+
- For DDIM, uses the subsampled tau schedule (`tau_num_steps`); for DDPM/SDE,
|
|
171
|
+
uses the full number of steps (`num_steps`).
|
|
172
|
+
- The compressor model is assumed to have `encode` and `decode` methods for
|
|
173
|
+
latent space conversion.
|
|
174
|
+
"""
|
|
175
|
+
if conditions is not None and self.conditional_model is None:
|
|
176
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
177
|
+
if conditions is None and self.conditional_model is not None:
|
|
178
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
179
|
+
|
|
180
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
181
|
+
|
|
182
|
+
self.noise_predictor.eval()
|
|
183
|
+
self.compressor.eval()
|
|
184
|
+
self.reverse.eval()
|
|
185
|
+
if self.conditional_model:
|
|
186
|
+
self.conditional_model.eval()
|
|
187
|
+
|
|
188
|
+
with torch.no_grad():
|
|
189
|
+
xt = noisy_samples
|
|
190
|
+
xt, _ = self.compressor.encode(xt)
|
|
191
|
+
|
|
192
|
+
if self.model == "ddim":
|
|
193
|
+
num_steps = self.reverse.hyper_params.tau_num_steps
|
|
194
|
+
elif self.model == "ddpm" or self.model == "sde":
|
|
195
|
+
num_steps = self.reverse.hyper_params.num_steps
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(f"Unknown model: {self.model}. Supported: ddpm, ddim, sde")
|
|
198
|
+
|
|
199
|
+
for t in reversed(range(num_steps)):
|
|
200
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
|
|
201
|
+
prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
|
|
202
|
+
|
|
203
|
+
if self.model == "sde":
|
|
204
|
+
noise = torch.randn_like(xt) if getattr(self.reverse, "method", None) != "ode" else None
|
|
205
|
+
|
|
206
|
+
if self.conditional_model is not None and conditions is not None:
|
|
207
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
208
|
+
key_padding_mask = (attention_masks == 0)
|
|
209
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
210
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y)
|
|
211
|
+
else:
|
|
212
|
+
predicted_noise = self.noise_predictor(xt, time_steps)
|
|
213
|
+
|
|
214
|
+
if self.model == "sde":
|
|
215
|
+
xt = self.reverse(xt, noise, predicted_noise, time_steps)
|
|
216
|
+
elif self.model == "ddim":
|
|
217
|
+
xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
|
|
218
|
+
elif self.model == "ddpm":
|
|
219
|
+
xt = self.reverse(xt, predicted_noise, time_steps)
|
|
220
|
+
else:
|
|
221
|
+
raise ValueError(f"Unknown model: {self.model}. Supported: ddpm, ddim, sde")
|
|
222
|
+
|
|
223
|
+
x = self.compressor.decode(xt)
|
|
224
|
+
generated_imgs = torch.clamp(x, min=self.output_range[0], max=self.output_range[1])
|
|
225
|
+
if normalize_output:
|
|
226
|
+
generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
|
|
227
|
+
|
|
228
|
+
return generated_imgs
|
|
229
|
+
|
|
230
|
+
def to(self, device):
|
|
231
|
+
"""Moves the module and its components to the specified device.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
device : torch.device
|
|
236
|
+
Target device for computation.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
self
|
|
241
|
+
The module moved to the specified device.
|
|
242
|
+
|
|
243
|
+
Notes
|
|
244
|
+
-----
|
|
245
|
+
- Moves `noise_predictor`, `reverse`, `compressor`, and `conditional_model`
|
|
246
|
+
(if applicable) to the specified device.
|
|
247
|
+
"""
|
|
248
|
+
self.device = device
|
|
249
|
+
self.noise_predictor.to(device)
|
|
250
|
+
self.reverse.to(device)
|
|
251
|
+
self.compressor.to(device)
|
|
252
|
+
if self.conditional_model:
|
|
253
|
+
self.conditional_model.to(device)
|
|
254
|
+
return super().to(device)
|