keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -44,10 +44,10 @@ class GemmaBackbone(Backbone):
44
44
  `hidden_dim / num_query_heads`. Defaults to True.
45
45
  use_post_ffw_norm: boolean. Whether to normalize after the feedforward
46
46
  block. Defaults to False.
47
- use_post_attention_norm: boolean. Whether to normalize after the attention
48
- block. Defaults to False.
49
- attention_logit_soft_cap: None or int. Soft cap for the attention logits.
50
- Defaults to None.
47
+ use_post_attention_norm: boolean. Whether to normalize after the
48
+ attention block. Defaults to False.
49
+ attention_logit_soft_cap: None or int. Soft cap for the attention
50
+ logits. Defaults to None.
51
51
  final_logit_soft_cap: None or int. Soft cap for the final logits.
52
52
  Defaults to None.
53
53
  use_sliding_window_attention boolean. Whether to use sliding local
@@ -205,7 +205,9 @@ class GemmaBackbone(Backbone):
205
205
  "final_logit_soft_cap": self.final_logit_soft_cap,
206
206
  "attention_logit_soft_cap": self.attention_logit_soft_cap,
207
207
  "sliding_window_size": self.sliding_window_size,
208
- "use_sliding_window_attention": self.use_sliding_window_attention,
208
+ "use_sliding_window_attention": (
209
+ self.use_sliding_window_attention
210
+ ),
209
211
  }
210
212
  )
211
213
  return config
@@ -224,7 +226,8 @@ class GemmaBackbone(Backbone):
224
226
 
225
227
  Example:
226
228
  ```
227
- # Feel free to change the mesh shape to balance data and model parallel
229
+ # Feel free to change the mesh shape to balance data and model
230
+ # parallelism
228
231
  mesh = keras.distribution.DeviceMesh(
229
232
  shape=(1, 8), axis_names=('batch', 'model'),
230
233
  devices=keras.distribution.list_devices())
@@ -232,11 +235,23 @@ class GemmaBackbone(Backbone):
232
235
  mesh, model_parallel_dim_name="model")
233
236
 
234
237
  distribution = keras.distribution.ModelParallel(
235
- mesh, layout_map, batch_dim_name='batch')
238
+ layout_map=layout_map, batch_dim_name='batch')
236
239
  with distribution.scope():
237
240
  gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
238
241
  ```
239
242
 
243
+ To see how the layout map was applied, load the model then run (for one
244
+ decoder block):
245
+ ```
246
+ embedding_layer = gemma_model.backbone.get_layer("token_embedding")
247
+ decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
248
+ for variable in embedding_layer.weights + decoder_block_1.weights:
249
+ print(
250
+ f'{variable.path:<58} {str(variable.shape):<16} '
251
+ f'{str(variable.value.sharding.spec)}'
252
+ )
253
+ ```
254
+
240
255
  Args:
241
256
  device_mesh: The `keras.distribution.DeviceMesh` instance for
242
257
  distribution.
@@ -246,25 +261,25 @@ class GemmaBackbone(Backbone):
246
261
  the data should be partition on.
247
262
  Return:
248
263
  `keras.distribution.LayoutMap` that contains the sharding spec
249
- of all the model weights.
264
+ for all the model weights.
250
265
  """
251
266
  # The weight path and shape of the Gemma backbone is like below (for 2G)
252
- # token_embedding/embeddings, (256128, 2048), 524550144
267
+ # token_embedding/embeddings, (256128, 2048)
253
268
  # repeat block for decoder
254
269
  # ...
255
- # decoder_block_17/pre_attention_norm/scale, (2048,), 2048
256
- # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304
257
- # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304
258
- # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304
259
- # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304
260
- # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048
261
- # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432
262
- # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432
263
- # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432
270
+ # decoder_block_17/pre_attention_norm/scale, (2048,)
271
+ # decoder_block_17/attention/query/kernel, (8, 2048, 256)
272
+ # decoder_block_17/attention/key/kernel, (8, 2048, 256)
273
+ # decoder_block_17/attention/value/kernel, (8, 2048, 256)
274
+ # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
275
+ # decoder_block_17/pre_ffw_norm/scale, (2048,)
276
+ # decoder_block_17/ffw_gating/kernel, (2048, 16384)
277
+ # decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
278
+ # decoder_block_17/ffw_linear/kernel, (16384, 2048)
264
279
  if not isinstance(device_mesh, keras.distribution.DeviceMesh):
265
280
  raise ValueError(
266
- "Invalid device_mesh type. Expected `keras.distribution.Device`,"
267
- f" got {type(device_mesh)}"
281
+ "Invalid device_mesh type. Expected "
282
+ f"`keras.distribution.Device`, got {type(device_mesh)}"
268
283
  )
269
284
  if model_parallel_dim_name not in device_mesh.axis_names:
270
285
  raise ValueError(
@@ -187,8 +187,8 @@ class GemmaCausalLM(CausalLM):
187
187
  Args:
188
188
  token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
189
189
  cache: a dense float Tensor, the cache of key and value.
190
- cache_update_index: int, or int Tensor. The index of current inputs in the
191
- whole sequence.
190
+ cache_update_index: int, or int Tensor. The index of current inputs
191
+ in the whole sequence.
192
192
 
193
193
  Returns:
194
194
  A (logits, hidden_states, cache) tuple. Where `logits` is the
@@ -220,7 +220,9 @@ class GemmaDecoderBlock(keras.layers.Layer):
220
220
  "use_post_ffw_norm": self.use_post_ffw_norm,
221
221
  "use_post_attention_norm": self.use_post_attention_norm,
222
222
  "logit_soft_cap": self.logit_soft_cap,
223
- "use_sliding_window_attention": self.use_sliding_window_attention,
223
+ "use_sliding_window_attention": (
224
+ self.use_sliding_window_attention
225
+ ),
224
226
  "sliding_window_size": self.sliding_window_size,
225
227
  "query_head_dim_normalize": self.query_head_dim_normalize,
226
228
  }
@@ -6,11 +6,9 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "2 billion parameter, 18-layer, base Gemma model.",
8
8
  "params": 2506172416,
9
- "official_name": "Gemma",
10
9
  "path": "gemma",
11
- "model_card": "https://www.kaggle.com/models/google/gemma",
12
10
  },
13
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/2",
11
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/3",
14
12
  },
15
13
  "gemma_instruct_2b_en": {
16
14
  "metadata": {
@@ -18,11 +16,9 @@ backbone_presets = {
18
16
  "2 billion parameter, 18-layer, instruction tuned Gemma model."
19
17
  ),
20
18
  "params": 2506172416,
21
- "official_name": "Gemma",
22
19
  "path": "gemma",
23
- "model_card": "https://www.kaggle.com/models/google/gemma",
24
20
  },
25
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/2",
21
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/3",
26
22
  },
27
23
  "gemma_1.1_instruct_2b_en": {
28
24
  "metadata": {
@@ -31,11 +27,9 @@ backbone_presets = {
31
27
  "The 1.1 update improves model quality."
32
28
  ),
33
29
  "params": 2506172416,
34
- "official_name": "Gemma",
35
30
  "path": "gemma",
36
- "model_card": "https://www.kaggle.com/models/google/gemma",
37
31
  },
38
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/3",
32
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/4",
39
33
  },
40
34
  "code_gemma_1.1_2b_en": {
41
35
  "metadata": {
@@ -45,11 +39,9 @@ backbone_presets = {
45
39
  "completion. The 1.1 update improves model quality."
46
40
  ),
47
41
  "params": 2506172416,
48
- "official_name": "Gemma",
49
42
  "path": "gemma",
50
- "model_card": "https://www.kaggle.com/models/google/gemma",
51
43
  },
52
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/1",
44
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/2",
53
45
  },
54
46
  "code_gemma_2b_en": {
55
47
  "metadata": {
@@ -59,21 +51,17 @@ backbone_presets = {
59
51
  "completion."
60
52
  ),
61
53
  "params": 2506172416,
62
- "official_name": "Gemma",
63
54
  "path": "gemma",
64
- "model_card": "https://www.kaggle.com/models/google/gemma",
65
55
  },
66
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_2b_en/1",
56
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_2b_en/2",
67
57
  },
68
58
  "gemma_7b_en": {
69
59
  "metadata": {
70
60
  "description": "7 billion parameter, 28-layer, base Gemma model.",
71
61
  "params": 8537680896,
72
- "official_name": "Gemma",
73
62
  "path": "gemma",
74
- "model_card": "https://www.kaggle.com/models/google/gemma",
75
63
  },
76
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/2",
64
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/3",
77
65
  },
78
66
  "gemma_instruct_7b_en": {
79
67
  "metadata": {
@@ -81,11 +69,9 @@ backbone_presets = {
81
69
  "7 billion parameter, 28-layer, instruction tuned Gemma model."
82
70
  ),
83
71
  "params": 8537680896,
84
- "official_name": "Gemma",
85
72
  "path": "gemma",
86
- "model_card": "https://www.kaggle.com/models/google/gemma",
87
73
  },
88
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/2",
74
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/3",
89
75
  },
90
76
  "gemma_1.1_instruct_7b_en": {
91
77
  "metadata": {
@@ -94,11 +80,9 @@ backbone_presets = {
94
80
  "The 1.1 update improves model quality."
95
81
  ),
96
82
  "params": 8537680896,
97
- "official_name": "Gemma",
98
83
  "path": "gemma",
99
- "model_card": "https://www.kaggle.com/models/google/gemma",
100
84
  },
101
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/3",
85
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/4",
102
86
  },
103
87
  "code_gemma_7b_en": {
104
88
  "metadata": {
@@ -108,11 +92,9 @@ backbone_presets = {
108
92
  "completion."
109
93
  ),
110
94
  "params": 8537680896,
111
- "official_name": "Gemma",
112
95
  "path": "gemma",
113
- "model_card": "https://www.kaggle.com/models/google/gemma",
114
96
  },
115
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/1",
97
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/2",
116
98
  },
117
99
  "code_gemma_instruct_7b_en": {
118
100
  "metadata": {
@@ -122,11 +104,9 @@ backbone_presets = {
122
104
  "to code."
123
105
  ),
124
106
  "params": 8537680896,
125
- "official_name": "Gemma",
126
107
  "path": "gemma",
127
- "model_card": "https://www.kaggle.com/models/google/gemma",
128
108
  },
129
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/1",
109
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/2",
130
110
  },
131
111
  "code_gemma_1.1_instruct_7b_en": {
132
112
  "metadata": {
@@ -136,100 +116,86 @@ backbone_presets = {
136
116
  "to code. The 1.1 update improves model quality."
137
117
  ),
138
118
  "params": 8537680896,
139
- "official_name": "Gemma",
140
119
  "path": "gemma",
141
- "model_card": "https://www.kaggle.com/models/google/gemma",
142
120
  },
143
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/1",
121
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/2",
144
122
  },
145
123
  "gemma2_2b_en": {
146
124
  "metadata": {
147
125
  "description": "2 billion parameter, 26-layer, base Gemma model.",
148
126
  "params": 2614341888,
149
- "official_name": "Gemma",
150
127
  "path": "gemma",
151
- "model_card": "https://www.kaggle.com/models/google/gemma",
152
128
  },
153
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_2b_en/1",
129
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_2b_en/2",
154
130
  },
155
131
  "gemma2_instruct_2b_en": {
156
132
  "metadata": {
157
- "description": "2 billion parameter, 26-layer, instruction tuned Gemma model.",
133
+ "description": (
134
+ "2 billion parameter, 26-layer, instruction tuned Gemma model."
135
+ ),
158
136
  "params": 2614341888,
159
- "official_name": "Gemma",
160
137
  "path": "gemma",
161
- "model_card": "https://www.kaggle.com/models/google/gemma",
162
138
  },
163
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/1",
139
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/2",
164
140
  },
165
141
  "gemma2_9b_en": {
166
142
  "metadata": {
167
143
  "description": "9 billion parameter, 42-layer, base Gemma model.",
168
144
  "params": 9241705984,
169
- "official_name": "Gemma",
170
145
  "path": "gemma",
171
- "model_card": "https://www.kaggle.com/models/google/gemma",
172
146
  },
173
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/2",
147
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/3",
174
148
  },
175
149
  "gemma2_instruct_9b_en": {
176
150
  "metadata": {
177
- "description": "9 billion parameter, 42-layer, instruction tuned Gemma model.",
151
+ "description": (
152
+ "9 billion parameter, 42-layer, instruction tuned Gemma model."
153
+ ),
178
154
  "params": 9241705984,
179
- "official_name": "Gemma",
180
155
  "path": "gemma",
181
- "model_card": "https://www.kaggle.com/models/google/gemma",
182
156
  },
183
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/2",
157
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/3",
184
158
  },
185
159
  "gemma2_27b_en": {
186
160
  "metadata": {
187
161
  "description": "27 billion parameter, 42-layer, base Gemma model.",
188
162
  "params": 27227128320,
189
- "official_name": "Gemma",
190
163
  "path": "gemma",
191
- "model_card": "https://www.kaggle.com/models/google/gemma",
192
164
  },
193
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/1",
165
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/2",
194
166
  },
195
167
  "gemma2_instruct_27b_en": {
196
168
  "metadata": {
197
- "description": "27 billion parameter, 42-layer, instruction tuned Gemma model.",
169
+ "description": (
170
+ "27 billion parameter, 42-layer, instruction tuned Gemma model."
171
+ ),
198
172
  "params": 27227128320,
199
- "official_name": "Gemma",
200
173
  "path": "gemma",
201
- "model_card": "https://www.kaggle.com/models/google/gemma",
202
174
  },
203
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/1",
175
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/2",
204
176
  },
205
177
  "shieldgemma_2b_en": {
206
178
  "metadata": {
207
179
  "description": "2 billion parameter, 26-layer, ShieldGemma model.",
208
180
  "params": 2614341888,
209
- "official_name": "Gemma",
210
181
  "path": "gemma",
211
- "model_card": "https://www.kaggle.com/models/google/gemma",
212
182
  },
213
- "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_2b_en/1",
183
+ "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_2b_en/2",
214
184
  },
215
185
  "shieldgemma_9b_en": {
216
186
  "metadata": {
217
187
  "description": "9 billion parameter, 42-layer, ShieldGemma model.",
218
188
  "params": 9241705984,
219
- "official_name": "Gemma",
220
189
  "path": "gemma",
221
- "model_card": "https://www.kaggle.com/models/google/gemma",
222
190
  },
223
- "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_9b_en/1",
191
+ "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_9b_en/2",
224
192
  },
225
193
  "shieldgemma_27b_en": {
226
194
  "metadata": {
227
195
  "description": "27 billion parameter, 42-layer, ShieldGemma model.",
228
196
  "params": 27227128320,
229
- "official_name": "Gemma",
230
197
  "path": "gemma",
231
- "model_card": "https://www.kaggle.com/models/google/gemma",
232
198
  },
233
- "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/1",
199
+ "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/2",
234
200
  },
235
201
  }
@@ -172,8 +172,8 @@ class GPT2CausalLM(CausalLM):
172
172
  Args:
173
173
  token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
174
174
  cache: a dense float Tensor, the cache of key and value.
175
- cache_update_index: int, or int Tensor. The index of current inputs in the
176
- whole sequence.
175
+ cache_update_index: int, or int Tensor. The index of current inputs
176
+ in the whole sequence.
177
177
 
178
178
  Returns:
179
179
  A (logits, hidden_states, cache) tuple. Where `logits` is the
@@ -9,11 +9,9 @@ backbone_presets = {
9
9
  "Trained on WebText."
10
10
  ),
11
11
  "params": 124439808,
12
- "official_name": "GPT-2",
13
12
  "path": "gpt2",
14
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
15
13
  },
16
- "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en/2",
14
+ "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en/3",
17
15
  },
18
16
  "gpt2_medium_en": {
19
17
  "metadata": {
@@ -22,11 +20,9 @@ backbone_presets = {
22
20
  "Trained on WebText."
23
21
  ),
24
22
  "params": 354823168,
25
- "official_name": "GPT-2",
26
23
  "path": "gpt2",
27
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
28
24
  },
29
- "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_medium_en/2",
25
+ "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_medium_en/3",
30
26
  },
31
27
  "gpt2_large_en": {
32
28
  "metadata": {
@@ -35,11 +31,9 @@ backbone_presets = {
35
31
  "Trained on WebText."
36
32
  ),
37
33
  "params": 774030080,
38
- "official_name": "GPT-2",
39
34
  "path": "gpt2",
40
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
41
35
  },
42
- "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_large_en/2",
36
+ "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_large_en/3",
43
37
  },
44
38
  "gpt2_extra_large_en": {
45
39
  "metadata": {
@@ -48,11 +42,9 @@ backbone_presets = {
48
42
  "Trained on WebText."
49
43
  ),
50
44
  "params": 1557611200,
51
- "official_name": "GPT-2",
52
45
  "path": "gpt2",
53
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
54
46
  },
55
- "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_extra_large_en/2",
47
+ "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_extra_large_en/3",
56
48
  },
57
49
  "gpt2_base_en_cnn_dailymail": {
58
50
  "metadata": {
@@ -61,9 +53,8 @@ backbone_presets = {
61
53
  "Finetuned on the CNN/DailyMail summarization dataset."
62
54
  ),
63
55
  "params": 124439808,
64
- "official_name": "GPT-2",
65
56
  "path": "gpt2",
66
57
  },
67
- "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/2",
58
+ "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/3",
68
59
  },
69
60
  }
@@ -202,7 +202,8 @@ class GPTNeoXAttention(keras.layers.Layer):
202
202
  training=training,
203
203
  )
204
204
 
205
- # Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
205
+ # Reshape `attention_output` to
206
+ # `(batch_size, sequence_length, hidden_dim)`.
206
207
  attention_output = ops.reshape(
207
208
  attention_output,
208
209
  [
@@ -27,9 +27,9 @@ class GPTNeoXCausalLM(CausalLM):
27
27
 
28
28
  Args:
29
29
  backbone: A `keras_hub.models.GPTNeoXBackbone` instance.
30
- preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or `None`.
31
- If `None`, this model will not apply preprocessing, and inputs
32
- should be preprocessed before calling the model.
30
+ preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or
31
+ `None`. If `None`, this model will not apply preprocessing, and
32
+ inputs should be preprocessed before calling the model.
33
33
  """
34
34
 
35
35
  backbone_cls = GPTNeoXBackbone
@@ -16,7 +16,8 @@ class GPTNeoXDecoder(keras.layers.Layer):
16
16
 
17
17
  This class follows the architecture of the GPT-NeoX decoder layer in the
18
18
  paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745).
19
- Users can instantiate multiple instances of this class to stack up a decoder.
19
+ Users can instantiate multiple instances of this class to stack up a
20
+ decoder.
20
21
 
21
22
  This layer will always apply a causal mask to the decoder attention layer.
22
23
 
@@ -15,11 +15,156 @@ class ImageClassifier(Task):
15
15
 
16
16
  To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
17
17
  labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
18
+ All `ImageClassifier` tasks include a `from_preset()` constructor which can
19
+ be used to load a pre-trained config and weights.
18
20
 
19
- All `ImageClassifier` tasks include a `from_preset()` constructor which can be
20
- used to load a pre-trained config and weights.
21
+ Args:
22
+ backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`.
23
+ num_classes: int. The number of classes to predict.
24
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
25
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
26
+ will be applied to the inputs.
27
+ pooling: `"avg"` or `"max"`. The type of pooling to apply on backbone
28
+ output. Defaults to average pooling.
29
+ activation: `None`, str, or callable. The activation function to use on
30
+ the `Dense` layer. Set `activation=None` to return the output
31
+ logits. Defaults to `"softmax"`.
32
+ head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
33
+ dtype to use for the classification head's computations and weights.
34
+
35
+ Examples:
36
+
37
+ Call `predict()` to run inference.
38
+ ```python
39
+ # Load preset and train
40
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
41
+ classifier = keras_hub.models.ImageClassifier.from_preset(
42
+ "resnet_50_imagenet"
43
+ )
44
+ classifier.predict(images)
45
+ ```
46
+
47
+ Call `fit()` on a single batch.
48
+ ```python
49
+ # Load preset and train
50
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
51
+ labels = [0, 3]
52
+ classifier = keras_hub.models.ImageClassifier.from_preset(
53
+ "resnet_50_imagenet"
54
+ )
55
+ classifier.fit(x=images, y=labels, batch_size=2)
56
+ ```
57
+
58
+ Call `fit()` with custom loss, optimizer and backbone.
59
+ ```python
60
+ classifier = keras_hub.models.ImageClassifier.from_preset(
61
+ "resnet_50_imagenet"
62
+ )
63
+ classifier.compile(
64
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
65
+ optimizer=keras.optimizers.Adam(5e-5),
66
+ )
67
+ classifier.backbone.trainable = False
68
+ classifier.fit(x=images, y=labels, batch_size=2)
69
+ ```
70
+
71
+ Custom backbone.
72
+ ```python
73
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
74
+ labels = [0, 3]
75
+ backbone = keras_hub.models.ResNetBackbone(
76
+ stackwise_num_filters=[64, 64, 64],
77
+ stackwise_num_blocks=[2, 2, 2],
78
+ stackwise_num_strides=[1, 2, 2],
79
+ block_type="basic_block",
80
+ use_pre_activation=True,
81
+ pooling="avg",
82
+ )
83
+ classifier = keras_hub.models.ImageClassifier(
84
+ backbone=backbone,
85
+ num_classes=4,
86
+ )
87
+ classifier.fit(x=images, y=labels, batch_size=2)
88
+ ```
21
89
  """
22
90
 
91
+ def __init__(
92
+ self,
93
+ backbone,
94
+ num_classes,
95
+ preprocessor=None,
96
+ pooling="avg",
97
+ activation=None,
98
+ dropout=0.0,
99
+ head_dtype=None,
100
+ **kwargs,
101
+ ):
102
+ head_dtype = head_dtype or backbone.dtype_policy
103
+ data_format = getattr(backbone, "data_format", None)
104
+
105
+ # === Layers ===
106
+ self.backbone = backbone
107
+ self.preprocessor = preprocessor
108
+ if pooling == "avg":
109
+ self.pooler = keras.layers.GlobalAveragePooling2D(
110
+ data_format,
111
+ dtype=head_dtype,
112
+ name="pooler",
113
+ )
114
+ elif pooling == "max":
115
+ self.pooler = keras.layers.GlobalMaxPooling2D(
116
+ data_format,
117
+ dtype=head_dtype,
118
+ name="pooler",
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ "Unknown `pooling` type. Polling should be either `'avg'` or "
123
+ f"`'max'`. Received: pooling={pooling}."
124
+ )
125
+ self.output_dropout = keras.layers.Dropout(
126
+ dropout,
127
+ dtype=head_dtype,
128
+ name="output_dropout",
129
+ )
130
+ self.output_dense = keras.layers.Dense(
131
+ num_classes,
132
+ activation=activation,
133
+ dtype=head_dtype,
134
+ name="predictions",
135
+ )
136
+
137
+ # === Functional Model ===
138
+ inputs = self.backbone.input
139
+ x = self.backbone(inputs)
140
+ x = self.pooler(x)
141
+ x = self.output_dropout(x)
142
+ outputs = self.output_dense(x)
143
+ super().__init__(
144
+ inputs=inputs,
145
+ outputs=outputs,
146
+ **kwargs,
147
+ )
148
+
149
+ # === Config ===
150
+ self.num_classes = num_classes
151
+ self.activation = activation
152
+ self.pooling = pooling
153
+ self.dropout = dropout
154
+
155
+ def get_config(self):
156
+ # Backbone serialized in `super`
157
+ config = super().get_config()
158
+ config.update(
159
+ {
160
+ "num_classes": self.num_classes,
161
+ "pooling": self.pooling,
162
+ "activation": self.activation,
163
+ "dropout": self.dropout,
164
+ }
165
+ )
166
+ return config
167
+
23
168
  def compile(
24
169
  self,
25
170
  optimizer="auto",
@@ -38,15 +38,18 @@ class ImageClassifierPreprocessor(Preprocessor):
38
38
  )
39
39
 
40
40
  # Resize a single image for resnet 50.
41
- x = np.ones((512, 512, 3))
41
+ x = np.random.randint(0, 256, (512, 512, 3))
42
42
  x = preprocessor(x)
43
43
 
44
44
  # Resize a labeled image.
45
- x, y = np.ones((512, 512, 3)), 1
45
+ x, y = np.random.randint(0, 256, (512, 512, 3)), 1
46
46
  x, y = preprocessor(x, y)
47
47
 
48
48
  # Resize a batch of labeled images.
49
- x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
49
+ x, y = [
50
+ np.random.randint(0, 256, (512, 512, 3)),
51
+ np.zeros((512, 512, 3))
52
+ ], [1, 0]
50
53
  x, y = preprocessor(x, y)
51
54
 
52
55
  # Use a `tf.data.Dataset`.