keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -44,10 +44,10 @@ class GemmaBackbone(Backbone):
|
|
44
44
|
`hidden_dim / num_query_heads`. Defaults to True.
|
45
45
|
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
46
46
|
block. Defaults to False.
|
47
|
-
use_post_attention_norm: boolean. Whether to normalize after the
|
48
|
-
block. Defaults to False.
|
49
|
-
attention_logit_soft_cap: None or int. Soft cap for the attention
|
50
|
-
Defaults to None.
|
47
|
+
use_post_attention_norm: boolean. Whether to normalize after the
|
48
|
+
attention block. Defaults to False.
|
49
|
+
attention_logit_soft_cap: None or int. Soft cap for the attention
|
50
|
+
logits. Defaults to None.
|
51
51
|
final_logit_soft_cap: None or int. Soft cap for the final logits.
|
52
52
|
Defaults to None.
|
53
53
|
use_sliding_window_attention boolean. Whether to use sliding local
|
@@ -205,7 +205,9 @@ class GemmaBackbone(Backbone):
|
|
205
205
|
"final_logit_soft_cap": self.final_logit_soft_cap,
|
206
206
|
"attention_logit_soft_cap": self.attention_logit_soft_cap,
|
207
207
|
"sliding_window_size": self.sliding_window_size,
|
208
|
-
"use_sliding_window_attention":
|
208
|
+
"use_sliding_window_attention": (
|
209
|
+
self.use_sliding_window_attention
|
210
|
+
),
|
209
211
|
}
|
210
212
|
)
|
211
213
|
return config
|
@@ -224,7 +226,8 @@ class GemmaBackbone(Backbone):
|
|
224
226
|
|
225
227
|
Example:
|
226
228
|
```
|
227
|
-
# Feel free to change the mesh shape to balance data and model
|
229
|
+
# Feel free to change the mesh shape to balance data and model
|
230
|
+
# parallelism
|
228
231
|
mesh = keras.distribution.DeviceMesh(
|
229
232
|
shape=(1, 8), axis_names=('batch', 'model'),
|
230
233
|
devices=keras.distribution.list_devices())
|
@@ -232,11 +235,23 @@ class GemmaBackbone(Backbone):
|
|
232
235
|
mesh, model_parallel_dim_name="model")
|
233
236
|
|
234
237
|
distribution = keras.distribution.ModelParallel(
|
235
|
-
|
238
|
+
layout_map=layout_map, batch_dim_name='batch')
|
236
239
|
with distribution.scope():
|
237
240
|
gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
|
238
241
|
```
|
239
242
|
|
243
|
+
To see how the layout map was applied, load the model then run (for one
|
244
|
+
decoder block):
|
245
|
+
```
|
246
|
+
embedding_layer = gemma_model.backbone.get_layer("token_embedding")
|
247
|
+
decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
|
248
|
+
for variable in embedding_layer.weights + decoder_block_1.weights:
|
249
|
+
print(
|
250
|
+
f'{variable.path:<58} {str(variable.shape):<16} '
|
251
|
+
f'{str(variable.value.sharding.spec)}'
|
252
|
+
)
|
253
|
+
```
|
254
|
+
|
240
255
|
Args:
|
241
256
|
device_mesh: The `keras.distribution.DeviceMesh` instance for
|
242
257
|
distribution.
|
@@ -246,25 +261,25 @@ class GemmaBackbone(Backbone):
|
|
246
261
|
the data should be partition on.
|
247
262
|
Return:
|
248
263
|
`keras.distribution.LayoutMap` that contains the sharding spec
|
249
|
-
|
264
|
+
for all the model weights.
|
250
265
|
"""
|
251
266
|
# The weight path and shape of the Gemma backbone is like below (for 2G)
|
252
|
-
# token_embedding/embeddings, (256128, 2048)
|
267
|
+
# token_embedding/embeddings, (256128, 2048)
|
253
268
|
# repeat block for decoder
|
254
269
|
# ...
|
255
|
-
# decoder_block_17/pre_attention_norm/scale, (2048,)
|
256
|
-
# decoder_block_17/attention/query/kernel, (8, 2048, 256)
|
257
|
-
# decoder_block_17/attention/key/kernel, (8, 2048, 256)
|
258
|
-
# decoder_block_17/attention/value/kernel, (8, 2048, 256)
|
259
|
-
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
|
260
|
-
# decoder_block_17/pre_ffw_norm/scale, (2048,)
|
261
|
-
# decoder_block_17/ffw_gating/kernel, (2048, 16384)
|
262
|
-
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
|
263
|
-
# decoder_block_17/ffw_linear/kernel, (16384, 2048)
|
270
|
+
# decoder_block_17/pre_attention_norm/scale, (2048,)
|
271
|
+
# decoder_block_17/attention/query/kernel, (8, 2048, 256)
|
272
|
+
# decoder_block_17/attention/key/kernel, (8, 2048, 256)
|
273
|
+
# decoder_block_17/attention/value/kernel, (8, 2048, 256)
|
274
|
+
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
|
275
|
+
# decoder_block_17/pre_ffw_norm/scale, (2048,)
|
276
|
+
# decoder_block_17/ffw_gating/kernel, (2048, 16384)
|
277
|
+
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
|
278
|
+
# decoder_block_17/ffw_linear/kernel, (16384, 2048)
|
264
279
|
if not isinstance(device_mesh, keras.distribution.DeviceMesh):
|
265
280
|
raise ValueError(
|
266
|
-
"Invalid device_mesh type. Expected
|
267
|
-
f" got {type(device_mesh)}"
|
281
|
+
"Invalid device_mesh type. Expected "
|
282
|
+
f"`keras.distribution.Device`, got {type(device_mesh)}"
|
268
283
|
)
|
269
284
|
if model_parallel_dim_name not in device_mesh.axis_names:
|
270
285
|
raise ValueError(
|
@@ -187,8 +187,8 @@ class GemmaCausalLM(CausalLM):
|
|
187
187
|
Args:
|
188
188
|
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
189
189
|
cache: a dense float Tensor, the cache of key and value.
|
190
|
-
cache_update_index: int, or int Tensor. The index of current inputs
|
191
|
-
whole sequence.
|
190
|
+
cache_update_index: int, or int Tensor. The index of current inputs
|
191
|
+
in the whole sequence.
|
192
192
|
|
193
193
|
Returns:
|
194
194
|
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
@@ -220,7 +220,9 @@ class GemmaDecoderBlock(keras.layers.Layer):
|
|
220
220
|
"use_post_ffw_norm": self.use_post_ffw_norm,
|
221
221
|
"use_post_attention_norm": self.use_post_attention_norm,
|
222
222
|
"logit_soft_cap": self.logit_soft_cap,
|
223
|
-
"use_sliding_window_attention":
|
223
|
+
"use_sliding_window_attention": (
|
224
|
+
self.use_sliding_window_attention
|
225
|
+
),
|
224
226
|
"sliding_window_size": self.sliding_window_size,
|
225
227
|
"query_head_dim_normalize": self.query_head_dim_normalize,
|
226
228
|
}
|
@@ -6,11 +6,9 @@ backbone_presets = {
|
|
6
6
|
"metadata": {
|
7
7
|
"description": "2 billion parameter, 18-layer, base Gemma model.",
|
8
8
|
"params": 2506172416,
|
9
|
-
"official_name": "Gemma",
|
10
9
|
"path": "gemma",
|
11
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
12
10
|
},
|
13
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/
|
11
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/3",
|
14
12
|
},
|
15
13
|
"gemma_instruct_2b_en": {
|
16
14
|
"metadata": {
|
@@ -18,11 +16,9 @@ backbone_presets = {
|
|
18
16
|
"2 billion parameter, 18-layer, instruction tuned Gemma model."
|
19
17
|
),
|
20
18
|
"params": 2506172416,
|
21
|
-
"official_name": "Gemma",
|
22
19
|
"path": "gemma",
|
23
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
24
20
|
},
|
25
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/
|
21
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/3",
|
26
22
|
},
|
27
23
|
"gemma_1.1_instruct_2b_en": {
|
28
24
|
"metadata": {
|
@@ -31,11 +27,9 @@ backbone_presets = {
|
|
31
27
|
"The 1.1 update improves model quality."
|
32
28
|
),
|
33
29
|
"params": 2506172416,
|
34
|
-
"official_name": "Gemma",
|
35
30
|
"path": "gemma",
|
36
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
37
31
|
},
|
38
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/
|
32
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/4",
|
39
33
|
},
|
40
34
|
"code_gemma_1.1_2b_en": {
|
41
35
|
"metadata": {
|
@@ -45,11 +39,9 @@ backbone_presets = {
|
|
45
39
|
"completion. The 1.1 update improves model quality."
|
46
40
|
),
|
47
41
|
"params": 2506172416,
|
48
|
-
"official_name": "Gemma",
|
49
42
|
"path": "gemma",
|
50
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
51
43
|
},
|
52
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/
|
44
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/2",
|
53
45
|
},
|
54
46
|
"code_gemma_2b_en": {
|
55
47
|
"metadata": {
|
@@ -59,21 +51,17 @@ backbone_presets = {
|
|
59
51
|
"completion."
|
60
52
|
),
|
61
53
|
"params": 2506172416,
|
62
|
-
"official_name": "Gemma",
|
63
54
|
"path": "gemma",
|
64
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
65
55
|
},
|
66
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_2b_en/
|
56
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_2b_en/2",
|
67
57
|
},
|
68
58
|
"gemma_7b_en": {
|
69
59
|
"metadata": {
|
70
60
|
"description": "7 billion parameter, 28-layer, base Gemma model.",
|
71
61
|
"params": 8537680896,
|
72
|
-
"official_name": "Gemma",
|
73
62
|
"path": "gemma",
|
74
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
75
63
|
},
|
76
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/
|
64
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/3",
|
77
65
|
},
|
78
66
|
"gemma_instruct_7b_en": {
|
79
67
|
"metadata": {
|
@@ -81,11 +69,9 @@ backbone_presets = {
|
|
81
69
|
"7 billion parameter, 28-layer, instruction tuned Gemma model."
|
82
70
|
),
|
83
71
|
"params": 8537680896,
|
84
|
-
"official_name": "Gemma",
|
85
72
|
"path": "gemma",
|
86
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
87
73
|
},
|
88
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/
|
74
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/3",
|
89
75
|
},
|
90
76
|
"gemma_1.1_instruct_7b_en": {
|
91
77
|
"metadata": {
|
@@ -94,11 +80,9 @@ backbone_presets = {
|
|
94
80
|
"The 1.1 update improves model quality."
|
95
81
|
),
|
96
82
|
"params": 8537680896,
|
97
|
-
"official_name": "Gemma",
|
98
83
|
"path": "gemma",
|
99
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
100
84
|
},
|
101
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/
|
85
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/4",
|
102
86
|
},
|
103
87
|
"code_gemma_7b_en": {
|
104
88
|
"metadata": {
|
@@ -108,11 +92,9 @@ backbone_presets = {
|
|
108
92
|
"completion."
|
109
93
|
),
|
110
94
|
"params": 8537680896,
|
111
|
-
"official_name": "Gemma",
|
112
95
|
"path": "gemma",
|
113
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
114
96
|
},
|
115
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/
|
97
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/2",
|
116
98
|
},
|
117
99
|
"code_gemma_instruct_7b_en": {
|
118
100
|
"metadata": {
|
@@ -122,11 +104,9 @@ backbone_presets = {
|
|
122
104
|
"to code."
|
123
105
|
),
|
124
106
|
"params": 8537680896,
|
125
|
-
"official_name": "Gemma",
|
126
107
|
"path": "gemma",
|
127
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
128
108
|
},
|
129
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/
|
109
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/2",
|
130
110
|
},
|
131
111
|
"code_gemma_1.1_instruct_7b_en": {
|
132
112
|
"metadata": {
|
@@ -136,100 +116,86 @@ backbone_presets = {
|
|
136
116
|
"to code. The 1.1 update improves model quality."
|
137
117
|
),
|
138
118
|
"params": 8537680896,
|
139
|
-
"official_name": "Gemma",
|
140
119
|
"path": "gemma",
|
141
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
142
120
|
},
|
143
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/
|
121
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/2",
|
144
122
|
},
|
145
123
|
"gemma2_2b_en": {
|
146
124
|
"metadata": {
|
147
125
|
"description": "2 billion parameter, 26-layer, base Gemma model.",
|
148
126
|
"params": 2614341888,
|
149
|
-
"official_name": "Gemma",
|
150
127
|
"path": "gemma",
|
151
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
152
128
|
},
|
153
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_2b_en/
|
129
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_2b_en/2",
|
154
130
|
},
|
155
131
|
"gemma2_instruct_2b_en": {
|
156
132
|
"metadata": {
|
157
|
-
"description":
|
133
|
+
"description": (
|
134
|
+
"2 billion parameter, 26-layer, instruction tuned Gemma model."
|
135
|
+
),
|
158
136
|
"params": 2614341888,
|
159
|
-
"official_name": "Gemma",
|
160
137
|
"path": "gemma",
|
161
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
162
138
|
},
|
163
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/
|
139
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/2",
|
164
140
|
},
|
165
141
|
"gemma2_9b_en": {
|
166
142
|
"metadata": {
|
167
143
|
"description": "9 billion parameter, 42-layer, base Gemma model.",
|
168
144
|
"params": 9241705984,
|
169
|
-
"official_name": "Gemma",
|
170
145
|
"path": "gemma",
|
171
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
172
146
|
},
|
173
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/
|
147
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/3",
|
174
148
|
},
|
175
149
|
"gemma2_instruct_9b_en": {
|
176
150
|
"metadata": {
|
177
|
-
"description":
|
151
|
+
"description": (
|
152
|
+
"9 billion parameter, 42-layer, instruction tuned Gemma model."
|
153
|
+
),
|
178
154
|
"params": 9241705984,
|
179
|
-
"official_name": "Gemma",
|
180
155
|
"path": "gemma",
|
181
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
182
156
|
},
|
183
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/
|
157
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/3",
|
184
158
|
},
|
185
159
|
"gemma2_27b_en": {
|
186
160
|
"metadata": {
|
187
161
|
"description": "27 billion parameter, 42-layer, base Gemma model.",
|
188
162
|
"params": 27227128320,
|
189
|
-
"official_name": "Gemma",
|
190
163
|
"path": "gemma",
|
191
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
192
164
|
},
|
193
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/
|
165
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/2",
|
194
166
|
},
|
195
167
|
"gemma2_instruct_27b_en": {
|
196
168
|
"metadata": {
|
197
|
-
"description":
|
169
|
+
"description": (
|
170
|
+
"27 billion parameter, 42-layer, instruction tuned Gemma model."
|
171
|
+
),
|
198
172
|
"params": 27227128320,
|
199
|
-
"official_name": "Gemma",
|
200
173
|
"path": "gemma",
|
201
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
202
174
|
},
|
203
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/
|
175
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/2",
|
204
176
|
},
|
205
177
|
"shieldgemma_2b_en": {
|
206
178
|
"metadata": {
|
207
179
|
"description": "2 billion parameter, 26-layer, ShieldGemma model.",
|
208
180
|
"params": 2614341888,
|
209
|
-
"official_name": "Gemma",
|
210
181
|
"path": "gemma",
|
211
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
212
182
|
},
|
213
|
-
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_2b_en/
|
183
|
+
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_2b_en/2",
|
214
184
|
},
|
215
185
|
"shieldgemma_9b_en": {
|
216
186
|
"metadata": {
|
217
187
|
"description": "9 billion parameter, 42-layer, ShieldGemma model.",
|
218
188
|
"params": 9241705984,
|
219
|
-
"official_name": "Gemma",
|
220
189
|
"path": "gemma",
|
221
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
222
190
|
},
|
223
|
-
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_9b_en/
|
191
|
+
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_9b_en/2",
|
224
192
|
},
|
225
193
|
"shieldgemma_27b_en": {
|
226
194
|
"metadata": {
|
227
195
|
"description": "27 billion parameter, 42-layer, ShieldGemma model.",
|
228
196
|
"params": 27227128320,
|
229
|
-
"official_name": "Gemma",
|
230
197
|
"path": "gemma",
|
231
|
-
"model_card": "https://www.kaggle.com/models/google/gemma",
|
232
198
|
},
|
233
|
-
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/
|
199
|
+
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/2",
|
234
200
|
},
|
235
201
|
}
|
@@ -172,8 +172,8 @@ class GPT2CausalLM(CausalLM):
|
|
172
172
|
Args:
|
173
173
|
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
174
174
|
cache: a dense float Tensor, the cache of key and value.
|
175
|
-
cache_update_index: int, or int Tensor. The index of current inputs
|
176
|
-
whole sequence.
|
175
|
+
cache_update_index: int, or int Tensor. The index of current inputs
|
176
|
+
in the whole sequence.
|
177
177
|
|
178
178
|
Returns:
|
179
179
|
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
@@ -9,11 +9,9 @@ backbone_presets = {
|
|
9
9
|
"Trained on WebText."
|
10
10
|
),
|
11
11
|
"params": 124439808,
|
12
|
-
"official_name": "GPT-2",
|
13
12
|
"path": "gpt2",
|
14
|
-
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
|
15
13
|
},
|
16
|
-
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en/
|
14
|
+
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en/3",
|
17
15
|
},
|
18
16
|
"gpt2_medium_en": {
|
19
17
|
"metadata": {
|
@@ -22,11 +20,9 @@ backbone_presets = {
|
|
22
20
|
"Trained on WebText."
|
23
21
|
),
|
24
22
|
"params": 354823168,
|
25
|
-
"official_name": "GPT-2",
|
26
23
|
"path": "gpt2",
|
27
|
-
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
|
28
24
|
},
|
29
|
-
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_medium_en/
|
25
|
+
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_medium_en/3",
|
30
26
|
},
|
31
27
|
"gpt2_large_en": {
|
32
28
|
"metadata": {
|
@@ -35,11 +31,9 @@ backbone_presets = {
|
|
35
31
|
"Trained on WebText."
|
36
32
|
),
|
37
33
|
"params": 774030080,
|
38
|
-
"official_name": "GPT-2",
|
39
34
|
"path": "gpt2",
|
40
|
-
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
|
41
35
|
},
|
42
|
-
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_large_en/
|
36
|
+
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_large_en/3",
|
43
37
|
},
|
44
38
|
"gpt2_extra_large_en": {
|
45
39
|
"metadata": {
|
@@ -48,11 +42,9 @@ backbone_presets = {
|
|
48
42
|
"Trained on WebText."
|
49
43
|
),
|
50
44
|
"params": 1557611200,
|
51
|
-
"official_name": "GPT-2",
|
52
45
|
"path": "gpt2",
|
53
|
-
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
|
54
46
|
},
|
55
|
-
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_extra_large_en/
|
47
|
+
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_extra_large_en/3",
|
56
48
|
},
|
57
49
|
"gpt2_base_en_cnn_dailymail": {
|
58
50
|
"metadata": {
|
@@ -61,9 +53,8 @@ backbone_presets = {
|
|
61
53
|
"Finetuned on the CNN/DailyMail summarization dataset."
|
62
54
|
),
|
63
55
|
"params": 124439808,
|
64
|
-
"official_name": "GPT-2",
|
65
56
|
"path": "gpt2",
|
66
57
|
},
|
67
|
-
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/
|
58
|
+
"kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/3",
|
68
59
|
},
|
69
60
|
}
|
@@ -202,7 +202,8 @@ class GPTNeoXAttention(keras.layers.Layer):
|
|
202
202
|
training=training,
|
203
203
|
)
|
204
204
|
|
205
|
-
# Reshape `attention_output` to
|
205
|
+
# Reshape `attention_output` to
|
206
|
+
# `(batch_size, sequence_length, hidden_dim)`.
|
206
207
|
attention_output = ops.reshape(
|
207
208
|
attention_output,
|
208
209
|
[
|
@@ -27,9 +27,9 @@ class GPTNeoXCausalLM(CausalLM):
|
|
27
27
|
|
28
28
|
Args:
|
29
29
|
backbone: A `keras_hub.models.GPTNeoXBackbone` instance.
|
30
|
-
preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or
|
31
|
-
If `None`, this model will not apply preprocessing, and
|
32
|
-
should be preprocessed before calling the model.
|
30
|
+
preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or
|
31
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
32
|
+
inputs should be preprocessed before calling the model.
|
33
33
|
"""
|
34
34
|
|
35
35
|
backbone_cls = GPTNeoXBackbone
|
@@ -16,7 +16,8 @@ class GPTNeoXDecoder(keras.layers.Layer):
|
|
16
16
|
|
17
17
|
This class follows the architecture of the GPT-NeoX decoder layer in the
|
18
18
|
paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745).
|
19
|
-
Users can instantiate multiple instances of this class to stack up a
|
19
|
+
Users can instantiate multiple instances of this class to stack up a
|
20
|
+
decoder.
|
20
21
|
|
21
22
|
This layer will always apply a causal mask to the decoder attention layer.
|
22
23
|
|
@@ -15,11 +15,156 @@ class ImageClassifier(Task):
|
|
15
15
|
|
16
16
|
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
17
17
|
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
18
|
+
All `ImageClassifier` tasks include a `from_preset()` constructor which can
|
19
|
+
be used to load a pre-trained config and weights.
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
+
Args:
|
22
|
+
backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`.
|
23
|
+
num_classes: int. The number of classes to predict.
|
24
|
+
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
|
25
|
+
a `keras.Layer` instance, or a callable. If `None` no preprocessing
|
26
|
+
will be applied to the inputs.
|
27
|
+
pooling: `"avg"` or `"max"`. The type of pooling to apply on backbone
|
28
|
+
output. Defaults to average pooling.
|
29
|
+
activation: `None`, str, or callable. The activation function to use on
|
30
|
+
the `Dense` layer. Set `activation=None` to return the output
|
31
|
+
logits. Defaults to `"softmax"`.
|
32
|
+
head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
|
33
|
+
dtype to use for the classification head's computations and weights.
|
34
|
+
|
35
|
+
Examples:
|
36
|
+
|
37
|
+
Call `predict()` to run inference.
|
38
|
+
```python
|
39
|
+
# Load preset and train
|
40
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
41
|
+
classifier = keras_hub.models.ImageClassifier.from_preset(
|
42
|
+
"resnet_50_imagenet"
|
43
|
+
)
|
44
|
+
classifier.predict(images)
|
45
|
+
```
|
46
|
+
|
47
|
+
Call `fit()` on a single batch.
|
48
|
+
```python
|
49
|
+
# Load preset and train
|
50
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
51
|
+
labels = [0, 3]
|
52
|
+
classifier = keras_hub.models.ImageClassifier.from_preset(
|
53
|
+
"resnet_50_imagenet"
|
54
|
+
)
|
55
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
56
|
+
```
|
57
|
+
|
58
|
+
Call `fit()` with custom loss, optimizer and backbone.
|
59
|
+
```python
|
60
|
+
classifier = keras_hub.models.ImageClassifier.from_preset(
|
61
|
+
"resnet_50_imagenet"
|
62
|
+
)
|
63
|
+
classifier.compile(
|
64
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
65
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
66
|
+
)
|
67
|
+
classifier.backbone.trainable = False
|
68
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
69
|
+
```
|
70
|
+
|
71
|
+
Custom backbone.
|
72
|
+
```python
|
73
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
74
|
+
labels = [0, 3]
|
75
|
+
backbone = keras_hub.models.ResNetBackbone(
|
76
|
+
stackwise_num_filters=[64, 64, 64],
|
77
|
+
stackwise_num_blocks=[2, 2, 2],
|
78
|
+
stackwise_num_strides=[1, 2, 2],
|
79
|
+
block_type="basic_block",
|
80
|
+
use_pre_activation=True,
|
81
|
+
pooling="avg",
|
82
|
+
)
|
83
|
+
classifier = keras_hub.models.ImageClassifier(
|
84
|
+
backbone=backbone,
|
85
|
+
num_classes=4,
|
86
|
+
)
|
87
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
88
|
+
```
|
21
89
|
"""
|
22
90
|
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
backbone,
|
94
|
+
num_classes,
|
95
|
+
preprocessor=None,
|
96
|
+
pooling="avg",
|
97
|
+
activation=None,
|
98
|
+
dropout=0.0,
|
99
|
+
head_dtype=None,
|
100
|
+
**kwargs,
|
101
|
+
):
|
102
|
+
head_dtype = head_dtype or backbone.dtype_policy
|
103
|
+
data_format = getattr(backbone, "data_format", None)
|
104
|
+
|
105
|
+
# === Layers ===
|
106
|
+
self.backbone = backbone
|
107
|
+
self.preprocessor = preprocessor
|
108
|
+
if pooling == "avg":
|
109
|
+
self.pooler = keras.layers.GlobalAveragePooling2D(
|
110
|
+
data_format,
|
111
|
+
dtype=head_dtype,
|
112
|
+
name="pooler",
|
113
|
+
)
|
114
|
+
elif pooling == "max":
|
115
|
+
self.pooler = keras.layers.GlobalMaxPooling2D(
|
116
|
+
data_format,
|
117
|
+
dtype=head_dtype,
|
118
|
+
name="pooler",
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
raise ValueError(
|
122
|
+
"Unknown `pooling` type. Polling should be either `'avg'` or "
|
123
|
+
f"`'max'`. Received: pooling={pooling}."
|
124
|
+
)
|
125
|
+
self.output_dropout = keras.layers.Dropout(
|
126
|
+
dropout,
|
127
|
+
dtype=head_dtype,
|
128
|
+
name="output_dropout",
|
129
|
+
)
|
130
|
+
self.output_dense = keras.layers.Dense(
|
131
|
+
num_classes,
|
132
|
+
activation=activation,
|
133
|
+
dtype=head_dtype,
|
134
|
+
name="predictions",
|
135
|
+
)
|
136
|
+
|
137
|
+
# === Functional Model ===
|
138
|
+
inputs = self.backbone.input
|
139
|
+
x = self.backbone(inputs)
|
140
|
+
x = self.pooler(x)
|
141
|
+
x = self.output_dropout(x)
|
142
|
+
outputs = self.output_dense(x)
|
143
|
+
super().__init__(
|
144
|
+
inputs=inputs,
|
145
|
+
outputs=outputs,
|
146
|
+
**kwargs,
|
147
|
+
)
|
148
|
+
|
149
|
+
# === Config ===
|
150
|
+
self.num_classes = num_classes
|
151
|
+
self.activation = activation
|
152
|
+
self.pooling = pooling
|
153
|
+
self.dropout = dropout
|
154
|
+
|
155
|
+
def get_config(self):
|
156
|
+
# Backbone serialized in `super`
|
157
|
+
config = super().get_config()
|
158
|
+
config.update(
|
159
|
+
{
|
160
|
+
"num_classes": self.num_classes,
|
161
|
+
"pooling": self.pooling,
|
162
|
+
"activation": self.activation,
|
163
|
+
"dropout": self.dropout,
|
164
|
+
}
|
165
|
+
)
|
166
|
+
return config
|
167
|
+
|
23
168
|
def compile(
|
24
169
|
self,
|
25
170
|
optimizer="auto",
|
@@ -38,15 +38,18 @@ class ImageClassifierPreprocessor(Preprocessor):
|
|
38
38
|
)
|
39
39
|
|
40
40
|
# Resize a single image for resnet 50.
|
41
|
-
x = np.
|
41
|
+
x = np.random.randint(0, 256, (512, 512, 3))
|
42
42
|
x = preprocessor(x)
|
43
43
|
|
44
44
|
# Resize a labeled image.
|
45
|
-
x, y = np.
|
45
|
+
x, y = np.random.randint(0, 256, (512, 512, 3)), 1
|
46
46
|
x, y = preprocessor(x, y)
|
47
47
|
|
48
48
|
# Resize a batch of labeled images.
|
49
|
-
x, y = [
|
49
|
+
x, y = [
|
50
|
+
np.random.randint(0, 256, (512, 512, 3)),
|
51
|
+
np.zeros((512, 512, 3))
|
52
|
+
], [1, 0]
|
50
53
|
x, y = preprocessor(x, y)
|
51
54
|
|
52
55
|
# Use a `tf.data.Dataset`.
|