TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
sde/sample_sde.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from transformers import BertTokenizer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SampleSDE(nn.Module):
|
|
9
|
+
"""Sampler for generating images using SDE-based generative models.
|
|
10
|
+
|
|
11
|
+
Generates images by iteratively denoising random noise using the reverse SDE process
|
|
12
|
+
and a trained noise predictor, as described in Song et al. (2021). Supports both
|
|
13
|
+
unconditional and conditional generation with text prompts.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
reverse_diffusion : ReverseSDE
|
|
18
|
+
Reverse SDE diffusion module for denoising.
|
|
19
|
+
noise_predictor : nn.Module
|
|
20
|
+
Model to predict noise added during the forward SDE process.
|
|
21
|
+
image_shape : tuple
|
|
22
|
+
Shape of generated images as (height, width).
|
|
23
|
+
conditional_model : nn.Module, optional
|
|
24
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
25
|
+
tokenizer : str or BertTokenizer, optional
|
|
26
|
+
Tokenizer for processing text prompts, default "bert-base-uncased".
|
|
27
|
+
max_length : int, optional
|
|
28
|
+
Maximum length for tokenized prompts (default: 77).
|
|
29
|
+
batch_size : int, optional
|
|
30
|
+
Number of images to generate per batch (default: 1).
|
|
31
|
+
in_channels : int, optional
|
|
32
|
+
Number of input channels for generated images (default: 3).
|
|
33
|
+
device : torch.device, optional
|
|
34
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
35
|
+
output_range : tuple, optional
|
|
36
|
+
Range for clamping generated images (min, max), default (-1, 1).
|
|
37
|
+
|
|
38
|
+
Attributes
|
|
39
|
+
----------
|
|
40
|
+
device : torch.device
|
|
41
|
+
Device used for computation.
|
|
42
|
+
reverse : ReverseSDE
|
|
43
|
+
Reverse SDE diffusion module.
|
|
44
|
+
noise_predictor : nn.Module
|
|
45
|
+
Noise prediction model.
|
|
46
|
+
conditional_model : nn.Module or None
|
|
47
|
+
Conditional model for text-based generation, if provided.
|
|
48
|
+
tokenizer : BertTokenizer
|
|
49
|
+
Tokenizer for text prompts.
|
|
50
|
+
max_length : int
|
|
51
|
+
Maximum length for tokenized prompts.
|
|
52
|
+
in_channels : int
|
|
53
|
+
Number of input channels.
|
|
54
|
+
image_shape : tuple
|
|
55
|
+
Shape of generated images (height, width).
|
|
56
|
+
batch_size : int
|
|
57
|
+
Batch size for generation.
|
|
58
|
+
output_range : tuple
|
|
59
|
+
Range for clamping generated images.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If `image_shape` is not a tuple of two positive integers, `batch_size` is not
|
|
65
|
+
positive, or `output_range` is not a tuple (min, max) with min < max.
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self, reverse_diffusion, noise_predictor, image_shape, conditional_model=None,
|
|
68
|
+
tokenizer="bert-base-uncased", max_length=77, batch_size=1, in_channels=3, device=None, output_range=(-1, 1)):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
71
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
72
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
73
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
74
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
75
|
+
self.max_length = max_length
|
|
76
|
+
self.in_channels = in_channels
|
|
77
|
+
self.image_shape = image_shape
|
|
78
|
+
self.batch_size = batch_size
|
|
79
|
+
self.output_range = output_range
|
|
80
|
+
|
|
81
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
|
|
82
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
83
|
+
if batch_size <= 0:
|
|
84
|
+
raise ValueError("batch_size must be positive")
|
|
85
|
+
if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
|
|
86
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
87
|
+
|
|
88
|
+
def tokenize(self, prompts):
|
|
89
|
+
"""Tokenizes text prompts for conditional generation.
|
|
90
|
+
|
|
91
|
+
Converts input prompts into tokenized tensors using the specified tokenizer.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
prompts : str or list
|
|
96
|
+
Text prompt(s) for conditional generation. Can be a single string or a list
|
|
97
|
+
of strings.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
tuple
|
|
102
|
+
A tuple containing:
|
|
103
|
+
- input_ids: Tokenized input IDs (torch.Tensor, shape (batch_size, max_length)).
|
|
104
|
+
- attention_mask: Attention mask for tokenized inputs (torch.Tensor, same shape).
|
|
105
|
+
|
|
106
|
+
Raises
|
|
107
|
+
------
|
|
108
|
+
TypeError
|
|
109
|
+
If `prompts` is not a string or a list of strings.
|
|
110
|
+
"""
|
|
111
|
+
if isinstance(prompts, str):
|
|
112
|
+
prompts = [prompts]
|
|
113
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
114
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
115
|
+
encoded = self.tokenizer(
|
|
116
|
+
prompts,
|
|
117
|
+
padding="max_length",
|
|
118
|
+
truncation=True,
|
|
119
|
+
max_length=self.max_length,
|
|
120
|
+
return_tensors="pt"
|
|
121
|
+
)
|
|
122
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
123
|
+
|
|
124
|
+
def forward(self, conditions=None, normalize_output=True):
|
|
125
|
+
"""Generates images using the reverse SDE sampling process.
|
|
126
|
+
|
|
127
|
+
Iteratively denoises random noise to generate images using the reverse SDE process
|
|
128
|
+
and noise predictor. Supports conditional generation with text prompts.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
conditions : str or list, optional
|
|
133
|
+
Text prompt(s) for conditional generation, default None.
|
|
134
|
+
normalize_output : bool, optional
|
|
135
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
torch.Tensor
|
|
140
|
+
Generated images, shape (batch_size, in_channels, height, width).
|
|
141
|
+
If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
|
|
142
|
+
they are clamped to `output_range`.
|
|
143
|
+
|
|
144
|
+
Raises
|
|
145
|
+
------
|
|
146
|
+
ValueError
|
|
147
|
+
If `conditions` is provided but no conditional model is specified, or if
|
|
148
|
+
a conditional model is specified but `conditions` is None.
|
|
149
|
+
|
|
150
|
+
Notes
|
|
151
|
+
-----
|
|
152
|
+
- Sampling is performed with `torch.no_grad()` for efficiency.
|
|
153
|
+
- The noise predictor, reverse SDE, and conditional model (if applicable) are set
|
|
154
|
+
to evaluation mode during sampling.
|
|
155
|
+
"""
|
|
156
|
+
if conditions is not None and self.conditional_model is None:
|
|
157
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
158
|
+
if conditions is None and self.conditional_model is not None:
|
|
159
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
160
|
+
|
|
161
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
162
|
+
|
|
163
|
+
self.noise_predictor.eval()
|
|
164
|
+
self.reverse.eval()
|
|
165
|
+
if self.conditional_model:
|
|
166
|
+
self.conditional_model.eval()
|
|
167
|
+
|
|
168
|
+
with torch.no_grad():
|
|
169
|
+
xt = noisy_samples
|
|
170
|
+
for t in reversed(range(self.reverse.hyper_params.num_steps)):
|
|
171
|
+
noise = torch.randn_like(xt) if self.reverse.method != "ode" else None
|
|
172
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
|
|
173
|
+
|
|
174
|
+
if self.conditional_model is not None and conditions is not None:
|
|
175
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
176
|
+
key_padding_mask = (attention_masks == 0)
|
|
177
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
178
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y)
|
|
179
|
+
else:
|
|
180
|
+
predicted_noise = self.noise_predictor(xt, time_steps)
|
|
181
|
+
|
|
182
|
+
xt = self.reverse(xt, noise, predicted_noise, time_steps)
|
|
183
|
+
|
|
184
|
+
generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
|
|
185
|
+
if normalize_output:
|
|
186
|
+
generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
|
|
187
|
+
|
|
188
|
+
return generated_imgs
|
|
189
|
+
|
|
190
|
+
def to(self, device):
|
|
191
|
+
"""Moves the module and its components to the specified device.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
device : torch.device
|
|
196
|
+
Target device for computation.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
self
|
|
201
|
+
The module moved to the specified device.
|
|
202
|
+
|
|
203
|
+
Notes
|
|
204
|
+
-----
|
|
205
|
+
- Moves `noise_predictor`, `reverse`, and `conditional_model` (if applicable) to
|
|
206
|
+
the specified device.
|
|
207
|
+
- The `compressor` attribute is not defined in this implementation and should be
|
|
208
|
+
removed or implemented if intended.
|
|
209
|
+
"""
|
|
210
|
+
self.device = device
|
|
211
|
+
self.noise_predictor.to(device)
|
|
212
|
+
self.reverse.to(device)
|
|
213
|
+
self.compressor.to(device)
|
|
214
|
+
if self.conditional_model:
|
|
215
|
+
self.conditional_model.to(device)
|
|
216
|
+
return super().to(device)
|
sde/text_encoder.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
from transformers import BertModel, DistilBertModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TextEncoder(torch.nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
use_pretrained_model=True,
|
|
13
|
+
model_name="bert-base-uncased",
|
|
14
|
+
vocabulary_size=30522,
|
|
15
|
+
num_layers=6,
|
|
16
|
+
input_dimension=768,
|
|
17
|
+
output_dimension=768,
|
|
18
|
+
num_heads=8,
|
|
19
|
+
context_length=77,
|
|
20
|
+
dropout_rate=0.1,
|
|
21
|
+
qkv_bias=False,
|
|
22
|
+
scaling_value=4,
|
|
23
|
+
epsilon=1e-5
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.use_pretrained_model = use_pretrained_model
|
|
27
|
+
if self.use_pretrained_model:
|
|
28
|
+
# self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
|
29
|
+
self.bert = BertModel.from_pretrained(model_name)
|
|
30
|
+
for param in self.bert.parameters():
|
|
31
|
+
param.requires_grad = False
|
|
32
|
+
self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
|
|
33
|
+
else:
|
|
34
|
+
self.embedding = Embedding(
|
|
35
|
+
vocabulary_size=vocabulary_size,
|
|
36
|
+
embedding_dimension=input_dimension,
|
|
37
|
+
context_length=context_length
|
|
38
|
+
)
|
|
39
|
+
self.layers = torch.nn.ModuleList([
|
|
40
|
+
EncoderLayer(
|
|
41
|
+
input_dimension=input_dimension,
|
|
42
|
+
output_dimension=output_dimension,
|
|
43
|
+
num_heads=num_heads,
|
|
44
|
+
dropout_rate=dropout_rate,
|
|
45
|
+
qkv_bias=qkv_bias,
|
|
46
|
+
scaling_value=scaling_value,
|
|
47
|
+
epsilon=epsilon
|
|
48
|
+
)
|
|
49
|
+
for _ in range(num_layers)
|
|
50
|
+
])
|
|
51
|
+
def forward(self, x, attention_mask=None):
|
|
52
|
+
if self.use_pretrained_model:
|
|
53
|
+
x = self.bert(input_ids=x, attention_mask=attention_mask)
|
|
54
|
+
x = x.last_hidden_state
|
|
55
|
+
x = self.projection(x)
|
|
56
|
+
else:
|
|
57
|
+
x = self.embedding(x)
|
|
58
|
+
for layer in self.layers:
|
|
59
|
+
x = layer(x, attention_mask=attention_mask)
|
|
60
|
+
return x
|
|
61
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
62
|
+
class EncoderLayer(torch.nn.Module):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
input_dimension,
|
|
66
|
+
output_dimension,
|
|
67
|
+
num_heads,
|
|
68
|
+
dropout_rate,
|
|
69
|
+
qkv_bias,
|
|
70
|
+
scaling_value,
|
|
71
|
+
epsilon=1e-5
|
|
72
|
+
):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.attention = nn.MultiheadAttention(
|
|
75
|
+
embed_dim=input_dimension,
|
|
76
|
+
num_heads=num_heads,
|
|
77
|
+
dropout=dropout_rate,
|
|
78
|
+
bias=qkv_bias,
|
|
79
|
+
batch_first=True
|
|
80
|
+
)
|
|
81
|
+
self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
|
|
82
|
+
self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
|
|
83
|
+
self.dropout1 = nn.Dropout(dropout_rate)
|
|
84
|
+
self.feedforward = FeedForward(
|
|
85
|
+
embedding_dimension=input_dimension,
|
|
86
|
+
scaling_value=scaling_value,
|
|
87
|
+
dropout_rate=dropout_rate
|
|
88
|
+
)
|
|
89
|
+
self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
|
|
90
|
+
self.dropout2 = nn.Dropout(dropout_rate)
|
|
91
|
+
def forward(self, x, attention_mask=None):
|
|
92
|
+
attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
|
|
93
|
+
attn_output = self.output_projection(attn_output)
|
|
94
|
+
x = self.norm1(x + self.dropout1(attn_output))
|
|
95
|
+
ff_output = self.feedforward(x)
|
|
96
|
+
x = self.norm2(x + self.dropout2(ff_output))
|
|
97
|
+
return x
|
|
98
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
99
|
+
class FeedForward(torch.nn.Module):
|
|
100
|
+
def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.layers = torch.nn.Sequential(
|
|
103
|
+
torch.nn.Linear(
|
|
104
|
+
in_features=embedding_dimension,
|
|
105
|
+
out_features=embedding_dimension * scaling_value,
|
|
106
|
+
bias=True
|
|
107
|
+
),
|
|
108
|
+
torch.nn.GELU(),
|
|
109
|
+
torch.nn.Dropout(dropout_rate),
|
|
110
|
+
torch.nn.Linear(
|
|
111
|
+
in_features=embedding_dimension * scaling_value,
|
|
112
|
+
out_features=embedding_dimension,
|
|
113
|
+
bias=True
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
def forward(self, x):
|
|
117
|
+
return self.layers(x)
|
|
118
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
119
|
+
class Embedding(torch.nn.Module):
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
vocabulary_size,
|
|
123
|
+
embedding_dimension=768,
|
|
124
|
+
context_length=77
|
|
125
|
+
):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.token_embedding = nn.Embedding(
|
|
128
|
+
num_embeddings=vocabulary_size,
|
|
129
|
+
embedding_dim=embedding_dimension
|
|
130
|
+
)
|
|
131
|
+
self.embedding_dimension = embedding_dimension
|
|
132
|
+
self.context_length = context_length
|
|
133
|
+
self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
|
|
134
|
+
|
|
135
|
+
def _generate_positional_encoding(self, seq_len):
|
|
136
|
+
position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
|
|
137
|
+
div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
|
|
138
|
+
-(math.log(10000.0) / self.embedding_dimension))
|
|
139
|
+
pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
|
|
140
|
+
pos_enc[:, 0::2] = torch.sin(position * div_term)
|
|
141
|
+
pos_enc[:, 1::2] = torch.cos(position * div_term)
|
|
142
|
+
return pos_enc.unsqueeze(0)
|
|
143
|
+
|
|
144
|
+
def forward(self, token_ids):
|
|
145
|
+
assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
|
|
146
|
+
token_embedded = self.token_embedding(token_ids)
|
|
147
|
+
seq_len = token_ids.size(1)
|
|
148
|
+
if seq_len > self.context_length:
|
|
149
|
+
position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
|
|
150
|
+
else:
|
|
151
|
+
position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
|
|
152
|
+
return token_embedded + position_encoded
|