TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ddpm/sample_ddpm.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Image generation using a trained Denoising Diffusion Probabilistic Model (DDPM).
|
|
2
|
+
|
|
3
|
+
This module implements the sampling process for generating images with a trained DDPM
|
|
4
|
+
model, as described in Ho et al. (2020, "Denoising Diffusion Probabilistic Models").
|
|
5
|
+
It supports both unconditional and conditional generation with text prompts.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from transformers import BertTokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SampleDDPM(nn.Module):
|
|
16
|
+
"""Image generation using a trained DDPM model.
|
|
17
|
+
|
|
18
|
+
Implements the sampling process for DDPM, generating images by iteratively
|
|
19
|
+
denoising random noise using a trained noise predictor and reverse diffusion
|
|
20
|
+
process. Supports conditional generation with text prompts via a conditional
|
|
21
|
+
model, as inspired by Ho et al. (2020).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
reverse_diffusion : nn.Module
|
|
26
|
+
Reverse diffusion module (e.g., ReverseDDPM) for the reverse process.
|
|
27
|
+
noise_predictor : nn.Module
|
|
28
|
+
Trained model to predict noise at each time step.
|
|
29
|
+
image_shape : tuple
|
|
30
|
+
Tuple of (height, width) specifying the generated image dimensions.
|
|
31
|
+
conditional_model : nn.Module, optional
|
|
32
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
33
|
+
tokenizer : str, optional
|
|
34
|
+
Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
|
|
35
|
+
max_length : int, optional
|
|
36
|
+
Maximum length for tokenized prompts (default: 77).
|
|
37
|
+
batch_size : int, optional
|
|
38
|
+
Number of images to generate per batch (default: 1).
|
|
39
|
+
in_channels : int, optional
|
|
40
|
+
Number of input channels for generated images (default: 3).
|
|
41
|
+
device : torch.device, optional
|
|
42
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
43
|
+
output_range : tuple, optional
|
|
44
|
+
Tuple of (min, max) for clamping generated images (default: (-1, 1)).
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
device : torch.device
|
|
49
|
+
Device used for computation.
|
|
50
|
+
reverse : nn.Module
|
|
51
|
+
Reverse diffusion module.
|
|
52
|
+
noise_predictor : nn.Module
|
|
53
|
+
Noise prediction model.
|
|
54
|
+
conditional_model : nn.Module or None
|
|
55
|
+
Conditional model for text-based generation, if provided.
|
|
56
|
+
tokenizer : BertTokenizer
|
|
57
|
+
Tokenizer for processing text prompts.
|
|
58
|
+
max_length : int
|
|
59
|
+
Maximum length for tokenized prompts.
|
|
60
|
+
in_channels : int
|
|
61
|
+
Number of input channels.
|
|
62
|
+
image_shape : tuple
|
|
63
|
+
Shape of generated images (height, width).
|
|
64
|
+
batch_size : int
|
|
65
|
+
Batch size for generation.
|
|
66
|
+
output_range : tuple
|
|
67
|
+
Range for clamping generated images.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
ValueError
|
|
72
|
+
If `image_shape` is not a tuple of two positive integers, `batch_size` is not
|
|
73
|
+
positive, or `output_range` is not a valid (min, max) tuple with min < max.
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self, reverse_diffusion, noise_predictor, image_shape, conditional_model=None, tokenizer="bert-base-uncased",
|
|
76
|
+
max_length=77, batch_size=1, in_channels=3, device=None, output_range=(-1, 1)):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
79
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
80
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
81
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
82
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
83
|
+
self.max_length = max_length
|
|
84
|
+
self.in_channels = in_channels
|
|
85
|
+
self.image_shape = image_shape
|
|
86
|
+
self.batch_size = batch_size
|
|
87
|
+
self.output_range = output_range
|
|
88
|
+
|
|
89
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
|
|
90
|
+
isinstance(s, int) and s > 0 for s in image_shape):
|
|
91
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
92
|
+
if batch_size <= 0:
|
|
93
|
+
raise ValueError("batch_size must be positive")
|
|
94
|
+
if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
|
|
95
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
96
|
+
|
|
97
|
+
def tokenize(self, prompts):
|
|
98
|
+
"""Tokenizes text prompts for conditional generation.
|
|
99
|
+
|
|
100
|
+
Converts input prompts into tokenized input IDs and attention masks using the
|
|
101
|
+
specified tokenizer, suitable for use with the conditional model.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
prompts : str or list
|
|
106
|
+
A single text prompt or a list of text prompts.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
tuple
|
|
111
|
+
A tuple containing:
|
|
112
|
+
- input_ids: Tokenized input IDs, shape (batch_size, max_length).
|
|
113
|
+
- attention_mask: Attention mask, shape (batch_size, max_length).
|
|
114
|
+
|
|
115
|
+
Raises
|
|
116
|
+
------
|
|
117
|
+
TypeError
|
|
118
|
+
If `prompts` is not a string or a list of strings.
|
|
119
|
+
"""
|
|
120
|
+
if isinstance(prompts, str):
|
|
121
|
+
prompts = [prompts]
|
|
122
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
123
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
124
|
+
encoded = self.tokenizer(
|
|
125
|
+
prompts,
|
|
126
|
+
padding="max_length",
|
|
127
|
+
truncation=True,
|
|
128
|
+
max_length=self.max_length,
|
|
129
|
+
return_tensors="pt"
|
|
130
|
+
)
|
|
131
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
132
|
+
|
|
133
|
+
def forward(self, conditions=None, normalize_output=True):
|
|
134
|
+
"""Generates images using the DDPM sampling process.
|
|
135
|
+
|
|
136
|
+
Iteratively denoises random noise to generate images using the reverse diffusion
|
|
137
|
+
process and noise predictor. Supports conditional generation with text prompts.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
conditions : str or list, optional
|
|
142
|
+
Text prompt(s) for conditional generation, default None.
|
|
143
|
+
normalize_output : bool, optional
|
|
144
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
torch.Tensor
|
|
149
|
+
Generated images, shape (batch_size, in_channels, height, width).
|
|
150
|
+
If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
|
|
151
|
+
they are clamped to `output_range`.
|
|
152
|
+
|
|
153
|
+
Raises
|
|
154
|
+
------
|
|
155
|
+
ValueError
|
|
156
|
+
If `conditions` is provided but no conditional model is specified, or if
|
|
157
|
+
a conditional model is specified but `conditions` is None.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
if conditions is not None and self.conditional_model is None:
|
|
161
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
162
|
+
if conditions is None and self.conditional_model is not None:
|
|
163
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
164
|
+
|
|
165
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
166
|
+
|
|
167
|
+
self.noise_predictor.eval()
|
|
168
|
+
self.reverse.eval()
|
|
169
|
+
if self.conditional_model:
|
|
170
|
+
self.conditional_model.eval()
|
|
171
|
+
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
xt = noisy_samples
|
|
174
|
+
for t in reversed(range(self.reverse.hyper_params.num_steps)):
|
|
175
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
|
|
176
|
+
if self.conditional_model is not None and conditions is not None:
|
|
177
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
178
|
+
key_padding_mask = (attention_masks == 0)
|
|
179
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
180
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y)
|
|
181
|
+
else:
|
|
182
|
+
predicted_noise = self.noise_predictor(xt, time_steps)
|
|
183
|
+
xt = self.reverse(xt, predicted_noise, time_steps)
|
|
184
|
+
|
|
185
|
+
generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
|
|
186
|
+
if normalize_output:
|
|
187
|
+
generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
|
|
188
|
+
|
|
189
|
+
return generated_imgs
|
|
190
|
+
|
|
191
|
+
def to(self, device):
|
|
192
|
+
"""Moves the module and its components to the specified device.
|
|
193
|
+
|
|
194
|
+
Updates the device attribute and moves the reverse diffusion, noise predictor,
|
|
195
|
+
and conditional model (if present) to the specified device.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
device : torch.device
|
|
200
|
+
Target device for the module and its components.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
SampleDDPM
|
|
205
|
+
The module itself, moved to the specified device.
|
|
206
|
+
"""
|
|
207
|
+
self.device = device
|
|
208
|
+
self.noise_predictor.to(device)
|
|
209
|
+
self.reverse.to(device)
|
|
210
|
+
self.compressor.to(device)
|
|
211
|
+
if self.conditional_model:
|
|
212
|
+
self.conditional_model.to(device)
|
|
213
|
+
return super().to(device)
|
ddpm/text_encoder.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
from transformers import BertModel, DistilBertModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TextEncoder(torch.nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
use_pretrained_model=True,
|
|
13
|
+
model_name="bert-base-uncased",
|
|
14
|
+
vocabulary_size=30522,
|
|
15
|
+
num_layers=6,
|
|
16
|
+
input_dimension=768,
|
|
17
|
+
output_dimension=768,
|
|
18
|
+
num_heads=8,
|
|
19
|
+
context_length=77,
|
|
20
|
+
dropout_rate=0.1,
|
|
21
|
+
qkv_bias=False,
|
|
22
|
+
scaling_value=4,
|
|
23
|
+
epsilon=1e-5
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.use_pretrained_model = use_pretrained_model
|
|
27
|
+
if self.use_pretrained_model:
|
|
28
|
+
# self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
|
29
|
+
self.bert = BertModel.from_pretrained(model_name)
|
|
30
|
+
for param in self.bert.parameters():
|
|
31
|
+
param.requires_grad = False
|
|
32
|
+
self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
|
|
33
|
+
else:
|
|
34
|
+
self.embedding = Embedding(
|
|
35
|
+
vocabulary_size=vocabulary_size,
|
|
36
|
+
embedding_dimension=input_dimension,
|
|
37
|
+
context_length=context_length
|
|
38
|
+
)
|
|
39
|
+
self.layers = torch.nn.ModuleList([
|
|
40
|
+
EncoderLayer(
|
|
41
|
+
input_dimension=input_dimension,
|
|
42
|
+
output_dimension=output_dimension,
|
|
43
|
+
num_heads=num_heads,
|
|
44
|
+
dropout_rate=dropout_rate,
|
|
45
|
+
qkv_bias=qkv_bias,
|
|
46
|
+
scaling_value=scaling_value,
|
|
47
|
+
epsilon=epsilon
|
|
48
|
+
)
|
|
49
|
+
for _ in range(num_layers)
|
|
50
|
+
])
|
|
51
|
+
def forward(self, x, attention_mask=None):
|
|
52
|
+
if self.use_pretrained_model:
|
|
53
|
+
x = self.bert(input_ids=x, attention_mask=attention_mask)
|
|
54
|
+
x = x.last_hidden_state
|
|
55
|
+
x = self.projection(x)
|
|
56
|
+
else:
|
|
57
|
+
x = self.embedding(x)
|
|
58
|
+
for layer in self.layers:
|
|
59
|
+
x = layer(x, attention_mask=attention_mask)
|
|
60
|
+
return x
|
|
61
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
62
|
+
class EncoderLayer(torch.nn.Module):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
input_dimension,
|
|
66
|
+
output_dimension,
|
|
67
|
+
num_heads,
|
|
68
|
+
dropout_rate,
|
|
69
|
+
qkv_bias,
|
|
70
|
+
scaling_value,
|
|
71
|
+
epsilon=1e-5
|
|
72
|
+
):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.attention = nn.MultiheadAttention(
|
|
75
|
+
embed_dim=input_dimension,
|
|
76
|
+
num_heads=num_heads,
|
|
77
|
+
dropout=dropout_rate,
|
|
78
|
+
bias=qkv_bias,
|
|
79
|
+
batch_first=True
|
|
80
|
+
)
|
|
81
|
+
self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
|
|
82
|
+
self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
|
|
83
|
+
self.dropout1 = nn.Dropout(dropout_rate)
|
|
84
|
+
self.feedforward = FeedForward(
|
|
85
|
+
embedding_dimension=input_dimension,
|
|
86
|
+
scaling_value=scaling_value,
|
|
87
|
+
dropout_rate=dropout_rate
|
|
88
|
+
)
|
|
89
|
+
self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
|
|
90
|
+
self.dropout2 = nn.Dropout(dropout_rate)
|
|
91
|
+
def forward(self, x, attention_mask=None):
|
|
92
|
+
attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
|
|
93
|
+
attn_output = self.output_projection(attn_output)
|
|
94
|
+
x = self.norm1(x + self.dropout1(attn_output))
|
|
95
|
+
ff_output = self.feedforward(x)
|
|
96
|
+
x = self.norm2(x + self.dropout2(ff_output))
|
|
97
|
+
return x
|
|
98
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
99
|
+
class FeedForward(torch.nn.Module):
|
|
100
|
+
def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.layers = torch.nn.Sequential(
|
|
103
|
+
torch.nn.Linear(
|
|
104
|
+
in_features=embedding_dimension,
|
|
105
|
+
out_features=embedding_dimension * scaling_value,
|
|
106
|
+
bias=True
|
|
107
|
+
),
|
|
108
|
+
torch.nn.GELU(),
|
|
109
|
+
torch.nn.Dropout(dropout_rate),
|
|
110
|
+
torch.nn.Linear(
|
|
111
|
+
in_features=embedding_dimension * scaling_value,
|
|
112
|
+
out_features=embedding_dimension,
|
|
113
|
+
bias=True
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
def forward(self, x):
|
|
117
|
+
return self.layers(x)
|
|
118
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
119
|
+
class Embedding(torch.nn.Module):
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
vocabulary_size,
|
|
123
|
+
embedding_dimension=768,
|
|
124
|
+
context_length=77
|
|
125
|
+
):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.token_embedding = nn.Embedding(
|
|
128
|
+
num_embeddings=vocabulary_size,
|
|
129
|
+
embedding_dim=embedding_dimension
|
|
130
|
+
)
|
|
131
|
+
self.embedding_dimension = embedding_dimension
|
|
132
|
+
self.context_length = context_length
|
|
133
|
+
self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
|
|
134
|
+
|
|
135
|
+
def _generate_positional_encoding(self, seq_len):
|
|
136
|
+
position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
|
|
137
|
+
div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
|
|
138
|
+
-(math.log(10000.0) / self.embedding_dimension))
|
|
139
|
+
pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
|
|
140
|
+
pos_enc[:, 0::2] = torch.sin(position * div_term)
|
|
141
|
+
pos_enc[:, 1::2] = torch.cos(position * div_term)
|
|
142
|
+
return pos_enc.unsqueeze(0)
|
|
143
|
+
|
|
144
|
+
def forward(self, token_ids):
|
|
145
|
+
assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
|
|
146
|
+
token_embedded = self.token_embedding(token_ids)
|
|
147
|
+
seq_len = token_ids.size(1)
|
|
148
|
+
if seq_len > self.context_length:
|
|
149
|
+
position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
|
|
150
|
+
else:
|
|
151
|
+
position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
|
|
152
|
+
return token_embedded + position_encoded
|