flaxdiff 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +132 -155
- flaxdiff/models/autoencoder/__init__.py +0 -0
- flaxdiff/models/autoencoder/autoencoder.py +14 -0
- flaxdiff/models/autoencoder/diffusers.py +88 -0
- flaxdiff/models/common.py +243 -0
- flaxdiff/models/simple_unet.py +17 -252
- flaxdiff/trainer/__init__.py +28 -45
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/METADATA +10 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/RECORD +12 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/top_level.txt +0 -0
flaxdiff/trainer/__init__.py
CHANGED
@@ -1,32 +1,24 @@
|
|
1
|
-
import orbax.checkpoint
|
2
|
-
import tqdm
|
3
1
|
from flax import linen as nn
|
4
2
|
import jax
|
5
3
|
from typing import Callable
|
6
4
|
from dataclasses import field
|
7
5
|
import jax.numpy as jnp
|
8
|
-
from clu import metrics
|
9
|
-
from flax.training import train_state # Useful dataclass to keep train state
|
10
6
|
import optax
|
11
|
-
from
|
12
|
-
import
|
13
|
-
import
|
14
|
-
import orbax
|
15
|
-
from flax.training import orbax_utils
|
7
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
8
|
+
from jax.experimental.shard_map import shard_map
|
9
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
16
10
|
|
17
11
|
from ..schedulers import NoiseScheduler
|
18
12
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
19
13
|
|
20
|
-
from .
|
14
|
+
from flaxdiff.utils import RandomMarkovState
|
15
|
+
|
16
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
21
17
|
|
22
18
|
class TrainState(SimpleTrainState):
|
23
19
|
rngs: jax.random.PRNGKey
|
24
20
|
ema_params: dict
|
25
21
|
|
26
|
-
def get_random_key(self):
|
27
|
-
rngs, subkey = jax.random.split(self.rngs)
|
28
|
-
return self.replace(rngs=rngs), subkey
|
29
|
-
|
30
22
|
def apply_ema(self, decay: float = 0.999):
|
31
23
|
new_ema_params = jax.tree_util.tree_map(
|
32
24
|
lambda ema, param: decay * ema + (1 - decay) * param,
|
@@ -63,7 +55,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
63
55
|
self.model_output_transform = model_output_transform
|
64
56
|
self.unconditional_prob = unconditional_prob
|
65
57
|
|
66
|
-
def
|
58
|
+
def generate_states(
|
67
59
|
self,
|
68
60
|
optimizer: optax.GradientTransformation,
|
69
61
|
rngs: jax.random.PRNGKey,
|
@@ -72,6 +64,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
72
64
|
model: nn.Module = None,
|
73
65
|
param_transforms: Callable = None
|
74
66
|
) -> Tuple[TrainState, TrainState]:
|
67
|
+
print("Generating states for DiffusionTrainer")
|
75
68
|
rngs, subkey = jax.random.split(rngs)
|
76
69
|
|
77
70
|
if existing_state == None:
|
@@ -102,7 +95,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
102
95
|
return state, best_state
|
103
96
|
|
104
97
|
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
105
|
-
noise_schedule = self.noise_schedule
|
98
|
+
noise_schedule: NoiseScheduler = self.noise_schedule
|
106
99
|
model = self.model
|
107
100
|
model_output_transform = self.model_output_transform
|
108
101
|
loss_fn = self.loss_fn
|
@@ -117,16 +110,19 @@ class DiffusionTrainer(SimpleTrainer):
|
|
117
110
|
|
118
111
|
distributed_training = self.distributed_training
|
119
112
|
|
120
|
-
|
113
|
+
# @jax.jit
|
114
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
121
115
|
"""Train for a single step."""
|
116
|
+
rng_state, subkey = rng_state.get_random_key()
|
117
|
+
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
|
118
|
+
local_rng_state = RandomMarkovState(subkey)
|
119
|
+
|
122
120
|
images = batch['image']
|
123
121
|
# normalize image
|
124
122
|
images = (images - 127.5) / 127.5
|
125
123
|
|
126
124
|
output = text_embedder(
|
127
125
|
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
128
|
-
# output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
129
|
-
|
130
126
|
label_seq = output.last_hidden_state
|
131
127
|
|
132
128
|
# Generate random probabilities to decide how much of this batch will be unconditional
|
@@ -134,10 +130,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
134
130
|
label_seq = jnp.concat(
|
135
131
|
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
136
132
|
|
137
|
-
noise_level,
|
138
|
-
|
139
|
-
|
133
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
134
|
+
|
135
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
140
136
|
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
137
|
+
|
141
138
|
rates = noise_schedule.get_rates(noise_level)
|
142
139
|
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
143
140
|
images, noise, rates)
|
@@ -154,16 +151,17 @@ class DiffusionTrainer(SimpleTrainer):
|
|
154
151
|
loss = nloss
|
155
152
|
return loss
|
156
153
|
|
157
|
-
loss, grads = jax.value_and_grad(model_loss)(
|
154
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
158
155
|
if distributed_training:
|
159
|
-
grads = jax.lax.pmean(grads, "
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
156
|
+
grads = jax.lax.pmean(grads, "data")
|
157
|
+
loss = jax.lax.pmean(loss, "data")
|
158
|
+
train_state = train_state.apply_gradients(grads=grads)
|
159
|
+
train_state = train_state.apply_ema(self.ema_decay)
|
160
|
+
return train_state, loss, rng_state
|
161
|
+
|
164
162
|
if distributed_training:
|
165
|
-
train_step =
|
166
|
-
|
163
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
164
|
+
out_specs=(P(), P(), P()))
|
167
165
|
train_step = jax.jit(train_step)
|
168
166
|
|
169
167
|
return train_step
|
@@ -184,18 +182,3 @@ class DiffusionTrainer(SimpleTrainer):
|
|
184
182
|
text_embedder = data['model']
|
185
183
|
super().fit(data, steps_per_epoch, epochs, {
|
186
184
|
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
187
|
-
|
188
|
-
|
189
|
-
pbar.set_postfix(loss=f'{loss:.4f}')
|
190
|
-
pbar.update(100)
|
191
|
-
end_time = time.time()
|
192
|
-
self.state = state
|
193
|
-
total_time = end_time - start_time
|
194
|
-
avg_time_per_step = total_time / steps_per_epoch
|
195
|
-
avg_loss = epoch_loss / steps_per_epoch
|
196
|
-
if avg_loss < self.best_loss:
|
197
|
-
self.best_loss = avg_loss
|
198
|
-
self.best_state = state
|
199
|
-
self.save(epoch, best=True)
|
200
|
-
print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
|
201
|
-
return self.state
|
@@ -5,14 +5,60 @@ import jax
|
|
5
5
|
from typing import Callable
|
6
6
|
from dataclasses import field
|
7
7
|
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
from functools import partial
|
8
10
|
from clu import metrics
|
9
11
|
from flax.training import train_state # Useful dataclass to keep train state
|
10
12
|
import optax
|
11
13
|
from flax import struct # Flax dataclasses
|
14
|
+
import flax
|
12
15
|
import time
|
13
16
|
import os
|
14
17
|
import orbax
|
15
18
|
from flax.training import orbax_utils
|
19
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
20
|
+
from jax.experimental import mesh_utils
|
21
|
+
from jax.experimental.shard_map import shard_map
|
22
|
+
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
|
+
from termcolor import colored
|
24
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
+
|
26
|
+
from flaxdiff.utils import RandomMarkovState
|
27
|
+
|
28
|
+
PROCESS_COLOR_MAP = {
|
29
|
+
0: "green",
|
30
|
+
1: "yellow",
|
31
|
+
2: "magenta",
|
32
|
+
3: "cyan",
|
33
|
+
4: "white",
|
34
|
+
5: "light_blue",
|
35
|
+
6: "light_red",
|
36
|
+
7: "light_cyan"
|
37
|
+
}
|
38
|
+
|
39
|
+
def _build_global_shape_and_sharding(
|
40
|
+
local_shape: tuple[int, ...], global_mesh: Mesh
|
41
|
+
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
42
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
43
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
44
|
+
return global_shape, sharding
|
45
|
+
|
46
|
+
|
47
|
+
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
48
|
+
"""Put local sharded array into local devices"""
|
49
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
50
|
+
try:
|
51
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
52
|
+
except ValueError as array_split_error:
|
53
|
+
raise ValueError(
|
54
|
+
f"Unable to put to devices shape {array.shape} with "
|
55
|
+
f"local device count {len(global_mesh.local_devices)} "
|
56
|
+
) from array_split_error
|
57
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
58
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
59
|
+
|
60
|
+
def convert_to_global_tree(global_mesh, pytree):
|
61
|
+
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
16
62
|
|
17
63
|
@struct.dataclass
|
18
64
|
class Metrics(metrics.Collection):
|
@@ -44,41 +90,75 @@ class SimpleTrainer:
|
|
44
90
|
name: str = "Simple",
|
45
91
|
load_from_checkpoint: bool = False,
|
46
92
|
checkpoint_suffix: str = "",
|
93
|
+
checkpoint_id: str = None,
|
47
94
|
loss_fn=optax.l2_loss,
|
48
95
|
param_transforms: Callable = None,
|
49
96
|
wandb_config: Dict[str, Any] = None,
|
50
97
|
distributed_training: bool = None,
|
98
|
+
checkpoint_base_path: str = "./checkpoints",
|
51
99
|
):
|
52
100
|
if distributed_training is None or distributed_training is True:
|
53
101
|
# Auto-detect if we are running on multiple devices
|
54
102
|
distributed_training = jax.device_count() > 1
|
103
|
+
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
104
|
+
# self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))
|
55
105
|
|
56
106
|
self.distributed_training = distributed_training
|
57
107
|
self.model = model
|
58
108
|
self.name = name
|
59
109
|
self.loss_fn = loss_fn
|
60
110
|
self.input_shapes = input_shapes
|
61
|
-
|
62
|
-
|
111
|
+
self.checkpoint_base_path = checkpoint_base_path
|
112
|
+
|
113
|
+
|
114
|
+
if wandb_config is not None and jax.process_index() == 0:
|
115
|
+
import wandb
|
63
116
|
run = wandb.init(**wandb_config)
|
64
117
|
self.wandb = run
|
118
|
+
|
119
|
+
# define our custom x axis metric
|
120
|
+
self.wandb.define_metric("train/step")
|
121
|
+
self.wandb.define_metric("train/epoch")
|
122
|
+
|
123
|
+
self.wandb.define_metric("train/loss", step_metric="train/step")
|
124
|
+
|
125
|
+
self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
|
126
|
+
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
|
127
|
+
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
128
|
+
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
129
|
+
|
130
|
+
if checkpoint_id is None:
|
131
|
+
self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
|
132
|
+
else:
|
133
|
+
self.checkpoint_id = checkpoint_id
|
134
|
+
|
135
|
+
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
136
|
+
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
65
137
|
|
66
|
-
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
67
138
|
options = orbax.checkpoint.CheckpointManagerOptions(
|
68
139
|
max_to_keep=4, create=True)
|
69
140
|
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
70
|
-
self.checkpoint_path() + checkpoint_suffix,
|
141
|
+
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
|
71
142
|
|
72
143
|
if load_from_checkpoint:
|
73
|
-
latest_epoch, old_state, old_best_state = self.load()
|
144
|
+
latest_epoch, old_state, old_best_state, rngstate = self.load()
|
74
145
|
else:
|
75
|
-
latest_epoch, old_state, old_best_state = 0, None, None
|
146
|
+
latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
|
76
147
|
|
77
148
|
self.latest_epoch = latest_epoch
|
149
|
+
|
150
|
+
if rngstate:
|
151
|
+
self.rngstate = RandomMarkovState(**rngstate)
|
152
|
+
else:
|
153
|
+
self.rngstate = RandomMarkovState(rngs)
|
154
|
+
|
155
|
+
self.rngstate, subkey = self.rngstate.get_random_key()
|
78
156
|
|
79
157
|
if train_state == None:
|
80
|
-
|
81
|
-
|
158
|
+
state, best_state = self.generate_states(
|
159
|
+
optimizer, subkey, old_state, old_best_state, model, param_transforms
|
160
|
+
)
|
161
|
+
self.init_state(state, best_state)
|
82
162
|
else:
|
83
163
|
self.state = train_state
|
84
164
|
self.best_state = train_state
|
@@ -87,7 +167,7 @@ class SimpleTrainer:
|
|
87
167
|
def get_input_ones(self):
|
88
168
|
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
|
89
169
|
|
90
|
-
def
|
170
|
+
def generate_states(
|
91
171
|
self,
|
92
172
|
optimizer: optax.GradientTransformation,
|
93
173
|
rngs: jax.random.PRNGKey,
|
@@ -96,17 +176,19 @@ class SimpleTrainer:
|
|
96
176
|
model: nn.Module = None,
|
97
177
|
param_transforms: Callable = None
|
98
178
|
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
179
|
+
print("Generating states for SimpleTrainer")
|
99
180
|
rngs, subkey = jax.random.split(rngs)
|
100
181
|
|
101
182
|
if existing_state == None:
|
102
183
|
input_vars = self.get_input_ones()
|
103
184
|
params = model.init(subkey, **input_vars)
|
185
|
+
else:
|
186
|
+
params = existing_state['params']
|
104
187
|
|
105
188
|
state = SimpleTrainState.create(
|
106
189
|
apply_fn=model.apply,
|
107
190
|
params=params,
|
108
191
|
tx=optimizer,
|
109
|
-
rngs=rngs,
|
110
192
|
metrics=Metrics.empty()
|
111
193
|
)
|
112
194
|
if existing_best_state is not None:
|
@@ -119,40 +201,28 @@ class SimpleTrainer:
|
|
119
201
|
|
120
202
|
def init_state(
|
121
203
|
self,
|
122
|
-
|
123
|
-
|
124
|
-
existing_state: dict = None,
|
125
|
-
existing_best_state: dict = None,
|
126
|
-
model: nn.Module = None,
|
127
|
-
param_transforms: Callable = None
|
204
|
+
state: SimpleTrainState,
|
205
|
+
best_state: SimpleTrainState,
|
128
206
|
):
|
129
|
-
|
130
|
-
state, best_state = self.__init_fn(
|
131
|
-
optimizer, rngs, existing_state, existing_best_state, model, param_transforms
|
132
|
-
)
|
133
207
|
self.best_loss = 1e9
|
134
208
|
|
135
|
-
if self.distributed_training:
|
136
|
-
devices = jax.local_devices()
|
137
|
-
if len(devices) > 1:
|
138
|
-
print("Replicating state across devices ", devices)
|
139
|
-
state = flax.jax_utils.replicate(state, devices)
|
140
|
-
best_state = flax.jax_utils.replicate(best_state, devices)
|
141
|
-
else:
|
142
|
-
print("Not replicating any state, Only single device connected to the process")
|
143
|
-
|
144
209
|
self.state = state
|
145
210
|
self.best_state = best_state
|
146
211
|
|
147
212
|
def get_state(self):
|
148
|
-
return
|
213
|
+
# return fully_replicated_host_local_array_to_global_array()
|
214
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
|
149
215
|
|
150
216
|
def get_best_state(self):
|
151
|
-
return flax.jax_utils.
|
217
|
+
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
|
218
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
|
219
|
+
|
220
|
+
def get_rngstate(self):
|
221
|
+
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
|
222
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
|
152
223
|
|
153
224
|
def checkpoint_path(self):
|
154
|
-
|
155
|
-
path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
|
225
|
+
path = os.path.join(self.checkpoint_base_path, self.checkpoint_id)
|
156
226
|
if not os.path.exists(path):
|
157
227
|
os.makedirs(path)
|
158
228
|
return path
|
@@ -170,24 +240,27 @@ class SimpleTrainer:
|
|
170
240
|
ckpt = self.checkpointer.restore(epoch)
|
171
241
|
state = ckpt['state']
|
172
242
|
best_state = ckpt['best_state']
|
243
|
+
rngstate = ckpt['rngs']
|
173
244
|
# Convert the state to a TrainState
|
174
245
|
self.best_loss = ckpt['best_loss']
|
175
246
|
print(
|
176
247
|
f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
|
177
|
-
return epoch, state, best_state
|
248
|
+
return epoch, state, best_state, rngstate
|
178
249
|
|
179
250
|
def save(self, epoch=0):
|
180
251
|
print(f"Saving model at epoch {epoch}")
|
181
252
|
ckpt = {
|
182
253
|
# 'model': self.model,
|
254
|
+
'rngs': self.get_rngstate(),
|
183
255
|
'state': self.get_state(),
|
184
256
|
'best_state': self.get_best_state(),
|
185
|
-
'best_loss': self.best_loss
|
257
|
+
'best_loss': np.array(self.best_loss),
|
186
258
|
}
|
187
259
|
try:
|
188
260
|
save_args = orbax_utils.save_args_from_target(ckpt)
|
189
261
|
self.checkpointer.save(epoch, ckpt, save_kwargs={
|
190
262
|
'save_args': save_args}, force=True)
|
263
|
+
self.checkpointer.wait_until_finished()
|
191
264
|
pass
|
192
265
|
except Exception as e:
|
193
266
|
print("Error saving checkpoint", e)
|
@@ -197,7 +270,7 @@ class SimpleTrainer:
|
|
197
270
|
loss_fn = self.loss_fn
|
198
271
|
distributed_training = self.distributed_training
|
199
272
|
|
200
|
-
def train_step(
|
273
|
+
def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
|
201
274
|
"""Train for a single step."""
|
202
275
|
images = batch['image']
|
203
276
|
labels = batch['label']
|
@@ -208,17 +281,15 @@ class SimpleTrainer:
|
|
208
281
|
nloss = loss_fn(preds, expected_output)
|
209
282
|
loss = jnp.mean(nloss)
|
210
283
|
return loss
|
211
|
-
loss, grads = jax.value_and_grad(model_loss)(
|
284
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
212
285
|
if distributed_training:
|
213
|
-
grads = jax.lax.pmean(grads, "
|
214
|
-
|
215
|
-
return
|
286
|
+
grads = jax.lax.pmean(grads, "data")
|
287
|
+
train_state = train_state.apply_gradients(grads=grads)
|
288
|
+
return train_state, loss, rng_state
|
216
289
|
|
217
290
|
if distributed_training:
|
218
|
-
train_step =
|
219
|
-
|
220
|
-
train_step = jax.jit(train_step)
|
221
|
-
|
291
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
|
292
|
+
train_step = jax.pmap(train_step)
|
222
293
|
return train_step
|
223
294
|
|
224
295
|
def _define_compute_metrics(self):
|
@@ -251,6 +322,7 @@ class SimpleTrainer:
|
|
251
322
|
}
|
252
323
|
|
253
324
|
def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
|
325
|
+
from flax.metrics import tensorboard
|
254
326
|
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
|
255
327
|
summary_writer.hparams({
|
256
328
|
**self.config(),
|
@@ -268,56 +340,79 @@ class SimpleTrainer:
|
|
268
340
|
test_ds = None
|
269
341
|
train_step = self._define_train_step(**train_step_args)
|
270
342
|
compute_metrics = self._define_compute_metrics()
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
343
|
+
train_state = self.state
|
344
|
+
rng_state = self.rngstate
|
345
|
+
global_device_count = jax.device_count()
|
346
|
+
local_device_count = jax.local_device_count()
|
347
|
+
process_index = jax.process_index()
|
348
|
+
if self.distributed_training:
|
349
|
+
global_device_indexes = jnp.arange(global_device_count)
|
350
|
+
else:
|
351
|
+
global_device_indexes = 0
|
277
352
|
|
278
|
-
|
279
|
-
self.latest_epoch += 1
|
280
|
-
current_epoch = self.latest_epoch
|
281
|
-
print(f"\nEpoch {current_epoch}/{epochs}")
|
282
|
-
start_time = time.time()
|
353
|
+
def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
|
283
354
|
epoch_loss = 0
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
355
|
+
current_step = 0
|
356
|
+
for i in range(steps_per_epoch):
|
357
|
+
batch = next(train_ds)
|
358
|
+
if self.distributed_training and global_device_count > 1:
|
359
|
+
# Convert the local device batches to a unified global jax.Array
|
360
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
361
|
+
train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
|
362
|
+
|
363
|
+
if self.distributed_training:
|
364
|
+
loss = jax.experimental.multihost_utils.process_allgather(loss)
|
365
|
+
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
366
|
+
|
367
|
+
epoch_loss += loss
|
291
368
|
|
292
|
-
|
293
|
-
loss = jnp.mean(loss)
|
294
|
-
|
295
|
-
epoch_loss += loss
|
369
|
+
if pbar is not None:
|
296
370
|
if i % 100 == 0:
|
297
371
|
pbar.set_postfix(loss=f'{loss:.4f}')
|
298
372
|
pbar.update(100)
|
299
373
|
current_step = current_epoch*steps_per_epoch + i
|
300
|
-
summary_writer.scalar(
|
301
|
-
'Train Loss', loss, step=current_step)
|
302
374
|
if self.wandb is not None:
|
303
|
-
self.wandb.log({
|
375
|
+
self.wandb.log({
|
376
|
+
"train/step" : current_step,
|
377
|
+
"train/loss": loss,
|
378
|
+
}, step=current_step)
|
379
|
+
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
|
380
|
+
return epoch_loss, current_step, train_state, rng_state
|
381
|
+
|
382
|
+
while self.latest_epoch < epochs:
|
383
|
+
current_epoch = self.latest_epoch
|
384
|
+
self.latest_epoch += 1
|
385
|
+
print(f"\nEpoch {current_epoch}/{epochs}")
|
386
|
+
start_time = time.time()
|
387
|
+
epoch_loss = 0
|
304
388
|
|
305
|
-
|
389
|
+
if process_index == 0:
|
390
|
+
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
|
391
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
|
392
|
+
else:
|
393
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
|
394
|
+
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP.get(process_index, 'white')))
|
395
|
+
|
306
396
|
end_time = time.time()
|
307
|
-
self.state =
|
397
|
+
self.state = train_state
|
398
|
+
self.rngstate = rng_state
|
308
399
|
total_time = end_time - start_time
|
309
400
|
avg_time_per_step = total_time / steps_per_epoch
|
310
401
|
avg_loss = epoch_loss / steps_per_epoch
|
311
402
|
if avg_loss < self.best_loss:
|
312
403
|
self.best_loss = avg_loss
|
313
|
-
self.best_state =
|
404
|
+
self.best_state = train_state
|
314
405
|
self.save(current_epoch)
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
406
|
+
|
407
|
+
if process_index == 0:
|
408
|
+
if self.wandb is not None:
|
409
|
+
self.wandb.log({
|
410
|
+
"train/epoch_time": total_time,
|
411
|
+
"train/avg_time_per_step": avg_time_per_step,
|
412
|
+
"train/avg_loss": avg_loss,
|
413
|
+
"train/best_loss": self.best_loss,
|
414
|
+
"train/epoch": current_epoch,
|
415
|
+
}, step=current_step)
|
416
|
+
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
322
417
|
self.save(epochs)
|
323
|
-
return self.state
|
418
|
+
return self.state
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -27,7 +27,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
|
|
27
27
|
|
28
28
|
In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
|
29
29
|
|
30
|
-
### Available Notebooks
|
30
|
+
### Available Notebooks and Resources
|
31
31
|
|
32
32
|
- **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
|
33
33
|
|
@@ -46,6 +46,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
|
|
46
46
|
|
47
47
|
These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
|
48
48
|
|
49
|
+
#### Other resources
|
50
|
+
|
51
|
+
- **[Multi-host Data parallel training script in JAX](./training.py)**
|
52
|
+
- Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
|
53
|
+
|
54
|
+
- **[TPU utilities for making life easier](./tpu-tools/)**
|
55
|
+
- A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
|
56
|
+
|
49
57
|
## Disclaimer (and About Me)
|
50
58
|
|
51
59
|
I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
|
@@ -1,11 +1,14 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
5
|
-
flaxdiff/models/common.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=KiAUyfujGpUZR13aJR6RVnL6pBXk5UcyM62VIXhojMg,14468
|
5
|
+
flaxdiff/models/common.py,sha256=jlyRB4uF7BmeuExor1YHaqEbBjSuyaDZ4mDsSW3rWKE,7948
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
-
flaxdiff/models/simple_unet.py,sha256=
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=o1DCa9yvqarEGTiUKsTqE70q-h6bRU6HcU0lZpb65jc,11418
|
8
8
|
flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
|
9
|
+
flaxdiff/models/autoencoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
flaxdiff/models/autoencoder/autoencoder.py,sha256=At-DhcmrZ0Gao4PUa4l9D25FTdTPwbE4gu6LKcFKzUQ,433
|
11
|
+
flaxdiff/models/autoencoder/diffusers.py,sha256=gwyD98277vQGKVPFbyd6w6CupoxMsNgKlN67AtzLCtg,3267
|
9
12
|
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
10
13
|
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
11
14
|
flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
|
@@ -24,9 +27,9 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
|
|
24
27
|
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
25
28
|
flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
|
26
29
|
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
27
|
-
flaxdiff/trainer/__init__.py,sha256=
|
28
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
29
|
-
flaxdiff-0.1.
|
30
|
-
flaxdiff-0.1.
|
31
|
-
flaxdiff-0.1.
|
32
|
-
flaxdiff-0.1.
|
30
|
+
flaxdiff/trainer/__init__.py,sha256=17qKQFITCfaXQFKYElMzkE-c-EPrv5iUL66gY1gKOsQ,7243
|
31
|
+
flaxdiff/trainer/simple_trainer.py,sha256=f4g2KGuGM__d9v_4Ip3ng8wQubmenWZUW60VEu2ANOg,16774
|
32
|
+
flaxdiff-0.1.5.dist-info/METADATA,sha256=tGKayFhkYSJJnLY_sHiaCJ60kJZqnO-kcLM3uH3JSN4,19811
|
33
|
+
flaxdiff-0.1.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
34
|
+
flaxdiff-0.1.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
35
|
+
flaxdiff-0.1.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|