flaxdiff 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +19 -5
- flaxdiff/models/simple_vit.py +12 -12
- {flaxdiff-0.1.30.dist-info → flaxdiff-0.1.32.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.30.dist-info → flaxdiff-0.1.32.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.30.dist-info → flaxdiff-0.1.32.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.30.dist-info → flaxdiff-0.1.32.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -117,7 +117,12 @@ def map_sample(
|
|
117
117
|
# "error": str(e)
|
118
118
|
# })
|
119
119
|
pass
|
120
|
-
|
120
|
+
|
121
|
+
def default_feature_extractor(sample):
|
122
|
+
return {
|
123
|
+
"url": sample["url"],
|
124
|
+
"caption": sample["caption"],
|
125
|
+
}
|
121
126
|
|
122
127
|
def map_batch(
|
123
128
|
batch, num_threads=256, image_shape=(256, 256),
|
@@ -125,6 +130,7 @@ def map_batch(
|
|
125
130
|
timeout=15, retries=3, image_processor=default_image_processor,
|
126
131
|
upscale_interpolation=cv2.INTER_CUBIC,
|
127
132
|
downscale_interpolation=cv2.INTER_AREA,
|
133
|
+
feature_extractor=default_feature_extractor,
|
128
134
|
):
|
129
135
|
try:
|
130
136
|
map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
@@ -132,7 +138,9 @@ def map_batch(
|
|
132
138
|
upscale_interpolation=upscale_interpolation,
|
133
139
|
downscale_interpolation=downscale_interpolation)
|
134
140
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
135
|
-
|
141
|
+
features = feature_extractor(batch)
|
142
|
+
url, caption = features["url"], features["caption"]
|
143
|
+
executor.map(map_sample_fn, url, caption)
|
136
144
|
except Exception as e:
|
137
145
|
print(f"Error maping batch", e)
|
138
146
|
traceback.print_exc()
|
@@ -149,12 +157,14 @@ def parallel_image_loader(
|
|
149
157
|
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
150
158
|
upscale_interpolation=cv2.INTER_CUBIC,
|
151
159
|
downscale_interpolation=cv2.INTER_AREA,
|
160
|
+
feature_extractor=default_feature_extractor,
|
152
161
|
):
|
153
162
|
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
154
163
|
min_image_shape=min_image_shape,
|
155
164
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
156
165
|
upscale_interpolation=upscale_interpolation,
|
157
|
-
downscale_interpolation=downscale_interpolation
|
166
|
+
downscale_interpolation=downscale_interpolation,
|
167
|
+
feature_extractor=feature_extractor)
|
158
168
|
shard_len = len(dataset) // num_workers
|
159
169
|
print(f"Local Shard lengths: {shard_len}")
|
160
170
|
with multiprocessing.Pool(num_workers) as pool:
|
@@ -181,6 +191,7 @@ class ImageBatchIterator:
|
|
181
191
|
image_processor=default_image_processor,
|
182
192
|
upscale_interpolation=cv2.INTER_CUBIC,
|
183
193
|
downscale_interpolation=cv2.INTER_AREA,
|
194
|
+
feature_extractor=default_feature_extractor,
|
184
195
|
):
|
185
196
|
self.dataset = dataset
|
186
197
|
self.num_workers = num_workers
|
@@ -191,7 +202,8 @@ class ImageBatchIterator:
|
|
191
202
|
num_workers=num_workers,
|
192
203
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
193
204
|
upscale_interpolation=upscale_interpolation,
|
194
|
-
downscale_interpolation=downscale_interpolation
|
205
|
+
downscale_interpolation=downscale_interpolation,
|
206
|
+
feature_extractor=feature_extractor)
|
195
207
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
196
208
|
self.thread.start()
|
197
209
|
|
@@ -256,6 +268,7 @@ class OnlineStreamingDataLoader():
|
|
256
268
|
image_processor=default_image_processor,
|
257
269
|
upscale_interpolation=cv2.INTER_CUBIC,
|
258
270
|
downscale_interpolation=cv2.INTER_AREA,
|
271
|
+
feature_extractor=default_feature_extractor,
|
259
272
|
):
|
260
273
|
if isinstance(dataset, str):
|
261
274
|
dataset_path = dataset
|
@@ -281,7 +294,8 @@ class OnlineStreamingDataLoader():
|
|
281
294
|
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
282
295
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
283
296
|
upscale_interpolation=upscale_interpolation,
|
284
|
-
downscale_interpolation=downscale_interpolation
|
297
|
+
downscale_interpolation=downscale_interpolation,
|
298
|
+
feature_extractor=feature_extractor)
|
285
299
|
self.batch_size = batch_size
|
286
300
|
|
287
301
|
# Launch a thread to load batches in the background
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -23,7 +23,7 @@ class PatchEmbedding(nn.Module):
|
|
23
23
|
embedding_dim: int
|
24
24
|
dtype: Any = jnp.float32
|
25
25
|
precision: Any = jax.lax.Precision.HIGH
|
26
|
-
kernel_init: Callable = kernel_init
|
26
|
+
kernel_init: Callable = partial(kernel_init, 1.0)
|
27
27
|
|
28
28
|
@nn.compact
|
29
29
|
def __call__(self, x):
|
@@ -34,7 +34,7 @@ class PatchEmbedding(nn.Module):
|
|
34
34
|
kernel_size=(self.patch_size, self.patch_size),
|
35
35
|
strides=(self.patch_size, self.patch_size),
|
36
36
|
dtype=self.dtype,
|
37
|
-
kernel_init=self.kernel_init,
|
37
|
+
kernel_init=self.kernel_init(),
|
38
38
|
precision=self.precision)(x)
|
39
39
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
40
40
|
return x
|
@@ -67,7 +67,7 @@ class UViT(nn.Module):
|
|
67
67
|
norm_groups:int=8
|
68
68
|
dtype: Optional[Dtype] = None
|
69
69
|
precision: PrecisionLike = None
|
70
|
-
kernel_init: Callable = partial(kernel_init)
|
70
|
+
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
71
|
add_residualblock_output: bool = False
|
72
72
|
|
73
73
|
def setup(self):
|
@@ -86,10 +86,10 @@ class UViT(nn.Module):
|
|
86
86
|
|
87
87
|
# Patch embedding
|
88
88
|
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
89
|
-
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init
|
89
|
+
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
|
90
90
|
num_patches = x.shape[1]
|
91
91
|
|
92
|
-
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(
|
92
|
+
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
93
93
|
dtype=self.dtype, precision=self.precision)(textcontext)
|
94
94
|
num_text_tokens = textcontext.shape[1]
|
95
95
|
|
@@ -112,7 +112,7 @@ class UViT(nn.Module):
|
|
112
112
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
113
113
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
114
114
|
only_pure_attention=False,
|
115
|
-
kernel_init=self.kernel_init(
|
115
|
+
kernel_init=self.kernel_init())(x)
|
116
116
|
skips.append(x)
|
117
117
|
|
118
118
|
# Middle block
|
@@ -120,24 +120,24 @@ class UViT(nn.Module):
|
|
120
120
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
121
121
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
122
122
|
only_pure_attention=False,
|
123
|
-
kernel_init=self.kernel_init(
|
123
|
+
kernel_init=self.kernel_init())(x)
|
124
124
|
|
125
125
|
# # Out blocks
|
126
126
|
for i in range(self.num_layers // 2):
|
127
127
|
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
128
|
-
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(
|
128
|
+
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
129
129
|
dtype=self.dtype, precision=self.precision)(x)
|
130
130
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
131
131
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
132
132
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
133
133
|
only_pure_attention=False,
|
134
|
-
kernel_init=self.kernel_init(
|
134
|
+
kernel_init=self.kernel_init())(x)
|
135
135
|
|
136
136
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
137
137
|
x = self.norm()(x)
|
138
138
|
|
139
139
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
140
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(
|
140
|
+
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
|
141
141
|
x = x[:, 1 + num_text_tokens:, :]
|
142
142
|
x = unpatchify(x, channels=self.output_channels)
|
143
143
|
|
@@ -151,7 +151,7 @@ class UViT(nn.Module):
|
|
151
151
|
kernel_size=(3, 3),
|
152
152
|
strides=(1, 1),
|
153
153
|
# activation=jax.nn.mish
|
154
|
-
kernel_init=self.kernel_init(0.0),
|
154
|
+
kernel_init=self.kernel_init(scale=0.0),
|
155
155
|
dtype=self.dtype,
|
156
156
|
precision=self.precision
|
157
157
|
)(x)
|
@@ -165,7 +165,7 @@ class UViT(nn.Module):
|
|
165
165
|
kernel_size=(3, 3),
|
166
166
|
strides=(1, 1),
|
167
167
|
# activation=jax.nn.mish
|
168
|
-
kernel_init=self.kernel_init(0.0),
|
168
|
+
kernel_init=self.kernel_init(scale=0.0),
|
169
169
|
dtype=self.dtype,
|
170
170
|
precision=self.precision
|
171
171
|
)(x)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=Yi1thc2izJoXeB9UPTAkfWhWp8m4UYWi6cDwW6I6zUc,11661
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
9
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=NHt6v-teGjiI65fk1l1WN3WqfeqTE7xY9VYqBiYUDgI,7454
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.32.dist-info/METADATA,sha256=yqODOUEAlcE_60tOOfSP7XoH0VTBm3dY9J1ali36VFY,22083
|
38
|
+
flaxdiff-0.1.32.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.32.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|