keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,47 +1,190 @@
1
+ import math
2
+
3
+ import keras
4
+ import numpy as np
5
+ from keras import ops
6
+
1
7
  from keras_hub.src.api_export import keras_hub_export
2
8
  from keras_hub.src.layers.preprocessing.preprocessing_layer import (
3
9
  PreprocessingLayer,
4
10
  )
5
- from keras_hub.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE
11
+ from keras_hub.src.utils.keras_utils import standardize_data_format
6
12
  from keras_hub.src.utils.preset_utils import builtin_presets
7
13
  from keras_hub.src.utils.preset_utils import find_subclass
8
14
  from keras_hub.src.utils.preset_utils import get_preset_loader
9
- from keras_hub.src.utils.preset_utils import save_serialized_object
15
+ from keras_hub.src.utils.preset_utils import get_preset_saver
10
16
  from keras_hub.src.utils.python_utils import classproperty
17
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
11
18
 
12
19
 
13
20
  @keras_hub_export("keras_hub.layers.ImageConverter")
14
21
  class ImageConverter(PreprocessingLayer):
15
- """Convert raw image for models that support image input.
22
+ """Preprocess raw images into model ready inputs.
23
+
24
+ This class converts from raw images to model ready inputs. This conversion
25
+ proceeds in the following steps:
16
26
 
17
- This class converts from raw images of any size, to preprocessed
18
- images for pretrained model inputs. It is meant to be a convenient way to
19
- write custom preprocessing code that is not model specific. This layer
20
- should be instantiated via the `from_preset()` constructor, which will
21
- create the correct subclass of this layer for the model preset.
27
+ 1. Resize the image using to `image_size`. If `image_size` is `None`, this
28
+ step will be skipped.
29
+ 2. Rescale the image by multiplying by `scale`, which can be either global
30
+ or per channel. If `scale` is `None`, this step will be skipped.
31
+ 3. Offset the image by adding `offset`, which can be either global
32
+ or per channel. If `offset` is `None`, this step will be skipped.
22
33
 
23
34
  The layer will take as input a raw image tensor in the channels last or
24
35
  channels first format, and output a preprocessed image input for modeling.
25
- The exact structure of the output will vary per model, though in most cases
26
- this layer will simply resize the image to the size needed by the model
27
- input.
36
+ This tensor can be batched (rank 4), or unbatched (rank 3).
37
+
38
+ This layer can be used with the `from_preset()` constructor to load a layer
39
+ that will rescale and resize an image for a specific pretrained model.
40
+ Using the layer this way allows writing preprocessing code that does not
41
+ need updating when switching between model checkpoints.
42
+
43
+ Args:
44
+ image_size: `(int, int)` tuple or `None`. The output size of the image,
45
+ not including the channels axis. If `None`, the input will not be
46
+ resized.
47
+ scale: float, tuple of floats, or `None`. The scale to apply to the
48
+ inputs. If `scale` is a single float, the entire input will be
49
+ multiplied by `scale`. If `scale` is a tuple, it's assumed to
50
+ contain per-channel scale value multiplied against each channel of
51
+ the input images. If `scale` is `None`, no scaling is applied.
52
+ offset: float, tuple of floats, or `None`. The offset to apply to the
53
+ inputs. If `offset` is a single float, the entire input will be
54
+ summed with `offset`. If `offset` is a tuple, it's assumed to
55
+ contain per-channel offset value summed against each channel of the
56
+ input images. If `offset` is `None`, no scaling is applied.
57
+ crop_to_aspect_ratio: If `True`, resize the images without aspect
58
+ ratio distortion. When the original aspect ratio differs
59
+ from the target aspect ratio, the output image will be
60
+ cropped so as to return the
61
+ largest possible window in the image (of size `(height, width)`)
62
+ that matches the target aspect ratio. By default
63
+ (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
64
+ interpolation: String, the interpolation method.
65
+ Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
66
+ `"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`.
67
+ data_format: String, either `"channels_last"` or `"channels_first"`.
68
+ The ordering of the dimensions in the inputs. `"channels_last"`
69
+ corresponds to inputs with shape `(batch, height, width, channels)`
70
+ while `"channels_first"` corresponds to inputs with shape
71
+ `(batch, channels, height, width)`. It defaults to the
72
+ `image_data_format` value found in your Keras config file at
73
+ `~/.keras/keras.json`. If you never set it, then it will be
74
+ `"channels_last"`.
28
75
 
29
76
  Examples:
30
77
  ```python
31
- # Resize images for `"pali_gemma_3b_224"`.
32
- converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_224")
33
- converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
34
- # Resize images for `"pali_gemma_3b_448"`.
35
- converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_448")
36
- converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3)
78
+ # Resize raw images and scale them to [0, 1].
79
+ converter = keras_hub.layers.ImageConverter(
80
+ image_size=(128, 128),
81
+ scale=1. / 255,
82
+ )
83
+ converter(np.random.randint(0, 256, size=(2, 512, 512, 3)))
84
+
85
+ # Resize images to the specific size needed for a PaliGemma preset.
86
+ converter = keras_hub.layers.ImageConverter.from_preset(
87
+ "pali_gemma_3b_224"
88
+ )
89
+ converter(np.random.randint(0, 256, size=(2, 512, 512, 3)))
37
90
  ```
38
91
  """
39
92
 
40
93
  backbone_cls = None
41
94
 
95
+ def __init__(
96
+ self,
97
+ image_size=None,
98
+ scale=None,
99
+ offset=None,
100
+ crop_to_aspect_ratio=True,
101
+ interpolation="bilinear",
102
+ data_format=None,
103
+ **kwargs,
104
+ ):
105
+ # TODO: old arg names. Delete this block after resaving Kaggle assets.
106
+ if "height" in kwargs and "width" in kwargs:
107
+ image_size = (kwargs.pop("height"), kwargs.pop("width"))
108
+ if "variance" in kwargs and "mean" in kwargs:
109
+ std = [math.sqrt(v) for v in kwargs.pop("variance")]
110
+ scale = [scale / s for s in std]
111
+ offset = [-m / s for m, s in zip(kwargs.pop("mean"), std)]
112
+
113
+ super().__init__(**kwargs)
114
+
115
+ # Create the `Resizing` layer here even if it's not being used. That
116
+ # allows us to make `image_size` a settable property.
117
+ self.resizing = keras.layers.Resizing(
118
+ height=image_size[0] if image_size else None,
119
+ width=image_size[1] if image_size else None,
120
+ crop_to_aspect_ratio=crop_to_aspect_ratio,
121
+ interpolation=interpolation,
122
+ data_format=data_format,
123
+ dtype=self.dtype_policy,
124
+ name="resizing",
125
+ )
126
+ self.scale = scale
127
+ self.offset = offset
128
+ self.crop_to_aspect_ratio = crop_to_aspect_ratio
129
+ self.interpolation = interpolation
130
+ self.data_format = standardize_data_format(data_format)
131
+
132
+ @property
42
133
  def image_size(self):
43
- """Returns the default size of a single image."""
44
- return (None, None)
134
+ """Settable tuple of `(height, width)` ints. The output image shape."""
135
+ if self.resizing.height is None:
136
+ return None
137
+ return (self.resizing.height, self.resizing.width)
138
+
139
+ @image_size.setter
140
+ def image_size(self, value):
141
+ if value is None:
142
+ value = (None, None)
143
+ self.resizing.height = value[0]
144
+ self.resizing.width = value[1]
145
+
146
+ @preprocessing_function
147
+ def call(self, inputs):
148
+ x = inputs
149
+ if self.image_size is not None:
150
+ x = self.resizing(x)
151
+ if self.scale is not None:
152
+ x = x * self._expand_non_channel_dims(self.scale, x)
153
+ if self.offset is not None:
154
+ x = x + self._expand_non_channel_dims(self.offset, x)
155
+ return x
156
+
157
+ def _expand_non_channel_dims(self, value, inputs):
158
+ unbatched = len(ops.shape(inputs)) == 3
159
+ channels_first = self.data_format == "channels_first"
160
+ if unbatched:
161
+ broadcast_dims = (1, 2) if channels_first else (0, 1)
162
+ else:
163
+ broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
164
+ # If inputs are not a tensor type, return a numpy array.
165
+ # This might happen when running under tf.data.
166
+ if ops.is_tensor(inputs):
167
+ # preprocessing decorator moves tensors to cpu in torch backend and
168
+ # processed on CPU, and then converted back to the appropriate
169
+ # device (potentially GPU) after preprocessing.
170
+ if keras.backend.backend() == "torch" and self.image_size is None:
171
+ return ops.expand_dims(value, broadcast_dims).cpu()
172
+ return ops.expand_dims(value, broadcast_dims)
173
+ else:
174
+ return np.expand_dims(value, broadcast_dims)
175
+
176
+ def get_config(self):
177
+ config = super().get_config()
178
+ config.update(
179
+ {
180
+ "image_size": self.image_size,
181
+ "scale": self.scale,
182
+ "offset": self.offset,
183
+ "interpolation": self.interpolation,
184
+ "crop_to_aspect_ratio": self.crop_to_aspect_ratio,
185
+ }
186
+ )
187
+ return config
45
188
 
46
189
  @classproperty
47
190
  def presets(cls):
@@ -69,13 +212,6 @@ class ImageConverter(PreprocessingLayer):
69
212
  You can run `cls.presets.keys()` to list all built-in presets available
70
213
  on the class.
71
214
 
72
- This constructor can be called in one of two ways. Either from the base
73
- class like `keras_hub.models.ImageConverter.from_preset()`, or from a
74
- model class like
75
- `keras_hub.models.PaliGemmaImageConverter.from_preset()`. If calling
76
- from the base class, the subclass of the returning object will be
77
- inferred from the config in the preset directory.
78
-
79
215
  Args:
80
216
  preset: string. A built-in preset identifier, a Kaggle Models
81
217
  handle, a Hugging Face handle, or a path to a local directory.
@@ -85,17 +221,20 @@ class ImageConverter(PreprocessingLayer):
85
221
 
86
222
  Examples:
87
223
  ```python
224
+ batch = np.random.randint(0, 256, size=(2, 512, 512, 3))
225
+
88
226
  # Resize images for `"pali_gemma_3b_224"`.
89
227
  converter = keras_hub.layers.ImageConverter.from_preset(
90
228
  "pali_gemma_3b_224"
91
229
  )
92
- converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
93
- # Override arguments on the base class.
230
+ converter(batch) # Output shape: (2, 224, 224, 3)
231
+
232
+ # Resize images for `"pali_gemma_3b_448"` without cropping.
94
233
  converter = keras_hub.layers.ImageConverter.from_preset(
95
234
  "pali_gemma_3b_448",
96
235
  crop_to_aspect_ratio=False,
97
236
  )
98
- converter(np.ones(2, 512, 512, 3)) # (2, 448, 448, 3)
237
+ converter(batch) # Output shape: (2, 448, 448, 3)
99
238
  ```
100
239
  """
101
240
  loader = get_preset_loader(preset)
@@ -110,8 +249,5 @@ class ImageConverter(PreprocessingLayer):
110
249
  Args:
111
250
  preset_dir: The path to the local model preset directory.
112
251
  """
113
- save_serialized_object(
114
- self,
115
- preset_dir,
116
- config_file=IMAGE_CONVERTER_CONFIG_FILE,
117
- )
252
+ saver = get_preset_saver(preset_dir)
253
+ saver.save_image_converter(self)
@@ -164,7 +164,7 @@ class Bleu(keras.metrics.Metric):
164
164
  return inputs
165
165
 
166
166
  def _get_ngrams(self, segment, max_order):
167
- """Extracts all n-grams up to a given maximum order from an input segment.
167
+ """Extracts all n-grams up to a given maximum order from an input.
168
168
 
169
169
  Uses Python ops. Inspired from
170
170
  https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
@@ -329,8 +329,9 @@ class Bleu(keras.metrics.Metric):
329
329
  return tf.squeeze(inputs, axis=-1)
330
330
  else:
331
331
  raise ValueError(
332
- f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
333
- f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
332
+ f"{tensor_name} must be of rank {base_rank}, "
333
+ f"{base_rank + 1}, or {base_rank + 2}. "
334
+ f"Found rank: {inputs.shape.rank}"
334
335
  )
335
336
 
336
337
  y_true = validate_and_fix_rank(y_true, "y_true", 1)
@@ -8,11 +8,9 @@ backbone_presets = {
8
8
  "Trained on English Wikipedia + BooksCorpus."
9
9
  ),
10
10
  "params": 11683584,
11
- "official_name": "ALBERT",
12
11
  "path": "albert",
13
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
14
12
  },
15
- "kaggle_handle": "kaggle://keras/albert/keras/albert_base_en_uncased/2",
13
+ "kaggle_handle": "kaggle://keras/albert/keras/albert_base_en_uncased/5",
16
14
  },
17
15
  "albert_large_en_uncased": {
18
16
  "metadata": {
@@ -21,11 +19,9 @@ backbone_presets = {
21
19
  "Trained on English Wikipedia + BooksCorpus."
22
20
  ),
23
21
  "params": 17683968,
24
- "official_name": "ALBERT",
25
22
  "path": "albert",
26
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
27
23
  },
28
- "kaggle_handle": "kaggle://keras/albert/keras/albert_large_en_uncased/2",
24
+ "kaggle_handle": "kaggle://keras/albert/keras/albert_large_en_uncased/3",
29
25
  },
30
26
  "albert_extra_large_en_uncased": {
31
27
  "metadata": {
@@ -34,11 +30,9 @@ backbone_presets = {
34
30
  "Trained on English Wikipedia + BooksCorpus."
35
31
  ),
36
32
  "params": 58724864,
37
- "official_name": "ALBERT",
38
33
  "path": "albert",
39
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
40
34
  },
41
- "kaggle_handle": "kaggle://keras/albert/keras/albert_extra_large_en_uncased/2",
35
+ "kaggle_handle": "kaggle://keras/albert/keras/albert_extra_large_en_uncased/3",
42
36
  },
43
37
  "albert_extra_extra_large_en_uncased": {
44
38
  "metadata": {
@@ -47,10 +41,8 @@ backbone_presets = {
47
41
  "Trained on English Wikipedia + BooksCorpus."
48
42
  ),
49
43
  "params": 222595584,
50
- "official_name": "ALBERT",
51
44
  "path": "albert",
52
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
53
45
  },
54
- "kaggle_handle": "kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/2",
46
+ "kaggle_handle": "kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/3",
55
47
  },
56
48
  }
@@ -20,10 +20,10 @@ from keras_hub.src.models.text_classifier import TextClassifier
20
20
  class AlbertTextClassifier(TextClassifier):
21
21
  """An end-to-end ALBERT model for classification tasks
22
22
 
23
- This model attaches a classification head to a `keras_hub.model.AlbertBackbone`
24
- backbone, mapping from the backbone outputs to logit output suitable for
25
- a classification task. For usage of this model with pre-trained weights, see
26
- the `from_preset()` method.
23
+ This model attaches a classification head to a
24
+ `keras_hub.model.AlbertBackbone` backbone, mapping from the backbone outputs
25
+ to logit output suitable for a classification task. For usage of this model
26
+ with pre-trained weights, see the `from_preset()` method.
27
27
 
28
28
  This model can optionally be configured with a `preprocessor` layer, in
29
29
  which case it will automatically apply preprocessing to raw inputs during
@@ -36,9 +36,9 @@ class AlbertTextClassifier(TextClassifier):
36
36
  Args:
37
37
  backbone: A `keras_hub.models.AlertBackbone` instance.
38
38
  num_classes: int. Number of classes to predict.
39
- preprocessor: A `keras_hub.models.AlbertTextClassifierPreprocessor` or `None`. If
40
- `None`, this model will not apply preprocessing, and inputs should
41
- be preprocessed before calling the model.
39
+ preprocessor: A `keras_hub.models.AlbertTextClassifierPreprocessor` or
40
+ `None`. If `None`, this model will not apply preprocessing, and
41
+ inputs should be preprocessed before calling the model.
42
42
  activation: Optional `str` or callable. The
43
43
  activation function to use on the model outputs. Set
44
44
  `activation="softmax"` to return output probabilities.
@@ -1,15 +1,9 @@
1
- import os
2
-
3
1
  import keras
4
2
 
5
3
  from keras_hub.src.api_export import keras_hub_export
6
- from keras_hub.src.utils.keras_utils import assert_quantization_support
7
- from keras_hub.src.utils.preset_utils import CONFIG_FILE
8
- from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
9
4
  from keras_hub.src.utils.preset_utils import builtin_presets
10
5
  from keras_hub.src.utils.preset_utils import get_preset_loader
11
- from keras_hub.src.utils.preset_utils import save_metadata
12
- from keras_hub.src.utils.preset_utils import save_serialized_object
6
+ from keras_hub.src.utils.preset_utils import get_preset_saver
13
7
  from keras_hub.src.utils.python_utils import classproperty
14
8
 
15
9
 
@@ -88,10 +82,6 @@ class Backbone(keras.Model):
88
82
  def token_embedding(self, value):
89
83
  self._token_embedding = value
90
84
 
91
- def quantize(self, mode, **kwargs):
92
- assert_quantization_support()
93
- return super().quantize(mode, **kwargs)
94
-
95
85
  def get_config(self):
96
86
  # Don't chain to super here. `get_config()` for functional models is
97
87
  # a nested layer config and cannot be passed to Backbone constructors.
@@ -193,9 +183,8 @@ class Backbone(keras.Model):
193
183
  Args:
194
184
  preset_dir: The path to the local model preset directory.
195
185
  """
196
- save_serialized_object(self, preset_dir, config_file=CONFIG_FILE)
197
- self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE))
198
- save_metadata(self, preset_dir)
186
+ saver = get_preset_saver(preset_dir)
187
+ saver.save_backbone(self)
199
188
 
200
189
  def enable_lora(self, rank):
201
190
  """Enable Lora on the backbone.
@@ -22,9 +22,9 @@ class BartBackbone(Backbone):
22
22
  described in
23
23
  ["BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension"](https://arxiv.org/abs/1910.13461).
24
24
 
25
- The default constructor gives a fully customizable, randomly initialized BART
26
- model with any number of layers, heads, and embedding dimensions. To load
27
- preset architectures and weights, use the `from_preset` constructor.
25
+ The default constructor gives a fully customizable, randomly initialized
26
+ BART model with any number of layers, heads, and embedding dimensions. To
27
+ load preset architectures and weights, use the `from_preset` constructor.
28
28
 
29
29
  Disclaimer: Pre-trained models are provided on an "as is" basis, without
30
30
  warranties or conditions of any kind. The underlying model is provided by a
@@ -78,7 +78,7 @@ class BartBackbone(Backbone):
78
78
  )
79
79
  output = model(input_data)
80
80
  ```
81
- """
81
+ """ # noqa: E501
82
82
 
83
83
  def __init__(
84
84
  self,
@@ -8,11 +8,9 @@ backbone_presets = {
8
8
  "Trained on BookCorpus, English Wikipedia and CommonCrawl."
9
9
  ),
10
10
  "params": 139417344,
11
- "official_name": "BART",
12
11
  "path": "bart",
13
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
14
12
  },
15
- "kaggle_handle": "kaggle://keras/bart/keras/bart_base_en/2",
13
+ "kaggle_handle": "kaggle://keras/bart/keras/bart_base_en/3",
16
14
  },
17
15
  "bart_large_en": {
18
16
  "metadata": {
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on BookCorpus, English Wikipedia and CommonCrawl."
22
20
  ),
23
21
  "params": 406287360,
24
- "official_name": "BART",
25
22
  "path": "bart",
26
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
27
23
  },
28
24
  "config": {
29
25
  "vocabulary_size": 50265,
@@ -34,7 +30,7 @@ backbone_presets = {
34
30
  "dropout": 0.1,
35
31
  "max_sequence_length": 1024,
36
32
  },
37
- "kaggle_handle": "kaggle://keras/bart/keras/bart_large_en/2",
33
+ "kaggle_handle": "kaggle://keras/bart/keras/bart_large_en/3",
38
34
  },
39
35
  "bart_large_en_cnn": {
40
36
  "metadata": {
@@ -43,9 +39,7 @@ backbone_presets = {
43
39
  "summarization dataset."
44
40
  ),
45
41
  "params": 406287360,
46
- "official_name": "BART",
47
42
  "path": "bart",
48
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
49
43
  },
50
44
  "config": {
51
45
  "vocabulary_size": 50264,
@@ -56,6 +50,6 @@ backbone_presets = {
56
50
  "dropout": 0.1,
57
51
  "max_sequence_length": 1024,
58
52
  },
59
- "kaggle_handle": "kaggle://keras/bart/keras/bart_large_en_cnn/2",
53
+ "kaggle_handle": "kaggle://keras/bart/keras/bart_large_en_cnn/3",
60
54
  },
61
55
  }
@@ -60,7 +60,8 @@ class BartSeq2SeqLM(Seq2SeqLM):
60
60
  bart_lm.generate("The quick brown fox", max_length=30)
61
61
  ```
62
62
 
63
- Use `generate()` with encoder inputs and an incomplete decoder input (prompt).
63
+ Use `generate()` with encoder inputs and an incomplete decoder input
64
+ (prompt).
64
65
  ```python
65
66
  bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
66
67
  bart_lm.generate(
@@ -79,10 +80,10 @@ class BartSeq2SeqLM(Seq2SeqLM):
79
80
  prompt = {
80
81
  "encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
81
82
  "encoder_padding_mask": np.array(
82
- [[True, True, True, True, True, True, False, False]]
83
+ [[1, 1, 1, 1, 1, 1, 0, 0]]
83
84
  ),
84
85
  "decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]),
85
- "decoder_padding_mask": np.array([[True, True, True, True, False, False]])
86
+ "decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0]])
86
87
  }
87
88
 
88
89
  bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
@@ -95,7 +96,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
95
96
  Call `fit()` on a single batch.
96
97
  ```python
97
98
  features = {
98
- "encoder_text": ["The quick brown fox jumped.", "I forgot my homework."],
99
+ "encoder_text": ["The quick fox jumped.", "I forgot my homework."],
99
100
  "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
100
101
  }
101
102
  bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
@@ -195,7 +196,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
195
196
  cross_attention_cache=None,
196
197
  cross_attention_cache_update_index=None,
197
198
  ):
198
- """Forward pass with a key/value caches for generative decoding..
199
+ """Forward pass with a key/value caches for generative decoding.
199
200
 
200
201
  `call_decoder_with_cache` adds an additional inference-time forward pass
201
202
  for the model for seq2seq text generation. Unlike calling the model
@@ -241,7 +242,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
241
242
  key/value cache in the decoder's self-attention layer and
242
243
  `cross_attention_cache` is the key/value cache in the decoder's
243
244
  cross-attention layer.
244
- """
245
+ """ # noqa: E501
245
246
  # Embedding layers.
246
247
  tokens = self.backbone.token_embedding(decoder_token_ids)
247
248
  positions = self.backbone.decoder_position_embedding(
@@ -331,7 +332,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
331
332
  def _build_cache(
332
333
  self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
333
334
  ):
334
- """Builds the self-attention cache and the cross-attention cache (key/value pairs)."""
335
+ """Builds the self-attention cache and the cross-attention cache."""
335
336
  encoder_hidden_states = self.call_encoder(
336
337
  token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
337
338
  )
@@ -417,7 +418,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
417
418
  prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])
418
419
 
419
420
  def repeat_tensor(x):
420
- """Repeats tensors along batch axis to match dim for beam search."""
421
+ """Repeats along batch axis to match dim for beam search."""
421
422
  if ops.shape(x)[0] == num_samples:
422
423
  return x
423
424
  return ops.repeat(x, repeats=num_samples // batch_size, axis=0)
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
2
+ from keras_hub.src.models.basnet.basnet_presets import basnet_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(basnet_presets, BASNetBackbone)
@@ -0,0 +1,122 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
5
+ from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
6
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.BASNetImageSegmenter")
10
+ class BASNetImageSegmenter(ImageSegmenter):
11
+ """BASNet image segmentation task.
12
+
13
+ Args:
14
+ backbone: A `keras_hub.models.BASNetBackbone` instance.
15
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
16
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
17
+ will be applied to the inputs.
18
+
19
+ Example:
20
+ ```python
21
+ import keras_hub
22
+
23
+ images = np.ones(shape=(1, 288, 288, 3))
24
+ labels = np.zeros(shape=(1, 288, 288, 1))
25
+
26
+ image_encoder = keras_hub.models.ResNetBackbone.from_preset(
27
+ "resnet_18_imagenet",
28
+ load_weights=False
29
+ )
30
+ backbone = keras_hub.models.BASNetBackbone(
31
+ image_encoder,
32
+ num_classes=1,
33
+ image_shape=[288, 288, 3]
34
+ )
35
+ model = keras_hub.models.BASNetImageSegmenter(backbone)
36
+
37
+ # Evaluate the model
38
+ pred_labels = model(images)
39
+
40
+ # Train the model
41
+ model.compile(
42
+ optimizer="adam",
43
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
44
+ metrics=["accuracy"],
45
+ )
46
+ model.fit(images, labels, epochs=3)
47
+ ```
48
+ """
49
+
50
+ backbone_cls = BASNetBackbone
51
+ preprocessor_cls = BASNetPreprocessor
52
+
53
+ def __init__(
54
+ self,
55
+ backbone,
56
+ preprocessor=None,
57
+ **kwargs,
58
+ ):
59
+ # === Functional Model ===
60
+ x = backbone.input
61
+ outputs = backbone(x)
62
+ # only return the refinement module's output as final prediction
63
+ outputs = outputs["refine_out"]
64
+ super().__init__(inputs=x, outputs=outputs, **kwargs)
65
+
66
+ # === Config ===
67
+ self.backbone = backbone
68
+ self.preprocessor = preprocessor
69
+
70
+ def compute_loss(self, x, y, y_pred, *args, **kwargs):
71
+ # train BASNet's prediction and refinement module outputs against the
72
+ # same ground truth data
73
+ outputs = self.backbone(x)
74
+ losses = []
75
+ for output in outputs.values():
76
+ losses.append(super().compute_loss(x, y, output, *args, **kwargs))
77
+ return keras.ops.sum(losses, axis=0)
78
+
79
+ def compile(
80
+ self,
81
+ optimizer="auto",
82
+ loss="auto",
83
+ metrics="auto",
84
+ **kwargs,
85
+ ):
86
+ """Configures the `BASNet` task for training.
87
+
88
+ `BASNet` extends the default compilation signature
89
+ of `keras.Model.compile` with defaults for `optimizer` and `loss`. To
90
+ override these defaults, pass any value to these arguments during
91
+ compilation.
92
+
93
+ Args:
94
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
95
+ instance. Defaults to `"auto"`, which uses the default
96
+ optimizer for `BASNet`. See `keras.Model.compile` and
97
+ `keras.optimizers` for more info on possible `optimizer`
98
+ values.
99
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
100
+ Defaults to `"auto"`, in which case the default loss
101
+ computation of `BASNet` will be applied.
102
+ See `keras.Model.compile` and `keras.losses` for more info on
103
+ possible `loss` values.
104
+ metrics: `"auto"`, or a list of metrics to be evaluated by
105
+ the model during training and testing. Defaults to `"auto"`,
106
+ where a `keras.metrics.Accuracy` will be applied to track the
107
+ accuracy of the model during training.
108
+ See `keras.Model.compile` and `keras.metrics` for
109
+ more info on possible `metrics` values.
110
+ **kwargs: See `keras.Model.compile` for a full list of arguments
111
+ supported by the compile method.
112
+ """
113
+ if loss == "auto":
114
+ loss = keras.losses.BinaryCrossentropy()
115
+ if metrics == "auto":
116
+ metrics = [keras.metrics.Accuracy()]
117
+ super().compile(
118
+ optimizer=optimizer,
119
+ loss=loss,
120
+ metrics=metrics,
121
+ **kwargs,
122
+ )