keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +12 -0
- keras_hub/api/models/__init__.py +32 -0
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/rms_normalization.py +34 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
- keras_hub/src/layers/preprocessing/image_converter.py +5 -0
- keras_hub/src/models/albert/albert_presets.py +0 -8
- keras_hub/src/models/bart/bart_presets.py +0 -6
- keras_hub/src/models/bert/bert_presets.py +0 -20
- keras_hub/src/models/bloom/bloom_presets.py +0 -16
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +0 -6
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
- keras_hub/src/models/efficientnet/mbconv.py +52 -21
- keras_hub/src/models/electra/electra_presets.py +0 -12
- keras_hub/src/models/f_net/f_net_presets.py +0 -4
- keras_hub/src/models/falcon/falcon_presets.py +0 -2
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +494 -0
- keras_hub/src/models/flux/flux_maths.py +218 -0
- keras_hub/src/models/flux/flux_model.py +231 -0
- keras_hub/src/models/flux/flux_presets.py +14 -0
- keras_hub/src/models/flux/flux_text_to_image.py +142 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_presets.py +0 -40
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_to_image.py +16 -10
- keras_hub/src/models/inpaint.py +20 -13
- keras_hub/src/models/llama/llama_backbone.py +1 -1
- keras_hub/src/models/llama/llama_presets.py +5 -15
- keras_hub/src/models/llama3/llama3_presets.py +0 -8
- keras_hub/src/models/mistral/mistral_presets.py +0 -6
- keras_hub/src/models/mit/mit_backbone.py +41 -27
- keras_hub/src/models/mit/mit_layers.py +9 -7
- keras_hub/src/models/mit/mit_presets.py +12 -24
- keras_hub/src/models/opt/opt_presets.py +0 -8
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
- keras_hub/src/models/phi3/phi3_presets.py +0 -4
- keras_hub/src/models/resnet/resnet_presets.py +10 -42
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
- keras_hub/src/models/roberta/roberta_presets.py +0 -4
- keras_hub/src/models/sam/sam_backbone.py +0 -1
- keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
- keras_hub/src/models/sam/sam_presets.py +0 -6
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +163 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +124 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +41 -13
- keras_hub/src/models/text_to_image.py +13 -5
- keras_hub/src/models/vgg/vgg_backbone.py +1 -1
- keras_hub/src/models/vgg/vgg_presets.py +0 -8
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
- keras_hub/src/models/whisper/whisper_presets.py +0 -20
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/utils/preset_utils.py +17 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -354,6 +354,8 @@ class DismantledBlock(layers.Layer):
|
|
354
354
|
mlp_ratio: float. The expansion ratio of `MLP`.
|
355
355
|
use_projection: bool. Whether to use an attention projection layer at
|
356
356
|
the end of the block.
|
357
|
+
qk_norm: Optional str. Whether to normalize the query and key tensors.
|
358
|
+
Available options are `None` and `"rms_norm"`. Defaults to `None`.
|
357
359
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
358
360
|
including `name`, `dtype` etc.
|
359
361
|
"""
|
@@ -364,6 +366,7 @@ class DismantledBlock(layers.Layer):
|
|
364
366
|
hidden_dim,
|
365
367
|
mlp_ratio=4.0,
|
366
368
|
use_projection=True,
|
369
|
+
qk_norm=None,
|
367
370
|
**kwargs,
|
368
371
|
):
|
369
372
|
super().__init__(**kwargs)
|
@@ -371,6 +374,7 @@ class DismantledBlock(layers.Layer):
|
|
371
374
|
self.hidden_dim = hidden_dim
|
372
375
|
self.mlp_ratio = mlp_ratio
|
373
376
|
self.use_projection = use_projection
|
377
|
+
self.qk_norm = qk_norm
|
374
378
|
|
375
379
|
head_dim = hidden_dim // num_heads
|
376
380
|
self.head_dim = head_dim
|
@@ -391,6 +395,18 @@ class DismantledBlock(layers.Layer):
|
|
391
395
|
self.attention_qkv = layers.Dense(
|
392
396
|
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
393
397
|
)
|
398
|
+
if qk_norm is not None and qk_norm == "rms_norm":
|
399
|
+
self.q_norm = layers.LayerNormalization(
|
400
|
+
epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
|
401
|
+
)
|
402
|
+
self.k_norm = layers.LayerNormalization(
|
403
|
+
epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
|
404
|
+
)
|
405
|
+
elif qk_norm is not None:
|
406
|
+
raise NotImplementedError(
|
407
|
+
"Supported `qk_norm` are `'rms_norm'` and `None`. "
|
408
|
+
f"Received: qk_norm={qk_norm}."
|
409
|
+
)
|
394
410
|
if use_projection:
|
395
411
|
self.attention_proj = layers.Dense(
|
396
412
|
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
|
@@ -413,6 +429,10 @@ class DismantledBlock(layers.Layer):
|
|
413
429
|
def build(self, inputs_shape, timestep_embedding):
|
414
430
|
self.ada_layer_norm.build(inputs_shape, timestep_embedding)
|
415
431
|
self.attention_qkv.build(inputs_shape)
|
432
|
+
if self.qk_norm is not None:
|
433
|
+
# [batch_size, sequence_length, num_heads, head_dim]
|
434
|
+
self.q_norm.build([None, None, self.num_heads, self.head_dim])
|
435
|
+
self.k_norm.build([None, None, self.num_heads, self.head_dim])
|
416
436
|
if self.use_projection:
|
417
437
|
self.attention_proj.build(inputs_shape)
|
418
438
|
self.norm2.build(inputs_shape)
|
@@ -435,6 +455,9 @@ class DismantledBlock(layers.Layer):
|
|
435
455
|
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
436
456
|
)
|
437
457
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
458
|
+
if self.qk_norm is not None:
|
459
|
+
q = self.q_norm(q, training=training)
|
460
|
+
k = self.k_norm(k, training=training)
|
438
461
|
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
439
462
|
else:
|
440
463
|
x = self.ada_layer_norm(
|
@@ -445,6 +468,9 @@ class DismantledBlock(layers.Layer):
|
|
445
468
|
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
446
469
|
)
|
447
470
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
471
|
+
if self.qk_norm is not None:
|
472
|
+
q = self.q_norm(q, training=training)
|
473
|
+
k = self.k_norm(k, training=training)
|
448
474
|
return (q, k, v)
|
449
475
|
|
450
476
|
def _compute_post_attention(
|
@@ -494,6 +520,7 @@ class DismantledBlock(layers.Layer):
|
|
494
520
|
"hidden_dim": self.hidden_dim,
|
495
521
|
"mlp_ratio": self.mlp_ratio,
|
496
522
|
"use_projection": self.use_projection,
|
523
|
+
"qk_norm": self.qk_norm,
|
497
524
|
}
|
498
525
|
)
|
499
526
|
return config
|
@@ -513,6 +540,8 @@ class MMDiTBlock(layers.Layer):
|
|
513
540
|
mlp_ratio: float. The expansion ratio of `MLP`.
|
514
541
|
use_context_projection: bool. Whether to use an attention projection
|
515
542
|
layer at the end of the context block.
|
543
|
+
qk_norm: Optional str. Whether to normalize the query and key tensors.
|
544
|
+
Available options are `None` and `"rms_norm"`. Defaults to `None`.
|
516
545
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
517
546
|
including `name`, `dtype` etc.
|
518
547
|
|
@@ -527,6 +556,7 @@ class MMDiTBlock(layers.Layer):
|
|
527
556
|
hidden_dim,
|
528
557
|
mlp_ratio=4.0,
|
529
558
|
use_context_projection=True,
|
559
|
+
qk_norm=None,
|
530
560
|
**kwargs,
|
531
561
|
):
|
532
562
|
super().__init__(**kwargs)
|
@@ -534,6 +564,7 @@ class MMDiTBlock(layers.Layer):
|
|
534
564
|
self.hidden_dim = hidden_dim
|
535
565
|
self.mlp_ratio = mlp_ratio
|
536
566
|
self.use_context_projection = use_context_projection
|
567
|
+
self.qk_norm = qk_norm
|
537
568
|
|
538
569
|
head_dim = hidden_dim // num_heads
|
539
570
|
self.head_dim = head_dim
|
@@ -544,6 +575,7 @@ class MMDiTBlock(layers.Layer):
|
|
544
575
|
hidden_dim=hidden_dim,
|
545
576
|
mlp_ratio=mlp_ratio,
|
546
577
|
use_projection=True,
|
578
|
+
qk_norm=qk_norm,
|
547
579
|
dtype=self.dtype_policy,
|
548
580
|
name="x_block",
|
549
581
|
)
|
@@ -552,6 +584,7 @@ class MMDiTBlock(layers.Layer):
|
|
552
584
|
hidden_dim=hidden_dim,
|
553
585
|
mlp_ratio=mlp_ratio,
|
554
586
|
use_projection=use_context_projection,
|
587
|
+
qk_norm=qk_norm,
|
555
588
|
dtype=self.dtype_policy,
|
556
589
|
name="context_block",
|
557
590
|
)
|
@@ -629,6 +662,7 @@ class MMDiTBlock(layers.Layer):
|
|
629
662
|
"hidden_dim": self.hidden_dim,
|
630
663
|
"mlp_ratio": self.mlp_ratio,
|
631
664
|
"use_context_projection": self.use_context_projection,
|
665
|
+
"qk_norm": self.qk_norm,
|
632
666
|
}
|
633
667
|
)
|
634
668
|
return config
|
@@ -705,6 +739,9 @@ class MMDiT(Backbone):
|
|
705
739
|
latent_shape: tuple. The shape of the latent image.
|
706
740
|
context_shape: tuple. The shape of the context.
|
707
741
|
pooled_projection_shape: tuple. The shape of the pooled projection.
|
742
|
+
qk_norm: Optional str. Whether to normalize the query and key tensors in
|
743
|
+
the intermediate blocks. Available options are `None` and
|
744
|
+
`"rms_norm"`. Defaults to `None`.
|
708
745
|
data_format: `None` or str. If specified, either `"channels_last"` or
|
709
746
|
`"channels_first"`. The ordering of the dimensions in the
|
710
747
|
inputs. `"channels_last"` corresponds to inputs with shape
|
@@ -729,6 +766,7 @@ class MMDiT(Backbone):
|
|
729
766
|
latent_shape=(64, 64, 16),
|
730
767
|
context_shape=(None, 4096),
|
731
768
|
pooled_projection_shape=(2048,),
|
769
|
+
qk_norm=None,
|
732
770
|
data_format=None,
|
733
771
|
dtype=None,
|
734
772
|
**kwargs,
|
@@ -782,6 +820,7 @@ class MMDiT(Backbone):
|
|
782
820
|
hidden_dim,
|
783
821
|
mlp_ratio,
|
784
822
|
use_context_projection=not (i == num_layers - 1),
|
823
|
+
qk_norm=qk_norm,
|
785
824
|
dtype=dtype,
|
786
825
|
name=f"joint_block_{i}",
|
787
826
|
)
|
@@ -851,6 +890,7 @@ class MMDiT(Backbone):
|
|
851
890
|
self.latent_shape = latent_shape
|
852
891
|
self.context_shape = context_shape
|
853
892
|
self.pooled_projection_shape = pooled_projection_shape
|
893
|
+
self.qk_norm = qk_norm
|
854
894
|
|
855
895
|
def get_config(self):
|
856
896
|
config = super().get_config()
|
@@ -865,6 +905,7 @@ class MMDiT(Backbone):
|
|
865
905
|
"latent_shape": self.latent_shape,
|
866
906
|
"context_shape": self.context_shape,
|
867
907
|
"pooled_projection_shape": self.pooled_projection_shape,
|
908
|
+
"qk_norm": self.qk_norm,
|
868
909
|
}
|
869
910
|
)
|
870
911
|
return config
|
@@ -202,6 +202,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
202
202
|
transformer in MMDiT.
|
203
203
|
mmdit_position_size: int. The size of the height and width for the
|
204
204
|
position embedding in MMDiT.
|
205
|
+
mmdit_qk_norm: Optional str. Whether to normalize the query and key
|
206
|
+
tensors for each transformer in MMDiT. Available options are `None`
|
207
|
+
and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
|
208
|
+
and to `"rms_norm" for 3.5 version.
|
205
209
|
vae: The VAE used for transformations between pixel space and latent
|
206
210
|
space.
|
207
211
|
clip_l: The CLIP text encoder for encoding the inputs.
|
@@ -215,8 +219,8 @@ class StableDiffusion3Backbone(Backbone):
|
|
215
219
|
model. Defaults to `1000`.
|
216
220
|
shift: float. The shift value for the timestep schedule. Defaults to
|
217
221
|
`3.0`.
|
218
|
-
|
219
|
-
|
222
|
+
image_shape: tuple. The input shape without the batch size. Defaults to
|
223
|
+
`(1024, 1024, 3)`.
|
220
224
|
data_format: `None` or str. If specified, either `"channels_last"` or
|
221
225
|
`"channels_first"`. The ordering of the dimensions in the
|
222
226
|
inputs. `"channels_last"` corresponds to inputs with shape
|
@@ -248,6 +252,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
248
252
|
mmdit_hidden_dim=256,
|
249
253
|
mmdit_depth=4,
|
250
254
|
mmdit_position_size=192,
|
255
|
+
mmdit_qk_norm=None,
|
251
256
|
vae=vae,
|
252
257
|
clip_l=clip_l,
|
253
258
|
clip_g=clip_g,
|
@@ -262,6 +267,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
262
267
|
mmdit_num_layers,
|
263
268
|
mmdit_num_heads,
|
264
269
|
mmdit_position_size,
|
270
|
+
mmdit_qk_norm,
|
265
271
|
vae,
|
266
272
|
clip_l,
|
267
273
|
clip_g,
|
@@ -270,23 +276,21 @@ class StableDiffusion3Backbone(Backbone):
|
|
270
276
|
output_channels=3,
|
271
277
|
num_train_timesteps=1000,
|
272
278
|
shift=3.0,
|
273
|
-
|
274
|
-
width=None,
|
279
|
+
image_shape=(1024, 1024, 3),
|
275
280
|
data_format=None,
|
276
281
|
dtype=None,
|
277
282
|
**kwargs,
|
278
283
|
):
|
279
|
-
height = int(height or 1024)
|
280
|
-
width = int(width or 1024)
|
281
|
-
if height % 8 != 0 or width % 8 != 0:
|
282
|
-
raise ValueError(
|
283
|
-
"`height` and `width` must be divisible by 8. "
|
284
|
-
f"Received: height={height}, width={width}"
|
285
|
-
)
|
286
284
|
data_format = standardize_data_format(data_format)
|
287
285
|
if data_format != "channels_last":
|
288
286
|
raise NotImplementedError
|
289
|
-
|
287
|
+
height = image_shape[0]
|
288
|
+
width = image_shape[1]
|
289
|
+
if height % 8 != 0 or width % 8 != 0:
|
290
|
+
raise ValueError(
|
291
|
+
"height and width in `image_shape` must be divisible by 8. "
|
292
|
+
f"Received: image_shape={image_shape}"
|
293
|
+
)
|
290
294
|
latent_shape = (height // 8, width // 8, int(latent_channels))
|
291
295
|
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
|
292
296
|
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
|
@@ -314,6 +318,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
314
318
|
latent_shape=latent_shape,
|
315
319
|
context_shape=context_shape,
|
316
320
|
pooled_projection_shape=pooled_projection_shape,
|
321
|
+
qk_norm=mmdit_qk_norm,
|
317
322
|
data_format=data_format,
|
318
323
|
dtype=dtype,
|
319
324
|
name="diffuser",
|
@@ -448,12 +453,12 @@ class StableDiffusion3Backbone(Backbone):
|
|
448
453
|
self.mmdit_num_layers = mmdit_num_layers
|
449
454
|
self.mmdit_num_heads = mmdit_num_heads
|
450
455
|
self.mmdit_position_size = mmdit_position_size
|
456
|
+
self.mmdit_qk_norm = mmdit_qk_norm
|
451
457
|
self.latent_channels = latent_channels
|
452
458
|
self.output_channels = output_channels
|
453
459
|
self.num_train_timesteps = num_train_timesteps
|
454
460
|
self.shift = shift
|
455
|
-
self.
|
456
|
-
self.width = width
|
461
|
+
self.image_shape = image_shape
|
457
462
|
|
458
463
|
@property
|
459
464
|
def latent_shape(self):
|
@@ -535,7 +540,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
535
540
|
embeddings,
|
536
541
|
step,
|
537
542
|
num_steps,
|
538
|
-
guidance_scale,
|
543
|
+
guidance_scale=None,
|
539
544
|
):
|
540
545
|
step = ops.convert_to_tensor(step)
|
541
546
|
next_step = ops.add(step, 1)
|
@@ -543,9 +548,15 @@ class StableDiffusion3Backbone(Backbone):
|
|
543
548
|
next_sigma, _ = self.scheduler(next_step, num_steps)
|
544
549
|
|
545
550
|
# Concatenation for classifier-free guidance.
|
546
|
-
|
547
|
-
|
548
|
-
|
551
|
+
if guidance_scale is not None:
|
552
|
+
concated_latents, contexts, pooled_projs, timesteps = (
|
553
|
+
self.cfg_concat(latents, *embeddings, timestep)
|
554
|
+
)
|
555
|
+
else:
|
556
|
+
timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
557
|
+
concated_latents = latents
|
558
|
+
contexts = embeddings[0]
|
559
|
+
pooled_projs = embeddings[2]
|
549
560
|
|
550
561
|
# Diffusion.
|
551
562
|
predicted_noise = self.diffuser(
|
@@ -559,7 +570,8 @@ class StableDiffusion3Backbone(Backbone):
|
|
559
570
|
)
|
560
571
|
|
561
572
|
# Classifier-free guidance.
|
562
|
-
|
573
|
+
if guidance_scale is not None:
|
574
|
+
predicted_noise = self.cfg(predicted_noise, guidance_scale)
|
563
575
|
|
564
576
|
# Euler step.
|
565
577
|
return self.euler_step(latents, predicted_noise, sigma, next_sigma)
|
@@ -577,6 +589,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
577
589
|
"mmdit_num_layers": self.mmdit_num_layers,
|
578
590
|
"mmdit_num_heads": self.mmdit_num_heads,
|
579
591
|
"mmdit_position_size": self.mmdit_position_size,
|
592
|
+
"mmdit_qk_norm": self.mmdit_qk_norm,
|
580
593
|
"vae": layers.serialize(self.vae),
|
581
594
|
"clip_l": layers.serialize(self.clip_l),
|
582
595
|
"clip_g": layers.serialize(self.clip_g),
|
@@ -585,8 +598,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
585
598
|
"output_channels": self.output_channels,
|
586
599
|
"num_train_timesteps": self.num_train_timesteps,
|
587
600
|
"shift": self.shift,
|
588
|
-
"
|
589
|
-
"width": self.width,
|
601
|
+
"image_shape": self.image_shape,
|
590
602
|
}
|
591
603
|
)
|
592
604
|
return config
|
@@ -624,4 +636,9 @@ class StableDiffusion3Backbone(Backbone):
|
|
624
636
|
config["t5"] = layers.deserialize(
|
625
637
|
config["t5"], custom_objects=custom_objects
|
626
638
|
)
|
639
|
+
|
640
|
+
# To maintain backward compatibility, we need to ensure that
|
641
|
+
# `mmdit_qk_norm` is included in the config.
|
642
|
+
if "mmdit_qk_norm" not in config:
|
643
|
+
config["mmdit_qk_norm"] = None
|
627
644
|
return cls(**config)
|
@@ -27,7 +27,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
|
|
27
27
|
Use `generate()` to do image generation.
|
28
28
|
```python
|
29
29
|
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
|
30
|
-
"stable_diffusion_3_medium",
|
30
|
+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
31
31
|
)
|
32
32
|
image_to_image.generate(
|
33
33
|
{
|
@@ -158,14 +158,14 @@ class StableDiffusion3ImageToImage(ImageToImage):
|
|
158
158
|
self,
|
159
159
|
inputs,
|
160
160
|
num_steps=50,
|
161
|
-
guidance_scale=7.0,
|
162
161
|
strength=0.8,
|
162
|
+
guidance_scale=7.0,
|
163
163
|
seed=None,
|
164
164
|
):
|
165
165
|
return super().generate(
|
166
166
|
inputs,
|
167
167
|
num_steps=num_steps,
|
168
|
-
guidance_scale=guidance_scale,
|
169
168
|
strength=strength,
|
169
|
+
guidance_scale=guidance_scale,
|
170
170
|
seed=seed,
|
171
171
|
)
|
@@ -29,7 +29,7 @@ class StableDiffusion3Inpaint(Inpaint):
|
|
29
29
|
reference_image = np.ones((1024, 1024, 3), dtype="float32")
|
30
30
|
reference_mask = np.ones((1024, 1024), dtype="float32")
|
31
31
|
inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
|
32
|
-
"stable_diffusion_3_medium",
|
32
|
+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
33
33
|
)
|
34
34
|
inpaint.generate(
|
35
35
|
reference_image,
|
@@ -181,14 +181,14 @@ class StableDiffusion3Inpaint(Inpaint):
|
|
181
181
|
self,
|
182
182
|
inputs,
|
183
183
|
num_steps=50,
|
184
|
-
guidance_scale=7.0,
|
185
184
|
strength=0.6,
|
185
|
+
guidance_scale=7.0,
|
186
186
|
seed=None,
|
187
187
|
):
|
188
188
|
return super().generate(
|
189
189
|
inputs,
|
190
190
|
num_steps=num_steps,
|
191
|
-
guidance_scale=guidance_scale,
|
192
191
|
strength=strength,
|
192
|
+
guidance_scale=guidance_scale,
|
193
193
|
seed=seed,
|
194
194
|
)
|
@@ -9,10 +9,34 @@ backbone_presets = {
|
|
9
9
|
"Developed by Stability AI."
|
10
10
|
),
|
11
11
|
"params": 2987080931,
|
12
|
-
"official_name": "StableDiffusion3",
|
13
12
|
"path": "stable_diffusion_3",
|
14
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
15
13
|
},
|
16
|
-
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/
|
17
|
-
}
|
14
|
+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3",
|
15
|
+
},
|
16
|
+
"stable_diffusion_3.5_large": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"9 billion parameter, including CLIP L and CLIP G text "
|
20
|
+
"encoders, MMDiT generative model, and VAE autoencoder. "
|
21
|
+
"Developed by Stability AI."
|
22
|
+
),
|
23
|
+
"params": 9048410595,
|
24
|
+
"path": "stable_diffusion_3",
|
25
|
+
},
|
26
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/1",
|
27
|
+
},
|
28
|
+
"stable_diffusion_3.5_large_turbo": {
|
29
|
+
"metadata": {
|
30
|
+
"description": (
|
31
|
+
"9 billion parameter, including CLIP L and CLIP G text "
|
32
|
+
"encoders, MMDiT generative model, and VAE autoencoder. "
|
33
|
+
"A timestep-distilled version that eliminates classifier-free "
|
34
|
+
"guidance and uses fewer steps for generation. "
|
35
|
+
"Developed by Stability AI."
|
36
|
+
),
|
37
|
+
"params": 9048410595,
|
38
|
+
"path": "stable_diffusion_3",
|
39
|
+
},
|
40
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/1",
|
41
|
+
},
|
18
42
|
}
|
@@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
27
27
|
Use `generate()` to do image generation.
|
28
28
|
```python
|
29
29
|
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
|
30
|
-
"stable_diffusion_3_medium",
|
30
|
+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
31
31
|
)
|
32
32
|
text_to_image.generate(
|
33
33
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
@@ -42,11 +42,12 @@ class T5Backbone(Backbone):
|
|
42
42
|
projections in the multi-head attention layers. Defaults to
|
43
43
|
hidden_dim / num_heads.
|
44
44
|
dropout: float. Dropout probability for the Transformer layers.
|
45
|
-
activation: activation function
|
46
|
-
|
47
|
-
Transformer layers. Defaults to `"relu"`.
|
45
|
+
activation: string. The activation function to use in the dense blocks
|
46
|
+
of the Transformer Layers.
|
48
47
|
use_gated_activation: boolean. Whether to use activation gating in
|
49
|
-
the inner dense blocks of the Transformer layers.
|
48
|
+
the inner dense blocks of the Transformer layers. When used with
|
49
|
+
the GELU activation function, this is referred to as GEGLU
|
50
|
+
(gated GLU) from https://arxiv.org/pdf/2002.05202.
|
50
51
|
The original T5 architecture didn't use gating, but more
|
51
52
|
recent versions do. Defaults to `True`.
|
52
53
|
layer_norm_epsilon: float. Epsilon factor to be used in the
|
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""T5 model preset configurations."""
|
2
2
|
|
3
3
|
backbone_presets = {
|
4
4
|
"t5_small_multi": {
|
@@ -8,12 +8,18 @@ backbone_presets = {
|
|
8
8
|
"Corpus (C4)."
|
9
9
|
),
|
10
10
|
"params": 0,
|
11
|
-
"official_name": "T5",
|
12
11
|
"path": "t5",
|
13
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/2",
|
16
14
|
},
|
15
|
+
"t5_1.1_small": {
|
16
|
+
"metadata": {
|
17
|
+
"description": (""),
|
18
|
+
"params": 60511616,
|
19
|
+
"path": "t5",
|
20
|
+
},
|
21
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_small/1",
|
22
|
+
},
|
17
23
|
"t5_base_multi": {
|
18
24
|
"metadata": {
|
19
25
|
"description": (
|
@@ -21,12 +27,18 @@ backbone_presets = {
|
|
21
27
|
"Corpus (C4)."
|
22
28
|
),
|
23
29
|
"params": 0,
|
24
|
-
"official_name": "T5",
|
25
30
|
"path": "t5",
|
26
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
27
31
|
},
|
28
32
|
"kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/2",
|
29
33
|
},
|
34
|
+
"t5_1.1_base": {
|
35
|
+
"metadata": {
|
36
|
+
"description": (""),
|
37
|
+
"params": 247577856,
|
38
|
+
"path": "t5",
|
39
|
+
},
|
40
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_base/1",
|
41
|
+
},
|
30
42
|
"t5_large_multi": {
|
31
43
|
"metadata": {
|
32
44
|
"description": (
|
@@ -34,12 +46,34 @@ backbone_presets = {
|
|
34
46
|
"Corpus (C4)."
|
35
47
|
),
|
36
48
|
"params": 0,
|
37
|
-
"official_name": "T5",
|
38
49
|
"path": "t5",
|
39
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
40
50
|
},
|
41
51
|
"kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/2",
|
42
52
|
},
|
53
|
+
"t5_1.1_large": {
|
54
|
+
"metadata": {
|
55
|
+
"description": (""),
|
56
|
+
"params": 750251008,
|
57
|
+
"path": "t5",
|
58
|
+
},
|
59
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_large/1",
|
60
|
+
},
|
61
|
+
"t5_1.1_xl": {
|
62
|
+
"metadata": {
|
63
|
+
"description": (""),
|
64
|
+
"params": 2849757184,
|
65
|
+
"path": "t5",
|
66
|
+
},
|
67
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xl/1",
|
68
|
+
},
|
69
|
+
"t5_1.1_xxl": {
|
70
|
+
"metadata": {
|
71
|
+
"description": (""),
|
72
|
+
"params": 11135332352,
|
73
|
+
"path": "t5",
|
74
|
+
},
|
75
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xxl/1",
|
76
|
+
},
|
43
77
|
"flan_small_multi": {
|
44
78
|
"metadata": {
|
45
79
|
"description": (
|
@@ -47,9 +81,7 @@ backbone_presets = {
|
|
47
81
|
"Corpus (C4)."
|
48
82
|
),
|
49
83
|
"params": 0,
|
50
|
-
"official_name": "T5",
|
51
84
|
"path": "t5",
|
52
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
53
85
|
},
|
54
86
|
"kaggle_handle": "kaggle://keras/t5/keras/flan_small_multi/2",
|
55
87
|
},
|
@@ -60,9 +92,7 @@ backbone_presets = {
|
|
60
92
|
"Corpus (C4)."
|
61
93
|
),
|
62
94
|
"params": 0,
|
63
|
-
"official_name": "T5",
|
64
95
|
"path": "t5",
|
65
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
66
96
|
},
|
67
97
|
"kaggle_handle": "kaggle://keras/t5/keras/flan_base_multi/2",
|
68
98
|
},
|
@@ -73,9 +103,7 @@ backbone_presets = {
|
|
73
103
|
"Corpus (C4)."
|
74
104
|
),
|
75
105
|
"params": 0,
|
76
|
-
"official_name": "T5",
|
77
106
|
"path": "t5",
|
78
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
79
107
|
},
|
80
108
|
"kaggle_handle": "kaggle://keras/t5/keras/flan_large_multi/2",
|
81
109
|
},
|
@@ -249,7 +249,7 @@ class TextToImage(Task):
|
|
249
249
|
self,
|
250
250
|
inputs,
|
251
251
|
num_steps,
|
252
|
-
guidance_scale,
|
252
|
+
guidance_scale=None,
|
253
253
|
seed=None,
|
254
254
|
):
|
255
255
|
"""Generate image based on the provided `inputs`.
|
@@ -283,15 +283,23 @@ class TextToImage(Task):
|
|
283
283
|
- A `tf.data.Dataset` with "prompts" and/or "negative_prompts"
|
284
284
|
keys
|
285
285
|
num_steps: int. The number of diffusion steps to take.
|
286
|
-
guidance_scale: float. The classifier free guidance scale
|
287
|
-
[Classifier-Free Diffusion Guidance](
|
286
|
+
guidance_scale: Optional float. The classifier free guidance scale
|
287
|
+
defined in [Classifier-Free Diffusion Guidance](
|
288
288
|
https://arxiv.org/abs/2207.12598). A higher scale encourages
|
289
289
|
generating images more closely related to the prompts, typically
|
290
|
-
at the cost of lower image quality.
|
290
|
+
at the cost of lower image quality. Note that some models don't
|
291
|
+
utilize classifier-free guidance.
|
291
292
|
seed: optional int. Used as a random seed.
|
292
293
|
"""
|
294
|
+
num_steps = int(num_steps)
|
295
|
+
guidance_scale = (
|
296
|
+
float(guidance_scale) if guidance_scale is not None else None
|
297
|
+
)
|
293
298
|
num_steps = ops.convert_to_tensor(num_steps, "int32")
|
294
|
-
guidance_scale
|
299
|
+
if guidance_scale is not None and guidance_scale > 1.0:
|
300
|
+
guidance_scale = ops.convert_to_tensor(guidance_scale)
|
301
|
+
else:
|
302
|
+
guidance_scale = None
|
295
303
|
|
296
304
|
# Setup our three main passes.
|
297
305
|
# 1. Preprocessing strings to dense integer tensors.
|
@@ -27,7 +27,7 @@ class VGGBackbone(Backbone):
|
|
27
27
|
input_data = np.ones((2, 224, 224, 3), dtype="float32")
|
28
28
|
|
29
29
|
# Pretrained VGG backbone.
|
30
|
-
model = keras_hub.models.VGGBackbone.from_preset("
|
30
|
+
model = keras_hub.models.VGGBackbone.from_preset("vgg_16_imagenet")
|
31
31
|
model(input_data)
|
32
32
|
|
33
33
|
# Randomly initialized VGG backbone with a custom config.
|
@@ -8,9 +8,7 @@ backbone_presets = {
|
|
8
8
|
"at a 224x224 resolution."
|
9
9
|
),
|
10
10
|
"params": 9220480,
|
11
|
-
"official_name": "vgg",
|
12
11
|
"path": "vgg",
|
13
|
-
"model_card": "https://arxiv.org/abs/1409.1556",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/vgg/keras/vgg_11_imagenet/1",
|
16
14
|
},
|
@@ -21,9 +19,7 @@ backbone_presets = {
|
|
21
19
|
"at a 224x224 resolution."
|
22
20
|
),
|
23
21
|
"params": 9404992,
|
24
|
-
"official_name": "vgg",
|
25
22
|
"path": "vgg",
|
26
|
-
"model_card": "https://arxiv.org/abs/1409.1556",
|
27
23
|
},
|
28
24
|
"kaggle_handle": "kaggle://keras/vgg/keras/vgg_13_imagenet/1",
|
29
25
|
},
|
@@ -34,9 +30,7 @@ backbone_presets = {
|
|
34
30
|
"at a 224x224 resolution."
|
35
31
|
),
|
36
32
|
"params": 14714688,
|
37
|
-
"official_name": "vgg",
|
38
33
|
"path": "vgg",
|
39
|
-
"model_card": "https://arxiv.org/abs/1409.1556",
|
40
34
|
},
|
41
35
|
"kaggle_handle": "kaggle://keras/vgg/keras/vgg_16_imagenet/1",
|
42
36
|
},
|
@@ -47,9 +41,7 @@ backbone_presets = {
|
|
47
41
|
"at a 224x224 resolution."
|
48
42
|
),
|
49
43
|
"params": 20024384,
|
50
|
-
"official_name": "vgg",
|
51
44
|
"path": "vgg",
|
52
|
-
"model_card": "https://arxiv.org/abs/1409.1556",
|
53
45
|
},
|
54
46
|
"kaggle_handle": "kaggle://keras/vgg/keras/vgg_19_imagenet/1",
|
55
47
|
},
|
@@ -39,7 +39,7 @@ class WhisperAudioConverter(AudioConverter):
|
|
39
39
|
audio_tensor = tf.ones((8000,), dtype="float32")
|
40
40
|
|
41
41
|
# Compute the log-mel spectrogram.
|
42
|
-
audio_converter = keras_hub.
|
42
|
+
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
|
43
43
|
"whisper_base_en",
|
44
44
|
)
|
45
45
|
audio_converter(audio_tensor)
|