flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,8 @@ from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
|
25
25
|
from flax.training.dynamic_scale import DynamicScale
|
26
26
|
from flaxdiff.utils import RandomMarkovState
|
27
27
|
from flax.training import dynamic_scale as dynamic_scale_lib
|
28
|
+
from dataclasses import dataclass
|
29
|
+
import gc
|
28
30
|
|
29
31
|
PROCESS_COLOR_MAP = {
|
30
32
|
0: "green",
|
@@ -71,6 +73,7 @@ class SimpleTrainState(train_state.TrainState):
|
|
71
73
|
metrics: Metrics
|
72
74
|
dynamic_scale: dynamic_scale_lib.DynamicScale
|
73
75
|
|
76
|
+
@dataclass
|
74
77
|
class SimpleTrainer:
|
75
78
|
state: SimpleTrainState
|
76
79
|
best_state: SimpleTrainState
|
@@ -86,7 +89,6 @@ class SimpleTrainer:
|
|
86
89
|
train_state: SimpleTrainState = None,
|
87
90
|
name: str = "Simple",
|
88
91
|
load_from_checkpoint: str = None,
|
89
|
-
checkpoint_suffix: str = "",
|
90
92
|
loss_fn=optax.l2_loss,
|
91
93
|
param_transforms: Callable = None,
|
92
94
|
wandb_config: Dict[str, Any] = None,
|
@@ -94,6 +96,7 @@ class SimpleTrainer:
|
|
94
96
|
checkpoint_base_path: str = "./checkpoints",
|
95
97
|
checkpoint_step: int = None,
|
96
98
|
use_dynamic_scale: bool = False,
|
99
|
+
max_checkpoints_to_keep: int = 2,
|
97
100
|
):
|
98
101
|
if distributed_training is None or distributed_training is True:
|
99
102
|
# Auto-detect if we are running on multiple devices
|
@@ -109,10 +112,9 @@ class SimpleTrainer:
|
|
109
112
|
self.input_shapes = input_shapes
|
110
113
|
self.checkpoint_base_path = checkpoint_base_path
|
111
114
|
|
112
|
-
|
113
115
|
if wandb_config is not None and jax.process_index() == 0:
|
114
116
|
import wandb
|
115
|
-
run = wandb.init(**wandb_config)
|
117
|
+
run = wandb.init(resume='allow', **wandb_config)
|
116
118
|
self.wandb = run
|
117
119
|
|
118
120
|
# define our custom x axis metric
|
@@ -126,13 +128,18 @@ class SimpleTrainer:
|
|
126
128
|
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
127
129
|
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
128
130
|
|
131
|
+
if self.wandb.sweep_id:
|
132
|
+
api = wandb.Api()
|
133
|
+
self.wandb_sweep = api.sweep(f"{self.wandb.entity}/{self.wandb.project}/{self.wandb.sweep_id}")
|
134
|
+
print(f"Running sweep {self.wandb_sweep.id} with id {self.wandb.sweep_id}")
|
135
|
+
|
129
136
|
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
130
137
|
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
131
138
|
|
132
139
|
options = orbax.checkpoint.CheckpointManagerOptions(
|
133
|
-
max_to_keep=
|
140
|
+
max_to_keep=max_checkpoints_to_keep, create=True)
|
134
141
|
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
135
|
-
self.checkpoint_path()
|
142
|
+
self.checkpoint_path(), async_checkpointer, options)
|
136
143
|
|
137
144
|
if load_from_checkpoint is not None:
|
138
145
|
latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
|
@@ -248,6 +255,10 @@ class SimpleTrainer:
|
|
248
255
|
step = checkpoint_step
|
249
256
|
|
250
257
|
print("Loading model from checkpoint at step ", step)
|
258
|
+
loaded_checkpoint_path = os.path.join(
|
259
|
+
checkpoint_path if checkpoint_path else self.checkpoint_path(),
|
260
|
+
f"{step}")
|
261
|
+
self.loaded_checkpoint_path = loaded_checkpoint_path
|
251
262
|
ckpt = checkpointer.restore(step)
|
252
263
|
state = ckpt['state']
|
253
264
|
best_state = ckpt['best_state']
|
@@ -311,7 +322,7 @@ class SimpleTrainer:
|
|
311
322
|
train_step = jax.pmap(train_step)
|
312
323
|
return train_step
|
313
324
|
|
314
|
-
def
|
325
|
+
def _define_validation_step(self):
|
315
326
|
model = self.model
|
316
327
|
loss_fn = self.loss_fn
|
317
328
|
distributed_training = self.distributed_training
|
@@ -418,8 +429,8 @@ class SimpleTrainer:
|
|
418
429
|
|
419
430
|
for i in range(train_steps_per_epoch):
|
420
431
|
batch = next(train_ds)
|
421
|
-
if i == 0:
|
422
|
-
|
432
|
+
# if i == 0:
|
433
|
+
# print(f"First batch loaded at step {current_step}")
|
423
434
|
|
424
435
|
if self.distributed_training and global_device_count > 1:
|
425
436
|
# # Convert the local device batches to a unified global jax.Array
|
@@ -433,16 +444,40 @@ class SimpleTrainer:
|
|
433
444
|
# loss = jax.experimental.multihost_utils.process_allgather(loss)
|
434
445
|
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
435
446
|
|
436
|
-
if loss <= 1e-8:
|
437
|
-
# If the loss is too low
|
438
|
-
print(colored(f"
|
439
|
-
|
440
|
-
|
441
|
-
|
447
|
+
if loss <= 1e-8 or jnp.isnan(loss) or jnp.isinf(loss):
|
448
|
+
# If the loss is too low or NaN/Inf, log the issue and attempt recovery
|
449
|
+
print(colored(f"Abnormal loss at step {current_step}: {loss}", 'red'))
|
450
|
+
|
451
|
+
# Check model parameters for NaN/Inf values
|
452
|
+
params = train_state.params
|
453
|
+
has_nan_or_inf = False
|
454
|
+
|
455
|
+
if isinstance(params, dict):
|
456
|
+
for key, value in params.items():
|
457
|
+
if isinstance(value, jnp.ndarray):
|
458
|
+
if jnp.isnan(value).any() or jnp.isinf(value).any():
|
459
|
+
print(colored(f"NaN/inf values found in params[{key}] at step {current_step}", 'red'))
|
460
|
+
has_nan_or_inf = True
|
461
|
+
break
|
462
|
+
|
463
|
+
if not has_nan_or_inf:
|
464
|
+
print(colored(f"Model parameters seem valid despite abnormal loss", 'yellow'))
|
465
|
+
|
466
|
+
# Try to recover - clear JAX caches and collect garbage
|
467
|
+
gc.collect()
|
468
|
+
if hasattr(jax, "clear_caches"):
|
469
|
+
jax.clear_caches()
|
470
|
+
|
471
|
+
# If we have a best state and the loss is truly invalid, consider restoring
|
472
|
+
if (loss <= 1e-8 or jnp.isnan(loss) or jnp.isinf(loss)) and self.best_state is not None:
|
473
|
+
print(colored(f"Attempting recovery by resetting model to last best state", 'yellow'))
|
442
474
|
train_state = self.best_state
|
443
475
|
loss = self.best_loss
|
444
476
|
else:
|
445
|
-
|
477
|
+
# If we can't recover, skip this step but continue training
|
478
|
+
print(colored(f"Unable to recover - continuing with current state", 'yellow'))
|
479
|
+
if loss <= 1e-8:
|
480
|
+
loss = 1.0 # Set to a reasonable default to continue training
|
446
481
|
|
447
482
|
epoch_loss += loss
|
448
483
|
current_step += 1
|
@@ -471,7 +506,7 @@ class SimpleTrainer:
|
|
471
506
|
def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
|
472
507
|
train_ds = iter(data['train']())
|
473
508
|
train_step = self._define_train_step(**train_step_args)
|
474
|
-
val_step = self.
|
509
|
+
val_step = self._define_validation_step(**validation_step_args)
|
475
510
|
train_state = self.state
|
476
511
|
rng_state = self.rngstate
|
477
512
|
process_index = jax.process_index()
|
flaxdiff/utils.py
CHANGED
@@ -2,26 +2,145 @@ import jax
|
|
2
2
|
import jax.numpy as jnp
|
3
3
|
import flax.struct as struct
|
4
4
|
import flax.linen as nn
|
5
|
-
from typing import Any
|
6
|
-
from dataclasses import dataclass
|
5
|
+
from typing import Any
|
7
6
|
from functools import partial
|
8
7
|
import numpy as np
|
8
|
+
import os
|
9
9
|
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
-
from
|
10
|
+
from flaxdiff.inputs import TextEncoder, CLIPTextEncoder
|
11
|
+
|
12
|
+
# Setup mappings for dtype, precision, and activation
|
13
|
+
DTYPE_MAP = {
|
14
|
+
'bfloat16': jnp.bfloat16,
|
15
|
+
'float32': jnp.float32,
|
16
|
+
'jax.numpy.float32': jnp.float32,
|
17
|
+
'jax.numpy.bfloat16': jnp.bfloat16,
|
18
|
+
'None': None,
|
19
|
+
None: None,
|
20
|
+
}
|
21
|
+
|
22
|
+
PRECISION_MAP = {
|
23
|
+
'high': jax.lax.Precision.HIGH,
|
24
|
+
'HIGH': jax.lax.Precision.HIGH,
|
25
|
+
'default': jax.lax.Precision.DEFAULT,
|
26
|
+
'DEFAULT': jax.lax.Precision.DEFAULT,
|
27
|
+
'highest': jax.lax.Precision.HIGHEST,
|
28
|
+
'HIGHEST': jax.lax.Precision.HIGHEST,
|
29
|
+
'None': None,
|
30
|
+
None: None,
|
31
|
+
}
|
32
|
+
|
33
|
+
ACTIVATION_MAP = {
|
34
|
+
'swish': jax.nn.swish,
|
35
|
+
'silu': jax.nn.silu,
|
36
|
+
'jax._src.nn.functions.silu': jax.nn.silu,
|
37
|
+
'mish': jax.nn.mish,
|
38
|
+
}
|
39
|
+
|
40
|
+
def map_nested_config(config):
|
41
|
+
new_config = {}
|
42
|
+
for key, value in config.items():
|
43
|
+
if isinstance(value, dict):
|
44
|
+
new_config[key] = map_nested_config(value)
|
45
|
+
elif isinstance(value, str):
|
46
|
+
if value in DTYPE_MAP:
|
47
|
+
new_config[key] = DTYPE_MAP[value]
|
48
|
+
elif value in PRECISION_MAP:
|
49
|
+
new_config[key] = PRECISION_MAP[value]
|
50
|
+
elif value in ACTIVATION_MAP:
|
51
|
+
new_config[key] = ACTIVATION_MAP[value]
|
52
|
+
elif value == 'None':
|
53
|
+
new_config[key] = None
|
54
|
+
elif '.' in value:
|
55
|
+
# Ignore any other string that contains a dot
|
56
|
+
print(
|
57
|
+
f"Ignoring key {key} with value {value} as it contains a dot.")
|
58
|
+
return new_config
|
59
|
+
|
60
|
+
def serialize_model(model: nn.Module):
|
61
|
+
"""
|
62
|
+
Serializes the model to a dictionary format.
|
63
|
+
"""
|
64
|
+
model_dict = model.__dict__
|
65
|
+
model_dict = {k: v for k, v in model_dict.items() if not k.startswith('_')}
|
66
|
+
# Convert all callable attributes to their string representation
|
67
|
+
def map(model_dict):
|
68
|
+
for k, v in model_dict.items():
|
69
|
+
if isinstance(v, dict):
|
70
|
+
# Recursively serialize nested dictionaries
|
71
|
+
model_dict[k] = map(v)
|
72
|
+
elif isinstance(v, list):
|
73
|
+
# Recursively serialize lists
|
74
|
+
[map(item) if isinstance(item, dict) else item for item in v]
|
75
|
+
elif callable(v):
|
76
|
+
# If the attribute has __name__, use that as the key
|
77
|
+
if hasattr(v, '__name__'):
|
78
|
+
model_dict[k] = v.__name__
|
79
|
+
else:
|
80
|
+
model_dict[k] = str(v).split('.')[-1]
|
81
|
+
map(model_dict)
|
82
|
+
return model_dict
|
83
|
+
|
84
|
+
def get_latest_checkpoint(checkpoint_path):
|
85
|
+
checkpoint_files = os.listdir(checkpoint_path)
|
86
|
+
# Sort files by step number
|
87
|
+
checkpoint_files = sorted([int(i) for i in checkpoint_files])
|
88
|
+
latest_step = checkpoint_files[-1]
|
89
|
+
latest_checkpoint = os.path.join(checkpoint_path, str(latest_step))
|
90
|
+
return latest_checkpoint
|
11
91
|
|
12
92
|
class MarkovState(struct.PyTreeNode):
|
13
93
|
pass
|
14
94
|
|
15
95
|
class RandomMarkovState(MarkovState):
|
16
96
|
rng: jax.random.PRNGKey
|
17
|
-
|
18
97
|
def get_random_key(self):
|
19
98
|
rng, subkey = jax.random.split(self.rng)
|
20
99
|
return RandomMarkovState(rng), subkey
|
21
100
|
|
22
101
|
def clip_images(images, clip_min=-1, clip_max=1):
|
102
|
+
"""Clip image values to a specified range.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
images: Images to clip
|
106
|
+
clip_min: Minimum value
|
107
|
+
clip_max: Maximum value
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Clipped images
|
111
|
+
"""
|
23
112
|
return jnp.clip(images, clip_min, clip_max)
|
24
113
|
|
114
|
+
def denormalize_images(images, target_type=jnp.uint8, source_range=(-1, 1), target_range=(0, 255)):
|
115
|
+
"""Convert images from normalized range (e.g. [-1, 1]) to target range (e.g. [0, 255]).
|
116
|
+
|
117
|
+
Args:
|
118
|
+
images: Normalized images
|
119
|
+
target_type: Target dtype (e.g. jnp.uint8 for standard images)
|
120
|
+
source_range: Tuple of (min, max) for the source normalization range
|
121
|
+
target_range: Tuple of (min, max) for the target range
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
Denormalized images in the target dtype
|
125
|
+
"""
|
126
|
+
src_min, src_max = source_range
|
127
|
+
tgt_min, tgt_max = target_range
|
128
|
+
|
129
|
+
# First clip to ensure we're in the expected source range
|
130
|
+
images = clip_images(images, src_min, src_max)
|
131
|
+
|
132
|
+
# Scale to [0, 1]
|
133
|
+
images = (images - src_min) / (src_max - src_min)
|
134
|
+
|
135
|
+
# Scale to target range
|
136
|
+
images = images * (tgt_max - tgt_min) + tgt_min
|
137
|
+
|
138
|
+
# Convert to target dtype if needed
|
139
|
+
if target_type is not None:
|
140
|
+
images = images.astype(target_type)
|
141
|
+
|
142
|
+
return images
|
143
|
+
|
25
144
|
def _build_global_shape_and_sharding(
|
26
145
|
local_shape: tuple[int, ...], global_mesh: Mesh
|
27
146
|
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
@@ -117,45 +236,6 @@ class RMSNorm(nn.Module):
|
|
117
236
|
y = mul * x
|
118
237
|
return jnp.asarray(y, dtype)
|
119
238
|
|
120
|
-
@dataclass
|
121
|
-
class ConditioningEncoder(ABC):
|
122
|
-
model: nn.Module
|
123
|
-
tokenizer: Callable
|
124
|
-
|
125
|
-
def __call__(self, data):
|
126
|
-
tokens = self.tokenize(data)
|
127
|
-
outputs = self.encode_from_tokens(tokens)
|
128
|
-
return outputs
|
129
|
-
|
130
|
-
def encode_from_tokens(self, tokens):
|
131
|
-
outputs = self.model(input_ids=tokens['input_ids'],
|
132
|
-
attention_mask=tokens['attention_mask'])
|
133
|
-
last_hidden_state = outputs.last_hidden_state
|
134
|
-
return last_hidden_state
|
135
|
-
|
136
|
-
def tokenize(self, data):
|
137
|
-
tokens = self.tokenizer(data, padding="max_length",
|
138
|
-
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
|
139
|
-
return tokens
|
140
|
-
|
141
|
-
@dataclass
|
142
|
-
class TextEncoder(ConditioningEncoder):
|
143
|
-
# def __call__(self, data):
|
144
|
-
# tokens = self.tokenize(data)
|
145
|
-
# outputs = self.encode_from_tokens(tokens)
|
146
|
-
# return outputs
|
147
|
-
|
148
|
-
# def encode_from_tokens(self, tokens):
|
149
|
-
# outputs = self.model(input_ids=tokens['input_ids'],
|
150
|
-
# attention_mask=tokens['attention_mask'])
|
151
|
-
# last_hidden_state = outputs.last_hidden_state
|
152
|
-
# # pooler_output = outputs.pooler_output # pooled (EOS token) states
|
153
|
-
# # embed_pooled = pooler_output # .astype(jnp.float16)
|
154
|
-
# embed_labels_full = last_hidden_state # .astype(jnp.float16)
|
155
|
-
|
156
|
-
# return embed_labels_full
|
157
|
-
pass
|
158
|
-
|
159
239
|
class AutoTextTokenizer:
|
160
240
|
def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
|
161
241
|
from transformers import AutoTokenizer
|
@@ -175,18 +255,9 @@ class AutoTextTokenizer:
|
|
175
255
|
|
176
256
|
def __repr__(self):
|
177
257
|
return self.__class__.__name__ + '()'
|
258
|
+
|
259
|
+
# class AutoAudioTokenizer:
|
178
260
|
|
179
|
-
def defaultTextEncodeModel(backend="jax"):
|
180
|
-
|
181
|
-
|
182
|
-
FlaxCLIPTextModel,
|
183
|
-
AutoTokenizer,
|
184
|
-
)
|
185
|
-
modelname = "openai/clip-vit-large-patch14"
|
186
|
-
if backend == "jax":
|
187
|
-
model = FlaxCLIPTextModel.from_pretrained(
|
188
|
-
modelname, dtype=jnp.bfloat16)
|
189
|
-
else:
|
190
|
-
model = CLIPTextModel.from_pretrained(modelname)
|
191
|
-
tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
|
192
|
-
return TextEncoder(model, tokenizer)
|
261
|
+
def defaultTextEncodeModel(modelname = "openai/clip-vit-large-patch14", backend="jax"):
|
262
|
+
"""Default text encoder model."""
|
263
|
+
return CLIPTextEncoder.from_modelname(modelname=modelname, backend=backend)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
|
3
|
+
flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
|
4
|
+
flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
|
5
|
+
flaxdiff/data/dataloaders.py,sha256=V4goNCK0JD_TthggXAEgJJD4LxJi1pUDew1x_fMCuO4,22576
|
6
|
+
flaxdiff/data/dataset_map.py,sha256=NrLG1XtIxy8GcCsZ-e6eascjgsP0Xq5lVA1z3HIIYyI,5093
|
7
|
+
flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
|
8
|
+
flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
|
9
|
+
flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
|
10
|
+
flaxdiff/data/sources/av_utils.py,sha256=n2qwMBQGouoBH025vdE7gitWC6RduUommUrs-SPdWe4,24041
|
11
|
+
flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
|
12
|
+
flaxdiff/data/sources/images.py,sha256=WpH4ywZhNol26peX3m6m5NrmDJ1K2s6fRcYHvOFlOk8,11102
|
13
|
+
flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
|
14
|
+
flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
|
15
|
+
flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
|
16
|
+
flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
flaxdiff/inference/pipeline.py,sha256=oMBRjvTtlC3Yzl1FqiBHcI4V34HXGAecCg8UvQbKoOc,8849
|
18
|
+
flaxdiff/inference/utils.py,sha256=SRNYo-YtHzEPRpNv0fD8ZrUvnRIK941Rh4tjlsOGRgM,12278
|
19
|
+
flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
|
20
|
+
flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
|
21
|
+
flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
|
22
|
+
flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
+
flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
24
|
+
flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
|
25
|
+
flaxdiff/models/__init__.py,sha256=amtDF07DfiAdnZsvWX4eaW79nwNEU1s8Zb4PB3ewtg4,118
|
26
|
+
flaxdiff/models/attention.py,sha256=-q3xqWy4vQSLG4vXtiUN3FHVBIo7ZjpQsdLT9CkML6c,13367
|
27
|
+
flaxdiff/models/common.py,sha256=7x9o5vY9UZvN4BNZ7LHzyuU3PNpsNym9B3m1Wfdddjo,10320
|
28
|
+
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
29
|
+
flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
|
30
|
+
flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
|
31
|
+
flaxdiff/models/simple_vit.py,sha256=6DNpwTeE0Gn2jSie6n0JVUmQncPoyFT7jSSBreqk458,7497
|
32
|
+
flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
|
33
|
+
flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
|
34
|
+
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
35
|
+
flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
|
36
|
+
flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
|
37
|
+
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=NnGFjrkq-1z8Ouh_UvlvP0PFpkzm2LYaTQKMmN1BhkM,2109
|
38
|
+
flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
|
39
|
+
flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
|
40
|
+
flaxdiff/samplers/common.py,sha256=-uU9FLkoQp3n3bga8Kfj_onDhtOS4MBggIKNrq3S8n4,18438
|
41
|
+
flaxdiff/samplers/ddim.py,sha256=iFgXz96NBYuNMiWGMovf3gLO2TCxdrhJd-o8tvSmVUI,2054
|
42
|
+
flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
|
43
|
+
flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
|
44
|
+
flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGEpHU8,1436
|
45
|
+
flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
|
46
|
+
flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
|
47
|
+
flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
|
48
|
+
flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
|
49
|
+
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
50
|
+
flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
|
51
|
+
flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
|
52
|
+
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
53
|
+
flaxdiff/schedulers/karras.py,sha256=7PS6mHdnZnTqS2Xl_DacBt5YQ1f_CFyAxShyOo55eG0,3804
|
54
|
+
flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
|
55
|
+
flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
|
56
|
+
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
57
|
+
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
58
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
59
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=VQ5p2ZaTv2R1LM0Epz4e719_EfK2dh1eoKK3WIysIW0,24040
|
60
|
+
flaxdiff/trainer/simple_trainer.py,sha256=CF2mMcc6AtBgcR1XiqKevRL0paGS0S9ZJofCns32nRM,24214
|
61
|
+
flaxdiff-0.2.0.dist-info/METADATA,sha256=1WLpd9RQy_mJE2E2uOdXptY5Fm3n_MTNcgZyBD7YmGw,23982
|
62
|
+
flaxdiff-0.2.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
63
|
+
flaxdiff-0.2.0.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
64
|
+
flaxdiff-0.2.0.dist-info/RECORD,,
|
flaxdiff/data/datasets.py
DELETED
@@ -1,169 +0,0 @@
|
|
1
|
-
import jax.numpy as jnp
|
2
|
-
import grain.python as pygrain
|
3
|
-
from typing import Dict
|
4
|
-
import numpy as np
|
5
|
-
import jax
|
6
|
-
from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
|
7
|
-
from .dataset_map import datasetMap, onlineDatasetMap
|
8
|
-
import traceback
|
9
|
-
from .online_loader import OnlineStreamingDataLoader
|
10
|
-
import queue
|
11
|
-
from jax.sharding import Mesh
|
12
|
-
import threading
|
13
|
-
|
14
|
-
def batch_mesh_map(mesh):
|
15
|
-
class augmenters(pygrain.MapTransform):
|
16
|
-
def __init__(self, *args, **kwargs):
|
17
|
-
super().__init__(*args, **kwargs)
|
18
|
-
|
19
|
-
def map(self, batch) -> Dict[str, jnp.array]:
|
20
|
-
return convert_to_global_tree(mesh, batch)
|
21
|
-
return augmenters
|
22
|
-
|
23
|
-
def get_dataset_grain(
|
24
|
-
data_name="cc12m",
|
25
|
-
batch_size=64,
|
26
|
-
image_scale=256,
|
27
|
-
count=None,
|
28
|
-
num_epochs=None,
|
29
|
-
method=jax.image.ResizeMethod.LANCZOS3,
|
30
|
-
worker_count=32,
|
31
|
-
read_thread_count=64,
|
32
|
-
read_buffer_size=50,
|
33
|
-
worker_buffer_size=20,
|
34
|
-
seed=0,
|
35
|
-
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
36
|
-
):
|
37
|
-
dataset = datasetMap[data_name]
|
38
|
-
data_source = dataset["source"](dataset_source)
|
39
|
-
augmenter = dataset["augmenter"](image_scale, method)
|
40
|
-
|
41
|
-
local_batch_size = batch_size // jax.process_count()
|
42
|
-
|
43
|
-
sampler = pygrain.IndexSampler(
|
44
|
-
num_records=len(data_source) if count is None else count,
|
45
|
-
shuffle=True,
|
46
|
-
seed=seed,
|
47
|
-
num_epochs=num_epochs,
|
48
|
-
shard_options=pygrain.ShardByJaxProcess(),
|
49
|
-
)
|
50
|
-
|
51
|
-
def get_trainset():
|
52
|
-
transformations = [
|
53
|
-
augmenter(),
|
54
|
-
pygrain.Batch(local_batch_size, drop_remainder=True),
|
55
|
-
]
|
56
|
-
|
57
|
-
# if mesh != None:
|
58
|
-
# transformations += [batch_mesh_map(mesh)]
|
59
|
-
|
60
|
-
loader = pygrain.DataLoader(
|
61
|
-
data_source=data_source,
|
62
|
-
sampler=sampler,
|
63
|
-
operations=transformations,
|
64
|
-
worker_count=worker_count,
|
65
|
-
read_options=pygrain.ReadOptions(
|
66
|
-
read_thread_count, read_buffer_size
|
67
|
-
),
|
68
|
-
worker_buffer_size=worker_buffer_size,
|
69
|
-
)
|
70
|
-
return loader
|
71
|
-
|
72
|
-
|
73
|
-
return {
|
74
|
-
"train": get_trainset,
|
75
|
-
"train_len": len(data_source),
|
76
|
-
"local_batch_size": local_batch_size,
|
77
|
-
"global_batch_size": batch_size,
|
78
|
-
# "null_labels": null_labels,
|
79
|
-
# "null_labels_full": null_labels_full,
|
80
|
-
# "model": model,
|
81
|
-
# "tokenizer": tokenizer,
|
82
|
-
}
|
83
|
-
|
84
|
-
def generate_collate_fn():
|
85
|
-
auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
86
|
-
def default_collate(batch):
|
87
|
-
try:
|
88
|
-
# urls = [sample["url"] for sample in batch]
|
89
|
-
captions = [sample["caption"] for sample in batch]
|
90
|
-
results = auto_tokenize(captions)
|
91
|
-
images = np.stack([sample["image"] for sample in batch], axis=0)
|
92
|
-
return {
|
93
|
-
"image": images,
|
94
|
-
"input_ids": results['input_ids'],
|
95
|
-
"attention_mask": results['attention_mask'],
|
96
|
-
}
|
97
|
-
except Exception as e:
|
98
|
-
print("Error in collate function", e, [sample["image"].shape for sample in batch])
|
99
|
-
traceback.print_exc()
|
100
|
-
|
101
|
-
return default_collate
|
102
|
-
|
103
|
-
def get_dataset_online(
|
104
|
-
data_name="combined_online",
|
105
|
-
batch_size=64,
|
106
|
-
image_scale=256,
|
107
|
-
count=None,
|
108
|
-
num_epochs=None,
|
109
|
-
method=jax.image.ResizeMethod.LANCZOS3,
|
110
|
-
worker_count=32,
|
111
|
-
read_thread_count=64,
|
112
|
-
read_buffer_size=50,
|
113
|
-
worker_buffer_size=20,
|
114
|
-
seed=0,
|
115
|
-
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
116
|
-
):
|
117
|
-
local_batch_size = batch_size // jax.process_count()
|
118
|
-
|
119
|
-
sources = onlineDatasetMap[data_name]["source"]
|
120
|
-
dataloader = OnlineStreamingDataLoader(
|
121
|
-
sources,
|
122
|
-
batch_size=local_batch_size,
|
123
|
-
num_workers=worker_count,
|
124
|
-
num_threads=read_thread_count,
|
125
|
-
image_shape=(image_scale, image_scale),
|
126
|
-
global_process_count=jax.process_count(),
|
127
|
-
global_process_index=jax.process_index(),
|
128
|
-
prefetch=worker_buffer_size,
|
129
|
-
collate_fn=generate_collate_fn(),
|
130
|
-
default_split="train",
|
131
|
-
)
|
132
|
-
|
133
|
-
def get_trainset(mesh: Mesh = None):
|
134
|
-
if mesh != None:
|
135
|
-
class dataLoaderWithMesh:
|
136
|
-
def __init__(self, dataloader, mesh):
|
137
|
-
self.dataloader = dataloader
|
138
|
-
self.mesh = mesh
|
139
|
-
self.tmp_queue = queue.Queue(worker_buffer_size)
|
140
|
-
def batch_loader():
|
141
|
-
for batch in self.dataloader:
|
142
|
-
try:
|
143
|
-
self.tmp_queue.put(convert_to_global_tree(mesh, batch))
|
144
|
-
except Exception as e:
|
145
|
-
print("Error processing batch", e)
|
146
|
-
self.loader_thread = threading.Thread(target=batch_loader)
|
147
|
-
self.loader_thread.start()
|
148
|
-
|
149
|
-
def __iter__(self):
|
150
|
-
return self
|
151
|
-
|
152
|
-
def __next__(self):
|
153
|
-
return self.tmp_queue.get()
|
154
|
-
|
155
|
-
dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
|
156
|
-
|
157
|
-
return dataloader_with_mesh
|
158
|
-
return dataloader
|
159
|
-
|
160
|
-
return {
|
161
|
-
"train": get_trainset,
|
162
|
-
"train_len": len(dataloader) * jax.process_count(),
|
163
|
-
"local_batch_size": local_batch_size,
|
164
|
-
"global_batch_size": batch_size,
|
165
|
-
# "null_labels": null_labels,
|
166
|
-
# "null_labels_full": null_labels_full,
|
167
|
-
# "model": model,
|
168
|
-
# "tokenizer": tokenizer,
|
169
|
-
}
|