keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +1 -0
- keras_hub/api/models/__init__.py +11 -6
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/converters.py +2 -2
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/rms_normalization.py +8 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
- keras_hub/src/metrics/bleu.py +1 -1
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/bert/bert_presets.py +4 -2
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/causal_lm.py +19 -15
- keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
- keras_hub/src/models/densenet/densenet_backbone.py +3 -1
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +6 -6
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/cba.py +1 -1
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
- keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
- keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
- keras_hub/src/models/efficientnet/mbconv.py +1 -1
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/flux/flux_layers.py +46 -44
- keras_hub/src/models/flux/flux_maths.py +24 -17
- keras_hub/src/models/flux/flux_model.py +24 -19
- keras_hub/src/models/flux/flux_presets.py +2 -1
- keras_hub/src/models/flux/flux_text_to_image.py +7 -3
- keras_hub/src/models/gemma/gemma_backbone.py +27 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +9 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier_preprocessor.py +4 -1
- keras_hub/src/models/image_object_detector.py +2 -2
- keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
- keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
- keras_hub/src/models/llama/llama_backbone.py +34 -26
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/mit_backbone.py +4 -3
- keras_hub/src/models/mit/mit_layers.py +2 -1
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +2 -2
- keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
- keras_hub/src/models/retinanet/prediction_head.py +2 -2
- keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
- keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
- keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +4 -2
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/sam_backbone.py +2 -2
- keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/segformer_backbone.py +18 -14
- keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
- keras_hub/src/models/segformer/segformer_presets.py +24 -12
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/task.py +4 -2
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +5 -1
- keras_hub/src/models/vae/vae_layers.py +0 -1
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +49 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +2 -2
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
- keras_hub/src/utils/preset_utils.py +25 -18
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -31,14 +31,15 @@ class TwoWayTransformer(keras.layers.Layer):
|
|
31
31
|
location and type.
|
32
32
|
|
33
33
|
Args:
|
34
|
-
num_layers: int, optional. The num_layers of the attention blocks
|
35
|
-
of attention blocks to use). Defaults to `2`.
|
34
|
+
num_layers: int, optional. The num_layers of the attention blocks
|
35
|
+
(the number of attention blocks to use). Defaults to `2`.
|
36
36
|
hidden_size: int, optional. The number of features of the input image
|
37
37
|
and point embeddings. Defaults to `256`.
|
38
38
|
num_heads: int, optional. Number of heads to use in the attention
|
39
39
|
layers. Defaults to `8`.
|
40
|
-
intermediate_dim: int, optional. The number of units in the hidden
|
41
|
-
the MLP block used in the attention layers.
|
40
|
+
intermediate_dim: int, optional. The number of units in the hidden
|
41
|
+
layer of the MLP block used in the attention layers.
|
42
|
+
Defaults to `2048`.
|
42
43
|
activation: str, optional. The activation of the MLP block's output
|
43
44
|
layer used in the attention layers. Defaults to `"relu"`.
|
44
45
|
attention_downsample_rate: int, optional. The downsample rate of the
|
@@ -6,18 +6,19 @@ from keras_hub.src.models.backbone import Backbone
|
|
6
6
|
|
7
7
|
@keras_hub_export("keras_hub.models.SegFormerBackbone")
|
8
8
|
class SegFormerBackbone(Backbone):
|
9
|
-
"""A Keras model implementing
|
9
|
+
"""A Keras model implementing SegFormer for semantic segmentation.
|
10
10
|
|
11
|
-
This class implements the majority of the SegFormer architecture described
|
12
|
-
[SegFormer: Simple and Efficient Design for Semantic Segmentation
|
13
|
-
|
14
|
-
(https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
|
11
|
+
This class implements the majority of the SegFormer architecture described
|
12
|
+
in [SegFormer: Simple and Efficient Design for Semantic Segmentation](https://arxiv.org/abs/2105.15203)
|
13
|
+
and based on the TensorFlow implementation
|
14
|
+
[from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
|
15
15
|
|
16
|
-
SegFormers are meant to be used with the MixTransformer (MiT) encoder
|
17
|
-
and use a very lightweight all-MLP decoder head.
|
16
|
+
SegFormers are meant to be used with the MixTransformer (MiT) encoder
|
17
|
+
family, and use a very lightweight all-MLP decoder head.
|
18
18
|
|
19
|
-
The MiT encoder uses a hierarchical transformer which outputs features at
|
20
|
-
similar to that of the hierarchical outputs typically
|
19
|
+
The MiT encoder uses a hierarchical transformer which outputs features at
|
20
|
+
multiple scales, similar to that of the hierarchical outputs typically
|
21
|
+
associated with CNNs.
|
21
22
|
|
22
23
|
Args:
|
23
24
|
image_encoder: `keras.Model`. The backbone network for the model that is
|
@@ -50,7 +51,8 @@ class SegFormerBackbone(Backbone):
|
|
50
51
|
strides=[4, 2, 2, 2],
|
51
52
|
)
|
52
53
|
|
53
|
-
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
54
|
+
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
55
|
+
image_encoder=backbone, projection_filters=256)
|
54
56
|
```
|
55
57
|
|
56
58
|
Using the class with a preset `backbone`:
|
@@ -59,7 +61,8 @@ class SegFormerBackbone(Backbone):
|
|
59
61
|
import keras_hub
|
60
62
|
|
61
63
|
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
|
62
|
-
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
64
|
+
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
65
|
+
image_encoder=backbone, projection_filters=256)
|
63
66
|
```
|
64
67
|
|
65
68
|
"""
|
@@ -74,9 +77,10 @@ class SegFormerBackbone(Backbone):
|
|
74
77
|
image_encoder, keras.Model
|
75
78
|
):
|
76
79
|
raise ValueError(
|
77
|
-
"Argument `image_encoder` must be a `keras.layers.Layer`
|
78
|
-
f" or `keras.Model`. Received instead "
|
79
|
-
f"image_encoder={image_encoder}
|
80
|
+
"Argument `image_encoder` must be a `keras.layers.Layer` "
|
81
|
+
f"instance or `keras.Model`. Received instead "
|
82
|
+
f"image_encoder={image_encoder} "
|
83
|
+
f"(of type {type(image_encoder)})."
|
80
84
|
)
|
81
85
|
|
82
86
|
# === Layers ===
|
@@ -3,41 +3,43 @@ import keras
|
|
3
3
|
from keras_hub.src.api_export import keras_hub_export
|
4
4
|
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
5
5
|
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
|
6
|
-
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
|
6
|
+
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( # noqa: E501
|
7
7
|
SegFormerImageSegmenterPreprocessor,
|
8
8
|
)
|
9
9
|
|
10
10
|
|
11
11
|
@keras_hub_export("keras_hub.models.SegFormerImageSegmenter")
|
12
12
|
class SegFormerImageSegmenter(ImageSegmenter):
|
13
|
-
"""A Keras model implementing
|
13
|
+
"""A Keras model implementing SegFormer for semantic segmentation.
|
14
14
|
|
15
|
-
This class implements the segmentation head of the SegFormer architecture
|
16
|
-
[SegFormer: Simple and Efficient Design for Semantic
|
17
|
-
(https://arxiv.org/abs/2105.15203) and
|
15
|
+
This class implements the segmentation head of the SegFormer architecture
|
16
|
+
described in [SegFormer: Simple and Efficient Design for Semantic
|
17
|
+
Segmentation with Transformers] (https://arxiv.org/abs/2105.15203) and
|
18
|
+
[based on the TensorFlow implementation from DeepVision]
|
18
19
|
(https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
|
19
20
|
|
20
|
-
SegFormers are meant to be used with the MixTransformer (MiT) encoder
|
21
|
-
and use a very lightweight all-MLP decoder head.
|
21
|
+
SegFormers are meant to be used with the MixTransformer (MiT) encoder
|
22
|
+
family, and and use a very lightweight all-MLP decoder head.
|
22
23
|
|
23
|
-
The MiT encoder uses a hierarchical transformer which outputs features at
|
24
|
-
similar to that of the hierarchical outputs typically
|
24
|
+
The MiT encoder uses a hierarchical transformer which outputs features at
|
25
|
+
multiple scales, similar to that of the hierarchical outputs typically
|
26
|
+
associated with CNNs.
|
25
27
|
|
26
28
|
Args:
|
27
29
|
image_encoder: `keras.Model`. The backbone network for the model that is
|
28
|
-
used as a feature extractor for the SegFormer encoder.
|
29
|
-
|
30
|
-
(`keras_hub.models.MiTBackbone`) which was created
|
31
|
-
|
32
|
-
|
33
|
-
`
|
34
|
-
|
35
|
-
|
30
|
+
used as a feature extractor for the SegFormer encoder. It is
|
31
|
+
*intended* to be used only with the MiT backbone model
|
32
|
+
(`keras_hub.models.MiTBackbone`) which was created specifically for
|
33
|
+
SegFormers. Alternatively, can be a `keras_hub.models.Backbone` a
|
34
|
+
model subclassing `keras_hub.models.FeaturePyramidBackbone`, or a
|
35
|
+
`keras.Model` that has a `pyramid_outputs` property which is a
|
36
|
+
dictionary with keys "P2", "P3", "P4", and "P5" and layer names as
|
37
|
+
values.
|
36
38
|
num_classes: int, the number of classes for the detection model,
|
37
39
|
including the background class.
|
38
40
|
projection_filters: int, number of filters in the
|
39
|
-
convolution layer projecting the concatenated features into
|
40
|
-
|
41
|
+
convolution layer projecting the concatenated features into a
|
42
|
+
segmentation map. Defaults to 256`.
|
41
43
|
|
42
44
|
|
43
45
|
Example:
|
@@ -45,10 +47,9 @@ class SegFormerImageSegmenter(ImageSegmenter):
|
|
45
47
|
Using presets:
|
46
48
|
|
47
49
|
```python
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset("segformer_b0_ade20k_512")
|
50
|
+
segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset(
|
51
|
+
"segformer_b0_ade20k_512"
|
52
|
+
)
|
52
53
|
|
53
54
|
images = np.random.rand(1, 512, 512, 3)
|
54
55
|
segformer(images)
|
@@ -57,17 +58,18 @@ class SegFormerImageSegmenter(ImageSegmenter):
|
|
57
58
|
Using the SegFormer backbone:
|
58
59
|
|
59
60
|
```python
|
60
|
-
encoder = keras_hub.models.MiTBackbone.from_preset(
|
61
|
-
|
61
|
+
encoder = keras_hub.models.MiTBackbone.from_preset(
|
62
|
+
"mit_b0_ade20k_512"
|
63
|
+
)
|
64
|
+
backbone = keras_hub.models.SegFormerBackbone(
|
65
|
+
image_encoder=encoder,
|
66
|
+
projection_filters=256,
|
67
|
+
)
|
62
68
|
```
|
63
69
|
|
64
70
|
Using the SegFormer backbone with a custom encoder:
|
65
71
|
|
66
72
|
```python
|
67
|
-
import keras
|
68
|
-
import keras_hub
|
69
|
-
import numpy as np
|
70
|
-
|
71
73
|
images = np.ones(shape=(1, 96, 96, 3))
|
72
74
|
labels = np.zeros(shape=(1, 96, 96, 1))
|
73
75
|
|
@@ -83,20 +85,31 @@ class SegFormerImageSegmenter(ImageSegmenter):
|
|
83
85
|
strides=[4, 2, 2, 2],
|
84
86
|
)
|
85
87
|
|
86
|
-
backbone = keras_hub.models.SegFormerBackbone(
|
87
|
-
|
88
|
-
|
89
|
-
|
88
|
+
backbone = keras_hub.models.SegFormerBackbone(
|
89
|
+
image_encoder=encoder,
|
90
|
+
projection_filters=256,
|
91
|
+
)
|
92
|
+
segformer = keras_hub.models.SegFormerImageSegmenter(
|
93
|
+
backbone=backbone,
|
94
|
+
num_classes=4,
|
95
|
+
)
|
96
|
+
segformer(images
|
90
97
|
```
|
91
98
|
|
92
99
|
Using the segmentor class with a preset backbone:
|
93
100
|
|
94
101
|
```python
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
backbone = keras_hub.models.SegFormerBackbone(
|
99
|
-
|
102
|
+
image_encoder = keras_hub.models.MiTBackbone.from_preset(
|
103
|
+
"mit_b0_ade20k_512"
|
104
|
+
)
|
105
|
+
backbone = keras_hub.models.SegFormerBackbone(
|
106
|
+
image_encoder=encoder,
|
107
|
+
projection_filters=256,
|
108
|
+
)
|
109
|
+
segformer = keras_hub.models.SegFormerImageSegmenter(
|
110
|
+
backbone=backbone,
|
111
|
+
num_classes=4,
|
112
|
+
)
|
100
113
|
```
|
101
114
|
"""
|
102
115
|
|
@@ -4,7 +4,8 @@ presets = {
|
|
4
4
|
"segformer_b0_ade20k_512": {
|
5
5
|
"metadata": {
|
6
6
|
"description": (
|
7
|
-
"SegFormer model with MiTB0 backbone fine-tuned on ADE20k in
|
7
|
+
"SegFormer model with MiTB0 backbone fine-tuned on ADE20k in "
|
8
|
+
"512x512 resolution."
|
8
9
|
),
|
9
10
|
"params": 3719027,
|
10
11
|
"path": "segformer_b0",
|
@@ -14,7 +15,8 @@ presets = {
|
|
14
15
|
"segformer_b1_ade20k_512": {
|
15
16
|
"metadata": {
|
16
17
|
"description": (
|
17
|
-
"SegFormer model with MiTB1 backbone fine-tuned on ADE20k in
|
18
|
+
"SegFormer model with MiTB1 backbone fine-tuned on ADE20k in "
|
19
|
+
"512x512 resolution."
|
18
20
|
),
|
19
21
|
"params": 13682643,
|
20
22
|
"path": "segformer_b1",
|
@@ -24,7 +26,8 @@ presets = {
|
|
24
26
|
"segformer_b2_ade20k_512": {
|
25
27
|
"metadata": {
|
26
28
|
"description": (
|
27
|
-
"SegFormer model with MiTB2 backbone fine-tuned on ADE20k in
|
29
|
+
"SegFormer model with MiTB2 backbone fine-tuned on ADE20k in "
|
30
|
+
"512x512 resolution."
|
28
31
|
),
|
29
32
|
"params": 24727507,
|
30
33
|
"path": "segformer_b2",
|
@@ -34,7 +37,8 @@ presets = {
|
|
34
37
|
"segformer_b3_ade20k_512": {
|
35
38
|
"metadata": {
|
36
39
|
"description": (
|
37
|
-
"SegFormer model with MiTB3 backbone fine-tuned on ADE20k in
|
40
|
+
"SegFormer model with MiTB3 backbone fine-tuned on ADE20k in "
|
41
|
+
"512x512 resolution."
|
38
42
|
),
|
39
43
|
"params": 44603347,
|
40
44
|
"path": "segformer_b3",
|
@@ -44,7 +48,8 @@ presets = {
|
|
44
48
|
"segformer_b4_ade20k_512": {
|
45
49
|
"metadata": {
|
46
50
|
"description": (
|
47
|
-
"SegFormer model with MiTB4 backbone fine-tuned on ADE20k in
|
51
|
+
"SegFormer model with MiTB4 backbone fine-tuned on ADE20k in "
|
52
|
+
"512x512 resolution."
|
48
53
|
),
|
49
54
|
"params": 61373907,
|
50
55
|
"path": "segformer_b4",
|
@@ -54,7 +59,8 @@ presets = {
|
|
54
59
|
"segformer_b5_ade20k_640": {
|
55
60
|
"metadata": {
|
56
61
|
"description": (
|
57
|
-
"SegFormer model with MiTB5 backbone fine-tuned on ADE20k in
|
62
|
+
"SegFormer model with MiTB5 backbone fine-tuned on ADE20k in "
|
63
|
+
"640x640 resolution."
|
58
64
|
),
|
59
65
|
"params": 81974227,
|
60
66
|
"path": "segformer_b5",
|
@@ -64,7 +70,8 @@ presets = {
|
|
64
70
|
"segformer_b0_cityscapes_1024": {
|
65
71
|
"metadata": {
|
66
72
|
"description": (
|
67
|
-
"SegFormer model with MiTB0 backbone fine-tuned on Cityscapes
|
73
|
+
"SegFormer model with MiTB0 backbone fine-tuned on Cityscapes "
|
74
|
+
"in 1024x1024 resolution."
|
68
75
|
),
|
69
76
|
"params": 3719027,
|
70
77
|
"path": "segformer_b0",
|
@@ -74,7 +81,8 @@ presets = {
|
|
74
81
|
"segformer_b1_cityscapes_1024": {
|
75
82
|
"metadata": {
|
76
83
|
"description": (
|
77
|
-
"SegFormer model with MiTB1 backbone fine-tuned on Cityscapes
|
84
|
+
"SegFormer model with MiTB1 backbone fine-tuned on Cityscapes "
|
85
|
+
"in 1024x1024 resolution."
|
78
86
|
),
|
79
87
|
"params": 13682643,
|
80
88
|
"path": "segformer_b1",
|
@@ -84,7 +92,8 @@ presets = {
|
|
84
92
|
"segformer_b2_cityscapes_1024": {
|
85
93
|
"metadata": {
|
86
94
|
"description": (
|
87
|
-
"SegFormer model with MiTB2 backbone fine-tuned on Cityscapes
|
95
|
+
"SegFormer model with MiTB2 backbone fine-tuned on Cityscapes "
|
96
|
+
"in 1024x1024 resolution."
|
88
97
|
),
|
89
98
|
"params": 24727507,
|
90
99
|
"path": "segformer_b2",
|
@@ -94,7 +103,8 @@ presets = {
|
|
94
103
|
"segformer_b3_cityscapes_1024": {
|
95
104
|
"metadata": {
|
96
105
|
"description": (
|
97
|
-
"SegFormer model with MiTB3 backbone fine-tuned on Cityscapes
|
106
|
+
"SegFormer model with MiTB3 backbone fine-tuned on Cityscapes "
|
107
|
+
"in 1024x1024 resolution."
|
98
108
|
),
|
99
109
|
"params": 44603347,
|
100
110
|
"path": "segformer_b3",
|
@@ -104,7 +114,8 @@ presets = {
|
|
104
114
|
"segformer_b4_cityscapes_1024": {
|
105
115
|
"metadata": {
|
106
116
|
"description": (
|
107
|
-
"SegFormer model with MiTB4 backbone fine-tuned on Cityscapes
|
117
|
+
"SegFormer model with MiTB4 backbone fine-tuned on Cityscapes "
|
118
|
+
"in 1024x1024 resolution."
|
108
119
|
),
|
109
120
|
"params": 61373907,
|
110
121
|
"path": "segformer_b4",
|
@@ -114,7 +125,8 @@ presets = {
|
|
114
125
|
"segformer_b5_cityscapes_1024": {
|
115
126
|
"metadata": {
|
116
127
|
"description": (
|
117
|
-
"SegFormer model with MiTB5 backbone fine-tuned on Cityscapes
|
128
|
+
"SegFormer model with MiTB5 backbone fine-tuned on Cityscapes "
|
129
|
+
"in 1024x1024 resolution."
|
118
130
|
),
|
119
131
|
"params": 81974227,
|
120
132
|
"path": "segformer_b5",
|
@@ -151,7 +151,7 @@ class Seq2SeqLMPreprocessor(Preprocessor):
|
|
151
151
|
# `sequence_length` is an alias for `decoder_sequence_length`
|
152
152
|
sequence_length=None,
|
153
153
|
):
|
154
|
-
"""Convert
|
154
|
+
"""Convert input strings to integer token inputs for generation.
|
155
155
|
|
156
156
|
Similar to calling the layer for training, this method takes in a dict
|
157
157
|
containing `"encoder_text"` and `"decoder_text"`, with strings or tensor
|
@@ -595,9 +595,28 @@ class MMDiTBlock(layers.Layer):
|
|
595
595
|
self.context_block.build(context_shape, timestep_embedding_shape)
|
596
596
|
|
597
597
|
def _compute_attention(self, query, key, value):
|
598
|
+
batch_size = ops.shape(query)[0]
|
599
|
+
|
600
|
+
# Use the fast path when `ops.dot_product_attention` and flash attention
|
601
|
+
# are available.
|
602
|
+
if hasattr(ops, "dot_product_attention") and hasattr(
|
603
|
+
keras.config, "is_flash_attention_enabled"
|
604
|
+
):
|
605
|
+
# `ops.dot_product_attention` is slower than the vanilla
|
606
|
+
# implementation in the tensorflow backend.
|
607
|
+
encoded = ops.dot_product_attention(
|
608
|
+
query,
|
609
|
+
key,
|
610
|
+
value,
|
611
|
+
scale=self._inverse_sqrt_key_dim,
|
612
|
+
flash_attention=keras.config.is_flash_attention_enabled(),
|
613
|
+
)
|
614
|
+
return ops.reshape(
|
615
|
+
encoded, (batch_size, -1, self.num_heads * self.head_dim)
|
616
|
+
)
|
617
|
+
|
598
618
|
# Ref: jax.nn.dot_product_attention
|
599
619
|
# https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
|
600
|
-
batch_size = ops.shape(query)[0]
|
601
620
|
logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
|
602
621
|
logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
|
603
622
|
probs = self.softmax(logits)
|
@@ -4,7 +4,7 @@ from keras import ops
|
|
4
4
|
|
5
5
|
from keras_hub.src.api_export import keras_hub_export
|
6
6
|
from keras_hub.src.models.backbone import Backbone
|
7
|
-
from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
|
7
|
+
from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import ( # noqa: E501
|
8
8
|
FlowMatchEulerDiscreteScheduler,
|
9
9
|
)
|
10
10
|
from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
|
@@ -2,10 +2,10 @@ from keras import ops
|
|
2
2
|
|
3
3
|
from keras_hub.src.api_export import keras_hub_export
|
4
4
|
from keras_hub.src.models.image_to_image import ImageToImage
|
5
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
5
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
6
6
|
StableDiffusion3Backbone,
|
7
7
|
)
|
8
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
8
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
|
9
9
|
StableDiffusion3TextToImagePreprocessor,
|
10
10
|
)
|
11
11
|
|
@@ -26,13 +26,17 @@ class StableDiffusion3ImageToImage(ImageToImage):
|
|
26
26
|
|
27
27
|
Use `generate()` to do image generation.
|
28
28
|
```python
|
29
|
+
prompt = (
|
30
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
31
|
+
"detailed, 8k"
|
32
|
+
)
|
29
33
|
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
|
30
34
|
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
31
35
|
)
|
32
36
|
image_to_image.generate(
|
33
37
|
{
|
34
38
|
"images": np.ones((512, 512, 3), dtype="float32"),
|
35
|
-
"prompts":
|
39
|
+
"prompts": prompt,
|
36
40
|
}
|
37
41
|
)
|
38
42
|
|
@@ -40,7 +44,10 @@ class StableDiffusion3ImageToImage(ImageToImage):
|
|
40
44
|
image_to_image.generate(
|
41
45
|
{
|
42
46
|
"images": np.ones((2, 512, 512, 3), dtype="float32"),
|
43
|
-
"prompts": [
|
47
|
+
"prompts": [
|
48
|
+
"cute wallpaper art of a cat",
|
49
|
+
"cute wallpaper art of a dog",
|
50
|
+
],
|
44
51
|
}
|
45
52
|
)
|
46
53
|
|
@@ -48,7 +55,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
|
|
48
55
|
image_to_image.generate(
|
49
56
|
{
|
50
57
|
"images": np.ones((512, 512, 3), dtype="float32"),
|
51
|
-
"prompts":
|
58
|
+
"prompts": prompt,
|
52
59
|
}
|
53
60
|
num_steps=50,
|
54
61
|
guidance_scale=5.0,
|
@@ -59,7 +66,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
|
|
59
66
|
text_to_image.generate(
|
60
67
|
{
|
61
68
|
"images": np.ones((512, 512, 3), dtype="float32"),
|
62
|
-
"prompts":
|
69
|
+
"prompts": prompt,
|
63
70
|
"negative_prompts": "green color",
|
64
71
|
}
|
65
72
|
)
|
@@ -2,10 +2,10 @@ from keras import ops
|
|
2
2
|
|
3
3
|
from keras_hub.src.api_export import keras_hub_export
|
4
4
|
from keras_hub.src.models.inpaint import Inpaint
|
5
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
5
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
6
6
|
StableDiffusion3Backbone,
|
7
7
|
)
|
8
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
8
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
|
9
9
|
StableDiffusion3TextToImagePreprocessor,
|
10
10
|
)
|
11
11
|
|
@@ -1,10 +1,10 @@
|
|
1
1
|
from keras import ops
|
2
2
|
|
3
3
|
from keras_hub.src.api_export import keras_hub_export
|
4
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
4
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
5
5
|
StableDiffusion3Backbone,
|
6
6
|
)
|
7
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
7
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
|
8
8
|
StableDiffusion3TextToImagePreprocessor,
|
9
9
|
)
|
10
10
|
from keras_hub.src.models.text_to_image import TextToImage
|
@@ -46,9 +46,13 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
46
46
|
)
|
47
47
|
|
48
48
|
# Generate with `negative_prompts`.
|
49
|
+
prompt = (
|
50
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
51
|
+
"detailed, 8k"
|
52
|
+
)
|
49
53
|
text_to_image.generate(
|
50
54
|
{
|
51
|
-
"prompts":
|
55
|
+
"prompts": prompt,
|
52
56
|
"negative_prompts": "green color",
|
53
57
|
}
|
54
58
|
)
|
@@ -3,7 +3,7 @@ from keras import layers
|
|
3
3
|
|
4
4
|
from keras_hub.src.api_export import keras_hub_export
|
5
5
|
from keras_hub.src.models.preprocessor import Preprocessor
|
6
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
6
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
7
7
|
StableDiffusion3Backbone,
|
8
8
|
)
|
9
9
|
|
keras_hub/src/models/task.py
CHANGED
@@ -149,7 +149,8 @@ class Task(PipelineModel):
|
|
149
149
|
|
150
150
|
This constructor can be called in one of two ways. Either from a task
|
151
151
|
specific base class like `keras_hub.models.CausalLM.from_preset()`, or
|
152
|
-
from a model class like
|
152
|
+
from a model class like
|
153
|
+
`keras_hub.models.BertTextClassifier.from_preset()`.
|
153
154
|
If calling from the a base class, the subclass of the returning object
|
154
155
|
will be inferred from the config in the preset directory.
|
155
156
|
|
@@ -294,7 +295,8 @@ class Task(PipelineModel):
|
|
294
295
|
return "(" + ", ".join(highlighted) + ")"
|
295
296
|
|
296
297
|
if self.preprocessor:
|
297
|
-
# Create a rich console for printing. Capture for non-interactive
|
298
|
+
# Create a rich console for printing. Capture for non-interactive
|
299
|
+
# logging.
|
298
300
|
if print_fn:
|
299
301
|
console = rich_console.Console(
|
300
302
|
highlight=False, force_terminal=False, color_system=None
|
@@ -21,8 +21,8 @@ class TextClassifier(Task):
|
|
21
21
|
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
22
22
|
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
23
23
|
|
24
|
-
All `TextClassifier` tasks include a `from_preset()` constructor which can
|
25
|
-
used to load a pre-trained config and weights.
|
24
|
+
All `TextClassifier` tasks include a `from_preset()` constructor which can
|
25
|
+
be used to load a pre-trained config and weights.
|
26
26
|
|
27
27
|
Some, but not all, classification presets include classification head
|
28
28
|
weights in a `task.weights.h5` file. For these presets, you can omit passing
|
@@ -262,9 +262,13 @@ class TextToImage(Task):
|
|
262
262
|
pass `prompts` and `negative_prompts` as a dict:
|
263
263
|
|
264
264
|
```python
|
265
|
+
prompt = (
|
266
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
267
|
+
"detailed, 8k"
|
268
|
+
)
|
265
269
|
text_to_image.generate(
|
266
270
|
{
|
267
|
-
"prompts":
|
271
|
+
"prompts": prompt,
|
268
272
|
"negative_prompts": "green color",
|
269
273
|
}
|
270
274
|
)
|