keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,520 @@
|
|
1
|
+
import itertools
|
2
|
+
from functools import partial
|
3
|
+
|
4
|
+
import keras
|
5
|
+
from keras import ops
|
6
|
+
from keras import random
|
7
|
+
|
8
|
+
from keras_hub.src.api_export import keras_hub_export
|
9
|
+
from keras_hub.src.models.task import Task
|
10
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
11
|
+
|
12
|
+
try:
|
13
|
+
import tensorflow as tf
|
14
|
+
except ImportError:
|
15
|
+
tf = None
|
16
|
+
|
17
|
+
|
18
|
+
@keras_hub_export("keras_hub.models.Inpaint")
|
19
|
+
class Inpaint(Task):
|
20
|
+
"""Base class for image-to-image tasks.
|
21
|
+
|
22
|
+
`Inpaint` tasks wrap a `keras_hub.models.Backbone` and
|
23
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
24
|
+
generation and generative fine-tuning.
|
25
|
+
|
26
|
+
`Inpaint` tasks provide an additional, high-level `generate()` function
|
27
|
+
which can be used to generate image by token with a (image, mask, string)
|
28
|
+
in, image out signature.
|
29
|
+
|
30
|
+
All `Inpaint` tasks include a `from_preset()` constructor which can be
|
31
|
+
used to load a pre-trained config and weights.
|
32
|
+
|
33
|
+
Example:
|
34
|
+
|
35
|
+
```python
|
36
|
+
# Load a Stable Diffusion 3 backbone with pre-trained weights.
|
37
|
+
reference_image = np.ones((1024, 1024, 3), dtype="float32")
|
38
|
+
reference_mask = np.ones((1024, 1024), dtype="float32")
|
39
|
+
inpaint = keras_hub.models.Inpaint.from_preset(
|
40
|
+
"stable_diffusion_3_medium",
|
41
|
+
)
|
42
|
+
inpaint.generate(
|
43
|
+
reference_image,
|
44
|
+
reference_mask,
|
45
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
46
|
+
)
|
47
|
+
|
48
|
+
# Load a Stable Diffusion 3 backbone at bfloat16 precision.
|
49
|
+
inpaint = keras_hub.models.Inpaint.from_preset(
|
50
|
+
"stable_diffusion_3_medium",
|
51
|
+
dtype="bfloat16",
|
52
|
+
)
|
53
|
+
inpaint.generate(
|
54
|
+
reference_image,
|
55
|
+
reference_mask,
|
56
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
57
|
+
)
|
58
|
+
```
|
59
|
+
"""
|
60
|
+
|
61
|
+
def __init__(self, *args, **kwargs):
|
62
|
+
super().__init__(*args, **kwargs)
|
63
|
+
# Default compilation.
|
64
|
+
self.compile()
|
65
|
+
|
66
|
+
@property
|
67
|
+
def support_negative_prompts(self):
|
68
|
+
"""Whether the model supports `negative_prompts` key in `generate()`."""
|
69
|
+
return bool(True)
|
70
|
+
|
71
|
+
@property
|
72
|
+
def image_shape(self):
|
73
|
+
return tuple(self.backbone.image_shape)
|
74
|
+
|
75
|
+
@property
|
76
|
+
def latent_shape(self):
|
77
|
+
return tuple(self.backbone.latent_shape)
|
78
|
+
|
79
|
+
def compile(
|
80
|
+
self,
|
81
|
+
optimizer="auto",
|
82
|
+
loss="auto",
|
83
|
+
*,
|
84
|
+
metrics="auto",
|
85
|
+
**kwargs,
|
86
|
+
):
|
87
|
+
"""Configures the `Inpaint` task for training.
|
88
|
+
|
89
|
+
The `Inpaint` task extends the default compilation signature of
|
90
|
+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
91
|
+
`metrics`. To override these defaults, pass any value
|
92
|
+
to these arguments during compilation.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
|
96
|
+
instance. Defaults to `"auto"`, which uses the default optimizer
|
97
|
+
for the given model and task. See `keras.Model.compile` and
|
98
|
+
`keras.optimizers` for more info on possible `optimizer` values.
|
99
|
+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
|
100
|
+
Defaults to `"auto"`, where a
|
101
|
+
`keras.losses.MeanSquaredError` loss will be applied. See
|
102
|
+
`keras.Model.compile` and `keras.losses` for more info on
|
103
|
+
possible `loss` values.
|
104
|
+
metrics: `"auto"`, or a list of metrics to be evaluated by
|
105
|
+
the model during training and testing. Defaults to `"auto"`,
|
106
|
+
where a `keras.metrics.MeanSquaredError` will be applied to
|
107
|
+
track the loss of the model during training. See
|
108
|
+
`keras.Model.compile` and `keras.metrics` for more info on
|
109
|
+
possible `metrics` values.
|
110
|
+
**kwargs: See `keras.Model.compile` for a full list of arguments
|
111
|
+
supported by the compile method.
|
112
|
+
"""
|
113
|
+
# Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414
|
114
|
+
if optimizer == "auto":
|
115
|
+
optimizer = keras.optimizers.AdamW(
|
116
|
+
1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0
|
117
|
+
)
|
118
|
+
if loss == "auto":
|
119
|
+
loss = keras.losses.MeanSquaredError()
|
120
|
+
if metrics == "auto":
|
121
|
+
metrics = [keras.metrics.MeanSquaredError()]
|
122
|
+
super().compile(
|
123
|
+
optimizer=optimizer,
|
124
|
+
loss=loss,
|
125
|
+
metrics=metrics,
|
126
|
+
**kwargs,
|
127
|
+
)
|
128
|
+
self.generate_function = None
|
129
|
+
|
130
|
+
def generate_step(self, *args, **kwargs):
|
131
|
+
"""Run generation on batches of input."""
|
132
|
+
raise NotImplementedError
|
133
|
+
|
134
|
+
def make_generate_function(self):
|
135
|
+
"""Create or return the compiled generation function."""
|
136
|
+
if self.generate_function is not None:
|
137
|
+
return self.generate_function
|
138
|
+
|
139
|
+
self.generate_function = self.generate_step
|
140
|
+
if keras.config.backend() == "torch":
|
141
|
+
import torch
|
142
|
+
|
143
|
+
def wrapped_function(*args, **kwargs):
|
144
|
+
with torch.no_grad():
|
145
|
+
return self.generate_step(*args, **kwargs)
|
146
|
+
|
147
|
+
self.generate_function = wrapped_function
|
148
|
+
elif keras.config.backend() == "tensorflow" and not self.run_eagerly:
|
149
|
+
self.generate_function = tf.function(
|
150
|
+
self.generate_step, jit_compile=self.jit_compile
|
151
|
+
)
|
152
|
+
elif keras.config.backend() == "jax" and not self.run_eagerly:
|
153
|
+
import jax
|
154
|
+
|
155
|
+
@partial(jax.jit)
|
156
|
+
def compiled_function(state, *args, **kwargs):
|
157
|
+
(
|
158
|
+
trainable_variables,
|
159
|
+
non_trainable_variables,
|
160
|
+
) = state
|
161
|
+
mapping = itertools.chain(
|
162
|
+
zip(self.trainable_variables, trainable_variables),
|
163
|
+
zip(self.non_trainable_variables, non_trainable_variables),
|
164
|
+
)
|
165
|
+
|
166
|
+
with keras.StatelessScope(state_mapping=mapping):
|
167
|
+
outputs = self.generate_step(*args, **kwargs)
|
168
|
+
return outputs
|
169
|
+
|
170
|
+
def wrapped_function(*args, **kwargs):
|
171
|
+
# Create an explicit tuple of all variable state.
|
172
|
+
state = (
|
173
|
+
# Use the explicit variable.value to preserve the
|
174
|
+
# sharding spec of distribution.
|
175
|
+
[v.value for v in self.trainable_variables],
|
176
|
+
[v.value for v in self.non_trainable_variables],
|
177
|
+
)
|
178
|
+
outputs = compiled_function(state, *args, **kwargs)
|
179
|
+
return outputs
|
180
|
+
|
181
|
+
self.generate_function = wrapped_function
|
182
|
+
return self.generate_function
|
183
|
+
|
184
|
+
def _normalize_generate_images(self, inputs):
|
185
|
+
"""Normalize user image to the generate function.
|
186
|
+
|
187
|
+
This function converts all inputs to tensors, adds a batch dimension if
|
188
|
+
necessary, and returns a iterable "dataset like" object (either an
|
189
|
+
actual `tf.data.Dataset` or a list with a single batch element).
|
190
|
+
"""
|
191
|
+
if tf and isinstance(inputs, tf.data.Dataset):
|
192
|
+
return inputs.as_numpy_iterator(), False
|
193
|
+
|
194
|
+
def normalize(x):
|
195
|
+
data_format = getattr(
|
196
|
+
self.backbone, "data_format", standardize_data_format(None)
|
197
|
+
)
|
198
|
+
input_is_scalar = False
|
199
|
+
x = ops.convert_to_tensor(x)
|
200
|
+
if len(ops.shape(x)) < 4:
|
201
|
+
x = ops.expand_dims(x, axis=0)
|
202
|
+
input_is_scalar = True
|
203
|
+
x = ops.image.resize(
|
204
|
+
x,
|
205
|
+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
|
206
|
+
interpolation="nearest",
|
207
|
+
data_format=data_format,
|
208
|
+
)
|
209
|
+
return x, input_is_scalar
|
210
|
+
|
211
|
+
if isinstance(inputs, dict):
|
212
|
+
for key in inputs:
|
213
|
+
inputs[key], input_is_scalar = normalize(inputs[key])
|
214
|
+
else:
|
215
|
+
inputs, input_is_scalar = normalize(inputs)
|
216
|
+
|
217
|
+
return inputs, input_is_scalar
|
218
|
+
|
219
|
+
def _normalize_generate_masks(self, inputs):
|
220
|
+
"""Normalize user masks to the generate function.
|
221
|
+
|
222
|
+
This function converts all inputs to tensors, adds a batch dimension if
|
223
|
+
necessary, and returns a iterable "dataset like" object (either an
|
224
|
+
actual `tf.data.Dataset` or a list with a single batch element).
|
225
|
+
"""
|
226
|
+
if tf and isinstance(inputs, tf.data.Dataset):
|
227
|
+
return inputs.as_numpy_iterator(), False
|
228
|
+
|
229
|
+
def normalize(x):
|
230
|
+
data_format = getattr(
|
231
|
+
self.backbone, "data_format", standardize_data_format(None)
|
232
|
+
)
|
233
|
+
input_is_scalar = False
|
234
|
+
x = ops.convert_to_tensor(x)
|
235
|
+
if len(ops.shape(x)) < 3:
|
236
|
+
x = ops.expand_dims(x, axis=0)
|
237
|
+
input_is_scalar = True
|
238
|
+
x = ops.expand_dims(x, axis=-1)
|
239
|
+
if keras.backend.standardize_dtype(x.dtype) == "bool":
|
240
|
+
x = ops.cast(x, "float32")
|
241
|
+
x = ops.image.resize(
|
242
|
+
x,
|
243
|
+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
|
244
|
+
interpolation="nearest",
|
245
|
+
data_format=data_format,
|
246
|
+
)
|
247
|
+
x = ops.squeeze(x, axis=-1)
|
248
|
+
return x, input_is_scalar
|
249
|
+
|
250
|
+
if isinstance(inputs, dict):
|
251
|
+
for key in inputs:
|
252
|
+
inputs[key], input_is_scalar = normalize(inputs[key])
|
253
|
+
else:
|
254
|
+
inputs, input_is_scalar = normalize(inputs)
|
255
|
+
|
256
|
+
return inputs, input_is_scalar
|
257
|
+
|
258
|
+
def _normalize_generate_inputs(self, inputs):
|
259
|
+
"""Normalize user input to the generate function.
|
260
|
+
|
261
|
+
This function converts all inputs to tensors, adds a batch dimension if
|
262
|
+
necessary, and returns a iterable "dataset like" object (either an
|
263
|
+
actual `tf.data.Dataset` or a list with a single batch element).
|
264
|
+
|
265
|
+
The input format must be one of the following:
|
266
|
+
- A dict with "images", "masks", "prompts" and/or "negative_prompts"
|
267
|
+
keys
|
268
|
+
- A tf.data.Dataset with "images", "masks", "prompts" and/or
|
269
|
+
"negative_prompts" keys
|
270
|
+
|
271
|
+
The output will be a dict with "images", "masks", "prompts" and/or
|
272
|
+
"negative_prompts" keys.
|
273
|
+
"""
|
274
|
+
if tf and isinstance(inputs, tf.data.Dataset):
|
275
|
+
_inputs = {
|
276
|
+
"images": inputs.map(lambda x: x["images"]).as_numpy_iterator(),
|
277
|
+
"masks": inputs.map(lambda x: x["masks"]).as_numpy_iterator(),
|
278
|
+
"prompts": inputs.map(
|
279
|
+
lambda x: x["prompts"]
|
280
|
+
).as_numpy_iterator(),
|
281
|
+
}
|
282
|
+
if self.support_negative_prompts:
|
283
|
+
_inputs["negative_prompts"] = inputs.map(
|
284
|
+
lambda x: x["negative_prompts"]
|
285
|
+
).as_numpy_iterator()
|
286
|
+
return _inputs, False
|
287
|
+
|
288
|
+
def normalize(x):
|
289
|
+
if isinstance(x, str):
|
290
|
+
return [x], True
|
291
|
+
if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
|
292
|
+
return x[tf.newaxis], True
|
293
|
+
return x, False
|
294
|
+
|
295
|
+
def normalize_images(x):
|
296
|
+
data_format = getattr(
|
297
|
+
self.backbone, "data_format", standardize_data_format(None)
|
298
|
+
)
|
299
|
+
input_is_scalar = False
|
300
|
+
x = ops.convert_to_tensor(x)
|
301
|
+
if len(ops.shape(x)) < 4:
|
302
|
+
x = ops.expand_dims(x, axis=0)
|
303
|
+
input_is_scalar = True
|
304
|
+
x = ops.image.resize(
|
305
|
+
x,
|
306
|
+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
|
307
|
+
interpolation="nearest",
|
308
|
+
data_format=data_format,
|
309
|
+
)
|
310
|
+
return x, input_is_scalar
|
311
|
+
|
312
|
+
def normalize_masks(x):
|
313
|
+
data_format = getattr(
|
314
|
+
self.backbone, "data_format", standardize_data_format(None)
|
315
|
+
)
|
316
|
+
input_is_scalar = False
|
317
|
+
x = ops.convert_to_tensor(x)
|
318
|
+
if len(ops.shape(x)) < 3:
|
319
|
+
x = ops.expand_dims(x, axis=0)
|
320
|
+
input_is_scalar = True
|
321
|
+
x = ops.expand_dims(x, axis=-1)
|
322
|
+
if keras.backend.standardize_dtype(x.dtype) == "bool":
|
323
|
+
x = ops.cast(x, "float32")
|
324
|
+
x = ops.image.resize(
|
325
|
+
x,
|
326
|
+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
|
327
|
+
interpolation="nearest",
|
328
|
+
data_format=data_format,
|
329
|
+
)
|
330
|
+
x = ops.squeeze(x, axis=-1)
|
331
|
+
return x, input_is_scalar
|
332
|
+
|
333
|
+
def get_dummy_prompts(x):
|
334
|
+
dummy_prompts = [""] * len(x)
|
335
|
+
if tf and isinstance(x, tf.Tensor):
|
336
|
+
return tf.convert_to_tensor(dummy_prompts)
|
337
|
+
else:
|
338
|
+
return dummy_prompts
|
339
|
+
|
340
|
+
for key in inputs:
|
341
|
+
if key == "images":
|
342
|
+
inputs[key], input_is_scalar = normalize_images(inputs[key])
|
343
|
+
elif key == "masks":
|
344
|
+
inputs[key], input_is_scalar = normalize_masks(inputs[key])
|
345
|
+
else:
|
346
|
+
inputs[key], input_is_scalar = normalize(inputs[key])
|
347
|
+
|
348
|
+
if self.support_negative_prompts and "negative_prompts" not in inputs:
|
349
|
+
inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"])
|
350
|
+
|
351
|
+
return [inputs], input_is_scalar
|
352
|
+
|
353
|
+
def _normalize_generate_outputs(self, outputs, input_is_scalar):
|
354
|
+
"""Normalize user output from the generate function.
|
355
|
+
|
356
|
+
This function converts all output to numpy with a value range of
|
357
|
+
`[0, 255]`. If a batch dimension was added to the input, it is removed
|
358
|
+
from the output.
|
359
|
+
"""
|
360
|
+
|
361
|
+
def normalize(x):
|
362
|
+
outputs = ops.concatenate(x, axis=0)
|
363
|
+
outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0)
|
364
|
+
outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8")
|
365
|
+
outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
|
366
|
+
return ops.convert_to_numpy(outputs)
|
367
|
+
|
368
|
+
if isinstance(outputs[0], dict):
|
369
|
+
normalized = {}
|
370
|
+
for key in outputs[0]:
|
371
|
+
normalized[key] = normalize([x[key] for x in outputs])
|
372
|
+
return normalized
|
373
|
+
return normalize([x for x in outputs])
|
374
|
+
|
375
|
+
def generate(
|
376
|
+
self,
|
377
|
+
inputs,
|
378
|
+
num_steps,
|
379
|
+
strength,
|
380
|
+
guidance_scale=None,
|
381
|
+
seed=None,
|
382
|
+
):
|
383
|
+
"""Generate image based on the provided `inputs`.
|
384
|
+
|
385
|
+
Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"`
|
386
|
+
keys. `"images"` are reference images within a value range of
|
387
|
+
`[-1.0, 1.0]`, which will be resized to height and width from
|
388
|
+
`self.backbone.image_shape`, then encoded into latent space by the VAE
|
389
|
+
encoder. `"masks"` are mask images with a boolean dtype, where white
|
390
|
+
pixels are repainted while black pixels are preserved. `"prompts"` are
|
391
|
+
strings that will be tokenized and encoded by the text encoder.
|
392
|
+
|
393
|
+
Some models support a `"negative_prompts"` key, which helps steer the
|
394
|
+
model away from generating certain styles and elements. To enable this,
|
395
|
+
add `"negative_prompts"` to the input dict.
|
396
|
+
|
397
|
+
If `inputs` are a `tf.data.Dataset`, outputs will be generated
|
398
|
+
"batch-by-batch" and concatenated. Otherwise, all inputs will be
|
399
|
+
processed as batches.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
inputs: python data, tensor data, or a `tf.data.Dataset`. The format
|
403
|
+
must be one of the following:
|
404
|
+
- A dict with `"images"`, `"masks"`, `"prompts"` and/or
|
405
|
+
`"negative_prompts"` keys.
|
406
|
+
- A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"`
|
407
|
+
and/or `"negative_prompts"` keys.
|
408
|
+
num_steps: int. The number of diffusion steps to take.
|
409
|
+
strength: float. Indicates the extent to which the reference
|
410
|
+
`images` are transformed. Must be between `0.0` and `1.0`. When
|
411
|
+
`strength=1.0`, `images` is essentially ignore and added noise
|
412
|
+
is maximum and the denoising process runs for the full number of
|
413
|
+
iterations specified in `num_steps`.
|
414
|
+
guidance_scale: Optional float. The classifier free guidance scale
|
415
|
+
defined in [Classifier-Free Diffusion Guidance](
|
416
|
+
https://arxiv.org/abs/2207.12598). A higher scale encourages
|
417
|
+
generating images more closely related to the prompts, typically
|
418
|
+
at the cost of lower image quality. Note that some models don't
|
419
|
+
utilize classifier-free guidance.
|
420
|
+
seed: optional int. Used as a random seed.
|
421
|
+
"""
|
422
|
+
num_steps = int(num_steps)
|
423
|
+
strength = float(strength)
|
424
|
+
guidance_scale = (
|
425
|
+
float(guidance_scale) if guidance_scale is not None else None
|
426
|
+
)
|
427
|
+
if strength < 0.0 or strength > 1.0:
|
428
|
+
raise ValueError(
|
429
|
+
"`strength` must be between `0.0` and `1.0`. "
|
430
|
+
f"Received strength={strength}."
|
431
|
+
)
|
432
|
+
if guidance_scale is not None and guidance_scale > 1.0:
|
433
|
+
guidance_scale = ops.convert_to_tensor(guidance_scale)
|
434
|
+
else:
|
435
|
+
guidance_scale = None
|
436
|
+
starting_step = int(num_steps * (1.0 - strength))
|
437
|
+
starting_step = ops.convert_to_tensor(starting_step, "int32")
|
438
|
+
num_steps = ops.convert_to_tensor(num_steps, "int32")
|
439
|
+
guidance_scale = ops.convert_to_tensor(guidance_scale)
|
440
|
+
|
441
|
+
# Check `inputs` format.
|
442
|
+
required_keys = ["images", "masks", "prompts"]
|
443
|
+
if tf and isinstance(inputs, tf.data.Dataset):
|
444
|
+
spec = inputs.element_spec
|
445
|
+
if not all(key in spec for key in required_keys):
|
446
|
+
raise ValueError(
|
447
|
+
"Expected a `tf.data.Dataset` with the following keys:"
|
448
|
+
f"{required_keys}. Received: inputs.element_spec={spec}"
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
if not isinstance(inputs, dict):
|
452
|
+
raise ValueError(
|
453
|
+
"Expected a `dict` or `tf.data.Dataset`. "
|
454
|
+
f"Received: inputs={inputs} of type {type(inputs)}."
|
455
|
+
)
|
456
|
+
if not all(key in inputs for key in required_keys):
|
457
|
+
raise ValueError(
|
458
|
+
"Expected a `dict` with the following keys:"
|
459
|
+
f"{required_keys}. "
|
460
|
+
f"Received: inputs.keys={list(inputs.keys())}"
|
461
|
+
)
|
462
|
+
|
463
|
+
# Setup our three main passes.
|
464
|
+
# 1. Preprocessing strings to dense integer tensors.
|
465
|
+
# 2. Generate outputs via a compiled function on dense tensors.
|
466
|
+
# 3. Postprocess dense tensors to a value range of `[0, 255]`.
|
467
|
+
generate_function = self.make_generate_function()
|
468
|
+
|
469
|
+
def preprocess(x):
|
470
|
+
if self.preprocessor is not None:
|
471
|
+
return self.preprocessor.generate_preprocess(x)
|
472
|
+
else:
|
473
|
+
return x
|
474
|
+
|
475
|
+
def generate(images, masks, x):
|
476
|
+
token_ids = x[0] if self.support_negative_prompts else x
|
477
|
+
|
478
|
+
# Initialize noises.
|
479
|
+
if isinstance(token_ids, dict):
|
480
|
+
arbitrary_key = list(token_ids.keys())[0]
|
481
|
+
batch_size = ops.shape(token_ids[arbitrary_key])[0]
|
482
|
+
else:
|
483
|
+
batch_size = ops.shape(token_ids)[0]
|
484
|
+
noise_shape = (batch_size,) + self.latent_shape[1:]
|
485
|
+
noises = random.normal(noise_shape, dtype="float32", seed=seed)
|
486
|
+
|
487
|
+
return generate_function(
|
488
|
+
images,
|
489
|
+
masks,
|
490
|
+
noises,
|
491
|
+
x,
|
492
|
+
starting_step,
|
493
|
+
num_steps,
|
494
|
+
guidance_scale,
|
495
|
+
)
|
496
|
+
|
497
|
+
# Normalize and preprocess inputs.
|
498
|
+
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
|
499
|
+
if self.support_negative_prompts:
|
500
|
+
images = [x["images"] for x in inputs]
|
501
|
+
masks = [x["masks"] for x in inputs]
|
502
|
+
token_ids = [preprocess(x["prompts"]) for x in inputs]
|
503
|
+
negative_token_ids = [
|
504
|
+
preprocess(x["negative_prompts"]) for x in inputs
|
505
|
+
]
|
506
|
+
# Tuple format: (images, masks, (token_ids, negative_token_ids)).
|
507
|
+
inputs = [
|
508
|
+
x
|
509
|
+
for x in zip(images, masks, zip(token_ids, negative_token_ids))
|
510
|
+
]
|
511
|
+
else:
|
512
|
+
images = [x["images"] for x in inputs]
|
513
|
+
masks = [x["masks"] for x in inputs]
|
514
|
+
token_ids = [preprocess(x["prompts"]) for x in inputs]
|
515
|
+
# Tuple format: (images, masks, token_ids).
|
516
|
+
inputs = [x for x in zip(images, masks, token_ids)]
|
517
|
+
|
518
|
+
# Inpaint.
|
519
|
+
outputs = [generate(*x) for x in inputs]
|
520
|
+
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
@@ -34,17 +34,18 @@ class LlamaBackbone(Backbone):
|
|
34
34
|
num_layers (int): The number of transformer layers.
|
35
35
|
num_query_heads (int): The number of query attention heads for
|
36
36
|
each transformer.
|
37
|
-
hidden_dim (int): The size of the transformer encoding and pooling
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
37
|
+
hidden_dim (int): The size of the transformer encoding and pooling
|
38
|
+
layers.
|
39
|
+
intermediate_dim (int): The output dimension of the first Dense layer in
|
40
|
+
a three-layer feedforward network for each transformer.
|
41
|
+
num_key_value_heads (int): The number of key and value attention heads
|
42
|
+
for each transformer.
|
43
|
+
rope_max_wavelength (int, optional): The maximum angular wavelength of
|
44
|
+
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
45
|
+
rope_scaling_factor (float, optional): The scaling factor for
|
46
|
+
calculation of roatary embedding. Defaults to `1.0`.
|
47
|
+
layer_norm_epsilon (float, optional): Epsilon for the layer
|
48
|
+
normalization layers in the transformer decoder. Defaults to `1e-6`.
|
48
49
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
49
50
|
for model computations and weights. Note that some computations,
|
50
51
|
such as softmax and layer normalization, will always be done at
|
@@ -59,7 +60,7 @@ class LlamaBackbone(Backbone):
|
|
59
60
|
}
|
60
61
|
|
61
62
|
# Pretrained Llama decoder.
|
62
|
-
model = keras_hub.models.LlamaBackbone.from_preset("
|
63
|
+
model = keras_hub.models.LlamaBackbone.from_preset("llama2_7b_en")
|
63
64
|
model(input_data)
|
64
65
|
|
65
66
|
# Randomly initialized Llama decoder with custom config.
|
@@ -175,3 +176,128 @@ class LlamaBackbone(Backbone):
|
|
175
176
|
}
|
176
177
|
)
|
177
178
|
return config
|
179
|
+
|
180
|
+
@staticmethod
|
181
|
+
def get_layout_map(
|
182
|
+
device_mesh,
|
183
|
+
model_parallel_dim_name="model",
|
184
|
+
data_parallel_dim_name="batch",
|
185
|
+
):
|
186
|
+
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
|
187
|
+
|
188
|
+
The returned `LayoutMap` contains the sharding spec for the Llama
|
189
|
+
backbone weights, so that you can use it to distribute weights across
|
190
|
+
the accelerators.
|
191
|
+
|
192
|
+
Example:
|
193
|
+
```
|
194
|
+
# Feel free to change the mesh shape to balance data and model
|
195
|
+
# parallelism
|
196
|
+
mesh = keras.distribution.DeviceMesh(
|
197
|
+
shape=(1, 8),
|
198
|
+
axis_names=('batch', 'model'),
|
199
|
+
devices=keras.distribution.list_devices(),
|
200
|
+
)
|
201
|
+
layout_map = LlamaBackbone.get_layout_map(
|
202
|
+
mesh,
|
203
|
+
model_parallel_dim_name="model",
|
204
|
+
)
|
205
|
+
|
206
|
+
distribution = keras.distribution.ModelParallel(
|
207
|
+
layout_map=layout_map,
|
208
|
+
batch_dim_name='batch',
|
209
|
+
)
|
210
|
+
|
211
|
+
with distribution.scope():
|
212
|
+
llama_model = keras_hub.models.LlamaCausalLM.from_preset()
|
213
|
+
```
|
214
|
+
|
215
|
+
To see how the layout map was applied, load the model then run
|
216
|
+
(for one decoder block):
|
217
|
+
```
|
218
|
+
embedding_layer = llama_model.backbone.get_layer("token_embedding")
|
219
|
+
decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0')
|
220
|
+
for variable in embedding_layer.weights + decoder_block_1.weights:
|
221
|
+
print(
|
222
|
+
f'{variable.path:<58} {str(variable.shape):<16} '
|
223
|
+
f'{str(variable.value.sharding.spec)}'
|
224
|
+
)
|
225
|
+
```
|
226
|
+
|
227
|
+
Args:
|
228
|
+
device_mesh: The `keras.distribution.DeviceMesh` instance for
|
229
|
+
distribution.
|
230
|
+
model_parallel_dim_name: The axis name of the device mesh, where
|
231
|
+
the weights should be partition on.
|
232
|
+
data_parallel_dim_name: The axis name of the device mesh, where
|
233
|
+
the data should be partition on.
|
234
|
+
Return:
|
235
|
+
`keras.distribution.LayoutMap` that contains the sharding spec
|
236
|
+
for all the model weights.
|
237
|
+
"""
|
238
|
+
# The weight path and shape of the Llama backbone is like below
|
239
|
+
# token_embedding/embeddings (128256, 2048)
|
240
|
+
# repeat block for decoder
|
241
|
+
# transformer_layer_0/self_attention/query/kernel (2048, 32, 64)
|
242
|
+
# transformer_layer_0/self_attention/key/kernel (2048, 8, 64)
|
243
|
+
# transformer_layer_0/self_attention/value/kernel (2048, 8, 64)
|
244
|
+
# transformer_layer_0/self_attention/attention_output/kernel
|
245
|
+
# (32, 64, 2048)
|
246
|
+
# transformer_layer_0/self_attention_layernorm/scale (2048,)
|
247
|
+
# transformer_layer_0/feedforward_intermediate_dense/kernel
|
248
|
+
# (2048, 8192)
|
249
|
+
# transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192)
|
250
|
+
# transformer_layer_0/feedforward_output_dense/kerne (8192, 2048)
|
251
|
+
# transformer_layer_0/feedforward_layernorm/scale (2048,)
|
252
|
+
|
253
|
+
if not isinstance(device_mesh, keras.distribution.DeviceMesh):
|
254
|
+
raise ValueError(
|
255
|
+
"Invalid device_mesh type. Expected "
|
256
|
+
f"`keras.distribution.Device`, got {type(device_mesh)}"
|
257
|
+
)
|
258
|
+
if model_parallel_dim_name not in device_mesh.axis_names:
|
259
|
+
raise ValueError(
|
260
|
+
f"{model_parallel_dim_name} is not found in the "
|
261
|
+
f"device_mesh.axis_names. {device_mesh.axis_name=}"
|
262
|
+
)
|
263
|
+
if data_parallel_dim_name not in device_mesh.axis_names:
|
264
|
+
raise ValueError(
|
265
|
+
f"{data_parallel_dim_name} is not found in the "
|
266
|
+
f"device_mesh.axis_names. {device_mesh.axis_name=}"
|
267
|
+
)
|
268
|
+
# Note that it is possible to further config the mesh to be 3D, eg
|
269
|
+
# (data, seq, model). We leave it as 2D for now for simplicity.
|
270
|
+
data_dim = data_parallel_dim_name
|
271
|
+
model_dim = model_parallel_dim_name
|
272
|
+
# The sharding config is based on the Gemma team training config.
|
273
|
+
# See https://arxiv.org/abs/2403.08295
|
274
|
+
layout_map = keras.distribution.LayoutMap(device_mesh)
|
275
|
+
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
|
276
|
+
layout_map[
|
277
|
+
"transformer_layer.*self_attention.*(query|key|value).kernel"
|
278
|
+
] = (
|
279
|
+
model_dim,
|
280
|
+
data_dim,
|
281
|
+
None,
|
282
|
+
)
|
283
|
+
layout_map["transformer_layer.*attention_output.kernel"] = (
|
284
|
+
model_dim,
|
285
|
+
None,
|
286
|
+
data_dim,
|
287
|
+
)
|
288
|
+
layout_map[
|
289
|
+
"transformer_layer.*feedforward_intermediate_dense.kernel"
|
290
|
+
] = (
|
291
|
+
data_dim,
|
292
|
+
model_dim,
|
293
|
+
)
|
294
|
+
layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = (
|
295
|
+
data_dim,
|
296
|
+
model_dim,
|
297
|
+
)
|
298
|
+
layout_map["transformer_layer.*feedforward_output_dense.kernel"] = (
|
299
|
+
model_dim,
|
300
|
+
data_dim,
|
301
|
+
)
|
302
|
+
|
303
|
+
return layout_map
|