flaxdiff 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,32 +1,24 @@
1
- import orbax.checkpoint
2
- import tqdm
3
1
  from flax import linen as nn
4
2
  import jax
5
3
  from typing import Callable
6
4
  from dataclasses import field
7
5
  import jax.numpy as jnp
8
- from clu import metrics
9
- from flax.training import train_state # Useful dataclass to keep train state
10
6
  import optax
11
- from flax import struct # Flax dataclasses
12
- import time
13
- import os
14
- import orbax
15
- from flax.training import orbax_utils
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
16
10
 
17
11
  from ..schedulers import NoiseScheduler
18
12
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
19
13
 
20
- from .simple_trainer import SimpleTrainer, SimpleTrainState
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
21
17
 
22
18
  class TrainState(SimpleTrainState):
23
19
  rngs: jax.random.PRNGKey
24
20
  ema_params: dict
25
21
 
26
- def get_random_key(self):
27
- rngs, subkey = jax.random.split(self.rngs)
28
- return self.replace(rngs=rngs), subkey
29
-
30
22
  def apply_ema(self, decay: float = 0.999):
31
23
  new_ema_params = jax.tree_util.tree_map(
32
24
  lambda ema, param: decay * ema + (1 - decay) * param,
@@ -63,7 +55,7 @@ class DiffusionTrainer(SimpleTrainer):
63
55
  self.model_output_transform = model_output_transform
64
56
  self.unconditional_prob = unconditional_prob
65
57
 
66
- def __init_fn(
58
+ def generate_states(
67
59
  self,
68
60
  optimizer: optax.GradientTransformation,
69
61
  rngs: jax.random.PRNGKey,
@@ -72,6 +64,7 @@ class DiffusionTrainer(SimpleTrainer):
72
64
  model: nn.Module = None,
73
65
  param_transforms: Callable = None
74
66
  ) -> Tuple[TrainState, TrainState]:
67
+ print("Generating states for DiffusionTrainer")
75
68
  rngs, subkey = jax.random.split(rngs)
76
69
 
77
70
  if existing_state == None:
@@ -102,7 +95,7 @@ class DiffusionTrainer(SimpleTrainer):
102
95
  return state, best_state
103
96
 
104
97
  def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
105
- noise_schedule = self.noise_schedule
98
+ noise_schedule: NoiseScheduler = self.noise_schedule
106
99
  model = self.model
107
100
  model_output_transform = self.model_output_transform
108
101
  loss_fn = self.loss_fn
@@ -117,16 +110,19 @@ class DiffusionTrainer(SimpleTrainer):
117
110
 
118
111
  distributed_training = self.distributed_training
119
112
 
120
- def train_step(state: TrainState, batch):
113
+ # @jax.jit
114
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
121
115
  """Train for a single step."""
116
+ rng_state, subkey = rng_state.get_random_key()
117
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
118
+ local_rng_state = RandomMarkovState(subkey)
119
+
122
120
  images = batch['image']
123
121
  # normalize image
124
122
  images = (images - 127.5) / 127.5
125
123
 
126
124
  output = text_embedder(
127
125
  input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
128
- # output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
129
-
130
126
  label_seq = output.last_hidden_state
131
127
 
132
128
  # Generate random probabilities to decide how much of this batch will be unconditional
@@ -134,10 +130,11 @@ class DiffusionTrainer(SimpleTrainer):
134
130
  label_seq = jnp.concat(
135
131
  [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
136
132
 
137
- noise_level, state = noise_schedule.generate_timesteps(
138
- images.shape[0], state)
139
- state, rngs = state.get_random_key()
133
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
134
+
135
+ local_rng_state, rngs = local_rng_state.get_random_key()
140
136
  noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
137
+
141
138
  rates = noise_schedule.get_rates(noise_level)
142
139
  noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
143
140
  images, noise, rates)
@@ -154,16 +151,17 @@ class DiffusionTrainer(SimpleTrainer):
154
151
  loss = nloss
155
152
  return loss
156
153
 
157
- loss, grads = jax.value_and_grad(model_loss)(state.params)
154
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
158
155
  if distributed_training:
159
- grads = jax.lax.pmean(grads, "device")
160
- state = state.apply_gradients(grads=grads)
161
- state = state.apply_ema(self.ema_decay)
162
- return state, loss
163
-
156
+ grads = jax.lax.pmean(grads, "data")
157
+ loss = jax.lax.pmean(loss, "data")
158
+ train_state = train_state.apply_gradients(grads=grads)
159
+ train_state = train_state.apply_ema(self.ema_decay)
160
+ return train_state, loss, rng_state
161
+
164
162
  if distributed_training:
165
- train_step = jax.pmap(axis_name="device")(train_step)
166
- else:
163
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
164
+ out_specs=(P(), P(), P()))
167
165
  train_step = jax.jit(train_step)
168
166
 
169
167
  return train_step
@@ -184,18 +182,3 @@ class DiffusionTrainer(SimpleTrainer):
184
182
  text_embedder = data['model']
185
183
  super().fit(data, steps_per_epoch, epochs, {
186
184
  "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
187
-
188
-
189
- pbar.set_postfix(loss=f'{loss:.4f}')
190
- pbar.update(100)
191
- end_time = time.time()
192
- self.state = state
193
- total_time = end_time - start_time
194
- avg_time_per_step = total_time / steps_per_epoch
195
- avg_loss = epoch_loss / steps_per_epoch
196
- if avg_loss < self.best_loss:
197
- self.best_loss = avg_loss
198
- self.best_state = state
199
- self.save(epoch, best=True)
200
- print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
201
- return self.state
@@ -5,14 +5,60 @@ import jax
5
5
  from typing import Callable
6
6
  from dataclasses import field
7
7
  import jax.numpy as jnp
8
+ import numpy as np
9
+ from functools import partial
8
10
  from clu import metrics
9
11
  from flax.training import train_state # Useful dataclass to keep train state
10
12
  import optax
11
13
  from flax import struct # Flax dataclasses
14
+ import flax
12
15
  import time
13
16
  import os
14
17
  import orbax
15
18
  from flax.training import orbax_utils
19
+ from jax.sharding import Mesh, PartitionSpec as P
20
+ from jax.experimental import mesh_utils
21
+ from jax.experimental.shard_map import shard_map
22
+ from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
+ from termcolor import colored
24
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
+
26
+ from flaxdiff.utils import RandomMarkovState
27
+
28
+ PROCESS_COLOR_MAP = {
29
+ 0: "green",
30
+ 1: "yellow",
31
+ 2: "magenta",
32
+ 3: "cyan",
33
+ 4: "white",
34
+ 5: "light_blue",
35
+ 6: "light_red",
36
+ 7: "light_cyan"
37
+ }
38
+
39
+ def _build_global_shape_and_sharding(
40
+ local_shape: tuple[int, ...], global_mesh: Mesh
41
+ ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
42
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
+ return global_shape, sharding
45
+
46
+
47
+ def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
48
+ """Put local sharded array into local devices"""
49
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
+ try:
51
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
+ except ValueError as array_split_error:
53
+ raise ValueError(
54
+ f"Unable to put to devices shape {array.shape} with "
55
+ f"local device count {len(global_mesh.local_devices)} "
56
+ ) from array_split_error
57
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
59
+
60
+ def convert_to_global_tree(global_mesh, pytree):
61
+ return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
16
62
 
17
63
  @struct.dataclass
18
64
  class Metrics(metrics.Collection):
@@ -44,41 +90,75 @@ class SimpleTrainer:
44
90
  name: str = "Simple",
45
91
  load_from_checkpoint: bool = False,
46
92
  checkpoint_suffix: str = "",
93
+ checkpoint_id: str = None,
47
94
  loss_fn=optax.l2_loss,
48
95
  param_transforms: Callable = None,
49
96
  wandb_config: Dict[str, Any] = None,
50
97
  distributed_training: bool = None,
98
+ checkpoint_base_path: str = "./checkpoints",
51
99
  ):
52
100
  if distributed_training is None or distributed_training is True:
53
101
  # Auto-detect if we are running on multiple devices
54
102
  distributed_training = jax.device_count() > 1
103
+ self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
104
+ # self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))
55
105
 
56
106
  self.distributed_training = distributed_training
57
107
  self.model = model
58
108
  self.name = name
59
109
  self.loss_fn = loss_fn
60
110
  self.input_shapes = input_shapes
61
-
62
- if wandb_config is not None:
111
+ self.checkpoint_base_path = checkpoint_base_path
112
+
113
+
114
+ if wandb_config is not None and jax.process_index() == 0:
115
+ import wandb
63
116
  run = wandb.init(**wandb_config)
64
117
  self.wandb = run
118
+
119
+ # define our custom x axis metric
120
+ self.wandb.define_metric("train/step")
121
+ self.wandb.define_metric("train/epoch")
122
+
123
+ self.wandb.define_metric("train/loss", step_metric="train/step")
124
+
125
+ self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
126
+ self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
127
+ self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
128
+ self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
129
+
130
+ if checkpoint_id is None:
131
+ self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
132
+ else:
133
+ self.checkpoint_id = checkpoint_id
134
+
135
+ # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
136
+ async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
65
137
 
66
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
67
138
  options = orbax.checkpoint.CheckpointManagerOptions(
68
139
  max_to_keep=4, create=True)
69
140
  self.checkpointer = orbax.checkpoint.CheckpointManager(
70
- self.checkpoint_path() + checkpoint_suffix, checkpointer, options)
141
+ self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
71
142
 
72
143
  if load_from_checkpoint:
73
- latest_epoch, old_state, old_best_state = self.load()
144
+ latest_epoch, old_state, old_best_state, rngstate = self.load()
74
145
  else:
75
- latest_epoch, old_state, old_best_state = 0, None, None
146
+ latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
76
147
 
77
148
  self.latest_epoch = latest_epoch
149
+
150
+ if rngstate:
151
+ self.rngstate = RandomMarkovState(**rngstate)
152
+ else:
153
+ self.rngstate = RandomMarkovState(rngs)
154
+
155
+ self.rngstate, subkey = self.rngstate.get_random_key()
78
156
 
79
157
  if train_state == None:
80
- self.init_state(optimizer, rngs, existing_state=old_state,
81
- existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
158
+ state, best_state = self.generate_states(
159
+ optimizer, subkey, old_state, old_best_state, model, param_transforms
160
+ )
161
+ self.init_state(state, best_state)
82
162
  else:
83
163
  self.state = train_state
84
164
  self.best_state = train_state
@@ -87,7 +167,7 @@ class SimpleTrainer:
87
167
  def get_input_ones(self):
88
168
  return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
89
169
 
90
- def __init_fn(
170
+ def generate_states(
91
171
  self,
92
172
  optimizer: optax.GradientTransformation,
93
173
  rngs: jax.random.PRNGKey,
@@ -96,17 +176,19 @@ class SimpleTrainer:
96
176
  model: nn.Module = None,
97
177
  param_transforms: Callable = None
98
178
  ) -> Tuple[SimpleTrainState, SimpleTrainState]:
179
+ print("Generating states for SimpleTrainer")
99
180
  rngs, subkey = jax.random.split(rngs)
100
181
 
101
182
  if existing_state == None:
102
183
  input_vars = self.get_input_ones()
103
184
  params = model.init(subkey, **input_vars)
185
+ else:
186
+ params = existing_state['params']
104
187
 
105
188
  state = SimpleTrainState.create(
106
189
  apply_fn=model.apply,
107
190
  params=params,
108
191
  tx=optimizer,
109
- rngs=rngs,
110
192
  metrics=Metrics.empty()
111
193
  )
112
194
  if existing_best_state is not None:
@@ -119,40 +201,28 @@ class SimpleTrainer:
119
201
 
120
202
  def init_state(
121
203
  self,
122
- optimizer: optax.GradientTransformation,
123
- rngs: jax.random.PRNGKey,
124
- existing_state: dict = None,
125
- existing_best_state: dict = None,
126
- model: nn.Module = None,
127
- param_transforms: Callable = None
204
+ state: SimpleTrainState,
205
+ best_state: SimpleTrainState,
128
206
  ):
129
-
130
- state, best_state = self.__init_fn(
131
- optimizer, rngs, existing_state, existing_best_state, model, param_transforms
132
- )
133
207
  self.best_loss = 1e9
134
208
 
135
- if self.distributed_training:
136
- devices = jax.local_devices()
137
- if len(devices) > 1:
138
- print("Replicating state across devices ", devices)
139
- state = flax.jax_utils.replicate(state, devices)
140
- best_state = flax.jax_utils.replicate(best_state, devices)
141
- else:
142
- print("Not replicating any state, Only single device connected to the process")
143
-
144
209
  self.state = state
145
210
  self.best_state = best_state
146
211
 
147
212
  def get_state(self):
148
- return flax.jax_utils.unreplicate(self.state)
213
+ # return fully_replicated_host_local_array_to_global_array()
214
+ return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
149
215
 
150
216
  def get_best_state(self):
151
- return flax.jax_utils.unreplicate(self.best_state)
217
+ # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
218
+ return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
219
+
220
+ def get_rngstate(self):
221
+ # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
222
+ return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
152
223
 
153
224
  def checkpoint_path(self):
154
- experiment_name = self.name
155
- path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
225
+ path = os.path.join(self.checkpoint_base_path, self.checkpoint_id)
156
226
  if not os.path.exists(path):
157
227
  os.makedirs(path)
158
228
  return path
@@ -170,24 +240,27 @@ class SimpleTrainer:
170
240
  ckpt = self.checkpointer.restore(epoch)
171
241
  state = ckpt['state']
172
242
  best_state = ckpt['best_state']
243
+ rngstate = ckpt['rngs']
173
244
  # Convert the state to a TrainState
174
245
  self.best_loss = ckpt['best_loss']
175
246
  print(
176
247
  f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
177
- return epoch, state, best_state
248
+ return epoch, state, best_state, rngstate
178
249
 
179
250
  def save(self, epoch=0):
180
251
  print(f"Saving model at epoch {epoch}")
181
252
  ckpt = {
182
253
  # 'model': self.model,
254
+ 'rngs': self.get_rngstate(),
183
255
  'state': self.get_state(),
184
256
  'best_state': self.get_best_state(),
185
- 'best_loss': self.best_loss
257
+ 'best_loss': np.array(self.best_loss),
186
258
  }
187
259
  try:
188
260
  save_args = orbax_utils.save_args_from_target(ckpt)
189
261
  self.checkpointer.save(epoch, ckpt, save_kwargs={
190
262
  'save_args': save_args}, force=True)
263
+ self.checkpointer.wait_until_finished()
191
264
  pass
192
265
  except Exception as e:
193
266
  print("Error saving checkpoint", e)
@@ -197,7 +270,7 @@ class SimpleTrainer:
197
270
  loss_fn = self.loss_fn
198
271
  distributed_training = self.distributed_training
199
272
 
200
- def train_step(state: SimpleTrainState, batch):
273
+ def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
201
274
  """Train for a single step."""
202
275
  images = batch['image']
203
276
  labels = batch['label']
@@ -208,17 +281,15 @@ class SimpleTrainer:
208
281
  nloss = loss_fn(preds, expected_output)
209
282
  loss = jnp.mean(nloss)
210
283
  return loss
211
- loss, grads = jax.value_and_grad(model_loss)(state.params)
284
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
212
285
  if distributed_training:
213
- grads = jax.lax.pmean(grads, "device")
214
- state = state.apply_gradients(grads=grads)
215
- return state, loss
286
+ grads = jax.lax.pmean(grads, "data")
287
+ train_state = train_state.apply_gradients(grads=grads)
288
+ return train_state, loss, rng_state
216
289
 
217
290
  if distributed_training:
218
- train_step = jax.pmap(axis_name="device")(train_step)
219
- else:
220
- train_step = jax.jit(train_step)
221
-
291
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
292
+ train_step = jax.pmap(train_step)
222
293
  return train_step
223
294
 
224
295
  def _define_compute_metrics(self):
@@ -251,6 +322,7 @@ class SimpleTrainer:
251
322
  }
252
323
 
253
324
  def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
325
+ from flax.metrics import tensorboard
254
326
  summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
255
327
  summary_writer.hparams({
256
328
  **self.config(),
@@ -268,56 +340,79 @@ class SimpleTrainer:
268
340
  test_ds = None
269
341
  train_step = self._define_train_step(**train_step_args)
270
342
  compute_metrics = self._define_compute_metrics()
271
- state = self.state
272
- device_count = jax.local_device_count()
273
- # train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
274
-
275
- summary_writer = self.init_tensorboard(
276
- data['global_batch_size'], steps_per_epoch, epochs)
343
+ train_state = self.state
344
+ rng_state = self.rngstate
345
+ global_device_count = jax.device_count()
346
+ local_device_count = jax.local_device_count()
347
+ process_index = jax.process_index()
348
+ if self.distributed_training:
349
+ global_device_indexes = jnp.arange(global_device_count)
350
+ else:
351
+ global_device_indexes = 0
277
352
 
278
- while self.latest_epoch <= epochs:
279
- self.latest_epoch += 1
280
- current_epoch = self.latest_epoch
281
- print(f"\nEpoch {current_epoch}/{epochs}")
282
- start_time = time.time()
353
+ def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
283
354
  epoch_loss = 0
284
-
285
- with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
286
- for i in range(steps_per_epoch):
287
- batch = next(train_ds)
288
- if self.distributed_training and device_count > 1:
289
- batch = jax.tree.map(lambda x: x.reshape(
290
- (device_count, -1, *x.shape[1:])), batch)
355
+ current_step = 0
356
+ for i in range(steps_per_epoch):
357
+ batch = next(train_ds)
358
+ if self.distributed_training and global_device_count > 1:
359
+ # Convert the local device batches to a unified global jax.Array
360
+ batch = convert_to_global_tree(self.mesh, batch)
361
+ train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
362
+
363
+ if self.distributed_training:
364
+ loss = jax.experimental.multihost_utils.process_allgather(loss)
365
+ loss = jnp.mean(loss) # Just to make sure its a scaler value
366
+
367
+ epoch_loss += loss
291
368
 
292
- state, loss = train_step(state, batch)
293
- loss = jnp.mean(loss)
294
-
295
- epoch_loss += loss
369
+ if pbar is not None:
296
370
  if i % 100 == 0:
297
371
  pbar.set_postfix(loss=f'{loss:.4f}')
298
372
  pbar.update(100)
299
373
  current_step = current_epoch*steps_per_epoch + i
300
- summary_writer.scalar(
301
- 'Train Loss', loss, step=current_step)
302
374
  if self.wandb is not None:
303
- self.wandb.log({"train/loss": loss})
375
+ self.wandb.log({
376
+ "train/step" : current_step,
377
+ "train/loss": loss,
378
+ }, step=current_step)
379
+ print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
380
+ return epoch_loss, current_step, train_state, rng_state
381
+
382
+ while self.latest_epoch < epochs:
383
+ current_epoch = self.latest_epoch
384
+ self.latest_epoch += 1
385
+ print(f"\nEpoch {current_epoch}/{epochs}")
386
+ start_time = time.time()
387
+ epoch_loss = 0
304
388
 
305
- print(f"\n\tEpoch done")
389
+ if process_index == 0:
390
+ with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
391
+ epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
392
+ else:
393
+ epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
394
+ print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP.get(process_index, 'white')))
395
+
306
396
  end_time = time.time()
307
- self.state = state
397
+ self.state = train_state
398
+ self.rngstate = rng_state
308
399
  total_time = end_time - start_time
309
400
  avg_time_per_step = total_time / steps_per_epoch
310
401
  avg_loss = epoch_loss / steps_per_epoch
311
402
  if avg_loss < self.best_loss:
312
403
  self.best_loss = avg_loss
313
- self.best_state = state
404
+ self.best_state = train_state
314
405
  self.save(current_epoch)
315
-
316
- # Compute Metrics
317
- metrics_str = ''
318
-
319
- print(
320
- f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")
321
-
406
+
407
+ if process_index == 0:
408
+ if self.wandb is not None:
409
+ self.wandb.log({
410
+ "train/epoch_time": total_time,
411
+ "train/avg_time_per_step": avg_time_per_step,
412
+ "train/avg_loss": avg_loss,
413
+ "train/best_loss": self.best_loss,
414
+ "train/epoch": current_epoch,
415
+ }, step=current_step)
416
+ print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
322
417
  self.save(epochs)
323
- return self.state
418
+ return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -27,7 +27,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
27
27
 
28
28
  In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
29
29
 
30
- ### Available Notebooks
30
+ ### Available Notebooks and Resources
31
31
 
32
32
  - **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
33
33
 
@@ -46,6 +46,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
46
46
 
47
47
  These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
48
48
 
49
+ #### Other resources
50
+
51
+ - **[Multi-host Data parallel training script in JAX](./training.py)**
52
+ - Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
53
+
54
+ - **[TPU utilities for making life easier](./tpu-tools/)**
55
+ - A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
56
+
49
57
  ## Disclaimer (and About Me)
50
58
 
51
59
  I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
@@ -1,11 +1,14 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=SL9cvINjmabW1LPvXLAFZNHv-FF1Ez_d3J7n5uHBTyQ,15301
5
- flaxdiff/models/common.py,sha256=CjC4iRLjkF3oQ0f6rAqfiLaiHllZGtCOwN3rXDUndbE,274
4
+ flaxdiff/models/attention.py,sha256=KiAUyfujGpUZR13aJR6RVnL6pBXk5UcyM62VIXhojMg,14468
5
+ flaxdiff/models/common.py,sha256=jlyRB4uF7BmeuExor1YHaqEbBjSuyaDZ4mDsSW3rWKE,7948
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=WlLry6v18syHBzcN8zAJ-zIVtq6ItMEIBWbeCcX0MLU,18693
7
+ flaxdiff/models/simple_unet.py,sha256=o1DCa9yvqarEGTiUKsTqE70q-h6bRU6HcU0lZpb65jc,11418
8
8
  flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
9
+ flaxdiff/models/autoencoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ flaxdiff/models/autoencoder/autoencoder.py,sha256=At-DhcmrZ0Gao4PUa4l9D25FTdTPwbE4gu6LKcFKzUQ,433
11
+ flaxdiff/models/autoencoder/diffusers.py,sha256=gwyD98277vQGKVPFbyd6w6CupoxMsNgKlN67AtzLCtg,3267
9
12
  flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
10
13
  flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
11
14
  flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
@@ -24,9 +27,9 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
24
27
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
25
28
  flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
26
29
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
27
- flaxdiff/trainer/__init__.py,sha256=kwzkm-BD97hffFIXZUP1Hb3_D85fZ4SRNO7bviEwHU8,7591
28
- flaxdiff/trainer/simple_trainer.py,sha256=jafxr-yZ6FXn0Qi-iTSnlf275QWnIO4GnSvNAeB3H-Q,11651
29
- flaxdiff-0.1.4.dist-info/METADATA,sha256=G8OijdrrYWuKyAfCNtD_dKwdfBmdME56vpR-EYIZKXg,19229
30
- flaxdiff-0.1.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
31
- flaxdiff-0.1.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
32
- flaxdiff-0.1.4.dist-info/RECORD,,
30
+ flaxdiff/trainer/__init__.py,sha256=17qKQFITCfaXQFKYElMzkE-c-EPrv5iUL66gY1gKOsQ,7243
31
+ flaxdiff/trainer/simple_trainer.py,sha256=f4g2KGuGM__d9v_4Ip3ng8wQubmenWZUW60VEu2ANOg,16774
32
+ flaxdiff-0.1.5.dist-info/METADATA,sha256=tGKayFhkYSJJnLY_sHiaCJ60kJZqnO-kcLM3uH3JSN4,19811
33
+ flaxdiff-0.1.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
34
+ flaxdiff-0.1.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
35
+ flaxdiff-0.1.5.dist-info/RECORD,,