TorchDiff 2.1.0__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {torchdiff-2.1.0 → torchdiff-2.2.0}/PKG-INFO +59 -25
  2. {torchdiff-2.1.0 → torchdiff-2.2.0}/README.md +57 -25
  3. {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/PKG-INFO +59 -25
  4. {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/SOURCES.txt +12 -16
  5. {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/top_level.txt +1 -0
  6. torchdiff-2.2.0/ddim/forward_ddim.py +63 -0
  7. torchdiff-2.2.0/ddim/reverse_ddim.py +135 -0
  8. torchdiff-2.2.0/ddim/sample_ddim.py +195 -0
  9. torchdiff-2.2.0/ddim/scheduler.py +140 -0
  10. torchdiff-2.2.0/ddim/test_ddim.py +426 -0
  11. torchdiff-2.2.0/ddim/train_ddim.py +550 -0
  12. torchdiff-2.2.0/ddpm/forward_ddpm.py +53 -0
  13. torchdiff-2.2.0/ddpm/reverse_ddpm.py +116 -0
  14. torchdiff-2.2.0/ddpm/sample_ddpm.py +195 -0
  15. torchdiff-2.2.0/ddpm/scheduler.py +83 -0
  16. torchdiff-2.2.0/ddpm/test_ddpm.py +439 -0
  17. torchdiff-2.2.0/ddpm/train_ddpm.py +548 -0
  18. torchdiff-2.2.0/ldm/autoencoder.py +652 -0
  19. torchdiff-2.2.0/ldm/sample_ldm.py +243 -0
  20. torchdiff-2.2.0/ldm/train_autoencoder.py +409 -0
  21. torchdiff-2.2.0/ldm/train_ldm.py +598 -0
  22. torchdiff-2.2.0/sde/forward_sde.py +156 -0
  23. torchdiff-2.2.0/sde/reverse_sde.py +170 -0
  24. torchdiff-2.2.0/sde/sample_sde.py +208 -0
  25. torchdiff-2.2.0/sde/scheduler.py +133 -0
  26. torchdiff-2.2.0/sde/test_sde.py +546 -0
  27. torchdiff-2.2.0/sde/train_sde.py +612 -0
  28. {torchdiff-2.1.0 → torchdiff-2.2.0}/setup.py +2 -5
  29. torchdiff-2.2.0/torchdiff/__init__.py +8 -0
  30. torchdiff-2.2.0/torchdiff/ddim.py +1107 -0
  31. torchdiff-2.2.0/torchdiff/ddpm.py +1113 -0
  32. {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/ldm.py +552 -767
  33. torchdiff-2.2.0/torchdiff/sde.py +1301 -0
  34. torchdiff-2.2.0/torchdiff/tests/test_ddim.py +426 -0
  35. torchdiff-2.2.0/torchdiff/tests/test_ddpm.py +439 -0
  36. {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/tests/test_ldm.py +45 -127
  37. torchdiff-2.2.0/torchdiff/tests/test_sde.py +546 -0
  38. {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/tests/test_unclip.py +26 -77
  39. {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/unclip.py +2 -1
  40. torchdiff-2.2.0/torchdiff/utils.py +1190 -0
  41. torchdiff-2.2.0/utils/__init__.py +0 -0
  42. torchdiff-2.2.0/utils/diff_net.py +354 -0
  43. torchdiff-2.2.0/utils/losses.py +68 -0
  44. {torchdiff-2.1.0/ldm → torchdiff-2.2.0/utils}/metrics.py +49 -49
  45. {torchdiff-2.1.0/ldm → torchdiff-2.2.0/utils}/text_encoder.py +244 -151
  46. torchdiff-2.1.0/ddim/forward_ddim.py +0 -79
  47. torchdiff-2.1.0/ddim/hyper_param.py +0 -225
  48. torchdiff-2.1.0/ddim/noise_predictor.py +0 -521
  49. torchdiff-2.1.0/ddim/reverse_ddim.py +0 -91
  50. torchdiff-2.1.0/ddim/sample_ddim.py +0 -219
  51. torchdiff-2.1.0/ddim/text_encoder.py +0 -152
  52. torchdiff-2.1.0/ddim/train_ddim.py +0 -394
  53. torchdiff-2.1.0/ddpm/forward_ddpm.py +0 -89
  54. torchdiff-2.1.0/ddpm/hyper_param.py +0 -180
  55. torchdiff-2.1.0/ddpm/noise_predictor.py +0 -521
  56. torchdiff-2.1.0/ddpm/reverse_ddpm.py +0 -102
  57. torchdiff-2.1.0/ddpm/sample_ddpm.py +0 -213
  58. torchdiff-2.1.0/ddpm/text_encoder.py +0 -152
  59. torchdiff-2.1.0/ddpm/train_ddpm.py +0 -386
  60. torchdiff-2.1.0/ldm/autoencoder.py +0 -855
  61. torchdiff-2.1.0/ldm/forward_idm.py +0 -100
  62. torchdiff-2.1.0/ldm/hyper_param.py +0 -239
  63. torchdiff-2.1.0/ldm/noise_predictor.py +0 -1074
  64. torchdiff-2.1.0/ldm/reverse_ldm.py +0 -119
  65. torchdiff-2.1.0/ldm/sample_ldm.py +0 -254
  66. torchdiff-2.1.0/ldm/train_autoencoder.py +0 -216
  67. torchdiff-2.1.0/ldm/train_ldm.py +0 -412
  68. torchdiff-2.1.0/sde/forward_sde.py +0 -98
  69. torchdiff-2.1.0/sde/hyper_param.py +0 -200
  70. torchdiff-2.1.0/sde/noise_predictor.py +0 -521
  71. torchdiff-2.1.0/sde/reverse_sde.py +0 -115
  72. torchdiff-2.1.0/sde/sample_sde.py +0 -216
  73. torchdiff-2.1.0/sde/text_encoder.py +0 -152
  74. torchdiff-2.1.0/sde/train_sde.py +0 -400
  75. torchdiff-2.1.0/torchdiff/__init__.py +0 -8
  76. torchdiff-2.1.0/torchdiff/ddim.py +0 -1225
  77. torchdiff-2.1.0/torchdiff/ddpm.py +0 -1153
  78. torchdiff-2.1.0/torchdiff/sde.py +0 -1231
  79. torchdiff-2.1.0/torchdiff/tests/test_ddim.py +0 -551
  80. torchdiff-2.1.0/torchdiff/tests/test_ddpm.py +0 -1188
  81. torchdiff-2.1.0/torchdiff/tests/test_sde.py +0 -626
  82. torchdiff-2.1.0/torchdiff/utils.py +0 -1664
  83. {torchdiff-2.1.0 → torchdiff-2.2.0}/LICENSE +0 -0
  84. {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/dependency_links.txt +0 -0
  85. {torchdiff-2.1.0 → torchdiff-2.2.0}/TorchDiff.egg-info/requires.txt +0 -0
  86. {torchdiff-2.1.0 → torchdiff-2.2.0}/ddim/__init__.py +0 -0
  87. {torchdiff-2.1.0 → torchdiff-2.2.0}/ddpm/__init__.py +0 -0
  88. {torchdiff-2.1.0 → torchdiff-2.2.0}/ldm/__init__.py +0 -0
  89. {torchdiff-2.1.0 → torchdiff-2.2.0}/sde/__init__.py +0 -0
  90. {torchdiff-2.1.0 → torchdiff-2.2.0}/setup.cfg +0 -0
  91. {torchdiff-2.1.0 → torchdiff-2.2.0}/torchdiff/tests/__init__.py +0 -0
  92. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/__init__.py +0 -0
  93. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/clip_model.py +0 -0
  94. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/ddim_model.py +0 -0
  95. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/decoder_model.py +0 -0
  96. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/prior_diff.py +0 -0
  97. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/prior_model.py +0 -0
  98. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/project_decoder.py +0 -0
  99. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/project_prior.py +0 -0
  100. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/train_decoder.py +0 -0
  101. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/train_prior.py +0 -0
  102. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/unclip_sampler.py +0 -0
  103. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/upsampler.py +0 -0
  104. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/upsampler_trainer.py +0 -0
  105. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/utils.py +0 -0
  106. {torchdiff-2.1.0 → torchdiff-2.2.0}/unclip/val_metrics.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: TorchDiff
3
- Version: 2.1.0
3
+ Version: 2.2.0
4
4
  Summary: A PyTorch-based library for diffusion models
5
5
  Home-page: https://github.com/LoqmanSamani/TorchDiff
6
6
  Author: Loghman Samani
@@ -10,9 +10,6 @@ Project-URL: Homepage, https://loqmansamani.github.io/torchdiff
10
10
  Project-URL: Documentation, https://torchdiff.readthedio
11
11
  Project-URL: Source, https://github.com/LoqmanSamani/TorchDiff
12
12
  Keywords: diffusion models,pytorch,machine learning,deep learning
13
- Classifier: Programming Language :: Python :: 3
14
- Classifier: Programming Language :: Python :: 3.8
15
- Classifier: Programming Language :: Python :: 3.9
16
13
  Classifier: Programming Language :: Python :: 3.10
17
14
  Classifier: Programming Language :: Python :: 3.11
18
15
  Classifier: Programming Language :: Python :: 3.12
@@ -22,7 +19,7 @@ Classifier: Intended Audience :: Developers
22
19
  Classifier: Intended Audience :: Science/Research
23
20
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
21
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
25
- Requires-Python: >=3.8
22
+ Requires-Python: >=3.10
26
23
  Description-Content-Type: text/markdown
27
24
  License-File: LICENSE
28
25
  Requires-Dist: lpips>=0.1.4
@@ -64,7 +61,7 @@ Dynamic: summary
64
61
 
65
62
  [![License: MIT](https://img.shields.io/badge/license-MIT-red?style=plastic)](https://opensource.org/licenses/MIT)
66
63
  [![PyTorch](https://img.shields.io/badge/PyTorch-white?style=plastic&logo=pytorch&logoColor=red)](https://pytorch.org/)
67
- [![Version](https://img.shields.io/badge/version-2.0.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
64
+ [![Version](https://img.shields.io/badge/version-2.1.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
68
65
  [![Python](https://img.shields.io/badge/python-3.8%2B-blue?style=plastic&logo=python&logoColor=white)](https://www.python.org/)
69
66
  [![Downloads](https://pepy.tech/badge/torchdiff)](https://pepy.tech/project/torchdiff)
70
67
 
@@ -76,7 +73,7 @@ Dynamic: summary
76
73
 
77
74
  **TorchDiff** is a PyTorch-based library for building and experimenting with diffusion models, inspired by leading research papers.
78
75
 
79
- The **TorchDiff 2.0.0** release includes implementations of five major diffusion model families:
76
+ The **TorchDiff 2.1.0** release includes implementations of five major diffusion model families:
80
77
  - **DDPM** (Denoising Diffusion Probabilistic Models)
81
78
  - **DDIM** (Denoising Diffusion Implicit Models)
82
79
  - **SDE-based Diffusion**
@@ -107,34 +104,71 @@ import torch.nn as nn
107
104
  from torchvision import datasets, transforms
108
105
  from torch.utils.data import DataLoader
109
106
 
110
- from torchdiff.ddpm import VarianceSchedulerDDPM, ForwardDDPM, ReverseDDPM, TrainDDPM, SampleDDPM
111
- from torchdiff.utils import NoisePredictor
107
+ from torchdiff.ddpm import (SchedulerDDPM, ForwardDDPM,
108
+ ReverseDDPM, TrainDDPM, SampleDDPM)
109
+ from torchdiff.utils import DiffusionNetwork, mse_loss
112
110
 
113
- # Dataset setup
111
+ # dataset: CIFAR10
114
112
  transform = transforms.Compose([
115
113
  transforms.Resize(32),
116
114
  transforms.ToTensor(),
117
115
  transforms.Normalize((0.5,), (0.5,))
118
116
  ])
119
- train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
120
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
117
+ train_dataset = datasets.CIFAR10(
118
+ root="./data", train=True, download=True, transform=transform
119
+ )
120
+ train_loader = DataLoader(
121
+ train_dataset, batch_size=64, shuffle=True
122
+ )
123
+ device = 'cuda' # gpu is used for training and sampling
124
+
125
+ # model components
126
+ diff_net = DiffusionNetwork(
127
+ in_channels = 3,
128
+ down_channels = [32, 64, 128],
129
+ mid_channels = [128, 128],
130
+ up_channels = [128, 64, 32],
131
+ down_sampling = [True, True],
132
+ time_embed_dim = 128,
133
+ y_embed_dim = 128,
134
+ num_down_blocks = 2,
135
+ num_mid_blocks = 2,
136
+ num_up_blocks = 2,
137
+ dropout_rate = 0.1,
138
+ cont_time = False # time is not continuous, if SDE models it should be true
139
+ )
140
+ print(sum(p.numel() for p in diff_net.parameters()))
121
141
 
122
- # Model components
123
- noise_pred = NoisePredictor(in_channels=3)
124
- vs = VarianceSchedulerDDPM(num_steps=1000)
125
- fwd, rev = ForwardDDPM(vs), ReverseDDPM(vs)
142
+ vs = SchedulerDDPM(time_steps = 400)
143
+ fwd = ForwardDDPM(vs, 'noise') # network is trained to predict noise
144
+ rwd = ReverseDDPM(vs, 'noise')
126
145
 
127
- # Training
146
+ # optimizer
147
+ optim = torch.optim.Adam(diff_net.parameters(), lr=1e-5)
148
+
149
+ # training algorithm
128
150
  trainer = TrainDDPM(
129
- noise_predictor=noise_pred, forward_diffusion=fwd, reverse_diffusion=rev,
130
- conditional_model=None, optimizer=torch.optim.Adam(noise_pred.parameters(), lr=1e-4),
131
- objective=nn.MSELoss(), data_loader=train_loader, max_epochs=1, device="cpu"
151
+ diff_net = diff_net,
152
+ fwd_ddpm = fwd,
153
+ rwd_ddpm = rwd,
154
+ train_loader = train_loader,
155
+ optim = optim,
156
+ loss_fn = mse_loss,
157
+ max_epochs = 10,
158
+ device = device,
159
+ grad_acc = 2
132
160
  )
133
- trainer()
161
+ #trainer()
134
162
 
135
163
  # Sampling
136
- sampler = SampleDDPM(reverse_diffusion=rev, noise_predictor=noise_pred,
137
- image_shape=(32, 32), batch_size=4, in_channels=3, device="cpu")
164
+ sampler = SampleDDPM(
165
+ rwd_ddpm = rwd,
166
+ diff_net = diff_net,
167
+ img_size = (32, 32),
168
+ batch_size = 10,
169
+ in_channels = 3,
170
+ device = device
171
+ )
138
172
  images = sampler()
139
173
  ```
140
174
 
@@ -164,12 +198,12 @@ DALL·E 2 architecture leveraging CLIP latents for text-to-image generation.
164
198
  TorchDiff breaks each model into reusable components:
165
199
  - **Forward Diffusion**: Adds noise to data
166
200
  - **Reverse Diffusion**: Removes noise to recover data
167
- - **Variance Scheduler**: Controls noise schedules
201
+ - **Scheduler**: Controls noise schedules
168
202
  - **Training**: Complete training pipelines
169
203
  - **Sampling**: Efficient inference and generation
170
204
 
171
205
  Additional utilities:
172
- - **Noise Predictor**: U-Net-like model with attention and time embeddings
206
+ - **Diffusion Network**: U-Net-like model with attention and time embeddings
173
207
  - **Text Encoder**: Transformer-based for conditional generation
174
208
  - **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
175
209
 
@@ -9,7 +9,7 @@
9
9
 
10
10
  [![License: MIT](https://img.shields.io/badge/license-MIT-red?style=plastic)](https://opensource.org/licenses/MIT)
11
11
  [![PyTorch](https://img.shields.io/badge/PyTorch-white?style=plastic&logo=pytorch&logoColor=red)](https://pytorch.org/)
12
- [![Version](https://img.shields.io/badge/version-2.0.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
12
+ [![Version](https://img.shields.io/badge/version-2.1.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
13
13
  [![Python](https://img.shields.io/badge/python-3.8%2B-blue?style=plastic&logo=python&logoColor=white)](https://www.python.org/)
14
14
  [![Downloads](https://pepy.tech/badge/torchdiff)](https://pepy.tech/project/torchdiff)
15
15
  [![Stars](https://img.shields.io/github/stars/LoqmanSamani/TorchDiff?style=plastic&color=yellow)](https://github.com/LoqmanSamani/TorchDiff)
@@ -43,12 +43,12 @@ These models support both **conditional** (e.g., text-to-image) and **unconditio
43
43
  TorchDiff is designed with **modularity** in mind. Each model is broken down into reusable components:
44
44
  - **Forward Diffusion**: Adds noise (e.g., `ForwardDDPM`).
45
45
  - **Reverse Diffusion**: Removes noise to recover data (e.g., `ReverseDDPM`).
46
- - **Variance Scheduler**: Controls noise schedules (e.g., `VarianceSchedulerDDPM`).
46
+ - **Scheduler**: Controls noise schedules (e.g., `SchedulerDDPM`).
47
47
  - **Training**: Full training pipelines (e.g., `TrainDDPM`).
48
48
  - **Sampling**: Efficient inference and generation (e.g., `SampleDDPM`).
49
49
 
50
50
  Additional utilities:
51
- - **Noise Predictor**: A U-Net-like model with attention and time embeddings.
51
+ - **Diffusion Network**: A U-Net-like model with attention and time embeddings used as main model.
52
52
  - **Text Encoder**: Transformer-based (e.g., BERT) for conditional generation.
53
53
  - **Metrics**: Evaluation suite including MSE, PSNR, SSIM, FID, and LPIPS.
54
54
 
@@ -56,7 +56,7 @@ Additional utilities:
56
56
 
57
57
  ## ⚡ Quick Start
58
58
 
59
- Here’s a minimal working example to train and sample with **DDPM** on dummy data:
59
+ Here’s a minimal working example to train and sample with **DDPM** on dummy data:
60
60
 
61
61
  ```python
62
62
  import torch
@@ -64,40 +64,72 @@ import torch.nn as nn
64
64
  from torchvision import datasets, transforms
65
65
  from torch.utils.data import DataLoader
66
66
 
67
- from torchdiff.ddpm import VarianceSchedulerDDPM, ForwardDDPM, ReverseDDPM, TrainDDPM, SampleDDPM
68
- from torchdiff.utils import NoisePredictor
67
+ from torchdiff.ddpm import (SchedulerDDPM, ForwardDDPM,
68
+ ReverseDDPM, TrainDDPM, SampleDDPM)
69
+ from torchdiff.utils import DiffusionNetwork, mse_loss
69
70
 
70
- # Dataset (CIFAR10 for demo)
71
+ # dataset: CIFAR10
71
72
  transform = transforms.Compose([
72
73
  transforms.Resize(32),
73
74
  transforms.ToTensor(),
74
75
  transforms.Normalize((0.5,), (0.5,))
75
76
  ])
76
- train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
77
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
77
+ train_dataset = datasets.CIFAR10(
78
+ root="./data", train=True, download=True, transform=transform
79
+ )
80
+ train_loader = DataLoader(
81
+ train_dataset, batch_size=64, shuffle=True
82
+ )
83
+ device = 'cuda' # gpu is used for training and sampling
84
+
85
+ # model components
86
+ diff_net = DiffusionNetwork(
87
+ in_channels = 3,
88
+ down_channels = [32, 64, 128],
89
+ mid_channels = [128, 128],
90
+ up_channels = [128, 64, 32],
91
+ down_sampling = [True, True],
92
+ time_embed_dim = 128,
93
+ y_embed_dim = 128,
94
+ num_down_blocks = 2,
95
+ num_mid_blocks = 2,
96
+ num_up_blocks = 2,
97
+ dropout_rate = 0.1,
98
+ cont_time = False # time is not continuous, if SDE models it should be true
99
+ )
100
+ print(sum(p.numel() for p in diff_net.parameters()))
78
101
 
79
- # Model components
80
- noise_pred = NoisePredictor(in_channels=3)
81
- vs = VarianceSchedulerDDPM(num_steps=1000)
82
- fwd, rev = ForwardDDPM(vs), ReverseDDPM(vs)
102
+ vs = SchedulerDDPM(time_steps = 400)
103
+ fwd = ForwardDDPM(vs, 'noise') # network is trained to predict noise
104
+ rwd = ReverseDDPM(vs, 'noise')
83
105
 
84
- # Optimizer & loss
85
- optim = torch.optim.Adam(noise_pred.parameters(), lr=1e-4)
86
- loss_fn = nn.MSELoss()
106
+ # optimizer
107
+ optim = torch.optim.Adam(diff_net.parameters(), lr=1e-5)
87
108
 
88
- # Training
109
+ # training algorithm
89
110
  trainer = TrainDDPM(
90
- noise_predictor=noise_pred, forward_diffusion=fwd, reverse_diffusion=rev,
91
- conditional_model=None, optimizer=optim, objective=loss_fn,
92
- data_loader=train_loader, max_epochs=1, device="cpu"
111
+ diff_net = diff_net,
112
+ fwd_ddpm = fwd,
113
+ rwd_ddpm = rwd,
114
+ train_loader = train_loader,
115
+ optim = optim,
116
+ loss_fn = mse_loss,
117
+ max_epochs = 10,
118
+ device = device,
119
+ grad_acc = 2
93
120
  )
94
- trainer()
121
+ #trainer()
95
122
 
96
123
  # Sampling
97
- sampler = SampleDDPM(reverse_diffusion=rev, noise_predictor=noise_pred,
98
- image_shape=(32, 32), batch_size=4, in_channels=3, device="cpu")
124
+ sampler = SampleDDPM(
125
+ rwd_ddpm = rwd,
126
+ diff_net = diff_net,
127
+ img_size = (32, 32),
128
+ batch_size = 10,
129
+ in_channels = 3,
130
+ device = device
131
+ )
99
132
  images = sampler()
100
- print("Generated images shape:", images.shape)
101
133
  ```
102
134
 
103
135
  For detailed examples, check the [examples/](https://github.com/LoqmanSamani/TorchDiff/tree/systembiology/examples) directory.
@@ -130,7 +162,7 @@ pip install -r requirements.txt
130
162
  pip install .
131
163
  ```
132
164
 
133
- > Requires **Python 3.8+**. For GPU acceleration, ensure PyTorch is installed with the correct CUDA version.
165
+ > Requires **Python 3.10+**. For GPU acceleration, ensure PyTorch is installed with the correct CUDA version.
134
166
 
135
167
  ---
136
168
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: TorchDiff
3
- Version: 2.1.0
3
+ Version: 2.2.0
4
4
  Summary: A PyTorch-based library for diffusion models
5
5
  Home-page: https://github.com/LoqmanSamani/TorchDiff
6
6
  Author: Loghman Samani
@@ -10,9 +10,6 @@ Project-URL: Homepage, https://loqmansamani.github.io/torchdiff
10
10
  Project-URL: Documentation, https://torchdiff.readthedio
11
11
  Project-URL: Source, https://github.com/LoqmanSamani/TorchDiff
12
12
  Keywords: diffusion models,pytorch,machine learning,deep learning
13
- Classifier: Programming Language :: Python :: 3
14
- Classifier: Programming Language :: Python :: 3.8
15
- Classifier: Programming Language :: Python :: 3.9
16
13
  Classifier: Programming Language :: Python :: 3.10
17
14
  Classifier: Programming Language :: Python :: 3.11
18
15
  Classifier: Programming Language :: Python :: 3.12
@@ -22,7 +19,7 @@ Classifier: Intended Audience :: Developers
22
19
  Classifier: Intended Audience :: Science/Research
23
20
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
21
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
25
- Requires-Python: >=3.8
22
+ Requires-Python: >=3.10
26
23
  Description-Content-Type: text/markdown
27
24
  License-File: LICENSE
28
25
  Requires-Dist: lpips>=0.1.4
@@ -64,7 +61,7 @@ Dynamic: summary
64
61
 
65
62
  [![License: MIT](https://img.shields.io/badge/license-MIT-red?style=plastic)](https://opensource.org/licenses/MIT)
66
63
  [![PyTorch](https://img.shields.io/badge/PyTorch-white?style=plastic&logo=pytorch&logoColor=red)](https://pytorch.org/)
67
- [![Version](https://img.shields.io/badge/version-2.0.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
64
+ [![Version](https://img.shields.io/badge/version-2.1.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
68
65
  [![Python](https://img.shields.io/badge/python-3.8%2B-blue?style=plastic&logo=python&logoColor=white)](https://www.python.org/)
69
66
  [![Downloads](https://pepy.tech/badge/torchdiff)](https://pepy.tech/project/torchdiff)
70
67
 
@@ -76,7 +73,7 @@ Dynamic: summary
76
73
 
77
74
  **TorchDiff** is a PyTorch-based library for building and experimenting with diffusion models, inspired by leading research papers.
78
75
 
79
- The **TorchDiff 2.0.0** release includes implementations of five major diffusion model families:
76
+ The **TorchDiff 2.1.0** release includes implementations of five major diffusion model families:
80
77
  - **DDPM** (Denoising Diffusion Probabilistic Models)
81
78
  - **DDIM** (Denoising Diffusion Implicit Models)
82
79
  - **SDE-based Diffusion**
@@ -107,34 +104,71 @@ import torch.nn as nn
107
104
  from torchvision import datasets, transforms
108
105
  from torch.utils.data import DataLoader
109
106
 
110
- from torchdiff.ddpm import VarianceSchedulerDDPM, ForwardDDPM, ReverseDDPM, TrainDDPM, SampleDDPM
111
- from torchdiff.utils import NoisePredictor
107
+ from torchdiff.ddpm import (SchedulerDDPM, ForwardDDPM,
108
+ ReverseDDPM, TrainDDPM, SampleDDPM)
109
+ from torchdiff.utils import DiffusionNetwork, mse_loss
112
110
 
113
- # Dataset setup
111
+ # dataset: CIFAR10
114
112
  transform = transforms.Compose([
115
113
  transforms.Resize(32),
116
114
  transforms.ToTensor(),
117
115
  transforms.Normalize((0.5,), (0.5,))
118
116
  ])
119
- train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
120
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
117
+ train_dataset = datasets.CIFAR10(
118
+ root="./data", train=True, download=True, transform=transform
119
+ )
120
+ train_loader = DataLoader(
121
+ train_dataset, batch_size=64, shuffle=True
122
+ )
123
+ device = 'cuda' # gpu is used for training and sampling
124
+
125
+ # model components
126
+ diff_net = DiffusionNetwork(
127
+ in_channels = 3,
128
+ down_channels = [32, 64, 128],
129
+ mid_channels = [128, 128],
130
+ up_channels = [128, 64, 32],
131
+ down_sampling = [True, True],
132
+ time_embed_dim = 128,
133
+ y_embed_dim = 128,
134
+ num_down_blocks = 2,
135
+ num_mid_blocks = 2,
136
+ num_up_blocks = 2,
137
+ dropout_rate = 0.1,
138
+ cont_time = False # time is not continuous, if SDE models it should be true
139
+ )
140
+ print(sum(p.numel() for p in diff_net.parameters()))
121
141
 
122
- # Model components
123
- noise_pred = NoisePredictor(in_channels=3)
124
- vs = VarianceSchedulerDDPM(num_steps=1000)
125
- fwd, rev = ForwardDDPM(vs), ReverseDDPM(vs)
142
+ vs = SchedulerDDPM(time_steps = 400)
143
+ fwd = ForwardDDPM(vs, 'noise') # network is trained to predict noise
144
+ rwd = ReverseDDPM(vs, 'noise')
126
145
 
127
- # Training
146
+ # optimizer
147
+ optim = torch.optim.Adam(diff_net.parameters(), lr=1e-5)
148
+
149
+ # training algorithm
128
150
  trainer = TrainDDPM(
129
- noise_predictor=noise_pred, forward_diffusion=fwd, reverse_diffusion=rev,
130
- conditional_model=None, optimizer=torch.optim.Adam(noise_pred.parameters(), lr=1e-4),
131
- objective=nn.MSELoss(), data_loader=train_loader, max_epochs=1, device="cpu"
151
+ diff_net = diff_net,
152
+ fwd_ddpm = fwd,
153
+ rwd_ddpm = rwd,
154
+ train_loader = train_loader,
155
+ optim = optim,
156
+ loss_fn = mse_loss,
157
+ max_epochs = 10,
158
+ device = device,
159
+ grad_acc = 2
132
160
  )
133
- trainer()
161
+ #trainer()
134
162
 
135
163
  # Sampling
136
- sampler = SampleDDPM(reverse_diffusion=rev, noise_predictor=noise_pred,
137
- image_shape=(32, 32), batch_size=4, in_channels=3, device="cpu")
164
+ sampler = SampleDDPM(
165
+ rwd_ddpm = rwd,
166
+ diff_net = diff_net,
167
+ img_size = (32, 32),
168
+ batch_size = 10,
169
+ in_channels = 3,
170
+ device = device
171
+ )
138
172
  images = sampler()
139
173
  ```
140
174
 
@@ -164,12 +198,12 @@ DALL·E 2 architecture leveraging CLIP latents for text-to-image generation.
164
198
  TorchDiff breaks each model into reusable components:
165
199
  - **Forward Diffusion**: Adds noise to data
166
200
  - **Reverse Diffusion**: Removes noise to recover data
167
- - **Variance Scheduler**: Controls noise schedules
201
+ - **Scheduler**: Controls noise schedules
168
202
  - **Training**: Complete training pipelines
169
203
  - **Sampling**: Efficient inference and generation
170
204
 
171
205
  Additional utilities:
172
- - **Noise Predictor**: U-Net-like model with attention and time embeddings
206
+ - **Diffusion Network**: U-Net-like model with attention and time embeddings
173
207
  - **Text Encoder**: Transformer-based for conditional generation
174
208
  - **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
175
209
 
@@ -8,38 +8,29 @@ TorchDiff.egg-info/requires.txt
8
8
  TorchDiff.egg-info/top_level.txt
9
9
  ddim/__init__.py
10
10
  ddim/forward_ddim.py
11
- ddim/hyper_param.py
12
- ddim/noise_predictor.py
13
11
  ddim/reverse_ddim.py
14
12
  ddim/sample_ddim.py
15
- ddim/text_encoder.py
13
+ ddim/scheduler.py
14
+ ddim/test_ddim.py
16
15
  ddim/train_ddim.py
17
16
  ddpm/__init__.py
18
17
  ddpm/forward_ddpm.py
19
- ddpm/hyper_param.py
20
- ddpm/noise_predictor.py
21
18
  ddpm/reverse_ddpm.py
22
19
  ddpm/sample_ddpm.py
23
- ddpm/text_encoder.py
20
+ ddpm/scheduler.py
21
+ ddpm/test_ddpm.py
24
22
  ddpm/train_ddpm.py
25
23
  ldm/__init__.py
26
24
  ldm/autoencoder.py
27
- ldm/forward_idm.py
28
- ldm/hyper_param.py
29
- ldm/metrics.py
30
- ldm/noise_predictor.py
31
- ldm/reverse_ldm.py
32
25
  ldm/sample_ldm.py
33
- ldm/text_encoder.py
34
26
  ldm/train_autoencoder.py
35
27
  ldm/train_ldm.py
36
28
  sde/__init__.py
37
29
  sde/forward_sde.py
38
- sde/hyper_param.py
39
- sde/noise_predictor.py
40
30
  sde/reverse_sde.py
41
31
  sde/sample_sde.py
42
- sde/text_encoder.py
32
+ sde/scheduler.py
33
+ sde/test_sde.py
43
34
  sde/train_sde.py
44
35
  torchdiff/__init__.py
45
36
  torchdiff/ddim.py
@@ -68,4 +59,9 @@ unclip/unclip_sampler.py
68
59
  unclip/upsampler.py
69
60
  unclip/upsampler_trainer.py
70
61
  unclip/utils.py
71
- unclip/val_metrics.py
62
+ unclip/val_metrics.py
63
+ utils/__init__.py
64
+ utils/diff_net.py
65
+ utils/losses.py
66
+ utils/metrics.py
67
+ utils/text_encoder.py
@@ -4,3 +4,4 @@ ldm
4
4
  sde
5
5
  torchdiff
6
6
  unclip
7
+ utils
@@ -0,0 +1,63 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class ForwardDDIM(nn.Module):
7
+ """
8
+ Implements the forward (noising) process of DDIM.
9
+
10
+ This module samples x_t from the forward diffusion distribution:
11
+
12
+ q(x_t | x_0) = N(x_t; sqrt(alphā_t) * x_0, (1 - alphā_t) * I)
13
+
14
+ It also computes the appropriate training target depending on the
15
+ prediction parameterization (noise, x0, or v-prediction).
16
+
17
+ Args:
18
+ scheduler: Noise scheduler containing precomputed diffusion coefficients.
19
+ pred_type: Type of model prediction. One of ["noise", "x0", "v"].
20
+ """
21
+ def __init__(
22
+ self,
23
+ scheduler: nn.Module,
24
+ pred_type: str = "noise"
25
+ ):
26
+ super().__init__()
27
+ valid_types = ["noise", "x0", "v"]
28
+ if pred_type not in valid_types:
29
+ raise ValueError(f"prediction_type must be one of {valid_types}, got {pred_type}")
30
+ self.vs = scheduler
31
+ self.pred_type = pred_type
32
+
33
+ def forward(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
34
+ """
35
+ Perform the forward diffusion step and compute the training target.
36
+
37
+ Samples x_t by adding noise to the clean input x_0 at timestep t,
38
+ and returns the corresponding supervision target for training.
39
+
40
+ Args:
41
+ x0: Clean input data of shape (batch, ...).
42
+ t: Discrete diffusion timesteps of shape (batch,).
43
+ noise: Gaussian noise of same shape as x0.
44
+
45
+ Returns:
46
+ xt: Noised data x_t of shape (batch, ...).
47
+ target: Training target corresponding to pred_type:
48
+ - "noise": the added noise ε
49
+ - "x0": the original clean input x0
50
+ - "v": the velocity parameterization
51
+ """
52
+ sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t]
53
+ sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t]
54
+ sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, x0.shape)
55
+ sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, x0.shape)
56
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
57
+ if self.pred_type == "noise":
58
+ target = noise
59
+ elif self.pred_type == "x0":
60
+ target = x0
61
+ elif self.pred_type == "v":
62
+ target = sqrt_alpha_cumprod_t * noise - sqrt_one_minus_alpha_cumprod_t * x0
63
+ return xt, target