keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409270338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -13
- keras_hub/api/__init__.py +0 -13
- keras_hub/api/bounding_box/__init__.py +0 -13
- keras_hub/api/layers/__init__.py +3 -13
- keras_hub/api/metrics/__init__.py +0 -13
- keras_hub/api/models/__init__.py +16 -13
- keras_hub/api/samplers/__init__.py +0 -13
- keras_hub/api/tokenizers/__init__.py +1 -13
- keras_hub/api/utils/__init__.py +0 -13
- keras_hub/src/__init__.py +0 -13
- keras_hub/src/api_export.py +0 -14
- keras_hub/src/bounding_box/__init__.py +0 -13
- keras_hub/src/bounding_box/converters.py +0 -13
- keras_hub/src/bounding_box/formats.py +0 -13
- keras_hub/src/bounding_box/iou.py +1 -13
- keras_hub/src/bounding_box/to_dense.py +0 -14
- keras_hub/src/bounding_box/to_ragged.py +0 -13
- keras_hub/src/bounding_box/utils.py +0 -13
- keras_hub/src/bounding_box/validate_format.py +0 -14
- keras_hub/src/layers/__init__.py +0 -13
- keras_hub/src/layers/modeling/__init__.py +0 -13
- keras_hub/src/layers/modeling/alibi_bias.py +0 -13
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +0 -14
- keras_hub/src/layers/modeling/f_net_encoder.py +0 -14
- keras_hub/src/layers/modeling/masked_lm_head.py +0 -14
- keras_hub/src/layers/modeling/position_embedding.py +0 -14
- keras_hub/src/layers/modeling/reversible_embedding.py +0 -14
- keras_hub/src/layers/modeling/rotary_embedding.py +0 -14
- keras_hub/src/layers/modeling/sine_position_encoding.py +0 -14
- keras_hub/src/layers/modeling/token_and_position_embedding.py +0 -14
- keras_hub/src/layers/modeling/transformer_decoder.py +0 -14
- keras_hub/src/layers/modeling/transformer_encoder.py +0 -14
- keras_hub/src/layers/modeling/transformer_layer_utils.py +0 -14
- keras_hub/src/layers/preprocessing/__init__.py +0 -13
- keras_hub/src/layers/preprocessing/audio_converter.py +0 -13
- keras_hub/src/layers/preprocessing/image_converter.py +0 -13
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +0 -15
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +0 -14
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +0 -14
- keras_hub/src/layers/preprocessing/random_deletion.py +0 -14
- keras_hub/src/layers/preprocessing/random_swap.py +0 -14
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -13
- keras_hub/src/layers/preprocessing/start_end_packer.py +0 -15
- keras_hub/src/metrics/__init__.py +0 -13
- keras_hub/src/metrics/bleu.py +0 -14
- keras_hub/src/metrics/edit_distance.py +0 -14
- keras_hub/src/metrics/perplexity.py +0 -14
- keras_hub/src/metrics/rouge_base.py +0 -14
- keras_hub/src/metrics/rouge_l.py +0 -14
- keras_hub/src/metrics/rouge_n.py +0 -14
- keras_hub/src/models/__init__.py +0 -13
- keras_hub/src/models/albert/__init__.py +0 -14
- keras_hub/src/models/albert/albert_backbone.py +0 -14
- keras_hub/src/models/albert/albert_masked_lm.py +0 -14
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/albert/albert_presets.py +0 -14
- keras_hub/src/models/albert/albert_text_classifier.py +0 -14
- keras_hub/src/models/albert/albert_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/albert/albert_tokenizer.py +0 -14
- keras_hub/src/models/backbone.py +0 -14
- keras_hub/src/models/bart/__init__.py +0 -14
- keras_hub/src/models/bart/bart_backbone.py +0 -14
- keras_hub/src/models/bart/bart_presets.py +0 -13
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +0 -15
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +0 -15
- keras_hub/src/models/bart/bart_tokenizer.py +0 -15
- keras_hub/src/models/bert/__init__.py +0 -14
- keras_hub/src/models/bert/bert_backbone.py +0 -14
- keras_hub/src/models/bert/bert_masked_lm.py +0 -14
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/bert/bert_presets.py +0 -13
- keras_hub/src/models/bert/bert_text_classifier.py +0 -14
- keras_hub/src/models/bert/bert_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/bert/bert_tokenizer.py +0 -14
- keras_hub/src/models/bloom/__init__.py +0 -14
- keras_hub/src/models/bloom/bloom_attention.py +0 -13
- keras_hub/src/models/bloom/bloom_backbone.py +0 -14
- keras_hub/src/models/bloom/bloom_causal_lm.py +0 -15
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +0 -15
- keras_hub/src/models/bloom/bloom_decoder.py +0 -13
- keras_hub/src/models/bloom/bloom_presets.py +0 -13
- keras_hub/src/models/bloom/bloom_tokenizer.py +0 -15
- keras_hub/src/models/causal_lm.py +0 -14
- keras_hub/src/models/causal_lm_preprocessor.py +0 -13
- keras_hub/src/models/clip/__init__.py +0 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -15
- keras_hub/src/models/clip/clip_preprocessor.py +134 -0
- keras_hub/src/models/clip/clip_text_encoder.py +139 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +65 -41
- keras_hub/src/models/csp_darknet/__init__.py +0 -13
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -13
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -13
- keras_hub/src/models/deberta_v3/__init__.py +0 -14
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +0 -15
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +0 -15
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -13
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +0 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +0 -15
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +0 -14
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +0 -14
- keras_hub/src/models/deberta_v3/relative_embedding.py +0 -14
- keras_hub/src/models/densenet/__init__.py +5 -13
- keras_hub/src/models/densenet/densenet_backbone.py +11 -21
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -17
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/{stable_diffusion_v3/__init__.py → densenet/densenet_image_converter.py} +10 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/distil_bert/__init__.py +0 -14
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -13
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +0 -15
- keras_hub/src/models/efficientnet/__init__.py +0 -13
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +0 -13
- keras_hub/src/models/efficientnet/fusedmbconv.py +0 -14
- keras_hub/src/models/efficientnet/mbconv.py +0 -14
- keras_hub/src/models/electra/__init__.py +0 -14
- keras_hub/src/models/electra/electra_backbone.py +0 -14
- keras_hub/src/models/electra/electra_presets.py +0 -13
- keras_hub/src/models/electra/electra_tokenizer.py +0 -14
- keras_hub/src/models/f_net/__init__.py +0 -14
- keras_hub/src/models/f_net/f_net_backbone.py +0 -15
- keras_hub/src/models/f_net/f_net_masked_lm.py +0 -15
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/f_net/f_net_presets.py +0 -13
- keras_hub/src/models/f_net/f_net_text_classifier.py +0 -15
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +0 -15
- keras_hub/src/models/f_net/f_net_tokenizer.py +0 -15
- keras_hub/src/models/falcon/__init__.py +0 -14
- keras_hub/src/models/falcon/falcon_attention.py +0 -13
- keras_hub/src/models/falcon/falcon_backbone.py +0 -13
- keras_hub/src/models/falcon/falcon_causal_lm.py +0 -14
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/falcon/falcon_presets.py +0 -13
- keras_hub/src/models/falcon/falcon_tokenizer.py +0 -15
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +0 -13
- keras_hub/src/models/feature_pyramid_backbone.py +0 -13
- keras_hub/src/models/gemma/__init__.py +0 -14
- keras_hub/src/models/gemma/gemma_attention.py +0 -13
- keras_hub/src/models/gemma/gemma_backbone.py +0 -15
- keras_hub/src/models/gemma/gemma_causal_lm.py +0 -15
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/gemma/gemma_decoder_block.py +0 -13
- keras_hub/src/models/gemma/gemma_presets.py +0 -13
- keras_hub/src/models/gemma/gemma_tokenizer.py +0 -14
- keras_hub/src/models/gemma/rms_normalization.py +0 -14
- keras_hub/src/models/gpt2/__init__.py +0 -14
- keras_hub/src/models/gpt2/gpt2_backbone.py +0 -15
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +0 -15
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +0 -15
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -13
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +0 -15
- keras_hub/src/models/gpt_neo_x/__init__.py +0 -13
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +0 -14
- keras_hub/src/models/image_classifier.py +0 -13
- keras_hub/src/models/image_classifier_preprocessor.py +0 -13
- keras_hub/src/models/image_segmenter.py +0 -13
- keras_hub/src/models/llama/__init__.py +0 -14
- keras_hub/src/models/llama/llama_attention.py +0 -13
- keras_hub/src/models/llama/llama_backbone.py +0 -13
- keras_hub/src/models/llama/llama_causal_lm.py +0 -13
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +0 -15
- keras_hub/src/models/llama/llama_decoder.py +0 -13
- keras_hub/src/models/llama/llama_layernorm.py +0 -13
- keras_hub/src/models/llama/llama_presets.py +0 -13
- keras_hub/src/models/llama/llama_tokenizer.py +0 -14
- keras_hub/src/models/llama3/__init__.py +0 -14
- keras_hub/src/models/llama3/llama3_backbone.py +0 -14
- keras_hub/src/models/llama3/llama3_causal_lm.py +0 -13
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/llama3/llama3_presets.py +0 -13
- keras_hub/src/models/llama3/llama3_tokenizer.py +0 -14
- keras_hub/src/models/masked_lm.py +0 -13
- keras_hub/src/models/masked_lm_preprocessor.py +0 -13
- keras_hub/src/models/mistral/__init__.py +0 -14
- keras_hub/src/models/mistral/mistral_attention.py +0 -13
- keras_hub/src/models/mistral/mistral_backbone.py +0 -14
- keras_hub/src/models/mistral/mistral_causal_lm.py +0 -14
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/mistral/mistral_layer_norm.py +0 -13
- keras_hub/src/models/mistral/mistral_presets.py +0 -13
- keras_hub/src/models/mistral/mistral_tokenizer.py +0 -14
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +0 -13
- keras_hub/src/models/mix_transformer/__init__.py +0 -13
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +0 -13
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -13
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +0 -13
- keras_hub/src/models/mobilenet/__init__.py +0 -13
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +0 -13
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -13
- keras_hub/src/models/opt/__init__.py +0 -14
- keras_hub/src/models/opt/opt_backbone.py +0 -15
- keras_hub/src/models/opt/opt_causal_lm.py +0 -15
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +0 -13
- keras_hub/src/models/opt/opt_presets.py +0 -13
- keras_hub/src/models/opt/opt_tokenizer.py +0 -15
- keras_hub/src/models/pali_gemma/__init__.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +0 -14
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +0 -13
- keras_hub/src/models/phi3/__init__.py +0 -14
- keras_hub/src/models/phi3/phi3_attention.py +0 -13
- keras_hub/src/models/phi3/phi3_backbone.py +0 -13
- keras_hub/src/models/phi3/phi3_causal_lm.py +0 -13
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/phi3/phi3_decoder.py +0 -13
- keras_hub/src/models/phi3/phi3_layernorm.py +0 -13
- keras_hub/src/models/phi3/phi3_presets.py +0 -13
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +0 -13
- keras_hub/src/models/phi3/phi3_tokenizer.py +0 -13
- keras_hub/src/models/preprocessor.py +51 -32
- keras_hub/src/models/resnet/__init__.py +0 -14
- keras_hub/src/models/resnet/resnet_backbone.py +0 -13
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -13
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +0 -14
- keras_hub/src/models/resnet/resnet_image_converter.py +0 -13
- keras_hub/src/models/resnet/resnet_presets.py +0 -13
- keras_hub/src/models/retinanet/__init__.py +0 -13
- keras_hub/src/models/retinanet/anchor_generator.py +0 -14
- keras_hub/src/models/retinanet/box_matcher.py +0 -14
- keras_hub/src/models/retinanet/non_max_supression.py +0 -14
- keras_hub/src/models/roberta/__init__.py +0 -14
- keras_hub/src/models/roberta/roberta_backbone.py +0 -15
- keras_hub/src/models/roberta/roberta_masked_lm.py +0 -15
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/roberta/roberta_presets.py +0 -13
- keras_hub/src/models/roberta/roberta_text_classifier.py +0 -15
- keras_hub/src/models/roberta/roberta_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/roberta/roberta_tokenizer.py +0 -15
- keras_hub/src/models/sam/__init__.py +0 -13
- keras_hub/src/models/sam/sam_backbone.py +0 -14
- keras_hub/src/models/sam/sam_image_segmenter.py +0 -14
- keras_hub/src/models/sam/sam_layers.py +0 -14
- keras_hub/src/models/sam/sam_mask_decoder.py +0 -14
- keras_hub/src/models/sam/sam_prompt_encoder.py +0 -14
- keras_hub/src/models/sam/sam_transformer.py +0 -14
- keras_hub/src/models/seq_2_seq_lm.py +0 -13
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +0 -13
- keras_hub/src/models/stable_diffusion_3/__init__.py +9 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +80 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -39
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +631 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +31 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +138 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +83 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -20
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +320 -0
- keras_hub/src/models/t5/__init__.py +0 -14
- keras_hub/src/models/t5/t5_backbone.py +0 -14
- keras_hub/src/models/t5/t5_layer_norm.py +0 -14
- keras_hub/src/models/t5/t5_multi_head_attention.py +0 -14
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -16
- keras_hub/src/models/t5/t5_presets.py +0 -13
- keras_hub/src/models/t5/t5_tokenizer.py +0 -14
- keras_hub/src/models/t5/t5_transformer_layer.py +0 -14
- keras_hub/src/models/task.py +0 -14
- keras_hub/src/models/text_classifier.py +0 -13
- keras_hub/src/models/text_classifier_preprocessor.py +0 -13
- keras_hub/src/models/text_to_image.py +282 -0
- keras_hub/src/models/vgg/__init__.py +0 -13
- keras_hub/src/models/vgg/vgg_backbone.py +0 -13
- keras_hub/src/models/vgg/vgg_image_classifier.py +0 -13
- keras_hub/src/models/vit_det/__init__.py +0 -13
- keras_hub/src/models/vit_det/vit_det_backbone.py +0 -14
- keras_hub/src/models/vit_det/vit_layers.py +0 -15
- keras_hub/src/models/whisper/__init__.py +0 -14
- keras_hub/src/models/whisper/whisper_audio_converter.py +0 -15
- keras_hub/src/models/whisper/whisper_backbone.py +0 -15
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +0 -13
- keras_hub/src/models/whisper/whisper_decoder.py +0 -14
- keras_hub/src/models/whisper/whisper_encoder.py +0 -14
- keras_hub/src/models/whisper/whisper_presets.py +0 -14
- keras_hub/src/models/whisper/whisper_tokenizer.py +0 -14
- keras_hub/src/models/xlm_roberta/__init__.py +0 -14
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -13
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +0 -15
- keras_hub/src/models/xlnet/__init__.py +0 -13
- keras_hub/src/models/xlnet/relative_attention.py +0 -14
- keras_hub/src/models/xlnet/xlnet_backbone.py +0 -14
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +0 -14
- keras_hub/src/models/xlnet/xlnet_encoder.py +0 -14
- keras_hub/src/samplers/__init__.py +0 -13
- keras_hub/src/samplers/beam_sampler.py +0 -14
- keras_hub/src/samplers/contrastive_sampler.py +0 -14
- keras_hub/src/samplers/greedy_sampler.py +0 -14
- keras_hub/src/samplers/random_sampler.py +0 -14
- keras_hub/src/samplers/sampler.py +0 -14
- keras_hub/src/samplers/serialization.py +0 -14
- keras_hub/src/samplers/top_k_sampler.py +0 -14
- keras_hub/src/samplers/top_p_sampler.py +0 -14
- keras_hub/src/tests/__init__.py +0 -13
- keras_hub/src/tests/test_case.py +0 -14
- keras_hub/src/tokenizers/__init__.py +0 -13
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +0 -14
- keras_hub/src/tokenizers/byte_tokenizer.py +0 -14
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +0 -14
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +0 -14
- keras_hub/src/tokenizers/tokenizer.py +23 -27
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +0 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +0 -14
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +0 -15
- keras_hub/src/utils/__init__.py +0 -13
- keras_hub/src/utils/imagenet/__init__.py +0 -13
- keras_hub/src/utils/imagenet/imagenet_utils.py +0 -13
- keras_hub/src/utils/keras_utils.py +0 -14
- keras_hub/src/utils/pipeline_model.py +0 -14
- keras_hub/src/utils/preset_utils.py +32 -76
- keras_hub/src/utils/python_utils.py +0 -13
- keras_hub/src/utils/tensor_utils.py +0 -14
- keras_hub/src/utils/timm/__init__.py +0 -13
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/convert_resnet.py +0 -13
- keras_hub/src/utils/timm/preset_loader.py +3 -13
- keras_hub/src/utils/transformers/__init__.py +0 -13
- keras_hub/src/utils/transformers/convert_albert.py +0 -13
- keras_hub/src/utils/transformers/convert_bart.py +0 -13
- keras_hub/src/utils/transformers/convert_bert.py +0 -13
- keras_hub/src/utils/transformers/convert_distilbert.py +0 -13
- keras_hub/src/utils/transformers/convert_gemma.py +0 -13
- keras_hub/src/utils/transformers/convert_gpt2.py +0 -13
- keras_hub/src/utils/transformers/convert_llama3.py +0 -13
- keras_hub/src/utils/transformers/convert_mistral.py +0 -13
- keras_hub/src/utils/transformers/convert_pali_gemma.py +0 -13
- keras_hub/src/utils/transformers/preset_loader.py +1 -15
- keras_hub/src/utils/transformers/safetensor_utils.py +9 -15
- keras_hub/src/version_utils.py +1 -15
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/METADATA +30 -27
- keras_hub_nightly-0.16.1.dev202409270338.dist-info/RECORD +351 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +0 -149
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- keras_hub_nightly-0.16.1.dev202409250340.dist-info/RECORD +0 -342
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/top_level.txt +0 -0
@@ -1,317 +0,0 @@
|
|
1
|
-
# Copyright 2024 The KerasHub Authors
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
import math
|
15
|
-
|
16
|
-
from keras import layers
|
17
|
-
from keras import models
|
18
|
-
from keras import ops
|
19
|
-
|
20
|
-
from keras_hub.src.utils.keras_utils import gelu_approximate
|
21
|
-
|
22
|
-
|
23
|
-
class DismantledBlock(layers.Layer):
|
24
|
-
def __init__(
|
25
|
-
self,
|
26
|
-
num_heads,
|
27
|
-
hidden_dim,
|
28
|
-
mlp_ratio=4.0,
|
29
|
-
use_projection=True,
|
30
|
-
**kwargs,
|
31
|
-
):
|
32
|
-
super().__init__(**kwargs)
|
33
|
-
self.num_heads = num_heads
|
34
|
-
self.hidden_dim = hidden_dim
|
35
|
-
self.mlp_ratio = mlp_ratio
|
36
|
-
self.use_projection = use_projection
|
37
|
-
|
38
|
-
head_dim = hidden_dim // num_heads
|
39
|
-
self.head_dim = head_dim
|
40
|
-
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
|
41
|
-
self.mlp_hidden_dim = mlp_hidden_dim
|
42
|
-
num_modulations = 6 if use_projection else 2
|
43
|
-
self.num_modulations = num_modulations
|
44
|
-
|
45
|
-
self.adaptive_norm_modulation = models.Sequential(
|
46
|
-
[
|
47
|
-
layers.Activation("silu", dtype=self.dtype_policy),
|
48
|
-
layers.Dense(
|
49
|
-
num_modulations * hidden_dim, dtype=self.dtype_policy
|
50
|
-
),
|
51
|
-
],
|
52
|
-
name="adaptive_norm_modulation",
|
53
|
-
)
|
54
|
-
self.norm1 = layers.LayerNormalization(
|
55
|
-
epsilon=1e-6,
|
56
|
-
center=False,
|
57
|
-
scale=False,
|
58
|
-
dtype=self.dtype_policy,
|
59
|
-
name="norm1",
|
60
|
-
)
|
61
|
-
self.attention_qkv = layers.Dense(
|
62
|
-
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
63
|
-
)
|
64
|
-
if use_projection:
|
65
|
-
self.attention_proj = layers.Dense(
|
66
|
-
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
|
67
|
-
)
|
68
|
-
self.norm2 = layers.LayerNormalization(
|
69
|
-
epsilon=1e-6,
|
70
|
-
center=False,
|
71
|
-
scale=False,
|
72
|
-
dtype=self.dtype_policy,
|
73
|
-
name="norm2",
|
74
|
-
)
|
75
|
-
self.mlp = models.Sequential(
|
76
|
-
[
|
77
|
-
layers.Dense(
|
78
|
-
mlp_hidden_dim,
|
79
|
-
activation=gelu_approximate,
|
80
|
-
dtype=self.dtype_policy,
|
81
|
-
),
|
82
|
-
layers.Dense(
|
83
|
-
hidden_dim,
|
84
|
-
dtype=self.dtype_policy,
|
85
|
-
),
|
86
|
-
],
|
87
|
-
name="mlp",
|
88
|
-
)
|
89
|
-
|
90
|
-
def build(self, inputs_shape, timestep_embedding):
|
91
|
-
self.adaptive_norm_modulation.build(timestep_embedding)
|
92
|
-
self.attention_qkv.build(inputs_shape)
|
93
|
-
self.norm1.build(inputs_shape)
|
94
|
-
if self.use_projection:
|
95
|
-
self.attention_proj.build(inputs_shape)
|
96
|
-
self.norm2.build(inputs_shape)
|
97
|
-
self.mlp.build(inputs_shape)
|
98
|
-
|
99
|
-
def _modulate(self, inputs, shift, scale):
|
100
|
-
shift = ops.expand_dims(shift, axis=1)
|
101
|
-
scale = ops.expand_dims(scale, axis=1)
|
102
|
-
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
103
|
-
|
104
|
-
def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
|
105
|
-
batch_size = ops.shape(inputs)[0]
|
106
|
-
if self.use_projection:
|
107
|
-
modulation = self.adaptive_norm_modulation(
|
108
|
-
timestep_embedding, training=training
|
109
|
-
)
|
110
|
-
modulation = ops.reshape(
|
111
|
-
modulation, (batch_size, 6, self.hidden_dim)
|
112
|
-
)
|
113
|
-
(
|
114
|
-
shift_msa,
|
115
|
-
scale_msa,
|
116
|
-
gate_msa,
|
117
|
-
shift_mlp,
|
118
|
-
scale_mlp,
|
119
|
-
gate_mlp,
|
120
|
-
) = ops.unstack(modulation, 6, axis=1)
|
121
|
-
qkv = self.attention_qkv(
|
122
|
-
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
123
|
-
training=training,
|
124
|
-
)
|
125
|
-
qkv = ops.reshape(
|
126
|
-
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
127
|
-
)
|
128
|
-
q, k, v = ops.unstack(qkv, 3, axis=2)
|
129
|
-
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
130
|
-
else:
|
131
|
-
modulation = self.adaptive_norm_modulation(
|
132
|
-
timestep_embedding, training=training
|
133
|
-
)
|
134
|
-
modulation = ops.reshape(
|
135
|
-
modulation, (batch_size, 2, self.hidden_dim)
|
136
|
-
)
|
137
|
-
shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
|
138
|
-
qkv = self.attention_qkv(
|
139
|
-
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
140
|
-
training=training,
|
141
|
-
)
|
142
|
-
qkv = ops.reshape(
|
143
|
-
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
144
|
-
)
|
145
|
-
q, k, v = ops.unstack(qkv, 3, axis=2)
|
146
|
-
return (q, k, v)
|
147
|
-
|
148
|
-
def _compute_post_attention(
|
149
|
-
self, inputs, inputs_intermediates, training=None
|
150
|
-
):
|
151
|
-
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
|
152
|
-
attn = self.attention_proj(inputs, training=training)
|
153
|
-
x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
|
154
|
-
x = ops.add(
|
155
|
-
x,
|
156
|
-
ops.multiply(
|
157
|
-
ops.expand_dims(gate_mlp, axis=1),
|
158
|
-
self.mlp(
|
159
|
-
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
160
|
-
training=training,
|
161
|
-
),
|
162
|
-
),
|
163
|
-
)
|
164
|
-
return x
|
165
|
-
|
166
|
-
def call(
|
167
|
-
self,
|
168
|
-
inputs,
|
169
|
-
timestep_embedding=None,
|
170
|
-
inputs_intermediates=None,
|
171
|
-
pre_attention=True,
|
172
|
-
training=None,
|
173
|
-
):
|
174
|
-
if pre_attention:
|
175
|
-
return self._compute_pre_attention(
|
176
|
-
inputs, timestep_embedding, training=training
|
177
|
-
)
|
178
|
-
else:
|
179
|
-
return self._compute_post_attention(
|
180
|
-
inputs, inputs_intermediates, training=training
|
181
|
-
)
|
182
|
-
|
183
|
-
def get_config(self):
|
184
|
-
config = super().get_config()
|
185
|
-
config.update(
|
186
|
-
{
|
187
|
-
"num_heads": self.num_heads,
|
188
|
-
"hidden_dim": self.hidden_dim,
|
189
|
-
"mlp_ratio": self.mlp_ratio,
|
190
|
-
"use_projection": self.use_projection,
|
191
|
-
}
|
192
|
-
)
|
193
|
-
return config
|
194
|
-
|
195
|
-
|
196
|
-
class MMDiTBlock(layers.Layer):
|
197
|
-
def __init__(
|
198
|
-
self,
|
199
|
-
num_heads,
|
200
|
-
hidden_dim,
|
201
|
-
mlp_ratio=4.0,
|
202
|
-
use_context_projection=True,
|
203
|
-
**kwargs,
|
204
|
-
):
|
205
|
-
super().__init__(**kwargs)
|
206
|
-
self.num_heads = num_heads
|
207
|
-
self.hidden_dim = hidden_dim
|
208
|
-
self.mlp_ratio = mlp_ratio
|
209
|
-
self.use_context_projection = use_context_projection
|
210
|
-
|
211
|
-
head_dim = hidden_dim // num_heads
|
212
|
-
self.head_dim = head_dim
|
213
|
-
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
|
214
|
-
self._dot_product_equation = "aecd,abcd->acbe"
|
215
|
-
self._combine_equation = "acbe,aecd->abcd"
|
216
|
-
|
217
|
-
self.x_block = DismantledBlock(
|
218
|
-
num_heads=num_heads,
|
219
|
-
hidden_dim=hidden_dim,
|
220
|
-
mlp_ratio=mlp_ratio,
|
221
|
-
use_projection=True,
|
222
|
-
dtype=self.dtype_policy,
|
223
|
-
name="x_block",
|
224
|
-
)
|
225
|
-
self.context_block = DismantledBlock(
|
226
|
-
num_heads=num_heads,
|
227
|
-
hidden_dim=hidden_dim,
|
228
|
-
mlp_ratio=mlp_ratio,
|
229
|
-
use_projection=use_context_projection,
|
230
|
-
dtype=self.dtype_policy,
|
231
|
-
name="context_block",
|
232
|
-
)
|
233
|
-
|
234
|
-
def build(self, inputs_shape, context_shape, timestep_embedding_shape):
|
235
|
-
self.x_block.build(inputs_shape, timestep_embedding_shape)
|
236
|
-
self.context_block.build(context_shape, timestep_embedding_shape)
|
237
|
-
|
238
|
-
def _compute_attention(self, query, key, value):
|
239
|
-
query = ops.multiply(
|
240
|
-
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
|
241
|
-
)
|
242
|
-
attention_scores = ops.einsum(self._dot_product_equation, key, query)
|
243
|
-
attention_scores = ops.nn.softmax(attention_scores, axis=-1)
|
244
|
-
attention_output = ops.einsum(
|
245
|
-
self._combine_equation, attention_scores, value
|
246
|
-
)
|
247
|
-
batch_size = ops.shape(attention_output)[0]
|
248
|
-
attention_output = ops.reshape(
|
249
|
-
attention_output, (batch_size, -1, self.num_heads * self.head_dim)
|
250
|
-
)
|
251
|
-
return attention_output
|
252
|
-
|
253
|
-
def call(self, inputs, context, timestep_embedding, training=None):
|
254
|
-
# Compute pre-attention.
|
255
|
-
x = inputs
|
256
|
-
if self.use_context_projection:
|
257
|
-
context_qkv, context_intermediates = self.context_block(
|
258
|
-
context,
|
259
|
-
timestep_embedding=timestep_embedding,
|
260
|
-
training=training,
|
261
|
-
)
|
262
|
-
else:
|
263
|
-
context_qkv = self.context_block(
|
264
|
-
context,
|
265
|
-
timestep_embedding=timestep_embedding,
|
266
|
-
training=training,
|
267
|
-
)
|
268
|
-
context_len = ops.shape(context_qkv[0])[1]
|
269
|
-
x_qkv, x_intermediates = self.x_block(
|
270
|
-
x, timestep_embedding=timestep_embedding, training=training
|
271
|
-
)
|
272
|
-
q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
|
273
|
-
k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
|
274
|
-
v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
|
275
|
-
|
276
|
-
# Compute attention.
|
277
|
-
attention = self._compute_attention(q, k, v)
|
278
|
-
context_attention = attention[:, :context_len]
|
279
|
-
x_attention = attention[:, context_len:]
|
280
|
-
|
281
|
-
# Compute post-attention.
|
282
|
-
x = self.x_block(
|
283
|
-
x_attention,
|
284
|
-
inputs_intermediates=x_intermediates,
|
285
|
-
pre_attention=False,
|
286
|
-
training=training,
|
287
|
-
)
|
288
|
-
if self.use_context_projection:
|
289
|
-
context = self.context_block(
|
290
|
-
context_attention,
|
291
|
-
inputs_intermediates=context_intermediates,
|
292
|
-
pre_attention=False,
|
293
|
-
training=training,
|
294
|
-
)
|
295
|
-
return x, context
|
296
|
-
else:
|
297
|
-
return x
|
298
|
-
|
299
|
-
def get_config(self):
|
300
|
-
config = super().get_config()
|
301
|
-
config.update(
|
302
|
-
{
|
303
|
-
"num_heads": self.num_heads,
|
304
|
-
"hidden_dim": self.hidden_dim,
|
305
|
-
"mlp_ratio": self.mlp_ratio,
|
306
|
-
"use_context_projection": self.use_context_projection,
|
307
|
-
}
|
308
|
-
)
|
309
|
-
return config
|
310
|
-
|
311
|
-
def compute_output_shape(
|
312
|
-
self, inputs_shape, context_shape, timestep_embedding_shape
|
313
|
-
):
|
314
|
-
if self.use_context_projection:
|
315
|
-
return inputs_shape, context_shape
|
316
|
-
else:
|
317
|
-
return inputs_shape
|
@@ -1,126 +0,0 @@
|
|
1
|
-
# Copyright 2024 The KerasHub Authors
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
import math
|
15
|
-
|
16
|
-
from keras import layers
|
17
|
-
from keras import ops
|
18
|
-
|
19
|
-
from keras_hub.src.utils.keras_utils import standardize_data_format
|
20
|
-
|
21
|
-
|
22
|
-
class VAEAttention(layers.Layer):
|
23
|
-
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
24
|
-
super().__init__(**kwargs)
|
25
|
-
self.filters = filters
|
26
|
-
self.data_format = standardize_data_format(data_format)
|
27
|
-
gn_axis = -1 if self.data_format == "channels_last" else 1
|
28
|
-
|
29
|
-
self.group_norm = layers.GroupNormalization(
|
30
|
-
groups=groups,
|
31
|
-
axis=gn_axis,
|
32
|
-
epsilon=1e-6,
|
33
|
-
dtype=self.dtype_policy,
|
34
|
-
name="group_norm",
|
35
|
-
)
|
36
|
-
self.query_conv2d = layers.Conv2D(
|
37
|
-
filters,
|
38
|
-
1,
|
39
|
-
1,
|
40
|
-
data_format=self.data_format,
|
41
|
-
dtype=self.dtype_policy,
|
42
|
-
name="query_conv2d",
|
43
|
-
)
|
44
|
-
self.key_conv2d = layers.Conv2D(
|
45
|
-
filters,
|
46
|
-
1,
|
47
|
-
1,
|
48
|
-
data_format=self.data_format,
|
49
|
-
dtype=self.dtype_policy,
|
50
|
-
name="key_conv2d",
|
51
|
-
)
|
52
|
-
self.value_conv2d = layers.Conv2D(
|
53
|
-
filters,
|
54
|
-
1,
|
55
|
-
1,
|
56
|
-
data_format=self.data_format,
|
57
|
-
dtype=self.dtype_policy,
|
58
|
-
name="value_conv2d",
|
59
|
-
)
|
60
|
-
self.softmax = layers.Softmax(dtype="float32")
|
61
|
-
self.output_conv2d = layers.Conv2D(
|
62
|
-
filters,
|
63
|
-
1,
|
64
|
-
1,
|
65
|
-
data_format=self.data_format,
|
66
|
-
dtype=self.dtype_policy,
|
67
|
-
name="output_conv2d",
|
68
|
-
)
|
69
|
-
|
70
|
-
self.groups = groups
|
71
|
-
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
72
|
-
|
73
|
-
def build(self, input_shape):
|
74
|
-
self.group_norm.build(input_shape)
|
75
|
-
self.query_conv2d.build(input_shape)
|
76
|
-
self.key_conv2d.build(input_shape)
|
77
|
-
self.value_conv2d.build(input_shape)
|
78
|
-
self.output_conv2d.build(input_shape)
|
79
|
-
|
80
|
-
def call(self, inputs, training=None):
|
81
|
-
x = self.group_norm(inputs)
|
82
|
-
query = self.query_conv2d(x)
|
83
|
-
key = self.key_conv2d(x)
|
84
|
-
value = self.value_conv2d(x)
|
85
|
-
|
86
|
-
if self.data_format == "channels_first":
|
87
|
-
query = ops.transpose(query, (0, 2, 3, 1))
|
88
|
-
key = ops.transpose(key, (0, 2, 3, 1))
|
89
|
-
value = ops.transpose(value, (0, 2, 3, 1))
|
90
|
-
shape = ops.shape(inputs)
|
91
|
-
b = shape[0]
|
92
|
-
query = ops.reshape(query, (b, -1, self.filters))
|
93
|
-
key = ops.reshape(key, (b, -1, self.filters))
|
94
|
-
value = ops.reshape(value, (b, -1, self.filters))
|
95
|
-
|
96
|
-
# Compute attention.
|
97
|
-
query = ops.multiply(
|
98
|
-
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
99
|
-
)
|
100
|
-
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
101
|
-
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
102
|
-
attention_scores = ops.cast(
|
103
|
-
self.softmax(attention_scores), self.compute_dtype
|
104
|
-
)
|
105
|
-
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
106
|
-
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
107
|
-
x = ops.reshape(attention_output, shape)
|
108
|
-
|
109
|
-
x = self.output_conv2d(x)
|
110
|
-
if self.data_format == "channels_first":
|
111
|
-
x = ops.transpose(x, (0, 3, 1, 2))
|
112
|
-
x = ops.add(x, inputs)
|
113
|
-
return x
|
114
|
-
|
115
|
-
def get_config(self):
|
116
|
-
config = super().get_config()
|
117
|
-
config.update(
|
118
|
-
{
|
119
|
-
"filters": self.filters,
|
120
|
-
"groups": self.groups,
|
121
|
-
}
|
122
|
-
)
|
123
|
-
return config
|
124
|
-
|
125
|
-
def compute_output_shape(self, input_shape):
|
126
|
-
return input_shape
|
@@ -1,186 +0,0 @@
|
|
1
|
-
# Copyright 2024 The KerasHub Authors
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
import keras
|
15
|
-
from keras import layers
|
16
|
-
|
17
|
-
from keras_hub.src.models.stable_diffusion_v3.vae_attention import VAEAttention
|
18
|
-
from keras_hub.src.utils.keras_utils import standardize_data_format
|
19
|
-
|
20
|
-
|
21
|
-
class VAEImageDecoder(keras.Model):
|
22
|
-
def __init__(
|
23
|
-
self,
|
24
|
-
stackwise_num_filters,
|
25
|
-
stackwise_num_blocks,
|
26
|
-
output_channels=3,
|
27
|
-
latent_shape=(None, None, 16),
|
28
|
-
data_format=None,
|
29
|
-
dtype=None,
|
30
|
-
**kwargs,
|
31
|
-
):
|
32
|
-
data_format = standardize_data_format(data_format)
|
33
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
34
|
-
|
35
|
-
# === Functional Model ===
|
36
|
-
latent_inputs = layers.Input(shape=latent_shape)
|
37
|
-
|
38
|
-
x = layers.Conv2D(
|
39
|
-
stackwise_num_filters[0],
|
40
|
-
3,
|
41
|
-
1,
|
42
|
-
padding="same",
|
43
|
-
data_format=data_format,
|
44
|
-
dtype=dtype,
|
45
|
-
name="input_projection",
|
46
|
-
)(latent_inputs)
|
47
|
-
x = apply_resnet_block(
|
48
|
-
x,
|
49
|
-
stackwise_num_filters[0],
|
50
|
-
data_format=data_format,
|
51
|
-
dtype=dtype,
|
52
|
-
name="input_block0",
|
53
|
-
)
|
54
|
-
x = VAEAttention(
|
55
|
-
stackwise_num_filters[0],
|
56
|
-
data_format=data_format,
|
57
|
-
dtype=dtype,
|
58
|
-
name="input_attention",
|
59
|
-
)(x)
|
60
|
-
x = apply_resnet_block(
|
61
|
-
x,
|
62
|
-
stackwise_num_filters[0],
|
63
|
-
data_format=data_format,
|
64
|
-
dtype=dtype,
|
65
|
-
name="input_block1",
|
66
|
-
)
|
67
|
-
|
68
|
-
# Stacks.
|
69
|
-
for i, filters in enumerate(stackwise_num_filters):
|
70
|
-
for j in range(stackwise_num_blocks[i]):
|
71
|
-
x = apply_resnet_block(
|
72
|
-
x,
|
73
|
-
filters,
|
74
|
-
data_format=data_format,
|
75
|
-
dtype=dtype,
|
76
|
-
name=f"block{i}_{j}",
|
77
|
-
)
|
78
|
-
if i != len(stackwise_num_filters) - 1:
|
79
|
-
# No upsamling in the last blcok.
|
80
|
-
x = layers.UpSampling2D(
|
81
|
-
2,
|
82
|
-
data_format=data_format,
|
83
|
-
dtype=dtype,
|
84
|
-
name=f"upsample_{i}",
|
85
|
-
)(x)
|
86
|
-
x = layers.Conv2D(
|
87
|
-
filters,
|
88
|
-
3,
|
89
|
-
1,
|
90
|
-
padding="same",
|
91
|
-
data_format=data_format,
|
92
|
-
dtype=dtype,
|
93
|
-
name=f"upsample_{i}_conv",
|
94
|
-
)(x)
|
95
|
-
|
96
|
-
# Ouput block.
|
97
|
-
x = layers.GroupNormalization(
|
98
|
-
groups=32,
|
99
|
-
axis=gn_axis,
|
100
|
-
epsilon=1e-6,
|
101
|
-
dtype=dtype,
|
102
|
-
name="output_norm",
|
103
|
-
)(x)
|
104
|
-
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
|
105
|
-
image_outputs = layers.Conv2D(
|
106
|
-
output_channels,
|
107
|
-
3,
|
108
|
-
1,
|
109
|
-
padding="same",
|
110
|
-
data_format=data_format,
|
111
|
-
dtype=dtype,
|
112
|
-
name="output_projection",
|
113
|
-
)(x)
|
114
|
-
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
|
115
|
-
|
116
|
-
# === Config ===
|
117
|
-
self.stackwise_num_filters = stackwise_num_filters
|
118
|
-
self.stackwise_num_blocks = stackwise_num_blocks
|
119
|
-
self.output_channels = output_channels
|
120
|
-
self.latent_shape = latent_shape
|
121
|
-
|
122
|
-
if dtype is not None:
|
123
|
-
try:
|
124
|
-
self.dtype_policy = keras.dtype_policies.get(dtype)
|
125
|
-
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
126
|
-
except AttributeError:
|
127
|
-
if isinstance(dtype, keras.DTypePolicy):
|
128
|
-
dtype = dtype.name
|
129
|
-
self.dtype_policy = keras.DTypePolicy(dtype)
|
130
|
-
|
131
|
-
def get_config(self):
|
132
|
-
config = super().get_config()
|
133
|
-
config.update(
|
134
|
-
{
|
135
|
-
"stackwise_num_filters": self.stackwise_num_filters,
|
136
|
-
"stackwise_num_blocks": self.stackwise_num_blocks,
|
137
|
-
"output_channels": self.output_channels,
|
138
|
-
"image_shape": self.latent_shape,
|
139
|
-
}
|
140
|
-
)
|
141
|
-
return config
|
142
|
-
|
143
|
-
|
144
|
-
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
|
145
|
-
data_format = standardize_data_format(data_format)
|
146
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
147
|
-
input_filters = x.shape[gn_axis]
|
148
|
-
|
149
|
-
residual = x
|
150
|
-
x = layers.GroupNormalization(
|
151
|
-
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
|
152
|
-
)(x)
|
153
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
154
|
-
x = layers.Conv2D(
|
155
|
-
filters,
|
156
|
-
3,
|
157
|
-
1,
|
158
|
-
padding="same",
|
159
|
-
data_format=data_format,
|
160
|
-
dtype=dtype,
|
161
|
-
name=f"{name}_conv1",
|
162
|
-
)(x)
|
163
|
-
x = layers.GroupNormalization(
|
164
|
-
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
|
165
|
-
)(x)
|
166
|
-
x = layers.Activation("swish")(x)
|
167
|
-
x = layers.Conv2D(
|
168
|
-
filters,
|
169
|
-
3,
|
170
|
-
1,
|
171
|
-
padding="same",
|
172
|
-
data_format=data_format,
|
173
|
-
dtype=dtype,
|
174
|
-
name=f"{name}_conv2",
|
175
|
-
)(x)
|
176
|
-
if input_filters != filters:
|
177
|
-
residual = layers.Conv2D(
|
178
|
-
filters,
|
179
|
-
1,
|
180
|
-
1,
|
181
|
-
data_format=data_format,
|
182
|
-
dtype=dtype,
|
183
|
-
name=f"{name}_residual_projection",
|
184
|
-
)(residual)
|
185
|
-
x = layers.Add(dtype=dtype)([residual, x])
|
186
|
-
return x
|