TorchDiff 2.1.0__tar.gz → 2.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchdiff-2.1.0 → torchdiff-2.2.0}/PKG-INFO +59 -25
- {torchdiff-2.1.0 → torchdiff-2.2.0}/README.md +57 -25
- {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/PKG-INFO +59 -25
- {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/SOURCES.txt +12 -16
- {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/top_level.txt +1 -0
- torchdiff-2.2.0/ddim/forward_ddim.py +63 -0
- torchdiff-2.2.0/ddim/reverse_ddim.py +135 -0
- torchdiff-2.2.0/ddim/sample_ddim.py +195 -0
- torchdiff-2.2.0/ddim/scheduler.py +140 -0
- torchdiff-2.2.0/ddim/test_ddim.py +426 -0
- torchdiff-2.2.0/ddim/train_ddim.py +550 -0
- torchdiff-2.2.0/ddpm/forward_ddpm.py +53 -0
- torchdiff-2.2.0/ddpm/reverse_ddpm.py +116 -0
- torchdiff-2.2.0/ddpm/sample_ddpm.py +195 -0
- torchdiff-2.2.0/ddpm/scheduler.py +83 -0
- torchdiff-2.2.0/ddpm/test_ddpm.py +439 -0
- torchdiff-2.2.0/ddpm/train_ddpm.py +548 -0
- torchdiff-2.2.0/ldm/autoencoder.py +652 -0
- torchdiff-2.2.0/ldm/sample_ldm.py +243 -0
- torchdiff-2.2.0/ldm/train_autoencoder.py +409 -0
- torchdiff-2.2.0/ldm/train_ldm.py +598 -0
- torchdiff-2.2.0/sde/forward_sde.py +156 -0
- torchdiff-2.2.0/sde/reverse_sde.py +170 -0
- torchdiff-2.2.0/sde/sample_sde.py +208 -0
- torchdiff-2.2.0/sde/scheduler.py +133 -0
- torchdiff-2.2.0/sde/test_sde.py +546 -0
- torchdiff-2.2.0/sde/train_sde.py +612 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/setup.py +2 -5
- torchdiff-2.2.0/torchdiff/__init__.py +8 -0
- torchdiff-2.2.0/torchdiff/ddim.py +1107 -0
- torchdiff-2.2.0/torchdiff/ddpm.py +1113 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/ldm.py +552 -767
- torchdiff-2.2.0/torchdiff/sde.py +1301 -0
- torchdiff-2.2.0/torchdiff/tests/test_ddim.py +426 -0
- torchdiff-2.2.0/torchdiff/tests/test_ddpm.py +439 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/tests/test_ldm.py +45 -127
- torchdiff-2.2.0/torchdiff/tests/test_sde.py +546 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/tests/test_unclip.py +26 -77
- {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/unclip.py +2 -1
- torchdiff-2.2.0/torchdiff/utils.py +1190 -0
- torchdiff-2.2.0/utils/__init__.py +0 -0
- torchdiff-2.2.0/utils/diff_net.py +354 -0
- torchdiff-2.2.0/utils/losses.py +68 -0
- {torchdiff-2.1.0/ldm → torchdiff-2.2.0/utils}/metrics.py +49 -49
- {torchdiff-2.1.0/ldm → torchdiff-2.2.0/utils}/text_encoder.py +244 -151
- torchdiff-2.1.0/ddim/forward_ddim.py +0 -79
- torchdiff-2.1.0/ddim/hyper_param.py +0 -225
- torchdiff-2.1.0/ddim/noise_predictor.py +0 -521
- torchdiff-2.1.0/ddim/reverse_ddim.py +0 -91
- torchdiff-2.1.0/ddim/sample_ddim.py +0 -219
- torchdiff-2.1.0/ddim/text_encoder.py +0 -152
- torchdiff-2.1.0/ddim/train_ddim.py +0 -394
- torchdiff-2.1.0/ddpm/forward_ddpm.py +0 -89
- torchdiff-2.1.0/ddpm/hyper_param.py +0 -180
- torchdiff-2.1.0/ddpm/noise_predictor.py +0 -521
- torchdiff-2.1.0/ddpm/reverse_ddpm.py +0 -102
- torchdiff-2.1.0/ddpm/sample_ddpm.py +0 -213
- torchdiff-2.1.0/ddpm/text_encoder.py +0 -152
- torchdiff-2.1.0/ddpm/train_ddpm.py +0 -386
- torchdiff-2.1.0/ldm/autoencoder.py +0 -855
- torchdiff-2.1.0/ldm/forward_idm.py +0 -100
- torchdiff-2.1.0/ldm/hyper_param.py +0 -239
- torchdiff-2.1.0/ldm/noise_predictor.py +0 -1074
- torchdiff-2.1.0/ldm/reverse_ldm.py +0 -119
- torchdiff-2.1.0/ldm/sample_ldm.py +0 -254
- torchdiff-2.1.0/ldm/train_autoencoder.py +0 -216
- torchdiff-2.1.0/ldm/train_ldm.py +0 -412
- torchdiff-2.1.0/sde/forward_sde.py +0 -98
- torchdiff-2.1.0/sde/hyper_param.py +0 -200
- torchdiff-2.1.0/sde/noise_predictor.py +0 -521
- torchdiff-2.1.0/sde/reverse_sde.py +0 -115
- torchdiff-2.1.0/sde/sample_sde.py +0 -216
- torchdiff-2.1.0/sde/text_encoder.py +0 -152
- torchdiff-2.1.0/sde/train_sde.py +0 -400
- torchdiff-2.1.0/torchdiff/__init__.py +0 -8
- torchdiff-2.1.0/torchdiff/ddim.py +0 -1225
- torchdiff-2.1.0/torchdiff/ddpm.py +0 -1153
- torchdiff-2.1.0/torchdiff/sde.py +0 -1231
- torchdiff-2.1.0/torchdiff/tests/test_ddim.py +0 -551
- torchdiff-2.1.0/torchdiff/tests/test_ddpm.py +0 -1188
- torchdiff-2.1.0/torchdiff/tests/test_sde.py +0 -626
- torchdiff-2.1.0/torchdiff/utils.py +0 -1664
- {torchdiff-2.1.0 → torchdiff-2.2.0}/LICENSE +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/dependency_links.txt +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/requires.txt +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/ddim/__init__.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/ddpm/__init__.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/ldm/__init__.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/sde/__init__.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/setup.cfg +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/tests/__init__.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/__init__.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/clip_model.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/ddim_model.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/decoder_model.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/prior_diff.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/prior_model.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/project_decoder.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/project_prior.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/train_decoder.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/train_prior.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/unclip_sampler.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/upsampler.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/upsampler_trainer.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/utils.py +0 -0
- {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/val_metrics.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: TorchDiff
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.0
|
|
4
4
|
Summary: A PyTorch-based library for diffusion models
|
|
5
5
|
Home-page: https://github.com/LoqmanSamani/TorchDiff
|
|
6
6
|
Author: Loghman Samani
|
|
@@ -10,9 +10,6 @@ Project-URL: Homepage, https://loqmansamani.github.io/torchdiff
|
|
|
10
10
|
Project-URL: Documentation, https://torchdiff.readthedio
|
|
11
11
|
Project-URL: Source, https://github.com/LoqmanSamani/TorchDiff
|
|
12
12
|
Keywords: diffusion models,pytorch,machine learning,deep learning
|
|
13
|
-
Classifier: Programming Language :: Python :: 3
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
16
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
17
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
@@ -22,7 +19,7 @@ Classifier: Intended Audience :: Developers
|
|
|
22
19
|
Classifier: Intended Audience :: Science/Research
|
|
23
20
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
24
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
25
|
-
Requires-Python: >=3.
|
|
22
|
+
Requires-Python: >=3.10
|
|
26
23
|
Description-Content-Type: text/markdown
|
|
27
24
|
License-File: LICENSE
|
|
28
25
|
Requires-Dist: lpips>=0.1.4
|
|
@@ -64,7 +61,7 @@ Dynamic: summary
|
|
|
64
61
|
|
|
65
62
|
[](https://opensource.org/licenses/MIT)
|
|
66
63
|
[](https://pytorch.org/)
|
|
67
|
-
[](https://pypi.org/project/torchdiff/)
|
|
68
65
|
[](https://www.python.org/)
|
|
69
66
|
[](https://pepy.tech/project/torchdiff)
|
|
70
67
|
|
|
@@ -76,7 +73,7 @@ Dynamic: summary
|
|
|
76
73
|
|
|
77
74
|
**TorchDiff** is a PyTorch-based library for building and experimenting with diffusion models, inspired by leading research papers.
|
|
78
75
|
|
|
79
|
-
The **TorchDiff 2.
|
|
76
|
+
The **TorchDiff 2.1.0** release includes implementations of five major diffusion model families:
|
|
80
77
|
- **DDPM** (Denoising Diffusion Probabilistic Models)
|
|
81
78
|
- **DDIM** (Denoising Diffusion Implicit Models)
|
|
82
79
|
- **SDE-based Diffusion**
|
|
@@ -107,34 +104,71 @@ import torch.nn as nn
|
|
|
107
104
|
from torchvision import datasets, transforms
|
|
108
105
|
from torch.utils.data import DataLoader
|
|
109
106
|
|
|
110
|
-
from torchdiff.ddpm import
|
|
111
|
-
|
|
107
|
+
from torchdiff.ddpm import (SchedulerDDPM, ForwardDDPM,
|
|
108
|
+
ReverseDDPM, TrainDDPM, SampleDDPM)
|
|
109
|
+
from torchdiff.utils import DiffusionNetwork, mse_loss
|
|
112
110
|
|
|
113
|
-
#
|
|
111
|
+
# dataset: CIFAR10
|
|
114
112
|
transform = transforms.Compose([
|
|
115
113
|
transforms.Resize(32),
|
|
116
114
|
transforms.ToTensor(),
|
|
117
115
|
transforms.Normalize((0.5,), (0.5,))
|
|
118
116
|
])
|
|
119
|
-
train_dataset = datasets.CIFAR10(
|
|
120
|
-
|
|
117
|
+
train_dataset = datasets.CIFAR10(
|
|
118
|
+
root="./data", train=True, download=True, transform=transform
|
|
119
|
+
)
|
|
120
|
+
train_loader = DataLoader(
|
|
121
|
+
train_dataset, batch_size=64, shuffle=True
|
|
122
|
+
)
|
|
123
|
+
device = 'cuda' # gpu is used for training and sampling
|
|
124
|
+
|
|
125
|
+
# model components
|
|
126
|
+
diff_net = DiffusionNetwork(
|
|
127
|
+
in_channels = 3,
|
|
128
|
+
down_channels = [32, 64, 128],
|
|
129
|
+
mid_channels = [128, 128],
|
|
130
|
+
up_channels = [128, 64, 32],
|
|
131
|
+
down_sampling = [True, True],
|
|
132
|
+
time_embed_dim = 128,
|
|
133
|
+
y_embed_dim = 128,
|
|
134
|
+
num_down_blocks = 2,
|
|
135
|
+
num_mid_blocks = 2,
|
|
136
|
+
num_up_blocks = 2,
|
|
137
|
+
dropout_rate = 0.1,
|
|
138
|
+
cont_time = False # time is not continuous, if SDE models it should be true
|
|
139
|
+
)
|
|
140
|
+
print(sum(p.numel() for p in diff_net.parameters()))
|
|
121
141
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
fwd, rev = ForwardDDPM(vs), ReverseDDPM(vs)
|
|
142
|
+
vs = SchedulerDDPM(time_steps = 400)
|
|
143
|
+
fwd = ForwardDDPM(vs, 'noise') # network is trained to predict noise
|
|
144
|
+
rwd = ReverseDDPM(vs, 'noise')
|
|
126
145
|
|
|
127
|
-
#
|
|
146
|
+
# optimizer
|
|
147
|
+
optim = torch.optim.Adam(diff_net.parameters(), lr=1e-5)
|
|
148
|
+
|
|
149
|
+
# training algorithm
|
|
128
150
|
trainer = TrainDDPM(
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
151
|
+
diff_net = diff_net,
|
|
152
|
+
fwd_ddpm = fwd,
|
|
153
|
+
rwd_ddpm = rwd,
|
|
154
|
+
train_loader = train_loader,
|
|
155
|
+
optim = optim,
|
|
156
|
+
loss_fn = mse_loss,
|
|
157
|
+
max_epochs = 10,
|
|
158
|
+
device = device,
|
|
159
|
+
grad_acc = 2
|
|
132
160
|
)
|
|
133
|
-
trainer()
|
|
161
|
+
#trainer()
|
|
134
162
|
|
|
135
163
|
# Sampling
|
|
136
|
-
sampler = SampleDDPM(
|
|
137
|
-
|
|
164
|
+
sampler = SampleDDPM(
|
|
165
|
+
rwd_ddpm = rwd,
|
|
166
|
+
diff_net = diff_net,
|
|
167
|
+
img_size = (32, 32),
|
|
168
|
+
batch_size = 10,
|
|
169
|
+
in_channels = 3,
|
|
170
|
+
device = device
|
|
171
|
+
)
|
|
138
172
|
images = sampler()
|
|
139
173
|
```
|
|
140
174
|
|
|
@@ -164,12 +198,12 @@ DALL·E 2 architecture leveraging CLIP latents for text-to-image generation.
|
|
|
164
198
|
TorchDiff breaks each model into reusable components:
|
|
165
199
|
- **Forward Diffusion**: Adds noise to data
|
|
166
200
|
- **Reverse Diffusion**: Removes noise to recover data
|
|
167
|
-
- **
|
|
201
|
+
- **Scheduler**: Controls noise schedules
|
|
168
202
|
- **Training**: Complete training pipelines
|
|
169
203
|
- **Sampling**: Efficient inference and generation
|
|
170
204
|
|
|
171
205
|
Additional utilities:
|
|
172
|
-
- **
|
|
206
|
+
- **Diffusion Network**: U-Net-like model with attention and time embeddings
|
|
173
207
|
- **Text Encoder**: Transformer-based for conditional generation
|
|
174
208
|
- **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
|
|
175
209
|
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
[](https://opensource.org/licenses/MIT)
|
|
11
11
|
[](https://pytorch.org/)
|
|
12
|
-
[](https://pypi.org/project/torchdiff/)
|
|
13
13
|
[](https://www.python.org/)
|
|
14
14
|
[](https://pepy.tech/project/torchdiff)
|
|
15
15
|
[](https://github.com/LoqmanSamani/TorchDiff)
|
|
@@ -43,12 +43,12 @@ These models support both **conditional** (e.g., text-to-image) and **unconditio
|
|
|
43
43
|
TorchDiff is designed with **modularity** in mind. Each model is broken down into reusable components:
|
|
44
44
|
- **Forward Diffusion**: Adds noise (e.g., `ForwardDDPM`).
|
|
45
45
|
- **Reverse Diffusion**: Removes noise to recover data (e.g., `ReverseDDPM`).
|
|
46
|
-
- **
|
|
46
|
+
- **Scheduler**: Controls noise schedules (e.g., `SchedulerDDPM`).
|
|
47
47
|
- **Training**: Full training pipelines (e.g., `TrainDDPM`).
|
|
48
48
|
- **Sampling**: Efficient inference and generation (e.g., `SampleDDPM`).
|
|
49
49
|
|
|
50
50
|
Additional utilities:
|
|
51
|
-
- **
|
|
51
|
+
- **Diffusion Network**: A U-Net-like model with attention and time embeddings used as main model.
|
|
52
52
|
- **Text Encoder**: Transformer-based (e.g., BERT) for conditional generation.
|
|
53
53
|
- **Metrics**: Evaluation suite including MSE, PSNR, SSIM, FID, and LPIPS.
|
|
54
54
|
|
|
@@ -56,7 +56,7 @@ Additional utilities:
|
|
|
56
56
|
|
|
57
57
|
## ⚡ Quick Start
|
|
58
58
|
|
|
59
|
-
Here’s a minimal working example to train and sample with **DDPM** on dummy data:
|
|
59
|
+
Here’s a minimal working example to train and sample with **DDPM** on dummy data:
|
|
60
60
|
|
|
61
61
|
```python
|
|
62
62
|
import torch
|
|
@@ -64,40 +64,72 @@ import torch.nn as nn
|
|
|
64
64
|
from torchvision import datasets, transforms
|
|
65
65
|
from torch.utils.data import DataLoader
|
|
66
66
|
|
|
67
|
-
from torchdiff.ddpm import
|
|
68
|
-
|
|
67
|
+
from torchdiff.ddpm import (SchedulerDDPM, ForwardDDPM,
|
|
68
|
+
ReverseDDPM, TrainDDPM, SampleDDPM)
|
|
69
|
+
from torchdiff.utils import DiffusionNetwork, mse_loss
|
|
69
70
|
|
|
70
|
-
#
|
|
71
|
+
# dataset: CIFAR10
|
|
71
72
|
transform = transforms.Compose([
|
|
72
73
|
transforms.Resize(32),
|
|
73
74
|
transforms.ToTensor(),
|
|
74
75
|
transforms.Normalize((0.5,), (0.5,))
|
|
75
76
|
])
|
|
76
|
-
train_dataset = datasets.CIFAR10(
|
|
77
|
-
|
|
77
|
+
train_dataset = datasets.CIFAR10(
|
|
78
|
+
root="./data", train=True, download=True, transform=transform
|
|
79
|
+
)
|
|
80
|
+
train_loader = DataLoader(
|
|
81
|
+
train_dataset, batch_size=64, shuffle=True
|
|
82
|
+
)
|
|
83
|
+
device = 'cuda' # gpu is used for training and sampling
|
|
84
|
+
|
|
85
|
+
# model components
|
|
86
|
+
diff_net = DiffusionNetwork(
|
|
87
|
+
in_channels = 3,
|
|
88
|
+
down_channels = [32, 64, 128],
|
|
89
|
+
mid_channels = [128, 128],
|
|
90
|
+
up_channels = [128, 64, 32],
|
|
91
|
+
down_sampling = [True, True],
|
|
92
|
+
time_embed_dim = 128,
|
|
93
|
+
y_embed_dim = 128,
|
|
94
|
+
num_down_blocks = 2,
|
|
95
|
+
num_mid_blocks = 2,
|
|
96
|
+
num_up_blocks = 2,
|
|
97
|
+
dropout_rate = 0.1,
|
|
98
|
+
cont_time = False # time is not continuous, if SDE models it should be true
|
|
99
|
+
)
|
|
100
|
+
print(sum(p.numel() for p in diff_net.parameters()))
|
|
78
101
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
fwd, rev = ForwardDDPM(vs), ReverseDDPM(vs)
|
|
102
|
+
vs = SchedulerDDPM(time_steps = 400)
|
|
103
|
+
fwd = ForwardDDPM(vs, 'noise') # network is trained to predict noise
|
|
104
|
+
rwd = ReverseDDPM(vs, 'noise')
|
|
83
105
|
|
|
84
|
-
#
|
|
85
|
-
optim = torch.optim.Adam(
|
|
86
|
-
loss_fn = nn.MSELoss()
|
|
106
|
+
# optimizer
|
|
107
|
+
optim = torch.optim.Adam(diff_net.parameters(), lr=1e-5)
|
|
87
108
|
|
|
88
|
-
#
|
|
109
|
+
# training algorithm
|
|
89
110
|
trainer = TrainDDPM(
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
111
|
+
diff_net = diff_net,
|
|
112
|
+
fwd_ddpm = fwd,
|
|
113
|
+
rwd_ddpm = rwd,
|
|
114
|
+
train_loader = train_loader,
|
|
115
|
+
optim = optim,
|
|
116
|
+
loss_fn = mse_loss,
|
|
117
|
+
max_epochs = 10,
|
|
118
|
+
device = device,
|
|
119
|
+
grad_acc = 2
|
|
93
120
|
)
|
|
94
|
-
trainer()
|
|
121
|
+
#trainer()
|
|
95
122
|
|
|
96
123
|
# Sampling
|
|
97
|
-
sampler = SampleDDPM(
|
|
98
|
-
|
|
124
|
+
sampler = SampleDDPM(
|
|
125
|
+
rwd_ddpm = rwd,
|
|
126
|
+
diff_net = diff_net,
|
|
127
|
+
img_size = (32, 32),
|
|
128
|
+
batch_size = 10,
|
|
129
|
+
in_channels = 3,
|
|
130
|
+
device = device
|
|
131
|
+
)
|
|
99
132
|
images = sampler()
|
|
100
|
-
print("Generated images shape:", images.shape)
|
|
101
133
|
```
|
|
102
134
|
|
|
103
135
|
For detailed examples, check the [examples/](https://github.com/LoqmanSamani/TorchDiff/tree/systembiology/examples) directory.
|
|
@@ -130,7 +162,7 @@ pip install -r requirements.txt
|
|
|
130
162
|
pip install .
|
|
131
163
|
```
|
|
132
164
|
|
|
133
|
-
> Requires **Python 3.
|
|
165
|
+
> Requires **Python 3.10+**. For GPU acceleration, ensure PyTorch is installed with the correct CUDA version.
|
|
134
166
|
|
|
135
167
|
---
|
|
136
168
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: TorchDiff
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.0
|
|
4
4
|
Summary: A PyTorch-based library for diffusion models
|
|
5
5
|
Home-page: https://github.com/LoqmanSamani/TorchDiff
|
|
6
6
|
Author: Loghman Samani
|
|
@@ -10,9 +10,6 @@ Project-URL: Homepage, https://loqmansamani.github.io/torchdiff
|
|
|
10
10
|
Project-URL: Documentation, https://torchdiff.readthedio
|
|
11
11
|
Project-URL: Source, https://github.com/LoqmanSamani/TorchDiff
|
|
12
12
|
Keywords: diffusion models,pytorch,machine learning,deep learning
|
|
13
|
-
Classifier: Programming Language :: Python :: 3
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
16
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
17
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
@@ -22,7 +19,7 @@ Classifier: Intended Audience :: Developers
|
|
|
22
19
|
Classifier: Intended Audience :: Science/Research
|
|
23
20
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
24
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
25
|
-
Requires-Python: >=3.
|
|
22
|
+
Requires-Python: >=3.10
|
|
26
23
|
Description-Content-Type: text/markdown
|
|
27
24
|
License-File: LICENSE
|
|
28
25
|
Requires-Dist: lpips>=0.1.4
|
|
@@ -64,7 +61,7 @@ Dynamic: summary
|
|
|
64
61
|
|
|
65
62
|
[](https://opensource.org/licenses/MIT)
|
|
66
63
|
[](https://pytorch.org/)
|
|
67
|
-
[](https://pypi.org/project/torchdiff/)
|
|
68
65
|
[](https://www.python.org/)
|
|
69
66
|
[](https://pepy.tech/project/torchdiff)
|
|
70
67
|
|
|
@@ -76,7 +73,7 @@ Dynamic: summary
|
|
|
76
73
|
|
|
77
74
|
**TorchDiff** is a PyTorch-based library for building and experimenting with diffusion models, inspired by leading research papers.
|
|
78
75
|
|
|
79
|
-
The **TorchDiff 2.
|
|
76
|
+
The **TorchDiff 2.1.0** release includes implementations of five major diffusion model families:
|
|
80
77
|
- **DDPM** (Denoising Diffusion Probabilistic Models)
|
|
81
78
|
- **DDIM** (Denoising Diffusion Implicit Models)
|
|
82
79
|
- **SDE-based Diffusion**
|
|
@@ -107,34 +104,71 @@ import torch.nn as nn
|
|
|
107
104
|
from torchvision import datasets, transforms
|
|
108
105
|
from torch.utils.data import DataLoader
|
|
109
106
|
|
|
110
|
-
from torchdiff.ddpm import
|
|
111
|
-
|
|
107
|
+
from torchdiff.ddpm import (SchedulerDDPM, ForwardDDPM,
|
|
108
|
+
ReverseDDPM, TrainDDPM, SampleDDPM)
|
|
109
|
+
from torchdiff.utils import DiffusionNetwork, mse_loss
|
|
112
110
|
|
|
113
|
-
#
|
|
111
|
+
# dataset: CIFAR10
|
|
114
112
|
transform = transforms.Compose([
|
|
115
113
|
transforms.Resize(32),
|
|
116
114
|
transforms.ToTensor(),
|
|
117
115
|
transforms.Normalize((0.5,), (0.5,))
|
|
118
116
|
])
|
|
119
|
-
train_dataset = datasets.CIFAR10(
|
|
120
|
-
|
|
117
|
+
train_dataset = datasets.CIFAR10(
|
|
118
|
+
root="./data", train=True, download=True, transform=transform
|
|
119
|
+
)
|
|
120
|
+
train_loader = DataLoader(
|
|
121
|
+
train_dataset, batch_size=64, shuffle=True
|
|
122
|
+
)
|
|
123
|
+
device = 'cuda' # gpu is used for training and sampling
|
|
124
|
+
|
|
125
|
+
# model components
|
|
126
|
+
diff_net = DiffusionNetwork(
|
|
127
|
+
in_channels = 3,
|
|
128
|
+
down_channels = [32, 64, 128],
|
|
129
|
+
mid_channels = [128, 128],
|
|
130
|
+
up_channels = [128, 64, 32],
|
|
131
|
+
down_sampling = [True, True],
|
|
132
|
+
time_embed_dim = 128,
|
|
133
|
+
y_embed_dim = 128,
|
|
134
|
+
num_down_blocks = 2,
|
|
135
|
+
num_mid_blocks = 2,
|
|
136
|
+
num_up_blocks = 2,
|
|
137
|
+
dropout_rate = 0.1,
|
|
138
|
+
cont_time = False # time is not continuous, if SDE models it should be true
|
|
139
|
+
)
|
|
140
|
+
print(sum(p.numel() for p in diff_net.parameters()))
|
|
121
141
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
fwd, rev = ForwardDDPM(vs), ReverseDDPM(vs)
|
|
142
|
+
vs = SchedulerDDPM(time_steps = 400)
|
|
143
|
+
fwd = ForwardDDPM(vs, 'noise') # network is trained to predict noise
|
|
144
|
+
rwd = ReverseDDPM(vs, 'noise')
|
|
126
145
|
|
|
127
|
-
#
|
|
146
|
+
# optimizer
|
|
147
|
+
optim = torch.optim.Adam(diff_net.parameters(), lr=1e-5)
|
|
148
|
+
|
|
149
|
+
# training algorithm
|
|
128
150
|
trainer = TrainDDPM(
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
151
|
+
diff_net = diff_net,
|
|
152
|
+
fwd_ddpm = fwd,
|
|
153
|
+
rwd_ddpm = rwd,
|
|
154
|
+
train_loader = train_loader,
|
|
155
|
+
optim = optim,
|
|
156
|
+
loss_fn = mse_loss,
|
|
157
|
+
max_epochs = 10,
|
|
158
|
+
device = device,
|
|
159
|
+
grad_acc = 2
|
|
132
160
|
)
|
|
133
|
-
trainer()
|
|
161
|
+
#trainer()
|
|
134
162
|
|
|
135
163
|
# Sampling
|
|
136
|
-
sampler = SampleDDPM(
|
|
137
|
-
|
|
164
|
+
sampler = SampleDDPM(
|
|
165
|
+
rwd_ddpm = rwd,
|
|
166
|
+
diff_net = diff_net,
|
|
167
|
+
img_size = (32, 32),
|
|
168
|
+
batch_size = 10,
|
|
169
|
+
in_channels = 3,
|
|
170
|
+
device = device
|
|
171
|
+
)
|
|
138
172
|
images = sampler()
|
|
139
173
|
```
|
|
140
174
|
|
|
@@ -164,12 +198,12 @@ DALL·E 2 architecture leveraging CLIP latents for text-to-image generation.
|
|
|
164
198
|
TorchDiff breaks each model into reusable components:
|
|
165
199
|
- **Forward Diffusion**: Adds noise to data
|
|
166
200
|
- **Reverse Diffusion**: Removes noise to recover data
|
|
167
|
-
- **
|
|
201
|
+
- **Scheduler**: Controls noise schedules
|
|
168
202
|
- **Training**: Complete training pipelines
|
|
169
203
|
- **Sampling**: Efficient inference and generation
|
|
170
204
|
|
|
171
205
|
Additional utilities:
|
|
172
|
-
- **
|
|
206
|
+
- **Diffusion Network**: U-Net-like model with attention and time embeddings
|
|
173
207
|
- **Text Encoder**: Transformer-based for conditional generation
|
|
174
208
|
- **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
|
|
175
209
|
|
|
@@ -8,38 +8,29 @@ TorchDiff.egg-info/requires.txt
|
|
|
8
8
|
TorchDiff.egg-info/top_level.txt
|
|
9
9
|
ddim/__init__.py
|
|
10
10
|
ddim/forward_ddim.py
|
|
11
|
-
ddim/hyper_param.py
|
|
12
|
-
ddim/noise_predictor.py
|
|
13
11
|
ddim/reverse_ddim.py
|
|
14
12
|
ddim/sample_ddim.py
|
|
15
|
-
ddim/
|
|
13
|
+
ddim/scheduler.py
|
|
14
|
+
ddim/test_ddim.py
|
|
16
15
|
ddim/train_ddim.py
|
|
17
16
|
ddpm/__init__.py
|
|
18
17
|
ddpm/forward_ddpm.py
|
|
19
|
-
ddpm/hyper_param.py
|
|
20
|
-
ddpm/noise_predictor.py
|
|
21
18
|
ddpm/reverse_ddpm.py
|
|
22
19
|
ddpm/sample_ddpm.py
|
|
23
|
-
ddpm/
|
|
20
|
+
ddpm/scheduler.py
|
|
21
|
+
ddpm/test_ddpm.py
|
|
24
22
|
ddpm/train_ddpm.py
|
|
25
23
|
ldm/__init__.py
|
|
26
24
|
ldm/autoencoder.py
|
|
27
|
-
ldm/forward_idm.py
|
|
28
|
-
ldm/hyper_param.py
|
|
29
|
-
ldm/metrics.py
|
|
30
|
-
ldm/noise_predictor.py
|
|
31
|
-
ldm/reverse_ldm.py
|
|
32
25
|
ldm/sample_ldm.py
|
|
33
|
-
ldm/text_encoder.py
|
|
34
26
|
ldm/train_autoencoder.py
|
|
35
27
|
ldm/train_ldm.py
|
|
36
28
|
sde/__init__.py
|
|
37
29
|
sde/forward_sde.py
|
|
38
|
-
sde/hyper_param.py
|
|
39
|
-
sde/noise_predictor.py
|
|
40
30
|
sde/reverse_sde.py
|
|
41
31
|
sde/sample_sde.py
|
|
42
|
-
sde/
|
|
32
|
+
sde/scheduler.py
|
|
33
|
+
sde/test_sde.py
|
|
43
34
|
sde/train_sde.py
|
|
44
35
|
torchdiff/__init__.py
|
|
45
36
|
torchdiff/ddim.py
|
|
@@ -68,4 +59,9 @@ unclip/unclip_sampler.py
|
|
|
68
59
|
unclip/upsampler.py
|
|
69
60
|
unclip/upsampler_trainer.py
|
|
70
61
|
unclip/utils.py
|
|
71
|
-
unclip/val_metrics.py
|
|
62
|
+
unclip/val_metrics.py
|
|
63
|
+
utils/__init__.py
|
|
64
|
+
utils/diff_net.py
|
|
65
|
+
utils/losses.py
|
|
66
|
+
utils/metrics.py
|
|
67
|
+
utils/text_encoder.py
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ForwardDDIM(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Implements the forward (noising) process of DDIM.
|
|
9
|
+
|
|
10
|
+
This module samples x_t from the forward diffusion distribution:
|
|
11
|
+
|
|
12
|
+
q(x_t | x_0) = N(x_t; sqrt(alphā_t) * x_0, (1 - alphā_t) * I)
|
|
13
|
+
|
|
14
|
+
It also computes the appropriate training target depending on the
|
|
15
|
+
prediction parameterization (noise, x0, or v-prediction).
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
scheduler: Noise scheduler containing precomputed diffusion coefficients.
|
|
19
|
+
pred_type: Type of model prediction. One of ["noise", "x0", "v"].
|
|
20
|
+
"""
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
scheduler: nn.Module,
|
|
24
|
+
pred_type: str = "noise"
|
|
25
|
+
):
|
|
26
|
+
super().__init__()
|
|
27
|
+
valid_types = ["noise", "x0", "v"]
|
|
28
|
+
if pred_type not in valid_types:
|
|
29
|
+
raise ValueError(f"prediction_type must be one of {valid_types}, got {pred_type}")
|
|
30
|
+
self.vs = scheduler
|
|
31
|
+
self.pred_type = pred_type
|
|
32
|
+
|
|
33
|
+
def forward(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
34
|
+
"""
|
|
35
|
+
Perform the forward diffusion step and compute the training target.
|
|
36
|
+
|
|
37
|
+
Samples x_t by adding noise to the clean input x_0 at timestep t,
|
|
38
|
+
and returns the corresponding supervision target for training.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
x0: Clean input data of shape (batch, ...).
|
|
42
|
+
t: Discrete diffusion timesteps of shape (batch,).
|
|
43
|
+
noise: Gaussian noise of same shape as x0.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
xt: Noised data x_t of shape (batch, ...).
|
|
47
|
+
target: Training target corresponding to pred_type:
|
|
48
|
+
- "noise": the added noise ε
|
|
49
|
+
- "x0": the original clean input x0
|
|
50
|
+
- "v": the velocity parameterization
|
|
51
|
+
"""
|
|
52
|
+
sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t]
|
|
53
|
+
sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t]
|
|
54
|
+
sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, x0.shape)
|
|
55
|
+
sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, x0.shape)
|
|
56
|
+
xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
|
|
57
|
+
if self.pred_type == "noise":
|
|
58
|
+
target = noise
|
|
59
|
+
elif self.pred_type == "x0":
|
|
60
|
+
target = x0
|
|
61
|
+
elif self.pred_type == "v":
|
|
62
|
+
target = sqrt_alpha_cumprod_t * noise - sqrt_one_minus_alpha_cumprod_t * x0
|
|
63
|
+
return xt, target
|