keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. keras_hub/api/layers/__init__.py +1 -0
  2. keras_hub/api/models/__init__.py +11 -6
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/converters.py +2 -2
  5. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  6. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  7. keras_hub/src/layers/modeling/rms_normalization.py +8 -6
  8. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  9. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  10. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  11. keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
  12. keras_hub/src/metrics/bleu.py +1 -1
  13. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  14. keras_hub/src/models/bart/bart_backbone.py +4 -4
  15. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  16. keras_hub/src/models/bert/bert_presets.py +4 -2
  17. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  18. keras_hub/src/models/causal_lm.py +19 -15
  19. keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
  20. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  21. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  22. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  23. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  24. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  25. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  26. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
  27. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
  28. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
  29. keras_hub/src/models/densenet/densenet_backbone.py +3 -1
  30. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
  31. keras_hub/src/models/densenet/densenet_presets.py +6 -6
  32. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  33. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  34. keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
  35. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  36. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  37. keras_hub/src/models/efficientnet/cba.py +1 -1
  38. keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
  39. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
  40. keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
  41. keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
  42. keras_hub/src/models/efficientnet/mbconv.py +1 -1
  43. keras_hub/src/models/electra/electra_backbone.py +2 -2
  44. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  45. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  46. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  47. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  48. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  49. keras_hub/src/models/flux/flux_layers.py +46 -44
  50. keras_hub/src/models/flux/flux_maths.py +24 -17
  51. keras_hub/src/models/flux/flux_model.py +24 -19
  52. keras_hub/src/models/flux/flux_presets.py +2 -1
  53. keras_hub/src/models/flux/flux_text_to_image.py +7 -3
  54. keras_hub/src/models/gemma/gemma_backbone.py +27 -20
  55. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  56. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  57. keras_hub/src/models/gemma/gemma_presets.py +9 -3
  58. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  59. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  60. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  61. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  62. keras_hub/src/models/image_classifier_preprocessor.py +4 -1
  63. keras_hub/src/models/image_object_detector.py +2 -2
  64. keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
  65. keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
  66. keras_hub/src/models/llama/llama_backbone.py +34 -26
  67. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  68. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  69. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  70. keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
  71. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  72. keras_hub/src/models/mit/mit_backbone.py +4 -3
  73. keras_hub/src/models/mit/mit_layers.py +2 -1
  74. keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
  75. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  76. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
  77. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
  78. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  79. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  80. keras_hub/src/models/preprocessor.py +2 -2
  81. keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
  82. keras_hub/src/models/retinanet/prediction_head.py +2 -2
  83. keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
  84. keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
  85. keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
  86. keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
  87. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  88. keras_hub/src/models/roberta/roberta_presets.py +4 -2
  89. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  90. keras_hub/src/models/sam/sam_backbone.py +2 -2
  91. keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
  92. keras_hub/src/models/sam/sam_layers.py +5 -3
  93. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  94. keras_hub/src/models/sam/sam_transformer.py +5 -4
  95. keras_hub/src/models/segformer/segformer_backbone.py +18 -14
  96. keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
  97. keras_hub/src/models/segformer/segformer_presets.py +24 -12
  98. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  99. keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
  100. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
  101. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
  102. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
  103. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
  104. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  105. keras_hub/src/models/task.py +4 -2
  106. keras_hub/src/models/text_classifier.py +2 -2
  107. keras_hub/src/models/text_to_image.py +5 -1
  108. keras_hub/src/models/vae/vae_layers.py +0 -1
  109. keras_hub/src/models/vit/__init__.py +5 -0
  110. keras_hub/src/models/vit/vit_backbone.py +152 -0
  111. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  112. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  113. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  114. keras_hub/src/models/vit/vit_layers.py +391 -0
  115. keras_hub/src/models/vit/vit_presets.py +49 -0
  116. keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
  117. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  118. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
  119. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  120. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  121. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  122. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  123. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  124. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  125. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  126. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  127. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  128. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  129. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  130. keras_hub/src/samplers/sampler.py +2 -1
  131. keras_hub/src/tests/test_case.py +2 -2
  132. keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
  133. keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
  134. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  135. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
  136. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
  137. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
  138. keras_hub/src/utils/preset_utils.py +25 -18
  139. keras_hub/src/utils/tensor_utils.py +4 -4
  140. keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
  141. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  142. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  143. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  144. keras_hub/src/version_utils.py +1 -1
  145. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
  146. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
  147. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
  148. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -31,14 +31,15 @@ class TwoWayTransformer(keras.layers.Layer):
31
31
  location and type.
32
32
 
33
33
  Args:
34
- num_layers: int, optional. The num_layers of the attention blocks (the number
35
- of attention blocks to use). Defaults to `2`.
34
+ num_layers: int, optional. The num_layers of the attention blocks
35
+ (the number of attention blocks to use). Defaults to `2`.
36
36
  hidden_size: int, optional. The number of features of the input image
37
37
  and point embeddings. Defaults to `256`.
38
38
  num_heads: int, optional. Number of heads to use in the attention
39
39
  layers. Defaults to `8`.
40
- intermediate_dim: int, optional. The number of units in the hidden layer of
41
- the MLP block used in the attention layers. Defaults to `2048`.
40
+ intermediate_dim: int, optional. The number of units in the hidden
41
+ layer of the MLP block used in the attention layers.
42
+ Defaults to `2048`.
42
43
  activation: str, optional. The activation of the MLP block's output
43
44
  layer used in the attention layers. Defaults to `"relu"`.
44
45
  attention_downsample_rate: int, optional. The downsample rate of the
@@ -6,18 +6,19 @@ from keras_hub.src.models.backbone import Backbone
6
6
 
7
7
  @keras_hub_export("keras_hub.models.SegFormerBackbone")
8
8
  class SegFormerBackbone(Backbone):
9
- """A Keras model implementing the SegFormer architecture for semantic segmentation.
9
+ """A Keras model implementing SegFormer for semantic segmentation.
10
10
 
11
- This class implements the majority of the SegFormer architecture described in
12
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers]
13
- (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision]
14
- (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
11
+ This class implements the majority of the SegFormer architecture described
12
+ in [SegFormer: Simple and Efficient Design for Semantic Segmentation](https://arxiv.org/abs/2105.15203)
13
+ and based on the TensorFlow implementation
14
+ [from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
15
15
 
16
- SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and
17
- and use a very lightweight all-MLP decoder head.
16
+ SegFormers are meant to be used with the MixTransformer (MiT) encoder
17
+ family, and use a very lightweight all-MLP decoder head.
18
18
 
19
- The MiT encoder uses a hierarchical transformer which outputs features at multiple scales,
20
- similar to that of the hierarchical outputs typically associated with CNNs.
19
+ The MiT encoder uses a hierarchical transformer which outputs features at
20
+ multiple scales, similar to that of the hierarchical outputs typically
21
+ associated with CNNs.
21
22
 
22
23
  Args:
23
24
  image_encoder: `keras.Model`. The backbone network for the model that is
@@ -50,7 +51,8 @@ class SegFormerBackbone(Backbone):
50
51
  strides=[4, 2, 2, 2],
51
52
  )
52
53
 
53
- segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)
54
+ segformer_backbone = keras_hub.models.SegFormerBackbone(
55
+ image_encoder=backbone, projection_filters=256)
54
56
  ```
55
57
 
56
58
  Using the class with a preset `backbone`:
@@ -59,7 +61,8 @@ class SegFormerBackbone(Backbone):
59
61
  import keras_hub
60
62
 
61
63
  backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
62
- segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)
64
+ segformer_backbone = keras_hub.models.SegFormerBackbone(
65
+ image_encoder=backbone, projection_filters=256)
63
66
  ```
64
67
 
65
68
  """
@@ -74,9 +77,10 @@ class SegFormerBackbone(Backbone):
74
77
  image_encoder, keras.Model
75
78
  ):
76
79
  raise ValueError(
77
- "Argument `image_encoder` must be a `keras.layers.Layer` instance "
78
- f" or `keras.Model`. Received instead "
79
- f"image_encoder={image_encoder} (of type {type(image_encoder)})."
80
+ "Argument `image_encoder` must be a `keras.layers.Layer` "
81
+ f"instance or `keras.Model`. Received instead "
82
+ f"image_encoder={image_encoder} "
83
+ f"(of type {type(image_encoder)})."
80
84
  )
81
85
 
82
86
  # === Layers ===
@@ -3,41 +3,43 @@ import keras
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
  from keras_hub.src.models.image_segmenter import ImageSegmenter
5
5
  from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
6
- from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
6
+ from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( # noqa: E501
7
7
  SegFormerImageSegmenterPreprocessor,
8
8
  )
9
9
 
10
10
 
11
11
  @keras_hub_export("keras_hub.models.SegFormerImageSegmenter")
12
12
  class SegFormerImageSegmenter(ImageSegmenter):
13
- """A Keras model implementing the SegFormer architecture for semantic segmentation.
13
+ """A Keras model implementing SegFormer for semantic segmentation.
14
14
 
15
- This class implements the segmentation head of the SegFormer architecture described in
16
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers]
17
- (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision]
15
+ This class implements the segmentation head of the SegFormer architecture
16
+ described in [SegFormer: Simple and Efficient Design for Semantic
17
+ Segmentation with Transformers] (https://arxiv.org/abs/2105.15203) and
18
+ [based on the TensorFlow implementation from DeepVision]
18
19
  (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
19
20
 
20
- SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and
21
- and use a very lightweight all-MLP decoder head.
21
+ SegFormers are meant to be used with the MixTransformer (MiT) encoder
22
+ family, and and use a very lightweight all-MLP decoder head.
22
23
 
23
- The MiT encoder uses a hierarchical transformer which outputs features at multiple scales,
24
- similar to that of the hierarchical outputs typically associated with CNNs.
24
+ The MiT encoder uses a hierarchical transformer which outputs features at
25
+ multiple scales, similar to that of the hierarchical outputs typically
26
+ associated with CNNs.
25
27
 
26
28
  Args:
27
29
  image_encoder: `keras.Model`. The backbone network for the model that is
28
- used as a feature extractor for the SegFormer encoder.
29
- It is *intended* to be used only with the MiT backbone model
30
- (`keras_hub.models.MiTBackbone`) which was created
31
- specifically for SegFormers.
32
- Alternatively, can be a `keras_hub.models.Backbone` a model subclassing
33
- `keras_hub.models.FeaturePyramidBackbone`, or a `keras.Model`
34
- that has a `pyramid_outputs` property which is
35
- a dictionary with keys "P2", "P3", "P4", and "P5" and layer names as values.
30
+ used as a feature extractor for the SegFormer encoder. It is
31
+ *intended* to be used only with the MiT backbone model
32
+ (`keras_hub.models.MiTBackbone`) which was created specifically for
33
+ SegFormers. Alternatively, can be a `keras_hub.models.Backbone` a
34
+ model subclassing `keras_hub.models.FeaturePyramidBackbone`, or a
35
+ `keras.Model` that has a `pyramid_outputs` property which is a
36
+ dictionary with keys "P2", "P3", "P4", and "P5" and layer names as
37
+ values.
36
38
  num_classes: int, the number of classes for the detection model,
37
39
  including the background class.
38
40
  projection_filters: int, number of filters in the
39
- convolution layer projecting the concatenated features into
40
- a segmentation map. Defaults to 256`.
41
+ convolution layer projecting the concatenated features into a
42
+ segmentation map. Defaults to 256`.
41
43
 
42
44
 
43
45
  Example:
@@ -45,10 +47,9 @@ class SegFormerImageSegmenter(ImageSegmenter):
45
47
  Using presets:
46
48
 
47
49
  ```python
48
- import keras_hub
49
- import numpy as np
50
-
51
- segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset("segformer_b0_ade20k_512")
50
+ segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset(
51
+ "segformer_b0_ade20k_512"
52
+ )
52
53
 
53
54
  images = np.random.rand(1, 512, 512, 3)
54
55
  segformer(images)
@@ -57,17 +58,18 @@ class SegFormerImageSegmenter(ImageSegmenter):
57
58
  Using the SegFormer backbone:
58
59
 
59
60
  ```python
60
- encoder = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
61
- backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256)
61
+ encoder = keras_hub.models.MiTBackbone.from_preset(
62
+ "mit_b0_ade20k_512"
63
+ )
64
+ backbone = keras_hub.models.SegFormerBackbone(
65
+ image_encoder=encoder,
66
+ projection_filters=256,
67
+ )
62
68
  ```
63
69
 
64
70
  Using the SegFormer backbone with a custom encoder:
65
71
 
66
72
  ```python
67
- import keras
68
- import keras_hub
69
- import numpy as np
70
-
71
73
  images = np.ones(shape=(1, 96, 96, 3))
72
74
  labels = np.zeros(shape=(1, 96, 96, 1))
73
75
 
@@ -83,20 +85,31 @@ class SegFormerImageSegmenter(ImageSegmenter):
83
85
  strides=[4, 2, 2, 2],
84
86
  )
85
87
 
86
- backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256)
87
- segformer = keras_hub.models.SegFormerImageSegmenter(backbone=backbone, num_classes=4)
88
-
89
- segformer(images)
88
+ backbone = keras_hub.models.SegFormerBackbone(
89
+ image_encoder=encoder,
90
+ projection_filters=256,
91
+ )
92
+ segformer = keras_hub.models.SegFormerImageSegmenter(
93
+ backbone=backbone,
94
+ num_classes=4,
95
+ )
96
+ segformer(images
90
97
  ```
91
98
 
92
99
  Using the segmentor class with a preset backbone:
93
100
 
94
101
  ```python
95
- import keras_hub
96
-
97
- image_encoder = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
98
- backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256)
99
- segformer = keras_hub.models.SegFormerImageSegmenter(backbone=backbone, num_classes=4)
102
+ image_encoder = keras_hub.models.MiTBackbone.from_preset(
103
+ "mit_b0_ade20k_512"
104
+ )
105
+ backbone = keras_hub.models.SegFormerBackbone(
106
+ image_encoder=encoder,
107
+ projection_filters=256,
108
+ )
109
+ segformer = keras_hub.models.SegFormerImageSegmenter(
110
+ backbone=backbone,
111
+ num_classes=4,
112
+ )
100
113
  ```
101
114
  """
102
115
 
@@ -4,7 +4,8 @@ presets = {
4
4
  "segformer_b0_ade20k_512": {
5
5
  "metadata": {
6
6
  "description": (
7
- "SegFormer model with MiTB0 backbone fine-tuned on ADE20k in 512x512 resolution."
7
+ "SegFormer model with MiTB0 backbone fine-tuned on ADE20k in "
8
+ "512x512 resolution."
8
9
  ),
9
10
  "params": 3719027,
10
11
  "path": "segformer_b0",
@@ -14,7 +15,8 @@ presets = {
14
15
  "segformer_b1_ade20k_512": {
15
16
  "metadata": {
16
17
  "description": (
17
- "SegFormer model with MiTB1 backbone fine-tuned on ADE20k in 512x512 resolution."
18
+ "SegFormer model with MiTB1 backbone fine-tuned on ADE20k in "
19
+ "512x512 resolution."
18
20
  ),
19
21
  "params": 13682643,
20
22
  "path": "segformer_b1",
@@ -24,7 +26,8 @@ presets = {
24
26
  "segformer_b2_ade20k_512": {
25
27
  "metadata": {
26
28
  "description": (
27
- "SegFormer model with MiTB2 backbone fine-tuned on ADE20k in 512x512 resolution."
29
+ "SegFormer model with MiTB2 backbone fine-tuned on ADE20k in "
30
+ "512x512 resolution."
28
31
  ),
29
32
  "params": 24727507,
30
33
  "path": "segformer_b2",
@@ -34,7 +37,8 @@ presets = {
34
37
  "segformer_b3_ade20k_512": {
35
38
  "metadata": {
36
39
  "description": (
37
- "SegFormer model with MiTB3 backbone fine-tuned on ADE20k in 512x512 resolution."
40
+ "SegFormer model with MiTB3 backbone fine-tuned on ADE20k in "
41
+ "512x512 resolution."
38
42
  ),
39
43
  "params": 44603347,
40
44
  "path": "segformer_b3",
@@ -44,7 +48,8 @@ presets = {
44
48
  "segformer_b4_ade20k_512": {
45
49
  "metadata": {
46
50
  "description": (
47
- "SegFormer model with MiTB4 backbone fine-tuned on ADE20k in 512x512 resolution."
51
+ "SegFormer model with MiTB4 backbone fine-tuned on ADE20k in "
52
+ "512x512 resolution."
48
53
  ),
49
54
  "params": 61373907,
50
55
  "path": "segformer_b4",
@@ -54,7 +59,8 @@ presets = {
54
59
  "segformer_b5_ade20k_640": {
55
60
  "metadata": {
56
61
  "description": (
57
- "SegFormer model with MiTB5 backbone fine-tuned on ADE20k in 640x640 resolution."
62
+ "SegFormer model with MiTB5 backbone fine-tuned on ADE20k in "
63
+ "640x640 resolution."
58
64
  ),
59
65
  "params": 81974227,
60
66
  "path": "segformer_b5",
@@ -64,7 +70,8 @@ presets = {
64
70
  "segformer_b0_cityscapes_1024": {
65
71
  "metadata": {
66
72
  "description": (
67
- "SegFormer model with MiTB0 backbone fine-tuned on Cityscapes in 1024x1024 resolution."
73
+ "SegFormer model with MiTB0 backbone fine-tuned on Cityscapes "
74
+ "in 1024x1024 resolution."
68
75
  ),
69
76
  "params": 3719027,
70
77
  "path": "segformer_b0",
@@ -74,7 +81,8 @@ presets = {
74
81
  "segformer_b1_cityscapes_1024": {
75
82
  "metadata": {
76
83
  "description": (
77
- "SegFormer model with MiTB1 backbone fine-tuned on Cityscapes in 1024x1024 resolution."
84
+ "SegFormer model with MiTB1 backbone fine-tuned on Cityscapes "
85
+ "in 1024x1024 resolution."
78
86
  ),
79
87
  "params": 13682643,
80
88
  "path": "segformer_b1",
@@ -84,7 +92,8 @@ presets = {
84
92
  "segformer_b2_cityscapes_1024": {
85
93
  "metadata": {
86
94
  "description": (
87
- "SegFormer model with MiTB2 backbone fine-tuned on Cityscapes in 1024x1024 resolution."
95
+ "SegFormer model with MiTB2 backbone fine-tuned on Cityscapes "
96
+ "in 1024x1024 resolution."
88
97
  ),
89
98
  "params": 24727507,
90
99
  "path": "segformer_b2",
@@ -94,7 +103,8 @@ presets = {
94
103
  "segformer_b3_cityscapes_1024": {
95
104
  "metadata": {
96
105
  "description": (
97
- "SegFormer model with MiTB3 backbone fine-tuned on Cityscapes in 1024x1024 resolution."
106
+ "SegFormer model with MiTB3 backbone fine-tuned on Cityscapes "
107
+ "in 1024x1024 resolution."
98
108
  ),
99
109
  "params": 44603347,
100
110
  "path": "segformer_b3",
@@ -104,7 +114,8 @@ presets = {
104
114
  "segformer_b4_cityscapes_1024": {
105
115
  "metadata": {
106
116
  "description": (
107
- "SegFormer model with MiTB4 backbone fine-tuned on Cityscapes in 1024x1024 resolution."
117
+ "SegFormer model with MiTB4 backbone fine-tuned on Cityscapes "
118
+ "in 1024x1024 resolution."
108
119
  ),
109
120
  "params": 61373907,
110
121
  "path": "segformer_b4",
@@ -114,7 +125,8 @@ presets = {
114
125
  "segformer_b5_cityscapes_1024": {
115
126
  "metadata": {
116
127
  "description": (
117
- "SegFormer model with MiTB5 backbone fine-tuned on Cityscapes in 1024x1024 resolution."
128
+ "SegFormer model with MiTB5 backbone fine-tuned on Cityscapes "
129
+ "in 1024x1024 resolution."
118
130
  ),
119
131
  "params": 81974227,
120
132
  "path": "segformer_b5",
@@ -151,7 +151,7 @@ class Seq2SeqLMPreprocessor(Preprocessor):
151
151
  # `sequence_length` is an alias for `decoder_sequence_length`
152
152
  sequence_length=None,
153
153
  ):
154
- """Convert encoder and decoder input strings to integer token inputs for generation.
154
+ """Convert input strings to integer token inputs for generation.
155
155
 
156
156
  Similar to calling the layer for training, this method takes in a dict
157
157
  containing `"encoder_text"` and `"decoder_text"`, with strings or tensor
@@ -595,9 +595,28 @@ class MMDiTBlock(layers.Layer):
595
595
  self.context_block.build(context_shape, timestep_embedding_shape)
596
596
 
597
597
  def _compute_attention(self, query, key, value):
598
+ batch_size = ops.shape(query)[0]
599
+
600
+ # Use the fast path when `ops.dot_product_attention` and flash attention
601
+ # are available.
602
+ if hasattr(ops, "dot_product_attention") and hasattr(
603
+ keras.config, "is_flash_attention_enabled"
604
+ ):
605
+ # `ops.dot_product_attention` is slower than the vanilla
606
+ # implementation in the tensorflow backend.
607
+ encoded = ops.dot_product_attention(
608
+ query,
609
+ key,
610
+ value,
611
+ scale=self._inverse_sqrt_key_dim,
612
+ flash_attention=keras.config.is_flash_attention_enabled(),
613
+ )
614
+ return ops.reshape(
615
+ encoded, (batch_size, -1, self.num_heads * self.head_dim)
616
+ )
617
+
598
618
  # Ref: jax.nn.dot_product_attention
599
619
  # https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
600
- batch_size = ops.shape(query)[0]
601
620
  logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
602
621
  logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
603
622
  probs = self.softmax(logits)
@@ -4,7 +4,7 @@ from keras import ops
4
4
 
5
5
  from keras_hub.src.api_export import keras_hub_export
6
6
  from keras_hub.src.models.backbone import Backbone
7
- from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
7
+ from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import ( # noqa: E501
8
8
  FlowMatchEulerDiscreteScheduler,
9
9
  )
10
10
  from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
@@ -2,10 +2,10 @@ from keras import ops
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
  from keras_hub.src.models.image_to_image import ImageToImage
5
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
5
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
6
6
  StableDiffusion3Backbone,
7
7
  )
8
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
8
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
9
9
  StableDiffusion3TextToImagePreprocessor,
10
10
  )
11
11
 
@@ -26,13 +26,17 @@ class StableDiffusion3ImageToImage(ImageToImage):
26
26
 
27
27
  Use `generate()` to do image generation.
28
28
  ```python
29
+ prompt = (
30
+ "Astronaut in a jungle, cold color palette, muted colors, "
31
+ "detailed, 8k"
32
+ )
29
33
  image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
30
34
  "stable_diffusion_3_medium", image_shape=(512, 512, 3)
31
35
  )
32
36
  image_to_image.generate(
33
37
  {
34
38
  "images": np.ones((512, 512, 3), dtype="float32"),
35
- "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
39
+ "prompts": prompt,
36
40
  }
37
41
  )
38
42
 
@@ -40,7 +44,10 @@ class StableDiffusion3ImageToImage(ImageToImage):
40
44
  image_to_image.generate(
41
45
  {
42
46
  "images": np.ones((2, 512, 512, 3), dtype="float32"),
43
- "prompts": ["cute wallpaper art of a cat", "cute wallpaper art of a dog"],
47
+ "prompts": [
48
+ "cute wallpaper art of a cat",
49
+ "cute wallpaper art of a dog",
50
+ ],
44
51
  }
45
52
  )
46
53
 
@@ -48,7 +55,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
48
55
  image_to_image.generate(
49
56
  {
50
57
  "images": np.ones((512, 512, 3), dtype="float32"),
51
- "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
58
+ "prompts": prompt,
52
59
  }
53
60
  num_steps=50,
54
61
  guidance_scale=5.0,
@@ -59,7 +66,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
59
66
  text_to_image.generate(
60
67
  {
61
68
  "images": np.ones((512, 512, 3), dtype="float32"),
62
- "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
69
+ "prompts": prompt,
63
70
  "negative_prompts": "green color",
64
71
  }
65
72
  )
@@ -2,10 +2,10 @@ from keras import ops
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
  from keras_hub.src.models.inpaint import Inpaint
5
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
5
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
6
6
  StableDiffusion3Backbone,
7
7
  )
8
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
8
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
9
9
  StableDiffusion3TextToImagePreprocessor,
10
10
  )
11
11
 
@@ -1,10 +1,10 @@
1
1
  from keras import ops
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
4
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
5
5
  StableDiffusion3Backbone,
6
6
  )
7
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
7
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
8
8
  StableDiffusion3TextToImagePreprocessor,
9
9
  )
10
10
  from keras_hub.src.models.text_to_image import TextToImage
@@ -46,9 +46,13 @@ class StableDiffusion3TextToImage(TextToImage):
46
46
  )
47
47
 
48
48
  # Generate with `negative_prompts`.
49
+ prompt = (
50
+ "Astronaut in a jungle, cold color palette, muted colors, "
51
+ "detailed, 8k"
52
+ )
49
53
  text_to_image.generate(
50
54
  {
51
- "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
55
+ "prompts": prompt,
52
56
  "negative_prompts": "green color",
53
57
  }
54
58
  )
@@ -3,7 +3,7 @@ from keras import layers
3
3
 
4
4
  from keras_hub.src.api_export import keras_hub_export
5
5
  from keras_hub.src.models.preprocessor import Preprocessor
6
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
6
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
7
7
  StableDiffusion3Backbone,
8
8
  )
9
9
 
@@ -149,7 +149,8 @@ class Task(PipelineModel):
149
149
 
150
150
  This constructor can be called in one of two ways. Either from a task
151
151
  specific base class like `keras_hub.models.CausalLM.from_preset()`, or
152
- from a model class like `keras_hub.models.BertTextClassifier.from_preset()`.
152
+ from a model class like
153
+ `keras_hub.models.BertTextClassifier.from_preset()`.
153
154
  If calling from the a base class, the subclass of the returning object
154
155
  will be inferred from the config in the preset directory.
155
156
 
@@ -294,7 +295,8 @@ class Task(PipelineModel):
294
295
  return "(" + ", ".join(highlighted) + ")"
295
296
 
296
297
  if self.preprocessor:
297
- # Create a rich console for printing. Capture for non-interactive logging.
298
+ # Create a rich console for printing. Capture for non-interactive
299
+ # logging.
298
300
  if print_fn:
299
301
  console = rich_console.Console(
300
302
  highlight=False, force_terminal=False, color_system=None
@@ -21,8 +21,8 @@ class TextClassifier(Task):
21
21
  To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
22
22
  labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
23
23
 
24
- All `TextClassifier` tasks include a `from_preset()` constructor which can be
25
- used to load a pre-trained config and weights.
24
+ All `TextClassifier` tasks include a `from_preset()` constructor which can
25
+ be used to load a pre-trained config and weights.
26
26
 
27
27
  Some, but not all, classification presets include classification head
28
28
  weights in a `task.weights.h5` file. For these presets, you can omit passing
@@ -262,9 +262,13 @@ class TextToImage(Task):
262
262
  pass `prompts` and `negative_prompts` as a dict:
263
263
 
264
264
  ```python
265
+ prompt = (
266
+ "Astronaut in a jungle, cold color palette, muted colors, "
267
+ "detailed, 8k"
268
+ )
265
269
  text_to_image.generate(
266
270
  {
267
- "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
271
+ "prompts": prompt,
268
272
  "negative_prompts": "green color",
269
273
  }
270
274
  )
@@ -159,7 +159,6 @@ class ResNetBlock(keras.layers.Layer):
159
159
  data_format=None,
160
160
  **kwargs,
161
161
  ):
162
-
163
162
  super().__init__(**kwargs)
164
163
  data_format = standardize_data_format(data_format)
165
164
  channel_axis = -1 if data_format == "channels_last" else 1
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
2
+ from keras_hub.src.models.vit.vit_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, ViTBackbone)