TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ddpm/sample_ddpm.py ADDED
@@ -0,0 +1,213 @@
1
+ """Image generation using a trained Denoising Diffusion Probabilistic Model (DDPM).
2
+
3
+ This module implements the sampling process for generating images with a trained DDPM
4
+ model, as described in Ho et al. (2020, "Denoising Diffusion Probabilistic Models").
5
+ It supports both unconditional and conditional generation with text prompts.
6
+ """
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import BertTokenizer
12
+
13
+
14
+
15
+ class SampleDDPM(nn.Module):
16
+ """Image generation using a trained DDPM model.
17
+
18
+ Implements the sampling process for DDPM, generating images by iteratively
19
+ denoising random noise using a trained noise predictor and reverse diffusion
20
+ process. Supports conditional generation with text prompts via a conditional
21
+ model, as inspired by Ho et al. (2020).
22
+
23
+ Parameters
24
+ ----------
25
+ reverse_diffusion : nn.Module
26
+ Reverse diffusion module (e.g., ReverseDDPM) for the reverse process.
27
+ noise_predictor : nn.Module
28
+ Trained model to predict noise at each time step.
29
+ image_shape : tuple
30
+ Tuple of (height, width) specifying the generated image dimensions.
31
+ conditional_model : nn.Module, optional
32
+ Model for conditional generation (e.g., text embeddings), default None.
33
+ tokenizer : str, optional
34
+ Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
35
+ max_length : int, optional
36
+ Maximum length for tokenized prompts (default: 77).
37
+ batch_size : int, optional
38
+ Number of images to generate per batch (default: 1).
39
+ in_channels : int, optional
40
+ Number of input channels for generated images (default: 3).
41
+ device : torch.device, optional
42
+ Device for computation (default: CUDA if available, else CPU).
43
+ output_range : tuple, optional
44
+ Tuple of (min, max) for clamping generated images (default: (-1, 1)).
45
+
46
+ Attributes
47
+ ----------
48
+ device : torch.device
49
+ Device used for computation.
50
+ reverse : nn.Module
51
+ Reverse diffusion module.
52
+ noise_predictor : nn.Module
53
+ Noise prediction model.
54
+ conditional_model : nn.Module or None
55
+ Conditional model for text-based generation, if provided.
56
+ tokenizer : BertTokenizer
57
+ Tokenizer for processing text prompts.
58
+ max_length : int
59
+ Maximum length for tokenized prompts.
60
+ in_channels : int
61
+ Number of input channels.
62
+ image_shape : tuple
63
+ Shape of generated images (height, width).
64
+ batch_size : int
65
+ Batch size for generation.
66
+ output_range : tuple
67
+ Range for clamping generated images.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If `image_shape` is not a tuple of two positive integers, `batch_size` is not
73
+ positive, or `output_range` is not a valid (min, max) tuple with min < max.
74
+ """
75
+ def __init__(self, reverse_diffusion, noise_predictor, image_shape, conditional_model=None, tokenizer="bert-base-uncased",
76
+ max_length=77, batch_size=1, in_channels=3, device=None, output_range=(-1, 1)):
77
+ super().__init__()
78
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ self.reverse = reverse_diffusion.to(self.device)
80
+ self.noise_predictor = noise_predictor.to(self.device)
81
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
82
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
83
+ self.max_length = max_length
84
+ self.in_channels = in_channels
85
+ self.image_shape = image_shape
86
+ self.batch_size = batch_size
87
+ self.output_range = output_range
88
+
89
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
90
+ isinstance(s, int) and s > 0 for s in image_shape):
91
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
92
+ if batch_size <= 0:
93
+ raise ValueError("batch_size must be positive")
94
+ if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
95
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
96
+
97
+ def tokenize(self, prompts):
98
+ """Tokenizes text prompts for conditional generation.
99
+
100
+ Converts input prompts into tokenized input IDs and attention masks using the
101
+ specified tokenizer, suitable for use with the conditional model.
102
+
103
+ Parameters
104
+ ----------
105
+ prompts : str or list
106
+ A single text prompt or a list of text prompts.
107
+
108
+ Returns
109
+ -------
110
+ tuple
111
+ A tuple containing:
112
+ - input_ids: Tokenized input IDs, shape (batch_size, max_length).
113
+ - attention_mask: Attention mask, shape (batch_size, max_length).
114
+
115
+ Raises
116
+ ------
117
+ TypeError
118
+ If `prompts` is not a string or a list of strings.
119
+ """
120
+ if isinstance(prompts, str):
121
+ prompts = [prompts]
122
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
123
+ raise TypeError("prompts must be a string or list of strings")
124
+ encoded = self.tokenizer(
125
+ prompts,
126
+ padding="max_length",
127
+ truncation=True,
128
+ max_length=self.max_length,
129
+ return_tensors="pt"
130
+ )
131
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
132
+
133
+ def forward(self, conditions=None, normalize_output=True):
134
+ """Generates images using the DDPM sampling process.
135
+
136
+ Iteratively denoises random noise to generate images using the reverse diffusion
137
+ process and noise predictor. Supports conditional generation with text prompts.
138
+
139
+ Parameters
140
+ ----------
141
+ conditions : str or list, optional
142
+ Text prompt(s) for conditional generation, default None.
143
+ normalize_output : bool, optional
144
+ If True, normalizes output images to [0, 1] (default: True).
145
+
146
+ Returns
147
+ -------
148
+ torch.Tensor
149
+ Generated images, shape (batch_size, in_channels, height, width).
150
+ If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
151
+ they are clamped to `output_range`.
152
+
153
+ Raises
154
+ ------
155
+ ValueError
156
+ If `conditions` is provided but no conditional model is specified, or if
157
+ a conditional model is specified but `conditions` is None.
158
+ """
159
+
160
+ if conditions is not None and self.conditional_model is None:
161
+ raise ValueError("Conditions provided but no conditional model specified")
162
+ if conditions is None and self.conditional_model is not None:
163
+ raise ValueError("Conditions must be provided for conditional model")
164
+
165
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
166
+
167
+ self.noise_predictor.eval()
168
+ self.reverse.eval()
169
+ if self.conditional_model:
170
+ self.conditional_model.eval()
171
+
172
+ with torch.no_grad():
173
+ xt = noisy_samples
174
+ for t in reversed(range(self.reverse.hyper_params.num_steps)):
175
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
176
+ if self.conditional_model is not None and conditions is not None:
177
+ input_ids, attention_masks = self.tokenize(conditions)
178
+ key_padding_mask = (attention_masks == 0)
179
+ y = self.conditional_model(input_ids, key_padding_mask)
180
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
181
+ else:
182
+ predicted_noise = self.noise_predictor(xt, time_steps)
183
+ xt = self.reverse(xt, predicted_noise, time_steps)
184
+
185
+ generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
186
+ if normalize_output:
187
+ generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
188
+
189
+ return generated_imgs
190
+
191
+ def to(self, device):
192
+ """Moves the module and its components to the specified device.
193
+
194
+ Updates the device attribute and moves the reverse diffusion, noise predictor,
195
+ and conditional model (if present) to the specified device.
196
+
197
+ Parameters
198
+ ----------
199
+ device : torch.device
200
+ Target device for the module and its components.
201
+
202
+ Returns
203
+ -------
204
+ SampleDDPM
205
+ The module itself, moved to the specified device.
206
+ """
207
+ self.device = device
208
+ self.noise_predictor.to(device)
209
+ self.reverse.to(device)
210
+ self.compressor.to(device)
211
+ if self.conditional_model:
212
+ self.conditional_model.to(device)
213
+ return super().to(device)
ddpm/text_encoder.py ADDED
@@ -0,0 +1,152 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers import BertModel, DistilBertModel
5
+
6
+
7
+
8
+
9
+ class TextEncoder(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ use_pretrained_model=True,
13
+ model_name="bert-base-uncased",
14
+ vocabulary_size=30522,
15
+ num_layers=6,
16
+ input_dimension=768,
17
+ output_dimension=768,
18
+ num_heads=8,
19
+ context_length=77,
20
+ dropout_rate=0.1,
21
+ qkv_bias=False,
22
+ scaling_value=4,
23
+ epsilon=1e-5
24
+ ):
25
+ super().__init__()
26
+ self.use_pretrained_model = use_pretrained_model
27
+ if self.use_pretrained_model:
28
+ # self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
29
+ self.bert = BertModel.from_pretrained(model_name)
30
+ for param in self.bert.parameters():
31
+ param.requires_grad = False
32
+ self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
33
+ else:
34
+ self.embedding = Embedding(
35
+ vocabulary_size=vocabulary_size,
36
+ embedding_dimension=input_dimension,
37
+ context_length=context_length
38
+ )
39
+ self.layers = torch.nn.ModuleList([
40
+ EncoderLayer(
41
+ input_dimension=input_dimension,
42
+ output_dimension=output_dimension,
43
+ num_heads=num_heads,
44
+ dropout_rate=dropout_rate,
45
+ qkv_bias=qkv_bias,
46
+ scaling_value=scaling_value,
47
+ epsilon=epsilon
48
+ )
49
+ for _ in range(num_layers)
50
+ ])
51
+ def forward(self, x, attention_mask=None):
52
+ if self.use_pretrained_model:
53
+ x = self.bert(input_ids=x, attention_mask=attention_mask)
54
+ x = x.last_hidden_state
55
+ x = self.projection(x)
56
+ else:
57
+ x = self.embedding(x)
58
+ for layer in self.layers:
59
+ x = layer(x, attention_mask=attention_mask)
60
+ return x
61
+ #-----------------------------------------------------------------------------------------------------------------------
62
+ class EncoderLayer(torch.nn.Module):
63
+ def __init__(
64
+ self,
65
+ input_dimension,
66
+ output_dimension,
67
+ num_heads,
68
+ dropout_rate,
69
+ qkv_bias,
70
+ scaling_value,
71
+ epsilon=1e-5
72
+ ):
73
+ super().__init__()
74
+ self.attention = nn.MultiheadAttention(
75
+ embed_dim=input_dimension,
76
+ num_heads=num_heads,
77
+ dropout=dropout_rate,
78
+ bias=qkv_bias,
79
+ batch_first=True
80
+ )
81
+ self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
82
+ self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
83
+ self.dropout1 = nn.Dropout(dropout_rate)
84
+ self.feedforward = FeedForward(
85
+ embedding_dimension=input_dimension,
86
+ scaling_value=scaling_value,
87
+ dropout_rate=dropout_rate
88
+ )
89
+ self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
90
+ self.dropout2 = nn.Dropout(dropout_rate)
91
+ def forward(self, x, attention_mask=None):
92
+ attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
93
+ attn_output = self.output_projection(attn_output)
94
+ x = self.norm1(x + self.dropout1(attn_output))
95
+ ff_output = self.feedforward(x)
96
+ x = self.norm2(x + self.dropout2(ff_output))
97
+ return x
98
+ #-----------------------------------------------------------------------------------------------------------------------
99
+ class FeedForward(torch.nn.Module):
100
+ def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
101
+ super().__init__()
102
+ self.layers = torch.nn.Sequential(
103
+ torch.nn.Linear(
104
+ in_features=embedding_dimension,
105
+ out_features=embedding_dimension * scaling_value,
106
+ bias=True
107
+ ),
108
+ torch.nn.GELU(),
109
+ torch.nn.Dropout(dropout_rate),
110
+ torch.nn.Linear(
111
+ in_features=embedding_dimension * scaling_value,
112
+ out_features=embedding_dimension,
113
+ bias=True
114
+ )
115
+ )
116
+ def forward(self, x):
117
+ return self.layers(x)
118
+ #-----------------------------------------------------------------------------------------------------------------------
119
+ class Embedding(torch.nn.Module):
120
+ def __init__(
121
+ self,
122
+ vocabulary_size,
123
+ embedding_dimension=768,
124
+ context_length=77
125
+ ):
126
+ super().__init__()
127
+ self.token_embedding = nn.Embedding(
128
+ num_embeddings=vocabulary_size,
129
+ embedding_dim=embedding_dimension
130
+ )
131
+ self.embedding_dimension = embedding_dimension
132
+ self.context_length = context_length
133
+ self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
134
+
135
+ def _generate_positional_encoding(self, seq_len):
136
+ position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
137
+ div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
138
+ -(math.log(10000.0) / self.embedding_dimension))
139
+ pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
140
+ pos_enc[:, 0::2] = torch.sin(position * div_term)
141
+ pos_enc[:, 1::2] = torch.cos(position * div_term)
142
+ return pos_enc.unsqueeze(0)
143
+
144
+ def forward(self, token_ids):
145
+ assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
146
+ token_embedded = self.token_embedding(token_ids)
147
+ seq_len = token_ids.size(1)
148
+ if seq_len > self.context_length:
149
+ position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
150
+ else:
151
+ position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
152
+ return token_embedded + position_encoded