keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -5,26 +5,24 @@ backbone_presets = {
5
5
  "metadata": {
6
6
  "description": (
7
7
  "12-layer RoBERTa model where case is maintained."
8
- "Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText."
8
+ "Trained on English Wikipedia, BooksCorpus, CommonCraw, and "
9
+ "OpenWebText."
9
10
  ),
10
11
  "params": 124052736,
11
- "official_name": "RoBERTa",
12
12
  "path": "roberta",
13
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
14
13
  },
15
- "kaggle_handle": "kaggle://keras/roberta/keras/roberta_base_en/2",
14
+ "kaggle_handle": "kaggle://keras/roberta/keras/roberta_base_en/3",
16
15
  },
17
16
  "roberta_large_en": {
18
17
  "metadata": {
19
18
  "description": (
20
19
  "24-layer RoBERTa model where case is maintained."
21
- "Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText."
20
+ "Trained on English Wikipedia, BooksCorpus, CommonCraw, and "
21
+ "OpenWebText."
22
22
  ),
23
23
  "params": 354307072,
24
- "official_name": "RoBERTa",
25
24
  "path": "roberta",
26
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
27
25
  },
28
- "kaggle_handle": "kaggle://keras/roberta/keras/roberta_large_en/2",
26
+ "kaggle_handle": "kaggle://keras/roberta/keras/roberta_large_en/3",
29
27
  },
30
28
  }
@@ -38,9 +38,9 @@ class RobertaTextClassifier(TextClassifier):
38
38
  Args:
39
39
  backbone: A `keras_hub.models.RobertaBackbone` instance.
40
40
  num_classes: int. Number of classes to predict.
41
- preprocessor: A `keras_hub.models.RobertaTextClassifierPreprocessor` or `None`. If
42
- `None`, this model will not apply preprocessing, and inputs should
43
- be preprocessed before calling the model.
41
+ preprocessor: A `keras_hub.models.RobertaTextClassifierPreprocessor` or
42
+ `None`. If `None`, this model will not apply preprocessing, and
43
+ inputs should be preprocessed before calling the model.
44
44
  activation: Optional `str` or callable. The activation function to use
45
45
  on the model outputs. Set `activation="softmax"` to return output
46
46
  probabilities. Defaults to `None`.
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.sam.sam_backbone import SAMBackbone
2
+ from keras_hub.src.models.sam.sam_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, SAMBackbone)
@@ -9,8 +9,8 @@ class SAMBackbone(Backbone):
9
9
  """A backbone for the Segment Anything Model (SAM).
10
10
 
11
11
  Args:
12
- image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for
13
- the input images.
12
+ image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor
13
+ for the input images.
14
14
  prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
15
15
  compute embeddings for points, box, and mask prompt.
16
16
  mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
@@ -68,7 +68,6 @@ class SAMBackbone(Backbone):
68
68
  image_encoder=image_encoder,
69
69
  prompt_encoder=prompt_encoder,
70
70
  mask_decoder=mask_decoder,
71
- image_shape=(image_size, image_size, 3),
72
71
  )
73
72
  backbone(input_data)
74
73
  ```
@@ -1,10 +1,8 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
- from keras_hub.src.layers.preprocessing.resizing_image_converter import (
3
- ResizingImageConverter,
4
- )
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
5
3
  from keras_hub.src.models.sam.sam_backbone import SAMBackbone
6
4
 
7
5
 
8
6
  @keras_hub_export("keras_hub.layers.SAMImageConverter")
9
- class SAMImageConverter(ResizingImageConverter):
7
+ class SAMImageConverter(ImageConverter):
10
8
  backbone_cls = SAMBackbone
@@ -31,7 +31,7 @@ class SAMImageSegmenter(ImageSegmenter):
31
31
 
32
32
 
33
33
  Args:
34
- backbone: A `keras_hub.models.VGGBackbone` instance.
34
+ backbone: A `keras_hub.models.SAMBackbone` instance.
35
35
 
36
36
  Example:
37
37
  Load pretrained model using `from_preset`.
@@ -51,9 +51,9 @@ class SAMImageSegmenter(ImageSegmenter):
51
51
  (batch_size, 0, image_size, image_size, 1)
52
52
  ),
53
53
  }
54
- # todo: update preset name
55
- sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
56
- sam(input_data)
54
+ sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b')
55
+ outputs = sam.predict(input_data)
56
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
57
57
  ```
58
58
 
59
59
  Load segment anything image segmenter with custom backbone
@@ -65,7 +65,7 @@ class SAMImageSegmenter(ImageSegmenter):
65
65
  (batch_size, image_size, image_size, 3),
66
66
  dtype="float32",
67
67
  )
68
- image_encoder = ViTDetBackbone(
68
+ image_encoder = keras_hub.models.ViTDetBackbone(
69
69
  hidden_size=16,
70
70
  num_layers=16,
71
71
  intermediate_dim=16 * 4,
@@ -76,7 +76,7 @@ class SAMImageSegmenter(ImageSegmenter):
76
76
  window_size=2,
77
77
  image_shape=(image_size, image_size, 3),
78
78
  )
79
- prompt_encoder = SAMPromptEncoder(
79
+ prompt_encoder = keras_hub.layers.SAMPromptEncoder(
80
80
  hidden_size=8,
81
81
  image_embedding_size=(8, 8),
82
82
  input_image_size=(
@@ -85,7 +85,7 @@ class SAMImageSegmenter(ImageSegmenter):
85
85
  ),
86
86
  mask_in_channels=16,
87
87
  )
88
- mask_decoder = SAMMaskDecoder(
88
+ mask_decoder = keras_hub.layers.SAMMaskDecoder(
89
89
  num_layers=2,
90
90
  hidden_size=8,
91
91
  intermediate_dim=32,
@@ -95,13 +95,12 @@ class SAMImageSegmenter(ImageSegmenter):
95
95
  iou_head_depth=3,
96
96
  iou_head_hidden_dim=8,
97
97
  )
98
- backbone = SAMBackbone(
98
+ backbone = keras_hub.models.SAMBackbone(
99
99
  image_encoder=image_encoder,
100
100
  prompt_encoder=prompt_encoder,
101
101
  mask_decoder=mask_decoder,
102
- image_shape=(image_size, image_size, 3),
103
102
  )
104
- sam = SAMImageSegmenter(
103
+ sam = keras_hub.models.SAMImageSegmenter(
105
104
  backbone=backbone
106
105
  )
107
106
  ```
@@ -115,7 +114,7 @@ class SAMImageSegmenter(ImageSegmenter):
115
114
  labels = np.array([[1., 0.]])
116
115
  box = np.array([[[[384., 384.], [640., 640.]]]])
117
116
  input_mask = np.ones((1, 1, 256, 256, 1))
118
- Prepare an input dictionary:
117
+ # Prepare an input dictionary:
119
118
  inputs = {
120
119
  "images": image,
121
120
  "points": points,
@@ -201,17 +200,18 @@ class SAMImageSegmenter(ImageSegmenter):
201
200
  def _add_placeholder_prompts(self, inputs):
202
201
  """Adds placeholder prompt inputs for a call to SAM.
203
202
 
204
- Because SAM is a functional subclass model, all inputs must be specified in
205
- calls to the model. However, prompt inputs are all optional, so we have to
206
- add placeholders when they're not specified by the user.
203
+ Because SAM is a functional subclass model, all inputs must be specified
204
+ in calls to the model. However, prompt inputs are all optional, so we
205
+ have to add placeholders when they're not specified by the user.
207
206
  """
208
207
  inputs = inputs.copy()
209
208
 
210
209
  # Get the batch shape based on the image input
211
210
  batch_size = ops.shape(inputs["images"])[0]
212
211
 
213
- # The type of the placeholders must match the existing inputs with respect
214
- # to whether or not they are tensors (as opposed to Numpy arrays).
212
+ # The type of the placeholders must match the existing inputs with
213
+ # respect to whether or not they are tensors (as opposed to Numpy
214
+ # arrays).
215
215
  zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros
216
216
 
217
217
  # Fill in missing inputs.
@@ -1,12 +1,22 @@
1
+ import keras
2
+
1
3
  from keras_hub.src.api_export import keras_hub_export
2
4
  from keras_hub.src.models.image_segmenter_preprocessor import (
3
5
  ImageSegmenterPreprocessor,
4
6
  )
5
7
  from keras_hub.src.models.sam.sam_backbone import SAMBackbone
6
8
  from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
9
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
7
10
 
8
11
 
9
- @keras_hub_export("keras_hub.models.SamImageSegmenterPreprocessor")
12
+ @keras_hub_export("keras_hub.models.SAMImageSegmenterPreprocessor")
10
13
  class SAMImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
11
14
  backbone_cls = SAMBackbone
12
15
  image_converter_cls = SAMImageConverter
16
+
17
+ @preprocessing_function
18
+ def call(self, x, y=None, sample_weight=None):
19
+ images = x["images"]
20
+ if self.image_converter:
21
+ x["images"] = self.image_converter(images)
22
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -170,8 +170,8 @@ class TwoWayMultiHeadAttention(keras.layers.Layer):
170
170
  key_dim: int. Size of each attention head for query, key, and
171
171
  value.
172
172
  intermediate_dim: int. Number of hidden dims to use in the mlp block.
173
- skip_first_layer_pos_embedding: bool. A boolean indicating whether to skip the
174
- first layer positional embeddings.
173
+ skip_first_layer_pos_embedding: bool. A boolean indicating whether to
174
+ skip the first layer positional embeddings.
175
175
  attention_downsample_rate: int, optional. The downsample rate to use
176
176
  in the attention layers. Defaults to 2.
177
177
  activation: str, optional. The activation for the mlp block's output
@@ -296,7 +296,9 @@ class TwoWayMultiHeadAttention(keras.layers.Layer):
296
296
  "num_heads": self.num_heads,
297
297
  "key_dim": self.key_dim,
298
298
  "intermediate_dim": self.intermediate_dim,
299
- "skip_first_layer_pos_embedding": self.skip_first_layer_pos_embedding,
299
+ "skip_first_layer_pos_embedding": (
300
+ self.skip_first_layer_pos_embedding
301
+ ),
300
302
  "attention_downsample_rate": self.attention_downsample_rate,
301
303
  "activation": self.activation,
302
304
  }
@@ -5,30 +5,24 @@ backbone_presets = {
5
5
  "metadata": {
6
6
  "description": ("The base SAM model trained on the SA1B dataset."),
7
7
  "params": 93735728,
8
- "official_name": "SAMImageSegmenter",
9
8
  "path": "sam",
10
- "model_card": "https://arxiv.org/abs/2304.02643",
11
9
  },
12
- "kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/1",
10
+ "kaggle_handle": "kaggle://keras/sam/keras/sam_base_sa1b/5",
13
11
  },
14
12
  "sam_large_sa1b": {
15
13
  "metadata": {
16
14
  "description": ("The large SAM model trained on the SA1B dataset."),
17
15
  "params": 641090864,
18
- "official_name": "SAMImageSegmenter",
19
16
  "path": "sam",
20
- "model_card": "https://arxiv.org/abs/2304.02643",
21
17
  },
22
- "kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/1",
18
+ "kaggle_handle": "kaggle://keras/sam/keras/sam_large_sa1b/5",
23
19
  },
24
20
  "sam_huge_sa1b": {
25
21
  "metadata": {
26
22
  "description": ("The huge SAM model trained on the SA1B dataset."),
27
23
  "params": 312343088,
28
- "official_name": "SAMImageSegmenter",
29
24
  "path": "sam",
30
- "model_card": "https://arxiv.org/abs/2304.02643",
31
25
  },
32
- "kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/1",
26
+ "kaggle_handle": "kaggle://keras/sam/keras/sam_huge_sa1b/5",
33
27
  },
34
28
  }
@@ -57,7 +57,7 @@ class SAMPromptEncoder(keras.layers.Layer):
57
57
  input_image_size=(1024, 1024),
58
58
  mask_in_channels=16,
59
59
  activation="gelu",
60
- **kwargs
60
+ **kwargs,
61
61
  ):
62
62
  super().__init__(**kwargs)
63
63
  self.hidden_size = hidden_size
@@ -305,7 +305,9 @@ class SAMPromptEncoder(keras.layers.Layer):
305
305
  return {
306
306
  "prompt_sparse_embeddings": sparse_embeddings,
307
307
  "prompt_dense_embeddings": dense_embeddings,
308
- "prompt_dense_positional_embeddings": prompt_dense_positional_embeddings,
308
+ "prompt_dense_positional_embeddings": (
309
+ prompt_dense_positional_embeddings
310
+ ),
309
311
  }
310
312
 
311
313
  def get_config(self):
@@ -31,14 +31,15 @@ class TwoWayTransformer(keras.layers.Layer):
31
31
  location and type.
32
32
 
33
33
  Args:
34
- num_layers: int, optional. The num_layers of the attention blocks (the number
35
- of attention blocks to use). Defaults to `2`.
34
+ num_layers: int, optional. The num_layers of the attention blocks
35
+ (the number of attention blocks to use). Defaults to `2`.
36
36
  hidden_size: int, optional. The number of features of the input image
37
37
  and point embeddings. Defaults to `256`.
38
38
  num_heads: int, optional. Number of heads to use in the attention
39
39
  layers. Defaults to `8`.
40
- intermediate_dim: int, optional. The number of units in the hidden layer of
41
- the MLP block used in the attention layers. Defaults to `2048`.
40
+ intermediate_dim: int, optional. The number of units in the hidden
41
+ layer of the MLP block used in the attention layers.
42
+ Defaults to `2048`.
42
43
  activation: str, optional. The activation of the MLP block's output
43
44
  layer used in the attention layers. Defaults to `"relu"`.
44
45
  attention_downsample_rate: int, optional. The downsample rate of the
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
2
+ from keras_hub.src.models.segformer.segformer_image_segmenter import (
3
+ SegFormerImageSegmenter,
4
+ )
5
+ from keras_hub.src.models.segformer.segformer_presets import presets
6
+ from keras_hub.src.utils.preset_utils import register_presets
7
+
8
+ register_presets(presets, SegFormerImageSegmenter)
@@ -0,0 +1,167 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+
6
+
7
+ @keras_hub_export("keras_hub.models.SegFormerBackbone")
8
+ class SegFormerBackbone(Backbone):
9
+ """A Keras model implementing SegFormer for semantic segmentation.
10
+
11
+ This class implements the majority of the SegFormer architecture described
12
+ in [SegFormer: Simple and Efficient Design for Semantic Segmentation](https://arxiv.org/abs/2105.15203)
13
+ and based on the TensorFlow implementation
14
+ [from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
15
+
16
+ SegFormers are meant to be used with the MixTransformer (MiT) encoder
17
+ family, and use a very lightweight all-MLP decoder head.
18
+
19
+ The MiT encoder uses a hierarchical transformer which outputs features at
20
+ multiple scales, similar to that of the hierarchical outputs typically
21
+ associated with CNNs.
22
+
23
+ Args:
24
+ image_encoder: `keras.Model`. The backbone network for the model that is
25
+ used as a feature extractor for the SegFormer encoder.
26
+ Should be used with the MiT backbone model
27
+ (`keras_hub.models.MiTBackbone`) which was created
28
+ specifically for SegFormers.
29
+ num_classes: int, the number of classes for the detection model,
30
+ including the background class.
31
+ projection_filters: int, number of filters in the
32
+ convolution layer projecting the concatenated features into
33
+ a segmentation map. Defaults to 256`.
34
+
35
+ Example:
36
+
37
+ Using the class with a custom `backbone`:
38
+
39
+ ```python
40
+ import keras_hub
41
+
42
+ backbone = keras_hub.models.MiTBackbone(
43
+ depths=[2, 2, 2, 2],
44
+ image_shape=(224, 224, 3),
45
+ hidden_dims=[32, 64, 160, 256],
46
+ num_layers=4,
47
+ blockwise_num_heads=[1, 2, 5, 8],
48
+ blockwise_sr_ratios=[8, 4, 2, 1],
49
+ max_drop_path_rate=0.1,
50
+ patch_sizes=[7, 3, 3, 3],
51
+ strides=[4, 2, 2, 2],
52
+ )
53
+
54
+ segformer_backbone = keras_hub.models.SegFormerBackbone(
55
+ image_encoder=backbone, projection_filters=256)
56
+ ```
57
+
58
+ Using the class with a preset `backbone`:
59
+
60
+ ```python
61
+ import keras_hub
62
+
63
+ backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
64
+ segformer_backbone = keras_hub.models.SegFormerBackbone(
65
+ image_encoder=backbone, projection_filters=256)
66
+ ```
67
+
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ image_encoder,
73
+ projection_filters,
74
+ **kwargs,
75
+ ):
76
+ if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
77
+ image_encoder, keras.Model
78
+ ):
79
+ raise ValueError(
80
+ "Argument `image_encoder` must be a `keras.layers.Layer` "
81
+ f"instance or `keras.Model`. Received instead "
82
+ f"image_encoder={image_encoder} "
83
+ f"(of type {type(image_encoder)})."
84
+ )
85
+
86
+ # === Layers ===
87
+ inputs = keras.layers.Input(shape=image_encoder.input.shape[1:])
88
+
89
+ self.feature_extractor = keras.Model(
90
+ image_encoder.inputs, image_encoder.pyramid_outputs
91
+ )
92
+
93
+ features = self.feature_extractor(inputs)
94
+ # Get height and width of level one output
95
+ _, height, width, _ = features["P1"].shape
96
+
97
+ self.mlp_blocks = []
98
+
99
+ for feature_dim, feature in zip(image_encoder.hidden_dims, features):
100
+ self.mlp_blocks.append(
101
+ keras.layers.Dense(
102
+ projection_filters, name=f"linear_{feature_dim}"
103
+ )
104
+ )
105
+
106
+ self.resizing = keras.layers.Resizing(
107
+ height, width, interpolation="bilinear"
108
+ )
109
+ self.concat = keras.layers.Concatenate(axis=-1)
110
+ self.linear_fuse = keras.Sequential(
111
+ [
112
+ keras.layers.Conv2D(
113
+ filters=projection_filters, kernel_size=1, use_bias=False
114
+ ),
115
+ keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9),
116
+ keras.layers.Activation("relu"),
117
+ ]
118
+ )
119
+
120
+ # === Functional Model ===
121
+ # Project all multi-level outputs onto
122
+ # the same dimensionality and feature map shape
123
+ multi_layer_outs = []
124
+ for index, (feature_dim, feature) in enumerate(
125
+ zip(image_encoder.hidden_dims, features)
126
+ ):
127
+ out = self.mlp_blocks[index](features[feature])
128
+ out = self.resizing(out)
129
+ multi_layer_outs.append(out)
130
+
131
+ # Concat now-equal feature maps
132
+ concatenated_outs = self.concat(multi_layer_outs[::-1])
133
+
134
+ # Fuse concatenated features into a segmentation map
135
+ seg = self.linear_fuse(concatenated_outs)
136
+
137
+ super().__init__(
138
+ inputs=inputs,
139
+ outputs=seg,
140
+ **kwargs,
141
+ )
142
+
143
+ # === Config ===
144
+ self.projection_filters = projection_filters
145
+ self.image_encoder = image_encoder
146
+
147
+ def get_config(self):
148
+ config = super().get_config()
149
+ config.update(
150
+ {
151
+ "projection_filters": self.projection_filters,
152
+ "image_encoder": keras.saving.serialize_keras_object(
153
+ self.image_encoder
154
+ ),
155
+ }
156
+ )
157
+ return config
158
+
159
+ @classmethod
160
+ def from_config(cls, config):
161
+ if "image_encoder" in config and isinstance(
162
+ config["image_encoder"], dict
163
+ ):
164
+ config["image_encoder"] = keras.layers.deserialize(
165
+ config["image_encoder"]
166
+ )
167
+ return super().from_config(config)
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.SegFormerImageConverter")
7
+ class SegFormerImageConverter(ImageConverter):
8
+ backbone_cls = SegFormerBackbone
@@ -0,0 +1,184 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
5
+ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
6
+ from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( # noqa: E501
7
+ SegFormerImageSegmenterPreprocessor,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.SegFormerImageSegmenter")
12
+ class SegFormerImageSegmenter(ImageSegmenter):
13
+ """A Keras model implementing SegFormer for semantic segmentation.
14
+
15
+ This class implements the segmentation head of the SegFormer architecture
16
+ described in [SegFormer: Simple and Efficient Design for Semantic
17
+ Segmentation with Transformers] (https://arxiv.org/abs/2105.15203) and
18
+ [based on the TensorFlow implementation from DeepVision]
19
+ (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
20
+
21
+ SegFormers are meant to be used with the MixTransformer (MiT) encoder
22
+ family, and and use a very lightweight all-MLP decoder head.
23
+
24
+ The MiT encoder uses a hierarchical transformer which outputs features at
25
+ multiple scales, similar to that of the hierarchical outputs typically
26
+ associated with CNNs.
27
+
28
+ Args:
29
+ image_encoder: `keras.Model`. The backbone network for the model that is
30
+ used as a feature extractor for the SegFormer encoder. It is
31
+ *intended* to be used only with the MiT backbone model
32
+ (`keras_hub.models.MiTBackbone`) which was created specifically for
33
+ SegFormers. Alternatively, can be a `keras_hub.models.Backbone` a
34
+ model subclassing `keras_hub.models.FeaturePyramidBackbone`, or a
35
+ `keras.Model` that has a `pyramid_outputs` property which is a
36
+ dictionary with keys "P2", "P3", "P4", and "P5" and layer names as
37
+ values.
38
+ num_classes: int, the number of classes for the detection model,
39
+ including the background class.
40
+ projection_filters: int, number of filters in the
41
+ convolution layer projecting the concatenated features into a
42
+ segmentation map. Defaults to 256`.
43
+
44
+
45
+ Example:
46
+
47
+ Using presets:
48
+
49
+ ```python
50
+ segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset(
51
+ "segformer_b0_ade20k_512"
52
+ )
53
+
54
+ images = np.random.rand(1, 512, 512, 3)
55
+ segformer(images)
56
+ ```
57
+
58
+ Using the SegFormer backbone:
59
+
60
+ ```python
61
+ encoder = keras_hub.models.MiTBackbone.from_preset(
62
+ "mit_b0_ade20k_512"
63
+ )
64
+ backbone = keras_hub.models.SegFormerBackbone(
65
+ image_encoder=encoder,
66
+ projection_filters=256,
67
+ )
68
+ ```
69
+
70
+ Using the SegFormer backbone with a custom encoder:
71
+
72
+ ```python
73
+ images = np.ones(shape=(1, 96, 96, 3))
74
+ labels = np.zeros(shape=(1, 96, 96, 1))
75
+
76
+ encoder = keras_hub.models.MiTBackbone(
77
+ depths=[2, 2, 2, 2],
78
+ image_shape=(96, 96, 3),
79
+ hidden_dims=[32, 64, 160, 256],
80
+ num_layers=4,
81
+ blockwise_num_heads=[1, 2, 5, 8],
82
+ blockwise_sr_ratios=[8, 4, 2, 1],
83
+ max_drop_path_rate=0.1,
84
+ patch_sizes=[7, 3, 3, 3],
85
+ strides=[4, 2, 2, 2],
86
+ )
87
+
88
+ backbone = keras_hub.models.SegFormerBackbone(
89
+ image_encoder=encoder,
90
+ projection_filters=256,
91
+ )
92
+ segformer = keras_hub.models.SegFormerImageSegmenter(
93
+ backbone=backbone,
94
+ num_classes=4,
95
+ )
96
+ segformer(images
97
+ ```
98
+
99
+ Using the segmentor class with a preset backbone:
100
+
101
+ ```python
102
+ image_encoder = keras_hub.models.MiTBackbone.from_preset(
103
+ "mit_b0_ade20k_512"
104
+ )
105
+ backbone = keras_hub.models.SegFormerBackbone(
106
+ image_encoder=encoder,
107
+ projection_filters=256,
108
+ )
109
+ segformer = keras_hub.models.SegFormerImageSegmenter(
110
+ backbone=backbone,
111
+ num_classes=4,
112
+ )
113
+ ```
114
+ """
115
+
116
+ backbone_cls = SegFormerBackbone
117
+ preprocessor_cls = SegFormerImageSegmenterPreprocessor
118
+
119
+ def __init__(
120
+ self,
121
+ backbone,
122
+ num_classes,
123
+ preprocessor=None,
124
+ **kwargs,
125
+ ):
126
+ if not isinstance(backbone, keras.layers.Layer) or not isinstance(
127
+ backbone, keras.Model
128
+ ):
129
+ raise ValueError(
130
+ "Argument `backbone` must be a `keras.layers.Layer` instance "
131
+ f" or `keras.Model`. Received instead "
132
+ f"backbone={backbone} (of type {type(backbone)})."
133
+ )
134
+
135
+ # === Layers ===
136
+ inputs = backbone.input
137
+
138
+ self.backbone = backbone
139
+ self.preprocessor = preprocessor
140
+ self.dropout = keras.layers.Dropout(0.1)
141
+ self.output_segmentation_head = keras.layers.Conv2D(
142
+ filters=num_classes, kernel_size=1, strides=1
143
+ )
144
+ self.resizing = keras.layers.Resizing(
145
+ height=inputs.shape[1],
146
+ width=inputs.shape[2],
147
+ interpolation="bilinear",
148
+ )
149
+
150
+ # === Functional Model ===
151
+ x = self.backbone(inputs)
152
+ x = self.dropout(x)
153
+ x = self.output_segmentation_head(x)
154
+ output = self.resizing(x)
155
+
156
+ super().__init__(
157
+ inputs=inputs,
158
+ outputs=output,
159
+ **kwargs,
160
+ )
161
+
162
+ # === Config ===
163
+ self.num_classes = num_classes
164
+ self.backbone = backbone
165
+
166
+ def get_config(self):
167
+ config = super().get_config()
168
+ config.update(
169
+ {
170
+ "num_classes": self.num_classes,
171
+ "backbone": keras.saving.serialize_keras_object(self.backbone),
172
+ }
173
+ )
174
+ return config
175
+
176
+ @classmethod
177
+ def from_config(cls, config):
178
+ if "image_encoder" in config and isinstance(
179
+ config["image_encoder"], dict
180
+ ):
181
+ config["image_encoder"] = keras.layers.deserialize(
182
+ config["image_encoder"]
183
+ )
184
+ return super().from_config(config)