TorchDiff 2.5.0__tar.gz → 2.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchdiff-2.5.0/TorchDiff.egg-info → torchdiff-2.6.0}/PKG-INFO +25 -27
- {torchdiff-2.5.0 → torchdiff-2.6.0}/README.md +24 -26
- {torchdiff-2.5.0 → torchdiff-2.6.0/TorchDiff.egg-info}/PKG-INFO +25 -27
- {torchdiff-2.5.0 → torchdiff-2.6.0}/TorchDiff.egg-info/SOURCES.txt +10 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/setup.py +1 -1
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/__init__.py +1 -1
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/ddim.py +37 -29
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/ddpm.py +47 -26
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/ldm.py +59 -37
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/sde.py +73 -65
- torchdiff-2.6.0/torchdiff/tests/bench_ddim.py +151 -0
- torchdiff-2.6.0/torchdiff/tests/bench_ddpm.py +181 -0
- torchdiff-2.6.0/torchdiff/tests/bench_ldm.py +135 -0
- torchdiff-2.6.0/torchdiff/tests/bench_sde.py +175 -0
- torchdiff-2.6.0/torchdiff/tests/bench_unclip.py +211 -0
- torchdiff-2.6.0/torchdiff/tests/test_ddp_ddim.py +365 -0
- torchdiff-2.6.0/torchdiff/tests/test_ddp_ddpm.py +382 -0
- torchdiff-2.6.0/torchdiff/tests/test_ddp_ldm.py +382 -0
- torchdiff-2.6.0/torchdiff/tests/test_ddp_sde.py +370 -0
- torchdiff-2.6.0/torchdiff/tests/test_ddp_unclip.py +604 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/tests/test_sde.py +27 -25
- torchdiff-2.6.0/torchdiff/tests/test_utils.py +756 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/unclip.py +133 -93
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/utils.py +28 -7
- torchdiff-2.5.0/torchdiff/tests/test_utils.py +0 -316
- {torchdiff-2.5.0 → torchdiff-2.6.0}/LICENSE +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/MANIFEST.in +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/TorchDiff.egg-info/dependency_links.txt +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/TorchDiff.egg-info/requires.txt +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/TorchDiff.egg-info/top_level.txt +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/forward_ddim.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/reverse_ddim.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/sample_ddim.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/scheduler.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/test_ddim.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddim/train_ddim.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/forward_ddpm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/reverse_ddpm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/sample_ddpm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/scheduler.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/test_ddpm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ddpm/train_ddpm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ldm/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ldm/autoencoder.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ldm/sample_ldm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ldm/train_autoencoder.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/ldm/train_ldm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/forward_sde.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/reverse_sde.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/sample_sde.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/scheduler.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/test_sde.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/sde/train_sde.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/setup.cfg +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/tests/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/tests/test_ddim.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/tests/test_ddpm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/tests/test_ldm.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/torchdiff/tests/test_unclip.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/clip_encoder.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/forward_unclip.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/projections.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/reverse_unclip.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/scheduler.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/train_unclip_decoder.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/train_unclip_prior.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/unclip_decoder.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/unclip_sampler.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/unclip_trainstormer_prior.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/upsampler_trainer.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/unclip/upsampler_unclip.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/utils/__init__.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/utils/diff_net.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/utils/losses.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/utils/metrics.py +0 -0
- {torchdiff-2.5.0 → torchdiff-2.6.0}/utils/text_encoder.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: TorchDiff
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6.0
|
|
4
4
|
Summary: A PyTorch-based library for diffusion models
|
|
5
5
|
Home-page: https://github.com/LoqmanSamani/TorchDiff
|
|
6
6
|
Author: Loghman Samani
|
|
@@ -61,7 +61,7 @@ Dynamic: summary
|
|
|
61
61
|
|
|
62
62
|
[](https://opensource.org/licenses/MIT)
|
|
63
63
|
[](https://pytorch.org/)
|
|
64
|
-
[](https://pypi.org/project/torchdiff/)
|
|
65
65
|
[](https://www.python.org/)
|
|
66
66
|
[](https://pepy.tech/project/torchdiff)
|
|
67
67
|
[](https://github.com/LoqmanSamani/TorchDiff)
|
|
@@ -76,7 +76,7 @@ Dynamic: summary
|
|
|
76
76
|
|
|
77
77
|
**TorchDiff** is a PyTorch library for diffusion models, implementing foundational architectures from recent research. The library provides modular components for building, training, and sampling from diffusion-based generative models.
|
|
78
78
|
|
|
79
|
-
Version 2.
|
|
79
|
+
Version 2.6.0 includes five major model families grounded in the diffusion modeling literature. **DDPM** (Ho et al., 2020) and **DDIM** (Song et al., 2021a) establish the core discrete-time framework. **SDE-based diffusion** (Song et al., 2021b) extends this to continuous stochastic processes with variance-exploding and variance-preserving formulations. **LDM** (Rombach et al., 2022) moves diffusion into learned latent spaces via variational autoencoders. **UnCLIP** (Ramesh et al., 2022) combines CLIP embeddings with hierarchical generation for text-to-image synthesis.
|
|
80
80
|
|
|
81
81
|
<div align="center">
|
|
82
82
|
<img src="https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/imgs/mount.png?raw=true" alt="Diffusion Model Process" width="1000"/>
|
|
@@ -93,17 +93,9 @@ We also provide evaluation utilities including standard metrics (MSE, PSNR, SSIM
|
|
|
93
93
|
|
|
94
94
|
---
|
|
95
95
|
|
|
96
|
-
## What's New in v2.5.0
|
|
97
96
|
|
|
98
|
-
- **UnCLIP improvements**: Fixed CLIPContextProjection output dimension handling, corrected sampling loop index arithmetic, resolved NaN loss in upsampler/prior training via bfloat16 autocast, and fixed CLIPEmbeddingProjection reconstruction loss bug.
|
|
99
|
-
- **Expanded test coverage**: Added test suites for LDM (AutoencoderLDM), UnCLIP (Scheduler, Forward/Reverse, Projections, TransformerPrior), and Utils (DiffusionNetwork, loss functions, Metrics).
|
|
100
|
-
- **API completeness**: `TrainUnCLIPPrior` now properly exported; removed duplicate `SampleUnCLIP` import.
|
|
101
|
-
- **Documentation**: Aligned all RST titles, added `torchmetrics` to mock imports for ReadTheDocs builds.
|
|
102
|
-
- **Build fixes**: Corrected ReadTheDocs URL in setup.py, removed trailing commas from requirements.txt, unified README for both GitHub and PyPI.
|
|
103
97
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
## Installation
|
|
98
|
+
### Installation
|
|
107
99
|
|
|
108
100
|
Install the stable release from PyPI.
|
|
109
101
|
|
|
@@ -213,7 +205,8 @@ DDPM (Ho et al., 2020) frames generation as learning to reverse a Markov chain t
|
|
|
213
205
|
|
|
214
206
|
The implementation supports both unconditional generation and conditional variants where generation is guided by auxiliary information like class labels or text embeddings.
|
|
215
207
|
|
|
216
|
-
**Paper:** [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
208
|
+
**Paper:** [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
209
|
+
|
|
217
210
|
**Example:** [DDPM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ddpm/ddpm.ipynb)
|
|
218
211
|
|
|
219
212
|
---
|
|
@@ -224,7 +217,8 @@ DDIM (Song et al., 2021a) reformulates the generative process as a non-Markovian
|
|
|
224
217
|
|
|
225
218
|
Like DDPM, both conditional and unconditional generation modes are supported.
|
|
226
219
|
|
|
227
|
-
**Paper:** [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
|
220
|
+
**Paper:** [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
|
221
|
+
|
|
228
222
|
**Example:** [DDIM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ddim/ddim.ipynb)
|
|
229
223
|
|
|
230
224
|
---
|
|
@@ -235,7 +229,8 @@ The SDE framework (Song et al., 2021b) generalizes diffusion models as continuou
|
|
|
235
229
|
|
|
236
230
|
We implement variance-exploding (VE), variance-preserving (VP), and sub-VP formulations. The reverse process can be simulated using either stochastic differential equations or their deterministic probability flow ODE counterparts. This unifies score matching with denoising diffusion and enables more flexible sampling strategies.
|
|
237
231
|
|
|
238
|
-
**Paper:** [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
|
|
232
|
+
**Paper:** [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
|
|
233
|
+
|
|
239
234
|
**Example:** [SDE Notebooks](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/sde/)
|
|
240
235
|
|
|
241
236
|
---
|
|
@@ -246,7 +241,8 @@ LDM (Rombach et al., 2022) addresses the computational cost of pixel-space diffu
|
|
|
246
241
|
|
|
247
242
|
Any of the diffusion backends (DDPM, DDIM, SDE) can operate in this latent space. The architecture enables high-resolution synthesis that would be impractical in pixel space.
|
|
248
243
|
|
|
249
|
-
**Paper:** [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
|
|
244
|
+
**Paper:** [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
|
|
245
|
+
|
|
250
246
|
**Example:** [LDM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ldm/ldm.ipynb)
|
|
251
247
|
|
|
252
248
|
---
|
|
@@ -259,7 +255,8 @@ This hierarchical approach leverages CLIP's multimodal embedding space where tex
|
|
|
259
255
|
|
|
260
256
|
Given the complexity, UnCLIP training requires more extensive setup than other models in this library.
|
|
261
257
|
|
|
262
|
-
**Paper:** [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)
|
|
258
|
+
**Paper:** [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)
|
|
259
|
+
|
|
263
260
|
**Example:** [UnCLIP Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/unclip/unclip.ipynb)
|
|
264
261
|
|
|
265
262
|
---
|
|
@@ -268,15 +265,16 @@ Given the complexity, UnCLIP training requires more extensive setup than other m
|
|
|
268
265
|
|
|
269
266
|
TorchDiff breaks each model into reusable components:
|
|
270
267
|
|
|
271
|
-
| Component
|
|
272
|
-
|
|
273
|
-
| **Forward Diffusion** | Adds noise to data following model-specific schedules
|
|
274
|
-
| **Reverse Diffusion** | Removes noise to recover data via learned denoising
|
|
275
|
-
| **Scheduler**
|
|
276
|
-
| **Training**
|
|
277
|
-
| **Sampling**
|
|
268
|
+
| Component | Description |
|
|
269
|
+
| --------------------------- | ----------------------------------------------------------------------- |
|
|
270
|
+
| **Forward Diffusion** | Adds noise to data following model-specific schedules |
|
|
271
|
+
| **Reverse Diffusion** | Removes noise to recover data via learned denoising |
|
|
272
|
+
| **Scheduler** | Controls variance/noise schedules across timesteps |
|
|
273
|
+
| **Training** | Complete training pipelines with mixed precision, gradient accumulation |
|
|
274
|
+
| **Sampling** | Efficient inference and image generation routines |
|
|
278
275
|
|
|
279
276
|
Additional utilities:
|
|
277
|
+
|
|
280
278
|
- **DiffusionNetwork**: U-Net architecture with attention and time embeddings
|
|
281
279
|
- **TextEncoder**: Transformer-based encoder for conditional generation
|
|
282
280
|
- **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
|
|
@@ -297,13 +295,13 @@ Documentation and additional materials are available online.
|
|
|
297
295
|
|
|
298
296
|
We are actively developing TorchDiff with several improvements planned for future releases.
|
|
299
297
|
|
|
300
|
-
**Model Extensions**
|
|
298
|
+
**Model Extensions**
|
|
301
299
|
New diffusion variants and training algorithms from recent literature will be added as they become established. We are particularly interested in methods that improve sample efficiency or generation quality.
|
|
302
300
|
|
|
303
|
-
**Performance Optimization**
|
|
301
|
+
**Performance Optimization**
|
|
304
302
|
Sampling speed and memory efficiency remain active areas of research. We plan to integrate faster sampling methods and more efficient architectures as they emerge.
|
|
305
303
|
|
|
306
|
-
**Experimental Utilities**
|
|
304
|
+
**Experimental Utilities**
|
|
307
305
|
Additional tools for hyperparameter tuning, ablation studies, and model comparison will make experimentation more straightforward.
|
|
308
306
|
|
|
309
307
|
---
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
[](https://opensource.org/licenses/MIT)
|
|
10
10
|
[](https://pytorch.org/)
|
|
11
|
-
[](https://pypi.org/project/torchdiff/)
|
|
12
12
|
[](https://www.python.org/)
|
|
13
13
|
[](https://pepy.tech/project/torchdiff)
|
|
14
14
|
[](https://github.com/LoqmanSamani/TorchDiff)
|
|
@@ -23,7 +23,7 @@
|
|
|
23
23
|
|
|
24
24
|
**TorchDiff** is a PyTorch library for diffusion models, implementing foundational architectures from recent research. The library provides modular components for building, training, and sampling from diffusion-based generative models.
|
|
25
25
|
|
|
26
|
-
Version 2.
|
|
26
|
+
Version 2.6.0 includes five major model families grounded in the diffusion modeling literature. **DDPM** (Ho et al., 2020) and **DDIM** (Song et al., 2021a) establish the core discrete-time framework. **SDE-based diffusion** (Song et al., 2021b) extends this to continuous stochastic processes with variance-exploding and variance-preserving formulations. **LDM** (Rombach et al., 2022) moves diffusion into learned latent spaces via variational autoencoders. **UnCLIP** (Ramesh et al., 2022) combines CLIP embeddings with hierarchical generation for text-to-image synthesis.
|
|
27
27
|
|
|
28
28
|
<div align="center">
|
|
29
29
|
<img src="https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/imgs/mount.png?raw=true" alt="Diffusion Model Process" width="1000"/>
|
|
@@ -40,17 +40,9 @@ We also provide evaluation utilities including standard metrics (MSE, PSNR, SSIM
|
|
|
40
40
|
|
|
41
41
|
---
|
|
42
42
|
|
|
43
|
-
## What's New in v2.5.0
|
|
44
43
|
|
|
45
|
-
- **UnCLIP improvements**: Fixed CLIPContextProjection output dimension handling, corrected sampling loop index arithmetic, resolved NaN loss in upsampler/prior training via bfloat16 autocast, and fixed CLIPEmbeddingProjection reconstruction loss bug.
|
|
46
|
-
- **Expanded test coverage**: Added test suites for LDM (AutoencoderLDM), UnCLIP (Scheduler, Forward/Reverse, Projections, TransformerPrior), and Utils (DiffusionNetwork, loss functions, Metrics).
|
|
47
|
-
- **API completeness**: `TrainUnCLIPPrior` now properly exported; removed duplicate `SampleUnCLIP` import.
|
|
48
|
-
- **Documentation**: Aligned all RST titles, added `torchmetrics` to mock imports for ReadTheDocs builds.
|
|
49
|
-
- **Build fixes**: Corrected ReadTheDocs URL in setup.py, removed trailing commas from requirements.txt, unified README for both GitHub and PyPI.
|
|
50
44
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
## Installation
|
|
45
|
+
### Installation
|
|
54
46
|
|
|
55
47
|
Install the stable release from PyPI.
|
|
56
48
|
|
|
@@ -160,7 +152,8 @@ DDPM (Ho et al., 2020) frames generation as learning to reverse a Markov chain t
|
|
|
160
152
|
|
|
161
153
|
The implementation supports both unconditional generation and conditional variants where generation is guided by auxiliary information like class labels or text embeddings.
|
|
162
154
|
|
|
163
|
-
**Paper:** [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
155
|
+
**Paper:** [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
156
|
+
|
|
164
157
|
**Example:** [DDPM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ddpm/ddpm.ipynb)
|
|
165
158
|
|
|
166
159
|
---
|
|
@@ -171,7 +164,8 @@ DDIM (Song et al., 2021a) reformulates the generative process as a non-Markovian
|
|
|
171
164
|
|
|
172
165
|
Like DDPM, both conditional and unconditional generation modes are supported.
|
|
173
166
|
|
|
174
|
-
**Paper:** [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
|
167
|
+
**Paper:** [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
|
168
|
+
|
|
175
169
|
**Example:** [DDIM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ddim/ddim.ipynb)
|
|
176
170
|
|
|
177
171
|
---
|
|
@@ -182,7 +176,8 @@ The SDE framework (Song et al., 2021b) generalizes diffusion models as continuou
|
|
|
182
176
|
|
|
183
177
|
We implement variance-exploding (VE), variance-preserving (VP), and sub-VP formulations. The reverse process can be simulated using either stochastic differential equations or their deterministic probability flow ODE counterparts. This unifies score matching with denoising diffusion and enables more flexible sampling strategies.
|
|
184
178
|
|
|
185
|
-
**Paper:** [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
|
|
179
|
+
**Paper:** [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
|
|
180
|
+
|
|
186
181
|
**Example:** [SDE Notebooks](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/sde/)
|
|
187
182
|
|
|
188
183
|
---
|
|
@@ -193,7 +188,8 @@ LDM (Rombach et al., 2022) addresses the computational cost of pixel-space diffu
|
|
|
193
188
|
|
|
194
189
|
Any of the diffusion backends (DDPM, DDIM, SDE) can operate in this latent space. The architecture enables high-resolution synthesis that would be impractical in pixel space.
|
|
195
190
|
|
|
196
|
-
**Paper:** [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
|
|
191
|
+
**Paper:** [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
|
|
192
|
+
|
|
197
193
|
**Example:** [LDM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ldm/ldm.ipynb)
|
|
198
194
|
|
|
199
195
|
---
|
|
@@ -206,7 +202,8 @@ This hierarchical approach leverages CLIP's multimodal embedding space where tex
|
|
|
206
202
|
|
|
207
203
|
Given the complexity, UnCLIP training requires more extensive setup than other models in this library.
|
|
208
204
|
|
|
209
|
-
**Paper:** [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)
|
|
205
|
+
**Paper:** [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)
|
|
206
|
+
|
|
210
207
|
**Example:** [UnCLIP Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/unclip/unclip.ipynb)
|
|
211
208
|
|
|
212
209
|
---
|
|
@@ -215,15 +212,16 @@ Given the complexity, UnCLIP training requires more extensive setup than other m
|
|
|
215
212
|
|
|
216
213
|
TorchDiff breaks each model into reusable components:
|
|
217
214
|
|
|
218
|
-
| Component
|
|
219
|
-
|
|
220
|
-
| **Forward Diffusion** | Adds noise to data following model-specific schedules
|
|
221
|
-
| **Reverse Diffusion** | Removes noise to recover data via learned denoising
|
|
222
|
-
| **Scheduler**
|
|
223
|
-
| **Training**
|
|
224
|
-
| **Sampling**
|
|
215
|
+
| Component | Description |
|
|
216
|
+
| --------------------------- | ----------------------------------------------------------------------- |
|
|
217
|
+
| **Forward Diffusion** | Adds noise to data following model-specific schedules |
|
|
218
|
+
| **Reverse Diffusion** | Removes noise to recover data via learned denoising |
|
|
219
|
+
| **Scheduler** | Controls variance/noise schedules across timesteps |
|
|
220
|
+
| **Training** | Complete training pipelines with mixed precision, gradient accumulation |
|
|
221
|
+
| **Sampling** | Efficient inference and image generation routines |
|
|
225
222
|
|
|
226
223
|
Additional utilities:
|
|
224
|
+
|
|
227
225
|
- **DiffusionNetwork**: U-Net architecture with attention and time embeddings
|
|
228
226
|
- **TextEncoder**: Transformer-based encoder for conditional generation
|
|
229
227
|
- **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
|
|
@@ -244,13 +242,13 @@ Documentation and additional materials are available online.
|
|
|
244
242
|
|
|
245
243
|
We are actively developing TorchDiff with several improvements planned for future releases.
|
|
246
244
|
|
|
247
|
-
**Model Extensions**
|
|
245
|
+
**Model Extensions**
|
|
248
246
|
New diffusion variants and training algorithms from recent literature will be added as they become established. We are particularly interested in methods that improve sample efficiency or generation quality.
|
|
249
247
|
|
|
250
|
-
**Performance Optimization**
|
|
248
|
+
**Performance Optimization**
|
|
251
249
|
Sampling speed and memory efficiency remain active areas of research. We plan to integrate faster sampling methods and more efficient architectures as they emerge.
|
|
252
250
|
|
|
253
|
-
**Experimental Utilities**
|
|
251
|
+
**Experimental Utilities**
|
|
254
252
|
Additional tools for hyperparameter tuning, ablation studies, and model comparison will make experimentation more straightforward.
|
|
255
253
|
|
|
256
254
|
---
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: TorchDiff
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6.0
|
|
4
4
|
Summary: A PyTorch-based library for diffusion models
|
|
5
5
|
Home-page: https://github.com/LoqmanSamani/TorchDiff
|
|
6
6
|
Author: Loghman Samani
|
|
@@ -61,7 +61,7 @@ Dynamic: summary
|
|
|
61
61
|
|
|
62
62
|
[](https://opensource.org/licenses/MIT)
|
|
63
63
|
[](https://pytorch.org/)
|
|
64
|
-
[](https://pypi.org/project/torchdiff/)
|
|
65
65
|
[](https://www.python.org/)
|
|
66
66
|
[](https://pepy.tech/project/torchdiff)
|
|
67
67
|
[](https://github.com/LoqmanSamani/TorchDiff)
|
|
@@ -76,7 +76,7 @@ Dynamic: summary
|
|
|
76
76
|
|
|
77
77
|
**TorchDiff** is a PyTorch library for diffusion models, implementing foundational architectures from recent research. The library provides modular components for building, training, and sampling from diffusion-based generative models.
|
|
78
78
|
|
|
79
|
-
Version 2.
|
|
79
|
+
Version 2.6.0 includes five major model families grounded in the diffusion modeling literature. **DDPM** (Ho et al., 2020) and **DDIM** (Song et al., 2021a) establish the core discrete-time framework. **SDE-based diffusion** (Song et al., 2021b) extends this to continuous stochastic processes with variance-exploding and variance-preserving formulations. **LDM** (Rombach et al., 2022) moves diffusion into learned latent spaces via variational autoencoders. **UnCLIP** (Ramesh et al., 2022) combines CLIP embeddings with hierarchical generation for text-to-image synthesis.
|
|
80
80
|
|
|
81
81
|
<div align="center">
|
|
82
82
|
<img src="https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/imgs/mount.png?raw=true" alt="Diffusion Model Process" width="1000"/>
|
|
@@ -93,17 +93,9 @@ We also provide evaluation utilities including standard metrics (MSE, PSNR, SSIM
|
|
|
93
93
|
|
|
94
94
|
---
|
|
95
95
|
|
|
96
|
-
## What's New in v2.5.0
|
|
97
96
|
|
|
98
|
-
- **UnCLIP improvements**: Fixed CLIPContextProjection output dimension handling, corrected sampling loop index arithmetic, resolved NaN loss in upsampler/prior training via bfloat16 autocast, and fixed CLIPEmbeddingProjection reconstruction loss bug.
|
|
99
|
-
- **Expanded test coverage**: Added test suites for LDM (AutoencoderLDM), UnCLIP (Scheduler, Forward/Reverse, Projections, TransformerPrior), and Utils (DiffusionNetwork, loss functions, Metrics).
|
|
100
|
-
- **API completeness**: `TrainUnCLIPPrior` now properly exported; removed duplicate `SampleUnCLIP` import.
|
|
101
|
-
- **Documentation**: Aligned all RST titles, added `torchmetrics` to mock imports for ReadTheDocs builds.
|
|
102
|
-
- **Build fixes**: Corrected ReadTheDocs URL in setup.py, removed trailing commas from requirements.txt, unified README for both GitHub and PyPI.
|
|
103
97
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
## Installation
|
|
98
|
+
### Installation
|
|
107
99
|
|
|
108
100
|
Install the stable release from PyPI.
|
|
109
101
|
|
|
@@ -213,7 +205,8 @@ DDPM (Ho et al., 2020) frames generation as learning to reverse a Markov chain t
|
|
|
213
205
|
|
|
214
206
|
The implementation supports both unconditional generation and conditional variants where generation is guided by auxiliary information like class labels or text embeddings.
|
|
215
207
|
|
|
216
|
-
**Paper:** [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
208
|
+
**Paper:** [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
209
|
+
|
|
217
210
|
**Example:** [DDPM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ddpm/ddpm.ipynb)
|
|
218
211
|
|
|
219
212
|
---
|
|
@@ -224,7 +217,8 @@ DDIM (Song et al., 2021a) reformulates the generative process as a non-Markovian
|
|
|
224
217
|
|
|
225
218
|
Like DDPM, both conditional and unconditional generation modes are supported.
|
|
226
219
|
|
|
227
|
-
**Paper:** [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
|
220
|
+
**Paper:** [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
|
221
|
+
|
|
228
222
|
**Example:** [DDIM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ddim/ddim.ipynb)
|
|
229
223
|
|
|
230
224
|
---
|
|
@@ -235,7 +229,8 @@ The SDE framework (Song et al., 2021b) generalizes diffusion models as continuou
|
|
|
235
229
|
|
|
236
230
|
We implement variance-exploding (VE), variance-preserving (VP), and sub-VP formulations. The reverse process can be simulated using either stochastic differential equations or their deterministic probability flow ODE counterparts. This unifies score matching with denoising diffusion and enables more flexible sampling strategies.
|
|
237
231
|
|
|
238
|
-
**Paper:** [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
|
|
232
|
+
**Paper:** [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
|
|
233
|
+
|
|
239
234
|
**Example:** [SDE Notebooks](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/sde/)
|
|
240
235
|
|
|
241
236
|
---
|
|
@@ -246,7 +241,8 @@ LDM (Rombach et al., 2022) addresses the computational cost of pixel-space diffu
|
|
|
246
241
|
|
|
247
242
|
Any of the diffusion backends (DDPM, DDIM, SDE) can operate in this latent space. The architecture enables high-resolution synthesis that would be impractical in pixel space.
|
|
248
243
|
|
|
249
|
-
**Paper:** [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
|
|
244
|
+
**Paper:** [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
|
|
245
|
+
|
|
250
246
|
**Example:** [LDM Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/ldm/ldm.ipynb)
|
|
251
247
|
|
|
252
248
|
---
|
|
@@ -259,7 +255,8 @@ This hierarchical approach leverages CLIP's multimodal embedding space where tex
|
|
|
259
255
|
|
|
260
256
|
Given the complexity, UnCLIP training requires more extensive setup than other models in this library.
|
|
261
257
|
|
|
262
|
-
**Paper:** [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)
|
|
258
|
+
**Paper:** [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)
|
|
259
|
+
|
|
263
260
|
**Example:** [UnCLIP Notebook](https://github.com/LoqmanSamani/TorchDiff/blob/systembiology/examples/unclip/unclip.ipynb)
|
|
264
261
|
|
|
265
262
|
---
|
|
@@ -268,15 +265,16 @@ Given the complexity, UnCLIP training requires more extensive setup than other m
|
|
|
268
265
|
|
|
269
266
|
TorchDiff breaks each model into reusable components:
|
|
270
267
|
|
|
271
|
-
| Component
|
|
272
|
-
|
|
273
|
-
| **Forward Diffusion** | Adds noise to data following model-specific schedules
|
|
274
|
-
| **Reverse Diffusion** | Removes noise to recover data via learned denoising
|
|
275
|
-
| **Scheduler**
|
|
276
|
-
| **Training**
|
|
277
|
-
| **Sampling**
|
|
268
|
+
| Component | Description |
|
|
269
|
+
| --------------------------- | ----------------------------------------------------------------------- |
|
|
270
|
+
| **Forward Diffusion** | Adds noise to data following model-specific schedules |
|
|
271
|
+
| **Reverse Diffusion** | Removes noise to recover data via learned denoising |
|
|
272
|
+
| **Scheduler** | Controls variance/noise schedules across timesteps |
|
|
273
|
+
| **Training** | Complete training pipelines with mixed precision, gradient accumulation |
|
|
274
|
+
| **Sampling** | Efficient inference and image generation routines |
|
|
278
275
|
|
|
279
276
|
Additional utilities:
|
|
277
|
+
|
|
280
278
|
- **DiffusionNetwork**: U-Net architecture with attention and time embeddings
|
|
281
279
|
- **TextEncoder**: Transformer-based encoder for conditional generation
|
|
282
280
|
- **Metrics**: Evaluation suite (MSE, PSNR, SSIM, FID, LPIPS)
|
|
@@ -297,13 +295,13 @@ Documentation and additional materials are available online.
|
|
|
297
295
|
|
|
298
296
|
We are actively developing TorchDiff with several improvements planned for future releases.
|
|
299
297
|
|
|
300
|
-
**Model Extensions**
|
|
298
|
+
**Model Extensions**
|
|
301
299
|
New diffusion variants and training algorithms from recent literature will be added as they become established. We are particularly interested in methods that improve sample efficiency or generation quality.
|
|
302
300
|
|
|
303
|
-
**Performance Optimization**
|
|
301
|
+
**Performance Optimization**
|
|
304
302
|
Sampling speed and memory efficiency remain active areas of research. We plan to integrate faster sampling methods and more efficient architectures as they emerge.
|
|
305
303
|
|
|
306
|
-
**Experimental Utilities**
|
|
304
|
+
**Experimental Utilities**
|
|
307
305
|
Additional tools for hyperparameter tuning, ablation studies, and model comparison will make experimentation more straightforward.
|
|
308
306
|
|
|
309
307
|
---
|
|
@@ -41,7 +41,17 @@ torchdiff/sde.py
|
|
|
41
41
|
torchdiff/unclip.py
|
|
42
42
|
torchdiff/utils.py
|
|
43
43
|
torchdiff/tests/__init__.py
|
|
44
|
+
torchdiff/tests/bench_ddim.py
|
|
45
|
+
torchdiff/tests/bench_ddpm.py
|
|
46
|
+
torchdiff/tests/bench_ldm.py
|
|
47
|
+
torchdiff/tests/bench_sde.py
|
|
48
|
+
torchdiff/tests/bench_unclip.py
|
|
44
49
|
torchdiff/tests/test_ddim.py
|
|
50
|
+
torchdiff/tests/test_ddp_ddim.py
|
|
51
|
+
torchdiff/tests/test_ddp_ddpm.py
|
|
52
|
+
torchdiff/tests/test_ddp_ldm.py
|
|
53
|
+
torchdiff/tests/test_ddp_sde.py
|
|
54
|
+
torchdiff/tests/test_ddp_unclip.py
|
|
45
55
|
torchdiff/tests/test_ddpm.py
|
|
46
56
|
torchdiff/tests/test_ldm.py
|
|
47
57
|
torchdiff/tests/test_sde.py
|
|
@@ -46,6 +46,14 @@ from typing_extensions import Self
|
|
|
46
46
|
from .utils import LossAdapter
|
|
47
47
|
import os
|
|
48
48
|
|
|
49
|
+
__all__ = [
|
|
50
|
+
"ForwardDDIM",
|
|
51
|
+
"ReverseDDIM",
|
|
52
|
+
"SchedulerDDIM",
|
|
53
|
+
"TrainDDIM",
|
|
54
|
+
"SampleDDIM",
|
|
55
|
+
]
|
|
56
|
+
|
|
49
57
|
|
|
50
58
|
###==================================================================================================================###
|
|
51
59
|
|
|
@@ -376,8 +384,7 @@ class SchedulerDDIM(nn.Module):
|
|
|
376
384
|
Reshaped tensor suitable for broadcasting.
|
|
377
385
|
"""
|
|
378
386
|
batch_size = t.shape[0]
|
|
379
|
-
|
|
380
|
-
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
|
|
387
|
+
return t.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
|
|
381
388
|
|
|
382
389
|
|
|
383
390
|
###==================================================================================================================###
|
|
@@ -506,6 +513,7 @@ class TrainDDIM(nn.Module):
|
|
|
506
513
|
factor=0.5
|
|
507
514
|
)
|
|
508
515
|
self.warmup_lr_scheduler = self.warmup_scheduler(self.optim, warmup_steps)
|
|
516
|
+
self._device_type = self.device.type if hasattr(self.device, 'type') else ('cuda' if 'cuda' in str(self.device) else 'cpu')
|
|
509
517
|
if tokenizer is None:
|
|
510
518
|
try:
|
|
511
519
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
@@ -524,15 +532,17 @@ class TrainDDIM(nn.Module):
|
|
|
524
532
|
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
|
|
525
533
|
if "WORLD_SIZE" not in os.environ:
|
|
526
534
|
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
|
|
527
|
-
if not torch.cuda.is_available():
|
|
528
|
-
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
529
535
|
if not torch.distributed.is_initialized():
|
|
530
|
-
|
|
536
|
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
|
537
|
+
init_process_group(backend=backend)
|
|
531
538
|
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
532
539
|
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
533
540
|
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
534
|
-
|
|
535
|
-
|
|
541
|
+
if torch.cuda.is_available():
|
|
542
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
543
|
+
torch.cuda.set_device(self.device)
|
|
544
|
+
else:
|
|
545
|
+
self.device = torch.device("cpu")
|
|
536
546
|
self.master_process = self.ddp_rank == 0
|
|
537
547
|
if self.master_process:
|
|
538
548
|
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
@@ -641,17 +651,12 @@ class TrainDDIM(nn.Module):
|
|
|
641
651
|
def _wrap_models_for_ddp(self) -> None:
|
|
642
652
|
"""Wrap models with DistributedDataParallel for multi-GPU training."""
|
|
643
653
|
if self.use_ddp:
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
device_ids=[self.ddp_local_rank]
|
|
647
|
-
|
|
648
|
-
)
|
|
654
|
+
ddp_kwargs = dict(find_unused_parameters=False)
|
|
655
|
+
if self._device_type == 'cuda':
|
|
656
|
+
ddp_kwargs['device_ids'] = [self.ddp_local_rank]
|
|
657
|
+
self.diff_net = DDP(self.diff_net, **ddp_kwargs)
|
|
649
658
|
if self.cond_net is not None:
|
|
650
|
-
self.cond_net = DDP(
|
|
651
|
-
self.cond_net,
|
|
652
|
-
device_ids=[self.ddp_local_rank],
|
|
653
|
-
find_unused_parameters=True
|
|
654
|
-
)
|
|
659
|
+
self.cond_net = DDP(self.cond_net, **ddp_kwargs)
|
|
655
660
|
|
|
656
661
|
def forward(self) -> Dict:
|
|
657
662
|
"""Trains the DDIM model to predict noise added by the forward diffusion process.
|
|
@@ -678,7 +683,10 @@ class TrainDDIM(nn.Module):
|
|
|
678
683
|
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
679
684
|
|
|
680
685
|
self._wrap_models_for_ddp()
|
|
681
|
-
|
|
686
|
+
use_amp = self._device_type == 'cuda'
|
|
687
|
+
scaler = torch.amp.GradScaler(self._device_type, enabled=use_amp)
|
|
688
|
+
if use_amp:
|
|
689
|
+
torch.backends.cudnn.benchmark = True
|
|
682
690
|
wait = 0
|
|
683
691
|
for epoch in range(self.max_epochs):
|
|
684
692
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not self.master_process)
|
|
@@ -691,7 +699,7 @@ class TrainDDIM(nn.Module):
|
|
|
691
699
|
y_encoded = self._process_conditional_input(y)
|
|
692
700
|
else:
|
|
693
701
|
y_encoded = None
|
|
694
|
-
with torch.autocast(device_type=
|
|
702
|
+
with torch.autocast(device_type=self._device_type, enabled=use_amp):
|
|
695
703
|
noise = torch.randn_like(x)
|
|
696
704
|
t = torch.randint(0, self.fwd_ddim.vs.train_steps, (x.shape[0],), device=x.device)
|
|
697
705
|
xt, target = self.fwd_ddim(x, t, noise)
|
|
@@ -706,8 +714,8 @@ class TrainDDIM(nn.Module):
|
|
|
706
714
|
torch.nn.utils.clip_grad_norm_(self.cond_net.parameters(), max_norm=1.0)
|
|
707
715
|
scaler.step(self.optim)
|
|
708
716
|
scaler.update()
|
|
709
|
-
self.optim.zero_grad()
|
|
710
|
-
if self.global_step
|
|
717
|
+
self.optim.zero_grad(set_to_none=True)
|
|
718
|
+
if self.global_step < self.warmup_steps:
|
|
711
719
|
self.warmup_lr_scheduler.step()
|
|
712
720
|
self.global_step += 1
|
|
713
721
|
pbar.set_postfix({'Loss': f'{loss.item() * self.grad_acc:.4f}'})
|
|
@@ -1042,7 +1050,7 @@ class SampleDDIM(nn.Module):
|
|
|
1042
1050
|
if conds is None and self.cond_net is not None:
|
|
1043
1051
|
raise ValueError("Conditions must be provided for conditional model")
|
|
1044
1052
|
|
|
1045
|
-
init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1]
|
|
1053
|
+
init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1], device=self.device)
|
|
1046
1054
|
self.diff_net.eval()
|
|
1047
1055
|
if self.cond_net:
|
|
1048
1056
|
self.cond_net.eval()
|
|
@@ -1055,14 +1063,13 @@ class SampleDDIM(nn.Module):
|
|
|
1055
1063
|
dynamic_ncols=True,
|
|
1056
1064
|
leave=True,
|
|
1057
1065
|
)
|
|
1058
|
-
if self.cond_net is not None and conds is not None:
|
|
1059
|
-
input_ids, attention_masks = self.tokenize(conds)
|
|
1060
|
-
key_padding_mask = (attention_masks == 0)
|
|
1061
|
-
y = self.cond_net(input_ids, key_padding_mask)
|
|
1062
|
-
else:
|
|
1063
|
-
y = None
|
|
1064
|
-
|
|
1065
1066
|
with torch.no_grad():
|
|
1067
|
+
if self.cond_net is not None and conds is not None:
|
|
1068
|
+
input_ids, attention_masks = self.tokenize(conds)
|
|
1069
|
+
key_padding_mask = (attention_masks == 0)
|
|
1070
|
+
y = self.cond_net(input_ids, key_padding_mask)
|
|
1071
|
+
else:
|
|
1072
|
+
y = None
|
|
1066
1073
|
xt = init_samps
|
|
1067
1074
|
for i in iterator:
|
|
1068
1075
|
t_current = timesteps[i].item()
|
|
@@ -1099,6 +1106,7 @@ class SampleDDIM(nn.Module):
|
|
|
1099
1106
|
"""
|
|
1100
1107
|
self.device = device
|
|
1101
1108
|
self.diff_net.to(device)
|
|
1109
|
+
self.rwd_ddim.to(device)
|
|
1102
1110
|
if self.cond_net:
|
|
1103
1111
|
self.cond_net.to(device)
|
|
1104
1112
|
return super().to(device)
|