TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ddim/sample_ddim.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Image generation using a trained Denoising Diffusion Implicit Model (DDIM).
|
|
2
|
+
|
|
3
|
+
This module implements the sampling process for generating images with a trained DDIM
|
|
4
|
+
model, as described in Song et al. (2021, "Denoising Diffusion Implicit Models"). It
|
|
5
|
+
supports both unconditional and conditional generation with text prompts, using a
|
|
6
|
+
subsampled time step schedule for faster sampling.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from transformers import BertTokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SampleDDIM(nn.Module):
|
|
16
|
+
"""Image generation using a trained DDIM model.
|
|
17
|
+
|
|
18
|
+
Implements the sampling process for DDIM, generating images by iteratively denoising
|
|
19
|
+
random noise using a trained noise predictor and reverse diffusion process with a
|
|
20
|
+
subsampled time step schedule. Supports conditional generation with text prompts,
|
|
21
|
+
as inspired by Song et al. (2021).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
reverse_diffusion : nn.Module
|
|
26
|
+
Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
|
|
27
|
+
noise_predictor : nn.Module
|
|
28
|
+
Trained model to predict noise at each time step.
|
|
29
|
+
image_shape : tuple
|
|
30
|
+
Tuple of (height, width) specifying the generated image dimensions.
|
|
31
|
+
conditional_model : nn.Module, optional
|
|
32
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
33
|
+
tokenizer : str, optional
|
|
34
|
+
Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
|
|
35
|
+
max_length : int, optional
|
|
36
|
+
Maximum length for tokenized prompts (default: 77).
|
|
37
|
+
batch_size : int, optional
|
|
38
|
+
Number of images to generate per batch (default: 1).
|
|
39
|
+
in_channels : int, optional
|
|
40
|
+
Number of input channels for generated images (default: 3).
|
|
41
|
+
device : torch.device, optional
|
|
42
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
43
|
+
output_range : tuple, optional
|
|
44
|
+
Tuple of (min, max) for clamping generated images (default: (-1, 1)).
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
device : torch.device
|
|
49
|
+
Device used for computation.
|
|
50
|
+
reverse : nn.Module
|
|
51
|
+
Reverse diffusion module.
|
|
52
|
+
noise_predictor : nn.Module
|
|
53
|
+
Noise prediction model.
|
|
54
|
+
conditional_model : nn.Module or None
|
|
55
|
+
Conditional model for text-based generation, if provided.
|
|
56
|
+
tokenizer : BertTokenizer
|
|
57
|
+
Tokenizer for processing text prompts.
|
|
58
|
+
max_length : int
|
|
59
|
+
Maximum length for tokenized prompts.
|
|
60
|
+
in_channels : int
|
|
61
|
+
Number of input channels.
|
|
62
|
+
image_shape : tuple
|
|
63
|
+
Shape of generated images (height, width).
|
|
64
|
+
batch_size : int
|
|
65
|
+
Batch size for generation.
|
|
66
|
+
output_range : tuple
|
|
67
|
+
Range for clamping generated images.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
ValueError
|
|
72
|
+
If `image_shape` is not a tuple of two positive integers, `batch_size` is not
|
|
73
|
+
positive, or `output_range` is not a valid (min, max) tuple with min < max.
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self, reverse_diffusion, noise_predictor, image_shape, conditional_model=None,
|
|
76
|
+
tokenizer="bert-base-uncased",
|
|
77
|
+
max_length=77, batch_size=1, in_channels=3, device=None, output_range=(-1, 1)):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
80
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
81
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
82
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
83
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
84
|
+
self.max_length = max_length
|
|
85
|
+
self.in_channels = in_channels
|
|
86
|
+
self.image_shape = image_shape
|
|
87
|
+
self.batch_size = batch_size
|
|
88
|
+
self.output_range = output_range
|
|
89
|
+
|
|
90
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
|
|
91
|
+
isinstance(s, int) and s > 0 for s in image_shape):
|
|
92
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
93
|
+
if batch_size <= 0:
|
|
94
|
+
raise ValueError("batch_size must be positive")
|
|
95
|
+
if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
|
|
96
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def tokenize(self, prompts):
|
|
100
|
+
"""Tokenizes text prompts for conditional generation.
|
|
101
|
+
|
|
102
|
+
Converts input prompts into tokenized input IDs and attention masks using the
|
|
103
|
+
specified tokenizer, suitable for use with the conditional model.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
prompts : str or list
|
|
108
|
+
A single text prompt or a list of text prompts.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
tuple
|
|
113
|
+
A tuple containing:
|
|
114
|
+
- input_ids: Tokenized input IDs, shape (batch_size, max_length).
|
|
115
|
+
- attention_mask: Attention mask, shape (batch_size, max_length).
|
|
116
|
+
|
|
117
|
+
Raises
|
|
118
|
+
------
|
|
119
|
+
TypeError
|
|
120
|
+
If `prompts` is not a string or a list of strings.
|
|
121
|
+
"""
|
|
122
|
+
if isinstance(prompts, str):
|
|
123
|
+
prompts = [prompts]
|
|
124
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
125
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
126
|
+
encoded = self.tokenizer(
|
|
127
|
+
prompts,
|
|
128
|
+
padding="max_length",
|
|
129
|
+
truncation=True,
|
|
130
|
+
max_length=self.max_length,
|
|
131
|
+
return_tensors="pt"
|
|
132
|
+
)
|
|
133
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
134
|
+
|
|
135
|
+
def forward(self, conditions=None, normalize_output=True):
|
|
136
|
+
"""Generates images using the DDIM sampling process.
|
|
137
|
+
|
|
138
|
+
Iteratively denoises random noise to generate images using the reverse diffusion
|
|
139
|
+
process with a subsampled time step schedule and noise predictor. Supports
|
|
140
|
+
conditional generation with text prompts.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
conditions : str or list, optional
|
|
145
|
+
Text prompt(s) for conditional generation, default None.
|
|
146
|
+
normalize_output : bool, optional
|
|
147
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
torch.Tensor
|
|
152
|
+
Generated images, shape (batch_size, in_channels, height, width).
|
|
153
|
+
If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
|
|
154
|
+
they are clamped to `output_range`.
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
ValueError
|
|
159
|
+
If `conditions` is provided but no conditional model is specified, or if
|
|
160
|
+
a conditional model is specified but `conditions` is None.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
if conditions is not None and self.conditional_model is None:
|
|
164
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
165
|
+
if conditions is None and self.conditional_model is not None:
|
|
166
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
167
|
+
|
|
168
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
169
|
+
|
|
170
|
+
self.noise_predictor.eval()
|
|
171
|
+
self.reverse.eval()
|
|
172
|
+
if self.conditional_model:
|
|
173
|
+
self.conditional_model.eval()
|
|
174
|
+
|
|
175
|
+
with torch.no_grad():
|
|
176
|
+
xt = noisy_samples
|
|
177
|
+
for t in reversed(range(self.reverse.hyper_params.tau_num_steps)):
|
|
178
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
|
|
179
|
+
prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
|
|
180
|
+
|
|
181
|
+
if self.conditional_model is not None and conditions is not None:
|
|
182
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
183
|
+
key_padding_mask = (attention_masks == 0)
|
|
184
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
185
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y)
|
|
186
|
+
else:
|
|
187
|
+
predicted_noise = self.noise_predictor(xt, time_steps)
|
|
188
|
+
|
|
189
|
+
xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
|
|
190
|
+
|
|
191
|
+
generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
|
|
192
|
+
if normalize_output:
|
|
193
|
+
generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
|
|
194
|
+
|
|
195
|
+
return generated_imgs
|
|
196
|
+
|
|
197
|
+
def to(self, device):
|
|
198
|
+
"""Moves the module and its components to the specified device.
|
|
199
|
+
|
|
200
|
+
Updates the device attribute and moves the reverse diffusion, noise predictor,
|
|
201
|
+
and conditional model (if present) to the specified device.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
device : torch.device
|
|
206
|
+
Target device for the module and its components.
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
SampleDDIM
|
|
211
|
+
The module itself, moved to the specified device.
|
|
212
|
+
"""
|
|
213
|
+
self.device = device
|
|
214
|
+
self.noise_predictor.to(device)
|
|
215
|
+
self.reverse.to(device)
|
|
216
|
+
self.compressor.to(device)
|
|
217
|
+
if self.conditional_model:
|
|
218
|
+
self.conditional_model.to(device)
|
|
219
|
+
return super().to(device)
|
ddim/text_encoder.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
from transformers import BertModel, DistilBertModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TextEncoder(torch.nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
use_pretrained_model=True,
|
|
13
|
+
model_name="bert-base-uncased",
|
|
14
|
+
vocabulary_size=30522,
|
|
15
|
+
num_layers=6,
|
|
16
|
+
input_dimension=768,
|
|
17
|
+
output_dimension=768,
|
|
18
|
+
num_heads=8,
|
|
19
|
+
context_length=77,
|
|
20
|
+
dropout_rate=0.1,
|
|
21
|
+
qkv_bias=False,
|
|
22
|
+
scaling_value=4,
|
|
23
|
+
epsilon=1e-5
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.use_pretrained_model = use_pretrained_model
|
|
27
|
+
if self.use_pretrained_model:
|
|
28
|
+
# self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
|
29
|
+
self.bert = BertModel.from_pretrained(model_name)
|
|
30
|
+
for param in self.bert.parameters():
|
|
31
|
+
param.requires_grad = False
|
|
32
|
+
self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
|
|
33
|
+
else:
|
|
34
|
+
self.embedding = Embedding(
|
|
35
|
+
vocabulary_size=vocabulary_size,
|
|
36
|
+
embedding_dimension=input_dimension,
|
|
37
|
+
context_length=context_length
|
|
38
|
+
)
|
|
39
|
+
self.layers = torch.nn.ModuleList([
|
|
40
|
+
EncoderLayer(
|
|
41
|
+
input_dimension=input_dimension,
|
|
42
|
+
output_dimension=output_dimension,
|
|
43
|
+
num_heads=num_heads,
|
|
44
|
+
dropout_rate=dropout_rate,
|
|
45
|
+
qkv_bias=qkv_bias,
|
|
46
|
+
scaling_value=scaling_value,
|
|
47
|
+
epsilon=epsilon
|
|
48
|
+
)
|
|
49
|
+
for _ in range(num_layers)
|
|
50
|
+
])
|
|
51
|
+
def forward(self, x, attention_mask=None):
|
|
52
|
+
if self.use_pretrained_model:
|
|
53
|
+
x = self.bert(input_ids=x, attention_mask=attention_mask)
|
|
54
|
+
x = x.last_hidden_state
|
|
55
|
+
x = self.projection(x)
|
|
56
|
+
else:
|
|
57
|
+
x = self.embedding(x)
|
|
58
|
+
for layer in self.layers:
|
|
59
|
+
x = layer(x, attention_mask=attention_mask)
|
|
60
|
+
return x
|
|
61
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
62
|
+
class EncoderLayer(torch.nn.Module):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
input_dimension,
|
|
66
|
+
output_dimension,
|
|
67
|
+
num_heads,
|
|
68
|
+
dropout_rate,
|
|
69
|
+
qkv_bias,
|
|
70
|
+
scaling_value,
|
|
71
|
+
epsilon=1e-5
|
|
72
|
+
):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.attention = nn.MultiheadAttention(
|
|
75
|
+
embed_dim=input_dimension,
|
|
76
|
+
num_heads=num_heads,
|
|
77
|
+
dropout=dropout_rate,
|
|
78
|
+
bias=qkv_bias,
|
|
79
|
+
batch_first=True
|
|
80
|
+
)
|
|
81
|
+
self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
|
|
82
|
+
self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
|
|
83
|
+
self.dropout1 = nn.Dropout(dropout_rate)
|
|
84
|
+
self.feedforward = FeedForward(
|
|
85
|
+
embedding_dimension=input_dimension,
|
|
86
|
+
scaling_value=scaling_value,
|
|
87
|
+
dropout_rate=dropout_rate
|
|
88
|
+
)
|
|
89
|
+
self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
|
|
90
|
+
self.dropout2 = nn.Dropout(dropout_rate)
|
|
91
|
+
def forward(self, x, attention_mask=None):
|
|
92
|
+
attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
|
|
93
|
+
attn_output = self.output_projection(attn_output)
|
|
94
|
+
x = self.norm1(x + self.dropout1(attn_output))
|
|
95
|
+
ff_output = self.feedforward(x)
|
|
96
|
+
x = self.norm2(x + self.dropout2(ff_output))
|
|
97
|
+
return x
|
|
98
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
99
|
+
class FeedForward(torch.nn.Module):
|
|
100
|
+
def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.layers = torch.nn.Sequential(
|
|
103
|
+
torch.nn.Linear(
|
|
104
|
+
in_features=embedding_dimension,
|
|
105
|
+
out_features=embedding_dimension * scaling_value,
|
|
106
|
+
bias=True
|
|
107
|
+
),
|
|
108
|
+
torch.nn.GELU(),
|
|
109
|
+
torch.nn.Dropout(dropout_rate),
|
|
110
|
+
torch.nn.Linear(
|
|
111
|
+
in_features=embedding_dimension * scaling_value,
|
|
112
|
+
out_features=embedding_dimension,
|
|
113
|
+
bias=True
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
def forward(self, x):
|
|
117
|
+
return self.layers(x)
|
|
118
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
119
|
+
class Embedding(torch.nn.Module):
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
vocabulary_size,
|
|
123
|
+
embedding_dimension=768,
|
|
124
|
+
context_length=77
|
|
125
|
+
):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.token_embedding = nn.Embedding(
|
|
128
|
+
num_embeddings=vocabulary_size,
|
|
129
|
+
embedding_dim=embedding_dimension
|
|
130
|
+
)
|
|
131
|
+
self.embedding_dimension = embedding_dimension
|
|
132
|
+
self.context_length = context_length
|
|
133
|
+
self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
|
|
134
|
+
|
|
135
|
+
def _generate_positional_encoding(self, seq_len):
|
|
136
|
+
position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
|
|
137
|
+
div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
|
|
138
|
+
-(math.log(10000.0) / self.embedding_dimension))
|
|
139
|
+
pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
|
|
140
|
+
pos_enc[:, 0::2] = torch.sin(position * div_term)
|
|
141
|
+
pos_enc[:, 1::2] = torch.cos(position * div_term)
|
|
142
|
+
return pos_enc.unsqueeze(0)
|
|
143
|
+
|
|
144
|
+
def forward(self, token_ids):
|
|
145
|
+
assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
|
|
146
|
+
token_embedded = self.token_embedding(token_ids)
|
|
147
|
+
seq_len = token_ids.size(1)
|
|
148
|
+
if seq_len > self.context_length:
|
|
149
|
+
position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
|
|
150
|
+
else:
|
|
151
|
+
position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
|
|
152
|
+
return token_embedded + position_encoded
|