keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
|
|
1
|
+
from keras import ops
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.image_to_image import ImageToImage
|
5
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
6
|
+
StableDiffusion3Backbone,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
|
9
|
+
StableDiffusion3TextToImagePreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage")
|
14
|
+
class StableDiffusion3ImageToImage(ImageToImage):
|
15
|
+
"""An end-to-end Stable Diffusion 3 model for image-to-image generation.
|
16
|
+
|
17
|
+
This model has a `generate()` method, which generates images based
|
18
|
+
on a combination of a reference image and a text prompt.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
|
22
|
+
preprocessor: A
|
23
|
+
`keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
|
24
|
+
|
25
|
+
Examples:
|
26
|
+
|
27
|
+
Use `generate()` to do image generation.
|
28
|
+
```python
|
29
|
+
prompt = (
|
30
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
31
|
+
"detailed, 8k"
|
32
|
+
)
|
33
|
+
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
|
34
|
+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
35
|
+
)
|
36
|
+
image_to_image.generate(
|
37
|
+
{
|
38
|
+
"images": np.ones((512, 512, 3), dtype="float32"),
|
39
|
+
"prompts": prompt,
|
40
|
+
}
|
41
|
+
)
|
42
|
+
|
43
|
+
# Generate with batched prompts.
|
44
|
+
image_to_image.generate(
|
45
|
+
{
|
46
|
+
"images": np.ones((2, 512, 512, 3), dtype="float32"),
|
47
|
+
"prompts": [
|
48
|
+
"cute wallpaper art of a cat",
|
49
|
+
"cute wallpaper art of a dog",
|
50
|
+
],
|
51
|
+
}
|
52
|
+
)
|
53
|
+
|
54
|
+
# Generate with different `num_steps`, `guidance_scale` and `strength`.
|
55
|
+
image_to_image.generate(
|
56
|
+
{
|
57
|
+
"images": np.ones((512, 512, 3), dtype="float32"),
|
58
|
+
"prompts": prompt,
|
59
|
+
}
|
60
|
+
num_steps=50,
|
61
|
+
guidance_scale=5.0,
|
62
|
+
strength=0.6,
|
63
|
+
)
|
64
|
+
|
65
|
+
# Generate with `negative_prompts`.
|
66
|
+
text_to_image.generate(
|
67
|
+
{
|
68
|
+
"images": np.ones((512, 512, 3), dtype="float32"),
|
69
|
+
"prompts": prompt,
|
70
|
+
"negative_prompts": "green color",
|
71
|
+
}
|
72
|
+
)
|
73
|
+
```
|
74
|
+
"""
|
75
|
+
|
76
|
+
backbone_cls = StableDiffusion3Backbone
|
77
|
+
preprocessor_cls = StableDiffusion3TextToImagePreprocessor
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
backbone,
|
82
|
+
preprocessor,
|
83
|
+
**kwargs,
|
84
|
+
):
|
85
|
+
# === Layers ===
|
86
|
+
self.backbone = backbone
|
87
|
+
self.preprocessor = preprocessor
|
88
|
+
|
89
|
+
# === Functional Model ===
|
90
|
+
inputs = backbone.input
|
91
|
+
outputs = backbone.output
|
92
|
+
super().__init__(
|
93
|
+
inputs=inputs,
|
94
|
+
outputs=outputs,
|
95
|
+
**kwargs,
|
96
|
+
)
|
97
|
+
|
98
|
+
def fit(self, *args, **kwargs):
|
99
|
+
raise NotImplementedError(
|
100
|
+
"Currently, `fit` is not supported for "
|
101
|
+
"`StableDiffusion3ImageToImage`."
|
102
|
+
)
|
103
|
+
|
104
|
+
def generate_step(
|
105
|
+
self,
|
106
|
+
images,
|
107
|
+
noises,
|
108
|
+
token_ids,
|
109
|
+
starting_step,
|
110
|
+
num_steps,
|
111
|
+
guidance_scale,
|
112
|
+
):
|
113
|
+
"""A compilable generation function for batched of inputs.
|
114
|
+
|
115
|
+
This function represents the inner, XLA-compilable, generation function
|
116
|
+
for batched inputs.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
images: A (batch_size, image_height, image_width, 3) tensor
|
120
|
+
containing the reference images.
|
121
|
+
noises: A (batch_size, latent_height, latent_width, channels) tensor
|
122
|
+
containing the noises to be added to the latents. Typically,
|
123
|
+
this tensor is sampled from the Gaussian distribution.
|
124
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
125
|
+
tokens based on the input prompts and negative prompts.
|
126
|
+
starting_step: int. The number of the starting diffusion step.
|
127
|
+
num_steps: int. The number of diffusion steps to take.
|
128
|
+
guidance_scale: float. The classifier free guidance scale defined in
|
129
|
+
[Classifier-Free Diffusion Guidance](
|
130
|
+
https://arxiv.org/abs/2207.12598). Higher scale encourages to
|
131
|
+
generate images that are closely linked to prompts, usually at
|
132
|
+
the expense of lower image quality.
|
133
|
+
"""
|
134
|
+
token_ids, negative_token_ids = token_ids
|
135
|
+
|
136
|
+
# Encode images.
|
137
|
+
latents = self.backbone.encode_image_step(images)
|
138
|
+
|
139
|
+
# Add noises to latents.
|
140
|
+
latents = self.backbone.add_noise_step(
|
141
|
+
latents, noises, starting_step, num_steps
|
142
|
+
)
|
143
|
+
|
144
|
+
# Encode inputs.
|
145
|
+
embeddings = self.backbone.encode_text_step(
|
146
|
+
token_ids, negative_token_ids
|
147
|
+
)
|
148
|
+
|
149
|
+
# Denoise.
|
150
|
+
def body_fun(step, latents):
|
151
|
+
return self.backbone.denoise_step(
|
152
|
+
latents,
|
153
|
+
embeddings,
|
154
|
+
step,
|
155
|
+
num_steps,
|
156
|
+
guidance_scale,
|
157
|
+
)
|
158
|
+
|
159
|
+
latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
|
160
|
+
|
161
|
+
# Decode.
|
162
|
+
return self.backbone.decode_step(latents)
|
163
|
+
|
164
|
+
def generate(
|
165
|
+
self,
|
166
|
+
inputs,
|
167
|
+
num_steps=50,
|
168
|
+
strength=0.8,
|
169
|
+
guidance_scale=7.0,
|
170
|
+
seed=None,
|
171
|
+
):
|
172
|
+
return super().generate(
|
173
|
+
inputs,
|
174
|
+
num_steps=num_steps,
|
175
|
+
strength=strength,
|
176
|
+
guidance_scale=guidance_scale,
|
177
|
+
seed=seed,
|
178
|
+
)
|
@@ -0,0 +1,193 @@
|
|
1
|
+
from keras import ops
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.inpaint import Inpaint
|
5
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
6
|
+
StableDiffusion3Backbone,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
|
9
|
+
StableDiffusion3TextToImagePreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3Inpaint")
|
14
|
+
class StableDiffusion3Inpaint(Inpaint):
|
15
|
+
"""An end-to-end Stable Diffusion 3 model for inpaint generation.
|
16
|
+
|
17
|
+
This model has a `generate()` method, which generates images based
|
18
|
+
on a combination of a reference image, mask and a text prompt.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
|
22
|
+
preprocessor: A
|
23
|
+
`keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
|
24
|
+
|
25
|
+
Examples:
|
26
|
+
|
27
|
+
Use `generate()` to do image generation.
|
28
|
+
```python
|
29
|
+
reference_image = np.ones((1024, 1024, 3), dtype="float32")
|
30
|
+
reference_mask = np.ones((1024, 1024), dtype="float32")
|
31
|
+
inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
|
32
|
+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
33
|
+
)
|
34
|
+
inpaint.generate(
|
35
|
+
reference_image,
|
36
|
+
reference_mask,
|
37
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
38
|
+
)
|
39
|
+
|
40
|
+
# Generate with batched prompts.
|
41
|
+
reference_images = np.ones((2, 512, 512, 3), dtype="float32")
|
42
|
+
reference_mask = np.ones((2, 1024, 1024), dtype="float32")
|
43
|
+
inpaint.generate(
|
44
|
+
reference_images,
|
45
|
+
reference_mask,
|
46
|
+
["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
|
47
|
+
)
|
48
|
+
|
49
|
+
# Generate with different `num_steps`, `guidance_scale` and `strength`.
|
50
|
+
inpaint.generate(
|
51
|
+
reference_image,
|
52
|
+
reference_mask,
|
53
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
54
|
+
num_steps=50,
|
55
|
+
guidance_scale=5.0,
|
56
|
+
strength=0.6,
|
57
|
+
)
|
58
|
+
```
|
59
|
+
"""
|
60
|
+
|
61
|
+
backbone_cls = StableDiffusion3Backbone
|
62
|
+
preprocessor_cls = StableDiffusion3TextToImagePreprocessor
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
backbone,
|
67
|
+
preprocessor,
|
68
|
+
**kwargs,
|
69
|
+
):
|
70
|
+
# === Layers ===
|
71
|
+
self.backbone = backbone
|
72
|
+
self.preprocessor = preprocessor
|
73
|
+
|
74
|
+
# === Functional Model ===
|
75
|
+
inputs = backbone.input
|
76
|
+
outputs = backbone.output
|
77
|
+
super().__init__(
|
78
|
+
inputs=inputs,
|
79
|
+
outputs=outputs,
|
80
|
+
**kwargs,
|
81
|
+
)
|
82
|
+
|
83
|
+
def fit(self, *args, **kwargs):
|
84
|
+
raise NotImplementedError(
|
85
|
+
"Currently, `fit` is not supported for `StableDiffusion3Inpaint`."
|
86
|
+
)
|
87
|
+
|
88
|
+
def generate_step(
|
89
|
+
self,
|
90
|
+
images,
|
91
|
+
masks,
|
92
|
+
noises,
|
93
|
+
token_ids,
|
94
|
+
starting_step,
|
95
|
+
num_steps,
|
96
|
+
guidance_scale,
|
97
|
+
):
|
98
|
+
"""A compilable generation function for batched of inputs.
|
99
|
+
|
100
|
+
This function represents the inner, XLA-compilable, generation function
|
101
|
+
for batched inputs.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
images: A (batch_size, image_height, image_width, 3) tensor
|
105
|
+
containing the reference images.
|
106
|
+
masks: A (batch_size, image_height, image_width) tensor
|
107
|
+
containing the reference masks.
|
108
|
+
noises: A (batch_size, latent_height, latent_width, channels) tensor
|
109
|
+
containing the noises to be added to the latents. Typically,
|
110
|
+
this tensor is sampled from the Gaussian distribution.
|
111
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
112
|
+
tokens based on the input prompts and negative prompts.
|
113
|
+
starting_step: int. The number of the starting diffusion step.
|
114
|
+
num_steps: int. The number of diffusion steps to take.
|
115
|
+
guidance_scale: float. The classifier free guidance scale defined in
|
116
|
+
[Classifier-Free Diffusion Guidance](
|
117
|
+
https://arxiv.org/abs/2207.12598). Higher scale encourages to
|
118
|
+
generate images that are closely linked to prompts, usually at
|
119
|
+
the expense of lower image quality.
|
120
|
+
"""
|
121
|
+
token_ids, negative_token_ids = token_ids
|
122
|
+
|
123
|
+
# Get masked images.
|
124
|
+
masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype)
|
125
|
+
masks_latent_size = ops.image.resize(
|
126
|
+
masks,
|
127
|
+
(self.backbone.latent_shape[1], self.backbone.latent_shape[2]),
|
128
|
+
interpolation="nearest",
|
129
|
+
)
|
130
|
+
|
131
|
+
# Encode images.
|
132
|
+
image_latents = self.backbone.encode_image_step(images)
|
133
|
+
|
134
|
+
# Add noises to latents.
|
135
|
+
latents = self.backbone.add_noise_step(
|
136
|
+
image_latents, noises, starting_step, num_steps
|
137
|
+
)
|
138
|
+
|
139
|
+
# Encode inputs.
|
140
|
+
embeddings = self.backbone.encode_text_step(
|
141
|
+
token_ids, negative_token_ids
|
142
|
+
)
|
143
|
+
|
144
|
+
# Denoise.
|
145
|
+
def body_fun(step, latents):
|
146
|
+
latents = self.backbone.denoise_step(
|
147
|
+
latents,
|
148
|
+
embeddings,
|
149
|
+
step,
|
150
|
+
num_steps,
|
151
|
+
guidance_scale,
|
152
|
+
)
|
153
|
+
|
154
|
+
# Compute the previous latents x_t -> x_t-1.
|
155
|
+
def true_fn():
|
156
|
+
next_step = ops.add(step, 1)
|
157
|
+
return self.backbone.add_noise_step(
|
158
|
+
image_latents, noises, next_step, num_steps
|
159
|
+
)
|
160
|
+
|
161
|
+
init_latents = ops.cond(
|
162
|
+
step < ops.subtract(num_steps, 1),
|
163
|
+
true_fn,
|
164
|
+
lambda: ops.cast(image_latents, noises.dtype),
|
165
|
+
)
|
166
|
+
latents = ops.add(
|
167
|
+
ops.multiply(
|
168
|
+
ops.subtract(1.0, masks_latent_size), init_latents
|
169
|
+
),
|
170
|
+
ops.multiply(masks_latent_size, latents),
|
171
|
+
)
|
172
|
+
return latents
|
173
|
+
|
174
|
+
latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
|
175
|
+
|
176
|
+
# Decode.
|
177
|
+
return self.backbone.decode_step(latents)
|
178
|
+
|
179
|
+
def generate(
|
180
|
+
self,
|
181
|
+
inputs,
|
182
|
+
num_steps=50,
|
183
|
+
strength=0.6,
|
184
|
+
guidance_scale=7.0,
|
185
|
+
seed=None,
|
186
|
+
):
|
187
|
+
return super().generate(
|
188
|
+
inputs,
|
189
|
+
num_steps=num_steps,
|
190
|
+
strength=strength,
|
191
|
+
guidance_scale=guidance_scale,
|
192
|
+
seed=seed,
|
193
|
+
)
|
@@ -5,14 +5,50 @@ backbone_presets = {
|
|
5
5
|
"metadata": {
|
6
6
|
"description": (
|
7
7
|
"3 billion parameter, including CLIP L and CLIP G text "
|
8
|
-
"encoders, MMDiT generative model, and VAE
|
8
|
+
"encoders, MMDiT generative model, and VAE autoencoder. "
|
9
9
|
"Developed by Stability AI."
|
10
10
|
),
|
11
|
-
"params":
|
12
|
-
"
|
13
|
-
"path": "stablediffusion3",
|
14
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
11
|
+
"params": 2987080931,
|
12
|
+
"path": "stable_diffusion_3",
|
15
13
|
},
|
16
|
-
"kaggle_handle": "kaggle://
|
17
|
-
}
|
14
|
+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
|
15
|
+
},
|
16
|
+
"stable_diffusion_3.5_medium": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"3 billion parameter, including CLIP L and CLIP G text "
|
20
|
+
"encoders, MMDiT-X generative model, and VAE autoencoder. "
|
21
|
+
"Developed by Stability AI."
|
22
|
+
),
|
23
|
+
"params": 3371793763,
|
24
|
+
"path": "stable_diffusion_3",
|
25
|
+
},
|
26
|
+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
|
27
|
+
},
|
28
|
+
"stable_diffusion_3.5_large": {
|
29
|
+
"metadata": {
|
30
|
+
"description": (
|
31
|
+
"9 billion parameter, including CLIP L and CLIP G text "
|
32
|
+
"encoders, MMDiT generative model, and VAE autoencoder. "
|
33
|
+
"Developed by Stability AI."
|
34
|
+
),
|
35
|
+
"params": 9048410595,
|
36
|
+
"path": "stable_diffusion_3",
|
37
|
+
},
|
38
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/2",
|
39
|
+
},
|
40
|
+
"stable_diffusion_3.5_large_turbo": {
|
41
|
+
"metadata": {
|
42
|
+
"description": (
|
43
|
+
"9 billion parameter, including CLIP L and CLIP G text "
|
44
|
+
"encoders, MMDiT generative model, and VAE autoencoder. "
|
45
|
+
"A timestep-distilled version that eliminates classifier-free "
|
46
|
+
"guidance and uses fewer steps for generation. "
|
47
|
+
"Developed by Stability AI."
|
48
|
+
),
|
49
|
+
"params": 9048410595,
|
50
|
+
"path": "stable_diffusion_3",
|
51
|
+
},
|
52
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/2",
|
53
|
+
},
|
18
54
|
}
|
@@ -1,10 +1,10 @@
|
|
1
1
|
from keras import ops
|
2
2
|
|
3
3
|
from keras_hub.src.api_export import keras_hub_export
|
4
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
4
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
5
5
|
StableDiffusion3Backbone,
|
6
6
|
)
|
7
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
7
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
|
8
8
|
StableDiffusion3TextToImagePreprocessor,
|
9
9
|
)
|
10
10
|
from keras_hub.src.models.text_to_image import TextToImage
|
@@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
27
27
|
Use `generate()` to do image generation.
|
28
28
|
```python
|
29
29
|
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
|
30
|
-
"stable_diffusion_3_medium",
|
30
|
+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
|
31
31
|
)
|
32
32
|
text_to_image.generate(
|
33
33
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
@@ -38,11 +38,23 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
38
38
|
["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
|
39
39
|
)
|
40
40
|
|
41
|
-
# Generate with different `num_steps` and `
|
41
|
+
# Generate with different `num_steps` and `guidance_scale`.
|
42
42
|
text_to_image.generate(
|
43
43
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
44
44
|
num_steps=50,
|
45
|
-
|
45
|
+
guidance_scale=5.0,
|
46
|
+
)
|
47
|
+
|
48
|
+
# Generate with `negative_prompts`.
|
49
|
+
prompt = (
|
50
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
51
|
+
"detailed, 8k"
|
52
|
+
)
|
53
|
+
text_to_image.generate(
|
54
|
+
{
|
55
|
+
"prompts": prompt,
|
56
|
+
"negative_prompts": "green color",
|
57
|
+
}
|
46
58
|
)
|
47
59
|
```
|
48
60
|
"""
|
@@ -79,7 +91,6 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
79
91
|
self,
|
80
92
|
latents,
|
81
93
|
token_ids,
|
82
|
-
negative_token_ids,
|
83
94
|
num_steps,
|
84
95
|
guidance_scale,
|
85
96
|
):
|
@@ -92,10 +103,8 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
92
103
|
latents: A (batch_size, height, width, channels) tensor
|
93
104
|
containing the latents to start generation from. Typically, this
|
94
105
|
tensor is sampled from the Gaussian distribution.
|
95
|
-
token_ids: A (batch_size, num_tokens) tensor containing the
|
96
|
-
tokens based on the input prompts.
|
97
|
-
negative_token_ids: A (batch_size, num_tokens) tensor
|
98
|
-
containing the negative tokens based on the input prompts.
|
106
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
107
|
+
tokens based on the input prompts and negative prompts.
|
99
108
|
num_steps: int. The number of diffusion steps to take.
|
100
109
|
guidance_scale: float. The classifier free guidance scale defined in
|
101
110
|
[Classifier-Free Diffusion Guidance](
|
@@ -103,8 +112,12 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
103
112
|
generate images that are closely linked to prompts, usually at
|
104
113
|
the expense of lower image quality.
|
105
114
|
"""
|
106
|
-
|
107
|
-
|
115
|
+
token_ids, negative_token_ids = token_ids
|
116
|
+
|
117
|
+
# Encode prompts.
|
118
|
+
embeddings = self.backbone.encode_text_step(
|
119
|
+
token_ids, negative_token_ids
|
120
|
+
)
|
108
121
|
|
109
122
|
# Denoise.
|
110
123
|
def body_fun(step, latents):
|
@@ -124,14 +137,12 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
124
137
|
def generate(
|
125
138
|
self,
|
126
139
|
inputs,
|
127
|
-
negative_inputs=None,
|
128
140
|
num_steps=28,
|
129
141
|
guidance_scale=7.0,
|
130
142
|
seed=None,
|
131
143
|
):
|
132
144
|
return super().generate(
|
133
145
|
inputs,
|
134
|
-
negative_inputs=negative_inputs,
|
135
146
|
num_steps=num_steps,
|
136
147
|
guidance_scale=guidance_scale,
|
137
148
|
seed=seed,
|
@@ -3,7 +3,7 @@ from keras import layers
|
|
3
3
|
|
4
4
|
from keras_hub.src.api_export import keras_hub_export
|
5
5
|
from keras_hub.src.models.preprocessor import Preprocessor
|
6
|
-
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
6
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
|
7
7
|
StableDiffusion3Backbone,
|
8
8
|
)
|
9
9
|
|
@@ -42,11 +42,12 @@ class T5Backbone(Backbone):
|
|
42
42
|
projections in the multi-head attention layers. Defaults to
|
43
43
|
hidden_dim / num_heads.
|
44
44
|
dropout: float. Dropout probability for the Transformer layers.
|
45
|
-
activation: activation function
|
46
|
-
|
47
|
-
Transformer layers. Defaults to `"relu"`.
|
45
|
+
activation: string. The activation function to use in the dense blocks
|
46
|
+
of the Transformer Layers.
|
48
47
|
use_gated_activation: boolean. Whether to use activation gating in
|
49
|
-
the inner dense blocks of the Transformer layers.
|
48
|
+
the inner dense blocks of the Transformer layers. When used with
|
49
|
+
the GELU activation function, this is referred to as GEGLU
|
50
|
+
(gated GLU) from https://arxiv.org/pdf/2002.05202.
|
50
51
|
The original T5 architecture didn't use gating, but more
|
51
52
|
recent versions do. Defaults to `True`.
|
52
53
|
layer_norm_epsilon: float. Epsilon factor to be used in the
|
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""T5 model preset configurations."""
|
2
2
|
|
3
3
|
backbone_presets = {
|
4
4
|
"t5_small_multi": {
|
@@ -8,11 +8,17 @@ backbone_presets = {
|
|
8
8
|
"Corpus (C4)."
|
9
9
|
),
|
10
10
|
"params": 0,
|
11
|
-
"official_name": "T5",
|
12
11
|
"path": "t5",
|
13
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/
|
13
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/3",
|
14
|
+
},
|
15
|
+
"t5_1.1_small": {
|
16
|
+
"metadata": {
|
17
|
+
"description": (""),
|
18
|
+
"params": 60511616,
|
19
|
+
"path": "t5",
|
20
|
+
},
|
21
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_small/2",
|
16
22
|
},
|
17
23
|
"t5_base_multi": {
|
18
24
|
"metadata": {
|
@@ -21,11 +27,17 @@ backbone_presets = {
|
|
21
27
|
"Corpus (C4)."
|
22
28
|
),
|
23
29
|
"params": 0,
|
24
|
-
"official_name": "T5",
|
25
30
|
"path": "t5",
|
26
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
27
31
|
},
|
28
|
-
"kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/
|
32
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/3",
|
33
|
+
},
|
34
|
+
"t5_1.1_base": {
|
35
|
+
"metadata": {
|
36
|
+
"description": (""),
|
37
|
+
"params": 247577856,
|
38
|
+
"path": "t5",
|
39
|
+
},
|
40
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_base/2",
|
29
41
|
},
|
30
42
|
"t5_large_multi": {
|
31
43
|
"metadata": {
|
@@ -34,11 +46,33 @@ backbone_presets = {
|
|
34
46
|
"Corpus (C4)."
|
35
47
|
),
|
36
48
|
"params": 0,
|
37
|
-
"official_name": "T5",
|
38
49
|
"path": "t5",
|
39
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
40
50
|
},
|
41
|
-
"kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/
|
51
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/3",
|
52
|
+
},
|
53
|
+
"t5_1.1_large": {
|
54
|
+
"metadata": {
|
55
|
+
"description": (""),
|
56
|
+
"params": 750251008,
|
57
|
+
"path": "t5",
|
58
|
+
},
|
59
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_large/2",
|
60
|
+
},
|
61
|
+
"t5_1.1_xl": {
|
62
|
+
"metadata": {
|
63
|
+
"description": (""),
|
64
|
+
"params": 2849757184,
|
65
|
+
"path": "t5",
|
66
|
+
},
|
67
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xl/2",
|
68
|
+
},
|
69
|
+
"t5_1.1_xxl": {
|
70
|
+
"metadata": {
|
71
|
+
"description": (""),
|
72
|
+
"params": 11135332352,
|
73
|
+
"path": "t5",
|
74
|
+
},
|
75
|
+
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xxl/2",
|
42
76
|
},
|
43
77
|
"flan_small_multi": {
|
44
78
|
"metadata": {
|
@@ -47,11 +81,9 @@ backbone_presets = {
|
|
47
81
|
"Corpus (C4)."
|
48
82
|
),
|
49
83
|
"params": 0,
|
50
|
-
"official_name": "T5",
|
51
84
|
"path": "t5",
|
52
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
53
85
|
},
|
54
|
-
"kaggle_handle": "kaggle://keras/t5/keras/flan_small_multi/
|
86
|
+
"kaggle_handle": "kaggle://keras/t5/keras/flan_small_multi/3",
|
55
87
|
},
|
56
88
|
"flan_base_multi": {
|
57
89
|
"metadata": {
|
@@ -60,11 +92,9 @@ backbone_presets = {
|
|
60
92
|
"Corpus (C4)."
|
61
93
|
),
|
62
94
|
"params": 0,
|
63
|
-
"official_name": "T5",
|
64
95
|
"path": "t5",
|
65
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
66
96
|
},
|
67
|
-
"kaggle_handle": "kaggle://keras/t5/keras/flan_base_multi/
|
97
|
+
"kaggle_handle": "kaggle://keras/t5/keras/flan_base_multi/3",
|
68
98
|
},
|
69
99
|
"flan_large_multi": {
|
70
100
|
"metadata": {
|
@@ -73,10 +103,8 @@ backbone_presets = {
|
|
73
103
|
"Corpus (C4)."
|
74
104
|
),
|
75
105
|
"params": 0,
|
76
|
-
"official_name": "T5",
|
77
106
|
"path": "t5",
|
78
|
-
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
|
79
107
|
},
|
80
|
-
"kaggle_handle": "kaggle://keras/t5/keras/flan_large_multi/
|
108
|
+
"kaggle_handle": "kaggle://keras/t5/keras/flan_large_multi/3",
|
81
109
|
},
|
82
110
|
}
|