flaxdiff 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/__init__.py +0 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +489 -0
- flaxdiff/models/common.py +7 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +519 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +113 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +43 -0
- flaxdiff/samplers/euler.py +59 -0
- flaxdiff/samplers/heun_sampler.py +28 -0
- flaxdiff/samplers/multistep_dpm.py +60 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +216 -0
- flaxdiff/utils.py +89 -0
- flaxdiff-0.1.1.dist-info/METADATA +326 -0
- flaxdiff-0.1.1.dist-info/RECORD +30 -0
- flaxdiff-0.1.1.dist-info/WHEEL +5 -0
- flaxdiff-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,326 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: flaxdiff
|
3
|
+
Version: 0.1.1
|
4
|
+
Summary: A versatile and easy to understand Diffusion library
|
5
|
+
Author: Ashish Kumar Singh
|
6
|
+
Author-email: ashishkmr472@gmail.com
|
7
|
+
Description-Content-Type: text/markdown
|
8
|
+
Requires-Dist: flax >=0.8.4
|
9
|
+
Requires-Dist: optax >=0.2.2
|
10
|
+
Requires-Dist: jax >=0.4.28
|
11
|
+
Requires-Dist: orbax
|
12
|
+
Requires-Dist: clu
|
13
|
+
|
14
|
+
# 
|
15
|
+
|
16
|
+
## A Versatile and simple Diffusion Library
|
17
|
+
|
18
|
+
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
19
|
+
|
20
|
+
FlaxDiff is a library of tools (schedulers, samplers, models, etc.) designed and implemented in an easy-to-understand way. The focus is on understandability and readability over performance. I started this project as a hobby to familiarize myself with Flax and Jax and to learn about diffusion and the latest research in generative AI.
|
21
|
+
|
22
|
+
I initially started this project in Keras, being familiar with TensorFlow 2.0, but transitioned to Flax, powered by Jax, for its performance and ease of use. The old notebooks and models, including my first Flax models, are also provided.
|
23
|
+
|
24
|
+
The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments. Several checkpoints are uploaded to the `pretrained` folder along with a copy of the working notebook associated with each checkpoint. *You may need to copy the notebook to the working root for it to function properly.*
|
25
|
+
|
26
|
+
## Example Notebooks from scratch
|
27
|
+
|
28
|
+
In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
|
29
|
+
|
30
|
+
### Available Notebooks
|
31
|
+
|
32
|
+
- **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
|
33
|
+
|
34
|
+
- **WORK IN PROGRESS** An in-depth exploration of the concept of Diffusion based generative models, DDPM (Denoising Diffusion Probabilistic Models), DDIM (Denoising Diffusion Implicit Models), and the SDE/ODE generalizations of diffusion, with step-by-step explainations and code.
|
35
|
+
|
36
|
+
<a target="_blank" href="https://colab.research.google.com/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb">
|
37
|
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
38
|
+
</a>
|
39
|
+
|
40
|
+
- **[EDM (Elucidating the Design Space of Diffusion-based Generative Models)](tutorial%20notebooks/edm%20tutorial.ipynb)**
|
41
|
+
- **TODO** A thorough guide to EDM, discussing the innovative approaches and techniques used in this advanced diffusion model.
|
42
|
+
|
43
|
+
<a target="_blank" href="https://colab.research.google.com/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/edm%20tutorial.ipynb">
|
44
|
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
45
|
+
</a>
|
46
|
+
|
47
|
+
These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
|
48
|
+
|
49
|
+
## Disclaimer (and About Me)
|
50
|
+
|
51
|
+
I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
|
52
|
+
|
53
|
+
Also, few of the text may be generated with help of github copilot, so please excuse any mistakes in the text.
|
54
|
+
|
55
|
+
## Index
|
56
|
+
|
57
|
+
- [A Versatile and Easy-to-Understand Diffusion Library](#a-versatile-and-easy-to-understand-diffusion-library)
|
58
|
+
- [Disclaimer (and About Me)](#disclaimer-and-about-me)
|
59
|
+
- [Features](#features)
|
60
|
+
- [Schedulers](#schedulers)
|
61
|
+
- [Model Predictors](#model-predictors)
|
62
|
+
- [Samplers](#samplers)
|
63
|
+
- [Training](#training)
|
64
|
+
- [Models](#models)
|
65
|
+
- [Installation of FlaxDiff](#installation)
|
66
|
+
- [Getting Started with FlaxDiff](#getting-started)
|
67
|
+
- [Training Example](#training-example)
|
68
|
+
- [Inference Example](#inference-example)
|
69
|
+
- [References and Acknowledgements](#references-and-acknowledgements)
|
70
|
+
- [Pending things to do list](#pending-things-to-do-list)
|
71
|
+
- [Gallery](#gallery)
|
72
|
+
- [Contribution](#contribution)
|
73
|
+
- [License](#license)
|
74
|
+
|
75
|
+
## Features
|
76
|
+
|
77
|
+
### Schedulers
|
78
|
+
Implemented in `flaxdiff.schedulers`:
|
79
|
+
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
80
|
+
- **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
|
81
|
+
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
82
|
+
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
83
|
+
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
84
|
+
- **KarrasVENoiseScheduler** (`flaxdiff.schedulers.KarrasVENoiseScheduler`): A sigma-parameterized continuous scheduler proposed by Karras et al. 2022, best suited for inference.
|
85
|
+
- **EDMNoiseScheduler** (`flaxdiff.schedulers.EDMNoiseScheduler`): A sigma-parameterized continuous scheduler based on the Exponential Diffusion Model (EDM), best suited for training with the KarrasKarrasVENoiseScheduler.
|
86
|
+
|
87
|
+
### Model Predictors
|
88
|
+
Implemented in `flaxdiff.predictors`:
|
89
|
+
- **EpsilonPredictor** (`flaxdiff.predictors.EpsilonPredictor`): Predicts the noise in the data.
|
90
|
+
- **X0Predictor** (`flaxdiff.predictors.X0Predictor`): Predicts the original data from the noisy data.
|
91
|
+
- **VPredictor** (`flaxdiff.predictors.VPredictor`): Predicts a linear combination of the data and noise, commonly used in the EDM.
|
92
|
+
- **KarrasEDMPredictor** (`flaxdiff.predictors.KarrasEDMPredictor`): A generalized predictor for the EDM, integrating various parameterizations.
|
93
|
+
|
94
|
+
### Samplers
|
95
|
+
Implemented in `flaxdiff.samplers`:
|
96
|
+
- **DDPMSampler** (`flaxdiff.samplers.DDPMSampler`): Implements the Denoising Diffusion Probabilistic Model (DDPM) sampling process.
|
97
|
+
- **DDIMSampler** (`flaxdiff.samplers.DDIMSampler`): Implements the Denoising Diffusion Implicit Model (DDIM) sampling process.
|
98
|
+
- **EulerSampler** (`flaxdiff.samplers.EulerSampler`): An ODE solver sampler using Euler's method.
|
99
|
+
- **HeunSampler** (`flaxdiff.samplers.HeunSampler`): An ODE solver sampler using Heun's method.
|
100
|
+
- **RK4Sampler** (`flaxdiff.samplers.RK4Sampler`): An ODE solver sampler using the Runge-Kutta method.
|
101
|
+
- **MultiStepDPM** (`flaxdiff.samplers.MultiStepDPM`): Implements a multi-step sampling method inspired by the Multistep DPM solver as presented here: [tonyduan/diffusion](https://github.com/tonyduan/diffusion/blob/fcc0ed829baf29e1493b460b073e735a848c08ea/src/samplers.py#L44))
|
102
|
+
|
103
|
+
### Training
|
104
|
+
Implemented in `flaxdiff.trainer`:
|
105
|
+
- **DiffusionTrainer** (`flaxdiff.trainer.DiffusionTrainer`): A class designed to facilitate the training of diffusion models. It manages the training loop, loss calculation, and model updates.
|
106
|
+
|
107
|
+
### Models
|
108
|
+
Implemented in `flaxdiff.models`:
|
109
|
+
- **UNet** (`flaxdiff.models.simple_unet.SimpleUNet`): A sample UNET architecture for diffusion models.
|
110
|
+
- **Layers**: A library of layers including upsampling (`flaxdiff.models.simple_unet.Upsample`), downsampling (`flaxdiff.models.simple_unet.Downsample`), Time embeddings (`flaxdiff.models.simple_unet.FouriedEmbedding`), attention (`flaxdiff.models.simple_unet.AttentionBlock`), and residual blocks (`flaxdiff.models.simple_unet.ResidualBlock`).
|
111
|
+
|
112
|
+
## Installation
|
113
|
+
|
114
|
+
To install FlaxDiff, you need to have Python 3.10 or higher. Install the required dependencies using:
|
115
|
+
|
116
|
+
```bash
|
117
|
+
pip install -r requirements.txt
|
118
|
+
```
|
119
|
+
|
120
|
+
The models were trained and tested with jax==0.4.28 and flax==0.8.4. However, when I updated to the latest jax==0.4.30 and flax==0.8.5,
|
121
|
+
the models stopped training. There seems to have been some major change breaking the training dynamics and therefore I would recommend
|
122
|
+
sticking to the versions mentioned in the requirements.txt
|
123
|
+
|
124
|
+
## Getting Started
|
125
|
+
|
126
|
+
### Training Example
|
127
|
+
|
128
|
+
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
129
|
+
|
130
|
+
```python
|
131
|
+
from flaxdiff.schedulers import EDMNoiseScheduler
|
132
|
+
from flaxdiff.predictors import KarrasPredictionTransform
|
133
|
+
from flaxdiff.models.simple_unet import SimpleUNet as UNet
|
134
|
+
from flaxdiff.trainer import DiffusionTrainer
|
135
|
+
import jax
|
136
|
+
import optax
|
137
|
+
from datetime import datetime
|
138
|
+
|
139
|
+
BATCH_SIZE = 16
|
140
|
+
IMAGE_SIZE = 64
|
141
|
+
|
142
|
+
# Define noise scheduler
|
143
|
+
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
144
|
+
|
145
|
+
# Define model
|
146
|
+
unet = UNet(emb_features=256,
|
147
|
+
feature_depths=[64, 128, 256, 512],
|
148
|
+
attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
|
149
|
+
num_res_blocks=2,
|
150
|
+
num_middle_res_blocks=1)
|
151
|
+
|
152
|
+
# Load dataset
|
153
|
+
data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
154
|
+
batches = datalen // BATCH_SIZE
|
155
|
+
|
156
|
+
# Define optimizer
|
157
|
+
solver = optax.adam(2e-4)
|
158
|
+
|
159
|
+
# Create trainer
|
160
|
+
trainer = DiffusionTrainer(unet, optimizer=solver,
|
161
|
+
noise_schedule=edm_schedule,
|
162
|
+
rngs=jax.random.PRNGKey(4),
|
163
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
164
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
|
165
|
+
|
166
|
+
# Train the model
|
167
|
+
final_state = trainer.fit(data, batches, epochs=2000)
|
168
|
+
```
|
169
|
+
|
170
|
+
### Inference Example
|
171
|
+
|
172
|
+
Here is a simplified example for generating images using a trained model:
|
173
|
+
|
174
|
+
```python
|
175
|
+
from flaxdiff.samplers import DiffusionSampler
|
176
|
+
|
177
|
+
class EulerSampler(DiffusionSampler):
|
178
|
+
def take_next_step(self, current_samples, reconstructed_samples, pred_noise, current_step, state, next_step=None):
|
179
|
+
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
180
|
+
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
181
|
+
dt = next_sigma - current_sigma
|
182
|
+
x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt
|
183
|
+
dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
|
184
|
+
next_samples = current_samples + dx * dt
|
185
|
+
return next_samples, state
|
186
|
+
|
187
|
+
# Create sampler
|
188
|
+
sampler = EulerSampler(trainer.model, trainer.state.ema_params, edm_schedule, model_output_transform=trainer.model_output_transform)
|
189
|
+
|
190
|
+
# Generate images
|
191
|
+
samples = sampler.generate_images(num_images=64, diffusion_steps=100, start_step=1000, end_step=0)
|
192
|
+
plotImages(samples, dpi=300)
|
193
|
+
```
|
194
|
+
|
195
|
+
## References and Acknowledgements
|
196
|
+
|
197
|
+
### Research papers and preprints
|
198
|
+
- The Original Denoising Diffusion Probabilistic Models (DDPM) [paper](https://arxiv.org/abs/2006.11239)
|
199
|
+
- Denoising Diffusion Implicit Models (DDIM) [paper](https://arxiv.org/abs/2010.02502)
|
200
|
+
- Improved Denoising Diffusion Probabilistic Models [paper](https://arxiv.org/abs/2102.09672)
|
201
|
+
- Diffusion Models beat GANs on image synthesis [paper](https://arxiv.org/pdf/2105.05233)
|
202
|
+
- Score-Based Generative Modeling through Stochastic Differential Equations [paper](https://arxiv.org/pdf/2011.13456)
|
203
|
+
- Elucidating the design space of Diffusion-based generative models (EDM) [paper](https://arxiv.org/abs/2206.00364)
|
204
|
+
- Perception Prioritized Training of Diffusion Models (P2 Weighting) [paper](https://arxiv.org/abs/2204.00227)
|
205
|
+
- Pseudo Numerical Methods for Diffusion Models on Manifolds (PNMDM) [paper](https://arxiv.org/abs/2202.09778)
|
206
|
+
- The DPM-Solver:A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps [paper](https://arxiv.org/pdf/2206.00927)
|
207
|
+
|
208
|
+
### Useful blogs and codebases
|
209
|
+
|
210
|
+
- An incredible series of blogs on various diffusion related topics by [Sander Dieleman](https://sander.ai/posts/). The posts particularly on [diffusion models](https://sander.ai/2022/01/31/diffusion.html), [Typicality](https://sander.ai/2020/09/01/typicality.html), [Geometry of Diffusion Guidance](https://sander.ai/2023/08/28/geometry.html#warning) and [Noise Schedules](https://sander.ai/2024/06/14/noise-schedules.html) are a must read
|
211
|
+
- An awesome blog series by Tony Duan on [Diffusion models from scratch](https://www.tonyduan.com/diffusion/index.html). Although it trains models for MNIST and the implementations are a bit basic, the maths is explained in a very nice way. The codebase is [here](https://github.com/tonyduan/diffusion)
|
212
|
+
- The [k-diffusion](https://github.com/crowsonkb/k-diffusion/) codebase Katherine Crowson, which hosts an exhaustive implementation of the EDM paper (Karras et al) along with the DPM-Solver, DPM-Solver++ (both 2S and 2M) in pytorch. Most other diffusion libraries borrow from this.
|
213
|
+
- The [Official EDM implementation](https://github.com/NVlabs/edm) by Tero Karras, in pytorch. Really neat code and the reference implementation for all the karras based samplers/schedules.
|
214
|
+
- The [Hugging Face Diffusers Library](https://github.com/huggingface/diffusers), Arguably the most complete set of implementations for the latest state-of-the-art techniques and concepts in this field. Written mainly in pytorch, but with flax implementations also available for a lot of the concepts, the focus of this repository is on completeness and ease of understanding as well.
|
215
|
+
- The [Keras DDPM Tutorial](https://keras.io/examples/generative/ddpm/) by A_K Nain, and the [Keras DDIM implementation](https://keras.io/examples/generative/ddim/) by András Béres, which are great starting points for beginners to understand the basics of diffusion models. I started my journey by trying to implement the concepts introduced in these tutorials from scratch.
|
216
|
+
- Special thanks to ChatGPT-4 by OpenAI for helping clear my doubts.
|
217
|
+
|
218
|
+
## Pending things to do list
|
219
|
+
|
220
|
+
- **Advanced solvers like DPM/DPM2/DPM++ etc**
|
221
|
+
- **SDE versions of the current ODE solvers i.e, ancestral sampling**
|
222
|
+
- **Text Conditioned image generation**
|
223
|
+
- **Classifier and Classified Free Guidance**
|
224
|
+
|
225
|
+
## Gallery
|
226
|
+
|
227
|
+
### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
|
228
|
+
Images generated by the following prompts using classifier free guidance with guidance factor = 2:
|
229
|
+
`'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
|
230
|
+
|
231
|
+
**Params**:
|
232
|
+
`Dataset: oxford_flowers102`
|
233
|
+
`Batch size: 16`
|
234
|
+
`Image Size: 128`
|
235
|
+
`Training Epochs: 1000`
|
236
|
+
`Steps per epoch: 511`
|
237
|
+
|
238
|
+
`Training Noise Schedule: EDMNoiseScheduler`
|
239
|
+
`Inference Noise Schedule: KarrasEDMPredictor`
|
240
|
+
|
241
|
+

|
242
|
+
|
243
|
+
### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
|
244
|
+
Images generated by the following prompts using classifier free guidance with guidance factor = 4:
|
245
|
+
`'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'`
|
246
|
+
|
247
|
+
**Params**:
|
248
|
+
`Dataset: oxford_flowers102`
|
249
|
+
`Batch size: 16`
|
250
|
+
`Image Size: 128`
|
251
|
+
`Training Epochs: 1000`
|
252
|
+
`Steps per epoch: 511`
|
253
|
+
|
254
|
+
`Training Noise Schedule: EDMNoiseScheduler`
|
255
|
+
`Inference Noise Schedule: KarrasEDMPredictor`
|
256
|
+
|
257
|
+

|
258
|
+
|
259
|
+
### Images generated by DDPM Sampler in 1000 steps [Unconditional]
|
260
|
+
|
261
|
+
**Params**:
|
262
|
+
`Dataset: oxford_flowers102`
|
263
|
+
`Batch size: 16`
|
264
|
+
`Image Size: 64`
|
265
|
+
`Training Epochs: 1000`
|
266
|
+
`Steps per epoch: 511`
|
267
|
+
|
268
|
+
`Training Noise Schedule: CosineNoiseSchedule`
|
269
|
+
`Inference Noise Schedule: CosineNoiseSchedule`
|
270
|
+
|
271
|
+
`Model: UNet(emb_features=256,
|
272
|
+
feature_depths=[64, 128, 256, 512],
|
273
|
+
attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
|
274
|
+
num_res_blocks=2,
|
275
|
+
num_middle_res_blocks=1)`
|
276
|
+
|
277
|
+

|
278
|
+
|
279
|
+
### Images generated by DDPM Sampler in 1000 steps [Unconditional]
|
280
|
+
|
281
|
+
**Params**:
|
282
|
+
`Dataset: oxford_flowers102`
|
283
|
+
`Batch size: 16`
|
284
|
+
`Image Size: 64`
|
285
|
+
`Training Epochs: 1000`
|
286
|
+
`Steps per epoch: 511`
|
287
|
+
|
288
|
+
`Training Noise Schedule: CosineNoiseSchedule`
|
289
|
+
`Inference Noise Schedule: CosineNoiseSchedule`
|
290
|
+
|
291
|
+
`Model: UNet(emb_features=256,
|
292
|
+
feature_depths=[64, 128, 256, 512],
|
293
|
+
attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
|
294
|
+
num_res_blocks=2,
|
295
|
+
num_middle_res_blocks=1)`
|
296
|
+
|
297
|
+

|
298
|
+
|
299
|
+
### Images generated by Heun Sampler in 10 steps (20 model inferences as Heun takes 2x inference steps) [Unconditional]
|
300
|
+
|
301
|
+
**Params**:
|
302
|
+
`Dataset: oxford_flowers102`
|
303
|
+
`Batch size: 16`
|
304
|
+
`Image Size: 64`
|
305
|
+
`Training Epochs: 1000`
|
306
|
+
`Steps per epoch: 511`
|
307
|
+
|
308
|
+
`Training Noise Schedule: EDMNoiseScheduler`
|
309
|
+
`Inference Noise Schedule: KarrasEDMPredictor`
|
310
|
+
|
311
|
+
`Model: UNet(emb_features=256,
|
312
|
+
feature_depths=[64, 128, 256, 512],
|
313
|
+
attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
|
314
|
+
num_res_blocks=2,
|
315
|
+
num_middle_res_blocks=1)`
|
316
|
+
|
317
|
+

|
318
|
+
|
319
|
+
|
320
|
+
## Contribution
|
321
|
+
|
322
|
+
Feel free to contribute by opening issues or submitting pull requests. Let's make FlaxDiff better together!
|
323
|
+
|
324
|
+
## License
|
325
|
+
|
326
|
+
This project is licensed under the MIT License.
|
@@ -0,0 +1,30 @@
|
|
1
|
+
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
|
+
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
+
flaxdiff/models/attention.py,sha256=enyqoZP4NMbIn07UdnduxvohtfpbsYW-n7nALE3K_s4,18369
|
5
|
+
flaxdiff/models/common.py,sha256=WUCbuqSa8jEWAUt0UbEStTlpt5j1Mw8oZmZXYj5VwWQ,241
|
6
|
+
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=EExRXSo0nvpiDUF_3lPKp4eQVGBa05PSskNs1ER0sqU,19273
|
8
|
+
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
9
|
+
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
10
|
+
flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
|
11
|
+
flaxdiff/samplers/ddim.py,sha256=XHMBX06S5hMTnKMaGh6fmq189pQcaGkA6fnX6YPbHP0,511
|
12
|
+
flaxdiff/samplers/ddpm.py,sha256=d_58hfVShJHsRPQf5h1-4YKrD43-HGjWz9Vd8hltZBg,2627
|
13
|
+
flaxdiff/samplers/euler.py,sha256=Epf7LBKUky7B8b-1ZyIlLWdRMgjmP08BQraGSKmr_3I,2726
|
14
|
+
flaxdiff/samplers/heun_sampler.py,sha256=hhWnSM26OfOIFAcsuWYa1z-2QPjASuoYTop2byLWqzE,1388
|
15
|
+
flaxdiff/samplers/multistep_dpm.py,sha256=ocmEq2sCvsULy6oTFaD5BhTU4c8VHsge4bdg6tfxW80,2724
|
16
|
+
flaxdiff/samplers/rk4_sampler.py,sha256=BF-dMV1KauO-SYShqrCfm3U3V-1n4clqQXBeoG8RWQo,1728
|
17
|
+
flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
|
18
|
+
flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
|
19
|
+
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
20
|
+
flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
|
21
|
+
flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
|
22
|
+
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
23
|
+
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
24
|
+
flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
|
25
|
+
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
26
|
+
flaxdiff/trainer/__init__.py,sha256=iXnrIugF2g2ZLgW3HxZZBzgsoxJx7bWvLxqVmWpmAbo,8536
|
27
|
+
flaxdiff-0.1.1.dist-info/METADATA,sha256=ZcNAw19k8s40DKgBILh3CriHkieOXuwhUbUJjx_YW8U,19229
|
28
|
+
flaxdiff-0.1.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
29
|
+
flaxdiff-0.1.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
30
|
+
flaxdiff-0.1.1.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
flaxdiff
|