TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
sde/sample_sde.py ADDED
@@ -0,0 +1,216 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer
4
+
5
+
6
+
7
+
8
+ class SampleSDE(nn.Module):
9
+ """Sampler for generating images using SDE-based generative models.
10
+
11
+ Generates images by iteratively denoising random noise using the reverse SDE process
12
+ and a trained noise predictor, as described in Song et al. (2021). Supports both
13
+ unconditional and conditional generation with text prompts.
14
+
15
+ Parameters
16
+ ----------
17
+ reverse_diffusion : ReverseSDE
18
+ Reverse SDE diffusion module for denoising.
19
+ noise_predictor : nn.Module
20
+ Model to predict noise added during the forward SDE process.
21
+ image_shape : tuple
22
+ Shape of generated images as (height, width).
23
+ conditional_model : nn.Module, optional
24
+ Model for conditional generation (e.g., text embeddings), default None.
25
+ tokenizer : str or BertTokenizer, optional
26
+ Tokenizer for processing text prompts, default "bert-base-uncased".
27
+ max_length : int, optional
28
+ Maximum length for tokenized prompts (default: 77).
29
+ batch_size : int, optional
30
+ Number of images to generate per batch (default: 1).
31
+ in_channels : int, optional
32
+ Number of input channels for generated images (default: 3).
33
+ device : torch.device, optional
34
+ Device for computation (default: CUDA if available, else CPU).
35
+ output_range : tuple, optional
36
+ Range for clamping generated images (min, max), default (-1, 1).
37
+
38
+ Attributes
39
+ ----------
40
+ device : torch.device
41
+ Device used for computation.
42
+ reverse : ReverseSDE
43
+ Reverse SDE diffusion module.
44
+ noise_predictor : nn.Module
45
+ Noise prediction model.
46
+ conditional_model : nn.Module or None
47
+ Conditional model for text-based generation, if provided.
48
+ tokenizer : BertTokenizer
49
+ Tokenizer for text prompts.
50
+ max_length : int
51
+ Maximum length for tokenized prompts.
52
+ in_channels : int
53
+ Number of input channels.
54
+ image_shape : tuple
55
+ Shape of generated images (height, width).
56
+ batch_size : int
57
+ Batch size for generation.
58
+ output_range : tuple
59
+ Range for clamping generated images.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If `image_shape` is not a tuple of two positive integers, `batch_size` is not
65
+ positive, or `output_range` is not a tuple (min, max) with min < max.
66
+ """
67
+ def __init__(self, reverse_diffusion, noise_predictor, image_shape, conditional_model=None,
68
+ tokenizer="bert-base-uncased", max_length=77, batch_size=1, in_channels=3, device=None, output_range=(-1, 1)):
69
+ super().__init__()
70
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ self.reverse = reverse_diffusion.to(self.device)
72
+ self.noise_predictor = noise_predictor.to(self.device)
73
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
74
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
75
+ self.max_length = max_length
76
+ self.in_channels = in_channels
77
+ self.image_shape = image_shape
78
+ self.batch_size = batch_size
79
+ self.output_range = output_range
80
+
81
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
82
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
83
+ if batch_size <= 0:
84
+ raise ValueError("batch_size must be positive")
85
+ if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
86
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
87
+
88
+ def tokenize(self, prompts):
89
+ """Tokenizes text prompts for conditional generation.
90
+
91
+ Converts input prompts into tokenized tensors using the specified tokenizer.
92
+
93
+ Parameters
94
+ ----------
95
+ prompts : str or list
96
+ Text prompt(s) for conditional generation. Can be a single string or a list
97
+ of strings.
98
+
99
+ Returns
100
+ -------
101
+ tuple
102
+ A tuple containing:
103
+ - input_ids: Tokenized input IDs (torch.Tensor, shape (batch_size, max_length)).
104
+ - attention_mask: Attention mask for tokenized inputs (torch.Tensor, same shape).
105
+
106
+ Raises
107
+ ------
108
+ TypeError
109
+ If `prompts` is not a string or a list of strings.
110
+ """
111
+ if isinstance(prompts, str):
112
+ prompts = [prompts]
113
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
114
+ raise TypeError("prompts must be a string or list of strings")
115
+ encoded = self.tokenizer(
116
+ prompts,
117
+ padding="max_length",
118
+ truncation=True,
119
+ max_length=self.max_length,
120
+ return_tensors="pt"
121
+ )
122
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
123
+
124
+ def forward(self, conditions=None, normalize_output=True):
125
+ """Generates images using the reverse SDE sampling process.
126
+
127
+ Iteratively denoises random noise to generate images using the reverse SDE process
128
+ and noise predictor. Supports conditional generation with text prompts.
129
+
130
+ Parameters
131
+ ----------
132
+ conditions : str or list, optional
133
+ Text prompt(s) for conditional generation, default None.
134
+ normalize_output : bool, optional
135
+ If True, normalizes output images to [0, 1] (default: True).
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor
140
+ Generated images, shape (batch_size, in_channels, height, width).
141
+ If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
142
+ they are clamped to `output_range`.
143
+
144
+ Raises
145
+ ------
146
+ ValueError
147
+ If `conditions` is provided but no conditional model is specified, or if
148
+ a conditional model is specified but `conditions` is None.
149
+
150
+ Notes
151
+ -----
152
+ - Sampling is performed with `torch.no_grad()` for efficiency.
153
+ - The noise predictor, reverse SDE, and conditional model (if applicable) are set
154
+ to evaluation mode during sampling.
155
+ """
156
+ if conditions is not None and self.conditional_model is None:
157
+ raise ValueError("Conditions provided but no conditional model specified")
158
+ if conditions is None and self.conditional_model is not None:
159
+ raise ValueError("Conditions must be provided for conditional model")
160
+
161
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
162
+
163
+ self.noise_predictor.eval()
164
+ self.reverse.eval()
165
+ if self.conditional_model:
166
+ self.conditional_model.eval()
167
+
168
+ with torch.no_grad():
169
+ xt = noisy_samples
170
+ for t in reversed(range(self.reverse.hyper_params.num_steps)):
171
+ noise = torch.randn_like(xt) if self.reverse.method != "ode" else None
172
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
173
+
174
+ if self.conditional_model is not None and conditions is not None:
175
+ input_ids, attention_masks = self.tokenize(conditions)
176
+ key_padding_mask = (attention_masks == 0)
177
+ y = self.conditional_model(input_ids, key_padding_mask)
178
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
179
+ else:
180
+ predicted_noise = self.noise_predictor(xt, time_steps)
181
+
182
+ xt = self.reverse(xt, noise, predicted_noise, time_steps)
183
+
184
+ generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
185
+ if normalize_output:
186
+ generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
187
+
188
+ return generated_imgs
189
+
190
+ def to(self, device):
191
+ """Moves the module and its components to the specified device.
192
+
193
+ Parameters
194
+ ----------
195
+ device : torch.device
196
+ Target device for computation.
197
+
198
+ Returns
199
+ -------
200
+ self
201
+ The module moved to the specified device.
202
+
203
+ Notes
204
+ -----
205
+ - Moves `noise_predictor`, `reverse`, and `conditional_model` (if applicable) to
206
+ the specified device.
207
+ - The `compressor` attribute is not defined in this implementation and should be
208
+ removed or implemented if intended.
209
+ """
210
+ self.device = device
211
+ self.noise_predictor.to(device)
212
+ self.reverse.to(device)
213
+ self.compressor.to(device)
214
+ if self.conditional_model:
215
+ self.conditional_model.to(device)
216
+ return super().to(device)
sde/text_encoder.py ADDED
@@ -0,0 +1,152 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers import BertModel, DistilBertModel
5
+
6
+
7
+
8
+
9
+ class TextEncoder(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ use_pretrained_model=True,
13
+ model_name="bert-base-uncased",
14
+ vocabulary_size=30522,
15
+ num_layers=6,
16
+ input_dimension=768,
17
+ output_dimension=768,
18
+ num_heads=8,
19
+ context_length=77,
20
+ dropout_rate=0.1,
21
+ qkv_bias=False,
22
+ scaling_value=4,
23
+ epsilon=1e-5
24
+ ):
25
+ super().__init__()
26
+ self.use_pretrained_model = use_pretrained_model
27
+ if self.use_pretrained_model:
28
+ # self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
29
+ self.bert = BertModel.from_pretrained(model_name)
30
+ for param in self.bert.parameters():
31
+ param.requires_grad = False
32
+ self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
33
+ else:
34
+ self.embedding = Embedding(
35
+ vocabulary_size=vocabulary_size,
36
+ embedding_dimension=input_dimension,
37
+ context_length=context_length
38
+ )
39
+ self.layers = torch.nn.ModuleList([
40
+ EncoderLayer(
41
+ input_dimension=input_dimension,
42
+ output_dimension=output_dimension,
43
+ num_heads=num_heads,
44
+ dropout_rate=dropout_rate,
45
+ qkv_bias=qkv_bias,
46
+ scaling_value=scaling_value,
47
+ epsilon=epsilon
48
+ )
49
+ for _ in range(num_layers)
50
+ ])
51
+ def forward(self, x, attention_mask=None):
52
+ if self.use_pretrained_model:
53
+ x = self.bert(input_ids=x, attention_mask=attention_mask)
54
+ x = x.last_hidden_state
55
+ x = self.projection(x)
56
+ else:
57
+ x = self.embedding(x)
58
+ for layer in self.layers:
59
+ x = layer(x, attention_mask=attention_mask)
60
+ return x
61
+ #-----------------------------------------------------------------------------------------------------------------------
62
+ class EncoderLayer(torch.nn.Module):
63
+ def __init__(
64
+ self,
65
+ input_dimension,
66
+ output_dimension,
67
+ num_heads,
68
+ dropout_rate,
69
+ qkv_bias,
70
+ scaling_value,
71
+ epsilon=1e-5
72
+ ):
73
+ super().__init__()
74
+ self.attention = nn.MultiheadAttention(
75
+ embed_dim=input_dimension,
76
+ num_heads=num_heads,
77
+ dropout=dropout_rate,
78
+ bias=qkv_bias,
79
+ batch_first=True
80
+ )
81
+ self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
82
+ self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
83
+ self.dropout1 = nn.Dropout(dropout_rate)
84
+ self.feedforward = FeedForward(
85
+ embedding_dimension=input_dimension,
86
+ scaling_value=scaling_value,
87
+ dropout_rate=dropout_rate
88
+ )
89
+ self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
90
+ self.dropout2 = nn.Dropout(dropout_rate)
91
+ def forward(self, x, attention_mask=None):
92
+ attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
93
+ attn_output = self.output_projection(attn_output)
94
+ x = self.norm1(x + self.dropout1(attn_output))
95
+ ff_output = self.feedforward(x)
96
+ x = self.norm2(x + self.dropout2(ff_output))
97
+ return x
98
+ #-----------------------------------------------------------------------------------------------------------------------
99
+ class FeedForward(torch.nn.Module):
100
+ def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
101
+ super().__init__()
102
+ self.layers = torch.nn.Sequential(
103
+ torch.nn.Linear(
104
+ in_features=embedding_dimension,
105
+ out_features=embedding_dimension * scaling_value,
106
+ bias=True
107
+ ),
108
+ torch.nn.GELU(),
109
+ torch.nn.Dropout(dropout_rate),
110
+ torch.nn.Linear(
111
+ in_features=embedding_dimension * scaling_value,
112
+ out_features=embedding_dimension,
113
+ bias=True
114
+ )
115
+ )
116
+ def forward(self, x):
117
+ return self.layers(x)
118
+ #-----------------------------------------------------------------------------------------------------------------------
119
+ class Embedding(torch.nn.Module):
120
+ def __init__(
121
+ self,
122
+ vocabulary_size,
123
+ embedding_dimension=768,
124
+ context_length=77
125
+ ):
126
+ super().__init__()
127
+ self.token_embedding = nn.Embedding(
128
+ num_embeddings=vocabulary_size,
129
+ embedding_dim=embedding_dimension
130
+ )
131
+ self.embedding_dimension = embedding_dimension
132
+ self.context_length = context_length
133
+ self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
134
+
135
+ def _generate_positional_encoding(self, seq_len):
136
+ position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
137
+ div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
138
+ -(math.log(10000.0) / self.embedding_dimension))
139
+ pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
140
+ pos_enc[:, 0::2] = torch.sin(position * div_term)
141
+ pos_enc[:, 1::2] = torch.cos(position * div_term)
142
+ return pos_enc.unsqueeze(0)
143
+
144
+ def forward(self, token_ids):
145
+ assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
146
+ token_embedded = self.token_embedding(token_ids)
147
+ seq_len = token_ids.size(1)
148
+ if seq_len > self.context_length:
149
+ position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
150
+ else:
151
+ position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
152
+ return token_embedded + position_encoded