flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,8 @@ from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
25
  from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
27
  from flax.training import dynamic_scale as dynamic_scale_lib
28
+ from dataclasses import dataclass
29
+ import gc
28
30
 
29
31
  PROCESS_COLOR_MAP = {
30
32
  0: "green",
@@ -71,6 +73,7 @@ class SimpleTrainState(train_state.TrainState):
71
73
  metrics: Metrics
72
74
  dynamic_scale: dynamic_scale_lib.DynamicScale
73
75
 
76
+ @dataclass
74
77
  class SimpleTrainer:
75
78
  state: SimpleTrainState
76
79
  best_state: SimpleTrainState
@@ -86,7 +89,6 @@ class SimpleTrainer:
86
89
  train_state: SimpleTrainState = None,
87
90
  name: str = "Simple",
88
91
  load_from_checkpoint: str = None,
89
- checkpoint_suffix: str = "",
90
92
  loss_fn=optax.l2_loss,
91
93
  param_transforms: Callable = None,
92
94
  wandb_config: Dict[str, Any] = None,
@@ -94,6 +96,7 @@ class SimpleTrainer:
94
96
  checkpoint_base_path: str = "./checkpoints",
95
97
  checkpoint_step: int = None,
96
98
  use_dynamic_scale: bool = False,
99
+ max_checkpoints_to_keep: int = 2,
97
100
  ):
98
101
  if distributed_training is None or distributed_training is True:
99
102
  # Auto-detect if we are running on multiple devices
@@ -109,10 +112,9 @@ class SimpleTrainer:
109
112
  self.input_shapes = input_shapes
110
113
  self.checkpoint_base_path = checkpoint_base_path
111
114
 
112
-
113
115
  if wandb_config is not None and jax.process_index() == 0:
114
116
  import wandb
115
- run = wandb.init(**wandb_config)
117
+ run = wandb.init(resume='allow', **wandb_config)
116
118
  self.wandb = run
117
119
 
118
120
  # define our custom x axis metric
@@ -126,13 +128,18 @@ class SimpleTrainer:
126
128
  self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
127
129
  self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
128
130
 
131
+ if self.wandb.sweep_id:
132
+ api = wandb.Api()
133
+ self.wandb_sweep = api.sweep(f"{self.wandb.entity}/{self.wandb.project}/{self.wandb.sweep_id}")
134
+ print(f"Running sweep {self.wandb_sweep.id} with id {self.wandb.sweep_id}")
135
+
129
136
  # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
130
137
  async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
131
138
 
132
139
  options = orbax.checkpoint.CheckpointManagerOptions(
133
- max_to_keep=4, create=True)
140
+ max_to_keep=max_checkpoints_to_keep, create=True)
134
141
  self.checkpointer = orbax.checkpoint.CheckpointManager(
135
- self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
142
+ self.checkpoint_path(), async_checkpointer, options)
136
143
 
137
144
  if load_from_checkpoint is not None:
138
145
  latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
@@ -248,6 +255,10 @@ class SimpleTrainer:
248
255
  step = checkpoint_step
249
256
 
250
257
  print("Loading model from checkpoint at step ", step)
258
+ loaded_checkpoint_path = os.path.join(
259
+ checkpoint_path if checkpoint_path else self.checkpoint_path(),
260
+ f"{step}")
261
+ self.loaded_checkpoint_path = loaded_checkpoint_path
251
262
  ckpt = checkpointer.restore(step)
252
263
  state = ckpt['state']
253
264
  best_state = ckpt['best_state']
@@ -311,7 +322,7 @@ class SimpleTrainer:
311
322
  train_step = jax.pmap(train_step)
312
323
  return train_step
313
324
 
314
- def _define_vaidation_step(self):
325
+ def _define_validation_step(self):
315
326
  model = self.model
316
327
  loss_fn = self.loss_fn
317
328
  distributed_training = self.distributed_training
@@ -418,8 +429,8 @@ class SimpleTrainer:
418
429
 
419
430
  for i in range(train_steps_per_epoch):
420
431
  batch = next(train_ds)
421
- if i == 0:
422
- print(f"First batch loaded at step {current_step}")
432
+ # if i == 0:
433
+ # print(f"First batch loaded at step {current_step}")
423
434
 
424
435
  if self.distributed_training and global_device_count > 1:
425
436
  # # Convert the local device batches to a unified global jax.Array
@@ -433,16 +444,40 @@ class SimpleTrainer:
433
444
  # loss = jax.experimental.multihost_utils.process_allgather(loss)
434
445
  loss = jnp.mean(loss) # Just to make sure its a scaler value
435
446
 
436
- if loss <= 1e-8:
437
- # If the loss is too low, we can assume the model has diverged
438
- print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
439
- # Reset the model to the old state
440
- if self.best_state is not None:
441
- print(colored(f"Resetting model to best state", 'red'))
447
+ if loss <= 1e-8 or jnp.isnan(loss) or jnp.isinf(loss):
448
+ # If the loss is too low or NaN/Inf, log the issue and attempt recovery
449
+ print(colored(f"Abnormal loss at step {current_step}: {loss}", 'red'))
450
+
451
+ # Check model parameters for NaN/Inf values
452
+ params = train_state.params
453
+ has_nan_or_inf = False
454
+
455
+ if isinstance(params, dict):
456
+ for key, value in params.items():
457
+ if isinstance(value, jnp.ndarray):
458
+ if jnp.isnan(value).any() or jnp.isinf(value).any():
459
+ print(colored(f"NaN/inf values found in params[{key}] at step {current_step}", 'red'))
460
+ has_nan_or_inf = True
461
+ break
462
+
463
+ if not has_nan_or_inf:
464
+ print(colored(f"Model parameters seem valid despite abnormal loss", 'yellow'))
465
+
466
+ # Try to recover - clear JAX caches and collect garbage
467
+ gc.collect()
468
+ if hasattr(jax, "clear_caches"):
469
+ jax.clear_caches()
470
+
471
+ # If we have a best state and the loss is truly invalid, consider restoring
472
+ if (loss <= 1e-8 or jnp.isnan(loss) or jnp.isinf(loss)) and self.best_state is not None:
473
+ print(colored(f"Attempting recovery by resetting model to last best state", 'yellow'))
442
474
  train_state = self.best_state
443
475
  loss = self.best_loss
444
476
  else:
445
- exit(1)
477
+ # If we can't recover, skip this step but continue training
478
+ print(colored(f"Unable to recover - continuing with current state", 'yellow'))
479
+ if loss <= 1e-8:
480
+ loss = 1.0 # Set to a reasonable default to continue training
446
481
 
447
482
  epoch_loss += loss
448
483
  current_step += 1
@@ -471,7 +506,7 @@ class SimpleTrainer:
471
506
  def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
472
507
  train_ds = iter(data['train']())
473
508
  train_step = self._define_train_step(**train_step_args)
474
- val_step = self._define_vaidation_step(**validation_step_args)
509
+ val_step = self._define_validation_step(**validation_step_args)
475
510
  train_state = self.state
476
511
  rng_state = self.rngstate
477
512
  process_index = jax.process_index()
flaxdiff/utils.py CHANGED
@@ -2,26 +2,145 @@ import jax
2
2
  import jax.numpy as jnp
3
3
  import flax.struct as struct
4
4
  import flax.linen as nn
5
- from typing import Any, Callable
6
- from dataclasses import dataclass
5
+ from typing import Any
7
6
  from functools import partial
8
7
  import numpy as np
8
+ import os
9
9
  from jax.sharding import Mesh, PartitionSpec as P
10
- from abc import ABC, abstractmethod
10
+ from flaxdiff.inputs import TextEncoder, CLIPTextEncoder
11
+
12
+ # Setup mappings for dtype, precision, and activation
13
+ DTYPE_MAP = {
14
+ 'bfloat16': jnp.bfloat16,
15
+ 'float32': jnp.float32,
16
+ 'jax.numpy.float32': jnp.float32,
17
+ 'jax.numpy.bfloat16': jnp.bfloat16,
18
+ 'None': None,
19
+ None: None,
20
+ }
21
+
22
+ PRECISION_MAP = {
23
+ 'high': jax.lax.Precision.HIGH,
24
+ 'HIGH': jax.lax.Precision.HIGH,
25
+ 'default': jax.lax.Precision.DEFAULT,
26
+ 'DEFAULT': jax.lax.Precision.DEFAULT,
27
+ 'highest': jax.lax.Precision.HIGHEST,
28
+ 'HIGHEST': jax.lax.Precision.HIGHEST,
29
+ 'None': None,
30
+ None: None,
31
+ }
32
+
33
+ ACTIVATION_MAP = {
34
+ 'swish': jax.nn.swish,
35
+ 'silu': jax.nn.silu,
36
+ 'jax._src.nn.functions.silu': jax.nn.silu,
37
+ 'mish': jax.nn.mish,
38
+ }
39
+
40
+ def map_nested_config(config):
41
+ new_config = {}
42
+ for key, value in config.items():
43
+ if isinstance(value, dict):
44
+ new_config[key] = map_nested_config(value)
45
+ elif isinstance(value, str):
46
+ if value in DTYPE_MAP:
47
+ new_config[key] = DTYPE_MAP[value]
48
+ elif value in PRECISION_MAP:
49
+ new_config[key] = PRECISION_MAP[value]
50
+ elif value in ACTIVATION_MAP:
51
+ new_config[key] = ACTIVATION_MAP[value]
52
+ elif value == 'None':
53
+ new_config[key] = None
54
+ elif '.' in value:
55
+ # Ignore any other string that contains a dot
56
+ print(
57
+ f"Ignoring key {key} with value {value} as it contains a dot.")
58
+ return new_config
59
+
60
+ def serialize_model(model: nn.Module):
61
+ """
62
+ Serializes the model to a dictionary format.
63
+ """
64
+ model_dict = model.__dict__
65
+ model_dict = {k: v for k, v in model_dict.items() if not k.startswith('_')}
66
+ # Convert all callable attributes to their string representation
67
+ def map(model_dict):
68
+ for k, v in model_dict.items():
69
+ if isinstance(v, dict):
70
+ # Recursively serialize nested dictionaries
71
+ model_dict[k] = map(v)
72
+ elif isinstance(v, list):
73
+ # Recursively serialize lists
74
+ [map(item) if isinstance(item, dict) else item for item in v]
75
+ elif callable(v):
76
+ # If the attribute has __name__, use that as the key
77
+ if hasattr(v, '__name__'):
78
+ model_dict[k] = v.__name__
79
+ else:
80
+ model_dict[k] = str(v).split('.')[-1]
81
+ map(model_dict)
82
+ return model_dict
83
+
84
+ def get_latest_checkpoint(checkpoint_path):
85
+ checkpoint_files = os.listdir(checkpoint_path)
86
+ # Sort files by step number
87
+ checkpoint_files = sorted([int(i) for i in checkpoint_files])
88
+ latest_step = checkpoint_files[-1]
89
+ latest_checkpoint = os.path.join(checkpoint_path, str(latest_step))
90
+ return latest_checkpoint
11
91
 
12
92
  class MarkovState(struct.PyTreeNode):
13
93
  pass
14
94
 
15
95
  class RandomMarkovState(MarkovState):
16
96
  rng: jax.random.PRNGKey
17
-
18
97
  def get_random_key(self):
19
98
  rng, subkey = jax.random.split(self.rng)
20
99
  return RandomMarkovState(rng), subkey
21
100
 
22
101
  def clip_images(images, clip_min=-1, clip_max=1):
102
+ """Clip image values to a specified range.
103
+
104
+ Args:
105
+ images: Images to clip
106
+ clip_min: Minimum value
107
+ clip_max: Maximum value
108
+
109
+ Returns:
110
+ Clipped images
111
+ """
23
112
  return jnp.clip(images, clip_min, clip_max)
24
113
 
114
+ def denormalize_images(images, target_type=jnp.uint8, source_range=(-1, 1), target_range=(0, 255)):
115
+ """Convert images from normalized range (e.g. [-1, 1]) to target range (e.g. [0, 255]).
116
+
117
+ Args:
118
+ images: Normalized images
119
+ target_type: Target dtype (e.g. jnp.uint8 for standard images)
120
+ source_range: Tuple of (min, max) for the source normalization range
121
+ target_range: Tuple of (min, max) for the target range
122
+
123
+ Returns:
124
+ Denormalized images in the target dtype
125
+ """
126
+ src_min, src_max = source_range
127
+ tgt_min, tgt_max = target_range
128
+
129
+ # First clip to ensure we're in the expected source range
130
+ images = clip_images(images, src_min, src_max)
131
+
132
+ # Scale to [0, 1]
133
+ images = (images - src_min) / (src_max - src_min)
134
+
135
+ # Scale to target range
136
+ images = images * (tgt_max - tgt_min) + tgt_min
137
+
138
+ # Convert to target dtype if needed
139
+ if target_type is not None:
140
+ images = images.astype(target_type)
141
+
142
+ return images
143
+
25
144
  def _build_global_shape_and_sharding(
26
145
  local_shape: tuple[int, ...], global_mesh: Mesh
27
146
  ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
@@ -117,45 +236,6 @@ class RMSNorm(nn.Module):
117
236
  y = mul * x
118
237
  return jnp.asarray(y, dtype)
119
238
 
120
- @dataclass
121
- class ConditioningEncoder(ABC):
122
- model: nn.Module
123
- tokenizer: Callable
124
-
125
- def __call__(self, data):
126
- tokens = self.tokenize(data)
127
- outputs = self.encode_from_tokens(tokens)
128
- return outputs
129
-
130
- def encode_from_tokens(self, tokens):
131
- outputs = self.model(input_ids=tokens['input_ids'],
132
- attention_mask=tokens['attention_mask'])
133
- last_hidden_state = outputs.last_hidden_state
134
- return last_hidden_state
135
-
136
- def tokenize(self, data):
137
- tokens = self.tokenizer(data, padding="max_length",
138
- max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
139
- return tokens
140
-
141
- @dataclass
142
- class TextEncoder(ConditioningEncoder):
143
- # def __call__(self, data):
144
- # tokens = self.tokenize(data)
145
- # outputs = self.encode_from_tokens(tokens)
146
- # return outputs
147
-
148
- # def encode_from_tokens(self, tokens):
149
- # outputs = self.model(input_ids=tokens['input_ids'],
150
- # attention_mask=tokens['attention_mask'])
151
- # last_hidden_state = outputs.last_hidden_state
152
- # # pooler_output = outputs.pooler_output # pooled (EOS token) states
153
- # # embed_pooled = pooler_output # .astype(jnp.float16)
154
- # embed_labels_full = last_hidden_state # .astype(jnp.float16)
155
-
156
- # return embed_labels_full
157
- pass
158
-
159
239
  class AutoTextTokenizer:
160
240
  def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
161
241
  from transformers import AutoTokenizer
@@ -175,18 +255,9 @@ class AutoTextTokenizer:
175
255
 
176
256
  def __repr__(self):
177
257
  return self.__class__.__name__ + '()'
258
+
259
+ # class AutoAudioTokenizer:
178
260
 
179
- def defaultTextEncodeModel(backend="jax"):
180
- from transformers import (
181
- CLIPTextModel,
182
- FlaxCLIPTextModel,
183
- AutoTokenizer,
184
- )
185
- modelname = "openai/clip-vit-large-patch14"
186
- if backend == "jax":
187
- model = FlaxCLIPTextModel.from_pretrained(
188
- modelname, dtype=jnp.bfloat16)
189
- else:
190
- model = CLIPTextModel.from_pretrained(modelname)
191
- tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
192
- return TextEncoder(model, tokenizer)
261
+ def defaultTextEncodeModel(modelname = "openai/clip-vit-large-patch14", backend="jax"):
262
+ """Default text encoder model."""
263
+ return CLIPTextEncoder.from_modelname(modelname=modelname, backend=backend)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.38
3
+ Version: 0.2.0
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,64 @@
1
+ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
+ flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
+ flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
+ flaxdiff/data/dataloaders.py,sha256=V4goNCK0JD_TthggXAEgJJD4LxJi1pUDew1x_fMCuO4,22576
6
+ flaxdiff/data/dataset_map.py,sha256=NrLG1XtIxy8GcCsZ-e6eascjgsP0Xq5lVA1z3HIIYyI,5093
7
+ flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
+ flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
+ flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
+ flaxdiff/data/sources/av_utils.py,sha256=n2qwMBQGouoBH025vdE7gitWC6RduUommUrs-SPdWe4,24041
11
+ flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
+ flaxdiff/data/sources/images.py,sha256=WpH4ywZhNol26peX3m6m5NrmDJ1K2s6fRcYHvOFlOk8,11102
13
+ flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
+ flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
15
+ flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
+ flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ flaxdiff/inference/pipeline.py,sha256=oMBRjvTtlC3Yzl1FqiBHcI4V34HXGAecCg8UvQbKoOc,8849
18
+ flaxdiff/inference/utils.py,sha256=SRNYo-YtHzEPRpNv0fD8ZrUvnRIK941Rh4tjlsOGRgM,12278
19
+ flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
+ flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
+ flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
22
+ flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
+ flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
25
+ flaxdiff/models/__init__.py,sha256=amtDF07DfiAdnZsvWX4eaW79nwNEU1s8Zb4PB3ewtg4,118
26
+ flaxdiff/models/attention.py,sha256=-q3xqWy4vQSLG4vXtiUN3FHVBIo7ZjpQsdLT9CkML6c,13367
27
+ flaxdiff/models/common.py,sha256=7x9o5vY9UZvN4BNZ7LHzyuU3PNpsNym9B3m1Wfdddjo,10320
28
+ flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
29
+ flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
30
+ flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
31
+ flaxdiff/models/simple_vit.py,sha256=6DNpwTeE0Gn2jSie6n0JVUmQncPoyFT7jSSBreqk458,7497
32
+ flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
33
+ flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
34
+ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
35
+ flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
36
+ flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
37
+ flaxdiff/models/autoencoder/simple_autoenc.py,sha256=NnGFjrkq-1z8Ouh_UvlvP0PFpkzm2LYaTQKMmN1BhkM,2109
38
+ flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
39
+ flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
40
+ flaxdiff/samplers/common.py,sha256=-uU9FLkoQp3n3bga8Kfj_onDhtOS4MBggIKNrq3S8n4,18438
41
+ flaxdiff/samplers/ddim.py,sha256=iFgXz96NBYuNMiWGMovf3gLO2TCxdrhJd-o8tvSmVUI,2054
42
+ flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
43
+ flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
44
+ flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGEpHU8,1436
45
+ flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
46
+ flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
47
+ flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
48
+ flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
49
+ flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
50
+ flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
51
+ flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
52
+ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
53
+ flaxdiff/schedulers/karras.py,sha256=7PS6mHdnZnTqS2Xl_DacBt5YQ1f_CFyAxShyOo55eG0,3804
54
+ flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
55
+ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
56
+ flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
57
+ flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
58
+ flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
59
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=VQ5p2ZaTv2R1LM0Epz4e719_EfK2dh1eoKK3WIysIW0,24040
60
+ flaxdiff/trainer/simple_trainer.py,sha256=CF2mMcc6AtBgcR1XiqKevRL0paGS0S9ZJofCns32nRM,24214
61
+ flaxdiff-0.2.0.dist-info/METADATA,sha256=1WLpd9RQy_mJE2E2uOdXptY5Fm3n_MTNcgZyBD7YmGw,23982
62
+ flaxdiff-0.2.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
63
+ flaxdiff-0.2.0.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
+ flaxdiff-0.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (78.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
flaxdiff/data/datasets.py DELETED
@@ -1,169 +0,0 @@
1
- import jax.numpy as jnp
2
- import grain.python as pygrain
3
- from typing import Dict
4
- import numpy as np
5
- import jax
6
- from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
- from .dataset_map import datasetMap, onlineDatasetMap
8
- import traceback
9
- from .online_loader import OnlineStreamingDataLoader
10
- import queue
11
- from jax.sharding import Mesh
12
- import threading
13
-
14
- def batch_mesh_map(mesh):
15
- class augmenters(pygrain.MapTransform):
16
- def __init__(self, *args, **kwargs):
17
- super().__init__(*args, **kwargs)
18
-
19
- def map(self, batch) -> Dict[str, jnp.array]:
20
- return convert_to_global_tree(mesh, batch)
21
- return augmenters
22
-
23
- def get_dataset_grain(
24
- data_name="cc12m",
25
- batch_size=64,
26
- image_scale=256,
27
- count=None,
28
- num_epochs=None,
29
- method=jax.image.ResizeMethod.LANCZOS3,
30
- worker_count=32,
31
- read_thread_count=64,
32
- read_buffer_size=50,
33
- worker_buffer_size=20,
34
- seed=0,
35
- dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
- ):
37
- dataset = datasetMap[data_name]
38
- data_source = dataset["source"](dataset_source)
39
- augmenter = dataset["augmenter"](image_scale, method)
40
-
41
- local_batch_size = batch_size // jax.process_count()
42
-
43
- sampler = pygrain.IndexSampler(
44
- num_records=len(data_source) if count is None else count,
45
- shuffle=True,
46
- seed=seed,
47
- num_epochs=num_epochs,
48
- shard_options=pygrain.ShardByJaxProcess(),
49
- )
50
-
51
- def get_trainset():
52
- transformations = [
53
- augmenter(),
54
- pygrain.Batch(local_batch_size, drop_remainder=True),
55
- ]
56
-
57
- # if mesh != None:
58
- # transformations += [batch_mesh_map(mesh)]
59
-
60
- loader = pygrain.DataLoader(
61
- data_source=data_source,
62
- sampler=sampler,
63
- operations=transformations,
64
- worker_count=worker_count,
65
- read_options=pygrain.ReadOptions(
66
- read_thread_count, read_buffer_size
67
- ),
68
- worker_buffer_size=worker_buffer_size,
69
- )
70
- return loader
71
-
72
-
73
- return {
74
- "train": get_trainset,
75
- "train_len": len(data_source),
76
- "local_batch_size": local_batch_size,
77
- "global_batch_size": batch_size,
78
- # "null_labels": null_labels,
79
- # "null_labels_full": null_labels_full,
80
- # "model": model,
81
- # "tokenizer": tokenizer,
82
- }
83
-
84
- def generate_collate_fn():
85
- auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
- def default_collate(batch):
87
- try:
88
- # urls = [sample["url"] for sample in batch]
89
- captions = [sample["caption"] for sample in batch]
90
- results = auto_tokenize(captions)
91
- images = np.stack([sample["image"] for sample in batch], axis=0)
92
- return {
93
- "image": images,
94
- "input_ids": results['input_ids'],
95
- "attention_mask": results['attention_mask'],
96
- }
97
- except Exception as e:
98
- print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
- traceback.print_exc()
100
-
101
- return default_collate
102
-
103
- def get_dataset_online(
104
- data_name="combined_online",
105
- batch_size=64,
106
- image_scale=256,
107
- count=None,
108
- num_epochs=None,
109
- method=jax.image.ResizeMethod.LANCZOS3,
110
- worker_count=32,
111
- read_thread_count=64,
112
- read_buffer_size=50,
113
- worker_buffer_size=20,
114
- seed=0,
115
- dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
- ):
117
- local_batch_size = batch_size // jax.process_count()
118
-
119
- sources = onlineDatasetMap[data_name]["source"]
120
- dataloader = OnlineStreamingDataLoader(
121
- sources,
122
- batch_size=local_batch_size,
123
- num_workers=worker_count,
124
- num_threads=read_thread_count,
125
- image_shape=(image_scale, image_scale),
126
- global_process_count=jax.process_count(),
127
- global_process_index=jax.process_index(),
128
- prefetch=worker_buffer_size,
129
- collate_fn=generate_collate_fn(),
130
- default_split="train",
131
- )
132
-
133
- def get_trainset(mesh: Mesh = None):
134
- if mesh != None:
135
- class dataLoaderWithMesh:
136
- def __init__(self, dataloader, mesh):
137
- self.dataloader = dataloader
138
- self.mesh = mesh
139
- self.tmp_queue = queue.Queue(worker_buffer_size)
140
- def batch_loader():
141
- for batch in self.dataloader:
142
- try:
143
- self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
- except Exception as e:
145
- print("Error processing batch", e)
146
- self.loader_thread = threading.Thread(target=batch_loader)
147
- self.loader_thread.start()
148
-
149
- def __iter__(self):
150
- return self
151
-
152
- def __next__(self):
153
- return self.tmp_queue.get()
154
-
155
- dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
-
157
- return dataloader_with_mesh
158
- return dataloader
159
-
160
- return {
161
- "train": get_trainset,
162
- "train_len": len(dataloader) * jax.process_count(),
163
- "local_batch_size": local_batch_size,
164
- "global_batch_size": batch_size,
165
- # "null_labels": null_labels,
166
- # "null_labels_full": null_labels_full,
167
- # "model": model,
168
- # "tokenizer": tokenizer,
169
- }