TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CLIPContextProjection(nn.Module):
6
+ """Projects CLIP image embeddings into multiple context tokens.
7
+
8
+ Transforms a single CLIP image embedding into a specified number of context tokens
9
+ using a linear projection followed by layer normalization.
10
+
11
+ Parameters
12
+ ----------
13
+ `clip_embedding_dim` : int
14
+ Dimensionality of the input CLIP embedding (e.g., 319 or 512).
15
+ `num_tokens` : int, optional
16
+ Number of context tokens to generate (default: 4).
17
+ """
18
+ def __init__(self, clip_embedding_dim, num_tokens=4):
19
+ super().__init__()
20
+ self.clip_embedding_dim = clip_embedding_dim
21
+ self.num_tokens = num_tokens
22
+ self.clip_projection = nn.Linear(clip_embedding_dim, clip_embedding_dim * num_tokens)
23
+ self.clip_embedding_norm = nn.LayerNorm(clip_embedding_dim)
24
+
25
+ def forward(self, z_i):
26
+ """Projects CLIP image embedding into context tokens.
27
+
28
+ Applies a linear projection to transform the input embedding into multiple tokens,
29
+ reshapes the output, and applies layer normalization.
30
+
31
+ Parameters
32
+ ----------
33
+ `z_i` : torch.Tensor
34
+ Input CLIP image embedding, shape (batch_size, input_dim).
35
+
36
+ Returns
37
+ -------
38
+ c : torch.Tensor
39
+ Context tokens, shape (batch_size, num_tokens, input_dim).
40
+ """
41
+ batch_size = z_i.shape[0]
42
+ projected = self.clip_projection(z_i)
43
+ c = projected.view(batch_size, self.num_tokens, self.clip_embedding_dim)
44
+ c = self.clip_embedding_norm(c)
45
+ return c
46
+
47
+
48
+ """
49
+ # Example usage
50
+ batch_size = 32
51
+ embed_dim = 319 # Example CLIP embedding dim after PCA
52
+
53
+ projector = Project(input_dim=embed_dim)
54
+ z_i = torch.randn(batch_size, embed_dim)
55
+ c = projector(z_i) # Shape: [batch_size, 4, token_dim]
56
+ print(f"Shape of c: {c.shape}") # Expected: [32, 4, 768]
57
+ """
@@ -0,0 +1,170 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class CLIPEmbeddingProjection(nn.Module):
6
+ """Projection module for dimensionality reduction and reconstruction.
7
+
8
+ Implements a neural network with forward and inverse projections to reduce and
9
+ restore input dimensionality, supporting customizable hidden layers, dropout, and
10
+ layer normalization.
11
+
12
+ Parameters
13
+ ----------
14
+ `clip_embedding_dim` : int, optional
15
+ Input dimensionality (default: 1024).
16
+ `transformer_embedding_dim` : int, optional
17
+ Output dimensionality for forward projection (default: 320).
18
+ `hidden_dim` : int, optional
19
+ Hidden layer dimensionality (default: 512).
20
+ `num_layers` : int, optional
21
+ Number of layers in the projection network (default: 2).
22
+ `dropout_rate` : float, optional
23
+ Dropout probability for regularization (default: 0.2).
24
+ `use_layer_norm` : bool, optional
25
+ Whether to apply layer normalization after hidden layers (default: True).
26
+ """
27
+ def __init__(
28
+ self,
29
+ clip_embedding_dim: int = 1024,
30
+ transformer_embedding_dim: int = 320,
31
+ hidden_dim: int = 512,
32
+ num_layers: int = 2,
33
+ dropout_rate: float = 0.2,
34
+ use_layer_norm: bool = True
35
+ ) -> None:
36
+ super().__init__()
37
+
38
+ self.clip_embedding_dim = clip_embedding_dim
39
+ self.transformer_embedding_dim = transformer_embedding_dim
40
+
41
+ # Forward projection: input_dim -> output_dim
42
+ self.forward_projection = self._build_projection_network(
43
+ clip_embedding_dim, transformer_embedding_dim, hidden_dim, num_layers, dropout_rate, use_layer_norm
44
+ )
45
+
46
+ # Inverse projection: output_dim -> input_dim
47
+ self.inverse_projection = self._build_projection_network(
48
+ transformer_embedding_dim, clip_embedding_dim, hidden_dim, num_layers, dropout_rate, use_layer_norm
49
+ )
50
+ def _build_projection_network(
51
+ self,
52
+ input_dim: int,
53
+ output_dim: int,
54
+ hidden_dim: int,
55
+ num_layers: int,
56
+ dropout: float,
57
+ use_layer_norm: bool
58
+ ) -> nn.Sequential:
59
+ """Builds a projection network with customizable layers.
60
+
61
+ Constructs a neural network with linear layers, optional layer normalization,
62
+ GELU activation, and dropout for either forward or inverse projection.
63
+
64
+ Parameters
65
+ ----------
66
+ `input_dim` : int
67
+ Input dimensionality for the network.
68
+ `output_dim` : int
69
+ Output dimensionality for the network.
70
+ `hidden_dim` : int
71
+ Hidden layer dimensionality.
72
+ `num_layers` : int
73
+ Number of layers in the network.
74
+ `dropout` : float
75
+ Dropout probability for regularization.
76
+ `use_layer_norm` : bool
77
+ Whether to apply layer normalization after hidden layers.
78
+
79
+ Returns
80
+ -------
81
+ network : nn.Sequential
82
+ Sequential container of the projection network layers.
83
+ """
84
+ layers = []
85
+ current_dim = input_dim
86
+
87
+ # Hidden layers
88
+ for i in range(num_layers - 1):
89
+ layers.append(nn.Linear(current_dim, hidden_dim))
90
+ if use_layer_norm:
91
+ layers.append(nn.LayerNorm(hidden_dim))
92
+ layers.append(nn.GELU())
93
+ layers.append(nn.Dropout(dropout))
94
+ current_dim = hidden_dim
95
+
96
+ # Output layer
97
+ layers.append(nn.Linear(current_dim, output_dim))
98
+
99
+ return nn.Sequential(*layers)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ """Projects input to a lower-dimensional space.
103
+
104
+ Applies the forward projection network to reduce the dimensionality of the input tensor.
105
+
106
+ Parameters
107
+ ----------
108
+ `x` : torch.Tensor
109
+ Input tensor to be projected, shape (batch_size, input_dim).
110
+
111
+ Returns
112
+ -------
113
+ x_reduced : torch.Tensor
114
+ Projected tensor, shape (batch_size, output_dim).
115
+ """
116
+ return self.forward_projection(x)
117
+
118
+ def inverse_transform(self, x_reduced: torch.Tensor) -> torch.Tensor:
119
+ """Reconstructs input from lower-dimensional space.
120
+
121
+ Applies the inverse projection network to restore the original dimensionality
122
+ of the input tensor.
123
+
124
+ Parameters
125
+ ----------
126
+ `x_reduced` : torch.Tensor
127
+ Reduced-dimensionality tensor, shape (batch_size, output_dim).
128
+
129
+ Returns
130
+ -------
131
+ x_reconstructed : torch.Tensor
132
+ Reconstructed tensor, shape (batch_size, input_dim).
133
+ """
134
+ return self.inverse_projection(x_reduced)
135
+
136
+ def reconstruction_loss(self, x: torch.Tensor) -> torch.Tensor:
137
+ """Computes the reconstruction loss for the projection.
138
+
139
+ Calculates the mean squared error between the original input and its reconstruction
140
+ after forward and inverse projections.
141
+
142
+ Parameters
143
+ ----------
144
+ `x` : torch.Tensor
145
+ Original input tensor, shape (batch_size, input_dim).
146
+
147
+ Returns
148
+ -------
149
+ loss : torch.Tensor
150
+ Mean squared error loss between the original and reconstructed tensors.
151
+ """
152
+ x_reduced = self.forward(x)
153
+ x_reconstructed = self.inverse_transform(x_reduced)
154
+ return F.mse_loss(x_reconstructed, x)
155
+
156
+ """
157
+ p = Projection(
158
+ input_dim=1024,
159
+ output_dim=512,
160
+ hidden_dim=768,
161
+ num_layers=2,
162
+ dropout=0.1,
163
+ use_layer_norm=True
164
+ )
165
+
166
+ x = torch.randn((100, 1024))
167
+ o = p(x)
168
+ print(o.size())
169
+ """
170
+