flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
flaxdiff/data/sources/gcs.py
DELETED
@@ -1,81 +0,0 @@
|
|
1
|
-
import cv2
|
2
|
-
import jax.numpy as jnp
|
3
|
-
import grain.python as pygrain
|
4
|
-
from flaxdiff.utils import AutoTextTokenizer
|
5
|
-
from typing import Dict
|
6
|
-
import os
|
7
|
-
import struct as st
|
8
|
-
from functools import partial
|
9
|
-
import numpy as np
|
10
|
-
|
11
|
-
# -----------------------------------------------------------------------------------------------#
|
12
|
-
# CC12m and other GCS data sources --------------------------------------------------------------#
|
13
|
-
# -----------------------------------------------------------------------------------------------#
|
14
|
-
|
15
|
-
def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
|
16
|
-
def data_source(base="/home/mrwhite0racle/gcs_mount"):
|
17
|
-
records_path = os.path.join(base, source)
|
18
|
-
records = [os.path.join(records_path, i) for i in os.listdir(
|
19
|
-
records_path) if 'array_record' in i]
|
20
|
-
ds = pygrain.ArrayRecordDataSource(records)
|
21
|
-
return ds
|
22
|
-
return data_source
|
23
|
-
|
24
|
-
def data_source_combined_gcs(
|
25
|
-
sources=[]):
|
26
|
-
def data_source(base="/home/mrwhite0racle/gcs_mount"):
|
27
|
-
records_paths = [os.path.join(base, source) for source in sources]
|
28
|
-
records = []
|
29
|
-
for records_path in records_paths:
|
30
|
-
records += [os.path.join(records_path, i) for i in os.listdir(
|
31
|
-
records_path) if 'array_record' in i]
|
32
|
-
ds = pygrain.ArrayRecordDataSource(records)
|
33
|
-
return ds
|
34
|
-
return data_source
|
35
|
-
|
36
|
-
def unpack_dict_of_byte_arrays(packed_data):
|
37
|
-
unpacked_dict = {}
|
38
|
-
offset = 0
|
39
|
-
while offset < len(packed_data):
|
40
|
-
# Unpack the key length
|
41
|
-
key_length = st.unpack_from('I', packed_data, offset)[0]
|
42
|
-
offset += st.calcsize('I')
|
43
|
-
# Unpack the key bytes and convert to string
|
44
|
-
key = packed_data[offset:offset+key_length].decode('utf-8')
|
45
|
-
offset += key_length
|
46
|
-
# Unpack the byte array length
|
47
|
-
byte_array_length = st.unpack_from('I', packed_data, offset)[0]
|
48
|
-
offset += st.calcsize('I')
|
49
|
-
# Unpack the byte array
|
50
|
-
byte_array = packed_data[offset:offset+byte_array_length]
|
51
|
-
offset += byte_array_length
|
52
|
-
unpacked_dict[key] = byte_array
|
53
|
-
return unpacked_dict
|
54
|
-
|
55
|
-
def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
|
56
|
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
57
|
-
image = cv2.resize(image, (image_scale, image_scale),
|
58
|
-
interpolation=cv2.INTER_AREA)
|
59
|
-
return image
|
60
|
-
|
61
|
-
def gcs_augmenters(image_scale, method):
|
62
|
-
labelizer = lambda sample : sample['txt']
|
63
|
-
class augmenters(pygrain.MapTransform):
|
64
|
-
def __init__(self, *args, **kwargs):
|
65
|
-
super().__init__(*args, **kwargs)
|
66
|
-
self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
67
|
-
self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
|
68
|
-
|
69
|
-
def map(self, element) -> Dict[str, jnp.array]:
|
70
|
-
element = unpack_dict_of_byte_arrays(element)
|
71
|
-
image = np.asarray(bytearray(element['jpg']), dtype="uint8")
|
72
|
-
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
|
73
|
-
image = self.image_augmenter(image)
|
74
|
-
caption = labelizer(element).decode('utf-8')
|
75
|
-
results = self.auto_tokenize(caption)
|
76
|
-
return {
|
77
|
-
"image": image,
|
78
|
-
"input_ids": results['input_ids'][0],
|
79
|
-
"attention_mask": results['attention_mask'][0],
|
80
|
-
}
|
81
|
-
return augmenters
|
flaxdiff/data/sources/tfds.py
DELETED
@@ -1,79 +0,0 @@
|
|
1
|
-
import cv2
|
2
|
-
import jax.numpy as jnp
|
3
|
-
import grain.python as pygrain
|
4
|
-
from flaxdiff.utils import AutoTextTokenizer
|
5
|
-
from typing import Dict
|
6
|
-
import random
|
7
|
-
import augmax
|
8
|
-
import jax
|
9
|
-
|
10
|
-
# -----------------------------------------------------------------------------------------------#
|
11
|
-
# Oxford flowers and other TFDS datasources -----------------------------------------------------#
|
12
|
-
# -----------------------------------------------------------------------------------------------#
|
13
|
-
|
14
|
-
PROMPT_TEMPLATES = [
|
15
|
-
"a photo of a {}",
|
16
|
-
"a photo of a {} flower",
|
17
|
-
"This is a photo of a {}",
|
18
|
-
"This is a photo of a {} flower",
|
19
|
-
"A photo of a {} flower",
|
20
|
-
]
|
21
|
-
|
22
|
-
def data_source_tfds(name, use_tf=True, split="all"):
|
23
|
-
import tensorflow_datasets as tfds
|
24
|
-
if use_tf:
|
25
|
-
def data_source(path_override):
|
26
|
-
return tfds.load(name, split=split, shuffle_files=True)
|
27
|
-
else:
|
28
|
-
def data_source(path_override):
|
29
|
-
return tfds.data_source(name, split=split, try_gcs=False)
|
30
|
-
return data_source
|
31
|
-
|
32
|
-
def labelizer_oxford_flowers102(path):
|
33
|
-
with open(path, "r") as f:
|
34
|
-
textlabels = [i.strip() for i in f.readlines()]
|
35
|
-
|
36
|
-
def load_labels(sample):
|
37
|
-
raw = textlabels[int(sample['label'])]
|
38
|
-
# randomly select a prompt template
|
39
|
-
template = random.choice(PROMPT_TEMPLATES)
|
40
|
-
# format the template with the label
|
41
|
-
caption = template.format(raw)
|
42
|
-
# return the caption
|
43
|
-
return caption
|
44
|
-
return load_labels
|
45
|
-
|
46
|
-
def tfds_augmenters(image_scale, method):
|
47
|
-
labelizer = labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
|
48
|
-
if image_scale > 256:
|
49
|
-
interpolation = cv2.INTER_CUBIC
|
50
|
-
else:
|
51
|
-
interpolation = cv2.INTER_AREA
|
52
|
-
|
53
|
-
from torchvision.transforms import v2
|
54
|
-
|
55
|
-
augments = v2.Compose([
|
56
|
-
v2.RandomHorizontalFlip(p=0.5),
|
57
|
-
v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
|
58
|
-
])
|
59
|
-
|
60
|
-
class augmenters(pygrain.MapTransform):
|
61
|
-
def __init__(self, *args, **kwargs):
|
62
|
-
super().__init__(*args, **kwargs)
|
63
|
-
self.tokenize = AutoTextTokenizer(tensor_type="np")
|
64
|
-
|
65
|
-
def map(self, element) -> Dict[str, jnp.array]:
|
66
|
-
image = element['image']
|
67
|
-
image = cv2.resize(image, (image_scale, image_scale),
|
68
|
-
interpolation=interpolation)
|
69
|
-
image = augments(image)
|
70
|
-
# image = (image - 127.5) / 127.5
|
71
|
-
|
72
|
-
caption = labelizer(element)
|
73
|
-
results = self.tokenize(caption)
|
74
|
-
return {
|
75
|
-
"image": image,
|
76
|
-
"input_ids": results['input_ids'][0],
|
77
|
-
"attention_mask": results['attention_mask'][0],
|
78
|
-
}
|
79
|
-
return augmenters
|
@@ -1,62 +0,0 @@
|
|
1
|
-
import flax
|
2
|
-
from flax import linen as nn
|
3
|
-
import jax
|
4
|
-
from typing import Callable
|
5
|
-
from dataclasses import field
|
6
|
-
import jax.numpy as jnp
|
7
|
-
import optax
|
8
|
-
import functools
|
9
|
-
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
-
from jax.experimental.shard_map import shard_map
|
11
|
-
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
12
|
-
|
13
|
-
from ..schedulers import NoiseScheduler
|
14
|
-
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
15
|
-
|
16
|
-
from flaxdiff.utils import RandomMarkovState
|
17
|
-
|
18
|
-
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
19
|
-
|
20
|
-
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
21
|
-
from flax.training import dynamic_scale as dynamic_scale_lib
|
22
|
-
|
23
|
-
class TrainState(SimpleTrainState):
|
24
|
-
rngs: jax.random.PRNGKey
|
25
|
-
ema_params: dict
|
26
|
-
|
27
|
-
def apply_ema(self, decay: float = 0.999):
|
28
|
-
new_ema_params = jax.tree_util.tree_map(
|
29
|
-
lambda ema, param: decay * ema + (1 - decay) * param,
|
30
|
-
self.ema_params,
|
31
|
-
self.params,
|
32
|
-
)
|
33
|
-
return self.replace(ema_params=new_ema_params)
|
34
|
-
|
35
|
-
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
36
|
-
from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
|
37
|
-
|
38
|
-
class SimpleVideoDiffusionTrainer(DiffusionTrainer):
|
39
|
-
def __init__(self,
|
40
|
-
model: nn.Module,
|
41
|
-
input_shapes: Dict[str, Tuple[int]],
|
42
|
-
optimizer: optax.GradientTransformation,
|
43
|
-
noise_schedule: NoiseScheduler,
|
44
|
-
rngs: jax.random.PRNGKey,
|
45
|
-
unconditional_prob: float = 0.12,
|
46
|
-
name: str = "SimpleVideoDiffusion",
|
47
|
-
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
48
|
-
autoencoder: AutoEncoder = None,
|
49
|
-
**kwargs
|
50
|
-
):
|
51
|
-
super().__init__(
|
52
|
-
model=model,
|
53
|
-
input_shapes=input_shapes,
|
54
|
-
optimizer=optimizer,
|
55
|
-
noise_schedule=noise_schedule,
|
56
|
-
unconditional_prob=unconditional_prob,
|
57
|
-
autoencoder=autoencoder,
|
58
|
-
model_output_transform=model_output_transform,
|
59
|
-
rngs=rngs,
|
60
|
-
name=name,
|
61
|
-
**kwargs
|
62
|
-
)
|
@@ -1,50 +0,0 @@
|
|
1
|
-
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
|
3
|
-
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
|
5
|
-
flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
|
6
|
-
flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
|
7
|
-
flaxdiff/data/sources/gcs.py,sha256=11ZuQhvMyJRLg21DgVdzO5qEuae7zgzTXGNOskF-cbs,3380
|
8
|
-
flaxdiff/data/sources/tfds.py,sha256=k6IzCXPGnoeekvB8Ul0iFBQECz1Wq6LJd2DAjA7ULHI,2761
|
9
|
-
flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
|
10
|
-
flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
|
13
|
-
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
14
|
-
flaxdiff/models/attention.py,sha256=-q3xqWy4vQSLG4vXtiUN3FHVBIo7ZjpQsdLT9CkML6c,13367
|
15
|
-
flaxdiff/models/common.py,sha256=7x9o5vY9UZvN4BNZ7LHzyuU3PNpsNym9B3m1Wfdddjo,10320
|
16
|
-
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
17
|
-
flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
|
18
|
-
flaxdiff/models/simple_unet.py,sha256=hXdpAIA24ARfe53cufJb5Dl4CXPn1bIfWQ0iN6WPNfo,10763
|
19
|
-
flaxdiff/models/simple_vit.py,sha256=RaWA85EmEyfquXvnutYQjImDxPEvJ6CQUoSufEjS9aU,7498
|
20
|
-
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
21
|
-
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
22
|
-
flaxdiff/models/autoencoder/diffusers.py,sha256=DVWT4LRMvEtN36Yt0FTD0KzG8Isq_BvHkNpgDy6Gs40,3651
|
23
|
-
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
24
|
-
flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
|
25
|
-
flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
|
26
|
-
flaxdiff/samplers/common.py,sha256=wkzalSYrnsq6oUsevEeRCVfzqwk8qfwvggAlgNTqK-o,8848
|
27
|
-
flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
|
28
|
-
flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
|
29
|
-
flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
|
30
|
-
flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGEpHU8,1436
|
31
|
-
flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
|
32
|
-
flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
|
33
|
-
flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
|
34
|
-
flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
|
35
|
-
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
36
|
-
flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
|
37
|
-
flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
|
38
|
-
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
39
|
-
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
40
|
-
flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
|
41
|
-
flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
|
42
|
-
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
43
|
-
flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
|
44
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=is7iBV8QnL1FZNmXmmqz7K6cnNGF0LaXZlmFV1lPwHU,14185
|
45
|
-
flaxdiff/trainer/simple_trainer.py,sha256=NzpCQZlp3pZ7jPxBnYTzPjP1Oh2neD1KHirUiUpcNH0,23222
|
46
|
-
flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
|
47
|
-
flaxdiff-0.1.38.1.dist-info/METADATA,sha256=PFvy6QShR5ldz7tyAjQTehexRUg9DK7hRwy9S_ceVuI,23985
|
48
|
-
flaxdiff-0.1.38.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
49
|
-
flaxdiff-0.1.38.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
50
|
-
flaxdiff-0.1.38.1.dist-info/RECORD,,
|
File without changes
|