flaxdiff 0.1.35.5__tar.gz → 0.1.36__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/PKG-INFO +8 -2
- flaxdiff-0.1.36/flaxdiff/data/dataset_map.py +71 -0
- flaxdiff-0.1.36/flaxdiff/data/datasets.py +169 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/data/online_loader.py +69 -42
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/attention.py +1 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/simple_unet.py +11 -11
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/simple_vit.py +1 -1
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/common.py +72 -20
- flaxdiff-0.1.36/flaxdiff/samplers/ddim.py +10 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/ddpm.py +5 -11
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/euler.py +7 -10
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/heun_sampler.py +3 -4
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/multistep_dpm.py +2 -3
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/rk4_sampler.py +9 -9
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/autoencoder_trainer.py +1 -1
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/diffusion_trainer.py +124 -32
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/simple_trainer.py +187 -91
- flaxdiff-0.1.36/flaxdiff/trainer/video_diffusion_trainer.py +62 -0
- flaxdiff-0.1.36/flaxdiff/utils.py +192 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/PKG-INFO +8 -2
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/SOURCES.txt +4 -1
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/setup.py +1 -1
- flaxdiff-0.1.35.5/flaxdiff/samplers/ddim.py +0 -10
- flaxdiff-0.1.35.5/flaxdiff/utils.py +0 -89
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/README.md +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.36
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -10,6 +10,12 @@ Requires-Dist: optax>=0.2.2
|
|
10
10
|
Requires-Dist: jax>=0.4.28
|
11
11
|
Requires-Dist: orbax
|
12
12
|
Requires-Dist: clu
|
13
|
+
Dynamic: author
|
14
|
+
Dynamic: author-email
|
15
|
+
Dynamic: description
|
16
|
+
Dynamic: description-content-type
|
17
|
+
Dynamic: requires-dist
|
18
|
+
Dynamic: summary
|
13
19
|
|
14
20
|
# 
|
15
21
|
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from .sources.tfds import data_source_tfds, tfds_augmenters
|
2
|
+
from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
|
3
|
+
|
4
|
+
# Configure the following for your datasets
|
5
|
+
datasetMap = {
|
6
|
+
"oxford_flowers102": {
|
7
|
+
"source": data_source_tfds("oxford_flowers102", use_tf=False),
|
8
|
+
"augmenter": tfds_augmenters,
|
9
|
+
},
|
10
|
+
"cc12m": {
|
11
|
+
"source": data_source_gcs('arrayrecord2/cc12m'),
|
12
|
+
"augmenter": gcs_augmenters,
|
13
|
+
},
|
14
|
+
"laiona_coco": {
|
15
|
+
"source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
|
16
|
+
"augmenter": gcs_augmenters,
|
17
|
+
},
|
18
|
+
"aesthetic_coyo": {
|
19
|
+
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
|
20
|
+
"augmenter": gcs_augmenters,
|
21
|
+
},
|
22
|
+
"combined_aesthetic": {
|
23
|
+
"source": data_source_combined_gcs([
|
24
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
25
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
26
|
+
'arrayrecord2/cc12m',
|
27
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
28
|
+
]),
|
29
|
+
"augmenter": gcs_augmenters,
|
30
|
+
},
|
31
|
+
"laiona_coco_coyo": {
|
32
|
+
"source": data_source_combined_gcs([
|
33
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
34
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
35
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
36
|
+
]),
|
37
|
+
"augmenter": gcs_augmenters,
|
38
|
+
},
|
39
|
+
"combined_30m": {
|
40
|
+
"source": data_source_combined_gcs([
|
41
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
42
|
+
'arrayrecord2/cc12m',
|
43
|
+
'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
|
44
|
+
"arrayrecord2/playground+leonardo_x4+cc3m.parquet",
|
45
|
+
]),
|
46
|
+
"augmenter": gcs_augmenters,
|
47
|
+
}
|
48
|
+
}
|
49
|
+
|
50
|
+
onlineDatasetMap = {
|
51
|
+
"combined_online": {
|
52
|
+
"source": [
|
53
|
+
# "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
|
54
|
+
# "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
|
55
|
+
# "dclure/laion-aesthetics-12m-umap",
|
56
|
+
"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
|
57
|
+
"gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
|
58
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
59
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
60
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
61
|
+
"gs://flaxdiff-datasets-regional/datasets/cc12m",
|
62
|
+
"gs://flaxdiff-datasets-regional/datasets/playground-liked",
|
63
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
64
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
65
|
+
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
66
|
+
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
67
|
+
"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
|
68
|
+
# "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
|
69
|
+
]
|
70
|
+
}
|
71
|
+
}
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import grain.python as pygrain
|
3
|
+
from typing import Dict
|
4
|
+
import numpy as np
|
5
|
+
import jax
|
6
|
+
from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
|
7
|
+
from .dataset_map import datasetMap, onlineDatasetMap
|
8
|
+
import traceback
|
9
|
+
from .online_loader import OnlineStreamingDataLoader
|
10
|
+
import queue
|
11
|
+
from jax.sharding import Mesh
|
12
|
+
import threading
|
13
|
+
|
14
|
+
def batch_mesh_map(mesh):
|
15
|
+
class augmenters(pygrain.MapTransform):
|
16
|
+
def __init__(self, *args, **kwargs):
|
17
|
+
super().__init__(*args, **kwargs)
|
18
|
+
|
19
|
+
def map(self, batch) -> Dict[str, jnp.array]:
|
20
|
+
return convert_to_global_tree(mesh, batch)
|
21
|
+
return augmenters
|
22
|
+
|
23
|
+
def get_dataset_grain(
|
24
|
+
data_name="cc12m",
|
25
|
+
batch_size=64,
|
26
|
+
image_scale=256,
|
27
|
+
count=None,
|
28
|
+
num_epochs=None,
|
29
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
30
|
+
worker_count=32,
|
31
|
+
read_thread_count=64,
|
32
|
+
read_buffer_size=50,
|
33
|
+
worker_buffer_size=20,
|
34
|
+
seed=0,
|
35
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
36
|
+
):
|
37
|
+
dataset = datasetMap[data_name]
|
38
|
+
data_source = dataset["source"](dataset_source)
|
39
|
+
augmenter = dataset["augmenter"](image_scale, method)
|
40
|
+
|
41
|
+
local_batch_size = batch_size // jax.process_count()
|
42
|
+
|
43
|
+
sampler = pygrain.IndexSampler(
|
44
|
+
num_records=len(data_source) if count is None else count,
|
45
|
+
shuffle=True,
|
46
|
+
seed=seed,
|
47
|
+
num_epochs=num_epochs,
|
48
|
+
shard_options=pygrain.ShardByJaxProcess(),
|
49
|
+
)
|
50
|
+
|
51
|
+
def get_trainset():
|
52
|
+
transformations = [
|
53
|
+
augmenter(),
|
54
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
55
|
+
]
|
56
|
+
|
57
|
+
# if mesh != None:
|
58
|
+
# transformations += [batch_mesh_map(mesh)]
|
59
|
+
|
60
|
+
loader = pygrain.DataLoader(
|
61
|
+
data_source=data_source,
|
62
|
+
sampler=sampler,
|
63
|
+
operations=transformations,
|
64
|
+
worker_count=worker_count,
|
65
|
+
read_options=pygrain.ReadOptions(
|
66
|
+
read_thread_count, read_buffer_size
|
67
|
+
),
|
68
|
+
worker_buffer_size=worker_buffer_size,
|
69
|
+
)
|
70
|
+
return loader
|
71
|
+
|
72
|
+
|
73
|
+
return {
|
74
|
+
"train": get_trainset,
|
75
|
+
"train_len": len(data_source),
|
76
|
+
"local_batch_size": local_batch_size,
|
77
|
+
"global_batch_size": batch_size,
|
78
|
+
# "null_labels": null_labels,
|
79
|
+
# "null_labels_full": null_labels_full,
|
80
|
+
# "model": model,
|
81
|
+
# "tokenizer": tokenizer,
|
82
|
+
}
|
83
|
+
|
84
|
+
def generate_collate_fn():
|
85
|
+
auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
86
|
+
def default_collate(batch):
|
87
|
+
try:
|
88
|
+
# urls = [sample["url"] for sample in batch]
|
89
|
+
captions = [sample["caption"] for sample in batch]
|
90
|
+
results = auto_tokenize(captions)
|
91
|
+
images = np.stack([sample["image"] for sample in batch], axis=0)
|
92
|
+
return {
|
93
|
+
"image": images,
|
94
|
+
"input_ids": results['input_ids'],
|
95
|
+
"attention_mask": results['attention_mask'],
|
96
|
+
}
|
97
|
+
except Exception as e:
|
98
|
+
print("Error in collate function", e, [sample["image"].shape for sample in batch])
|
99
|
+
traceback.print_exc()
|
100
|
+
|
101
|
+
return default_collate
|
102
|
+
|
103
|
+
def get_dataset_online(
|
104
|
+
data_name="combined_online",
|
105
|
+
batch_size=64,
|
106
|
+
image_scale=256,
|
107
|
+
count=None,
|
108
|
+
num_epochs=None,
|
109
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
110
|
+
worker_count=32,
|
111
|
+
read_thread_count=64,
|
112
|
+
read_buffer_size=50,
|
113
|
+
worker_buffer_size=20,
|
114
|
+
seed=0,
|
115
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
116
|
+
):
|
117
|
+
local_batch_size = batch_size // jax.process_count()
|
118
|
+
|
119
|
+
sources = onlineDatasetMap[data_name]["source"]
|
120
|
+
dataloader = OnlineStreamingDataLoader(
|
121
|
+
sources,
|
122
|
+
batch_size=local_batch_size,
|
123
|
+
num_workers=worker_count,
|
124
|
+
num_threads=read_thread_count,
|
125
|
+
image_shape=(image_scale, image_scale),
|
126
|
+
global_process_count=jax.process_count(),
|
127
|
+
global_process_index=jax.process_index(),
|
128
|
+
prefetch=worker_buffer_size,
|
129
|
+
collate_fn=generate_collate_fn(),
|
130
|
+
default_split="train",
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_trainset(mesh: Mesh = None):
|
134
|
+
if mesh != None:
|
135
|
+
class dataLoaderWithMesh:
|
136
|
+
def __init__(self, dataloader, mesh):
|
137
|
+
self.dataloader = dataloader
|
138
|
+
self.mesh = mesh
|
139
|
+
self.tmp_queue = queue.Queue(worker_buffer_size)
|
140
|
+
def batch_loader():
|
141
|
+
for batch in self.dataloader:
|
142
|
+
try:
|
143
|
+
self.tmp_queue.put(convert_to_global_tree(mesh, batch))
|
144
|
+
except Exception as e:
|
145
|
+
print("Error processing batch", e)
|
146
|
+
self.loader_thread = threading.Thread(target=batch_loader)
|
147
|
+
self.loader_thread.start()
|
148
|
+
|
149
|
+
def __iter__(self):
|
150
|
+
return self
|
151
|
+
|
152
|
+
def __next__(self):
|
153
|
+
return self.tmp_queue.get()
|
154
|
+
|
155
|
+
dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
|
156
|
+
|
157
|
+
return dataloader_with_mesh
|
158
|
+
return dataloader
|
159
|
+
|
160
|
+
return {
|
161
|
+
"train": get_trainset,
|
162
|
+
"train_len": len(dataloader) * jax.process_count(),
|
163
|
+
"local_batch_size": local_batch_size,
|
164
|
+
"global_batch_size": batch_size,
|
165
|
+
# "null_labels": null_labels,
|
166
|
+
# "null_labels_full": null_labels_full,
|
167
|
+
# "model": model,
|
168
|
+
# "tokenizer": tokenizer,
|
169
|
+
}
|
@@ -45,36 +45,43 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
45
45
|
|
46
46
|
|
47
47
|
def default_image_processor(
|
48
|
-
image, image_shape,
|
48
|
+
image, image_shape,
|
49
49
|
min_image_shape=(128, 128),
|
50
50
|
upscale_interpolation=cv2.INTER_CUBIC,
|
51
51
|
downscale_interpolation=cv2.INTER_AREA,
|
52
52
|
):
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
image,
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
53
|
+
try:
|
54
|
+
image = np.array(image)
|
55
|
+
if len(image.shape) != 3 or image.shape[2] != 3:
|
56
|
+
return None, 0, 0
|
57
|
+
original_height, original_width = image.shape[:2]
|
58
|
+
# check if the image is too small
|
59
|
+
if min(original_height, original_width) < min(min_image_shape):
|
60
|
+
return None, original_height, original_width
|
61
|
+
# check if wrong aspect ratio
|
62
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
63
|
+
return None, original_height, original_width
|
64
|
+
# check if the variance is too low
|
65
|
+
if np.std(image) < 1e-5:
|
66
|
+
return None, original_height, original_width
|
67
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
68
|
+
downscale = max(original_width, original_height) > max(image_shape)
|
69
|
+
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
70
|
+
|
71
|
+
image = A.longest_max_size(image, max(
|
72
|
+
image_shape), interpolation=interpolation)
|
73
|
+
image = A.pad(
|
74
|
+
image,
|
75
|
+
min_height=image_shape[0],
|
76
|
+
min_width=image_shape[1],
|
77
|
+
border_mode=cv2.BORDER_CONSTANT,
|
78
|
+
value=[255, 255, 255],
|
79
|
+
)
|
80
|
+
return image, original_height, original_width
|
81
|
+
except Exception as e:
|
82
|
+
# print("Error processing image", e, image_shape, interpolation)
|
83
|
+
# traceback.print_exc()
|
84
|
+
return None, 0, 0
|
78
85
|
|
79
86
|
|
80
87
|
def map_sample(
|
@@ -120,14 +127,36 @@ def map_sample(
|
|
120
127
|
# })
|
121
128
|
pass
|
122
129
|
|
123
|
-
|
124
130
|
def default_feature_extractor(sample):
|
131
|
+
url = None
|
132
|
+
if "url" in sample:
|
133
|
+
url = sample["url"]
|
134
|
+
elif "URL" in sample:
|
135
|
+
url = sample["URL"]
|
136
|
+
elif "image_url" in sample:
|
137
|
+
url = sample["image_url"]
|
138
|
+
else:
|
139
|
+
print("No url found in sample, skipping", sample.keys())
|
140
|
+
|
141
|
+
caption = None
|
142
|
+
if "caption" in sample:
|
143
|
+
caption = sample["caption"]
|
144
|
+
elif "CAPTION" in sample:
|
145
|
+
caption = sample["CAPTION"]
|
146
|
+
elif "txt" in sample:
|
147
|
+
caption = sample["txt"]
|
148
|
+
elif "TEXT" in sample:
|
149
|
+
caption = sample["TEXT"]
|
150
|
+
elif "text" in sample:
|
151
|
+
caption = sample["text"]
|
152
|
+
else:
|
153
|
+
print("No caption found in sample, skipping", sample.keys())
|
154
|
+
|
125
155
|
return {
|
126
|
-
"url":
|
127
|
-
"caption":
|
156
|
+
"url": url,
|
157
|
+
"caption": caption,
|
128
158
|
}
|
129
|
-
|
130
|
-
|
159
|
+
|
131
160
|
def map_batch(
|
132
161
|
batch, num_threads=256, image_shape=(256, 256),
|
133
162
|
min_image_shape=(128, 128),
|
@@ -301,15 +330,13 @@ class OnlineStreamingDataLoader():
|
|
301
330
|
self.dataset = dataset.shard(
|
302
331
|
num_shards=global_process_count, index=global_process_index)
|
303
332
|
print(f"Dataset length: {len(dataset)}")
|
304
|
-
self.iterator = ImageBatchIterator(
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
feature_extractor=feature_extractor
|
312
|
-
)
|
333
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
334
|
+
min_image_shape=min_image_shape,
|
335
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
336
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
337
|
+
upscale_interpolation=upscale_interpolation,
|
338
|
+
downscale_interpolation=downscale_interpolation,
|
339
|
+
feature_extractor=feature_extractor)
|
313
340
|
self.batch_size = batch_size
|
314
341
|
|
315
342
|
# Launch a thread to load batches in the background
|
@@ -320,7 +347,7 @@ class OnlineStreamingDataLoader():
|
|
320
347
|
try:
|
321
348
|
self.batch_queue.put(collate_fn(batch))
|
322
349
|
except Exception as e:
|
323
|
-
print("Error
|
350
|
+
print("Error collating batch", e)
|
324
351
|
|
325
352
|
self.loader_thread = threading.Thread(target=batch_loader)
|
326
353
|
self.loader_thread.start()
|
@@ -333,4 +360,4 @@ class OnlineStreamingDataLoader():
|
|
333
360
|
# return self.collate_fn(next(self.iterator))
|
334
361
|
|
335
362
|
def __len__(self):
|
336
|
-
return len(self.dataset)
|
363
|
+
return len(self.dataset)
|
@@ -50,7 +50,7 @@ class Unet(nn.Module):
|
|
50
50
|
features=self.feature_depths[0],
|
51
51
|
kernel_size=(3, 3),
|
52
52
|
strides=(1, 1),
|
53
|
-
kernel_init=self.kernel_init(1.0),
|
53
|
+
kernel_init=self.kernel_init(scale=1.0),
|
54
54
|
dtype=self.dtype,
|
55
55
|
precision=self.precision
|
56
56
|
)(x)
|
@@ -65,7 +65,7 @@ class Unet(nn.Module):
|
|
65
65
|
down_conv_type,
|
66
66
|
name=f"down_{i}_residual_{j}",
|
67
67
|
features=dim_in,
|
68
|
-
kernel_init=self.kernel_init(1.0),
|
68
|
+
kernel_init=self.kernel_init(scale=1.0),
|
69
69
|
kernel_size=(3, 3),
|
70
70
|
strides=(1, 1),
|
71
71
|
activation=self.activation,
|
@@ -85,7 +85,7 @@ class Unet(nn.Module):
|
|
85
85
|
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
86
|
norm_inputs=attention_config.get("norm_inputs", True),
|
87
87
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
88
|
-
kernel_init=self.kernel_init(1.0),
|
88
|
+
kernel_init=self.kernel_init(scale=1.0),
|
89
89
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
90
90
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
91
91
|
downs.append(x)
|
@@ -108,7 +108,7 @@ class Unet(nn.Module):
|
|
108
108
|
middle_conv_type,
|
109
109
|
name=f"middle_res1_{j}",
|
110
110
|
features=middle_dim_out,
|
111
|
-
kernel_init=self.kernel_init(1.0),
|
111
|
+
kernel_init=self.kernel_init(scale=1.0),
|
112
112
|
kernel_size=(3, 3),
|
113
113
|
strides=(1, 1),
|
114
114
|
activation=self.activation,
|
@@ -129,13 +129,13 @@ class Unet(nn.Module):
|
|
129
129
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
130
130
|
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
131
|
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
132
|
-
kernel_init=self.kernel_init(1.0),
|
132
|
+
kernel_init=self.kernel_init(scale=1.0),
|
133
133
|
name=f"middle_attention_{j}")(x, textcontext)
|
134
134
|
x = ResidualBlock(
|
135
135
|
middle_conv_type,
|
136
136
|
name=f"middle_res2_{j}",
|
137
137
|
features=middle_dim_out,
|
138
|
-
kernel_init=self.kernel_init(1.0),
|
138
|
+
kernel_init=self.kernel_init(scale=1.0),
|
139
139
|
kernel_size=(3, 3),
|
140
140
|
strides=(1, 1),
|
141
141
|
activation=self.activation,
|
@@ -157,7 +157,7 @@ class Unet(nn.Module):
|
|
157
157
|
up_conv_type,# if j == 0 else "separable",
|
158
158
|
name=f"up_{i}_residual_{j}",
|
159
159
|
features=dim_out,
|
160
|
-
kernel_init=self.kernel_init(1.0),
|
160
|
+
kernel_init=self.kernel_init(scale=1.0),
|
161
161
|
kernel_size=kernel_size,
|
162
162
|
strides=(1, 1),
|
163
163
|
activation=self.activation,
|
@@ -177,7 +177,7 @@ class Unet(nn.Module):
|
|
177
177
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
178
178
|
norm_inputs=attention_config.get("norm_inputs", True),
|
179
179
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
180
|
-
kernel_init=self.kernel_init(1.0),
|
180
|
+
kernel_init=self.kernel_init(scale=1.0),
|
181
181
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
182
182
|
# print("Upscaling ", i, x.shape)
|
183
183
|
if i != len(feature_depths) - 1:
|
@@ -196,7 +196,7 @@ class Unet(nn.Module):
|
|
196
196
|
features=self.feature_depths[0],
|
197
197
|
kernel_size=(3, 3),
|
198
198
|
strides=(1, 1),
|
199
|
-
kernel_init=self.kernel_init(1.0),
|
199
|
+
kernel_init=self.kernel_init(scale=1.0),
|
200
200
|
dtype=self.dtype,
|
201
201
|
precision=self.precision
|
202
202
|
)(x)
|
@@ -207,7 +207,7 @@ class Unet(nn.Module):
|
|
207
207
|
conv_type,
|
208
208
|
name="final_residual",
|
209
209
|
features=self.feature_depths[0],
|
210
|
-
kernel_init=self.kernel_init(1.0),
|
210
|
+
kernel_init=self.kernel_init(scale=1.0),
|
211
211
|
kernel_size=(3,3),
|
212
212
|
strides=(1, 1),
|
213
213
|
activation=self.activation,
|
@@ -226,7 +226,7 @@ class Unet(nn.Module):
|
|
226
226
|
kernel_size=(3, 3),
|
227
227
|
strides=(1, 1),
|
228
228
|
# activation=jax.nn.mish
|
229
|
-
kernel_init=self.kernel_init(0.0),
|
229
|
+
kernel_init=self.kernel_init(scale=0.0),
|
230
230
|
dtype=self.dtype,
|
231
231
|
precision=self.precision
|
232
232
|
)(x)
|
@@ -70,7 +70,7 @@ class UViT(nn.Module):
|
|
70
70
|
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
71
|
add_residualblock_output: bool = False
|
72
72
|
norm_inputs: bool = False
|
73
|
-
explicitly_add_residual: bool =
|
73
|
+
explicitly_add_residual: bool = True
|
74
74
|
|
75
75
|
def setup(self):
|
76
76
|
if self.norm_groups > 0:
|