flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,7 @@ from termcolor import colored
24
24
  from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
25
  from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
+ from flax.training import dynamic_scale as dynamic_scale_lib
27
28
 
28
29
  PROCESS_COLOR_MAP = {
29
30
  0: "green",
@@ -63,12 +64,12 @@ def convert_to_global_tree(global_mesh, pytree):
63
64
  @struct.dataclass
64
65
  class Metrics(metrics.Collection):
65
66
  accuracy: metrics.Accuracy
66
- loss: metrics.Average.from_output('loss')
67
+ loss: metrics.Average#.from_output('loss')
67
68
 
68
69
  # Define the TrainState
69
70
  class SimpleTrainState(train_state.TrainState):
70
71
  metrics: Metrics
71
- dynamic_scale: DynamicScale
72
+ dynamic_scale: dynamic_scale_lib.DynamicScale
72
73
 
73
74
  class SimpleTrainer:
74
75
  state: SimpleTrainState
@@ -110,6 +111,7 @@ class SimpleTrainer:
110
111
 
111
112
 
112
113
  if wandb_config is not None and jax.process_index() == 0:
114
+ import wandb
113
115
  run = wandb.init(**wandb_config)
114
116
  self.wandb = run
115
117
 
@@ -177,16 +179,13 @@ class SimpleTrainer:
177
179
  params = model.init(subkey, **input_vars)
178
180
  else:
179
181
  params = existing_state['params']
180
-
181
- if param_transforms is not None:
182
- params = param_transforms(params)
183
182
 
184
183
  state = SimpleTrainState.create(
185
184
  apply_fn=model.apply,
186
185
  params=params,
187
186
  tx=optimizer,
188
187
  metrics=Metrics.empty(),
189
- dynamic_scale = DynamicScale() if use_dynamic_scale else None
188
+ dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
190
189
  )
191
190
  if existing_best_state is not None:
192
191
  best_state = state.replace(
@@ -207,16 +206,16 @@ class SimpleTrainer:
207
206
  self.best_state = best_state
208
207
 
209
208
  def get_state(self):
210
- # return fully_replicated_host_local_array_to_global_array()
211
- return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
209
+ return self.get_np_tree(self.state)
212
210
 
213
211
  def get_best_state(self):
214
- # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
215
- return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
212
+ return self.get_np_tree(self.best_state)
216
213
 
217
214
  def get_rngstate(self):
218
- # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
219
- return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
215
+ return self.get_np_tree(self.rngstate)
216
+
217
+ def get_np_tree(self, pytree):
218
+ return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
220
219
 
221
220
  def checkpoint_path(self):
222
221
  path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
@@ -253,29 +252,35 @@ class SimpleTrainer:
253
252
  rngstate = ckpt['rngs']
254
253
  # Convert the state to a TrainState
255
254
  self.best_loss = ckpt['best_loss']
255
+ if self.best_loss == 0:
256
+ # It cant be zero as that must have been some problem
257
+ self.best_loss = 1e9
256
258
  current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
257
259
  print(
258
260
  f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
259
261
  return current_epoch, step, state, best_state, rngstate
260
262
 
261
- def save(self, epoch=0, step=0):
263
+ def save(self, epoch=0, step=0, state=None, rngstate=None):
262
264
  print(f"Saving model at epoch {epoch} step {step}")
263
- ckpt = {
264
- # 'model': self.model,
265
- 'rngs': self.get_rngstate(),
266
- 'state': self.get_state(),
267
- 'best_state': self.get_best_state(),
268
- 'best_loss': np.array(self.best_loss),
269
- 'epoch': epoch,
270
- }
271
265
  try:
272
- save_args = orbax_utils.save_args_from_target(ckpt)
273
- self.checkpointer.save(step, ckpt, save_kwargs={
274
- 'save_args': save_args}, force=True)
275
- self.checkpointer.wait_until_finished()
276
- pass
266
+ ckpt = {
267
+ # 'model': self.model,
268
+ 'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
269
+ 'state': self.get_state() if state is None else self.get_np_tree(state),
270
+ 'best_state': self.get_best_state(),
271
+ 'best_loss': np.array(self.best_loss),
272
+ 'epoch': epoch,
273
+ }
274
+ try:
275
+ save_args = orbax_utils.save_args_from_target(ckpt)
276
+ self.checkpointer.save(step, ckpt, save_kwargs={
277
+ 'save_args': save_args}, force=True)
278
+ self.checkpointer.wait_until_finished()
279
+ pass
280
+ except Exception as e:
281
+ print("Error saving checkpoint", e)
277
282
  except Exception as e:
278
- print("Error saving checkpoint", e)
283
+ print("Error saving checkpoint outer", e)
279
284
 
280
285
  def _define_train_step(self, **kwargs):
281
286
  model = self.model
@@ -304,21 +309,26 @@ class SimpleTrainer:
304
309
  train_step = jax.pmap(train_step)
305
310
  return train_step
306
311
 
307
- def _define_compute_metrics(self):
312
+ def _define_vaidation_step(self):
308
313
  model = self.model
309
314
  loss_fn = self.loss_fn
315
+ distributed_training = self.distributed_training
310
316
 
311
- @jax.jit
312
- def compute_metrics(state: SimpleTrainState, batch):
317
+ def validation_step(state: SimpleTrainState, batch):
313
318
  preds = model.apply(state.params, batch['image'])
314
319
  expected_output = batch['label']
315
320
  loss = jnp.mean(loss_fn(preds, expected_output))
321
+ if distributed_training:
322
+ loss = jax.lax.pmean(loss, "data")
316
323
  metric_updates = state.metrics.single_from_model_output(
317
324
  loss=loss, logits=preds, labels=expected_output)
318
325
  metrics = state.metrics.merge(metric_updates)
319
326
  state = state.replace(metrics=metrics)
320
327
  return state
321
- return compute_metrics
328
+ if distributed_training:
329
+ validation_step = shard_map(validation_step, mesh=self.mesh, in_specs=(P(), P('data')), out_specs=(P()))
330
+ validation_step = jax.pmap(validation_step)
331
+ return validation_step
322
332
 
323
333
  def summary(self):
324
334
  input_vars = self.get_input_ones()
@@ -343,17 +353,53 @@ class SimpleTrainer:
343
353
  "batch_size": batch_size
344
354
  })
345
355
  return summary_writer
356
+
357
+ def validation_loop(
358
+ self,
359
+ val_state: SimpleTrainState,
360
+ val_step_fn: Callable,
361
+ val_ds,
362
+ val_steps_per_epoch,
363
+ current_step,
364
+ ):
365
+ global_device_count = jax.device_count()
366
+ local_device_count = jax.local_device_count()
367
+ process_index = jax.process_index()
368
+
369
+ val_ds = iter(val_ds()) if val_ds else None
370
+ # Evaluation step
371
+ try:
372
+ for i in range(val_steps_per_epoch):
373
+ if val_ds is None:
374
+ batch = None
375
+ else:
376
+ batch = next(val_ds)
377
+ if self.distributed_training and global_device_count > 1:
378
+ batch = convert_to_global_tree(self.mesh, batch)
379
+ if i == 0:
380
+ print(f"Evaluation started for process index {process_index}")
381
+ metrics = val_step_fn(val_state, batch)
382
+ if self.wandb is not None:
383
+ # metrics is a dict of metrics
384
+ if metrics and type(metrics) == dict:
385
+ for key, value in metrics.items():
386
+ if isinstance(value, jnp.ndarray):
387
+ value = np.array(value)
388
+ self.wandb.log({
389
+ f"val/{key}": value,
390
+ }, step=current_step)
391
+ except Exception as e:
392
+ print("Error logging images to wandb", e)
346
393
 
347
- def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
348
- train_ds = iter(data['train']())
349
- if 'test' in data:
350
- test_ds = data['test']
351
- else:
352
- test_ds = None
353
- train_step = self._define_train_step(**train_step_args)
354
- compute_metrics = self._define_compute_metrics()
355
- train_state = self.state
356
- rng_state = self.rngstate
394
+ def train_loop(
395
+ self,
396
+ train_state: SimpleTrainState,
397
+ train_step_fn: Callable,
398
+ train_ds,
399
+ train_steps_per_epoch,
400
+ current_step,
401
+ rng_state
402
+ ):
357
403
  global_device_count = jax.device_count()
358
404
  local_device_count = jax.local_device_count()
359
405
  process_index = jax.process_index()
@@ -361,67 +407,105 @@ class SimpleTrainer:
361
407
  global_device_indexes = jnp.arange(global_device_count)
362
408
  else:
363
409
  global_device_indexes = 0
410
+
411
+ epoch_loss = 0
412
+ current_epoch = current_step // train_steps_per_epoch
413
+ last_save_time = time.time()
414
+
415
+ if process_index == 0:
416
+ pbar = tqdm.tqdm(total=train_steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step')
417
+
418
+ for i in range(train_steps_per_epoch):
419
+ batch = next(train_ds)
420
+ if i == 0:
421
+ print(f"First batch loaded at step {current_step}")
422
+
423
+ if self.distributed_training and global_device_count > 1:
424
+ # # Convert the local device batches to a unified global jax.Array
425
+ batch = convert_to_global_tree(self.mesh, batch)
426
+ train_state, loss, rng_state = train_step_fn(train_state, rng_state, batch, global_device_indexes)
364
427
 
365
- def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
366
- epoch_loss = 0
367
- current_epoch = current_step // steps_per_epoch
368
- last_save_time = time.time()
369
- for i in range(steps_per_epoch):
370
- batch = next(train_ds)
371
- if self.distributed_training and global_device_count > 1:
372
- # Convert the local device batches to a unified global jax.Array
373
- batch = convert_to_global_tree(self.mesh, batch)
374
- train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
375
-
376
- if self.distributed_training:
377
- loss = jax.experimental.multihost_utils.process_allgather(loss)
378
- loss = jnp.mean(loss) # Just to make sure its a scaler value
428
+ if i == 0:
429
+ print(f"Training started for process index {process_index} at step {current_step}")
430
+
431
+ if self.distributed_training:
432
+ # loss = jax.experimental.multihost_utils.process_allgather(loss)
433
+ loss = jnp.mean(loss) # Just to make sure its a scaler value
379
434
 
380
- if loss <= 1e-6:
381
- # If the loss is too low, we can assume the model has diverged
382
- print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
383
- # Exit the training loop
384
- exit(1)
435
+ if loss <= 1e-6:
436
+ # If the loss is too low, we can assume the model has diverged
437
+ print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
438
+ # Reset the model to the old state
439
+ exit(1)
385
440
 
386
- epoch_loss += loss
387
- current_step += 1
388
- if i % 100 == 0:
389
- if pbar is not None:
390
- pbar.set_postfix(loss=f'{loss:.4f}')
391
- pbar.update(100)
392
- if self.wandb is not None:
393
- self.wandb.log({
394
- "train/step" : current_step,
395
- "train/loss": loss,
396
- }, step=current_step)
397
- # Save the model every 40 minutes
398
- if time.time() - last_save_time > 40 * 60:
399
- print(f"Saving model after 40 minutes at step {current_step}")
400
- self.save(current_epoch, current_step)
401
- last_save_time = time.time()
402
- print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
403
- return epoch_loss, current_step, train_state, rng_state
404
-
405
- while self.latest_step < epochs * steps_per_epoch:
406
- current_epoch = self.latest_step // steps_per_epoch
441
+ epoch_loss += loss
442
+ current_step += 1
443
+ if i % 100 == 0:
444
+ if pbar is not None:
445
+ pbar.set_postfix(loss=f'{loss:.4f}')
446
+ pbar.update(100)
447
+ if self.wandb is not None:
448
+ self.wandb.log({
449
+ "train/step" : current_step,
450
+ "train/loss": loss,
451
+ }, step=current_step)
452
+ # Save the model every few steps
453
+ if i % 10000 == 0 and i > 0:
454
+ print(f"Saving model after 10000 step {current_step}")
455
+ print(f"Devices: {len(jax.devices())}") # To sync the devices
456
+ self.save(current_epoch, current_step, train_state, rng_state)
457
+ print(f"Saving done by process index {process_index}")
458
+ last_save_time = time.time()
459
+ print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/train_steps_per_epoch}", 'green'))
460
+ if pbar is not None:
461
+ pbar.close()
462
+ return epoch_loss, current_step, train_state, rng_state
463
+
464
+
465
+ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
466
+ train_ds = iter(data['train']())
467
+ train_step = self._define_train_step(**train_step_args)
468
+ val_step = self._define_vaidation_step(**validation_step_args)
469
+ train_state = self.state
470
+ rng_state = self.rngstate
471
+ process_index = jax.process_index()
472
+
473
+ if val_steps_per_epoch > 0:
474
+ # We should first run a validation step to make sure the model is working
475
+ print(f"Validation run for sanity check for process index {process_index}")
476
+ # Validation step
477
+ self.validation_loop(
478
+ train_state,
479
+ val_step,
480
+ data.get('test', data.get('val', None)),
481
+ val_steps_per_epoch,
482
+ self.latest_step,
483
+ )
484
+ print(colored(f"Sanity Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
485
+
486
+ while self.latest_step < epochs * train_steps_per_epoch:
487
+ current_epoch = self.latest_step // train_steps_per_epoch
407
488
  print(f"\nEpoch {current_epoch}/{epochs}")
408
489
  start_time = time.time()
409
490
  epoch_loss = 0
410
-
411
- if process_index == 0:
412
- with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
413
- epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
414
- else:
415
- epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
416
- print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
491
+
492
+ epoch_loss, current_step, train_state, rng_state = self.train_loop(
493
+ train_state,
494
+ train_step,
495
+ train_ds,
496
+ train_steps_per_epoch,
497
+ self.latest_step,
498
+ rng_state,
499
+ )
500
+ print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
417
501
 
418
502
  self.latest_step = current_step
419
503
  end_time = time.time()
420
504
  self.state = train_state
421
505
  self.rngstate = rng_state
422
506
  total_time = end_time - start_time
423
- avg_time_per_step = total_time / steps_per_epoch
424
- avg_loss = epoch_loss / steps_per_epoch
507
+ avg_time_per_step = total_time / train_steps_per_epoch
508
+ avg_loss = epoch_loss / train_steps_per_epoch
425
509
  if avg_loss < self.best_loss:
426
510
  self.best_loss = avg_loss
427
511
  self.best_state = train_state
@@ -437,6 +521,18 @@ class SimpleTrainer:
437
521
  "train/epoch": current_epoch,
438
522
  }, step=current_step)
439
523
  print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
440
- print("Training done")
524
+
525
+ if val_steps_per_epoch > 0:
526
+ print(f"Validation started for process index {process_index}")
527
+ # Validation step
528
+ self.validation_loop(
529
+ train_state,
530
+ val_step,
531
+ data.get('test', None),
532
+ val_steps_per_epoch,
533
+ current_step,
534
+ )
535
+ print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
536
+
441
537
  self.save(epochs)
442
538
  return self.state
@@ -0,0 +1,62 @@
1
+ import flax
2
+ from flax import linen as nn
3
+ import jax
4
+ from typing import Callable
5
+ from dataclasses import field
6
+ import jax.numpy as jnp
7
+ import optax
8
+ import functools
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from jax.experimental.shard_map import shard_map
11
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
12
+
13
+ from ..schedulers import NoiseScheduler
14
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
15
+
16
+ from flaxdiff.utils import RandomMarkovState
17
+
18
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
19
+
20
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
21
+ from flax.training import dynamic_scale as dynamic_scale_lib
22
+
23
+ class TrainState(SimpleTrainState):
24
+ rngs: jax.random.PRNGKey
25
+ ema_params: dict
26
+
27
+ def apply_ema(self, decay: float = 0.999):
28
+ new_ema_params = jax.tree_util.tree_map(
29
+ lambda ema, param: decay * ema + (1 - decay) * param,
30
+ self.ema_params,
31
+ self.params,
32
+ )
33
+ return self.replace(ema_params=new_ema_params)
34
+
35
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
36
+ from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
37
+
38
+ class SimpleVideoDiffusionTrainer(DiffusionTrainer):
39
+ def __init__(self,
40
+ model: nn.Module,
41
+ input_shapes: Dict[str, Tuple[int]],
42
+ optimizer: optax.GradientTransformation,
43
+ noise_schedule: NoiseScheduler,
44
+ rngs: jax.random.PRNGKey,
45
+ unconditional_prob: float = 0.12,
46
+ name: str = "SimpleVideoDiffusion",
47
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
48
+ autoencoder: AutoEncoder = None,
49
+ **kwargs
50
+ ):
51
+ super().__init__(
52
+ model=model,
53
+ input_shapes=input_shapes,
54
+ optimizer=optimizer,
55
+ noise_schedule=noise_schedule,
56
+ unconditional_prob=unconditional_prob,
57
+ autoencoder=autoencoder,
58
+ model_output_transform=model_output_transform,
59
+ rngs=rngs,
60
+ name=name,
61
+ **kwargs
62
+ )
flaxdiff/utils.py CHANGED
@@ -2,7 +2,12 @@ import jax
2
2
  import jax.numpy as jnp
3
3
  import flax.struct as struct
4
4
  import flax.linen as nn
5
- from typing import Any
5
+ from typing import Any, Callable
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ import numpy as np
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from abc import ABC, abstractmethod
6
11
 
7
12
  class MarkovState(struct.PyTreeNode):
8
13
  pass
@@ -17,6 +22,30 @@ class RandomMarkovState(MarkovState):
17
22
  def clip_images(images, clip_min=-1, clip_max=1):
18
23
  return jnp.clip(images, clip_min, clip_max)
19
24
 
25
+ def _build_global_shape_and_sharding(
26
+ local_shape: tuple[int, ...], global_mesh: Mesh
27
+ ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
28
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
29
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
30
+ return global_shape, sharding
31
+
32
+
33
+ def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
34
+ """Put local sharded array into local devices"""
35
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
36
+ try:
37
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
38
+ except ValueError as array_split_error:
39
+ raise ValueError(
40
+ f"Unable to put to devices shape {array.shape} with "
41
+ f"local device count {len(global_mesh.local_devices)} "
42
+ ) from array_split_error
43
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
44
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
45
+
46
+ def convert_to_global_tree(global_mesh, pytree):
47
+ return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
48
+
20
49
  class RMSNorm(nn.Module):
21
50
  """
22
51
  From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
@@ -86,4 +115,78 @@ class RMSNorm(nn.Module):
86
115
  ).reshape(feature_shape)
87
116
  mul *= scale
88
117
  y = mul * x
89
- return jnp.asarray(y, dtype)
118
+ return jnp.asarray(y, dtype)
119
+
120
+ @dataclass
121
+ class ConditioningEncoder(ABC):
122
+ model: nn.Module
123
+ tokenizer: Callable
124
+
125
+ def __call__(self, data):
126
+ tokens = self.tokenize(data)
127
+ outputs = self.encode_from_tokens(tokens)
128
+ return outputs
129
+
130
+ def encode_from_tokens(self, tokens):
131
+ outputs = self.model(input_ids=tokens['input_ids'],
132
+ attention_mask=tokens['attention_mask'])
133
+ last_hidden_state = outputs.last_hidden_state
134
+ return last_hidden_state
135
+
136
+ def tokenize(self, data):
137
+ tokens = self.tokenizer(data, padding="max_length",
138
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
139
+ return tokens
140
+
141
+ @dataclass
142
+ class TextEncoder(ConditioningEncoder):
143
+ # def __call__(self, data):
144
+ # tokens = self.tokenize(data)
145
+ # outputs = self.encode_from_tokens(tokens)
146
+ # return outputs
147
+
148
+ # def encode_from_tokens(self, tokens):
149
+ # outputs = self.model(input_ids=tokens['input_ids'],
150
+ # attention_mask=tokens['attention_mask'])
151
+ # last_hidden_state = outputs.last_hidden_state
152
+ # # pooler_output = outputs.pooler_output # pooled (EOS token) states
153
+ # # embed_pooled = pooler_output # .astype(jnp.float16)
154
+ # embed_labels_full = last_hidden_state # .astype(jnp.float16)
155
+
156
+ # return embed_labels_full
157
+ pass
158
+
159
+ class AutoTextTokenizer:
160
+ def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
161
+ from transformers import AutoTokenizer
162
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
163
+ self.tensor_type = tensor_type
164
+
165
+ def __call__(self, inputs):
166
+ # print(caption)
167
+ tokens = self.tokenizer(inputs, padding="max_length", max_length=self.tokenizer.model_max_length,
168
+ truncation=True, return_tensors=self.tensor_type)
169
+ # print(tokens.keys())
170
+ return {
171
+ "input_ids": tokens["input_ids"],
172
+ "attention_mask": tokens["attention_mask"],
173
+ "caption": inputs,
174
+ }
175
+
176
+ def __repr__(self):
177
+ return self.__class__.__name__ + '()'
178
+
179
+ def defaultTextEncodeModel(backend="jax"):
180
+ from transformers import (
181
+ CLIPTextModel,
182
+ FlaxCLIPTextModel,
183
+ AutoTokenizer,
184
+ )
185
+ modelname = "openai/clip-vit-large-patch14"
186
+ if backend == "jax":
187
+ model = FlaxCLIPTextModel.from_pretrained(
188
+ modelname, dtype=jnp.bfloat16)
189
+ else:
190
+ model = CLIPTextModel.from_pretrained(modelname)
191
+ tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
192
+ return TextEncoder(model, tokenizer)
@@ -1,15 +1,21 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.35.5
3
+ Version: 0.1.36
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
7
7
  Description-Content-Type: text/markdown
8
- Requires-Dist: flax >=0.8.4
9
- Requires-Dist: optax >=0.2.2
10
- Requires-Dist: jax >=0.4.28
8
+ Requires-Dist: flax>=0.8.4
9
+ Requires-Dist: optax>=0.2.2
10
+ Requires-Dist: jax>=0.4.28
11
11
  Requires-Dist: orbax
12
12
  Requires-Dist: clu
13
+ Dynamic: author
14
+ Dynamic: author-email
15
+ Dynamic: description
16
+ Dynamic: description-content-type
17
+ Dynamic: requires-dist
18
+ Dynamic: summary
13
19
 
14
20
  # ![](images/logo.jpeg "FlaxDiff")
15
21
 
@@ -0,0 +1,43 @@
1
+ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
+ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
+ flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
5
+ flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
6
+ flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
7
+ flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
8
+ flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
9
+ flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
10
+ flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
11
+ flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
12
+ flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
13
+ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
14
+ flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
15
+ flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
16
+ flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
17
+ flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
18
+ flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
19
+ flaxdiff/samplers/common.py,sha256=ZA08VyovxegpRx4wOQq9LSwZi0gSCz2lrbS5oVYOEYg,8488
20
+ flaxdiff/samplers/ddim.py,sha256=pB8Kod8ZLJ3GXev4uM3cOj1Uy6ibR0jsaZa-VE0fyJM,552
21
+ flaxdiff/samplers/ddpm.py,sha256=u1OchQu0XPhc_6w9JXoaFp2wo4y-zXyQNtGAIJwxNLg,2209
22
+ flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,2740
23
+ flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
24
+ flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
25
+ flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
26
+ flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
27
+ flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
28
+ flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
29
+ flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
30
+ flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
31
+ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
32
+ flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
33
+ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
34
+ flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
35
+ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
36
+ flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
37
+ flaxdiff/trainer/diffusion_trainer.py,sha256=ajOWBgFFwXP_VQScUjcuPoaB4Gk02aF0Ls5LNlA8wqA,12691
38
+ flaxdiff/trainer/simple_trainer.py,sha256=jCD9-qCwX0SC0rN3GrXUBfRrndWNqUI0HmbOAbmYBMM,21906
39
+ flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
40
+ flaxdiff-0.1.36.dist-info/METADATA,sha256=7fO1e_icIEK6dmSopv538Hm2fQnhnkOAE2Ab9inpcNE,22213
41
+ flaxdiff-0.1.36.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
42
+ flaxdiff-0.1.36.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
43
+ flaxdiff-0.1.36.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5