keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
14
14
  from keras_hub.src.layers.modeling.reversible_embedding import (
15
15
  ReversibleEmbedding,
16
16
  )
17
+ from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
17
18
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
18
19
  from keras_hub.src.layers.modeling.sine_position_encoding import (
19
20
  SinePositionEncoding,
@@ -34,12 +35,16 @@ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
34
35
  from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
35
36
  from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
36
37
  from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
38
+ from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
37
39
  from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
38
40
  DeepLabV3ImageConverter,
39
41
  )
40
42
  from keras_hub.src.models.densenet.densenet_image_converter import (
41
43
  DenseNetImageConverter,
42
44
  )
45
+ from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
46
+ EfficientNetImageConverter,
47
+ )
43
48
  from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
44
49
  from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
45
50
  PaliGemmaImageConverter,
@@ -47,9 +52,16 @@ from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
47
52
  from keras_hub.src.models.resnet.resnet_image_converter import (
48
53
  ResNetImageConverter,
49
54
  )
55
+ from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
56
+ from keras_hub.src.models.retinanet.retinanet_image_converter import (
57
+ RetinaNetImageConverter,
58
+ )
50
59
  from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
51
60
  from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
52
61
  from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
62
+ from keras_hub.src.models.segformer.segformer_image_converter import (
63
+ SegFormerImageConverter,
64
+ )
53
65
  from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
54
66
  from keras_hub.src.models.whisper.whisper_audio_converter import (
55
67
  WhisperAudioConverter,
@@ -53,8 +53,11 @@ from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import (
53
53
  from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
54
54
  from keras_hub.src.models.causal_lm import CausalLM
55
55
  from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
56
+ from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
56
57
  from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor
58
+ from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder
57
59
  from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer
60
+ from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder
58
61
  from keras_hub.src.models.csp_darknet.csp_darknet_backbone import (
59
62
  CSPDarkNetBackbone,
60
63
  )
@@ -128,6 +131,12 @@ from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
128
131
  from keras_hub.src.models.efficientnet.efficientnet_backbone import (
129
132
  EfficientNetBackbone,
130
133
  )
134
+ from keras_hub.src.models.efficientnet.efficientnet_image_classifier import (
135
+ EfficientNetImageClassifier,
136
+ )
137
+ from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import (
138
+ EfficientNetImageClassifierPreprocessor,
139
+ )
131
140
  from keras_hub.src.models.electra.electra_backbone import ElectraBackbone
132
141
  from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer
133
142
  from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone
@@ -153,6 +162,11 @@ from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import (
153
162
  )
154
163
  from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
155
164
  from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
165
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
166
+ from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage
167
+ from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
168
+ FluxTextToImagePreprocessor,
169
+ )
156
170
  from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
157
171
  from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
158
172
  from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
@@ -176,6 +190,10 @@ from keras_hub.src.models.image_classifier import ImageClassifier
176
190
  from keras_hub.src.models.image_classifier_preprocessor import (
177
191
  ImageClassifierPreprocessor,
178
192
  )
193
+ from keras_hub.src.models.image_object_detector import ImageObjectDetector
194
+ from keras_hub.src.models.image_object_detector_preprocessor import (
195
+ ImageObjectDetectorPreprocessor,
196
+ )
179
197
  from keras_hub.src.models.image_segmenter import ImageSegmenter
180
198
  from keras_hub.src.models.image_segmenter_preprocessor import (
181
199
  ImageSegmenterPreprocessor,
@@ -243,6 +261,13 @@ from keras_hub.src.models.resnet.resnet_image_classifier import (
243
261
  from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
244
262
  ResNetImageClassifierPreprocessor,
245
263
  )
264
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
265
+ from keras_hub.src.models.retinanet.retinanet_object_detector import (
266
+ RetinaNetObjectDetector,
267
+ )
268
+ from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
269
+ RetinaNetObjectDetectorPreprocessor,
270
+ )
246
271
  from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
247
272
  from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
248
273
  from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import (
@@ -266,6 +291,13 @@ from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
266
291
  from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
267
292
  SAMImageSegmenterPreprocessor,
268
293
  )
294
+ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
295
+ from keras_hub.src.models.segformer.segformer_image_segmenter import (
296
+ SegFormerImageSegmenter,
297
+ )
298
+ from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
299
+ SegFormerImageSegmenterPreprocessor,
300
+ )
269
301
  from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
270
302
  from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
271
303
  from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
@@ -0,0 +1,2 @@
1
+ # TODO: Once all bounding boxes are moved to keras repostory remove the
2
+ # bounding box folder.
@@ -20,29 +20,74 @@ class RequiresImagesException(Exception):
20
20
  ALL_AXES = 4
21
21
 
22
22
 
23
- def _encode_box_to_deltas(
23
+ def encode_box_to_deltas(
24
24
  anchors,
25
25
  boxes,
26
- anchor_format: str,
27
- box_format: str,
26
+ anchor_format,
27
+ box_format,
28
+ encoding_format="center_yxhw",
28
29
  variance=None,
29
30
  image_shape=None,
30
31
  ):
31
- """Converts bounding_boxes from `center_yxhw` to delta format."""
32
+ """Encodes bounding boxes relative to anchors as deltas.
33
+
34
+ This function calculates the deltas that represent the difference between
35
+ bounding boxes and provided anchors. Deltas encode the offsets and scaling
36
+ factors to apply to anchors to obtain the target boxes.
37
+
38
+ Boxes and anchors are first converted to the specified `encoding_format`
39
+ (defaulting to `center_yxhw`) for consistent delta representation.
40
+
41
+ Args:
42
+ anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the
43
+ number of anchors.
44
+ boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape
45
+ `(B, N, 4)` or `(N, 4)`.
46
+ anchor_format: str. The format of the input `anchors`
47
+ (e.g., "xyxy", "xywh", etc.).
48
+ box_format: str. The format of the input `boxes`
49
+ (e.g., "xyxy", "xywh", etc.).
50
+ encoding_format: str. The intermediate format to which boxes and anchors
51
+ are converted before delta calculation. Defaults to "center_yxhw".
52
+ variance: `List[float]`. A 4-element array/tensor representing variance
53
+ factors to scale the box deltas. If provided, the calculated deltas
54
+ are divided by the variance. Defaults to None.
55
+ image_shape: `Tuple[int]`. The shape of the image (height, width, 3).
56
+ When using relative bounding box format for `box_format` the
57
+ `image_shape` is used for normalization.
58
+ Returns:
59
+ Encoded box deltas. The return type matches the `encode_format`.
60
+
61
+ Raises:
62
+ ValueError: If `variance` is not None and its length is not 4.
63
+ ValueError: If `encoding_format` is not `"center_xywh"` or
64
+ `"center_yxhw"`.
65
+
66
+ """
32
67
  if variance is not None:
33
68
  variance = ops.convert_to_tensor(variance, "float32")
34
69
  var_len = variance.shape[-1]
35
70
 
36
71
  if var_len != 4:
37
72
  raise ValueError(f"`variance` must be length 4, got {variance}")
73
+
74
+ if encoding_format not in ["center_xywh", "center_yxhw"]:
75
+ raise ValueError(
76
+ "`encoding_format` should be one of 'center_xywh' or 'center_yxhw', "
77
+ f"got {encoding_format}"
78
+ )
79
+
38
80
  encoded_anchors = convert_format(
39
81
  anchors,
40
82
  source=anchor_format,
41
- target="center_yxhw",
83
+ target=encoding_format,
42
84
  image_shape=image_shape,
43
85
  )
44
86
  boxes = convert_format(
45
- boxes, source=box_format, target="center_yxhw", image_shape=image_shape
87
+ boxes,
88
+ source=box_format,
89
+ target=encoding_format,
90
+ image_shape=image_shape,
46
91
  )
47
92
  anchor_dimensions = ops.maximum(
48
93
  encoded_anchors[..., 2:], keras.backend.epsilon()
@@ -61,15 +106,54 @@ def _encode_box_to_deltas(
61
106
  return boxes_delta
62
107
 
63
108
 
64
- def _decode_deltas_to_boxes(
109
+ def decode_deltas_to_boxes(
65
110
  anchors,
66
111
  boxes_delta,
67
- anchor_format: str,
68
- box_format: str,
112
+ anchor_format,
113
+ box_format,
114
+ encoded_format="center_yxhw",
69
115
  variance=None,
70
116
  image_shape=None,
71
117
  ):
72
- """Converts bounding_boxes from delta format to `center_yxhw`."""
118
+ """Converts bounding boxes from delta format to the specified `box_format`.
119
+
120
+ This function decodes bounding box deltas relative to anchors to obtain the
121
+ final bounding box coordinates. The boxes are encoded in a specific
122
+ `encoded_format` (center_yxhw by default) during the decoding process.
123
+ This allows flexibility in how the deltas are applied to the anchors.
124
+
125
+ Args:
126
+ anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level
127
+ indices and values are corresponding anchor boxes.
128
+ The shape of the array/tensor should be `(N, 4)` where N is the
129
+ number of anchors.
130
+ boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas
131
+ must have the same type and structure as `anchors`. The
132
+ shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is
133
+ the number of boxes.
134
+ anchor_format: str. The format of the input `anchors`.
135
+ (e.g., `"xyxy"`, `"xywh"`, etc.)
136
+ box_format: str. The desired format for the output boxes.
137
+ (e.g., `"xyxy"`, `"xywh"`, etc.)
138
+ encoded_format: str. Raw output format from regression head. Defaults
139
+ to `"center_yxhw"`.
140
+ variance: `List[floats]`. A 4-element array/tensor representing
141
+ variance factors to scale the box deltas. If provided, the deltas
142
+ are multiplied by the variance before being applied to the anchors.
143
+ Defaults to None.
144
+ image_shape: The shape of the image (height, width). This is needed
145
+ if normalization to image size is required when converting between
146
+ formats. Defaults to None.
147
+
148
+ Returns:
149
+ Decoded box coordinates. The return type matches the `box_format`.
150
+
151
+ Raises:
152
+ ValueError: If `variance` is not None and its length is not 4.
153
+ ValueError: If `encoded_format` is not `"center_xywh"` or
154
+ `"center_yxhw"`.
155
+
156
+ """
73
157
  if variance is not None:
74
158
  variance = ops.convert_to_tensor(variance, "float32")
75
159
  var_len = variance.shape[-1]
@@ -77,11 +161,17 @@ def _decode_deltas_to_boxes(
77
161
  if var_len != 4:
78
162
  raise ValueError(f"`variance` must be length 4, got {variance}")
79
163
 
164
+ if encoded_format not in ["center_xywh", "center_yxhw"]:
165
+ raise ValueError(
166
+ f"`encoded_format` should be 'center_xywh' or 'center_yxhw', "
167
+ f"but got '{encoded_format}'."
168
+ )
169
+
80
170
  def decode_single_level(anchor, box_delta):
81
171
  encoded_anchor = convert_format(
82
172
  anchor,
83
173
  source=anchor_format,
84
- target="center_yxhw",
174
+ target=encoded_format,
85
175
  image_shape=image_shape,
86
176
  )
87
177
  if variance is not None:
@@ -97,7 +187,7 @@ def _decode_deltas_to_boxes(
97
187
  )
98
188
  box = convert_format(
99
189
  box,
100
- source="center_yxhw",
190
+ source=encoded_format,
101
191
  target=box_format,
102
192
  image_shape=image_shape,
103
193
  )
@@ -0,0 +1,34 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+
6
+
7
+ @keras_hub_export("keras_hub.layers.RMSNormalization")
8
+ class RMSNormalization(keras.layers.Layer):
9
+ """
10
+ Root Mean Square (RMS) Normalization layer.
11
+ This layer normalizes the input tensor based on its RMS value and applies
12
+ a learned scaling factor.
13
+ Args:
14
+ input_dim: int. The dimensionality of the input tensor.
15
+ """
16
+
17
+ def __init__(self, input_dim):
18
+ super().__init__()
19
+ self.scale = self.add_weight(
20
+ name="scale", shape=(input_dim,), initializer="ones"
21
+ )
22
+
23
+ def call(self, x):
24
+ """
25
+ Applies RMS normalization to the input tensor.
26
+ Args:
27
+ x: KerasTensor. Input tensor of shape (batch_size, input_dim).
28
+ Returns:
29
+ KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim),
30
+ scaled by the learned `scale` parameter.
31
+ """
32
+ x = ops.cast(x, float)
33
+ rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6)
34
+ return (x * rrms) * self.scale
@@ -170,7 +170,12 @@ class TransformerEncoder(keras.layers.Layer):
170
170
  self.built = True
171
171
 
172
172
  def call(
173
- self, inputs, padding_mask=None, attention_mask=None, training=None
173
+ self,
174
+ inputs,
175
+ padding_mask=None,
176
+ attention_mask=None,
177
+ training=None,
178
+ return_attention_scores=False,
174
179
  ):
175
180
  """Forward pass of the TransformerEncoder.
176
181
 
@@ -185,6 +190,7 @@ class TransformerEncoder(keras.layers.Layer):
185
190
  [batch_size, sequence_length, sequence_length].
186
191
  training: a boolean indicating whether the layer should behave in
187
192
  training mode or in inference mode.
193
+ return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`.
188
194
 
189
195
  Returns:
190
196
  A Tensor of the same shape as the `inputs`.
@@ -200,12 +206,23 @@ class TransformerEncoder(keras.layers.Layer):
200
206
  residual = x
201
207
  if self.normalize_first:
202
208
  x = self._self_attention_layer_norm(x)
203
- x = self._self_attention_layer(
204
- query=x,
205
- value=x,
206
- attention_mask=self_attention_mask,
207
- training=training,
208
- )
209
+
210
+ if return_attention_scores:
211
+ x, attention_scores = self._self_attention_layer(
212
+ query=x,
213
+ value=x,
214
+ attention_mask=self_attention_mask,
215
+ return_attention_scores=return_attention_scores,
216
+ training=training,
217
+ )
218
+ else:
219
+ x = self._self_attention_layer(
220
+ query=x,
221
+ value=x,
222
+ attention_mask=self_attention_mask,
223
+ training=training,
224
+ )
225
+
209
226
  x = self._self_attention_dropout(x, training=training)
210
227
  x = x + residual
211
228
  if not self.normalize_first:
@@ -222,6 +239,9 @@ class TransformerEncoder(keras.layers.Layer):
222
239
  if not self.normalize_first:
223
240
  x = self._feedforward_layer_norm(x)
224
241
 
242
+ if return_attention_scores:
243
+ return x, attention_scores
244
+
225
245
  return x
226
246
 
227
247
  def get_config(self):
@@ -164,6 +164,11 @@ class ImageConverter(PreprocessingLayer):
164
164
  # If inputs are not a tensor type, return a numpy array.
165
165
  # This might happen when running under tf.data.
166
166
  if ops.is_tensor(inputs):
167
+ # preprocessing decorator moves tensors to cpu in torch backend and
168
+ # processed on CPU, and then converted back to the appropriate
169
+ # device (potentially GPU) after preprocessing.
170
+ if keras.backend.backend() == "torch" and self.image_size is None:
171
+ return ops.expand_dims(value, broadcast_dims).cpu()
167
172
  return ops.expand_dims(value, broadcast_dims)
168
173
  else:
169
174
  return np.expand_dims(value, broadcast_dims)
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on English Wikipedia + BooksCorpus."
9
9
  ),
10
10
  "params": 11683584,
11
- "official_name": "ALBERT",
12
11
  "path": "albert",
13
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/albert/keras/albert_base_en_uncased/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on English Wikipedia + BooksCorpus."
22
20
  ),
23
21
  "params": 17683968,
24
- "official_name": "ALBERT",
25
22
  "path": "albert",
26
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/albert/keras/albert_large_en_uncased/2",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "Trained on English Wikipedia + BooksCorpus."
35
31
  ),
36
32
  "params": 58724864,
37
- "official_name": "ALBERT",
38
33
  "path": "albert",
39
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/albert/keras/albert_extra_large_en_uncased/2",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "Trained on English Wikipedia + BooksCorpus."
48
42
  ),
49
43
  "params": 222595584,
50
- "official_name": "ALBERT",
51
44
  "path": "albert",
52
- "model_card": "https://github.com/google-research/albert/blob/master/README.md",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/2",
55
47
  },
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on BookCorpus, English Wikipedia and CommonCrawl."
9
9
  ),
10
10
  "params": 139417344,
11
- "official_name": "BART",
12
11
  "path": "bart",
13
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/bart/keras/bart_base_en/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on BookCorpus, English Wikipedia and CommonCrawl."
22
20
  ),
23
21
  "params": 406287360,
24
- "official_name": "BART",
25
22
  "path": "bart",
26
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
27
23
  },
28
24
  "config": {
29
25
  "vocabulary_size": 50265,
@@ -43,9 +39,7 @@ backbone_presets = {
43
39
  "summarization dataset."
44
40
  ),
45
41
  "params": 406287360,
46
- "official_name": "BART",
47
42
  "path": "bart",
48
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
49
43
  },
50
44
  "config": {
51
45
  "vocabulary_size": 50264,
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on English Wikipedia + BooksCorpus."
9
9
  ),
10
10
  "params": 4385920,
11
- "official_name": "BERT",
12
11
  "path": "bert",
13
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on English Wikipedia + BooksCorpus."
22
20
  ),
23
21
  "params": 28763648,
24
- "official_name": "BERT",
25
22
  "path": "bert",
26
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/bert/keras/bert_small_en_uncased/2",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "Trained on English Wikipedia + BooksCorpus."
35
31
  ),
36
32
  "params": 41373184,
37
- "official_name": "BERT",
38
33
  "path": "bert",
39
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/bert/keras/bert_medium_en_uncased/2",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "Trained on English Wikipedia + BooksCorpus."
48
42
  ),
49
43
  "params": 109482240,
50
- "official_name": "BERT",
51
44
  "path": "bert",
52
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/bert/keras/bert_base_en_uncased/2",
55
47
  },
@@ -60,9 +52,7 @@ backbone_presets = {
60
52
  "Trained on English Wikipedia + BooksCorpus."
61
53
  ),
62
54
  "params": 108310272,
63
- "official_name": "BERT",
64
55
  "path": "bert",
65
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
66
56
  },
67
57
  "kaggle_handle": "kaggle://keras/bert/keras/bert_base_en/2",
68
58
  },
@@ -72,9 +62,7 @@ backbone_presets = {
72
62
  "12-layer BERT model. Trained on Chinese Wikipedia."
73
63
  ),
74
64
  "params": 102267648,
75
- "official_name": "BERT",
76
65
  "path": "bert",
77
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
78
66
  },
79
67
  "kaggle_handle": "kaggle://keras/bert/keras/bert_base_zh/2",
80
68
  },
@@ -84,9 +72,7 @@ backbone_presets = {
84
72
  "12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages"
85
73
  ),
86
74
  "params": 177853440,
87
- "official_name": "BERT",
88
75
  "path": "bert",
89
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
90
76
  },
91
77
  "kaggle_handle": "kaggle://keras/bert/keras/bert_base_multi/2",
92
78
  },
@@ -97,9 +83,7 @@ backbone_presets = {
97
83
  "Trained on English Wikipedia + BooksCorpus."
98
84
  ),
99
85
  "params": 335141888,
100
- "official_name": "BERT",
101
86
  "path": "bert",
102
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
103
87
  },
104
88
  "kaggle_handle": "kaggle://keras/bert/keras/bert_large_en_uncased/2",
105
89
  },
@@ -110,9 +94,7 @@ backbone_presets = {
110
94
  "Trained on English Wikipedia + BooksCorpus."
111
95
  ),
112
96
  "params": 333579264,
113
- "official_name": "BERT",
114
97
  "path": "bert",
115
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
116
98
  },
117
99
  "kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
118
100
  },
@@ -122,9 +104,7 @@ backbone_presets = {
122
104
  "The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset."
123
105
  ),
124
106
  "params": 4385920,
125
- "official_name": "BERT",
126
107
  "path": "bert",
127
- "model_card": "https://github.com/google-research/bert/blob/master/README.md",
128
108
  },
129
109
  "kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
130
110
  },
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "trained on 45 natural languages and 12 programming languages."
9
9
  ),
10
10
  "params": 559214592,
11
- "official_name": "BLOOM",
12
11
  "path": "bloom",
13
- "model_card": "https://huggingface.co/bigscience/bloom-560m",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/3",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "trained on 45 natural languages and 12 programming languages."
22
20
  ),
23
21
  "params": 1065314304,
24
- "official_name": "BLOOM",
25
22
  "path": "bloom",
26
- "model_card": "https://huggingface.co/bigscience/bloom-1b1",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.1b_multi/1",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "trained on 45 natural languages and 12 programming languages."
35
31
  ),
36
32
  "params": 1722408960,
37
- "official_name": "BLOOM",
38
33
  "path": "bloom",
39
- "model_card": "https://huggingface.co/bigscience/bloom-1b7",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.7b_multi/1",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "trained on 45 natural languages and 12 programming languages."
48
42
  ),
49
43
  "params": 3002557440,
50
- "official_name": "BLOOM",
51
44
  "path": "bloom",
52
- "model_card": "https://huggingface.co/bigscience/bloom-3b",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/bloom/keras/bloom_3b_multi/1",
55
47
  },
@@ -60,9 +52,7 @@ backbone_presets = {
60
52
  "finetuned on crosslingual task mixture (xP3) dataset."
61
53
  ),
62
54
  "params": 559214592,
63
- "official_name": "BLOOMZ",
64
55
  "path": "bloom",
65
- "model_card": "https://huggingface.co/bigscience/bloomz-560m",
66
56
  },
67
57
  "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_560m_multi/1",
68
58
  },
@@ -73,9 +63,7 @@ backbone_presets = {
73
63
  "finetuned on crosslingual task mixture (xP3) dataset."
74
64
  ),
75
65
  "params": 1065314304,
76
- "official_name": "BLOOMZ",
77
66
  "path": "bloom",
78
- "model_card": "https://huggingface.co/bigscience/bloomz-1b1",
79
67
  },
80
68
  "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.1b_multi/1",
81
69
  },
@@ -86,9 +74,7 @@ backbone_presets = {
86
74
  "finetuned on crosslingual task mixture (xP3) dataset."
87
75
  ),
88
76
  "params": 1722408960,
89
- "official_name": "BLOOMZ",
90
77
  "path": "bloom",
91
- "model_card": "https://huggingface.co/bigscience/bloomz-1b7",
92
78
  },
93
79
  "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.7b_multi/1",
94
80
  },
@@ -99,9 +85,7 @@ backbone_presets = {
99
85
  "finetuned on crosslingual task mixture (xP3) dataset."
100
86
  ),
101
87
  "params": 3002557440,
102
- "official_name": "BLOOMZ",
103
88
  "path": "bloom",
104
- "model_card": "https://huggingface.co/bigscience/bloomz-3b",
105
89
  },
106
90
  "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_3b_multi/1",
107
91
  },
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
2
+ from keras_hub.src.models.clip.clip_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, CLIPBackbone)