TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ddim/sample_ddim.py ADDED
@@ -0,0 +1,219 @@
1
+ """Image generation using a trained Denoising Diffusion Implicit Model (DDIM).
2
+
3
+ This module implements the sampling process for generating images with a trained DDIM
4
+ model, as described in Song et al. (2021, "Denoising Diffusion Implicit Models"). It
5
+ supports both unconditional and conditional generation with text prompts, using a
6
+ subsampled time step schedule for faster sampling.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import BertTokenizer
12
+
13
+
14
+
15
+ class SampleDDIM(nn.Module):
16
+ """Image generation using a trained DDIM model.
17
+
18
+ Implements the sampling process for DDIM, generating images by iteratively denoising
19
+ random noise using a trained noise predictor and reverse diffusion process with a
20
+ subsampled time step schedule. Supports conditional generation with text prompts,
21
+ as inspired by Song et al. (2021).
22
+
23
+ Parameters
24
+ ----------
25
+ reverse_diffusion : nn.Module
26
+ Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
27
+ noise_predictor : nn.Module
28
+ Trained model to predict noise at each time step.
29
+ image_shape : tuple
30
+ Tuple of (height, width) specifying the generated image dimensions.
31
+ conditional_model : nn.Module, optional
32
+ Model for conditional generation (e.g., text embeddings), default None.
33
+ tokenizer : str, optional
34
+ Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
35
+ max_length : int, optional
36
+ Maximum length for tokenized prompts (default: 77).
37
+ batch_size : int, optional
38
+ Number of images to generate per batch (default: 1).
39
+ in_channels : int, optional
40
+ Number of input channels for generated images (default: 3).
41
+ device : torch.device, optional
42
+ Device for computation (default: CUDA if available, else CPU).
43
+ output_range : tuple, optional
44
+ Tuple of (min, max) for clamping generated images (default: (-1, 1)).
45
+
46
+ Attributes
47
+ ----------
48
+ device : torch.device
49
+ Device used for computation.
50
+ reverse : nn.Module
51
+ Reverse diffusion module.
52
+ noise_predictor : nn.Module
53
+ Noise prediction model.
54
+ conditional_model : nn.Module or None
55
+ Conditional model for text-based generation, if provided.
56
+ tokenizer : BertTokenizer
57
+ Tokenizer for processing text prompts.
58
+ max_length : int
59
+ Maximum length for tokenized prompts.
60
+ in_channels : int
61
+ Number of input channels.
62
+ image_shape : tuple
63
+ Shape of generated images (height, width).
64
+ batch_size : int
65
+ Batch size for generation.
66
+ output_range : tuple
67
+ Range for clamping generated images.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If `image_shape` is not a tuple of two positive integers, `batch_size` is not
73
+ positive, or `output_range` is not a valid (min, max) tuple with min < max.
74
+ """
75
+ def __init__(self, reverse_diffusion, noise_predictor, image_shape, conditional_model=None,
76
+ tokenizer="bert-base-uncased",
77
+ max_length=77, batch_size=1, in_channels=3, device=None, output_range=(-1, 1)):
78
+ super().__init__()
79
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ self.reverse = reverse_diffusion.to(self.device)
81
+ self.noise_predictor = noise_predictor.to(self.device)
82
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
83
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
84
+ self.max_length = max_length
85
+ self.in_channels = in_channels
86
+ self.image_shape = image_shape
87
+ self.batch_size = batch_size
88
+ self.output_range = output_range
89
+
90
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
91
+ isinstance(s, int) and s > 0 for s in image_shape):
92
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
93
+ if batch_size <= 0:
94
+ raise ValueError("batch_size must be positive")
95
+ if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
96
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
97
+
98
+
99
+ def tokenize(self, prompts):
100
+ """Tokenizes text prompts for conditional generation.
101
+
102
+ Converts input prompts into tokenized input IDs and attention masks using the
103
+ specified tokenizer, suitable for use with the conditional model.
104
+
105
+ Parameters
106
+ ----------
107
+ prompts : str or list
108
+ A single text prompt or a list of text prompts.
109
+
110
+ Returns
111
+ -------
112
+ tuple
113
+ A tuple containing:
114
+ - input_ids: Tokenized input IDs, shape (batch_size, max_length).
115
+ - attention_mask: Attention mask, shape (batch_size, max_length).
116
+
117
+ Raises
118
+ ------
119
+ TypeError
120
+ If `prompts` is not a string or a list of strings.
121
+ """
122
+ if isinstance(prompts, str):
123
+ prompts = [prompts]
124
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
125
+ raise TypeError("prompts must be a string or list of strings")
126
+ encoded = self.tokenizer(
127
+ prompts,
128
+ padding="max_length",
129
+ truncation=True,
130
+ max_length=self.max_length,
131
+ return_tensors="pt"
132
+ )
133
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
134
+
135
+ def forward(self, conditions=None, normalize_output=True):
136
+ """Generates images using the DDIM sampling process.
137
+
138
+ Iteratively denoises random noise to generate images using the reverse diffusion
139
+ process with a subsampled time step schedule and noise predictor. Supports
140
+ conditional generation with text prompts.
141
+
142
+ Parameters
143
+ ----------
144
+ conditions : str or list, optional
145
+ Text prompt(s) for conditional generation, default None.
146
+ normalize_output : bool, optional
147
+ If True, normalizes output images to [0, 1] (default: True).
148
+
149
+ Returns
150
+ -------
151
+ torch.Tensor
152
+ Generated images, shape (batch_size, in_channels, height, width).
153
+ If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
154
+ they are clamped to `output_range`.
155
+
156
+ Raises
157
+ ------
158
+ ValueError
159
+ If `conditions` is provided but no conditional model is specified, or if
160
+ a conditional model is specified but `conditions` is None.
161
+ """
162
+
163
+ if conditions is not None and self.conditional_model is None:
164
+ raise ValueError("Conditions provided but no conditional model specified")
165
+ if conditions is None and self.conditional_model is not None:
166
+ raise ValueError("Conditions must be provided for conditional model")
167
+
168
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
169
+
170
+ self.noise_predictor.eval()
171
+ self.reverse.eval()
172
+ if self.conditional_model:
173
+ self.conditional_model.eval()
174
+
175
+ with torch.no_grad():
176
+ xt = noisy_samples
177
+ for t in reversed(range(self.reverse.hyper_params.tau_num_steps)):
178
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
179
+ prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
180
+
181
+ if self.conditional_model is not None and conditions is not None:
182
+ input_ids, attention_masks = self.tokenize(conditions)
183
+ key_padding_mask = (attention_masks == 0)
184
+ y = self.conditional_model(input_ids, key_padding_mask)
185
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
186
+ else:
187
+ predicted_noise = self.noise_predictor(xt, time_steps)
188
+
189
+ xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
190
+
191
+ generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
192
+ if normalize_output:
193
+ generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
194
+
195
+ return generated_imgs
196
+
197
+ def to(self, device):
198
+ """Moves the module and its components to the specified device.
199
+
200
+ Updates the device attribute and moves the reverse diffusion, noise predictor,
201
+ and conditional model (if present) to the specified device.
202
+
203
+ Parameters
204
+ ----------
205
+ device : torch.device
206
+ Target device for the module and its components.
207
+
208
+ Returns
209
+ -------
210
+ SampleDDIM
211
+ The module itself, moved to the specified device.
212
+ """
213
+ self.device = device
214
+ self.noise_predictor.to(device)
215
+ self.reverse.to(device)
216
+ self.compressor.to(device)
217
+ if self.conditional_model:
218
+ self.conditional_model.to(device)
219
+ return super().to(device)
ddim/text_encoder.py ADDED
@@ -0,0 +1,152 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers import BertModel, DistilBertModel
5
+
6
+
7
+
8
+
9
+ class TextEncoder(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ use_pretrained_model=True,
13
+ model_name="bert-base-uncased",
14
+ vocabulary_size=30522,
15
+ num_layers=6,
16
+ input_dimension=768,
17
+ output_dimension=768,
18
+ num_heads=8,
19
+ context_length=77,
20
+ dropout_rate=0.1,
21
+ qkv_bias=False,
22
+ scaling_value=4,
23
+ epsilon=1e-5
24
+ ):
25
+ super().__init__()
26
+ self.use_pretrained_model = use_pretrained_model
27
+ if self.use_pretrained_model:
28
+ # self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
29
+ self.bert = BertModel.from_pretrained(model_name)
30
+ for param in self.bert.parameters():
31
+ param.requires_grad = False
32
+ self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
33
+ else:
34
+ self.embedding = Embedding(
35
+ vocabulary_size=vocabulary_size,
36
+ embedding_dimension=input_dimension,
37
+ context_length=context_length
38
+ )
39
+ self.layers = torch.nn.ModuleList([
40
+ EncoderLayer(
41
+ input_dimension=input_dimension,
42
+ output_dimension=output_dimension,
43
+ num_heads=num_heads,
44
+ dropout_rate=dropout_rate,
45
+ qkv_bias=qkv_bias,
46
+ scaling_value=scaling_value,
47
+ epsilon=epsilon
48
+ )
49
+ for _ in range(num_layers)
50
+ ])
51
+ def forward(self, x, attention_mask=None):
52
+ if self.use_pretrained_model:
53
+ x = self.bert(input_ids=x, attention_mask=attention_mask)
54
+ x = x.last_hidden_state
55
+ x = self.projection(x)
56
+ else:
57
+ x = self.embedding(x)
58
+ for layer in self.layers:
59
+ x = layer(x, attention_mask=attention_mask)
60
+ return x
61
+ #-----------------------------------------------------------------------------------------------------------------------
62
+ class EncoderLayer(torch.nn.Module):
63
+ def __init__(
64
+ self,
65
+ input_dimension,
66
+ output_dimension,
67
+ num_heads,
68
+ dropout_rate,
69
+ qkv_bias,
70
+ scaling_value,
71
+ epsilon=1e-5
72
+ ):
73
+ super().__init__()
74
+ self.attention = nn.MultiheadAttention(
75
+ embed_dim=input_dimension,
76
+ num_heads=num_heads,
77
+ dropout=dropout_rate,
78
+ bias=qkv_bias,
79
+ batch_first=True
80
+ )
81
+ self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
82
+ self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
83
+ self.dropout1 = nn.Dropout(dropout_rate)
84
+ self.feedforward = FeedForward(
85
+ embedding_dimension=input_dimension,
86
+ scaling_value=scaling_value,
87
+ dropout_rate=dropout_rate
88
+ )
89
+ self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
90
+ self.dropout2 = nn.Dropout(dropout_rate)
91
+ def forward(self, x, attention_mask=None):
92
+ attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
93
+ attn_output = self.output_projection(attn_output)
94
+ x = self.norm1(x + self.dropout1(attn_output))
95
+ ff_output = self.feedforward(x)
96
+ x = self.norm2(x + self.dropout2(ff_output))
97
+ return x
98
+ #-----------------------------------------------------------------------------------------------------------------------
99
+ class FeedForward(torch.nn.Module):
100
+ def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
101
+ super().__init__()
102
+ self.layers = torch.nn.Sequential(
103
+ torch.nn.Linear(
104
+ in_features=embedding_dimension,
105
+ out_features=embedding_dimension * scaling_value,
106
+ bias=True
107
+ ),
108
+ torch.nn.GELU(),
109
+ torch.nn.Dropout(dropout_rate),
110
+ torch.nn.Linear(
111
+ in_features=embedding_dimension * scaling_value,
112
+ out_features=embedding_dimension,
113
+ bias=True
114
+ )
115
+ )
116
+ def forward(self, x):
117
+ return self.layers(x)
118
+ #-----------------------------------------------------------------------------------------------------------------------
119
+ class Embedding(torch.nn.Module):
120
+ def __init__(
121
+ self,
122
+ vocabulary_size,
123
+ embedding_dimension=768,
124
+ context_length=77
125
+ ):
126
+ super().__init__()
127
+ self.token_embedding = nn.Embedding(
128
+ num_embeddings=vocabulary_size,
129
+ embedding_dim=embedding_dimension
130
+ )
131
+ self.embedding_dimension = embedding_dimension
132
+ self.context_length = context_length
133
+ self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
134
+
135
+ def _generate_positional_encoding(self, seq_len):
136
+ position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
137
+ div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
138
+ -(math.log(10000.0) / self.embedding_dimension))
139
+ pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
140
+ pos_enc[:, 0::2] = torch.sin(position * div_term)
141
+ pos_enc[:, 1::2] = torch.cos(position * div_term)
142
+ return pos_enc.unsqueeze(0)
143
+
144
+ def forward(self, token_ids):
145
+ assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
146
+ token_embedded = self.token_embedding(token_ids)
147
+ seq_len = token_ids.size(1)
148
+ if seq_len > self.context_length:
149
+ position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
150
+ else:
151
+ position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
152
+ return token_embedded + position_encoded