flaxdiff 0.1.36.3__py3-none-any.whl → 0.1.36.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/sources/tfds.py +12 -0
- flaxdiff/trainer/diffusion_trainer.py +6 -7
- {flaxdiff-0.1.36.3.dist-info → flaxdiff-0.1.36.4.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.36.3.dist-info → flaxdiff-0.1.36.4.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.36.3.dist-info → flaxdiff-0.1.36.4.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.36.3.dist-info → flaxdiff-0.1.36.4.dist-info}/top_level.txt +0 -0
flaxdiff/data/sources/tfds.py
CHANGED
@@ -4,6 +4,8 @@ import grain.python as pygrain
|
|
4
4
|
from flaxdiff.utils import AutoTextTokenizer
|
5
5
|
from typing import Dict
|
6
6
|
import random
|
7
|
+
import augmax
|
8
|
+
import jax
|
7
9
|
|
8
10
|
# -----------------------------------------------------------------------------------------------#
|
9
11
|
# Oxford flowers and other TFDS datasources -----------------------------------------------------#
|
@@ -47,6 +49,15 @@ def tfds_augmenters(image_scale, method):
|
|
47
49
|
interpolation = cv2.INTER_CUBIC
|
48
50
|
else:
|
49
51
|
interpolation = cv2.INTER_AREA
|
52
|
+
|
53
|
+
augments = augmax.Chain(
|
54
|
+
augmax.HorizontalFlip(0.5),
|
55
|
+
augmax.RandomContrast((-0.05, 0.05), 1.),
|
56
|
+
augmax.RandomBrightness((-0.2, 0.2), 1.)
|
57
|
+
)
|
58
|
+
|
59
|
+
augments = jax.jit(augments, backend="cpu")
|
60
|
+
|
50
61
|
class augmenters(pygrain.MapTransform):
|
51
62
|
def __init__(self, *args, **kwargs):
|
52
63
|
super().__init__(*args, **kwargs)
|
@@ -56,6 +67,7 @@ def tfds_augmenters(image_scale, method):
|
|
56
67
|
image = element['image']
|
57
68
|
image = cv2.resize(image, (image_scale, image_scale),
|
58
69
|
interpolation=interpolation)
|
70
|
+
# image = augments(image)
|
59
71
|
# image = (image - 127.5) / 127.5
|
60
72
|
caption = labelizer(element)
|
61
73
|
results = self.tokenize(caption)
|
@@ -14,6 +14,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
|
|
14
14
|
from ..schedulers import NoiseScheduler
|
15
15
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
16
16
|
from ..samplers.common import DiffusionSampler
|
17
|
+
from ..samplers.ddim import DDIMSampler
|
17
18
|
|
18
19
|
from flaxdiff.utils import RandomMarkovState
|
19
20
|
|
@@ -179,9 +180,6 @@ class DiffusionTrainer(SimpleTrainer):
|
|
179
180
|
nloss = loss_fn(preds, expected_output)
|
180
181
|
# Ignore the loss contribution of images with zero standard deviation
|
181
182
|
nloss *= noise_schedule.get_weights(noise_level)
|
182
|
-
# nloss = jnp.mean(nloss, axis=(1,2,3))
|
183
|
-
# nloss = jnp.where(is_non_zero, nloss, 0)
|
184
|
-
# nloss = jnp.mean(nloss, where=nloss != 0)
|
185
183
|
nloss = jnp.mean(nloss)
|
186
184
|
loss = nloss
|
187
185
|
return loss
|
@@ -224,11 +222,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
224
222
|
if distributed_training:
|
225
223
|
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
226
224
|
out_specs=(P(), P(), P()))
|
227
|
-
|
225
|
+
train_step = jax.jit(train_step)
|
228
226
|
|
229
227
|
return train_step
|
230
228
|
|
231
|
-
def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
|
229
|
+
def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
|
232
230
|
model = self.model
|
233
231
|
encoder = self.encoder
|
234
232
|
autoencoder = self.autoencoder
|
@@ -241,7 +239,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
241
239
|
sampler = sampler_class(
|
242
240
|
model=model,
|
243
241
|
params=state.ema_params,
|
244
|
-
noise_schedule=self.noise_schedule,
|
242
|
+
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
245
243
|
model_output_transform=self.model_output_transform,
|
246
244
|
image_size=self.input_shapes['x'][0],
|
247
245
|
null_labels_seq=null_labels_full,
|
@@ -311,10 +309,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
311
309
|
print("Error logging images to wandb", e)
|
312
310
|
traceback.print_exc()
|
313
311
|
|
314
|
-
def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
|
312
|
+
def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
|
315
313
|
local_batch_size = data['local_batch_size']
|
316
314
|
validation_step_args = {
|
317
315
|
"sampler_class": sampler_class,
|
316
|
+
"sampling_noise_schedule": sampling_noise_schedule,
|
318
317
|
}
|
319
318
|
super().fit(
|
320
319
|
data,
|
@@ -5,7 +5,7 @@ flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,
|
|
5
5
|
flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
|
6
6
|
flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
|
7
7
|
flaxdiff/data/sources/gcs.py,sha256=11ZuQhvMyJRLg21DgVdzO5qEuae7zgzTXGNOskF-cbs,3380
|
8
|
-
flaxdiff/data/sources/tfds.py,sha256=
|
8
|
+
flaxdiff/data/sources/tfds.py,sha256=7n-uobG_UvkD5mU_1ovPd9kb6xJrbEKFFXdVEHDunts,2781
|
9
9
|
flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
|
10
10
|
flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
|
11
11
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
@@ -38,10 +38,10 @@ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k
|
|
38
38
|
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
39
39
|
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
40
40
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
|
41
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=
|
41
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=zde_nRzsC2GD5KNCn5Qjw9ldHi7L_-teJhcUNUDCdcQ,12815
|
42
42
|
flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
|
43
43
|
flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
|
44
|
-
flaxdiff-0.1.36.
|
45
|
-
flaxdiff-0.1.36.
|
46
|
-
flaxdiff-0.1.36.
|
47
|
-
flaxdiff-0.1.36.
|
44
|
+
flaxdiff-0.1.36.4.dist-info/METADATA,sha256=MTgRu4VgbQaGqbGv_S3wXd_dzeNmHXnixRdvs93dWj0,22310
|
45
|
+
flaxdiff-0.1.36.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
46
|
+
flaxdiff-0.1.36.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
47
|
+
flaxdiff-0.1.36.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|