TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/clip_model.py ADDED
@@ -0,0 +1,304 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Union, Optional
5
+ from PIL import Image
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+
9
+
10
+
11
+ class CLIPEncoder(nn.Module):
12
+ """Encodes images or text using a pre-trained CLIP model.
13
+
14
+ Loads a CLIP model and processor from the transformers library, providing methods to
15
+ encode images or text into embeddings and compute similarity scores between them.
16
+
17
+ Parameters
18
+ ----------
19
+ `model_name` : str, optional
20
+ Name of the CLIP model to load (default: 'openai/clip-vit-base-patch32').
21
+ `device` : str, optional
22
+ Device to run the model on (default: 'cuda' if available, else 'cpu').
23
+ `use_fast` : bool, optional
24
+ Whether to use the fast image processor (torchvision-based) (default: False).
25
+ """
26
+ def __init__(
27
+ self,
28
+ model_name: str = "openai/clip-vit-base-patch32",
29
+ device: Optional[str] = None,
30
+ use_fast: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+
34
+ # Set model name and device
35
+ self.model_name = model_name
36
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ try:
39
+ # Load CLIP model and processor
40
+ self.model = CLIPModel.from_pretrained(self.model_name)
41
+ self.processor = CLIPProcessor.from_pretrained(self.model_name, use_fast=use_fast)
42
+ self.model = self.model.to(self.device)
43
+ except Exception as e:
44
+ raise RuntimeError(f"Failed to load CLIP model or processor for {self.model_name}: {e}")
45
+
46
+ # set model to evaluation mode by default
47
+ self.model.eval()
48
+
49
+ def forward(
50
+ self,
51
+ data: Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]],
52
+ data_type: str,
53
+ normalize: bool = True
54
+ ) -> torch.Tensor:
55
+ """Encodes input data (image or text) using the CLIP model.
56
+
57
+ Processes input data (images or text) to produce embeddings, with optional L2
58
+ normalization.
59
+
60
+ Parameters
61
+ ----------
62
+ `data` : Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]]
63
+ Input data to encode:
64
+ - torch.Tensor: Preprocessed image tensor (batch_size, channels, height, width).
65
+ - List[str] or str: Text or list of texts.
66
+ - PIL.Image.Image or List[PIL.Image.Image]: Single or list of PIL images.
67
+ `data_type` : str
68
+ Type of input data ('img' or 'text').
69
+ `normalize` : bool, optional
70
+ Whether to L2-normalize the output embeddings (default: True).
71
+
72
+ Returns
73
+ -------
74
+ outputs : torch.Tensor
75
+ Encoded embeddings, shape (batch_size, embedding_dim).
76
+ """
77
+ if data_type not in ["img", "text"]:
78
+ raise ValueError(f"Invalid data_type: {data_type}. Must be 'img' or 'text'.")
79
+
80
+ with torch.no_grad():
81
+ if data_type == "img":
82
+ outputs = self._encode_images(data)
83
+ else:
84
+ outputs = self._encode_texts(data)
85
+
86
+ # normalize embeddings if requested
87
+ if normalize:
88
+ outputs = F.normalize(outputs, p=2, dim=-1)
89
+
90
+ return outputs
91
+
92
+ def _encode_images(self, data: Union[torch.Tensor, Image.Image, List[Image.Image]]) -> torch.Tensor:
93
+ """Encodes images into embeddings using the CLIP model.
94
+
95
+ Processes image inputs (tensors or PIL images) to produce image embeddings.
96
+
97
+ Parameters
98
+ ----------
99
+ `data` : Union[torch.Tensor, Image.Image, List[Image.Image]]
100
+ Input images as a tensor or PIL image(s).
101
+
102
+ Returns
103
+ -------
104
+ image_features : torch.Tensor
105
+ Image embeddings, shape (batch_size, embedding_dim).
106
+ """
107
+ if isinstance(data, torch.Tensor):
108
+ if data.dim() == 3:
109
+ data = data.unsqueeze(0)
110
+ inputs = {"pixel_values": data.to(self.device)}
111
+ elif isinstance(data, (Image.Image, list)):
112
+ if isinstance(data, Image.Image):
113
+ data = [data]
114
+ inputs = self.processor(images=data, return_tensors="pt", padding=True)
115
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
116
+ else:
117
+ raise ValueError(f"Invalid image data type: {type(data)}. Expected torch.Tensor, PIL.Image.Image, or List[PIL.Image.Image].")
118
+ return self.model.get_image_features(**inputs)
119
+
120
+ def _encode_texts(self, data: Union[str, List[str], torch.Tensor]) -> torch.Tensor:
121
+ """Encodes texts into embeddings using the CLIP model.
122
+
123
+ Processes text inputs (strings or tokenized tensors) to produce text embeddings.
124
+
125
+ Parameters
126
+ ----------
127
+ `data` : Union[str, List[str], torch.Tensor]
128
+ Input texts as strings or tokenized tensor.
129
+
130
+ Returns
131
+ -------
132
+ text_features : torch.Tensor
133
+ Text embeddings, shape (batch_size, embedding_dim).
134
+ """
135
+ if isinstance(data, torch.Tensor):
136
+ data = data.to(self.device)
137
+ if data.dim() == 2:
138
+ return data
139
+ if data.dim() == 1:
140
+ data = data.unsqueeze(0)
141
+ attention_mask = torch.ones_like(data)
142
+ return self.model.get_text_features(input_ids=data, attention_mask=attention_mask)
143
+
144
+ if isinstance(data, str):
145
+ data = [data]
146
+ elif isinstance(data, list) and all(isinstance(t, str) for t in data):
147
+ pass
148
+ else:
149
+ raise ValueError(
150
+ f"Invalid text data type: {type(data)}. Expected str, List[str], or torch.Tensor."
151
+ )
152
+
153
+ inputs = self.processor(text=data, return_tensors="pt", padding=True, truncation=True)
154
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
155
+ return self.model.get_text_features(**inputs)
156
+
157
+ def compute_similarity(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
158
+ """Computes cosine similarity between image and text embeddings.
159
+
160
+ Calculates the cosine similarity matrix between batches of image and text embeddings.
161
+
162
+ Parameters
163
+ ----------
164
+ `image_features` : torch.Tensor
165
+ Image embeddings, shape (batch_size, embedding_dim).
166
+ `text_features` : torch.Tensor
167
+ Text embeddings, shape (batch_size, embedding_dim).
168
+
169
+ Returns
170
+ -------
171
+ similarity : torch.Tensor
172
+ Cosine similarity scores, shape (batch_size, batch_size).
173
+ """
174
+ image_features = F.normalize(image_features, p=2, dim=-1)
175
+ text_features = F.normalize(text_features, p=2, dim=-1)
176
+ return torch.matmul(image_features, text_features.T)
177
+
178
+
179
+
180
+ """
181
+ # ============================================================================
182
+ # USAGE EXAMPLE
183
+ # ============================================================================
184
+
185
+ def main():
186
+ "Demonstrate how to use the CLIP class."
187
+
188
+ print("=== CLIP Usage Example ===\n")
189
+
190
+ # Initialize CLIP model
191
+ clip_model = CLIPEncoder(model_name="openai/clip-vit-base-patch32")
192
+ print(f"Model loaded on device: {clip_model.device}")
193
+ print(f"Model name: {clip_model.model_name}\n")
194
+
195
+ # ========================================================================
196
+ # TEXT ENCODING EXAMPLES
197
+ # ========================================================================
198
+ print("1. TEXT ENCODING:")
199
+ print("-" * 50)
200
+
201
+ # Single text
202
+ single_text = "a photo of a cat"
203
+ text_features_single = clip_model(single_text, data_type="text")
204
+ print(f"Single text: '{single_text}'")
205
+ print(f"Output shape: {text_features_single.shape}")
206
+ print(f"Output dtype: {text_features_single.dtype}")
207
+ print(f"Output range: [{text_features_single.min():.4f}, {text_features_single.max():.4f}]\n")
208
+
209
+ # Multiple texts
210
+ multiple_texts = [
211
+ "a photo of a cat",
212
+ "a photo of a dog",
213
+ "a beautiful sunset over mountains",
214
+ "a red sports car"
215
+ ]
216
+ text_features_multiple = clip_model(multiple_texts, data_type="text")
217
+ print(f"Multiple texts ({len(multiple_texts)} items):")
218
+ for i, text in enumerate(multiple_texts):
219
+ print(f" {i + 1}. '{text}'")
220
+ print(f"Output shape: {text_features_multiple.shape}")
221
+ print(f"Output dtype: {text_features_multiple.dtype}")
222
+ print(f"Output range: [{text_features_multiple.min():.4f}, {text_features_multiple.max():.4f}]\n")
223
+
224
+ # ========================================================================
225
+ # IMAGE ENCODING EXAMPLES (using synthetic data for demo)
226
+ # ========================================================================
227
+ print("2. IMAGE ENCODING:")
228
+ print("-" * 50)
229
+
230
+ # Create synthetic PIL images for demonstration
231
+ synthetic_images = []
232
+ for i in range(3):
233
+ # Create random RGB image
234
+ img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
235
+ pil_img = Image.fromarray(img_array)
236
+ synthetic_images.append(pil_img)
237
+
238
+ # Single image
239
+ single_image = synthetic_images[0]
240
+ image_features_single = clip_model(single_image, data_type="img")
241
+ print(f"Single PIL image (size: {single_image.size})")
242
+ print(f"Output shape: {image_features_single.shape}")
243
+ print(f"Output dtype: {image_features_single.dtype}")
244
+ print(f"Output range: [{image_features_single.min():.4f}, {image_features_single.max():.4f}]\n")
245
+
246
+ # Multiple images
247
+ image_features_multiple = clip_model(synthetic_images, data_type="img")
248
+ print(f"Multiple PIL images ({len(synthetic_images)} images)")
249
+ print(f"Output shape: {image_features_multiple.shape}")
250
+ print(f"Output dtype: {image_features_multiple.dtype}")
251
+ print(f"Output range: [{image_features_multiple.min():.4f}, {image_features_multiple.max():.4f}]\n")
252
+
253
+ # Tensor input (pre-processed)
254
+ tensor_input = torch.randn(2, 3, 224, 224) # Batch of 2 images
255
+ image_features_tensor = clip_model(tensor_input, data_type="img")
256
+ print(f"Tensor input shape: {tensor_input.shape}")
257
+ print(f"Output shape: {image_features_tensor.shape}")
258
+ print(f"Output dtype: {image_features_tensor.dtype}\n")
259
+
260
+ # ========================================================================
261
+ # SIMILARITY COMPUTATION EXAMPLE
262
+ # ========================================================================
263
+ print("3. SIMILARITY COMPUTATION:")
264
+ print("-" * 50)
265
+
266
+ # Compute similarity between images and texts
267
+ similarity_matrix = clip_model.compute_similarity(image_features_multiple, text_features_multiple)
268
+ print(f"Similarity matrix shape: {similarity_matrix.shape}")
269
+ print(f"Similarity matrix (images vs texts):")
270
+ print(similarity_matrix.detach().cpu().numpy())
271
+ print()
272
+
273
+ # Find best matches
274
+ best_matches = similarity_matrix.argmax(dim=1)
275
+ print("Best text matches for each image:")
276
+ for i, match_idx in enumerate(best_matches):
277
+ print(f" Image {i + 1} -> Text {match_idx + 1}: '{multiple_texts[match_idx]}'")
278
+ print(f" Similarity score: {similarity_matrix[i, match_idx]:.4f}")
279
+ print()
280
+
281
+ # ========================================================================
282
+ # EXPECTED INPUT/OUTPUT SUMMARY
283
+ # ========================================================================
284
+ print("4. INPUT/OUTPUT SUMMARY:")
285
+ print("-" * 50)
286
+ print("INPUT TYPES:")
287
+ print(" Text:")
288
+ print(" - str: Single text string")
289
+ print(" - List[str]: List of text strings")
290
+ print(" Images:")
291
+ print(" - PIL.Image.Image: Single PIL image")
292
+ print(" - List[PIL.Image.Image]: List of PIL images")
293
+ print(" - torch.Tensor: Pre-processed tensor (C, H, W) or (B, C, H, W)")
294
+ print()
295
+ print("OUTPUT:")
296
+ print(" - torch.Tensor: Shape (batch_size, 512) for ViT-Base models")
297
+ print(" - dtype: torch.float32")
298
+ print(" - Range: [-1, 1] if normalized (default), varies if not normalized")
299
+ print(" - Device: Same as model device")
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
304
+ """