flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +140 -162
- flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff/models/common.py +322 -0
- flaxdiff/models/simple_unet.py +21 -327
- flaxdiff/trainer/__init__.py +2 -201
- flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff/trainer/diffusion_trainer.py +202 -0
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/METADATA +12 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/RECORD +15 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/top_level.txt +0 -0
@@ -5,14 +5,60 @@ import jax
|
|
5
5
|
from typing import Callable
|
6
6
|
from dataclasses import field
|
7
7
|
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
from functools import partial
|
8
10
|
from clu import metrics
|
9
11
|
from flax.training import train_state # Useful dataclass to keep train state
|
10
12
|
import optax
|
11
13
|
from flax import struct # Flax dataclasses
|
14
|
+
import flax
|
12
15
|
import time
|
13
16
|
import os
|
14
17
|
import orbax
|
15
18
|
from flax.training import orbax_utils
|
19
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
20
|
+
from jax.experimental import mesh_utils
|
21
|
+
from jax.experimental.shard_map import shard_map
|
22
|
+
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
|
+
from termcolor import colored
|
24
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
+
|
26
|
+
from flaxdiff.utils import RandomMarkovState
|
27
|
+
|
28
|
+
PROCESS_COLOR_MAP = {
|
29
|
+
0: "green",
|
30
|
+
1: "yellow",
|
31
|
+
2: "magenta",
|
32
|
+
3: "cyan",
|
33
|
+
4: "white",
|
34
|
+
5: "light_blue",
|
35
|
+
6: "light_red",
|
36
|
+
7: "light_cyan"
|
37
|
+
}
|
38
|
+
|
39
|
+
def _build_global_shape_and_sharding(
|
40
|
+
local_shape: tuple[int, ...], global_mesh: Mesh
|
41
|
+
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
42
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
43
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
44
|
+
return global_shape, sharding
|
45
|
+
|
46
|
+
|
47
|
+
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
48
|
+
"""Put local sharded array into local devices"""
|
49
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
50
|
+
try:
|
51
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
52
|
+
except ValueError as array_split_error:
|
53
|
+
raise ValueError(
|
54
|
+
f"Unable to put to devices shape {array.shape} with "
|
55
|
+
f"local device count {len(global_mesh.local_devices)} "
|
56
|
+
) from array_split_error
|
57
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
58
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
59
|
+
|
60
|
+
def convert_to_global_tree(global_mesh, pytree):
|
61
|
+
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
16
62
|
|
17
63
|
@struct.dataclass
|
18
64
|
class Metrics(metrics.Collection):
|
@@ -44,41 +90,75 @@ class SimpleTrainer:
|
|
44
90
|
name: str = "Simple",
|
45
91
|
load_from_checkpoint: bool = False,
|
46
92
|
checkpoint_suffix: str = "",
|
93
|
+
checkpoint_id: str = None,
|
47
94
|
loss_fn=optax.l2_loss,
|
48
95
|
param_transforms: Callable = None,
|
49
96
|
wandb_config: Dict[str, Any] = None,
|
50
97
|
distributed_training: bool = None,
|
98
|
+
checkpoint_base_path: str = "./checkpoints",
|
51
99
|
):
|
52
100
|
if distributed_training is None or distributed_training is True:
|
53
101
|
# Auto-detect if we are running on multiple devices
|
54
102
|
distributed_training = jax.device_count() > 1
|
103
|
+
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
104
|
+
# self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))
|
55
105
|
|
56
106
|
self.distributed_training = distributed_training
|
57
107
|
self.model = model
|
58
108
|
self.name = name
|
59
109
|
self.loss_fn = loss_fn
|
60
110
|
self.input_shapes = input_shapes
|
61
|
-
|
62
|
-
|
111
|
+
self.checkpoint_base_path = checkpoint_base_path
|
112
|
+
|
113
|
+
|
114
|
+
if wandb_config is not None and jax.process_index() == 0:
|
115
|
+
import wandb
|
63
116
|
run = wandb.init(**wandb_config)
|
64
117
|
self.wandb = run
|
118
|
+
|
119
|
+
# define our custom x axis metric
|
120
|
+
self.wandb.define_metric("train/step")
|
121
|
+
self.wandb.define_metric("train/epoch")
|
122
|
+
|
123
|
+
self.wandb.define_metric("train/loss", step_metric="train/step")
|
124
|
+
|
125
|
+
self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
|
126
|
+
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
|
127
|
+
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
128
|
+
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
129
|
+
|
130
|
+
if checkpoint_id is None:
|
131
|
+
self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
|
132
|
+
else:
|
133
|
+
self.checkpoint_id = checkpoint_id
|
134
|
+
|
135
|
+
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
136
|
+
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
65
137
|
|
66
|
-
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
67
138
|
options = orbax.checkpoint.CheckpointManagerOptions(
|
68
139
|
max_to_keep=4, create=True)
|
69
140
|
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
70
|
-
self.checkpoint_path() + checkpoint_suffix,
|
141
|
+
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
|
71
142
|
|
72
143
|
if load_from_checkpoint:
|
73
|
-
latest_epoch, old_state, old_best_state = self.load()
|
144
|
+
latest_epoch, old_state, old_best_state, rngstate = self.load()
|
74
145
|
else:
|
75
|
-
latest_epoch, old_state, old_best_state = 0, None, None
|
146
|
+
latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
|
76
147
|
|
77
148
|
self.latest_epoch = latest_epoch
|
149
|
+
|
150
|
+
if rngstate:
|
151
|
+
self.rngstate = RandomMarkovState(**rngstate)
|
152
|
+
else:
|
153
|
+
self.rngstate = RandomMarkovState(rngs)
|
154
|
+
|
155
|
+
self.rngstate, subkey = self.rngstate.get_random_key()
|
78
156
|
|
79
157
|
if train_state == None:
|
80
|
-
|
81
|
-
|
158
|
+
state, best_state = self.generate_states(
|
159
|
+
optimizer, subkey, old_state, old_best_state, model, param_transforms
|
160
|
+
)
|
161
|
+
self.init_state(state, best_state)
|
82
162
|
else:
|
83
163
|
self.state = train_state
|
84
164
|
self.best_state = train_state
|
@@ -87,7 +167,7 @@ class SimpleTrainer:
|
|
87
167
|
def get_input_ones(self):
|
88
168
|
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
|
89
169
|
|
90
|
-
def
|
170
|
+
def generate_states(
|
91
171
|
self,
|
92
172
|
optimizer: optax.GradientTransformation,
|
93
173
|
rngs: jax.random.PRNGKey,
|
@@ -96,17 +176,19 @@ class SimpleTrainer:
|
|
96
176
|
model: nn.Module = None,
|
97
177
|
param_transforms: Callable = None
|
98
178
|
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
179
|
+
print("Generating states for SimpleTrainer")
|
99
180
|
rngs, subkey = jax.random.split(rngs)
|
100
181
|
|
101
182
|
if existing_state == None:
|
102
183
|
input_vars = self.get_input_ones()
|
103
184
|
params = model.init(subkey, **input_vars)
|
185
|
+
else:
|
186
|
+
params = existing_state['params']
|
104
187
|
|
105
188
|
state = SimpleTrainState.create(
|
106
189
|
apply_fn=model.apply,
|
107
190
|
params=params,
|
108
191
|
tx=optimizer,
|
109
|
-
rngs=rngs,
|
110
192
|
metrics=Metrics.empty()
|
111
193
|
)
|
112
194
|
if existing_best_state is not None:
|
@@ -119,40 +201,28 @@ class SimpleTrainer:
|
|
119
201
|
|
120
202
|
def init_state(
|
121
203
|
self,
|
122
|
-
|
123
|
-
|
124
|
-
existing_state: dict = None,
|
125
|
-
existing_best_state: dict = None,
|
126
|
-
model: nn.Module = None,
|
127
|
-
param_transforms: Callable = None
|
204
|
+
state: SimpleTrainState,
|
205
|
+
best_state: SimpleTrainState,
|
128
206
|
):
|
129
|
-
|
130
|
-
state, best_state = self.__init_fn(
|
131
|
-
optimizer, rngs, existing_state, existing_best_state, model, param_transforms
|
132
|
-
)
|
133
207
|
self.best_loss = 1e9
|
134
208
|
|
135
|
-
if self.distributed_training:
|
136
|
-
devices = jax.local_devices()
|
137
|
-
if len(devices) > 1:
|
138
|
-
print("Replicating state across devices ", devices)
|
139
|
-
state = flax.jax_utils.replicate(state, devices)
|
140
|
-
best_state = flax.jax_utils.replicate(best_state, devices)
|
141
|
-
else:
|
142
|
-
print("Not replicating any state, Only single device connected to the process")
|
143
|
-
|
144
209
|
self.state = state
|
145
210
|
self.best_state = best_state
|
146
211
|
|
147
212
|
def get_state(self):
|
148
|
-
return
|
213
|
+
# return fully_replicated_host_local_array_to_global_array()
|
214
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
|
149
215
|
|
150
216
|
def get_best_state(self):
|
151
|
-
return flax.jax_utils.
|
217
|
+
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
|
218
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
|
219
|
+
|
220
|
+
def get_rngstate(self):
|
221
|
+
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
|
222
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
|
152
223
|
|
153
224
|
def checkpoint_path(self):
|
154
|
-
|
155
|
-
path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
|
225
|
+
path = os.path.join(self.checkpoint_base_path, self.checkpoint_id)
|
156
226
|
if not os.path.exists(path):
|
157
227
|
os.makedirs(path)
|
158
228
|
return path
|
@@ -170,24 +240,27 @@ class SimpleTrainer:
|
|
170
240
|
ckpt = self.checkpointer.restore(epoch)
|
171
241
|
state = ckpt['state']
|
172
242
|
best_state = ckpt['best_state']
|
243
|
+
rngstate = ckpt['rngs']
|
173
244
|
# Convert the state to a TrainState
|
174
245
|
self.best_loss = ckpt['best_loss']
|
175
246
|
print(
|
176
247
|
f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
|
177
|
-
return epoch, state, best_state
|
248
|
+
return epoch, state, best_state, rngstate
|
178
249
|
|
179
250
|
def save(self, epoch=0):
|
180
251
|
print(f"Saving model at epoch {epoch}")
|
181
252
|
ckpt = {
|
182
253
|
# 'model': self.model,
|
254
|
+
'rngs': self.get_rngstate(),
|
183
255
|
'state': self.get_state(),
|
184
256
|
'best_state': self.get_best_state(),
|
185
|
-
'best_loss': self.best_loss
|
257
|
+
'best_loss': np.array(self.best_loss),
|
186
258
|
}
|
187
259
|
try:
|
188
260
|
save_args = orbax_utils.save_args_from_target(ckpt)
|
189
261
|
self.checkpointer.save(epoch, ckpt, save_kwargs={
|
190
262
|
'save_args': save_args}, force=True)
|
263
|
+
self.checkpointer.wait_until_finished()
|
191
264
|
pass
|
192
265
|
except Exception as e:
|
193
266
|
print("Error saving checkpoint", e)
|
@@ -197,7 +270,7 @@ class SimpleTrainer:
|
|
197
270
|
loss_fn = self.loss_fn
|
198
271
|
distributed_training = self.distributed_training
|
199
272
|
|
200
|
-
def train_step(
|
273
|
+
def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
|
201
274
|
"""Train for a single step."""
|
202
275
|
images = batch['image']
|
203
276
|
labels = batch['label']
|
@@ -208,17 +281,15 @@ class SimpleTrainer:
|
|
208
281
|
nloss = loss_fn(preds, expected_output)
|
209
282
|
loss = jnp.mean(nloss)
|
210
283
|
return loss
|
211
|
-
loss, grads = jax.value_and_grad(model_loss)(
|
284
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
212
285
|
if distributed_training:
|
213
|
-
grads = jax.lax.pmean(grads, "
|
214
|
-
|
215
|
-
return
|
286
|
+
grads = jax.lax.pmean(grads, "data")
|
287
|
+
train_state = train_state.apply_gradients(grads=grads)
|
288
|
+
return train_state, loss, rng_state
|
216
289
|
|
217
290
|
if distributed_training:
|
218
|
-
train_step =
|
219
|
-
|
220
|
-
train_step = jax.jit(train_step)
|
221
|
-
|
291
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
|
292
|
+
train_step = jax.pmap(train_step)
|
222
293
|
return train_step
|
223
294
|
|
224
295
|
def _define_compute_metrics(self):
|
@@ -251,6 +322,7 @@ class SimpleTrainer:
|
|
251
322
|
}
|
252
323
|
|
253
324
|
def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
|
325
|
+
from flax.metrics import tensorboard
|
254
326
|
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
|
255
327
|
summary_writer.hparams({
|
256
328
|
**self.config(),
|
@@ -268,56 +340,79 @@ class SimpleTrainer:
|
|
268
340
|
test_ds = None
|
269
341
|
train_step = self._define_train_step(**train_step_args)
|
270
342
|
compute_metrics = self._define_compute_metrics()
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
343
|
+
train_state = self.state
|
344
|
+
rng_state = self.rngstate
|
345
|
+
global_device_count = jax.device_count()
|
346
|
+
local_device_count = jax.local_device_count()
|
347
|
+
process_index = jax.process_index()
|
348
|
+
if self.distributed_training:
|
349
|
+
global_device_indexes = jnp.arange(global_device_count)
|
350
|
+
else:
|
351
|
+
global_device_indexes = 0
|
277
352
|
|
278
|
-
|
279
|
-
self.latest_epoch += 1
|
280
|
-
current_epoch = self.latest_epoch
|
281
|
-
print(f"\nEpoch {current_epoch}/{epochs}")
|
282
|
-
start_time = time.time()
|
353
|
+
def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
|
283
354
|
epoch_loss = 0
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
355
|
+
current_step = 0
|
356
|
+
for i in range(steps_per_epoch):
|
357
|
+
batch = next(train_ds)
|
358
|
+
if self.distributed_training and global_device_count > 1:
|
359
|
+
# Convert the local device batches to a unified global jax.Array
|
360
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
361
|
+
train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
|
362
|
+
|
363
|
+
if self.distributed_training:
|
364
|
+
loss = jax.experimental.multihost_utils.process_allgather(loss)
|
365
|
+
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
366
|
+
|
367
|
+
epoch_loss += loss
|
291
368
|
|
292
|
-
|
293
|
-
loss = jnp.mean(loss)
|
294
|
-
|
295
|
-
epoch_loss += loss
|
369
|
+
if pbar is not None:
|
296
370
|
if i % 100 == 0:
|
297
371
|
pbar.set_postfix(loss=f'{loss:.4f}')
|
298
372
|
pbar.update(100)
|
299
373
|
current_step = current_epoch*steps_per_epoch + i
|
300
|
-
summary_writer.scalar(
|
301
|
-
'Train Loss', loss, step=current_step)
|
302
374
|
if self.wandb is not None:
|
303
|
-
self.wandb.log({
|
375
|
+
self.wandb.log({
|
376
|
+
"train/step" : current_step,
|
377
|
+
"train/loss": loss,
|
378
|
+
}, step=current_step)
|
379
|
+
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
|
380
|
+
return epoch_loss, current_step, train_state, rng_state
|
381
|
+
|
382
|
+
while self.latest_epoch < epochs:
|
383
|
+
current_epoch = self.latest_epoch
|
384
|
+
self.latest_epoch += 1
|
385
|
+
print(f"\nEpoch {current_epoch}/{epochs}")
|
386
|
+
start_time = time.time()
|
387
|
+
epoch_loss = 0
|
304
388
|
|
305
|
-
|
389
|
+
if process_index == 0:
|
390
|
+
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
|
391
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
|
392
|
+
else:
|
393
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
|
394
|
+
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP.get(process_index, 'white')))
|
395
|
+
|
306
396
|
end_time = time.time()
|
307
|
-
self.state =
|
397
|
+
self.state = train_state
|
398
|
+
self.rngstate = rng_state
|
308
399
|
total_time = end_time - start_time
|
309
400
|
avg_time_per_step = total_time / steps_per_epoch
|
310
401
|
avg_loss = epoch_loss / steps_per_epoch
|
311
402
|
if avg_loss < self.best_loss:
|
312
403
|
self.best_loss = avg_loss
|
313
|
-
self.best_state =
|
404
|
+
self.best_state = train_state
|
314
405
|
self.save(current_epoch)
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
406
|
+
|
407
|
+
if process_index == 0:
|
408
|
+
if self.wandb is not None:
|
409
|
+
self.wandb.log({
|
410
|
+
"train/epoch_time": total_time,
|
411
|
+
"train/avg_time_per_step": avg_time_per_step,
|
412
|
+
"train/avg_loss": avg_loss,
|
413
|
+
"train/best_loss": self.best_loss,
|
414
|
+
"train/epoch": current_epoch,
|
415
|
+
}, step=current_step)
|
416
|
+
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
322
417
|
self.save(epochs)
|
323
|
-
return self.state
|
418
|
+
return self.state
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -13,6 +13,8 @@ Requires-Dist: clu
|
|
13
13
|
|
14
14
|
# 
|
15
15
|
|
16
|
+
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
17
|
+
|
16
18
|
## A Versatile and simple Diffusion Library
|
17
19
|
|
18
20
|
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
@@ -27,7 +29,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
|
|
27
29
|
|
28
30
|
In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
|
29
31
|
|
30
|
-
### Available Notebooks
|
32
|
+
### Available Notebooks and Resources
|
31
33
|
|
32
34
|
- **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
|
33
35
|
|
@@ -46,6 +48,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
|
|
46
48
|
|
47
49
|
These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
|
48
50
|
|
51
|
+
#### Other resources
|
52
|
+
|
53
|
+
- **[Multi-host Data parallel training script in JAX](./training.py)**
|
54
|
+
- Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
|
55
|
+
|
56
|
+
- **[TPU utilities for making life easier](./tpu-tools/)**
|
57
|
+
- A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
|
58
|
+
|
49
59
|
## Disclaimer (and About Me)
|
50
60
|
|
51
61
|
I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
|
@@ -1,11 +1,15 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
5
|
-
flaxdiff/models/common.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=OhpKQXdxWbf8K2_yotLfS0DYdHb-zNpL2p8--ql_FAg,14503
|
5
|
+
flaxdiff/models/common.py,sha256=RYNxX9K19hvwSWaB9Wtv7MIZLhcacdugDgD9uZDh8XM,10358
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
-
flaxdiff/models/simple_unet.py,sha256=
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=hAcz074E9NVdUtECPMi1c1Kw-52Dc6l_ME-5FqIg-n8,9255
|
8
8
|
flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
|
9
|
+
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
10
|
+
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
11
|
+
flaxdiff/models/autoencoder/diffusers.py,sha256=kwlKwHBSAegtTiEkGju_1Trltegj-e47hXFN9jCKmgY,3609
|
12
|
+
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
9
13
|
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
10
14
|
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
11
15
|
flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
|
@@ -24,9 +28,11 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
|
|
24
28
|
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
25
29
|
flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
|
26
30
|
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
27
|
-
flaxdiff/trainer/__init__.py,sha256=
|
28
|
-
flaxdiff/trainer/
|
29
|
-
flaxdiff
|
30
|
-
flaxdiff
|
31
|
-
flaxdiff-0.1.
|
32
|
-
flaxdiff-0.1.
|
31
|
+
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
32
|
+
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
33
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=h5YxIMjBI553xDNeapzLDGF0_4y0MfGRMuHume5sPtM,7785
|
34
|
+
flaxdiff/trainer/simple_trainer.py,sha256=f4g2KGuGM__d9v_4Ip3ng8wQubmenWZUW60VEu2ANOg,16774
|
35
|
+
flaxdiff-0.1.6.dist-info/METADATA,sha256=sWY_oQgQhhuyW89KyRwIBrpVHBPJjRMmsk5twfgIBlo,20090
|
36
|
+
flaxdiff-0.1.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
37
|
+
flaxdiff-0.1.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
38
|
+
flaxdiff-0.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|