keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
keras_hub/src/models/task.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1
|
-
import os
|
2
|
-
|
3
1
|
import keras
|
4
2
|
from rich import console as rich_console
|
5
3
|
from rich import markup
|
6
4
|
from rich import table as rich_table
|
7
5
|
|
8
6
|
from keras_hub.src.api_export import keras_hub_export
|
7
|
+
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
|
8
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
9
|
+
from keras_hub.src.models.backbone import Backbone
|
10
|
+
from keras_hub.src.models.preprocessor import Preprocessor
|
11
|
+
from keras_hub.src.tokenizers.tokenizer import Tokenizer
|
9
12
|
from keras_hub.src.utils.keras_utils import print_msg
|
10
13
|
from keras_hub.src.utils.pipeline_model import PipelineModel
|
11
|
-
from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
|
12
|
-
from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
|
13
14
|
from keras_hub.src.utils.preset_utils import builtin_presets
|
14
15
|
from keras_hub.src.utils.preset_utils import find_subclass
|
15
16
|
from keras_hub.src.utils.preset_utils import get_preset_loader
|
16
|
-
from keras_hub.src.utils.preset_utils import
|
17
|
+
from keras_hub.src.utils.preset_utils import get_preset_saver
|
17
18
|
from keras_hub.src.utils.python_utils import classproperty
|
18
19
|
|
19
20
|
|
@@ -58,10 +59,15 @@ class Task(PipelineModel):
|
|
58
59
|
self.compile()
|
59
60
|
|
60
61
|
def preprocess_samples(self, x, y=None, sample_weight=None):
|
61
|
-
|
62
|
+
# If `preprocessor` is `None`, return inputs unaltered.
|
63
|
+
if self.preprocessor is None:
|
64
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
65
|
+
# If `preprocessor` is `Preprocessor` subclass, pass labels as a kwarg.
|
66
|
+
if isinstance(self.preprocessor, Preprocessor):
|
62
67
|
return self.preprocessor(x, y=y, sample_weight=sample_weight)
|
63
|
-
|
64
|
-
|
68
|
+
# For other layers and callable, do not pass the label.
|
69
|
+
x = self.preprocessor(x)
|
70
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
65
71
|
|
66
72
|
def __setattr__(self, name, value):
|
67
73
|
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
|
@@ -143,7 +149,8 @@ class Task(PipelineModel):
|
|
143
149
|
|
144
150
|
This constructor can be called in one of two ways. Either from a task
|
145
151
|
specific base class like `keras_hub.models.CausalLM.from_preset()`, or
|
146
|
-
from a model class like
|
152
|
+
from a model class like
|
153
|
+
`keras_hub.models.BertTextClassifier.from_preset()`.
|
147
154
|
If calling from the a base class, the subclass of the returning object
|
148
155
|
will be inferred from the config in the preset directory.
|
149
156
|
|
@@ -178,7 +185,10 @@ class Task(PipelineModel):
|
|
178
185
|
loader = get_preset_loader(preset)
|
179
186
|
backbone_cls = loader.check_backbone_class()
|
180
187
|
# Detect the correct subclass if we need to.
|
181
|
-
if
|
188
|
+
if (
|
189
|
+
issubclass(backbone_cls, Backbone)
|
190
|
+
and cls.backbone_cls != backbone_cls
|
191
|
+
):
|
182
192
|
cls = find_subclass(preset, cls, backbone_cls)
|
183
193
|
# Specifically for classifiers, we never load task weights if
|
184
194
|
# num_classes is supplied. We handle this in the task base class because
|
@@ -232,17 +242,8 @@ class Task(PipelineModel):
|
|
232
242
|
Args:
|
233
243
|
preset_dir: The path to the local model preset directory.
|
234
244
|
"""
|
235
|
-
|
236
|
-
|
237
|
-
"Cannot save `task` to preset: `Preprocessor` is not initialized."
|
238
|
-
)
|
239
|
-
|
240
|
-
save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE)
|
241
|
-
if self.has_task_weights():
|
242
|
-
self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE))
|
243
|
-
|
244
|
-
self.preprocessor.save_to_preset(preset_dir)
|
245
|
-
self.backbone.save_to_preset(preset_dir)
|
245
|
+
saver = get_preset_saver(preset_dir)
|
246
|
+
saver.save_task(self)
|
246
247
|
|
247
248
|
@property
|
248
249
|
def layers(self):
|
@@ -280,7 +281,7 @@ class Task(PipelineModel):
|
|
280
281
|
|
281
282
|
def highlight_number(x):
|
282
283
|
if x is None:
|
283
|
-
f"[color(45)]{x}[/]"
|
284
|
+
return f"[color(45)]{x}[/]"
|
284
285
|
return f"[color(34)]{x:,}[/]" # Format number with commas.
|
285
286
|
|
286
287
|
def highlight_symbol(x):
|
@@ -294,7 +295,8 @@ class Task(PipelineModel):
|
|
294
295
|
return "(" + ", ".join(highlighted) + ")"
|
295
296
|
|
296
297
|
if self.preprocessor:
|
297
|
-
# Create a rich console for printing. Capture for non-interactive
|
298
|
+
# Create a rich console for printing. Capture for non-interactive
|
299
|
+
# logging.
|
298
300
|
if print_fn:
|
299
301
|
console = rich_console.Console(
|
300
302
|
highlight=False, force_terminal=False, color_system=None
|
@@ -327,24 +329,30 @@ class Task(PipelineModel):
|
|
327
329
|
info,
|
328
330
|
)
|
329
331
|
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
332
|
+
# Since the preprocessor might be nested with multiple `Tokenizer`,
|
333
|
+
# `ImageConverter`, `AudioConverter` and even other `Preprocessor`
|
334
|
+
# instances, we should recursively iterate through them.
|
335
|
+
preprocessor = self.preprocessor
|
336
|
+
if preprocessor and isinstance(preprocessor, keras.Layer):
|
337
|
+
for layer in preprocessor._flatten_layers(include_self=False):
|
338
|
+
if isinstance(layer, Tokenizer):
|
339
|
+
info = "Vocab size: "
|
340
|
+
info += highlight_number(layer.vocabulary_size())
|
341
|
+
add_layer(layer, info)
|
342
|
+
elif isinstance(layer, ImageConverter):
|
343
|
+
info = "Image size: "
|
344
|
+
image_size = layer.image_size
|
345
|
+
if image_size is None:
|
346
|
+
image_size = (None, None)
|
347
|
+
info += highlight_shape(image_size)
|
348
|
+
add_layer(layer, info)
|
349
|
+
elif isinstance(layer, AudioConverter):
|
350
|
+
info = "Audio shape: "
|
351
|
+
info += highlight_shape(layer.audio_shape())
|
352
|
+
add_layer(layer, info)
|
345
353
|
|
346
354
|
# Print the to the console.
|
347
|
-
preprocessor_name = markup.escape(
|
355
|
+
preprocessor_name = markup.escape(preprocessor.name)
|
348
356
|
console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
|
349
357
|
console.print(table)
|
350
358
|
|
@@ -21,8 +21,8 @@ class TextClassifier(Task):
|
|
21
21
|
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
22
22
|
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
23
23
|
|
24
|
-
All `TextClassifier` tasks include a `from_preset()` constructor which can
|
25
|
-
used to load a pre-trained config and weights.
|
24
|
+
All `TextClassifier` tasks include a `from_preset()` constructor which can
|
25
|
+
be used to load a pre-trained config and weights.
|
26
26
|
|
27
27
|
Some, but not all, classification presets include classification head
|
28
28
|
weights in a `task.weights.h5` file. For these presets, you can omit passing
|
@@ -56,6 +56,11 @@ class TextToImage(Task):
|
|
56
56
|
# Default compilation.
|
57
57
|
self.compile()
|
58
58
|
|
59
|
+
@property
|
60
|
+
def support_negative_prompts(self):
|
61
|
+
"""Whether the model supports `negative_prompts` key in `generate()`."""
|
62
|
+
return bool(True)
|
63
|
+
|
59
64
|
@property
|
60
65
|
def latent_shape(self):
|
61
66
|
return tuple(self.backbone.latent_shape)
|
@@ -171,9 +176,26 @@ class TextToImage(Task):
|
|
171
176
|
This function converts all inputs to tensors, adds a batch dimension if
|
172
177
|
necessary, and returns a iterable "dataset like" object (either an
|
173
178
|
actual `tf.data.Dataset` or a list with a single batch element).
|
179
|
+
|
180
|
+
The input format must be one of the following:
|
181
|
+
- A single string
|
182
|
+
- A list of strings
|
183
|
+
- A dict with "prompts" and/or "negative_prompts" keys
|
184
|
+
- A tf.data.Dataset with "prompts" and/or "negative_prompts" keys
|
185
|
+
|
186
|
+
The output will be a dict with "prompts" and/or "negative_prompts" keys.
|
174
187
|
"""
|
175
188
|
if tf and isinstance(inputs, tf.data.Dataset):
|
176
|
-
|
189
|
+
_inputs = {
|
190
|
+
"prompts": inputs.map(
|
191
|
+
lambda x: x["prompts"]
|
192
|
+
).as_numpy_iterator()
|
193
|
+
}
|
194
|
+
if self.support_negative_prompts:
|
195
|
+
_inputs["negative_prompts"] = inputs.map(
|
196
|
+
lambda x: x["negative_prompts"]
|
197
|
+
).as_numpy_iterator()
|
198
|
+
return _inputs, False
|
177
199
|
|
178
200
|
def normalize(x):
|
179
201
|
if isinstance(x, str):
|
@@ -182,13 +204,24 @@ class TextToImage(Task):
|
|
182
204
|
return x[tf.newaxis], True
|
183
205
|
return x, False
|
184
206
|
|
207
|
+
def get_dummy_prompts(x):
|
208
|
+
dummy_prompts = [""] * len(x)
|
209
|
+
if tf and isinstance(x, tf.Tensor):
|
210
|
+
return tf.convert_to_tensor(dummy_prompts)
|
211
|
+
else:
|
212
|
+
return dummy_prompts
|
213
|
+
|
185
214
|
if isinstance(inputs, dict):
|
186
215
|
for key in inputs:
|
187
216
|
inputs[key], input_is_scalar = normalize(inputs[key])
|
188
217
|
else:
|
189
218
|
inputs, input_is_scalar = normalize(inputs)
|
219
|
+
inputs = {"prompts": inputs}
|
190
220
|
|
191
|
-
|
221
|
+
if self.support_negative_prompts and "negative_prompts" not in inputs:
|
222
|
+
inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"])
|
223
|
+
|
224
|
+
return [inputs], input_is_scalar
|
192
225
|
|
193
226
|
def _normalize_generate_outputs(self, outputs, input_is_scalar):
|
194
227
|
"""Normalize user output from the generate function.
|
@@ -199,12 +232,11 @@ class TextToImage(Task):
|
|
199
232
|
"""
|
200
233
|
|
201
234
|
def normalize(x):
|
202
|
-
outputs = ops.
|
235
|
+
outputs = ops.concatenate(x, axis=0)
|
236
|
+
outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0)
|
203
237
|
outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8")
|
204
|
-
outputs = ops.
|
205
|
-
|
206
|
-
outputs = outputs[0]
|
207
|
-
return outputs
|
238
|
+
outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
|
239
|
+
return ops.convert_to_numpy(outputs)
|
208
240
|
|
209
241
|
if isinstance(outputs[0], dict):
|
210
242
|
normalized = {}
|
@@ -216,33 +248,62 @@ class TextToImage(Task):
|
|
216
248
|
def generate(
|
217
249
|
self,
|
218
250
|
inputs,
|
219
|
-
negative_inputs,
|
220
251
|
num_steps,
|
221
|
-
guidance_scale,
|
252
|
+
guidance_scale=None,
|
222
253
|
seed=None,
|
223
254
|
):
|
224
|
-
"""Generate image based on the provided `inputs
|
255
|
+
"""Generate image based on the provided `inputs`.
|
256
|
+
|
257
|
+
Typically, `inputs` contains a text description (known as a prompt) used
|
258
|
+
to guide the image generation.
|
259
|
+
|
260
|
+
Some models support a `negative_prompts` key, which helps steer the
|
261
|
+
model away from generating certain styles and elements. To enable this,
|
262
|
+
pass `prompts` and `negative_prompts` as a dict:
|
263
|
+
|
264
|
+
```python
|
265
|
+
prompt = (
|
266
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
267
|
+
"detailed, 8k"
|
268
|
+
)
|
269
|
+
text_to_image.generate(
|
270
|
+
{
|
271
|
+
"prompts": prompt,
|
272
|
+
"negative_prompts": "green color",
|
273
|
+
}
|
274
|
+
)
|
275
|
+
```
|
225
276
|
|
226
277
|
If `inputs` are a `tf.data.Dataset`, outputs will be generated
|
227
278
|
"batch-by-batch" and concatenated. Otherwise, all inputs will be
|
228
279
|
processed as batches.
|
229
280
|
|
230
281
|
Args:
|
231
|
-
inputs: python data, tensor data, or a `tf.data.Dataset`.
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
282
|
+
inputs: python data, tensor data, or a `tf.data.Dataset`. The format
|
283
|
+
must be one of the following:
|
284
|
+
- A single string
|
285
|
+
- A list of strings
|
286
|
+
- A dict with "prompts" and/or "negative_prompts" keys
|
287
|
+
- A `tf.data.Dataset` with "prompts" and/or "negative_prompts"
|
288
|
+
keys
|
236
289
|
num_steps: int. The number of diffusion steps to take.
|
237
|
-
guidance_scale: float. The classifier free guidance scale
|
238
|
-
[Classifier-Free Diffusion Guidance](
|
290
|
+
guidance_scale: Optional float. The classifier free guidance scale
|
291
|
+
defined in [Classifier-Free Diffusion Guidance](
|
239
292
|
https://arxiv.org/abs/2207.12598). A higher scale encourages
|
240
293
|
generating images more closely related to the prompts, typically
|
241
|
-
at the cost of lower image quality.
|
294
|
+
at the cost of lower image quality. Note that some models don't
|
295
|
+
utilize classifier-free guidance.
|
242
296
|
seed: optional int. Used as a random seed.
|
243
297
|
"""
|
298
|
+
num_steps = int(num_steps)
|
299
|
+
guidance_scale = (
|
300
|
+
float(guidance_scale) if guidance_scale is not None else None
|
301
|
+
)
|
244
302
|
num_steps = ops.convert_to_tensor(num_steps, "int32")
|
245
|
-
guidance_scale
|
303
|
+
if guidance_scale is not None and guidance_scale > 1.0:
|
304
|
+
guidance_scale = ops.convert_to_tensor(guidance_scale)
|
305
|
+
else:
|
306
|
+
guidance_scale = None
|
246
307
|
|
247
308
|
# Setup our three main passes.
|
248
309
|
# 1. Preprocessing strings to dense integer tensors.
|
@@ -251,32 +312,36 @@ class TextToImage(Task):
|
|
251
312
|
generate_function = self.make_generate_function()
|
252
313
|
|
253
314
|
def preprocess(x):
|
254
|
-
|
315
|
+
if self.preprocessor is not None:
|
316
|
+
return self.preprocessor.generate_preprocess(x)
|
317
|
+
else:
|
318
|
+
return x
|
319
|
+
|
320
|
+
def generate(x):
|
321
|
+
token_ids = x[0] if self.support_negative_prompts else x
|
322
|
+
|
323
|
+
# Initialize latents.
|
324
|
+
if isinstance(token_ids, dict):
|
325
|
+
arbitrary_key = list(token_ids.keys())[0]
|
326
|
+
batch_size = ops.shape(token_ids[arbitrary_key])[0]
|
327
|
+
else:
|
328
|
+
batch_size = ops.shape(token_ids)[0]
|
329
|
+
latent_shape = (batch_size,) + self.latent_shape[1:]
|
330
|
+
latents = random.normal(latent_shape, dtype="float32", seed=seed)
|
331
|
+
|
332
|
+
return generate_function(latents, x, num_steps, guidance_scale)
|
255
333
|
|
256
334
|
# Normalize and preprocess inputs.
|
257
335
|
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
|
258
|
-
if
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
inputs =
|
264
|
-
negative_inputs = preprocess(negative_inputs)
|
265
|
-
if isinstance(inputs, dict):
|
266
|
-
batch_size = len(inputs[list(inputs.keys())[0]])
|
336
|
+
if self.support_negative_prompts:
|
337
|
+
token_ids = [preprocess(x["prompts"]) for x in inputs]
|
338
|
+
negative_token_ids = [
|
339
|
+
preprocess(x["negative_prompts"]) for x in inputs
|
340
|
+
]
|
341
|
+
inputs = [x for x in zip(token_ids, negative_token_ids)]
|
267
342
|
else:
|
268
|
-
|
269
|
-
|
270
|
-
# Initialize random latents.
|
271
|
-
latent_shape = (batch_size,) + self.latent_shape[1:]
|
272
|
-
latents = random.normal(latent_shape, dtype="float32", seed=seed)
|
343
|
+
inputs = [preprocess(x["prompts"]) for x in inputs]
|
273
344
|
|
274
345
|
# Text-to-image.
|
275
|
-
outputs =
|
276
|
-
latents,
|
277
|
-
inputs,
|
278
|
-
negative_inputs,
|
279
|
-
num_steps,
|
280
|
-
guidance_scale,
|
281
|
-
)
|
346
|
+
outputs = [generate(x) for x in inputs]
|
282
347
|
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
@@ -0,0 +1 @@
|
|
1
|
+
from keras_hub.src.models.vae.vae_backbone import VAEBackbone
|
@@ -0,0 +1,184 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.models.backbone import Backbone
|
4
|
+
from keras_hub.src.models.vae.vae_layers import (
|
5
|
+
DiagonalGaussianDistributionSampler,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.vae.vae_layers import VAEDecoder
|
8
|
+
from keras_hub.src.models.vae.vae_layers import VAEEncoder
|
9
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
10
|
+
|
11
|
+
|
12
|
+
class VAEBackbone(Backbone):
|
13
|
+
"""Variational Autoencoder(VAE) backbone used in latent diffusion models.
|
14
|
+
|
15
|
+
When encoding, this model generates mean and log variance of the input
|
16
|
+
images. When decoding, it reconstructs images from the latent space.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
encoder_num_filters: list of ints. The number of filters for each
|
20
|
+
block in encoder.
|
21
|
+
encoder_num_blocks: list of ints. The number of blocks for each block in
|
22
|
+
encoder.
|
23
|
+
decoder_num_filters: list of ints. The number of filters for each
|
24
|
+
block in decoder.
|
25
|
+
decoder_num_blocks: list of ints. The number of blocks for each block in
|
26
|
+
decoder.
|
27
|
+
sampler_method: str. The method of the sampler for the intermediate
|
28
|
+
output. Available methods are `"sample"` and `"mode"`. `"sample"`
|
29
|
+
draws from the distribution using both the mean and log variance.
|
30
|
+
`"mode"` draws from the distribution using the mean only. Defaults
|
31
|
+
to `sample`.
|
32
|
+
input_channels: int. The number of channels in the input.
|
33
|
+
sample_channels: int. The number of channels in the sample. Typically,
|
34
|
+
this indicates the intermediate output of VAE, which is mean and
|
35
|
+
log variance.
|
36
|
+
output_channels: int. The number of channels in the output.
|
37
|
+
scale: float. The scaling factor applied to the latent space to ensure
|
38
|
+
it has unit variance during training of the diffusion model.
|
39
|
+
Defaults to `1.5305`, which is the value used in Stable Diffusion 3.
|
40
|
+
shift: float. The shift factor applied to the latent space to ensure it
|
41
|
+
has zero mean during training of the diffusion model. Defaults to
|
42
|
+
`0.0609`, which is the value used in Stable Diffusion 3.
|
43
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
44
|
+
`"channels_first"`. The ordering of the dimensions in the
|
45
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
46
|
+
`(batch_size, height, width, channels)`
|
47
|
+
while `"channels_first"` corresponds to inputs with shape
|
48
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
49
|
+
`image_data_format` value found in your Keras config file at
|
50
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
51
|
+
`"channels_last"`.
|
52
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
53
|
+
to use for the model's computations and weights.
|
54
|
+
|
55
|
+
Example:
|
56
|
+
```Python
|
57
|
+
backbone = VAEBackbone(
|
58
|
+
encoder_num_filters=[32, 32, 32, 32],
|
59
|
+
encoder_num_blocks=[1, 1, 1, 1],
|
60
|
+
decoder_num_filters=[32, 32, 32, 32],
|
61
|
+
decoder_num_blocks=[1, 1, 1, 1],
|
62
|
+
)
|
63
|
+
input_data = ops.ones((2, self.height, self.width, 3))
|
64
|
+
output = backbone(input_data)
|
65
|
+
```
|
66
|
+
"""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
encoder_num_filters,
|
71
|
+
encoder_num_blocks,
|
72
|
+
decoder_num_filters,
|
73
|
+
decoder_num_blocks,
|
74
|
+
sampler_method="sample",
|
75
|
+
input_channels=3,
|
76
|
+
sample_channels=32,
|
77
|
+
output_channels=3,
|
78
|
+
scale=1.5305,
|
79
|
+
shift=0.0609,
|
80
|
+
data_format=None,
|
81
|
+
dtype=None,
|
82
|
+
**kwargs,
|
83
|
+
):
|
84
|
+
data_format = standardize_data_format(data_format)
|
85
|
+
if data_format == "channels_last":
|
86
|
+
image_shape = (None, None, input_channels)
|
87
|
+
channel_axis = -1
|
88
|
+
else:
|
89
|
+
image_shape = (input_channels, None, None)
|
90
|
+
channel_axis = 1
|
91
|
+
|
92
|
+
# === Layers ===
|
93
|
+
self.encoder = VAEEncoder(
|
94
|
+
encoder_num_filters,
|
95
|
+
encoder_num_blocks,
|
96
|
+
output_channels=sample_channels,
|
97
|
+
data_format=data_format,
|
98
|
+
dtype=dtype,
|
99
|
+
name="encoder",
|
100
|
+
)
|
101
|
+
# Use `sample()` to define the functional model.
|
102
|
+
self.distribution_sampler = DiagonalGaussianDistributionSampler(
|
103
|
+
method=sampler_method,
|
104
|
+
axis=channel_axis,
|
105
|
+
dtype=dtype,
|
106
|
+
name="distribution_sampler",
|
107
|
+
)
|
108
|
+
self.decoder = VAEDecoder(
|
109
|
+
decoder_num_filters,
|
110
|
+
decoder_num_blocks,
|
111
|
+
output_channels=output_channels,
|
112
|
+
data_format=data_format,
|
113
|
+
dtype=dtype,
|
114
|
+
name="decoder",
|
115
|
+
)
|
116
|
+
|
117
|
+
# === Functional Model ===
|
118
|
+
image_input = keras.Input(shape=image_shape)
|
119
|
+
sample = self.encoder(image_input)
|
120
|
+
latent = self.distribution_sampler(sample)
|
121
|
+
image_output = self.decoder(latent)
|
122
|
+
super().__init__(
|
123
|
+
inputs=image_input,
|
124
|
+
outputs=image_output,
|
125
|
+
dtype=dtype,
|
126
|
+
**kwargs,
|
127
|
+
)
|
128
|
+
|
129
|
+
# === Config ===
|
130
|
+
self.encoder_num_filters = encoder_num_filters
|
131
|
+
self.encoder_num_blocks = encoder_num_blocks
|
132
|
+
self.decoder_num_filters = decoder_num_filters
|
133
|
+
self.decoder_num_blocks = decoder_num_blocks
|
134
|
+
self.sampler_method = sampler_method
|
135
|
+
self.input_channels = input_channels
|
136
|
+
self.sample_channels = sample_channels
|
137
|
+
self.output_channels = output_channels
|
138
|
+
self._scale = scale
|
139
|
+
self._shift = shift
|
140
|
+
|
141
|
+
@property
|
142
|
+
def scale(self):
|
143
|
+
"""The scaling factor for the latent space.
|
144
|
+
|
145
|
+
This is used to scale the latent space to have unit variance when
|
146
|
+
training the diffusion model.
|
147
|
+
"""
|
148
|
+
return self._scale
|
149
|
+
|
150
|
+
@property
|
151
|
+
def shift(self):
|
152
|
+
"""The shift factor for the latent space.
|
153
|
+
|
154
|
+
This is used to shift the latent space to have zero mean when
|
155
|
+
training the diffusion model.
|
156
|
+
"""
|
157
|
+
return self._shift
|
158
|
+
|
159
|
+
def encode(self, inputs, **kwargs):
|
160
|
+
"""Encode the input images into latent space."""
|
161
|
+
sample = self.encoder(inputs, **kwargs)
|
162
|
+
return self.distribution_sampler(sample)
|
163
|
+
|
164
|
+
def decode(self, inputs, **kwargs):
|
165
|
+
"""Decode the input latent space into images."""
|
166
|
+
return self.decoder(inputs, **kwargs)
|
167
|
+
|
168
|
+
def get_config(self):
|
169
|
+
config = super().get_config()
|
170
|
+
config.update(
|
171
|
+
{
|
172
|
+
"encoder_num_filters": self.encoder_num_filters,
|
173
|
+
"encoder_num_blocks": self.encoder_num_blocks,
|
174
|
+
"decoder_num_filters": self.decoder_num_filters,
|
175
|
+
"decoder_num_blocks": self.decoder_num_blocks,
|
176
|
+
"sampler_method": self.sampler_method,
|
177
|
+
"input_channels": self.input_channels,
|
178
|
+
"sample_channels": self.sample_channels,
|
179
|
+
"output_channels": self.output_channels,
|
180
|
+
"scale": self.scale,
|
181
|
+
"shift": self.shift,
|
182
|
+
}
|
183
|
+
)
|
184
|
+
return config
|