flaxdiff 0.1.36__tar.gz → 0.1.36.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {flaxdiff-0.1.36 → flaxdiff-0.1.36.2}/PKG-INFO +13 -10
  2. flaxdiff-0.1.36.2/pyproject.toml +29 -0
  3. flaxdiff-0.1.36.2/src/data/sources/gcs.py +81 -0
  4. flaxdiff-0.1.36.2/src/data/sources/tfds.py +67 -0
  5. {flaxdiff-0.1.36 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/PKG-INFO +13 -10
  6. flaxdiff-0.1.36.2/src/flaxdiff.egg-info/SOURCES.txt +50 -0
  7. flaxdiff-0.1.36.2/src/flaxdiff.egg-info/requires.txt +14 -0
  8. flaxdiff-0.1.36.2/src/flaxdiff.egg-info/top_level.txt +9 -0
  9. flaxdiff-0.1.36.2/src/metrics/inception.py +658 -0
  10. flaxdiff-0.1.36.2/src/metrics/utils.py +49 -0
  11. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/simple_trainer.py +2 -0
  12. flaxdiff-0.1.36/flaxdiff.egg-info/SOURCES.txt +0 -46
  13. flaxdiff-0.1.36/flaxdiff.egg-info/requires.txt +0 -5
  14. flaxdiff-0.1.36/flaxdiff.egg-info/top_level.txt +0 -1
  15. flaxdiff-0.1.36/setup.py +0 -21
  16. {flaxdiff-0.1.36 → flaxdiff-0.1.36.2}/README.md +0 -0
  17. {flaxdiff-0.1.36 → flaxdiff-0.1.36.2}/setup.cfg +0 -0
  18. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/__init__.py +0 -0
  19. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/__init__.py +0 -0
  20. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/dataset_map.py +0 -0
  21. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/datasets.py +0 -0
  22. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/online_loader.py +0 -0
  23. {flaxdiff-0.1.36 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/dependency_links.txt +0 -0
  24. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/__init__.py +0 -0
  25. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/attention.py +0 -0
  26. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/__init__.py +0 -0
  27. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/autoencoder.py +0 -0
  28. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/diffusers.py +0 -0
  29. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/simple_autoenc.py +0 -0
  30. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/common.py +0 -0
  31. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/favor_fastattn.py +0 -0
  32. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/simple_unet.py +0 -0
  33. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/simple_vit.py +0 -0
  34. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/predictors/__init__.py +0 -0
  35. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/__init__.py +0 -0
  36. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/common.py +0 -0
  37. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/ddim.py +0 -0
  38. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/ddpm.py +0 -0
  39. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/euler.py +0 -0
  40. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/heun_sampler.py +0 -0
  41. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/multistep_dpm.py +0 -0
  42. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/rk4_sampler.py +0 -0
  43. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/__init__.py +0 -0
  44. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/common.py +0 -0
  45. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/continuous.py +0 -0
  46. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/cosine.py +0 -0
  47. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/discrete.py +0 -0
  48. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/exp.py +0 -0
  49. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/karras.py +0 -0
  50. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/linear.py +0 -0
  51. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/sqrt.py +0 -0
  52. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/__init__.py +0 -0
  53. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/autoencoder_trainer.py +0 -0
  54. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/diffusion_trainer.py +0 -0
  55. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/video_diffusion_trainer.py +0 -0
  56. {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/utils.py +0 -0
@@ -1,21 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36
3
+ Version: 0.1.36.2
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
13
- Dynamic: author
14
- Dynamic: author-email
15
- Dynamic: description
16
- Dynamic: description-content-type
17
- Dynamic: requires-dist
18
- Dynamic: summary
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
19
22
 
20
23
  # ![](images/logo.jpeg "FlaxDiff")
21
24
 
@@ -0,0 +1,29 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flaxdiff"
7
+ version = "0.1.36.2"
8
+ description = "A versatile and easy to understand Diffusion library"
9
+ readme = "README.md"
10
+ authors = [
11
+ { name="Ashish Kumar Singh", email="ashishkmr472@gmail.com" }
12
+ ]
13
+ dependencies = [
14
+ "flax>=0.8.4",
15
+ "jax>=0.4.28",
16
+ "optax>=0.2.2",
17
+ "orbax",
18
+ "numpy",
19
+ "clu",
20
+ "einops",
21
+ "tqdm",
22
+ "grain",
23
+ "termcolor",
24
+ "augmax",
25
+ "albumentations",
26
+ "rich",
27
+ "python-dotenv",
28
+ ]
29
+ license = "MIT"
@@ -0,0 +1,81 @@
1
+ import cv2
2
+ import jax.numpy as jnp
3
+ import grain.python as pygrain
4
+ from flaxdiff.utils import AutoTextTokenizer
5
+ from typing import Dict
6
+ import os
7
+ import struct as st
8
+ from functools import partial
9
+ import numpy as np
10
+
11
+ # -----------------------------------------------------------------------------------------------#
12
+ # CC12m and other GCS data sources --------------------------------------------------------------#
13
+ # -----------------------------------------------------------------------------------------------#
14
+
15
+ def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
16
+ def data_source(base="/home/mrwhite0racle/gcs_mount"):
17
+ records_path = os.path.join(base, source)
18
+ records = [os.path.join(records_path, i) for i in os.listdir(
19
+ records_path) if 'array_record' in i]
20
+ ds = pygrain.ArrayRecordDataSource(records)
21
+ return ds
22
+ return data_source
23
+
24
+ def data_source_combined_gcs(
25
+ sources=[]):
26
+ def data_source(base="/home/mrwhite0racle/gcs_mount"):
27
+ records_paths = [os.path.join(base, source) for source in sources]
28
+ records = []
29
+ for records_path in records_paths:
30
+ records += [os.path.join(records_path, i) for i in os.listdir(
31
+ records_path) if 'array_record' in i]
32
+ ds = pygrain.ArrayRecordDataSource(records)
33
+ return ds
34
+ return data_source
35
+
36
+ def unpack_dict_of_byte_arrays(packed_data):
37
+ unpacked_dict = {}
38
+ offset = 0
39
+ while offset < len(packed_data):
40
+ # Unpack the key length
41
+ key_length = st.unpack_from('I', packed_data, offset)[0]
42
+ offset += st.calcsize('I')
43
+ # Unpack the key bytes and convert to string
44
+ key = packed_data[offset:offset+key_length].decode('utf-8')
45
+ offset += key_length
46
+ # Unpack the byte array length
47
+ byte_array_length = st.unpack_from('I', packed_data, offset)[0]
48
+ offset += st.calcsize('I')
49
+ # Unpack the byte array
50
+ byte_array = packed_data[offset:offset+byte_array_length]
51
+ offset += byte_array_length
52
+ unpacked_dict[key] = byte_array
53
+ return unpacked_dict
54
+
55
+ def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
56
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
+ image = cv2.resize(image, (image_scale, image_scale),
58
+ interpolation=cv2.INTER_AREA)
59
+ return image
60
+
61
+ def gcs_augmenters(image_scale, method):
62
+ labelizer = lambda sample : sample['txt']
63
+ class augmenters(pygrain.MapTransform):
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
67
+ self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
68
+
69
+ def map(self, element) -> Dict[str, jnp.array]:
70
+ element = unpack_dict_of_byte_arrays(element)
71
+ image = np.asarray(bytearray(element['jpg']), dtype="uint8")
72
+ image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
73
+ image = self.image_augmenter(image)
74
+ caption = labelizer(element).decode('utf-8')
75
+ results = self.auto_tokenize(caption)
76
+ return {
77
+ "image": image,
78
+ "input_ids": results['input_ids'][0],
79
+ "attention_mask": results['attention_mask'][0],
80
+ }
81
+ return augmenters
@@ -0,0 +1,67 @@
1
+ import cv2
2
+ import jax.numpy as jnp
3
+ import grain.python as pygrain
4
+ from flaxdiff.utils import AutoTextTokenizer
5
+ from typing import Dict
6
+ import random
7
+
8
+ # -----------------------------------------------------------------------------------------------#
9
+ # Oxford flowers and other TFDS datasources -----------------------------------------------------#
10
+ # -----------------------------------------------------------------------------------------------#
11
+
12
+ PROMPT_TEMPLATES = [
13
+ "a photo of a {}",
14
+ "a photo of a {} flower",
15
+ "This is a photo of a {}",
16
+ "This is a photo of a {} flower",
17
+ "A photo of a {} flower",
18
+ ]
19
+
20
+ def data_source_tfds(name, use_tf=True, split="all"):
21
+ import tensorflow_datasets as tfds
22
+ if use_tf:
23
+ def data_source(path_override):
24
+ return tfds.load(name, split=split, shuffle_files=True)
25
+ else:
26
+ def data_source(path_override):
27
+ return tfds.data_source(name, split=split, try_gcs=False)
28
+ return data_source
29
+
30
+ def labelizer_oxford_flowers102(path):
31
+ with open(path, "r") as f:
32
+ textlabels = [i.strip() for i in f.readlines()]
33
+
34
+ def load_labels(sample):
35
+ raw = textlabels[int(sample['label'])]
36
+ # randomly select a prompt template
37
+ template = random.choice(PROMPT_TEMPLATES)
38
+ # format the template with the label
39
+ caption = template.format(raw)
40
+ # return the caption
41
+ return caption
42
+ return load_labels
43
+
44
+ def tfds_augmenters(image_scale, method):
45
+ labelizer = labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
46
+ if image_scale > 256:
47
+ interpolation = cv2.INTER_CUBIC
48
+ else:
49
+ interpolation = cv2.INTER_AREA
50
+ class augmenters(pygrain.MapTransform):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+ self.tokenize = AutoTextTokenizer(tensor_type="np")
54
+
55
+ def map(self, element) -> Dict[str, jnp.array]:
56
+ image = element['image']
57
+ image = cv2.resize(image, (image_scale, image_scale),
58
+ interpolation=interpolation)
59
+ # image = (image - 127.5) / 127.5
60
+ caption = labelizer(element)
61
+ results = self.tokenize(caption)
62
+ return {
63
+ "image": image,
64
+ "input_ids": results['input_ids'][0],
65
+ "attention_mask": results['attention_mask'][0],
66
+ }
67
+ return augmenters
@@ -1,21 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36
3
+ Version: 0.1.36.2
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
13
- Dynamic: author
14
- Dynamic: author-email
15
- Dynamic: description
16
- Dynamic: description-content-type
17
- Dynamic: requires-dist
18
- Dynamic: summary
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
19
22
 
20
23
  # ![](images/logo.jpeg "FlaxDiff")
21
24
 
@@ -0,0 +1,50 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/__init__.py
4
+ src/utils.py
5
+ src/data/__init__.py
6
+ src/data/dataset_map.py
7
+ src/data/datasets.py
8
+ src/data/online_loader.py
9
+ src/data/sources/gcs.py
10
+ src/data/sources/tfds.py
11
+ src/flaxdiff.egg-info/PKG-INFO
12
+ src/flaxdiff.egg-info/SOURCES.txt
13
+ src/flaxdiff.egg-info/dependency_links.txt
14
+ src/flaxdiff.egg-info/requires.txt
15
+ src/flaxdiff.egg-info/top_level.txt
16
+ src/metrics/inception.py
17
+ src/metrics/utils.py
18
+ src/models/__init__.py
19
+ src/models/attention.py
20
+ src/models/common.py
21
+ src/models/favor_fastattn.py
22
+ src/models/simple_unet.py
23
+ src/models/simple_vit.py
24
+ src/models/autoencoder/__init__.py
25
+ src/models/autoencoder/autoencoder.py
26
+ src/models/autoencoder/diffusers.py
27
+ src/models/autoencoder/simple_autoenc.py
28
+ src/predictors/__init__.py
29
+ src/samplers/__init__.py
30
+ src/samplers/common.py
31
+ src/samplers/ddim.py
32
+ src/samplers/ddpm.py
33
+ src/samplers/euler.py
34
+ src/samplers/heun_sampler.py
35
+ src/samplers/multistep_dpm.py
36
+ src/samplers/rk4_sampler.py
37
+ src/schedulers/__init__.py
38
+ src/schedulers/common.py
39
+ src/schedulers/continuous.py
40
+ src/schedulers/cosine.py
41
+ src/schedulers/discrete.py
42
+ src/schedulers/exp.py
43
+ src/schedulers/karras.py
44
+ src/schedulers/linear.py
45
+ src/schedulers/sqrt.py
46
+ src/trainer/__init__.py
47
+ src/trainer/autoencoder_trainer.py
48
+ src/trainer/diffusion_trainer.py
49
+ src/trainer/simple_trainer.py
50
+ src/trainer/video_diffusion_trainer.py
@@ -0,0 +1,14 @@
1
+ flax>=0.8.4
2
+ jax>=0.4.28
3
+ optax>=0.2.2
4
+ orbax
5
+ numpy
6
+ clu
7
+ einops
8
+ tqdm
9
+ grain
10
+ termcolor
11
+ augmax
12
+ albumentations
13
+ rich
14
+ python-dotenv
@@ -0,0 +1,9 @@
1
+ __init__
2
+ data
3
+ metrics
4
+ models
5
+ predictors
6
+ samplers
7
+ schedulers
8
+ trainer
9
+ utils
@@ -0,0 +1,658 @@
1
+ # Mostly derived from
2
+ # https://github.com/matthias-wright/jax-fid
3
+
4
+ import jax
5
+ from jax import lax
6
+ from jax.nn import initializers
7
+ import jax.numpy as jnp
8
+ import flax
9
+ from flax.linen.module import merge_param
10
+ import flax.linen as nn
11
+ from typing import Callable, Iterable, Optional, Tuple, Union, Any
12
+ import functools
13
+ import pickle
14
+ from . import utils
15
+
16
+ PRNGKey = Any
17
+ Array = Any
18
+ Shape = Tuple[int]
19
+ Dtype = Any
20
+
21
+
22
+ class InceptionV3(nn.Module):
23
+ """
24
+ InceptionV3 network.
25
+ Reference: https://arxiv.org/abs/1512.00567
26
+ Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py
27
+
28
+ Attributes:
29
+ include_head (bool): If True, include classifier head.
30
+ num_classes (int): Number of classes.
31
+ pretrained (bool): If True, use pretrained weights.
32
+ transform_input (bool): If True, preprocesses the input according to the method with which it
33
+ was trained on ImageNet.
34
+ aux_logits (bool): If True, add an auxiliary branch that can improve training.
35
+ dtype (str): Data type.
36
+ """
37
+ include_head: bool=False
38
+ num_classes: int=1000
39
+ pretrained: bool=False
40
+ transform_input: bool=False
41
+ aux_logits: bool=False
42
+ ckpt_path: str='https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1'
43
+ dtype: str='float32'
44
+
45
+ def setup(self):
46
+ if self.pretrained:
47
+ ckpt_file = utils.download(self.ckpt_path)
48
+ self.params_dict = pickle.load(open(ckpt_file, 'rb'))
49
+ self.num_classes_ = 1000
50
+ else:
51
+ self.params_dict = None
52
+ self.num_classes_ = self.num_classes
53
+
54
+ @nn.compact
55
+ def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)):
56
+ """
57
+ Args:
58
+ x (tensor): Input image, shape [B, H, W, C].
59
+ train (bool): If True, training mode.
60
+ rng (jax.random.PRNGKey): Random seed.
61
+ """
62
+ x = self._transform_input(x)
63
+ x = BasicConv2d(out_channels=32,
64
+ kernel_size=(3, 3),
65
+ strides=(2, 2),
66
+ params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'),
67
+ dtype=self.dtype)(x, train)
68
+ x = BasicConv2d(out_channels=32,
69
+ kernel_size=(3, 3),
70
+ params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'),
71
+ dtype=self.dtype)(x, train)
72
+ x = BasicConv2d(out_channels=64,
73
+ kernel_size=(3, 3),
74
+ padding=((1, 1), (1, 1)),
75
+ params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'),
76
+ dtype=self.dtype)(x, train)
77
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
78
+ x = BasicConv2d(out_channels=80,
79
+ kernel_size=(1, 1),
80
+ params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'),
81
+ dtype=self.dtype)(x, train)
82
+ x = BasicConv2d(out_channels=192,
83
+ kernel_size=(3, 3),
84
+ params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'),
85
+ dtype=self.dtype)(x, train)
86
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
87
+ x = InceptionA(pool_features=32,
88
+ params_dict=utils.get(self.params_dict, 'Mixed_5b'),
89
+ dtype=self.dtype)(x, train)
90
+ x = InceptionA(pool_features=64,
91
+ params_dict=utils.get(self.params_dict, 'Mixed_5c'),
92
+ dtype=self.dtype)(x, train)
93
+ x = InceptionA(pool_features=64,
94
+ params_dict=utils.get(self.params_dict, 'Mixed_5d'),
95
+ dtype=self.dtype)(x, train)
96
+ x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'),
97
+ dtype=self.dtype)(x, train)
98
+ x = InceptionC(channels_7x7=128,
99
+ params_dict=utils.get(self.params_dict, 'Mixed_6b'),
100
+ dtype=self.dtype)(x, train)
101
+ x = InceptionC(channels_7x7=160,
102
+ params_dict=utils.get(self.params_dict, 'Mixed_6c'),
103
+ dtype=self.dtype)(x, train)
104
+ x = InceptionC(channels_7x7=160,
105
+ params_dict=utils.get(self.params_dict, 'Mixed_6d'),
106
+ dtype=self.dtype)(x, train)
107
+ x = InceptionC(channels_7x7=192,
108
+ params_dict=utils.get(self.params_dict, 'Mixed_6e'),
109
+ dtype=self.dtype)(x, train)
110
+ aux = None
111
+ if self.aux_logits and train:
112
+ aux = InceptionAux(num_classes=self.num_classes_,
113
+ params_dict=utils.get(self.params_dict, 'AuxLogits'),
114
+ dtype=self.dtype)(x, train)
115
+ x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'),
116
+ dtype=self.dtype)(x, train)
117
+ x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'),
118
+ dtype=self.dtype)(x, train)
119
+ # Following the implementation by @mseitzer, we use max pooling instead
120
+ # of average pooling here.
121
+ # See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320
122
+ x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'),
123
+ dtype=self.dtype)(x, train)
124
+ x = jnp.mean(x, axis=(1, 2), keepdims=True)
125
+ if not self.include_head:
126
+ return x
127
+ x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng)
128
+ x = jnp.reshape(x, newshape=(x.shape[0], -1))
129
+ x = Dense(features=self.num_classes_,
130
+ params_dict=utils.get(self.params_dict, 'fc'),
131
+ dtype=self.dtype)(x)
132
+ if self.aux_logits:
133
+ return x, aux
134
+ return x
135
+
136
+ def _transform_input(self, x):
137
+ if self.transform_input:
138
+ x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
139
+ x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
140
+ x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
141
+ x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1)
142
+ return x
143
+
144
+
145
+ class Dense(nn.Module):
146
+ features: int
147
+ kernel_init: functools.partial=nn.initializers.lecun_normal()
148
+ bias_init: functools.partial=nn.initializers.zeros
149
+ params_dict: dict=None
150
+ dtype: str='float32'
151
+
152
+ @nn.compact
153
+ def __call__(self, x):
154
+ x = nn.Dense(features=self.features,
155
+ kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']),
156
+ bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x)
157
+ return x
158
+
159
+
160
+ class BasicConv2d(nn.Module):
161
+ out_channels: int
162
+ kernel_size: Union[int, Iterable[int]]=(3, 3)
163
+ strides: Optional[Iterable[int]]=(1, 1)
164
+ padding: Union[str, Iterable[Tuple[int, int]]]='valid'
165
+ use_bias: bool=False
166
+ kernel_init: functools.partial=nn.initializers.lecun_normal()
167
+ bias_init: functools.partial=nn.initializers.zeros
168
+ params_dict: dict=None
169
+ dtype: str='float32'
170
+
171
+ @nn.compact
172
+ def __call__(self, x, train=True):
173
+ x = nn.Conv(features=self.out_channels,
174
+ kernel_size=self.kernel_size,
175
+ strides=self.strides,
176
+ padding=self.padding,
177
+ use_bias=self.use_bias,
178
+ kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']),
179
+ bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']),
180
+ dtype=self.dtype)(x)
181
+ if self.params_dict is None:
182
+ x = BatchNorm(epsilon=0.001,
183
+ momentum=0.1,
184
+ use_running_average=not train,
185
+ dtype=self.dtype)(x)
186
+ else:
187
+ x = BatchNorm(epsilon=0.001,
188
+ momentum=0.1,
189
+ bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']),
190
+ scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']),
191
+ mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']),
192
+ var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']),
193
+ use_running_average=not train,
194
+ dtype=self.dtype)(x)
195
+ x = jax.nn.relu(x)
196
+ return x
197
+
198
+
199
+ class InceptionA(nn.Module):
200
+ pool_features: int
201
+ params_dict: dict=None
202
+ dtype: str='float32'
203
+
204
+ @nn.compact
205
+ def __call__(self, x, train=True):
206
+ branch1x1 = BasicConv2d(out_channels=64,
207
+ kernel_size=(1, 1),
208
+ params_dict=utils.get(self.params_dict, 'branch1x1'),
209
+ dtype=self.dtype)(x, train)
210
+ branch5x5 = BasicConv2d(out_channels=48,
211
+ kernel_size=(1, 1),
212
+ params_dict=utils.get(self.params_dict, 'branch5x5_1'),
213
+ dtype=self.dtype)(x, train)
214
+ branch5x5 = BasicConv2d(out_channels=64,
215
+ kernel_size=(5, 5),
216
+ padding=((2, 2), (2, 2)),
217
+ params_dict=utils.get(self.params_dict, 'branch5x5_2'),
218
+ dtype=self.dtype)(branch5x5, train)
219
+
220
+ branch3x3dbl = BasicConv2d(out_channels=64,
221
+ kernel_size=(1, 1),
222
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
223
+ dtype=self.dtype)(x, train)
224
+ branch3x3dbl = BasicConv2d(out_channels=96,
225
+ kernel_size=(3, 3),
226
+ padding=((1, 1), (1, 1)),
227
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
228
+ dtype=self.dtype)(branch3x3dbl, train)
229
+ branch3x3dbl = BasicConv2d(out_channels=96,
230
+ kernel_size=(3, 3),
231
+ padding=((1, 1), (1, 1)),
232
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
233
+ dtype=self.dtype)(branch3x3dbl, train)
234
+
235
+ branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
236
+ branch_pool = BasicConv2d(out_channels=self.pool_features,
237
+ kernel_size=(1, 1),
238
+ params_dict=utils.get(self.params_dict, 'branch_pool'),
239
+ dtype=self.dtype)(branch_pool, train)
240
+
241
+ output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1)
242
+ return output
243
+
244
+
245
+ class InceptionB(nn.Module):
246
+ params_dict: dict=None
247
+ dtype: str='float32'
248
+
249
+ @nn.compact
250
+ def __call__(self, x, train=True):
251
+ branch3x3 = BasicConv2d(out_channels=384,
252
+ kernel_size=(3, 3),
253
+ strides=(2, 2),
254
+ params_dict=utils.get(self.params_dict, 'branch3x3'),
255
+ dtype=self.dtype)(x, train)
256
+
257
+ branch3x3dbl = BasicConv2d(out_channels=64,
258
+ kernel_size=(1, 1),
259
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
260
+ dtype=self.dtype)(x, train)
261
+ branch3x3dbl = BasicConv2d(out_channels=96,
262
+ kernel_size=(3, 3),
263
+ padding=((1, 1), (1, 1)),
264
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
265
+ dtype=self.dtype)(branch3x3dbl, train)
266
+ branch3x3dbl = BasicConv2d(out_channels=96,
267
+ kernel_size=(3, 3),
268
+ strides=(2, 2),
269
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
270
+ dtype=self.dtype)(branch3x3dbl, train)
271
+
272
+ branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
273
+
274
+ output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1)
275
+ return output
276
+
277
+
278
+ class InceptionC(nn.Module):
279
+ channels_7x7: int
280
+ params_dict: dict=None
281
+ dtype: str='float32'
282
+
283
+ @nn.compact
284
+ def __call__(self, x, train=True):
285
+ branch1x1 = BasicConv2d(out_channels=192,
286
+ kernel_size=(1, 1),
287
+ params_dict=utils.get(self.params_dict, 'branch1x1'),
288
+ dtype=self.dtype)(x, train)
289
+
290
+ branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
291
+ kernel_size=(1, 1),
292
+ params_dict=utils.get(self.params_dict, 'branch7x7_1'),
293
+ dtype=self.dtype)(x, train)
294
+ branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
295
+ kernel_size=(1, 7),
296
+ padding=((0, 0), (3, 3)),
297
+ params_dict=utils.get(self.params_dict, 'branch7x7_2'),
298
+ dtype=self.dtype)(branch7x7, train)
299
+ branch7x7 = BasicConv2d(out_channels=192,
300
+ kernel_size=(7, 1),
301
+ padding=((3, 3), (0, 0)),
302
+ params_dict=utils.get(self.params_dict, 'branch7x7_3'),
303
+ dtype=self.dtype)(branch7x7, train)
304
+
305
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
306
+ kernel_size=(1, 1),
307
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'),
308
+ dtype=self.dtype)(x, train)
309
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
310
+ kernel_size=(7, 1),
311
+ padding=((3, 3), (0, 0)),
312
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'),
313
+ dtype=self.dtype)(branch7x7dbl, train)
314
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
315
+ kernel_size=(1, 7),
316
+ padding=((0, 0), (3, 3)),
317
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'),
318
+ dtype=self.dtype)(branch7x7dbl, train)
319
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
320
+ kernel_size=(7, 1),
321
+ padding=((3, 3), (0, 0)),
322
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'),
323
+ dtype=self.dtype)(branch7x7dbl, train)
324
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
325
+ kernel_size=(1, 7),
326
+ padding=((0, 0), (3, 3)),
327
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'),
328
+ dtype=self.dtype)(branch7x7dbl, train)
329
+
330
+ branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
331
+ branch_pool = BasicConv2d(out_channels=192,
332
+ kernel_size=(1, 1),
333
+ params_dict=utils.get(self.params_dict, 'branch_pool'),
334
+ dtype=self.dtype)(branch_pool, train)
335
+
336
+ output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1)
337
+ return output
338
+
339
+
340
+ class InceptionD(nn.Module):
341
+ params_dict: dict=None
342
+ dtype: str='float32'
343
+
344
+ @nn.compact
345
+ def __call__(self, x, train=True):
346
+ branch3x3 = BasicConv2d(out_channels=192,
347
+ kernel_size=(1, 1),
348
+ params_dict=utils.get(self.params_dict, 'branch3x3_1'),
349
+ dtype=self.dtype)(x, train)
350
+ branch3x3 = BasicConv2d(out_channels=320,
351
+ kernel_size=(3, 3),
352
+ strides=(2, 2),
353
+ params_dict=utils.get(self.params_dict, 'branch3x3_2'),
354
+ dtype=self.dtype)(branch3x3, train)
355
+
356
+ branch7x7x3 = BasicConv2d(out_channels=192,
357
+ kernel_size=(1, 1),
358
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_1'),
359
+ dtype=self.dtype)(x, train)
360
+ branch7x7x3 = BasicConv2d(out_channels=192,
361
+ kernel_size=(1, 7),
362
+ padding=((0, 0), (3, 3)),
363
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_2'),
364
+ dtype=self.dtype)(branch7x7x3, train)
365
+ branch7x7x3 = BasicConv2d(out_channels=192,
366
+ kernel_size=(7, 1),
367
+ padding=((3, 3), (0, 0)),
368
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_3'),
369
+ dtype=self.dtype)(branch7x7x3, train)
370
+ branch7x7x3 = BasicConv2d(out_channels=192,
371
+ kernel_size=(3, 3),
372
+ strides=(2, 2),
373
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_4'),
374
+ dtype=self.dtype)(branch7x7x3, train)
375
+
376
+ branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
377
+
378
+ output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1)
379
+ return output
380
+
381
+
382
+ class InceptionE(nn.Module):
383
+ pooling: Callable
384
+ params_dict: dict=None
385
+ dtype: str='float32'
386
+
387
+ @nn.compact
388
+ def __call__(self, x, train=True):
389
+ branch1x1 = BasicConv2d(out_channels=320,
390
+ kernel_size=(1, 1),
391
+ params_dict=utils.get(self.params_dict, 'branch1x1'),
392
+ dtype=self.dtype)(x, train)
393
+
394
+ branch3x3 = BasicConv2d(out_channels=384,
395
+ kernel_size=(1, 1),
396
+ params_dict=utils.get(self.params_dict, 'branch3x3_1'),
397
+ dtype=self.dtype)(x, train)
398
+ branch3x3_a = BasicConv2d(out_channels=384,
399
+ kernel_size=(1, 3),
400
+ padding=((0, 0), (1, 1)),
401
+ params_dict=utils.get(self.params_dict, 'branch3x3_2a'),
402
+ dtype=self.dtype)(branch3x3, train)
403
+ branch3x3_b = BasicConv2d(out_channels=384,
404
+ kernel_size=(3, 1),
405
+ padding=((1, 1), (0, 0)),
406
+ params_dict=utils.get(self.params_dict, 'branch3x3_2b'),
407
+ dtype=self.dtype)(branch3x3, train)
408
+ branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1)
409
+
410
+ branch3x3dbl = BasicConv2d(out_channels=448,
411
+ kernel_size=(1, 1),
412
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
413
+ dtype=self.dtype)(x, train)
414
+ branch3x3dbl = BasicConv2d(out_channels=384,
415
+ kernel_size=(3, 3),
416
+ padding=((1, 1), (1, 1)),
417
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
418
+ dtype=self.dtype)(branch3x3dbl, train)
419
+ branch3x3dbl_a = BasicConv2d(out_channels=384,
420
+ kernel_size=(1, 3),
421
+ padding=((0, 0), (1, 1)),
422
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'),
423
+ dtype=self.dtype)(branch3x3dbl, train)
424
+ branch3x3dbl_b = BasicConv2d(out_channels=384,
425
+ kernel_size=(3, 1),
426
+ padding=((1, 1), (0, 0)),
427
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'),
428
+ dtype=self.dtype)(branch3x3dbl, train)
429
+ branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1)
430
+
431
+ branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
432
+ branch_pool = BasicConv2d(out_channels=192,
433
+ kernel_size=(1, 1),
434
+ params_dict=utils.get(self.params_dict, 'branch_pool'),
435
+ dtype=self.dtype)(branch_pool, train)
436
+
437
+ output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1)
438
+ return output
439
+
440
+
441
+ class InceptionAux(nn.Module):
442
+ num_classes: int
443
+ kernel_init: functools.partial=nn.initializers.lecun_normal()
444
+ bias_init: functools.partial=nn.initializers.zeros
445
+ params_dict: dict=None
446
+ dtype: str='float32'
447
+
448
+ @nn.compact
449
+ def __call__(self, x, train=True):
450
+ x = avg_pool(x, window_shape=(5, 5), strides=(3, 3))
451
+ x = BasicConv2d(out_channels=128,
452
+ kernel_size=(1, 1),
453
+ params_dict=utils.get(self.params_dict, 'conv0'),
454
+ dtype=self.dtype)(x, train)
455
+ x = BasicConv2d(out_channels=768,
456
+ kernel_size=(5, 5),
457
+ params_dict=utils.get(self.params_dict, 'conv1'),
458
+ dtype=self.dtype)(x, train)
459
+ x = jnp.mean(x, axis=(1, 2))
460
+ x = jnp.reshape(x, newshape=(x.shape[0], -1))
461
+ x = Dense(features=self.num_classes,
462
+ params_dict=utils.get(self.params_dict, 'fc'),
463
+ dtype=self.dtype)(x)
464
+ return x
465
+
466
+ def _absolute_dims(rank, dims):
467
+ return tuple([rank + dim if dim < 0 else dim for dim in dims])
468
+
469
+
470
+ class BatchNorm(nn.Module):
471
+ """BatchNorm Module.
472
+ Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
473
+ Attributes:
474
+ use_running_average: if True, the statistics stored in batch_stats
475
+ will be used instead of computing the batch statistics on the input.
476
+ axis: the feature or non-batch axis of the input.
477
+ momentum: decay rate for the exponential moving average of the batch statistics.
478
+ epsilon: a small float added to variance to avoid dividing by zero.
479
+ dtype: the dtype of the computation (default: float32).
480
+ use_bias: if True, bias (beta) is added.
481
+ use_scale: if True, multiply by scale (gamma).
482
+ When the next layer is linear (also e.g. nn.relu), this can be disabled
483
+ since the scaling will be done by the next layer.
484
+ bias_init: initializer for bias, by default, zero.
485
+ scale_init: initializer for scale, by default, one.
486
+ axis_name: the axis name used to combine batch statistics from multiple
487
+ devices. See `jax.pmap` for a description of axis names (default: None).
488
+ axis_index_groups: groups of axis indices within that named axis
489
+ representing subsets of devices to reduce over (default: None). For
490
+ example, `[[0, 1], [2, 3]]` would independently batch-normalize over
491
+ the examples on the first two and last two devices. See `jax.lax.psum`
492
+ for more details.
493
+ """
494
+ use_running_average: Optional[bool] = None
495
+ axis: int = -1
496
+ momentum: float = 0.99
497
+ epsilon: float = 1e-5
498
+ dtype: Dtype = jnp.float32
499
+ use_bias: bool = True
500
+ use_scale: bool = True
501
+ bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
502
+ scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
503
+ mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
504
+ var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
505
+ axis_name: Optional[str] = None
506
+ axis_index_groups: Any = None
507
+
508
+ @nn.compact
509
+ def __call__(self, x, use_running_average: Optional[bool] = None):
510
+ """Normalizes the input using batch statistics.
511
+
512
+ NOTE:
513
+ During initialization (when parameters are mutable) the running average
514
+ of the batch statistics will not be updated. Therefore, the inputs
515
+ fed during initialization don't need to match that of the actual input
516
+ distribution and the reduction axis (set with `axis_name`) does not have
517
+ to exist.
518
+ Args:
519
+ x: the input to be normalized.
520
+ use_running_average: if true, the statistics stored in batch_stats
521
+ will be used instead of computing the batch statistics on the input.
522
+ Returns:
523
+ Normalized inputs (the same shape as inputs).
524
+ """
525
+ use_running_average = merge_param(
526
+ 'use_running_average', self.use_running_average, use_running_average)
527
+ x = jnp.asarray(x, jnp.float32)
528
+ axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
529
+ axis = _absolute_dims(x.ndim, axis)
530
+ feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
531
+ reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
532
+ reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
533
+
534
+ # see NOTE above on initialization behavior
535
+ initializing = self.is_mutable_collection('params')
536
+
537
+ ra_mean = self.variable('batch_stats', 'mean',
538
+ self.mean_init,
539
+ reduced_feature_shape)
540
+ ra_var = self.variable('batch_stats', 'var',
541
+ self.var_init,
542
+ reduced_feature_shape)
543
+
544
+ if use_running_average:
545
+ mean, var = ra_mean.value, ra_var.value
546
+ else:
547
+ mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
548
+ mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
549
+ if self.axis_name is not None and not initializing:
550
+ concatenated_mean = jnp.concatenate([mean, mean2])
551
+ mean, mean2 = jnp.split(
552
+ lax.pmean(
553
+ concatenated_mean,
554
+ axis_name=self.axis_name,
555
+ axis_index_groups=self.axis_index_groups), 2)
556
+ var = mean2 - lax.square(mean)
557
+
558
+ if not initializing:
559
+ ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
560
+ ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
561
+
562
+ y = x - mean.reshape(feature_shape)
563
+ mul = lax.rsqrt(var + self.epsilon)
564
+ if self.use_scale:
565
+ scale = self.param('scale',
566
+ self.scale_init,
567
+ reduced_feature_shape).reshape(feature_shape)
568
+ mul = mul * scale
569
+ y = y * mul
570
+ if self.use_bias:
571
+ bias = self.param('bias',
572
+ self.bias_init,
573
+ reduced_feature_shape).reshape(feature_shape)
574
+ y = y + bias
575
+ return jnp.asarray(y, self.dtype)
576
+
577
+
578
+ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
579
+ """
580
+ Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py
581
+
582
+ Helper function to define pooling functions.
583
+ Pooling functions are implemented using the ReduceWindow XLA op.
584
+ NOTE: Be aware that pooling is not generally differentiable.
585
+ That means providing a reduce_fn that is differentiable does not imply
586
+ that pool is differentiable.
587
+ Args:
588
+ inputs: input data with dimensions (batch, window dims..., features).
589
+ init: the initial value for the reduction
590
+ reduce_fn: a reduce function of the form `(T, T) -> T`.
591
+ window_shape: a shape tuple defining the window to reduce over.
592
+ strides: a sequence of `n` integers, representing the inter-window
593
+ strides.
594
+ padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
595
+ of `n` `(low, high)` integer pairs that give the padding to apply before
596
+ and after each spatial dimension.
597
+ Returns:
598
+ The output of the reduction for each window slice.
599
+ """
600
+ strides = strides or (1,) * len(window_shape)
601
+ assert len(window_shape) == len(strides), (
602
+ f"len({window_shape}) == len({strides})")
603
+ strides = (1,) + strides + (1,)
604
+ dims = (1,) + window_shape + (1,)
605
+
606
+ is_single_input = False
607
+ if inputs.ndim == len(dims) - 1:
608
+ # add singleton batch dimension because lax.reduce_window always
609
+ # needs a batch dimension.
610
+ inputs = inputs[None]
611
+ is_single_input = True
612
+
613
+ assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
614
+ if not isinstance(padding, str):
615
+ padding = tuple(map(tuple, padding))
616
+ assert(len(padding) == len(window_shape)), (
617
+ f"padding {padding} must specify pads for same number of dims as "
618
+ f"window_shape {window_shape}")
619
+ assert(all([len(x) == 2 for x in padding])), (
620
+ f"each entry in padding {padding} must be length 2")
621
+ padding = ((0,0),) + padding + ((0,0),)
622
+ y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
623
+ if is_single_input:
624
+ y = jnp.squeeze(y, axis=0)
625
+ return y
626
+
627
+
628
+ def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
629
+ """
630
+ Pools the input by taking the average over a window.
631
+
632
+ In comparison to flax.linen.avg_pool, this pooling operation does not
633
+ consider the padded zero's for the average computation.
634
+
635
+ Args:
636
+ inputs: input data with dimensions (batch, window dims..., features).
637
+ window_shape: a shape tuple defining the window to reduce over.
638
+ strides: a sequence of `n` integers, representing the inter-window
639
+ strides (default: `(1, ..., 1)`).
640
+ padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
641
+ of `n` `(low, high)` integer pairs that give the padding to apply before
642
+ and after each spatial dimension (default: `'VALID'`).
643
+ Returns:
644
+ The average for each window slice.
645
+ """
646
+ assert inputs.ndim == 4
647
+ assert len(window_shape) == 2
648
+
649
+ y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
650
+ ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype)
651
+ counts = jax.lax.conv_general_dilated(ones,
652
+ jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)),
653
+ window_strides=(1, 1),
654
+ padding=((1, 1), (1, 1)),
655
+ dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape),
656
+ feature_group_count=1)
657
+ y = y / counts
658
+ return y
@@ -0,0 +1,49 @@
1
+ # Mostly derived from
2
+ # https://github.com/matthias-wright/jax-fid
3
+
4
+ import jax
5
+ import flax
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import requests
9
+ import os
10
+ import tempfile
11
+
12
+
13
+ def download(url, ckpt_dir=None):
14
+ name = url[url.rfind('/') + 1 : url.rfind('?')]
15
+ if ckpt_dir is None:
16
+ ckpt_dir = tempfile.gettempdir()
17
+ ckpt_dir = os.path.join(ckpt_dir, 'jax_fid')
18
+ ckpt_file = os.path.join(ckpt_dir, name)
19
+ if not os.path.exists(ckpt_file):
20
+ print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
21
+ if not os.path.exists(ckpt_dir):
22
+ os.makedirs(ckpt_dir)
23
+
24
+ response = requests.get(url, stream=True)
25
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
26
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
27
+
28
+ # first create temp file, in case the download fails
29
+ ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
30
+ with open(ckpt_file_temp, 'wb') as file:
31
+ for data in response.iter_content(chunk_size=1024):
32
+ progress_bar.update(len(data))
33
+ file.write(data)
34
+ progress_bar.close()
35
+
36
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
37
+ print('An error occured while downloading, please try again.')
38
+ if os.path.exists(ckpt_file_temp):
39
+ os.remove(ckpt_file_temp)
40
+ else:
41
+ # if download was successful, rename the temp file
42
+ os.rename(ckpt_file_temp, ckpt_file)
43
+ return ckpt_file
44
+
45
+
46
+ def get(dictionary, key):
47
+ if dictionary is None or key not in dictionary:
48
+ return None
49
+ return dictionary[key]
@@ -219,6 +219,8 @@ class SimpleTrainer:
219
219
 
220
220
  def checkpoint_path(self):
221
221
  path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
222
+ # Convert the path to an absolute path
223
+ path = os.path.abspath(path)
222
224
  if not os.path.exists(path):
223
225
  os.makedirs(path)
224
226
  return path
@@ -1,46 +0,0 @@
1
- README.md
2
- setup.py
3
- flaxdiff/__init__.py
4
- flaxdiff/utils.py
5
- flaxdiff.egg-info/PKG-INFO
6
- flaxdiff.egg-info/SOURCES.txt
7
- flaxdiff.egg-info/dependency_links.txt
8
- flaxdiff.egg-info/requires.txt
9
- flaxdiff.egg-info/top_level.txt
10
- flaxdiff/data/__init__.py
11
- flaxdiff/data/dataset_map.py
12
- flaxdiff/data/datasets.py
13
- flaxdiff/data/online_loader.py
14
- flaxdiff/models/__init__.py
15
- flaxdiff/models/attention.py
16
- flaxdiff/models/common.py
17
- flaxdiff/models/favor_fastattn.py
18
- flaxdiff/models/simple_unet.py
19
- flaxdiff/models/simple_vit.py
20
- flaxdiff/models/autoencoder/__init__.py
21
- flaxdiff/models/autoencoder/autoencoder.py
22
- flaxdiff/models/autoencoder/diffusers.py
23
- flaxdiff/models/autoencoder/simple_autoenc.py
24
- flaxdiff/predictors/__init__.py
25
- flaxdiff/samplers/__init__.py
26
- flaxdiff/samplers/common.py
27
- flaxdiff/samplers/ddim.py
28
- flaxdiff/samplers/ddpm.py
29
- flaxdiff/samplers/euler.py
30
- flaxdiff/samplers/heun_sampler.py
31
- flaxdiff/samplers/multistep_dpm.py
32
- flaxdiff/samplers/rk4_sampler.py
33
- flaxdiff/schedulers/__init__.py
34
- flaxdiff/schedulers/common.py
35
- flaxdiff/schedulers/continuous.py
36
- flaxdiff/schedulers/cosine.py
37
- flaxdiff/schedulers/discrete.py
38
- flaxdiff/schedulers/exp.py
39
- flaxdiff/schedulers/karras.py
40
- flaxdiff/schedulers/linear.py
41
- flaxdiff/schedulers/sqrt.py
42
- flaxdiff/trainer/__init__.py
43
- flaxdiff/trainer/autoencoder_trainer.py
44
- flaxdiff/trainer/diffusion_trainer.py
45
- flaxdiff/trainer/simple_trainer.py
46
- flaxdiff/trainer/video_diffusion_trainer.py
@@ -1,5 +0,0 @@
1
- flax>=0.8.4
2
- optax>=0.2.2
3
- jax>=0.4.28
4
- orbax
5
- clu
@@ -1 +0,0 @@
1
- flaxdiff
flaxdiff-0.1.36/setup.py DELETED
@@ -1,21 +0,0 @@
1
- from setuptools import find_packages, setup
2
-
3
- required_packages=[
4
- 'flax>=0.8.4',
5
- 'optax>=0.2.2',
6
- 'jax>=0.4.28',
7
- 'orbax',
8
- 'clu',
9
- ]
10
-
11
- setup(
12
- name='flaxdiff',
13
- packages=find_packages(),
14
- version='0.1.36',
15
- description='A versatile and easy to understand Diffusion library',
16
- long_description=open('README.md').read(),
17
- long_description_content_type='text/markdown',
18
- author='Ashish Kumar Singh',
19
- author_email='ashishkmr472@gmail.com',
20
- install_requires=required_packages,
21
- )
File without changes
File without changes