TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/clip_model.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import List, Union, Optional
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from transformers import CLIPProcessor, CLIPModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CLIPEncoder(nn.Module):
|
|
12
|
+
"""Encodes images or text using a pre-trained CLIP model.
|
|
13
|
+
|
|
14
|
+
Loads a CLIP model and processor from the transformers library, providing methods to
|
|
15
|
+
encode images or text into embeddings and compute similarity scores between them.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
`model_name` : str, optional
|
|
20
|
+
Name of the CLIP model to load (default: 'openai/clip-vit-base-patch32').
|
|
21
|
+
`device` : str, optional
|
|
22
|
+
Device to run the model on (default: 'cuda' if available, else 'cpu').
|
|
23
|
+
`use_fast` : bool, optional
|
|
24
|
+
Whether to use the fast image processor (torchvision-based) (default: False).
|
|
25
|
+
"""
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model_name: str = "openai/clip-vit-base-patch32",
|
|
29
|
+
device: Optional[str] = None,
|
|
30
|
+
use_fast: bool = False,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
# Set model name and device
|
|
35
|
+
self.model_name = model_name
|
|
36
|
+
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
# Load CLIP model and processor
|
|
40
|
+
self.model = CLIPModel.from_pretrained(self.model_name)
|
|
41
|
+
self.processor = CLIPProcessor.from_pretrained(self.model_name, use_fast=use_fast)
|
|
42
|
+
self.model = self.model.to(self.device)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(f"Failed to load CLIP model or processor for {self.model_name}: {e}")
|
|
45
|
+
|
|
46
|
+
# set model to evaluation mode by default
|
|
47
|
+
self.model.eval()
|
|
48
|
+
|
|
49
|
+
def forward(
|
|
50
|
+
self,
|
|
51
|
+
data: Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]],
|
|
52
|
+
data_type: str,
|
|
53
|
+
normalize: bool = True
|
|
54
|
+
) -> torch.Tensor:
|
|
55
|
+
"""Encodes input data (image or text) using the CLIP model.
|
|
56
|
+
|
|
57
|
+
Processes input data (images or text) to produce embeddings, with optional L2
|
|
58
|
+
normalization.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
`data` : Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]]
|
|
63
|
+
Input data to encode:
|
|
64
|
+
- torch.Tensor: Preprocessed image tensor (batch_size, channels, height, width).
|
|
65
|
+
- List[str] or str: Text or list of texts.
|
|
66
|
+
- PIL.Image.Image or List[PIL.Image.Image]: Single or list of PIL images.
|
|
67
|
+
`data_type` : str
|
|
68
|
+
Type of input data ('img' or 'text').
|
|
69
|
+
`normalize` : bool, optional
|
|
70
|
+
Whether to L2-normalize the output embeddings (default: True).
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
outputs : torch.Tensor
|
|
75
|
+
Encoded embeddings, shape (batch_size, embedding_dim).
|
|
76
|
+
"""
|
|
77
|
+
if data_type not in ["img", "text"]:
|
|
78
|
+
raise ValueError(f"Invalid data_type: {data_type}. Must be 'img' or 'text'.")
|
|
79
|
+
|
|
80
|
+
with torch.no_grad():
|
|
81
|
+
if data_type == "img":
|
|
82
|
+
outputs = self._encode_images(data)
|
|
83
|
+
else:
|
|
84
|
+
outputs = self._encode_texts(data)
|
|
85
|
+
|
|
86
|
+
# normalize embeddings if requested
|
|
87
|
+
if normalize:
|
|
88
|
+
outputs = F.normalize(outputs, p=2, dim=-1)
|
|
89
|
+
|
|
90
|
+
return outputs
|
|
91
|
+
|
|
92
|
+
def _encode_images(self, data: Union[torch.Tensor, Image.Image, List[Image.Image]]) -> torch.Tensor:
|
|
93
|
+
"""Encodes images into embeddings using the CLIP model.
|
|
94
|
+
|
|
95
|
+
Processes image inputs (tensors or PIL images) to produce image embeddings.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
`data` : Union[torch.Tensor, Image.Image, List[Image.Image]]
|
|
100
|
+
Input images as a tensor or PIL image(s).
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
image_features : torch.Tensor
|
|
105
|
+
Image embeddings, shape (batch_size, embedding_dim).
|
|
106
|
+
"""
|
|
107
|
+
if isinstance(data, torch.Tensor):
|
|
108
|
+
if data.dim() == 3:
|
|
109
|
+
data = data.unsqueeze(0)
|
|
110
|
+
inputs = {"pixel_values": data.to(self.device)}
|
|
111
|
+
elif isinstance(data, (Image.Image, list)):
|
|
112
|
+
if isinstance(data, Image.Image):
|
|
113
|
+
data = [data]
|
|
114
|
+
inputs = self.processor(images=data, return_tensors="pt", padding=True)
|
|
115
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(f"Invalid image data type: {type(data)}. Expected torch.Tensor, PIL.Image.Image, or List[PIL.Image.Image].")
|
|
118
|
+
return self.model.get_image_features(**inputs)
|
|
119
|
+
|
|
120
|
+
def _encode_texts(self, data: Union[str, List[str], torch.Tensor]) -> torch.Tensor:
|
|
121
|
+
"""Encodes texts into embeddings using the CLIP model.
|
|
122
|
+
|
|
123
|
+
Processes text inputs (strings or tokenized tensors) to produce text embeddings.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
`data` : Union[str, List[str], torch.Tensor]
|
|
128
|
+
Input texts as strings or tokenized tensor.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
text_features : torch.Tensor
|
|
133
|
+
Text embeddings, shape (batch_size, embedding_dim).
|
|
134
|
+
"""
|
|
135
|
+
if isinstance(data, torch.Tensor):
|
|
136
|
+
data = data.to(self.device)
|
|
137
|
+
if data.dim() == 2:
|
|
138
|
+
return data
|
|
139
|
+
if data.dim() == 1:
|
|
140
|
+
data = data.unsqueeze(0)
|
|
141
|
+
attention_mask = torch.ones_like(data)
|
|
142
|
+
return self.model.get_text_features(input_ids=data, attention_mask=attention_mask)
|
|
143
|
+
|
|
144
|
+
if isinstance(data, str):
|
|
145
|
+
data = [data]
|
|
146
|
+
elif isinstance(data, list) and all(isinstance(t, str) for t in data):
|
|
147
|
+
pass
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Invalid text data type: {type(data)}. Expected str, List[str], or torch.Tensor."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
inputs = self.processor(text=data, return_tensors="pt", padding=True, truncation=True)
|
|
154
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
155
|
+
return self.model.get_text_features(**inputs)
|
|
156
|
+
|
|
157
|
+
def compute_similarity(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
"""Computes cosine similarity between image and text embeddings.
|
|
159
|
+
|
|
160
|
+
Calculates the cosine similarity matrix between batches of image and text embeddings.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
`image_features` : torch.Tensor
|
|
165
|
+
Image embeddings, shape (batch_size, embedding_dim).
|
|
166
|
+
`text_features` : torch.Tensor
|
|
167
|
+
Text embeddings, shape (batch_size, embedding_dim).
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
similarity : torch.Tensor
|
|
172
|
+
Cosine similarity scores, shape (batch_size, batch_size).
|
|
173
|
+
"""
|
|
174
|
+
image_features = F.normalize(image_features, p=2, dim=-1)
|
|
175
|
+
text_features = F.normalize(text_features, p=2, dim=-1)
|
|
176
|
+
return torch.matmul(image_features, text_features.T)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
"""
|
|
181
|
+
# ============================================================================
|
|
182
|
+
# USAGE EXAMPLE
|
|
183
|
+
# ============================================================================
|
|
184
|
+
|
|
185
|
+
def main():
|
|
186
|
+
"Demonstrate how to use the CLIP class."
|
|
187
|
+
|
|
188
|
+
print("=== CLIP Usage Example ===\n")
|
|
189
|
+
|
|
190
|
+
# Initialize CLIP model
|
|
191
|
+
clip_model = CLIPEncoder(model_name="openai/clip-vit-base-patch32")
|
|
192
|
+
print(f"Model loaded on device: {clip_model.device}")
|
|
193
|
+
print(f"Model name: {clip_model.model_name}\n")
|
|
194
|
+
|
|
195
|
+
# ========================================================================
|
|
196
|
+
# TEXT ENCODING EXAMPLES
|
|
197
|
+
# ========================================================================
|
|
198
|
+
print("1. TEXT ENCODING:")
|
|
199
|
+
print("-" * 50)
|
|
200
|
+
|
|
201
|
+
# Single text
|
|
202
|
+
single_text = "a photo of a cat"
|
|
203
|
+
text_features_single = clip_model(single_text, data_type="text")
|
|
204
|
+
print(f"Single text: '{single_text}'")
|
|
205
|
+
print(f"Output shape: {text_features_single.shape}")
|
|
206
|
+
print(f"Output dtype: {text_features_single.dtype}")
|
|
207
|
+
print(f"Output range: [{text_features_single.min():.4f}, {text_features_single.max():.4f}]\n")
|
|
208
|
+
|
|
209
|
+
# Multiple texts
|
|
210
|
+
multiple_texts = [
|
|
211
|
+
"a photo of a cat",
|
|
212
|
+
"a photo of a dog",
|
|
213
|
+
"a beautiful sunset over mountains",
|
|
214
|
+
"a red sports car"
|
|
215
|
+
]
|
|
216
|
+
text_features_multiple = clip_model(multiple_texts, data_type="text")
|
|
217
|
+
print(f"Multiple texts ({len(multiple_texts)} items):")
|
|
218
|
+
for i, text in enumerate(multiple_texts):
|
|
219
|
+
print(f" {i + 1}. '{text}'")
|
|
220
|
+
print(f"Output shape: {text_features_multiple.shape}")
|
|
221
|
+
print(f"Output dtype: {text_features_multiple.dtype}")
|
|
222
|
+
print(f"Output range: [{text_features_multiple.min():.4f}, {text_features_multiple.max():.4f}]\n")
|
|
223
|
+
|
|
224
|
+
# ========================================================================
|
|
225
|
+
# IMAGE ENCODING EXAMPLES (using synthetic data for demo)
|
|
226
|
+
# ========================================================================
|
|
227
|
+
print("2. IMAGE ENCODING:")
|
|
228
|
+
print("-" * 50)
|
|
229
|
+
|
|
230
|
+
# Create synthetic PIL images for demonstration
|
|
231
|
+
synthetic_images = []
|
|
232
|
+
for i in range(3):
|
|
233
|
+
# Create random RGB image
|
|
234
|
+
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
|
235
|
+
pil_img = Image.fromarray(img_array)
|
|
236
|
+
synthetic_images.append(pil_img)
|
|
237
|
+
|
|
238
|
+
# Single image
|
|
239
|
+
single_image = synthetic_images[0]
|
|
240
|
+
image_features_single = clip_model(single_image, data_type="img")
|
|
241
|
+
print(f"Single PIL image (size: {single_image.size})")
|
|
242
|
+
print(f"Output shape: {image_features_single.shape}")
|
|
243
|
+
print(f"Output dtype: {image_features_single.dtype}")
|
|
244
|
+
print(f"Output range: [{image_features_single.min():.4f}, {image_features_single.max():.4f}]\n")
|
|
245
|
+
|
|
246
|
+
# Multiple images
|
|
247
|
+
image_features_multiple = clip_model(synthetic_images, data_type="img")
|
|
248
|
+
print(f"Multiple PIL images ({len(synthetic_images)} images)")
|
|
249
|
+
print(f"Output shape: {image_features_multiple.shape}")
|
|
250
|
+
print(f"Output dtype: {image_features_multiple.dtype}")
|
|
251
|
+
print(f"Output range: [{image_features_multiple.min():.4f}, {image_features_multiple.max():.4f}]\n")
|
|
252
|
+
|
|
253
|
+
# Tensor input (pre-processed)
|
|
254
|
+
tensor_input = torch.randn(2, 3, 224, 224) # Batch of 2 images
|
|
255
|
+
image_features_tensor = clip_model(tensor_input, data_type="img")
|
|
256
|
+
print(f"Tensor input shape: {tensor_input.shape}")
|
|
257
|
+
print(f"Output shape: {image_features_tensor.shape}")
|
|
258
|
+
print(f"Output dtype: {image_features_tensor.dtype}\n")
|
|
259
|
+
|
|
260
|
+
# ========================================================================
|
|
261
|
+
# SIMILARITY COMPUTATION EXAMPLE
|
|
262
|
+
# ========================================================================
|
|
263
|
+
print("3. SIMILARITY COMPUTATION:")
|
|
264
|
+
print("-" * 50)
|
|
265
|
+
|
|
266
|
+
# Compute similarity between images and texts
|
|
267
|
+
similarity_matrix = clip_model.compute_similarity(image_features_multiple, text_features_multiple)
|
|
268
|
+
print(f"Similarity matrix shape: {similarity_matrix.shape}")
|
|
269
|
+
print(f"Similarity matrix (images vs texts):")
|
|
270
|
+
print(similarity_matrix.detach().cpu().numpy())
|
|
271
|
+
print()
|
|
272
|
+
|
|
273
|
+
# Find best matches
|
|
274
|
+
best_matches = similarity_matrix.argmax(dim=1)
|
|
275
|
+
print("Best text matches for each image:")
|
|
276
|
+
for i, match_idx in enumerate(best_matches):
|
|
277
|
+
print(f" Image {i + 1} -> Text {match_idx + 1}: '{multiple_texts[match_idx]}'")
|
|
278
|
+
print(f" Similarity score: {similarity_matrix[i, match_idx]:.4f}")
|
|
279
|
+
print()
|
|
280
|
+
|
|
281
|
+
# ========================================================================
|
|
282
|
+
# EXPECTED INPUT/OUTPUT SUMMARY
|
|
283
|
+
# ========================================================================
|
|
284
|
+
print("4. INPUT/OUTPUT SUMMARY:")
|
|
285
|
+
print("-" * 50)
|
|
286
|
+
print("INPUT TYPES:")
|
|
287
|
+
print(" Text:")
|
|
288
|
+
print(" - str: Single text string")
|
|
289
|
+
print(" - List[str]: List of text strings")
|
|
290
|
+
print(" Images:")
|
|
291
|
+
print(" - PIL.Image.Image: Single PIL image")
|
|
292
|
+
print(" - List[PIL.Image.Image]: List of PIL images")
|
|
293
|
+
print(" - torch.Tensor: Pre-processed tensor (C, H, W) or (B, C, H, W)")
|
|
294
|
+
print()
|
|
295
|
+
print("OUTPUT:")
|
|
296
|
+
print(" - torch.Tensor: Shape (batch_size, 512) for ViT-Base models")
|
|
297
|
+
print(" - dtype: torch.float32")
|
|
298
|
+
print(" - Range: [-1, 1] if normalized (default), varies if not normalized")
|
|
299
|
+
print(" - Device: Same as model device")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
if __name__ == "__main__":
|
|
303
|
+
main()
|
|
304
|
+
"""
|