keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409270338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -13
- keras_hub/api/__init__.py +0 -13
- keras_hub/api/bounding_box/__init__.py +0 -13
- keras_hub/api/layers/__init__.py +3 -13
- keras_hub/api/metrics/__init__.py +0 -13
- keras_hub/api/models/__init__.py +16 -13
- keras_hub/api/samplers/__init__.py +0 -13
- keras_hub/api/tokenizers/__init__.py +1 -13
- keras_hub/api/utils/__init__.py +0 -13
- keras_hub/src/__init__.py +0 -13
- keras_hub/src/api_export.py +0 -14
- keras_hub/src/bounding_box/__init__.py +0 -13
- keras_hub/src/bounding_box/converters.py +0 -13
- keras_hub/src/bounding_box/formats.py +0 -13
- keras_hub/src/bounding_box/iou.py +1 -13
- keras_hub/src/bounding_box/to_dense.py +0 -14
- keras_hub/src/bounding_box/to_ragged.py +0 -13
- keras_hub/src/bounding_box/utils.py +0 -13
- keras_hub/src/bounding_box/validate_format.py +0 -14
- keras_hub/src/layers/__init__.py +0 -13
- keras_hub/src/layers/modeling/__init__.py +0 -13
- keras_hub/src/layers/modeling/alibi_bias.py +0 -13
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +0 -14
- keras_hub/src/layers/modeling/f_net_encoder.py +0 -14
- keras_hub/src/layers/modeling/masked_lm_head.py +0 -14
- keras_hub/src/layers/modeling/position_embedding.py +0 -14
- keras_hub/src/layers/modeling/reversible_embedding.py +0 -14
- keras_hub/src/layers/modeling/rotary_embedding.py +0 -14
- keras_hub/src/layers/modeling/sine_position_encoding.py +0 -14
- keras_hub/src/layers/modeling/token_and_position_embedding.py +0 -14
- keras_hub/src/layers/modeling/transformer_decoder.py +0 -14
- keras_hub/src/layers/modeling/transformer_encoder.py +0 -14
- keras_hub/src/layers/modeling/transformer_layer_utils.py +0 -14
- keras_hub/src/layers/preprocessing/__init__.py +0 -13
- keras_hub/src/layers/preprocessing/audio_converter.py +0 -13
- keras_hub/src/layers/preprocessing/image_converter.py +0 -13
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +0 -15
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +0 -14
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +0 -14
- keras_hub/src/layers/preprocessing/random_deletion.py +0 -14
- keras_hub/src/layers/preprocessing/random_swap.py +0 -14
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -13
- keras_hub/src/layers/preprocessing/start_end_packer.py +0 -15
- keras_hub/src/metrics/__init__.py +0 -13
- keras_hub/src/metrics/bleu.py +0 -14
- keras_hub/src/metrics/edit_distance.py +0 -14
- keras_hub/src/metrics/perplexity.py +0 -14
- keras_hub/src/metrics/rouge_base.py +0 -14
- keras_hub/src/metrics/rouge_l.py +0 -14
- keras_hub/src/metrics/rouge_n.py +0 -14
- keras_hub/src/models/__init__.py +0 -13
- keras_hub/src/models/albert/__init__.py +0 -14
- keras_hub/src/models/albert/albert_backbone.py +0 -14
- keras_hub/src/models/albert/albert_masked_lm.py +0 -14
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/albert/albert_presets.py +0 -14
- keras_hub/src/models/albert/albert_text_classifier.py +0 -14
- keras_hub/src/models/albert/albert_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/albert/albert_tokenizer.py +0 -14
- keras_hub/src/models/backbone.py +0 -14
- keras_hub/src/models/bart/__init__.py +0 -14
- keras_hub/src/models/bart/bart_backbone.py +0 -14
- keras_hub/src/models/bart/bart_presets.py +0 -13
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +0 -15
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +0 -15
- keras_hub/src/models/bart/bart_tokenizer.py +0 -15
- keras_hub/src/models/bert/__init__.py +0 -14
- keras_hub/src/models/bert/bert_backbone.py +0 -14
- keras_hub/src/models/bert/bert_masked_lm.py +0 -14
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/bert/bert_presets.py +0 -13
- keras_hub/src/models/bert/bert_text_classifier.py +0 -14
- keras_hub/src/models/bert/bert_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/bert/bert_tokenizer.py +0 -14
- keras_hub/src/models/bloom/__init__.py +0 -14
- keras_hub/src/models/bloom/bloom_attention.py +0 -13
- keras_hub/src/models/bloom/bloom_backbone.py +0 -14
- keras_hub/src/models/bloom/bloom_causal_lm.py +0 -15
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +0 -15
- keras_hub/src/models/bloom/bloom_decoder.py +0 -13
- keras_hub/src/models/bloom/bloom_presets.py +0 -13
- keras_hub/src/models/bloom/bloom_tokenizer.py +0 -15
- keras_hub/src/models/causal_lm.py +0 -14
- keras_hub/src/models/causal_lm_preprocessor.py +0 -13
- keras_hub/src/models/clip/__init__.py +0 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -15
- keras_hub/src/models/clip/clip_preprocessor.py +134 -0
- keras_hub/src/models/clip/clip_text_encoder.py +139 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +65 -41
- keras_hub/src/models/csp_darknet/__init__.py +0 -13
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -13
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -13
- keras_hub/src/models/deberta_v3/__init__.py +0 -14
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +0 -15
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +0 -15
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -13
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +0 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +0 -15
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +0 -14
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +0 -14
- keras_hub/src/models/deberta_v3/relative_embedding.py +0 -14
- keras_hub/src/models/densenet/__init__.py +5 -13
- keras_hub/src/models/densenet/densenet_backbone.py +11 -21
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -17
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/{stable_diffusion_v3/__init__.py → densenet/densenet_image_converter.py} +10 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/distil_bert/__init__.py +0 -14
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -13
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +0 -15
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +0 -15
- keras_hub/src/models/efficientnet/__init__.py +0 -13
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +0 -13
- keras_hub/src/models/efficientnet/fusedmbconv.py +0 -14
- keras_hub/src/models/efficientnet/mbconv.py +0 -14
- keras_hub/src/models/electra/__init__.py +0 -14
- keras_hub/src/models/electra/electra_backbone.py +0 -14
- keras_hub/src/models/electra/electra_presets.py +0 -13
- keras_hub/src/models/electra/electra_tokenizer.py +0 -14
- keras_hub/src/models/f_net/__init__.py +0 -14
- keras_hub/src/models/f_net/f_net_backbone.py +0 -15
- keras_hub/src/models/f_net/f_net_masked_lm.py +0 -15
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/f_net/f_net_presets.py +0 -13
- keras_hub/src/models/f_net/f_net_text_classifier.py +0 -15
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +0 -15
- keras_hub/src/models/f_net/f_net_tokenizer.py +0 -15
- keras_hub/src/models/falcon/__init__.py +0 -14
- keras_hub/src/models/falcon/falcon_attention.py +0 -13
- keras_hub/src/models/falcon/falcon_backbone.py +0 -13
- keras_hub/src/models/falcon/falcon_causal_lm.py +0 -14
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/falcon/falcon_presets.py +0 -13
- keras_hub/src/models/falcon/falcon_tokenizer.py +0 -15
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +0 -13
- keras_hub/src/models/feature_pyramid_backbone.py +0 -13
- keras_hub/src/models/gemma/__init__.py +0 -14
- keras_hub/src/models/gemma/gemma_attention.py +0 -13
- keras_hub/src/models/gemma/gemma_backbone.py +0 -15
- keras_hub/src/models/gemma/gemma_causal_lm.py +0 -15
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/gemma/gemma_decoder_block.py +0 -13
- keras_hub/src/models/gemma/gemma_presets.py +0 -13
- keras_hub/src/models/gemma/gemma_tokenizer.py +0 -14
- keras_hub/src/models/gemma/rms_normalization.py +0 -14
- keras_hub/src/models/gpt2/__init__.py +0 -14
- keras_hub/src/models/gpt2/gpt2_backbone.py +0 -15
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +0 -15
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +0 -15
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -13
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +0 -15
- keras_hub/src/models/gpt_neo_x/__init__.py +0 -13
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +0 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +0 -14
- keras_hub/src/models/image_classifier.py +0 -13
- keras_hub/src/models/image_classifier_preprocessor.py +0 -13
- keras_hub/src/models/image_segmenter.py +0 -13
- keras_hub/src/models/llama/__init__.py +0 -14
- keras_hub/src/models/llama/llama_attention.py +0 -13
- keras_hub/src/models/llama/llama_backbone.py +0 -13
- keras_hub/src/models/llama/llama_causal_lm.py +0 -13
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +0 -15
- keras_hub/src/models/llama/llama_decoder.py +0 -13
- keras_hub/src/models/llama/llama_layernorm.py +0 -13
- keras_hub/src/models/llama/llama_presets.py +0 -13
- keras_hub/src/models/llama/llama_tokenizer.py +0 -14
- keras_hub/src/models/llama3/__init__.py +0 -14
- keras_hub/src/models/llama3/llama3_backbone.py +0 -14
- keras_hub/src/models/llama3/llama3_causal_lm.py +0 -13
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/llama3/llama3_presets.py +0 -13
- keras_hub/src/models/llama3/llama3_tokenizer.py +0 -14
- keras_hub/src/models/masked_lm.py +0 -13
- keras_hub/src/models/masked_lm_preprocessor.py +0 -13
- keras_hub/src/models/mistral/__init__.py +0 -14
- keras_hub/src/models/mistral/mistral_attention.py +0 -13
- keras_hub/src/models/mistral/mistral_backbone.py +0 -14
- keras_hub/src/models/mistral/mistral_causal_lm.py +0 -14
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/mistral/mistral_layer_norm.py +0 -13
- keras_hub/src/models/mistral/mistral_presets.py +0 -13
- keras_hub/src/models/mistral/mistral_tokenizer.py +0 -14
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +0 -13
- keras_hub/src/models/mix_transformer/__init__.py +0 -13
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +0 -13
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -13
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +0 -13
- keras_hub/src/models/mobilenet/__init__.py +0 -13
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +0 -13
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -13
- keras_hub/src/models/opt/__init__.py +0 -14
- keras_hub/src/models/opt/opt_backbone.py +0 -15
- keras_hub/src/models/opt/opt_causal_lm.py +0 -15
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +0 -13
- keras_hub/src/models/opt/opt_presets.py +0 -13
- keras_hub/src/models/opt/opt_tokenizer.py +0 -15
- keras_hub/src/models/pali_gemma/__init__.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +0 -14
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +0 -13
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +0 -13
- keras_hub/src/models/phi3/__init__.py +0 -14
- keras_hub/src/models/phi3/phi3_attention.py +0 -13
- keras_hub/src/models/phi3/phi3_backbone.py +0 -13
- keras_hub/src/models/phi3/phi3_causal_lm.py +0 -13
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +0 -14
- keras_hub/src/models/phi3/phi3_decoder.py +0 -13
- keras_hub/src/models/phi3/phi3_layernorm.py +0 -13
- keras_hub/src/models/phi3/phi3_presets.py +0 -13
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +0 -13
- keras_hub/src/models/phi3/phi3_tokenizer.py +0 -13
- keras_hub/src/models/preprocessor.py +51 -32
- keras_hub/src/models/resnet/__init__.py +0 -14
- keras_hub/src/models/resnet/resnet_backbone.py +0 -13
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -13
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +0 -14
- keras_hub/src/models/resnet/resnet_image_converter.py +0 -13
- keras_hub/src/models/resnet/resnet_presets.py +0 -13
- keras_hub/src/models/retinanet/__init__.py +0 -13
- keras_hub/src/models/retinanet/anchor_generator.py +0 -14
- keras_hub/src/models/retinanet/box_matcher.py +0 -14
- keras_hub/src/models/retinanet/non_max_supression.py +0 -14
- keras_hub/src/models/roberta/__init__.py +0 -14
- keras_hub/src/models/roberta/roberta_backbone.py +0 -15
- keras_hub/src/models/roberta/roberta_masked_lm.py +0 -15
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/roberta/roberta_presets.py +0 -13
- keras_hub/src/models/roberta/roberta_text_classifier.py +0 -15
- keras_hub/src/models/roberta/roberta_text_classifier_preprocessor.py +0 -14
- keras_hub/src/models/roberta/roberta_tokenizer.py +0 -15
- keras_hub/src/models/sam/__init__.py +0 -13
- keras_hub/src/models/sam/sam_backbone.py +0 -14
- keras_hub/src/models/sam/sam_image_segmenter.py +0 -14
- keras_hub/src/models/sam/sam_layers.py +0 -14
- keras_hub/src/models/sam/sam_mask_decoder.py +0 -14
- keras_hub/src/models/sam/sam_prompt_encoder.py +0 -14
- keras_hub/src/models/sam/sam_transformer.py +0 -14
- keras_hub/src/models/seq_2_seq_lm.py +0 -13
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +0 -13
- keras_hub/src/models/stable_diffusion_3/__init__.py +9 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +80 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -39
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +631 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +31 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +138 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +83 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -20
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +320 -0
- keras_hub/src/models/t5/__init__.py +0 -14
- keras_hub/src/models/t5/t5_backbone.py +0 -14
- keras_hub/src/models/t5/t5_layer_norm.py +0 -14
- keras_hub/src/models/t5/t5_multi_head_attention.py +0 -14
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -16
- keras_hub/src/models/t5/t5_presets.py +0 -13
- keras_hub/src/models/t5/t5_tokenizer.py +0 -14
- keras_hub/src/models/t5/t5_transformer_layer.py +0 -14
- keras_hub/src/models/task.py +0 -14
- keras_hub/src/models/text_classifier.py +0 -13
- keras_hub/src/models/text_classifier_preprocessor.py +0 -13
- keras_hub/src/models/text_to_image.py +282 -0
- keras_hub/src/models/vgg/__init__.py +0 -13
- keras_hub/src/models/vgg/vgg_backbone.py +0 -13
- keras_hub/src/models/vgg/vgg_image_classifier.py +0 -13
- keras_hub/src/models/vit_det/__init__.py +0 -13
- keras_hub/src/models/vit_det/vit_det_backbone.py +0 -14
- keras_hub/src/models/vit_det/vit_layers.py +0 -15
- keras_hub/src/models/whisper/__init__.py +0 -14
- keras_hub/src/models/whisper/whisper_audio_converter.py +0 -15
- keras_hub/src/models/whisper/whisper_backbone.py +0 -15
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +0 -13
- keras_hub/src/models/whisper/whisper_decoder.py +0 -14
- keras_hub/src/models/whisper/whisper_encoder.py +0 -14
- keras_hub/src/models/whisper/whisper_presets.py +0 -14
- keras_hub/src/models/whisper/whisper_tokenizer.py +0 -14
- keras_hub/src/models/xlm_roberta/__init__.py +0 -14
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +0 -14
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -13
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +0 -15
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +0 -15
- keras_hub/src/models/xlnet/__init__.py +0 -13
- keras_hub/src/models/xlnet/relative_attention.py +0 -14
- keras_hub/src/models/xlnet/xlnet_backbone.py +0 -14
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +0 -14
- keras_hub/src/models/xlnet/xlnet_encoder.py +0 -14
- keras_hub/src/samplers/__init__.py +0 -13
- keras_hub/src/samplers/beam_sampler.py +0 -14
- keras_hub/src/samplers/contrastive_sampler.py +0 -14
- keras_hub/src/samplers/greedy_sampler.py +0 -14
- keras_hub/src/samplers/random_sampler.py +0 -14
- keras_hub/src/samplers/sampler.py +0 -14
- keras_hub/src/samplers/serialization.py +0 -14
- keras_hub/src/samplers/top_k_sampler.py +0 -14
- keras_hub/src/samplers/top_p_sampler.py +0 -14
- keras_hub/src/tests/__init__.py +0 -13
- keras_hub/src/tests/test_case.py +0 -14
- keras_hub/src/tokenizers/__init__.py +0 -13
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +0 -14
- keras_hub/src/tokenizers/byte_tokenizer.py +0 -14
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +0 -14
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +0 -14
- keras_hub/src/tokenizers/tokenizer.py +23 -27
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +0 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +0 -14
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +0 -15
- keras_hub/src/utils/__init__.py +0 -13
- keras_hub/src/utils/imagenet/__init__.py +0 -13
- keras_hub/src/utils/imagenet/imagenet_utils.py +0 -13
- keras_hub/src/utils/keras_utils.py +0 -14
- keras_hub/src/utils/pipeline_model.py +0 -14
- keras_hub/src/utils/preset_utils.py +32 -76
- keras_hub/src/utils/python_utils.py +0 -13
- keras_hub/src/utils/tensor_utils.py +0 -14
- keras_hub/src/utils/timm/__init__.py +0 -13
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/convert_resnet.py +0 -13
- keras_hub/src/utils/timm/preset_loader.py +3 -13
- keras_hub/src/utils/transformers/__init__.py +0 -13
- keras_hub/src/utils/transformers/convert_albert.py +0 -13
- keras_hub/src/utils/transformers/convert_bart.py +0 -13
- keras_hub/src/utils/transformers/convert_bert.py +0 -13
- keras_hub/src/utils/transformers/convert_distilbert.py +0 -13
- keras_hub/src/utils/transformers/convert_gemma.py +0 -13
- keras_hub/src/utils/transformers/convert_gpt2.py +0 -13
- keras_hub/src/utils/transformers/convert_llama3.py +0 -13
- keras_hub/src/utils/transformers/convert_mistral.py +0 -13
- keras_hub/src/utils/transformers/convert_pali_gemma.py +0 -13
- keras_hub/src/utils/transformers/preset_loader.py +1 -15
- keras_hub/src/utils/transformers/safetensor_utils.py +9 -15
- keras_hub/src/version_utils.py +1 -15
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/METADATA +30 -27
- keras_hub_nightly-0.16.1.dev202409270338.dist-info/RECORD +351 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +0 -149
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- keras_hub_nightly-0.16.1.dev202409250340.dist-info/RECORD +0 -342
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,3 @@
|
|
1
|
-
# Copyright 2024 The KerasHub Authors
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
1
|
import math
|
15
2
|
|
16
3
|
import keras
|
@@ -19,7 +6,8 @@ from keras import models
|
|
19
6
|
from keras import ops
|
20
7
|
|
21
8
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
22
|
-
from keras_hub.src.models.
|
9
|
+
from keras_hub.src.models.backbone import Backbone
|
10
|
+
from keras_hub.src.utils.keras_utils import gelu_approximate
|
23
11
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
24
12
|
|
25
13
|
|
@@ -79,8 +67,8 @@ class AdjustablePositionEmbedding(PositionEmbedding):
|
|
79
67
|
width = width or self.width
|
80
68
|
shape = ops.shape(inputs)
|
81
69
|
feature_length = shape[-1]
|
82
|
-
top = ops.floor_divide(self.height - height, 2)
|
83
|
-
left = ops.floor_divide(self.width - width, 2)
|
70
|
+
top = ops.cast(ops.floor_divide(self.height - height, 2), "int32")
|
71
|
+
left = ops.cast(ops.floor_divide(self.width - width, 2), "int32")
|
84
72
|
position_embedding = ops.convert_to_tensor(self.position_embeddings)
|
85
73
|
position_embedding = ops.reshape(
|
86
74
|
position_embedding, (self.height, self.width, feature_length)
|
@@ -166,6 +154,305 @@ class TimestepEmbedding(layers.Layer):
|
|
166
154
|
return output_shape
|
167
155
|
|
168
156
|
|
157
|
+
class DismantledBlock(layers.Layer):
|
158
|
+
def __init__(
|
159
|
+
self,
|
160
|
+
num_heads,
|
161
|
+
hidden_dim,
|
162
|
+
mlp_ratio=4.0,
|
163
|
+
use_projection=True,
|
164
|
+
**kwargs,
|
165
|
+
):
|
166
|
+
super().__init__(**kwargs)
|
167
|
+
self.num_heads = num_heads
|
168
|
+
self.hidden_dim = hidden_dim
|
169
|
+
self.mlp_ratio = mlp_ratio
|
170
|
+
self.use_projection = use_projection
|
171
|
+
|
172
|
+
head_dim = hidden_dim // num_heads
|
173
|
+
self.head_dim = head_dim
|
174
|
+
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
|
175
|
+
self.mlp_hidden_dim = mlp_hidden_dim
|
176
|
+
num_modulations = 6 if use_projection else 2
|
177
|
+
self.num_modulations = num_modulations
|
178
|
+
|
179
|
+
self.adaptive_norm_modulation = models.Sequential(
|
180
|
+
[
|
181
|
+
layers.Activation("silu", dtype=self.dtype_policy),
|
182
|
+
layers.Dense(
|
183
|
+
num_modulations * hidden_dim, dtype=self.dtype_policy
|
184
|
+
),
|
185
|
+
],
|
186
|
+
name="adaptive_norm_modulation",
|
187
|
+
)
|
188
|
+
self.norm1 = layers.LayerNormalization(
|
189
|
+
epsilon=1e-6,
|
190
|
+
center=False,
|
191
|
+
scale=False,
|
192
|
+
dtype="float32",
|
193
|
+
name="norm1",
|
194
|
+
)
|
195
|
+
self.attention_qkv = layers.Dense(
|
196
|
+
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
197
|
+
)
|
198
|
+
if use_projection:
|
199
|
+
self.attention_proj = layers.Dense(
|
200
|
+
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
|
201
|
+
)
|
202
|
+
self.norm2 = layers.LayerNormalization(
|
203
|
+
epsilon=1e-6,
|
204
|
+
center=False,
|
205
|
+
scale=False,
|
206
|
+
dtype="float32",
|
207
|
+
name="norm2",
|
208
|
+
)
|
209
|
+
self.mlp = models.Sequential(
|
210
|
+
[
|
211
|
+
layers.Dense(
|
212
|
+
mlp_hidden_dim,
|
213
|
+
activation=gelu_approximate,
|
214
|
+
dtype=self.dtype_policy,
|
215
|
+
),
|
216
|
+
layers.Dense(
|
217
|
+
hidden_dim,
|
218
|
+
dtype=self.dtype_policy,
|
219
|
+
),
|
220
|
+
],
|
221
|
+
name="mlp",
|
222
|
+
)
|
223
|
+
|
224
|
+
def build(self, inputs_shape, timestep_embedding):
|
225
|
+
self.adaptive_norm_modulation.build(timestep_embedding)
|
226
|
+
self.attention_qkv.build(inputs_shape)
|
227
|
+
self.norm1.build(inputs_shape)
|
228
|
+
if self.use_projection:
|
229
|
+
self.attention_proj.build(inputs_shape)
|
230
|
+
self.norm2.build(inputs_shape)
|
231
|
+
self.mlp.build(inputs_shape)
|
232
|
+
|
233
|
+
def _modulate(self, inputs, shift, scale):
|
234
|
+
shift = ops.expand_dims(shift, axis=1)
|
235
|
+
scale = ops.expand_dims(scale, axis=1)
|
236
|
+
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
237
|
+
|
238
|
+
def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
|
239
|
+
batch_size = ops.shape(inputs)[0]
|
240
|
+
if self.use_projection:
|
241
|
+
modulation = self.adaptive_norm_modulation(
|
242
|
+
timestep_embedding, training=training
|
243
|
+
)
|
244
|
+
modulation = ops.reshape(
|
245
|
+
modulation, (batch_size, 6, self.hidden_dim)
|
246
|
+
)
|
247
|
+
(
|
248
|
+
shift_msa,
|
249
|
+
scale_msa,
|
250
|
+
gate_msa,
|
251
|
+
shift_mlp,
|
252
|
+
scale_mlp,
|
253
|
+
gate_mlp,
|
254
|
+
) = ops.unstack(modulation, 6, axis=1)
|
255
|
+
qkv = self.attention_qkv(
|
256
|
+
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
257
|
+
training=training,
|
258
|
+
)
|
259
|
+
qkv = ops.reshape(
|
260
|
+
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
261
|
+
)
|
262
|
+
q, k, v = ops.unstack(qkv, 3, axis=2)
|
263
|
+
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
264
|
+
else:
|
265
|
+
modulation = self.adaptive_norm_modulation(
|
266
|
+
timestep_embedding, training=training
|
267
|
+
)
|
268
|
+
modulation = ops.reshape(
|
269
|
+
modulation, (batch_size, 2, self.hidden_dim)
|
270
|
+
)
|
271
|
+
shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
|
272
|
+
qkv = self.attention_qkv(
|
273
|
+
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
274
|
+
training=training,
|
275
|
+
)
|
276
|
+
qkv = ops.reshape(
|
277
|
+
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
278
|
+
)
|
279
|
+
q, k, v = ops.unstack(qkv, 3, axis=2)
|
280
|
+
return (q, k, v)
|
281
|
+
|
282
|
+
def _compute_post_attention(
|
283
|
+
self, inputs, inputs_intermediates, training=None
|
284
|
+
):
|
285
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
|
286
|
+
attn = self.attention_proj(inputs, training=training)
|
287
|
+
x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
|
288
|
+
x = ops.add(
|
289
|
+
x,
|
290
|
+
ops.multiply(
|
291
|
+
ops.expand_dims(gate_mlp, axis=1),
|
292
|
+
self.mlp(
|
293
|
+
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
294
|
+
training=training,
|
295
|
+
),
|
296
|
+
),
|
297
|
+
)
|
298
|
+
return x
|
299
|
+
|
300
|
+
def call(
|
301
|
+
self,
|
302
|
+
inputs,
|
303
|
+
timestep_embedding=None,
|
304
|
+
inputs_intermediates=None,
|
305
|
+
pre_attention=True,
|
306
|
+
training=None,
|
307
|
+
):
|
308
|
+
if pre_attention:
|
309
|
+
return self._compute_pre_attention(
|
310
|
+
inputs, timestep_embedding, training=training
|
311
|
+
)
|
312
|
+
else:
|
313
|
+
return self._compute_post_attention(
|
314
|
+
inputs, inputs_intermediates, training=training
|
315
|
+
)
|
316
|
+
|
317
|
+
def get_config(self):
|
318
|
+
config = super().get_config()
|
319
|
+
config.update(
|
320
|
+
{
|
321
|
+
"num_heads": self.num_heads,
|
322
|
+
"hidden_dim": self.hidden_dim,
|
323
|
+
"mlp_ratio": self.mlp_ratio,
|
324
|
+
"use_projection": self.use_projection,
|
325
|
+
}
|
326
|
+
)
|
327
|
+
return config
|
328
|
+
|
329
|
+
|
330
|
+
class MMDiTBlock(layers.Layer):
|
331
|
+
def __init__(
|
332
|
+
self,
|
333
|
+
num_heads,
|
334
|
+
hidden_dim,
|
335
|
+
mlp_ratio=4.0,
|
336
|
+
use_context_projection=True,
|
337
|
+
**kwargs,
|
338
|
+
):
|
339
|
+
super().__init__(**kwargs)
|
340
|
+
self.num_heads = num_heads
|
341
|
+
self.hidden_dim = hidden_dim
|
342
|
+
self.mlp_ratio = mlp_ratio
|
343
|
+
self.use_context_projection = use_context_projection
|
344
|
+
|
345
|
+
head_dim = hidden_dim // num_heads
|
346
|
+
self.head_dim = head_dim
|
347
|
+
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
|
348
|
+
self._dot_product_equation = "aecd,abcd->acbe"
|
349
|
+
self._combine_equation = "acbe,aecd->abcd"
|
350
|
+
|
351
|
+
self.x_block = DismantledBlock(
|
352
|
+
num_heads=num_heads,
|
353
|
+
hidden_dim=hidden_dim,
|
354
|
+
mlp_ratio=mlp_ratio,
|
355
|
+
use_projection=True,
|
356
|
+
dtype=self.dtype_policy,
|
357
|
+
name="x_block",
|
358
|
+
)
|
359
|
+
self.context_block = DismantledBlock(
|
360
|
+
num_heads=num_heads,
|
361
|
+
hidden_dim=hidden_dim,
|
362
|
+
mlp_ratio=mlp_ratio,
|
363
|
+
use_projection=use_context_projection,
|
364
|
+
dtype=self.dtype_policy,
|
365
|
+
name="context_block",
|
366
|
+
)
|
367
|
+
self.softmax = layers.Softmax(dtype="float32")
|
368
|
+
|
369
|
+
def build(self, inputs_shape, context_shape, timestep_embedding_shape):
|
370
|
+
self.x_block.build(inputs_shape, timestep_embedding_shape)
|
371
|
+
self.context_block.build(context_shape, timestep_embedding_shape)
|
372
|
+
|
373
|
+
def _compute_attention(self, query, key, value):
|
374
|
+
query = ops.multiply(
|
375
|
+
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
|
376
|
+
)
|
377
|
+
attention_scores = ops.einsum(self._dot_product_equation, key, query)
|
378
|
+
attention_scores = self.softmax(attention_scores)
|
379
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
380
|
+
attention_output = ops.einsum(
|
381
|
+
self._combine_equation, attention_scores, value
|
382
|
+
)
|
383
|
+
batch_size = ops.shape(attention_output)[0]
|
384
|
+
attention_output = ops.reshape(
|
385
|
+
attention_output, (batch_size, -1, self.num_heads * self.head_dim)
|
386
|
+
)
|
387
|
+
return attention_output
|
388
|
+
|
389
|
+
def call(self, inputs, context, timestep_embedding, training=None):
|
390
|
+
# Compute pre-attention.
|
391
|
+
x = inputs
|
392
|
+
if self.use_context_projection:
|
393
|
+
context_qkv, context_intermediates = self.context_block(
|
394
|
+
context,
|
395
|
+
timestep_embedding=timestep_embedding,
|
396
|
+
training=training,
|
397
|
+
)
|
398
|
+
else:
|
399
|
+
context_qkv = self.context_block(
|
400
|
+
context,
|
401
|
+
timestep_embedding=timestep_embedding,
|
402
|
+
training=training,
|
403
|
+
)
|
404
|
+
context_len = ops.shape(context_qkv[0])[1]
|
405
|
+
x_qkv, x_intermediates = self.x_block(
|
406
|
+
x, timestep_embedding=timestep_embedding, training=training
|
407
|
+
)
|
408
|
+
q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
|
409
|
+
k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
|
410
|
+
v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
|
411
|
+
|
412
|
+
# Compute attention.
|
413
|
+
attention = self._compute_attention(q, k, v)
|
414
|
+
context_attention = attention[:, :context_len]
|
415
|
+
x_attention = attention[:, context_len:]
|
416
|
+
|
417
|
+
# Compute post-attention.
|
418
|
+
x = self.x_block(
|
419
|
+
x_attention,
|
420
|
+
inputs_intermediates=x_intermediates,
|
421
|
+
pre_attention=False,
|
422
|
+
training=training,
|
423
|
+
)
|
424
|
+
if self.use_context_projection:
|
425
|
+
context = self.context_block(
|
426
|
+
context_attention,
|
427
|
+
inputs_intermediates=context_intermediates,
|
428
|
+
pre_attention=False,
|
429
|
+
training=training,
|
430
|
+
)
|
431
|
+
return x, context
|
432
|
+
else:
|
433
|
+
return x
|
434
|
+
|
435
|
+
def get_config(self):
|
436
|
+
config = super().get_config()
|
437
|
+
config.update(
|
438
|
+
{
|
439
|
+
"num_heads": self.num_heads,
|
440
|
+
"hidden_dim": self.hidden_dim,
|
441
|
+
"mlp_ratio": self.mlp_ratio,
|
442
|
+
"use_context_projection": self.use_context_projection,
|
443
|
+
}
|
444
|
+
)
|
445
|
+
return config
|
446
|
+
|
447
|
+
def compute_output_shape(
|
448
|
+
self, inputs_shape, context_shape, timestep_embedding_shape
|
449
|
+
):
|
450
|
+
if self.use_context_projection:
|
451
|
+
return inputs_shape, context_shape
|
452
|
+
else:
|
453
|
+
return inputs_shape
|
454
|
+
|
455
|
+
|
169
456
|
class OutputLayer(layers.Layer):
|
170
457
|
def __init__(self, hidden_dim, output_dim, **kwargs):
|
171
458
|
super().__init__(**kwargs)
|
@@ -186,11 +473,11 @@ class OutputLayer(layers.Layer):
|
|
186
473
|
epsilon=1e-6,
|
187
474
|
center=False,
|
188
475
|
scale=False,
|
189
|
-
dtype=
|
476
|
+
dtype="float32",
|
190
477
|
name="norm",
|
191
478
|
)
|
192
479
|
self.output_dense = layers.Dense(
|
193
|
-
output_dim,
|
480
|
+
output_dim,
|
194
481
|
use_bias=True,
|
195
482
|
dtype=self.dtype_policy,
|
196
483
|
name="output_dense",
|
@@ -227,6 +514,11 @@ class OutputLayer(layers.Layer):
|
|
227
514
|
)
|
228
515
|
return config
|
229
516
|
|
517
|
+
def compute_output_shape(self, inputs_shape):
|
518
|
+
outputs_shape = list(inputs_shape)
|
519
|
+
outputs_shape[-1] = self.output_dim
|
520
|
+
return outputs_shape
|
521
|
+
|
230
522
|
|
231
523
|
class Unpatch(layers.Layer):
|
232
524
|
def __init__(self, patch_size, output_dim, **kwargs):
|
@@ -263,18 +555,48 @@ class Unpatch(layers.Layer):
|
|
263
555
|
return [inputs_shape[0], None, None, self.output_dim]
|
264
556
|
|
265
557
|
|
266
|
-
class MMDiT(
|
558
|
+
class MMDiT(Backbone):
|
559
|
+
"""Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3.
|
560
|
+
|
561
|
+
MMDiT is introduced in [
|
562
|
+
Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
|
563
|
+
https://arxiv.org/abs/2403.03206).
|
564
|
+
|
565
|
+
Args:
|
566
|
+
patch_size: int. The size of each square patch in the input image.
|
567
|
+
hidden_dim: int. The size of the transformer hidden state at the end
|
568
|
+
of each transformer layer.
|
569
|
+
num_layers: int. The number of transformer layers.
|
570
|
+
num_heads: int. The number of attention heads for each transformer.
|
571
|
+
position_size: int. The size of the height and width for the position
|
572
|
+
embedding.
|
573
|
+
mlp_ratio: float. The ratio of the mlp hidden dim to the transformer
|
574
|
+
latent_shape: tuple. The shape of the latent image.
|
575
|
+
context_shape: tuple. The shape of the context.
|
576
|
+
pooled_projection_shape: tuple. The shape of the pooled projection.
|
577
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
578
|
+
`"channels_first"`. The ordering of the dimensions in the
|
579
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
580
|
+
`(batch_size, height, width, channels)`
|
581
|
+
while `"channels_first"` corresponds to inputs with shape
|
582
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
583
|
+
`image_data_format` value found in your Keras config file at
|
584
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
585
|
+
`"channels_last"`.
|
586
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
587
|
+
to use for the model's computations and weights.
|
588
|
+
"""
|
589
|
+
|
267
590
|
def __init__(
|
268
591
|
self,
|
269
592
|
patch_size,
|
270
|
-
num_heads,
|
271
593
|
hidden_dim,
|
272
|
-
|
594
|
+
num_layers,
|
595
|
+
num_heads,
|
273
596
|
position_size,
|
274
|
-
output_dim,
|
275
597
|
mlp_ratio=4.0,
|
276
598
|
latent_shape=(64, 64, 16),
|
277
|
-
context_shape=(
|
599
|
+
context_shape=(None, 4096),
|
278
600
|
pooled_projection_shape=(2048,),
|
279
601
|
data_format=None,
|
280
602
|
dtype=None,
|
@@ -287,6 +609,7 @@ class MMDiT(keras.Model):
|
|
287
609
|
)
|
288
610
|
image_height = latent_shape[0] // patch_size
|
289
611
|
image_width = latent_shape[1] // patch_size
|
612
|
+
output_dim = latent_shape[-1]
|
290
613
|
output_dim_in_final = patch_size**2 * output_dim
|
291
614
|
data_format = standardize_data_format(data_format)
|
292
615
|
if data_format != "channels_last":
|
@@ -331,11 +654,11 @@ class MMDiT(keras.Model):
|
|
331
654
|
num_heads,
|
332
655
|
hidden_dim,
|
333
656
|
mlp_ratio,
|
334
|
-
use_context_projection=not (i ==
|
657
|
+
use_context_projection=not (i == num_layers - 1),
|
335
658
|
dtype=dtype,
|
336
659
|
name=f"joint_block_{i}",
|
337
660
|
)
|
338
|
-
for i in range(
|
661
|
+
for i in range(num_layers)
|
339
662
|
]
|
340
663
|
self.output_layer = OutputLayer(
|
341
664
|
hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
|
@@ -391,33 +714,22 @@ class MMDiT(keras.Model):
|
|
391
714
|
self.patch_size = patch_size
|
392
715
|
self.num_heads = num_heads
|
393
716
|
self.hidden_dim = hidden_dim
|
394
|
-
self.
|
717
|
+
self.num_layers = num_layers
|
395
718
|
self.position_size = position_size
|
396
|
-
self.output_dim = output_dim
|
397
719
|
self.mlp_ratio = mlp_ratio
|
398
720
|
self.latent_shape = latent_shape
|
399
721
|
self.context_shape = context_shape
|
400
722
|
self.pooled_projection_shape = pooled_projection_shape
|
401
723
|
|
402
|
-
if dtype is not None:
|
403
|
-
try:
|
404
|
-
self.dtype_policy = keras.dtype_policies.get(dtype)
|
405
|
-
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
406
|
-
except AttributeError:
|
407
|
-
if isinstance(dtype, keras.DTypePolicy):
|
408
|
-
dtype = dtype.name
|
409
|
-
self.dtype_policy = keras.DTypePolicy(dtype)
|
410
|
-
|
411
724
|
def get_config(self):
|
412
725
|
config = super().get_config()
|
413
726
|
config.update(
|
414
727
|
{
|
415
728
|
"patch_size": self.patch_size,
|
416
|
-
"num_heads": self.num_heads,
|
417
729
|
"hidden_dim": self.hidden_dim,
|
418
|
-
"
|
730
|
+
"num_layers": self.num_layers,
|
731
|
+
"num_heads": self.num_heads,
|
419
732
|
"position_size": self.position_size,
|
420
|
-
"output_dim": self.output_dim,
|
421
733
|
"mlp_ratio": self.mlp_ratio,
|
422
734
|
"latent_shape": self.latent_shape,
|
423
735
|
"context_shape": self.context_shape,
|