flaxdiff 0.1.36__tar.gz → 0.1.36.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.2}/PKG-INFO +13 -10
- flaxdiff-0.1.36.2/pyproject.toml +29 -0
- flaxdiff-0.1.36.2/src/data/sources/gcs.py +81 -0
- flaxdiff-0.1.36.2/src/data/sources/tfds.py +67 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/PKG-INFO +13 -10
- flaxdiff-0.1.36.2/src/flaxdiff.egg-info/SOURCES.txt +50 -0
- flaxdiff-0.1.36.2/src/flaxdiff.egg-info/requires.txt +14 -0
- flaxdiff-0.1.36.2/src/flaxdiff.egg-info/top_level.txt +9 -0
- flaxdiff-0.1.36.2/src/metrics/inception.py +658 -0
- flaxdiff-0.1.36.2/src/metrics/utils.py +49 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/simple_trainer.py +2 -0
- flaxdiff-0.1.36/flaxdiff.egg-info/SOURCES.txt +0 -46
- flaxdiff-0.1.36/flaxdiff.egg-info/requires.txt +0 -5
- flaxdiff-0.1.36/flaxdiff.egg-info/top_level.txt +0 -1
- flaxdiff-0.1.36/setup.py +0 -21
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.2}/README.md +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.2}/setup.cfg +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/dataset_map.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/datasets.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/data/online_loader.py +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/attention.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/common.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/simple_unet.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/models/simple_vit.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/predictors/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/common.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/ddim.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/euler.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/common.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/exp.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/karras.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/linear.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/__init__.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.36/flaxdiff → flaxdiff-0.1.36.2/src}/utils.py +0 -0
@@ -1,21 +1,24 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.36
|
3
|
+
Version: 0.1.36.2
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
|
5
|
+
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
|
+
License-Expression: MIT
|
7
7
|
Description-Content-Type: text/markdown
|
8
8
|
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
9
|
Requires-Dist: jax>=0.4.28
|
10
|
+
Requires-Dist: optax>=0.2.2
|
11
11
|
Requires-Dist: orbax
|
12
|
+
Requires-Dist: numpy
|
12
13
|
Requires-Dist: clu
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
14
|
+
Requires-Dist: einops
|
15
|
+
Requires-Dist: tqdm
|
16
|
+
Requires-Dist: grain
|
17
|
+
Requires-Dist: termcolor
|
18
|
+
Requires-Dist: augmax
|
19
|
+
Requires-Dist: albumentations
|
20
|
+
Requires-Dist: rich
|
21
|
+
Requires-Dist: python-dotenv
|
19
22
|
|
20
23
|
# 
|
21
24
|
|
@@ -0,0 +1,29 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools", "wheel"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "flaxdiff"
|
7
|
+
version = "0.1.36.2"
|
8
|
+
description = "A versatile and easy to understand Diffusion library"
|
9
|
+
readme = "README.md"
|
10
|
+
authors = [
|
11
|
+
{ name="Ashish Kumar Singh", email="ashishkmr472@gmail.com" }
|
12
|
+
]
|
13
|
+
dependencies = [
|
14
|
+
"flax>=0.8.4",
|
15
|
+
"jax>=0.4.28",
|
16
|
+
"optax>=0.2.2",
|
17
|
+
"orbax",
|
18
|
+
"numpy",
|
19
|
+
"clu",
|
20
|
+
"einops",
|
21
|
+
"tqdm",
|
22
|
+
"grain",
|
23
|
+
"termcolor",
|
24
|
+
"augmax",
|
25
|
+
"albumentations",
|
26
|
+
"rich",
|
27
|
+
"python-dotenv",
|
28
|
+
]
|
29
|
+
license = "MIT"
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import cv2
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import grain.python as pygrain
|
4
|
+
from flaxdiff.utils import AutoTextTokenizer
|
5
|
+
from typing import Dict
|
6
|
+
import os
|
7
|
+
import struct as st
|
8
|
+
from functools import partial
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
# -----------------------------------------------------------------------------------------------#
|
12
|
+
# CC12m and other GCS data sources --------------------------------------------------------------#
|
13
|
+
# -----------------------------------------------------------------------------------------------#
|
14
|
+
|
15
|
+
def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
|
16
|
+
def data_source(base="/home/mrwhite0racle/gcs_mount"):
|
17
|
+
records_path = os.path.join(base, source)
|
18
|
+
records = [os.path.join(records_path, i) for i in os.listdir(
|
19
|
+
records_path) if 'array_record' in i]
|
20
|
+
ds = pygrain.ArrayRecordDataSource(records)
|
21
|
+
return ds
|
22
|
+
return data_source
|
23
|
+
|
24
|
+
def data_source_combined_gcs(
|
25
|
+
sources=[]):
|
26
|
+
def data_source(base="/home/mrwhite0racle/gcs_mount"):
|
27
|
+
records_paths = [os.path.join(base, source) for source in sources]
|
28
|
+
records = []
|
29
|
+
for records_path in records_paths:
|
30
|
+
records += [os.path.join(records_path, i) for i in os.listdir(
|
31
|
+
records_path) if 'array_record' in i]
|
32
|
+
ds = pygrain.ArrayRecordDataSource(records)
|
33
|
+
return ds
|
34
|
+
return data_source
|
35
|
+
|
36
|
+
def unpack_dict_of_byte_arrays(packed_data):
|
37
|
+
unpacked_dict = {}
|
38
|
+
offset = 0
|
39
|
+
while offset < len(packed_data):
|
40
|
+
# Unpack the key length
|
41
|
+
key_length = st.unpack_from('I', packed_data, offset)[0]
|
42
|
+
offset += st.calcsize('I')
|
43
|
+
# Unpack the key bytes and convert to string
|
44
|
+
key = packed_data[offset:offset+key_length].decode('utf-8')
|
45
|
+
offset += key_length
|
46
|
+
# Unpack the byte array length
|
47
|
+
byte_array_length = st.unpack_from('I', packed_data, offset)[0]
|
48
|
+
offset += st.calcsize('I')
|
49
|
+
# Unpack the byte array
|
50
|
+
byte_array = packed_data[offset:offset+byte_array_length]
|
51
|
+
offset += byte_array_length
|
52
|
+
unpacked_dict[key] = byte_array
|
53
|
+
return unpacked_dict
|
54
|
+
|
55
|
+
def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
|
56
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
57
|
+
image = cv2.resize(image, (image_scale, image_scale),
|
58
|
+
interpolation=cv2.INTER_AREA)
|
59
|
+
return image
|
60
|
+
|
61
|
+
def gcs_augmenters(image_scale, method):
|
62
|
+
labelizer = lambda sample : sample['txt']
|
63
|
+
class augmenters(pygrain.MapTransform):
|
64
|
+
def __init__(self, *args, **kwargs):
|
65
|
+
super().__init__(*args, **kwargs)
|
66
|
+
self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
67
|
+
self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
|
68
|
+
|
69
|
+
def map(self, element) -> Dict[str, jnp.array]:
|
70
|
+
element = unpack_dict_of_byte_arrays(element)
|
71
|
+
image = np.asarray(bytearray(element['jpg']), dtype="uint8")
|
72
|
+
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
|
73
|
+
image = self.image_augmenter(image)
|
74
|
+
caption = labelizer(element).decode('utf-8')
|
75
|
+
results = self.auto_tokenize(caption)
|
76
|
+
return {
|
77
|
+
"image": image,
|
78
|
+
"input_ids": results['input_ids'][0],
|
79
|
+
"attention_mask": results['attention_mask'][0],
|
80
|
+
}
|
81
|
+
return augmenters
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import cv2
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import grain.python as pygrain
|
4
|
+
from flaxdiff.utils import AutoTextTokenizer
|
5
|
+
from typing import Dict
|
6
|
+
import random
|
7
|
+
|
8
|
+
# -----------------------------------------------------------------------------------------------#
|
9
|
+
# Oxford flowers and other TFDS datasources -----------------------------------------------------#
|
10
|
+
# -----------------------------------------------------------------------------------------------#
|
11
|
+
|
12
|
+
PROMPT_TEMPLATES = [
|
13
|
+
"a photo of a {}",
|
14
|
+
"a photo of a {} flower",
|
15
|
+
"This is a photo of a {}",
|
16
|
+
"This is a photo of a {} flower",
|
17
|
+
"A photo of a {} flower",
|
18
|
+
]
|
19
|
+
|
20
|
+
def data_source_tfds(name, use_tf=True, split="all"):
|
21
|
+
import tensorflow_datasets as tfds
|
22
|
+
if use_tf:
|
23
|
+
def data_source(path_override):
|
24
|
+
return tfds.load(name, split=split, shuffle_files=True)
|
25
|
+
else:
|
26
|
+
def data_source(path_override):
|
27
|
+
return tfds.data_source(name, split=split, try_gcs=False)
|
28
|
+
return data_source
|
29
|
+
|
30
|
+
def labelizer_oxford_flowers102(path):
|
31
|
+
with open(path, "r") as f:
|
32
|
+
textlabels = [i.strip() for i in f.readlines()]
|
33
|
+
|
34
|
+
def load_labels(sample):
|
35
|
+
raw = textlabels[int(sample['label'])]
|
36
|
+
# randomly select a prompt template
|
37
|
+
template = random.choice(PROMPT_TEMPLATES)
|
38
|
+
# format the template with the label
|
39
|
+
caption = template.format(raw)
|
40
|
+
# return the caption
|
41
|
+
return caption
|
42
|
+
return load_labels
|
43
|
+
|
44
|
+
def tfds_augmenters(image_scale, method):
|
45
|
+
labelizer = labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
|
46
|
+
if image_scale > 256:
|
47
|
+
interpolation = cv2.INTER_CUBIC
|
48
|
+
else:
|
49
|
+
interpolation = cv2.INTER_AREA
|
50
|
+
class augmenters(pygrain.MapTransform):
|
51
|
+
def __init__(self, *args, **kwargs):
|
52
|
+
super().__init__(*args, **kwargs)
|
53
|
+
self.tokenize = AutoTextTokenizer(tensor_type="np")
|
54
|
+
|
55
|
+
def map(self, element) -> Dict[str, jnp.array]:
|
56
|
+
image = element['image']
|
57
|
+
image = cv2.resize(image, (image_scale, image_scale),
|
58
|
+
interpolation=interpolation)
|
59
|
+
# image = (image - 127.5) / 127.5
|
60
|
+
caption = labelizer(element)
|
61
|
+
results = self.tokenize(caption)
|
62
|
+
return {
|
63
|
+
"image": image,
|
64
|
+
"input_ids": results['input_ids'][0],
|
65
|
+
"attention_mask": results['attention_mask'][0],
|
66
|
+
}
|
67
|
+
return augmenters
|
@@ -1,21 +1,24 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.36
|
3
|
+
Version: 0.1.36.2
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
|
5
|
+
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
|
+
License-Expression: MIT
|
7
7
|
Description-Content-Type: text/markdown
|
8
8
|
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
9
|
Requires-Dist: jax>=0.4.28
|
10
|
+
Requires-Dist: optax>=0.2.2
|
11
11
|
Requires-Dist: orbax
|
12
|
+
Requires-Dist: numpy
|
12
13
|
Requires-Dist: clu
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
14
|
+
Requires-Dist: einops
|
15
|
+
Requires-Dist: tqdm
|
16
|
+
Requires-Dist: grain
|
17
|
+
Requires-Dist: termcolor
|
18
|
+
Requires-Dist: augmax
|
19
|
+
Requires-Dist: albumentations
|
20
|
+
Requires-Dist: rich
|
21
|
+
Requires-Dist: python-dotenv
|
19
22
|
|
20
23
|
# 
|
21
24
|
|
@@ -0,0 +1,50 @@
|
|
1
|
+
README.md
|
2
|
+
pyproject.toml
|
3
|
+
src/__init__.py
|
4
|
+
src/utils.py
|
5
|
+
src/data/__init__.py
|
6
|
+
src/data/dataset_map.py
|
7
|
+
src/data/datasets.py
|
8
|
+
src/data/online_loader.py
|
9
|
+
src/data/sources/gcs.py
|
10
|
+
src/data/sources/tfds.py
|
11
|
+
src/flaxdiff.egg-info/PKG-INFO
|
12
|
+
src/flaxdiff.egg-info/SOURCES.txt
|
13
|
+
src/flaxdiff.egg-info/dependency_links.txt
|
14
|
+
src/flaxdiff.egg-info/requires.txt
|
15
|
+
src/flaxdiff.egg-info/top_level.txt
|
16
|
+
src/metrics/inception.py
|
17
|
+
src/metrics/utils.py
|
18
|
+
src/models/__init__.py
|
19
|
+
src/models/attention.py
|
20
|
+
src/models/common.py
|
21
|
+
src/models/favor_fastattn.py
|
22
|
+
src/models/simple_unet.py
|
23
|
+
src/models/simple_vit.py
|
24
|
+
src/models/autoencoder/__init__.py
|
25
|
+
src/models/autoencoder/autoencoder.py
|
26
|
+
src/models/autoencoder/diffusers.py
|
27
|
+
src/models/autoencoder/simple_autoenc.py
|
28
|
+
src/predictors/__init__.py
|
29
|
+
src/samplers/__init__.py
|
30
|
+
src/samplers/common.py
|
31
|
+
src/samplers/ddim.py
|
32
|
+
src/samplers/ddpm.py
|
33
|
+
src/samplers/euler.py
|
34
|
+
src/samplers/heun_sampler.py
|
35
|
+
src/samplers/multistep_dpm.py
|
36
|
+
src/samplers/rk4_sampler.py
|
37
|
+
src/schedulers/__init__.py
|
38
|
+
src/schedulers/common.py
|
39
|
+
src/schedulers/continuous.py
|
40
|
+
src/schedulers/cosine.py
|
41
|
+
src/schedulers/discrete.py
|
42
|
+
src/schedulers/exp.py
|
43
|
+
src/schedulers/karras.py
|
44
|
+
src/schedulers/linear.py
|
45
|
+
src/schedulers/sqrt.py
|
46
|
+
src/trainer/__init__.py
|
47
|
+
src/trainer/autoencoder_trainer.py
|
48
|
+
src/trainer/diffusion_trainer.py
|
49
|
+
src/trainer/simple_trainer.py
|
50
|
+
src/trainer/video_diffusion_trainer.py
|
@@ -0,0 +1,658 @@
|
|
1
|
+
# Mostly derived from
|
2
|
+
# https://github.com/matthias-wright/jax-fid
|
3
|
+
|
4
|
+
import jax
|
5
|
+
from jax import lax
|
6
|
+
from jax.nn import initializers
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import flax
|
9
|
+
from flax.linen.module import merge_param
|
10
|
+
import flax.linen as nn
|
11
|
+
from typing import Callable, Iterable, Optional, Tuple, Union, Any
|
12
|
+
import functools
|
13
|
+
import pickle
|
14
|
+
from . import utils
|
15
|
+
|
16
|
+
PRNGKey = Any
|
17
|
+
Array = Any
|
18
|
+
Shape = Tuple[int]
|
19
|
+
Dtype = Any
|
20
|
+
|
21
|
+
|
22
|
+
class InceptionV3(nn.Module):
|
23
|
+
"""
|
24
|
+
InceptionV3 network.
|
25
|
+
Reference: https://arxiv.org/abs/1512.00567
|
26
|
+
Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py
|
27
|
+
|
28
|
+
Attributes:
|
29
|
+
include_head (bool): If True, include classifier head.
|
30
|
+
num_classes (int): Number of classes.
|
31
|
+
pretrained (bool): If True, use pretrained weights.
|
32
|
+
transform_input (bool): If True, preprocesses the input according to the method with which it
|
33
|
+
was trained on ImageNet.
|
34
|
+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
|
35
|
+
dtype (str): Data type.
|
36
|
+
"""
|
37
|
+
include_head: bool=False
|
38
|
+
num_classes: int=1000
|
39
|
+
pretrained: bool=False
|
40
|
+
transform_input: bool=False
|
41
|
+
aux_logits: bool=False
|
42
|
+
ckpt_path: str='https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1'
|
43
|
+
dtype: str='float32'
|
44
|
+
|
45
|
+
def setup(self):
|
46
|
+
if self.pretrained:
|
47
|
+
ckpt_file = utils.download(self.ckpt_path)
|
48
|
+
self.params_dict = pickle.load(open(ckpt_file, 'rb'))
|
49
|
+
self.num_classes_ = 1000
|
50
|
+
else:
|
51
|
+
self.params_dict = None
|
52
|
+
self.num_classes_ = self.num_classes
|
53
|
+
|
54
|
+
@nn.compact
|
55
|
+
def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)):
|
56
|
+
"""
|
57
|
+
Args:
|
58
|
+
x (tensor): Input image, shape [B, H, W, C].
|
59
|
+
train (bool): If True, training mode.
|
60
|
+
rng (jax.random.PRNGKey): Random seed.
|
61
|
+
"""
|
62
|
+
x = self._transform_input(x)
|
63
|
+
x = BasicConv2d(out_channels=32,
|
64
|
+
kernel_size=(3, 3),
|
65
|
+
strides=(2, 2),
|
66
|
+
params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'),
|
67
|
+
dtype=self.dtype)(x, train)
|
68
|
+
x = BasicConv2d(out_channels=32,
|
69
|
+
kernel_size=(3, 3),
|
70
|
+
params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'),
|
71
|
+
dtype=self.dtype)(x, train)
|
72
|
+
x = BasicConv2d(out_channels=64,
|
73
|
+
kernel_size=(3, 3),
|
74
|
+
padding=((1, 1), (1, 1)),
|
75
|
+
params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'),
|
76
|
+
dtype=self.dtype)(x, train)
|
77
|
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
78
|
+
x = BasicConv2d(out_channels=80,
|
79
|
+
kernel_size=(1, 1),
|
80
|
+
params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'),
|
81
|
+
dtype=self.dtype)(x, train)
|
82
|
+
x = BasicConv2d(out_channels=192,
|
83
|
+
kernel_size=(3, 3),
|
84
|
+
params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'),
|
85
|
+
dtype=self.dtype)(x, train)
|
86
|
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
87
|
+
x = InceptionA(pool_features=32,
|
88
|
+
params_dict=utils.get(self.params_dict, 'Mixed_5b'),
|
89
|
+
dtype=self.dtype)(x, train)
|
90
|
+
x = InceptionA(pool_features=64,
|
91
|
+
params_dict=utils.get(self.params_dict, 'Mixed_5c'),
|
92
|
+
dtype=self.dtype)(x, train)
|
93
|
+
x = InceptionA(pool_features=64,
|
94
|
+
params_dict=utils.get(self.params_dict, 'Mixed_5d'),
|
95
|
+
dtype=self.dtype)(x, train)
|
96
|
+
x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'),
|
97
|
+
dtype=self.dtype)(x, train)
|
98
|
+
x = InceptionC(channels_7x7=128,
|
99
|
+
params_dict=utils.get(self.params_dict, 'Mixed_6b'),
|
100
|
+
dtype=self.dtype)(x, train)
|
101
|
+
x = InceptionC(channels_7x7=160,
|
102
|
+
params_dict=utils.get(self.params_dict, 'Mixed_6c'),
|
103
|
+
dtype=self.dtype)(x, train)
|
104
|
+
x = InceptionC(channels_7x7=160,
|
105
|
+
params_dict=utils.get(self.params_dict, 'Mixed_6d'),
|
106
|
+
dtype=self.dtype)(x, train)
|
107
|
+
x = InceptionC(channels_7x7=192,
|
108
|
+
params_dict=utils.get(self.params_dict, 'Mixed_6e'),
|
109
|
+
dtype=self.dtype)(x, train)
|
110
|
+
aux = None
|
111
|
+
if self.aux_logits and train:
|
112
|
+
aux = InceptionAux(num_classes=self.num_classes_,
|
113
|
+
params_dict=utils.get(self.params_dict, 'AuxLogits'),
|
114
|
+
dtype=self.dtype)(x, train)
|
115
|
+
x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'),
|
116
|
+
dtype=self.dtype)(x, train)
|
117
|
+
x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'),
|
118
|
+
dtype=self.dtype)(x, train)
|
119
|
+
# Following the implementation by @mseitzer, we use max pooling instead
|
120
|
+
# of average pooling here.
|
121
|
+
# See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320
|
122
|
+
x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'),
|
123
|
+
dtype=self.dtype)(x, train)
|
124
|
+
x = jnp.mean(x, axis=(1, 2), keepdims=True)
|
125
|
+
if not self.include_head:
|
126
|
+
return x
|
127
|
+
x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng)
|
128
|
+
x = jnp.reshape(x, newshape=(x.shape[0], -1))
|
129
|
+
x = Dense(features=self.num_classes_,
|
130
|
+
params_dict=utils.get(self.params_dict, 'fc'),
|
131
|
+
dtype=self.dtype)(x)
|
132
|
+
if self.aux_logits:
|
133
|
+
return x, aux
|
134
|
+
return x
|
135
|
+
|
136
|
+
def _transform_input(self, x):
|
137
|
+
if self.transform_input:
|
138
|
+
x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
139
|
+
x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
140
|
+
x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
141
|
+
x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1)
|
142
|
+
return x
|
143
|
+
|
144
|
+
|
145
|
+
class Dense(nn.Module):
|
146
|
+
features: int
|
147
|
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
148
|
+
bias_init: functools.partial=nn.initializers.zeros
|
149
|
+
params_dict: dict=None
|
150
|
+
dtype: str='float32'
|
151
|
+
|
152
|
+
@nn.compact
|
153
|
+
def __call__(self, x):
|
154
|
+
x = nn.Dense(features=self.features,
|
155
|
+
kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']),
|
156
|
+
bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x)
|
157
|
+
return x
|
158
|
+
|
159
|
+
|
160
|
+
class BasicConv2d(nn.Module):
|
161
|
+
out_channels: int
|
162
|
+
kernel_size: Union[int, Iterable[int]]=(3, 3)
|
163
|
+
strides: Optional[Iterable[int]]=(1, 1)
|
164
|
+
padding: Union[str, Iterable[Tuple[int, int]]]='valid'
|
165
|
+
use_bias: bool=False
|
166
|
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
167
|
+
bias_init: functools.partial=nn.initializers.zeros
|
168
|
+
params_dict: dict=None
|
169
|
+
dtype: str='float32'
|
170
|
+
|
171
|
+
@nn.compact
|
172
|
+
def __call__(self, x, train=True):
|
173
|
+
x = nn.Conv(features=self.out_channels,
|
174
|
+
kernel_size=self.kernel_size,
|
175
|
+
strides=self.strides,
|
176
|
+
padding=self.padding,
|
177
|
+
use_bias=self.use_bias,
|
178
|
+
kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']),
|
179
|
+
bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']),
|
180
|
+
dtype=self.dtype)(x)
|
181
|
+
if self.params_dict is None:
|
182
|
+
x = BatchNorm(epsilon=0.001,
|
183
|
+
momentum=0.1,
|
184
|
+
use_running_average=not train,
|
185
|
+
dtype=self.dtype)(x)
|
186
|
+
else:
|
187
|
+
x = BatchNorm(epsilon=0.001,
|
188
|
+
momentum=0.1,
|
189
|
+
bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']),
|
190
|
+
scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']),
|
191
|
+
mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']),
|
192
|
+
var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']),
|
193
|
+
use_running_average=not train,
|
194
|
+
dtype=self.dtype)(x)
|
195
|
+
x = jax.nn.relu(x)
|
196
|
+
return x
|
197
|
+
|
198
|
+
|
199
|
+
class InceptionA(nn.Module):
|
200
|
+
pool_features: int
|
201
|
+
params_dict: dict=None
|
202
|
+
dtype: str='float32'
|
203
|
+
|
204
|
+
@nn.compact
|
205
|
+
def __call__(self, x, train=True):
|
206
|
+
branch1x1 = BasicConv2d(out_channels=64,
|
207
|
+
kernel_size=(1, 1),
|
208
|
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
209
|
+
dtype=self.dtype)(x, train)
|
210
|
+
branch5x5 = BasicConv2d(out_channels=48,
|
211
|
+
kernel_size=(1, 1),
|
212
|
+
params_dict=utils.get(self.params_dict, 'branch5x5_1'),
|
213
|
+
dtype=self.dtype)(x, train)
|
214
|
+
branch5x5 = BasicConv2d(out_channels=64,
|
215
|
+
kernel_size=(5, 5),
|
216
|
+
padding=((2, 2), (2, 2)),
|
217
|
+
params_dict=utils.get(self.params_dict, 'branch5x5_2'),
|
218
|
+
dtype=self.dtype)(branch5x5, train)
|
219
|
+
|
220
|
+
branch3x3dbl = BasicConv2d(out_channels=64,
|
221
|
+
kernel_size=(1, 1),
|
222
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
223
|
+
dtype=self.dtype)(x, train)
|
224
|
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
225
|
+
kernel_size=(3, 3),
|
226
|
+
padding=((1, 1), (1, 1)),
|
227
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
228
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
229
|
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
230
|
+
kernel_size=(3, 3),
|
231
|
+
padding=((1, 1), (1, 1)),
|
232
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
|
233
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
234
|
+
|
235
|
+
branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
236
|
+
branch_pool = BasicConv2d(out_channels=self.pool_features,
|
237
|
+
kernel_size=(1, 1),
|
238
|
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
239
|
+
dtype=self.dtype)(branch_pool, train)
|
240
|
+
|
241
|
+
output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1)
|
242
|
+
return output
|
243
|
+
|
244
|
+
|
245
|
+
class InceptionB(nn.Module):
|
246
|
+
params_dict: dict=None
|
247
|
+
dtype: str='float32'
|
248
|
+
|
249
|
+
@nn.compact
|
250
|
+
def __call__(self, x, train=True):
|
251
|
+
branch3x3 = BasicConv2d(out_channels=384,
|
252
|
+
kernel_size=(3, 3),
|
253
|
+
strides=(2, 2),
|
254
|
+
params_dict=utils.get(self.params_dict, 'branch3x3'),
|
255
|
+
dtype=self.dtype)(x, train)
|
256
|
+
|
257
|
+
branch3x3dbl = BasicConv2d(out_channels=64,
|
258
|
+
kernel_size=(1, 1),
|
259
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
260
|
+
dtype=self.dtype)(x, train)
|
261
|
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
262
|
+
kernel_size=(3, 3),
|
263
|
+
padding=((1, 1), (1, 1)),
|
264
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
265
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
266
|
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
267
|
+
kernel_size=(3, 3),
|
268
|
+
strides=(2, 2),
|
269
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
|
270
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
271
|
+
|
272
|
+
branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
273
|
+
|
274
|
+
output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1)
|
275
|
+
return output
|
276
|
+
|
277
|
+
|
278
|
+
class InceptionC(nn.Module):
|
279
|
+
channels_7x7: int
|
280
|
+
params_dict: dict=None
|
281
|
+
dtype: str='float32'
|
282
|
+
|
283
|
+
@nn.compact
|
284
|
+
def __call__(self, x, train=True):
|
285
|
+
branch1x1 = BasicConv2d(out_channels=192,
|
286
|
+
kernel_size=(1, 1),
|
287
|
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
288
|
+
dtype=self.dtype)(x, train)
|
289
|
+
|
290
|
+
branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
|
291
|
+
kernel_size=(1, 1),
|
292
|
+
params_dict=utils.get(self.params_dict, 'branch7x7_1'),
|
293
|
+
dtype=self.dtype)(x, train)
|
294
|
+
branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
|
295
|
+
kernel_size=(1, 7),
|
296
|
+
padding=((0, 0), (3, 3)),
|
297
|
+
params_dict=utils.get(self.params_dict, 'branch7x7_2'),
|
298
|
+
dtype=self.dtype)(branch7x7, train)
|
299
|
+
branch7x7 = BasicConv2d(out_channels=192,
|
300
|
+
kernel_size=(7, 1),
|
301
|
+
padding=((3, 3), (0, 0)),
|
302
|
+
params_dict=utils.get(self.params_dict, 'branch7x7_3'),
|
303
|
+
dtype=self.dtype)(branch7x7, train)
|
304
|
+
|
305
|
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
306
|
+
kernel_size=(1, 1),
|
307
|
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'),
|
308
|
+
dtype=self.dtype)(x, train)
|
309
|
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
310
|
+
kernel_size=(7, 1),
|
311
|
+
padding=((3, 3), (0, 0)),
|
312
|
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'),
|
313
|
+
dtype=self.dtype)(branch7x7dbl, train)
|
314
|
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
315
|
+
kernel_size=(1, 7),
|
316
|
+
padding=((0, 0), (3, 3)),
|
317
|
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'),
|
318
|
+
dtype=self.dtype)(branch7x7dbl, train)
|
319
|
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
320
|
+
kernel_size=(7, 1),
|
321
|
+
padding=((3, 3), (0, 0)),
|
322
|
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'),
|
323
|
+
dtype=self.dtype)(branch7x7dbl, train)
|
324
|
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
325
|
+
kernel_size=(1, 7),
|
326
|
+
padding=((0, 0), (3, 3)),
|
327
|
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'),
|
328
|
+
dtype=self.dtype)(branch7x7dbl, train)
|
329
|
+
|
330
|
+
branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
331
|
+
branch_pool = BasicConv2d(out_channels=192,
|
332
|
+
kernel_size=(1, 1),
|
333
|
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
334
|
+
dtype=self.dtype)(branch_pool, train)
|
335
|
+
|
336
|
+
output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1)
|
337
|
+
return output
|
338
|
+
|
339
|
+
|
340
|
+
class InceptionD(nn.Module):
|
341
|
+
params_dict: dict=None
|
342
|
+
dtype: str='float32'
|
343
|
+
|
344
|
+
@nn.compact
|
345
|
+
def __call__(self, x, train=True):
|
346
|
+
branch3x3 = BasicConv2d(out_channels=192,
|
347
|
+
kernel_size=(1, 1),
|
348
|
+
params_dict=utils.get(self.params_dict, 'branch3x3_1'),
|
349
|
+
dtype=self.dtype)(x, train)
|
350
|
+
branch3x3 = BasicConv2d(out_channels=320,
|
351
|
+
kernel_size=(3, 3),
|
352
|
+
strides=(2, 2),
|
353
|
+
params_dict=utils.get(self.params_dict, 'branch3x3_2'),
|
354
|
+
dtype=self.dtype)(branch3x3, train)
|
355
|
+
|
356
|
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
357
|
+
kernel_size=(1, 1),
|
358
|
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_1'),
|
359
|
+
dtype=self.dtype)(x, train)
|
360
|
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
361
|
+
kernel_size=(1, 7),
|
362
|
+
padding=((0, 0), (3, 3)),
|
363
|
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_2'),
|
364
|
+
dtype=self.dtype)(branch7x7x3, train)
|
365
|
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
366
|
+
kernel_size=(7, 1),
|
367
|
+
padding=((3, 3), (0, 0)),
|
368
|
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_3'),
|
369
|
+
dtype=self.dtype)(branch7x7x3, train)
|
370
|
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
371
|
+
kernel_size=(3, 3),
|
372
|
+
strides=(2, 2),
|
373
|
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_4'),
|
374
|
+
dtype=self.dtype)(branch7x7x3, train)
|
375
|
+
|
376
|
+
branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
377
|
+
|
378
|
+
output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1)
|
379
|
+
return output
|
380
|
+
|
381
|
+
|
382
|
+
class InceptionE(nn.Module):
|
383
|
+
pooling: Callable
|
384
|
+
params_dict: dict=None
|
385
|
+
dtype: str='float32'
|
386
|
+
|
387
|
+
@nn.compact
|
388
|
+
def __call__(self, x, train=True):
|
389
|
+
branch1x1 = BasicConv2d(out_channels=320,
|
390
|
+
kernel_size=(1, 1),
|
391
|
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
392
|
+
dtype=self.dtype)(x, train)
|
393
|
+
|
394
|
+
branch3x3 = BasicConv2d(out_channels=384,
|
395
|
+
kernel_size=(1, 1),
|
396
|
+
params_dict=utils.get(self.params_dict, 'branch3x3_1'),
|
397
|
+
dtype=self.dtype)(x, train)
|
398
|
+
branch3x3_a = BasicConv2d(out_channels=384,
|
399
|
+
kernel_size=(1, 3),
|
400
|
+
padding=((0, 0), (1, 1)),
|
401
|
+
params_dict=utils.get(self.params_dict, 'branch3x3_2a'),
|
402
|
+
dtype=self.dtype)(branch3x3, train)
|
403
|
+
branch3x3_b = BasicConv2d(out_channels=384,
|
404
|
+
kernel_size=(3, 1),
|
405
|
+
padding=((1, 1), (0, 0)),
|
406
|
+
params_dict=utils.get(self.params_dict, 'branch3x3_2b'),
|
407
|
+
dtype=self.dtype)(branch3x3, train)
|
408
|
+
branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1)
|
409
|
+
|
410
|
+
branch3x3dbl = BasicConv2d(out_channels=448,
|
411
|
+
kernel_size=(1, 1),
|
412
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
413
|
+
dtype=self.dtype)(x, train)
|
414
|
+
branch3x3dbl = BasicConv2d(out_channels=384,
|
415
|
+
kernel_size=(3, 3),
|
416
|
+
padding=((1, 1), (1, 1)),
|
417
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
418
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
419
|
+
branch3x3dbl_a = BasicConv2d(out_channels=384,
|
420
|
+
kernel_size=(1, 3),
|
421
|
+
padding=((0, 0), (1, 1)),
|
422
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'),
|
423
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
424
|
+
branch3x3dbl_b = BasicConv2d(out_channels=384,
|
425
|
+
kernel_size=(3, 1),
|
426
|
+
padding=((1, 1), (0, 0)),
|
427
|
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'),
|
428
|
+
dtype=self.dtype)(branch3x3dbl, train)
|
429
|
+
branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1)
|
430
|
+
|
431
|
+
branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
432
|
+
branch_pool = BasicConv2d(out_channels=192,
|
433
|
+
kernel_size=(1, 1),
|
434
|
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
435
|
+
dtype=self.dtype)(branch_pool, train)
|
436
|
+
|
437
|
+
output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1)
|
438
|
+
return output
|
439
|
+
|
440
|
+
|
441
|
+
class InceptionAux(nn.Module):
|
442
|
+
num_classes: int
|
443
|
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
444
|
+
bias_init: functools.partial=nn.initializers.zeros
|
445
|
+
params_dict: dict=None
|
446
|
+
dtype: str='float32'
|
447
|
+
|
448
|
+
@nn.compact
|
449
|
+
def __call__(self, x, train=True):
|
450
|
+
x = avg_pool(x, window_shape=(5, 5), strides=(3, 3))
|
451
|
+
x = BasicConv2d(out_channels=128,
|
452
|
+
kernel_size=(1, 1),
|
453
|
+
params_dict=utils.get(self.params_dict, 'conv0'),
|
454
|
+
dtype=self.dtype)(x, train)
|
455
|
+
x = BasicConv2d(out_channels=768,
|
456
|
+
kernel_size=(5, 5),
|
457
|
+
params_dict=utils.get(self.params_dict, 'conv1'),
|
458
|
+
dtype=self.dtype)(x, train)
|
459
|
+
x = jnp.mean(x, axis=(1, 2))
|
460
|
+
x = jnp.reshape(x, newshape=(x.shape[0], -1))
|
461
|
+
x = Dense(features=self.num_classes,
|
462
|
+
params_dict=utils.get(self.params_dict, 'fc'),
|
463
|
+
dtype=self.dtype)(x)
|
464
|
+
return x
|
465
|
+
|
466
|
+
def _absolute_dims(rank, dims):
|
467
|
+
return tuple([rank + dim if dim < 0 else dim for dim in dims])
|
468
|
+
|
469
|
+
|
470
|
+
class BatchNorm(nn.Module):
|
471
|
+
"""BatchNorm Module.
|
472
|
+
Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
|
473
|
+
Attributes:
|
474
|
+
use_running_average: if True, the statistics stored in batch_stats
|
475
|
+
will be used instead of computing the batch statistics on the input.
|
476
|
+
axis: the feature or non-batch axis of the input.
|
477
|
+
momentum: decay rate for the exponential moving average of the batch statistics.
|
478
|
+
epsilon: a small float added to variance to avoid dividing by zero.
|
479
|
+
dtype: the dtype of the computation (default: float32).
|
480
|
+
use_bias: if True, bias (beta) is added.
|
481
|
+
use_scale: if True, multiply by scale (gamma).
|
482
|
+
When the next layer is linear (also e.g. nn.relu), this can be disabled
|
483
|
+
since the scaling will be done by the next layer.
|
484
|
+
bias_init: initializer for bias, by default, zero.
|
485
|
+
scale_init: initializer for scale, by default, one.
|
486
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
487
|
+
devices. See `jax.pmap` for a description of axis names (default: None).
|
488
|
+
axis_index_groups: groups of axis indices within that named axis
|
489
|
+
representing subsets of devices to reduce over (default: None). For
|
490
|
+
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
491
|
+
the examples on the first two and last two devices. See `jax.lax.psum`
|
492
|
+
for more details.
|
493
|
+
"""
|
494
|
+
use_running_average: Optional[bool] = None
|
495
|
+
axis: int = -1
|
496
|
+
momentum: float = 0.99
|
497
|
+
epsilon: float = 1e-5
|
498
|
+
dtype: Dtype = jnp.float32
|
499
|
+
use_bias: bool = True
|
500
|
+
use_scale: bool = True
|
501
|
+
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
|
502
|
+
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
|
503
|
+
mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
|
504
|
+
var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
|
505
|
+
axis_name: Optional[str] = None
|
506
|
+
axis_index_groups: Any = None
|
507
|
+
|
508
|
+
@nn.compact
|
509
|
+
def __call__(self, x, use_running_average: Optional[bool] = None):
|
510
|
+
"""Normalizes the input using batch statistics.
|
511
|
+
|
512
|
+
NOTE:
|
513
|
+
During initialization (when parameters are mutable) the running average
|
514
|
+
of the batch statistics will not be updated. Therefore, the inputs
|
515
|
+
fed during initialization don't need to match that of the actual input
|
516
|
+
distribution and the reduction axis (set with `axis_name`) does not have
|
517
|
+
to exist.
|
518
|
+
Args:
|
519
|
+
x: the input to be normalized.
|
520
|
+
use_running_average: if true, the statistics stored in batch_stats
|
521
|
+
will be used instead of computing the batch statistics on the input.
|
522
|
+
Returns:
|
523
|
+
Normalized inputs (the same shape as inputs).
|
524
|
+
"""
|
525
|
+
use_running_average = merge_param(
|
526
|
+
'use_running_average', self.use_running_average, use_running_average)
|
527
|
+
x = jnp.asarray(x, jnp.float32)
|
528
|
+
axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
|
529
|
+
axis = _absolute_dims(x.ndim, axis)
|
530
|
+
feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
|
531
|
+
reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
|
532
|
+
reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
|
533
|
+
|
534
|
+
# see NOTE above on initialization behavior
|
535
|
+
initializing = self.is_mutable_collection('params')
|
536
|
+
|
537
|
+
ra_mean = self.variable('batch_stats', 'mean',
|
538
|
+
self.mean_init,
|
539
|
+
reduced_feature_shape)
|
540
|
+
ra_var = self.variable('batch_stats', 'var',
|
541
|
+
self.var_init,
|
542
|
+
reduced_feature_shape)
|
543
|
+
|
544
|
+
if use_running_average:
|
545
|
+
mean, var = ra_mean.value, ra_var.value
|
546
|
+
else:
|
547
|
+
mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
|
548
|
+
mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
|
549
|
+
if self.axis_name is not None and not initializing:
|
550
|
+
concatenated_mean = jnp.concatenate([mean, mean2])
|
551
|
+
mean, mean2 = jnp.split(
|
552
|
+
lax.pmean(
|
553
|
+
concatenated_mean,
|
554
|
+
axis_name=self.axis_name,
|
555
|
+
axis_index_groups=self.axis_index_groups), 2)
|
556
|
+
var = mean2 - lax.square(mean)
|
557
|
+
|
558
|
+
if not initializing:
|
559
|
+
ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
|
560
|
+
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
|
561
|
+
|
562
|
+
y = x - mean.reshape(feature_shape)
|
563
|
+
mul = lax.rsqrt(var + self.epsilon)
|
564
|
+
if self.use_scale:
|
565
|
+
scale = self.param('scale',
|
566
|
+
self.scale_init,
|
567
|
+
reduced_feature_shape).reshape(feature_shape)
|
568
|
+
mul = mul * scale
|
569
|
+
y = y * mul
|
570
|
+
if self.use_bias:
|
571
|
+
bias = self.param('bias',
|
572
|
+
self.bias_init,
|
573
|
+
reduced_feature_shape).reshape(feature_shape)
|
574
|
+
y = y + bias
|
575
|
+
return jnp.asarray(y, self.dtype)
|
576
|
+
|
577
|
+
|
578
|
+
def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
579
|
+
"""
|
580
|
+
Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py
|
581
|
+
|
582
|
+
Helper function to define pooling functions.
|
583
|
+
Pooling functions are implemented using the ReduceWindow XLA op.
|
584
|
+
NOTE: Be aware that pooling is not generally differentiable.
|
585
|
+
That means providing a reduce_fn that is differentiable does not imply
|
586
|
+
that pool is differentiable.
|
587
|
+
Args:
|
588
|
+
inputs: input data with dimensions (batch, window dims..., features).
|
589
|
+
init: the initial value for the reduction
|
590
|
+
reduce_fn: a reduce function of the form `(T, T) -> T`.
|
591
|
+
window_shape: a shape tuple defining the window to reduce over.
|
592
|
+
strides: a sequence of `n` integers, representing the inter-window
|
593
|
+
strides.
|
594
|
+
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
|
595
|
+
of `n` `(low, high)` integer pairs that give the padding to apply before
|
596
|
+
and after each spatial dimension.
|
597
|
+
Returns:
|
598
|
+
The output of the reduction for each window slice.
|
599
|
+
"""
|
600
|
+
strides = strides or (1,) * len(window_shape)
|
601
|
+
assert len(window_shape) == len(strides), (
|
602
|
+
f"len({window_shape}) == len({strides})")
|
603
|
+
strides = (1,) + strides + (1,)
|
604
|
+
dims = (1,) + window_shape + (1,)
|
605
|
+
|
606
|
+
is_single_input = False
|
607
|
+
if inputs.ndim == len(dims) - 1:
|
608
|
+
# add singleton batch dimension because lax.reduce_window always
|
609
|
+
# needs a batch dimension.
|
610
|
+
inputs = inputs[None]
|
611
|
+
is_single_input = True
|
612
|
+
|
613
|
+
assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
|
614
|
+
if not isinstance(padding, str):
|
615
|
+
padding = tuple(map(tuple, padding))
|
616
|
+
assert(len(padding) == len(window_shape)), (
|
617
|
+
f"padding {padding} must specify pads for same number of dims as "
|
618
|
+
f"window_shape {window_shape}")
|
619
|
+
assert(all([len(x) == 2 for x in padding])), (
|
620
|
+
f"each entry in padding {padding} must be length 2")
|
621
|
+
padding = ((0,0),) + padding + ((0,0),)
|
622
|
+
y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
|
623
|
+
if is_single_input:
|
624
|
+
y = jnp.squeeze(y, axis=0)
|
625
|
+
return y
|
626
|
+
|
627
|
+
|
628
|
+
def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
|
629
|
+
"""
|
630
|
+
Pools the input by taking the average over a window.
|
631
|
+
|
632
|
+
In comparison to flax.linen.avg_pool, this pooling operation does not
|
633
|
+
consider the padded zero's for the average computation.
|
634
|
+
|
635
|
+
Args:
|
636
|
+
inputs: input data with dimensions (batch, window dims..., features).
|
637
|
+
window_shape: a shape tuple defining the window to reduce over.
|
638
|
+
strides: a sequence of `n` integers, representing the inter-window
|
639
|
+
strides (default: `(1, ..., 1)`).
|
640
|
+
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
|
641
|
+
of `n` `(low, high)` integer pairs that give the padding to apply before
|
642
|
+
and after each spatial dimension (default: `'VALID'`).
|
643
|
+
Returns:
|
644
|
+
The average for each window slice.
|
645
|
+
"""
|
646
|
+
assert inputs.ndim == 4
|
647
|
+
assert len(window_shape) == 2
|
648
|
+
|
649
|
+
y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
|
650
|
+
ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype)
|
651
|
+
counts = jax.lax.conv_general_dilated(ones,
|
652
|
+
jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)),
|
653
|
+
window_strides=(1, 1),
|
654
|
+
padding=((1, 1), (1, 1)),
|
655
|
+
dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape),
|
656
|
+
feature_group_count=1)
|
657
|
+
y = y / counts
|
658
|
+
return y
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Mostly derived from
|
2
|
+
# https://github.com/matthias-wright/jax-fid
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import flax
|
6
|
+
import numpy as np
|
7
|
+
from tqdm import tqdm
|
8
|
+
import requests
|
9
|
+
import os
|
10
|
+
import tempfile
|
11
|
+
|
12
|
+
|
13
|
+
def download(url, ckpt_dir=None):
|
14
|
+
name = url[url.rfind('/') + 1 : url.rfind('?')]
|
15
|
+
if ckpt_dir is None:
|
16
|
+
ckpt_dir = tempfile.gettempdir()
|
17
|
+
ckpt_dir = os.path.join(ckpt_dir, 'jax_fid')
|
18
|
+
ckpt_file = os.path.join(ckpt_dir, name)
|
19
|
+
if not os.path.exists(ckpt_file):
|
20
|
+
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
|
21
|
+
if not os.path.exists(ckpt_dir):
|
22
|
+
os.makedirs(ckpt_dir)
|
23
|
+
|
24
|
+
response = requests.get(url, stream=True)
|
25
|
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
26
|
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
27
|
+
|
28
|
+
# first create temp file, in case the download fails
|
29
|
+
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
|
30
|
+
with open(ckpt_file_temp, 'wb') as file:
|
31
|
+
for data in response.iter_content(chunk_size=1024):
|
32
|
+
progress_bar.update(len(data))
|
33
|
+
file.write(data)
|
34
|
+
progress_bar.close()
|
35
|
+
|
36
|
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
37
|
+
print('An error occured while downloading, please try again.')
|
38
|
+
if os.path.exists(ckpt_file_temp):
|
39
|
+
os.remove(ckpt_file_temp)
|
40
|
+
else:
|
41
|
+
# if download was successful, rename the temp file
|
42
|
+
os.rename(ckpt_file_temp, ckpt_file)
|
43
|
+
return ckpt_file
|
44
|
+
|
45
|
+
|
46
|
+
def get(dictionary, key):
|
47
|
+
if dictionary is None or key not in dictionary:
|
48
|
+
return None
|
49
|
+
return dictionary[key]
|
@@ -219,6 +219,8 @@ class SimpleTrainer:
|
|
219
219
|
|
220
220
|
def checkpoint_path(self):
|
221
221
|
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
|
222
|
+
# Convert the path to an absolute path
|
223
|
+
path = os.path.abspath(path)
|
222
224
|
if not os.path.exists(path):
|
223
225
|
os.makedirs(path)
|
224
226
|
return path
|
@@ -1,46 +0,0 @@
|
|
1
|
-
README.md
|
2
|
-
setup.py
|
3
|
-
flaxdiff/__init__.py
|
4
|
-
flaxdiff/utils.py
|
5
|
-
flaxdiff.egg-info/PKG-INFO
|
6
|
-
flaxdiff.egg-info/SOURCES.txt
|
7
|
-
flaxdiff.egg-info/dependency_links.txt
|
8
|
-
flaxdiff.egg-info/requires.txt
|
9
|
-
flaxdiff.egg-info/top_level.txt
|
10
|
-
flaxdiff/data/__init__.py
|
11
|
-
flaxdiff/data/dataset_map.py
|
12
|
-
flaxdiff/data/datasets.py
|
13
|
-
flaxdiff/data/online_loader.py
|
14
|
-
flaxdiff/models/__init__.py
|
15
|
-
flaxdiff/models/attention.py
|
16
|
-
flaxdiff/models/common.py
|
17
|
-
flaxdiff/models/favor_fastattn.py
|
18
|
-
flaxdiff/models/simple_unet.py
|
19
|
-
flaxdiff/models/simple_vit.py
|
20
|
-
flaxdiff/models/autoencoder/__init__.py
|
21
|
-
flaxdiff/models/autoencoder/autoencoder.py
|
22
|
-
flaxdiff/models/autoencoder/diffusers.py
|
23
|
-
flaxdiff/models/autoencoder/simple_autoenc.py
|
24
|
-
flaxdiff/predictors/__init__.py
|
25
|
-
flaxdiff/samplers/__init__.py
|
26
|
-
flaxdiff/samplers/common.py
|
27
|
-
flaxdiff/samplers/ddim.py
|
28
|
-
flaxdiff/samplers/ddpm.py
|
29
|
-
flaxdiff/samplers/euler.py
|
30
|
-
flaxdiff/samplers/heun_sampler.py
|
31
|
-
flaxdiff/samplers/multistep_dpm.py
|
32
|
-
flaxdiff/samplers/rk4_sampler.py
|
33
|
-
flaxdiff/schedulers/__init__.py
|
34
|
-
flaxdiff/schedulers/common.py
|
35
|
-
flaxdiff/schedulers/continuous.py
|
36
|
-
flaxdiff/schedulers/cosine.py
|
37
|
-
flaxdiff/schedulers/discrete.py
|
38
|
-
flaxdiff/schedulers/exp.py
|
39
|
-
flaxdiff/schedulers/karras.py
|
40
|
-
flaxdiff/schedulers/linear.py
|
41
|
-
flaxdiff/schedulers/sqrt.py
|
42
|
-
flaxdiff/trainer/__init__.py
|
43
|
-
flaxdiff/trainer/autoencoder_trainer.py
|
44
|
-
flaxdiff/trainer/diffusion_trainer.py
|
45
|
-
flaxdiff/trainer/simple_trainer.py
|
46
|
-
flaxdiff/trainer/video_diffusion_trainer.py
|
@@ -1 +0,0 @@
|
|
1
|
-
flaxdiff
|
flaxdiff-0.1.36/setup.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
from setuptools import find_packages, setup
|
2
|
-
|
3
|
-
required_packages=[
|
4
|
-
'flax>=0.8.4',
|
5
|
-
'optax>=0.2.2',
|
6
|
-
'jax>=0.4.28',
|
7
|
-
'orbax',
|
8
|
-
'clu',
|
9
|
-
]
|
10
|
-
|
11
|
-
setup(
|
12
|
-
name='flaxdiff',
|
13
|
-
packages=find_packages(),
|
14
|
-
version='0.1.36',
|
15
|
-
description='A versatile and easy to understand Diffusion library',
|
16
|
-
long_description=open('README.md').read(),
|
17
|
-
long_description_content_type='text/markdown',
|
18
|
-
author='Ashish Kumar Singh',
|
19
|
-
author_email='ashishkmr472@gmail.com',
|
20
|
-
install_requires=required_packages,
|
21
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|