flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,81 +0,0 @@
1
- import cv2
2
- import jax.numpy as jnp
3
- import grain.python as pygrain
4
- from flaxdiff.utils import AutoTextTokenizer
5
- from typing import Dict
6
- import os
7
- import struct as st
8
- from functools import partial
9
- import numpy as np
10
-
11
- # -----------------------------------------------------------------------------------------------#
12
- # CC12m and other GCS data sources --------------------------------------------------------------#
13
- # -----------------------------------------------------------------------------------------------#
14
-
15
- def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
16
- def data_source(base="/home/mrwhite0racle/gcs_mount"):
17
- records_path = os.path.join(base, source)
18
- records = [os.path.join(records_path, i) for i in os.listdir(
19
- records_path) if 'array_record' in i]
20
- ds = pygrain.ArrayRecordDataSource(records)
21
- return ds
22
- return data_source
23
-
24
- def data_source_combined_gcs(
25
- sources=[]):
26
- def data_source(base="/home/mrwhite0racle/gcs_mount"):
27
- records_paths = [os.path.join(base, source) for source in sources]
28
- records = []
29
- for records_path in records_paths:
30
- records += [os.path.join(records_path, i) for i in os.listdir(
31
- records_path) if 'array_record' in i]
32
- ds = pygrain.ArrayRecordDataSource(records)
33
- return ds
34
- return data_source
35
-
36
- def unpack_dict_of_byte_arrays(packed_data):
37
- unpacked_dict = {}
38
- offset = 0
39
- while offset < len(packed_data):
40
- # Unpack the key length
41
- key_length = st.unpack_from('I', packed_data, offset)[0]
42
- offset += st.calcsize('I')
43
- # Unpack the key bytes and convert to string
44
- key = packed_data[offset:offset+key_length].decode('utf-8')
45
- offset += key_length
46
- # Unpack the byte array length
47
- byte_array_length = st.unpack_from('I', packed_data, offset)[0]
48
- offset += st.calcsize('I')
49
- # Unpack the byte array
50
- byte_array = packed_data[offset:offset+byte_array_length]
51
- offset += byte_array_length
52
- unpacked_dict[key] = byte_array
53
- return unpacked_dict
54
-
55
- def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
56
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
- image = cv2.resize(image, (image_scale, image_scale),
58
- interpolation=cv2.INTER_AREA)
59
- return image
60
-
61
- def gcs_augmenters(image_scale, method):
62
- labelizer = lambda sample : sample['txt']
63
- class augmenters(pygrain.MapTransform):
64
- def __init__(self, *args, **kwargs):
65
- super().__init__(*args, **kwargs)
66
- self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
67
- self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
68
-
69
- def map(self, element) -> Dict[str, jnp.array]:
70
- element = unpack_dict_of_byte_arrays(element)
71
- image = np.asarray(bytearray(element['jpg']), dtype="uint8")
72
- image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
73
- image = self.image_augmenter(image)
74
- caption = labelizer(element).decode('utf-8')
75
- results = self.auto_tokenize(caption)
76
- return {
77
- "image": image,
78
- "input_ids": results['input_ids'][0],
79
- "attention_mask": results['attention_mask'][0],
80
- }
81
- return augmenters
@@ -1,79 +0,0 @@
1
- import cv2
2
- import jax.numpy as jnp
3
- import grain.python as pygrain
4
- from flaxdiff.utils import AutoTextTokenizer
5
- from typing import Dict
6
- import random
7
- import augmax
8
- import jax
9
-
10
- # -----------------------------------------------------------------------------------------------#
11
- # Oxford flowers and other TFDS datasources -----------------------------------------------------#
12
- # -----------------------------------------------------------------------------------------------#
13
-
14
- PROMPT_TEMPLATES = [
15
- "a photo of a {}",
16
- "a photo of a {} flower",
17
- "This is a photo of a {}",
18
- "This is a photo of a {} flower",
19
- "A photo of a {} flower",
20
- ]
21
-
22
- def data_source_tfds(name, use_tf=True, split="all"):
23
- import tensorflow_datasets as tfds
24
- if use_tf:
25
- def data_source(path_override):
26
- return tfds.load(name, split=split, shuffle_files=True)
27
- else:
28
- def data_source(path_override):
29
- return tfds.data_source(name, split=split, try_gcs=False)
30
- return data_source
31
-
32
- def labelizer_oxford_flowers102(path):
33
- with open(path, "r") as f:
34
- textlabels = [i.strip() for i in f.readlines()]
35
-
36
- def load_labels(sample):
37
- raw = textlabels[int(sample['label'])]
38
- # randomly select a prompt template
39
- template = random.choice(PROMPT_TEMPLATES)
40
- # format the template with the label
41
- caption = template.format(raw)
42
- # return the caption
43
- return caption
44
- return load_labels
45
-
46
- def tfds_augmenters(image_scale, method):
47
- labelizer = labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
48
- if image_scale > 256:
49
- interpolation = cv2.INTER_CUBIC
50
- else:
51
- interpolation = cv2.INTER_AREA
52
-
53
- from torchvision.transforms import v2
54
-
55
- augments = v2.Compose([
56
- v2.RandomHorizontalFlip(p=0.5),
57
- v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
58
- ])
59
-
60
- class augmenters(pygrain.MapTransform):
61
- def __init__(self, *args, **kwargs):
62
- super().__init__(*args, **kwargs)
63
- self.tokenize = AutoTextTokenizer(tensor_type="np")
64
-
65
- def map(self, element) -> Dict[str, jnp.array]:
66
- image = element['image']
67
- image = cv2.resize(image, (image_scale, image_scale),
68
- interpolation=interpolation)
69
- image = augments(image)
70
- # image = (image - 127.5) / 127.5
71
-
72
- caption = labelizer(element)
73
- results = self.tokenize(caption)
74
- return {
75
- "image": image,
76
- "input_ids": results['input_ids'][0],
77
- "attention_mask": results['attention_mask'][0],
78
- }
79
- return augmenters
@@ -1,62 +0,0 @@
1
- import flax
2
- from flax import linen as nn
3
- import jax
4
- from typing import Callable
5
- from dataclasses import field
6
- import jax.numpy as jnp
7
- import optax
8
- import functools
9
- from jax.sharding import Mesh, PartitionSpec as P
10
- from jax.experimental.shard_map import shard_map
11
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
12
-
13
- from ..schedulers import NoiseScheduler
14
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
15
-
16
- from flaxdiff.utils import RandomMarkovState
17
-
18
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
19
-
20
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
21
- from flax.training import dynamic_scale as dynamic_scale_lib
22
-
23
- class TrainState(SimpleTrainState):
24
- rngs: jax.random.PRNGKey
25
- ema_params: dict
26
-
27
- def apply_ema(self, decay: float = 0.999):
28
- new_ema_params = jax.tree_util.tree_map(
29
- lambda ema, param: decay * ema + (1 - decay) * param,
30
- self.ema_params,
31
- self.params,
32
- )
33
- return self.replace(ema_params=new_ema_params)
34
-
35
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
36
- from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
37
-
38
- class SimpleVideoDiffusionTrainer(DiffusionTrainer):
39
- def __init__(self,
40
- model: nn.Module,
41
- input_shapes: Dict[str, Tuple[int]],
42
- optimizer: optax.GradientTransformation,
43
- noise_schedule: NoiseScheduler,
44
- rngs: jax.random.PRNGKey,
45
- unconditional_prob: float = 0.12,
46
- name: str = "SimpleVideoDiffusion",
47
- model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
48
- autoencoder: AutoEncoder = None,
49
- **kwargs
50
- ):
51
- super().__init__(
52
- model=model,
53
- input_shapes=input_shapes,
54
- optimizer=optimizer,
55
- noise_schedule=noise_schedule,
56
- unconditional_prob=unconditional_prob,
57
- autoencoder=autoencoder,
58
- model_output_transform=model_output_transform,
59
- rngs=rngs,
60
- name=name,
61
- **kwargs
62
- )
@@ -1,50 +0,0 @@
1
- flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
- flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
5
- flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
6
- flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
7
- flaxdiff/data/sources/gcs.py,sha256=11ZuQhvMyJRLg21DgVdzO5qEuae7zgzTXGNOskF-cbs,3380
8
- flaxdiff/data/sources/tfds.py,sha256=k6IzCXPGnoeekvB8Ul0iFBQECz1Wq6LJd2DAjA7ULHI,2761
9
- flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
10
- flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
13
- flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
14
- flaxdiff/models/attention.py,sha256=-q3xqWy4vQSLG4vXtiUN3FHVBIo7ZjpQsdLT9CkML6c,13367
15
- flaxdiff/models/common.py,sha256=7x9o5vY9UZvN4BNZ7LHzyuU3PNpsNym9B3m1Wfdddjo,10320
16
- flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
17
- flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
18
- flaxdiff/models/simple_unet.py,sha256=hXdpAIA24ARfe53cufJb5Dl4CXPn1bIfWQ0iN6WPNfo,10763
19
- flaxdiff/models/simple_vit.py,sha256=RaWA85EmEyfquXvnutYQjImDxPEvJ6CQUoSufEjS9aU,7498
20
- flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
21
- flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
22
- flaxdiff/models/autoencoder/diffusers.py,sha256=DVWT4LRMvEtN36Yt0FTD0KzG8Isq_BvHkNpgDy6Gs40,3651
23
- flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
24
- flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
25
- flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
26
- flaxdiff/samplers/common.py,sha256=wkzalSYrnsq6oUsevEeRCVfzqwk8qfwvggAlgNTqK-o,8848
27
- flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
28
- flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
29
- flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
30
- flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGEpHU8,1436
31
- flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
32
- flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
33
- flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
34
- flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
35
- flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
36
- flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
37
- flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
38
- flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
39
- flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
40
- flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
41
- flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
42
- flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
43
- flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
44
- flaxdiff/trainer/diffusion_trainer.py,sha256=is7iBV8QnL1FZNmXmmqz7K6cnNGF0LaXZlmFV1lPwHU,14185
45
- flaxdiff/trainer/simple_trainer.py,sha256=NzpCQZlp3pZ7jPxBnYTzPjP1Oh2neD1KHirUiUpcNH0,23222
46
- flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
47
- flaxdiff-0.1.38.1.dist-info/METADATA,sha256=PFvy6QShR5ldz7tyAjQTehexRUg9DK7hRwy9S_ceVuI,23985
48
- flaxdiff-0.1.38.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
49
- flaxdiff-0.1.38.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
50
- flaxdiff-0.1.38.1.dist-info/RECORD,,