keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,282 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
5
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
6
+ SAM3PromptableConceptBackbone,
7
+ )
8
+ from keras_hub.src.models.sam3.sam3_pc_image_segmenter_preprocessor import (
9
+ SAM3PromptableConceptImageSegmenterPreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.SAM3PromptableConceptImageSegmenter")
14
+ class SAM3PromptableConceptImageSegmenter(ImageSegmenter):
15
+ """The Segment Anything 3 (SAM3) promptable concept image segmenter Model.
16
+
17
+ SAM3 promptable concept segmentation (PCS) segments objects in images based
18
+ on concept prompts, which could be short noun phrases
19
+ (e.g., “yellow school bus”), image exemplars, or a combination of both.
20
+ SAM3 PCS takes such prompts and returns segmentation masks and unique
21
+ identities for all matching object instances.
22
+
23
+ There are two ways to prompt:
24
+ 1. Text prompt: A short noun phrase describing the concept to segment.
25
+ 2. Box prompt: A box tells the model which part/crop of the image to
26
+ segment.
27
+
28
+ These prompts can be used individually or together, but at least one of the
29
+ prompts must be present. To turn off a particular prompt, simply exclude it
30
+ from the inputs to the model.
31
+
32
+ Args:
33
+ backbone: A `keras_hub.models.SAM3PromptableConceptBackbone` instance.
34
+ preprocessor: Optional. An instance of
35
+ `SAM3PromptableConceptImageSegmenterPreprocessor` for input data
36
+ preprocessing.
37
+
38
+ Example:
39
+
40
+ Load pretrained model using `from_preset`.
41
+
42
+ ```python
43
+ image_size = 128
44
+ batch_size = 2
45
+ input_data = {
46
+ "images": np.ones(
47
+ (batch_size, image_size, image_size, 3), dtype="float32",
48
+ ),
49
+ "prompts": ["ear", "head"],
50
+ "boxes": np.ones((batch_size, 1, 4), dtype="float32"), # XYXY format.
51
+ "box_labels": np.ones((batch_size, 1), dtype="float32"),
52
+ }
53
+ sam3_pcs = keras_hub.models.SAM3PromptableConceptImageSegmenter.from_preset(
54
+ "sam3_pcs"
55
+ )
56
+ outputs = sam3_pcs.predict(input_data)
57
+ scores = outputs["scores"] # [B, num_queries]
58
+ boxes = outputs["boxes"] # [B, num_queries, 4]
59
+ masks = outputs["masks"] # [B, num_queries, H, W]
60
+ ```
61
+
62
+ Load pretrained model with custom image shape.
63
+
64
+ ```python
65
+ input_image_size = 128
66
+ batch_size = 1
67
+ model_image_size = 336
68
+ input_data = {
69
+ "images": np.ones(
70
+ (batch_size, input_image_size, input_image_size, 3),
71
+ dtype="float32",
72
+ ),
73
+ "prompts": ["ear", "head"],
74
+ "boxes": np.ones((batch_size, 1, 4), dtype="float32"), # XYXY format.
75
+ "box_labels": np.ones((batch_size, 1), dtype="float32"),
76
+ }
77
+ sam3_backbone = keras_hub.models.SAM3PromptableConceptBackbone.from_preset(
78
+ "sam3_pcs", image_shape=(model_image_size, model_image_size, 3)
79
+ )
80
+ sam3_preprocessor = keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor.from_preset(
81
+ "sam3_pcs"
82
+ )
83
+ sam3_preprocessor.image_size = (model_image_size, model_image_size)
84
+ sam3_pcs = keras_hub.models.SAM3PromptableConceptImageSegmenter(
85
+ backbone=sam3_backbone, preprocessor=sam3_preprocessor
86
+ )
87
+ outputs = sam3_pcs.predict(input_data)
88
+ scores = outputs["scores"] # [B, num_queries]
89
+ boxes = outputs["boxes"] # [B, num_queries, 4]
90
+ masks = outputs["masks"] # [B, num_queries, H, W]
91
+ ```
92
+
93
+ Load SAM3PromptableConceptImageSegmenter with custom backbone
94
+
95
+ ```python
96
+ vision_encoder = keras_hub.layers.SAM3VisionEncoder(
97
+ image_shape=(224, 224, 3),
98
+ patch_size=14,
99
+ num_layers=2,
100
+ hidden_dim=32,
101
+ intermediate_dim=128,
102
+ num_heads=2,
103
+ fpn_hidden_dim=32,
104
+ fpn_scale_factors=[4.0, 2.0, 1.0, 0.5],
105
+ pretrain_image_shape=(112, 112, 3),
106
+ window_size=2,
107
+ global_attn_indexes=[1, 2],
108
+ )
109
+ text_encoder = keras_hub.layers.SAM3TextEncoder(
110
+ vocabulary_size=1024,
111
+ embedding_dim=32,
112
+ hidden_dim=32,
113
+ num_layers=2,
114
+ num_heads=2,
115
+ intermediate_dim=128,
116
+ )
117
+ geometry_encoder = keras_hub.layers.SAM3GeometryEncoder(
118
+ num_layers=3,
119
+ hidden_dim=32,
120
+ intermediate_dim=128,
121
+ num_heads=2,
122
+ roi_size=7,
123
+ )
124
+ detr_encoder = keras_hub.layers.SAM3DetrEncoder(
125
+ num_layers=3,
126
+ hidden_dim=32,
127
+ intermediate_dim=128,
128
+ num_heads=2,
129
+ )
130
+ detr_decoder = keras_hub.layers.SAM3DetrDecoder(
131
+ image_shape=(224, 224, 3),
132
+ patch_size=14,
133
+ num_layers=2,
134
+ hidden_dim=32,
135
+ intermediate_dim=128,
136
+ num_heads=2,
137
+ num_queries=100,
138
+ )
139
+ mask_decoder = keras_hub.layers.SAM3MaskDecoder(
140
+ num_upsampling_stages=3,
141
+ hidden_dim=32,
142
+ num_heads=2,
143
+ )
144
+ backbone = keras_hub.models.SAM3PromptableConceptBackbone(
145
+ vision_encoder=vision_encoder,
146
+ text_encoder=text_encoder,
147
+ geometry_encoder=geometry_encoder,
148
+ detr_encoder=detr_encoder,
149
+ detr_decoder=detr_decoder,
150
+ mask_decoder=mask_decoder,
151
+ )
152
+ preprocessor = keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor.from_preset(
153
+ "sam3_pcs"
154
+ )
155
+ sam3_pcs = keras_hub.models.SAM3PromptableConceptImageSegmenter(
156
+ backbone=backbone, preprocessor=preprocessor
157
+ )
158
+ ```
159
+
160
+ For example, to pass in all the prompts, do:
161
+
162
+ ```python
163
+ image_size = 128
164
+ batch_size = 2
165
+ images = np.ones(
166
+ (batch_size, image_size, image_size, 3), dtype="float32",
167
+ )
168
+ prompts = ["ear", "head"]
169
+ # Box prompt in XYXY format
170
+ boxes = np.array(
171
+ [[[100.0, 100.0, 150.0, 150.0]], [[50.0, 50.0, 80.0, 80.0]]],
172
+ dtype="float32",
173
+ )
174
+ # Box labels: 1 means positive box, 0 means negative box, -10 is for
175
+ # padding boxes.
176
+ box_labels = np.array([[1], [1]], dtype="int32")
177
+ # Prepare an input dictionary:
178
+ inputs = {
179
+ "images": images,
180
+ "prompts": prompts,
181
+ "boxes": boxes,
182
+ "box_labels": box_labels,
183
+ }
184
+ outputs = sam3_pcs.predict(inputs)
185
+ scores = outputs["scores"] # [B, num_queries]
186
+ boxes = outputs["boxes"] # [B, num_queries, 4]
187
+ masks = outputs["masks"] # [B, num_queries, H, W]
188
+ ```
189
+
190
+ Now, in case of only text prompts, simply exclude the box prompts:
191
+
192
+ ```python
193
+ inputs = {
194
+ "images": images,
195
+ "prompts": prompts,
196
+ }
197
+ outputs = sam3_pcs.predict(inputs)
198
+ scores = outputs["scores"] # [B, num_queries]
199
+ boxes = outputs["boxes"] # [B, num_queries, 4]
200
+ masks = outputs["masks"] # [B, num_queries, H, W]
201
+ ```
202
+ """ # noqa: E501
203
+
204
+ backbone_cls = SAM3PromptableConceptBackbone
205
+ preprocessor_cls = SAM3PromptableConceptImageSegmenterPreprocessor
206
+
207
+ def __init__(
208
+ self,
209
+ backbone,
210
+ preprocessor=None,
211
+ **kwargs,
212
+ ):
213
+ # === Layers ===
214
+ self.backbone = backbone
215
+ self.preprocessor = preprocessor
216
+
217
+ # === Functional Model ===
218
+ inputs = self.backbone.input
219
+ outputs = self.backbone(inputs)
220
+ super().__init__(inputs=inputs, outputs=outputs, **kwargs)
221
+
222
+ def fit(self, *args, **kwargs):
223
+ raise NotImplementedError(
224
+ "SAM3PromptableConceptImageSegmenter only supports inference for "
225
+ "now. Training the model isn't supported yet."
226
+ )
227
+
228
+ def post_process_prediction(self, predictions):
229
+ """Post-processes the raw model predictions.
230
+
231
+ This method converts the raw model preditions into the scores, boxes and
232
+ masks.
233
+
234
+ The output format is as follows:
235
+ - scores: A float tensor of shape `[batch_size, num_queries]`
236
+ representing the confidence score of each object instance. The score
237
+ is in the range [0, 1].
238
+ - boxes: A float tensor of shape `[batch_size, num_queries, 4]`
239
+ representing the bounding boxes of each object instance in
240
+ `[x_min, y_min, x_max, y_max]` format. The box coordinates are
241
+ normalized to the range [0, 1].
242
+ - masks: A boolean tensor of shape
243
+ `[batch_size, num_queries, height, width]` representing the binary
244
+ masks of each object instance.
245
+ """
246
+ pred_logits = predictions["pred_logits"]
247
+ pred_boxes = predictions["pred_boxes"]
248
+ pred_masks = predictions["pred_masks"]
249
+ presence_logits = predictions["presence_logits"]
250
+
251
+ pred_scores = keras.ops.sigmoid(pred_logits)
252
+ presence_scores = keras.ops.sigmoid(presence_logits)
253
+ scores = keras.ops.multiply(pred_scores, presence_scores)
254
+
255
+ masks = keras.ops.sigmoid(pred_masks)
256
+ masks = keras.ops.transpose(masks, [0, 3, 1, 2])
257
+ return {
258
+ "scores": scores,
259
+ "boxes": pred_boxes,
260
+ "masks": masks,
261
+ }
262
+
263
+ def predict_step(self, *args):
264
+ predictions = super().predict_step(*args)
265
+ if isinstance(predictions, tuple):
266
+ return self.post_process_prediction(predictions[0]), predictions[1]
267
+ return self.post_process_prediction(predictions)
268
+
269
+ @classmethod
270
+ def from_config(cls, config):
271
+ config = config.copy()
272
+ if "backbone" in config and isinstance(config["backbone"], dict):
273
+ config["backbone"] = keras.saving.deserialize_keras_object(
274
+ config["backbone"]
275
+ )
276
+ if "preprocessor" in config and isinstance(
277
+ config["preprocessor"], dict
278
+ ):
279
+ config["preprocessor"] = keras.saving.deserialize_keras_object(
280
+ config["preprocessor"]
281
+ )
282
+ return cls(**config)
@@ -0,0 +1,336 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
5
+ from keras_hub.src.models.preprocessor import Preprocessor
6
+ from keras_hub.src.models.sam3.sam3_image_converter import SAM3ImageConverter
7
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
8
+ SAM3PromptableConceptBackbone,
9
+ )
10
+ from keras_hub.src.models.sam3.sam3_tokenizer import SAM3Tokenizer
11
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
12
+
13
+ try:
14
+ import tensorflow as tf
15
+ except ImportError:
16
+ tf = None
17
+
18
+
19
+ @keras_hub_export(
20
+ "keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor"
21
+ )
22
+ class SAM3PromptableConceptImageSegmenterPreprocessor(Preprocessor):
23
+ """SAM3 Promptable Concept Image Segmenter preprocessor.
24
+
25
+ This preprocessing layer is meant for use with
26
+ `keras_hub.models.SAM3PromptableConceptImageSegmenter`.
27
+
28
+ Args:
29
+ tokenizer: A `keras_hub.models.SAM3Tokenizer` instance.
30
+ image_converter: A `keras_hub.layers.SAM3ImageConverter` instance.
31
+ sequence_length: The length of the packed token_ids. Defaults to `32`.
32
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
33
+ start token to each input sequence. Defaults to `True`.
34
+ add_end_token: If `True`, the preprocessor will append the tokenizer
35
+ end token to each input sequence. Defaults to `True`.
36
+ point_pad_value: int. The padding value for box prompts. Defaults to
37
+ `-10`.
38
+
39
+ Call arguments:
40
+ x: A dictionary with the following keys:
41
+ - images: A single image or a batch of images, of shape
42
+ `(height, width, 3)` or `(batch_size, height, width, 3)`.
43
+ - prompts: (optional) A string or a batch of strings containing the
44
+ text prompts. If not provided, a default prompt will be used.
45
+ - boxes: (optional) A tensor of shape `(num_boxes, 4)` or
46
+ `(batch_size, num_boxes, 4)` containing box coordinates in
47
+ `(x_min, y_min, x_max, y_max)` format. Coordinates should be in
48
+ absolute pixel values. If not provided, no box prompts will be
49
+ used. `-10` is used as the padding value.
50
+ - box_labels: (optional) A tensor of shape `(num_boxes,)` or
51
+ `(batch_size, num_boxes)` containing box labels. If not provided,
52
+ no box labels will be used. `-10` is used as the padding value.
53
+
54
+ Examples:
55
+
56
+ ```python
57
+ # Load the preprocessor from a preset.
58
+ preprocessor = keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor.from_preset(
59
+ "sam3_pcs"
60
+ )
61
+
62
+ # Unbatched inputs, with one image and one text prompt.
63
+ preprocessor(
64
+ {
65
+ "prompts": "ear",
66
+ "images": np.ones((896, 896, 3), dtype="float32")
67
+ }
68
+ )
69
+
70
+ # Unbatched inputs, with one image and one box prompt.
71
+ preprocessor(
72
+ {
73
+ "boxes": [[0, 0, 300, 300]],
74
+ "box_labels": [1],
75
+ "images": np.ones((896, 896, 3), dtype="float32")
76
+ }
77
+ )
78
+
79
+ # Batched inputs, one image per text prompt.
80
+ preprocessor(
81
+ {
82
+ "prompts": [
83
+ "ear",
84
+ "head"
85
+ ],
86
+ "images": [
87
+ np.ones((896, 896, 3), dtype="float32"),
88
+ np.ones((896, 896, 3), dtype="float32")
89
+ ]
90
+ }
91
+ )
92
+
93
+ # Batched inputs, one image per box prompt.
94
+ preprocessor(
95
+ {
96
+ "boxes": [
97
+ [[0, 0, 300, 300]],
98
+ [[50, 50, 100, 100]]
99
+ ],
100
+ "box_labels": [
101
+ [1],
102
+ [1]
103
+ ],
104
+ "images": [
105
+ np.ones((896, 896, 3), dtype="float32"),
106
+ np.ones((896, 896, 3), dtype="float32")
107
+ ]
108
+ }
109
+ )
110
+
111
+ # Different number of box prompts in every sample.
112
+ preprocessor(
113
+ {
114
+ "boxes": [
115
+ [[0, 0, 300, 300]],
116
+ [[50, 50, 100, 100], [150, 150, 200, 200]]
117
+ ],
118
+ "box_labels": [
119
+ [1],
120
+ [1, 1]
121
+ ],
122
+ "images": [
123
+ np.ones((896, 896, 3), dtype="float32"),
124
+ np.ones((896, 896, 3), dtype="float32")
125
+ ]
126
+ }
127
+ )
128
+
129
+ # Apply preprocessing to a `tf.data.Dataset`.
130
+ inputs = {
131
+ "prompts": [
132
+ "ear",
133
+ "head",
134
+ ],
135
+ "images": np.ones((2, 896, 896, 3), dtype="float32")
136
+ }
137
+ ds = tf.data.Dataset.from_tensor_slices(inputs)
138
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
139
+ ```
140
+ """ # noqa: E501
141
+
142
+ backbone_cls = SAM3PromptableConceptBackbone
143
+ tokenizer_cls = SAM3Tokenizer
144
+ image_converter_cls = SAM3ImageConverter
145
+
146
+ def __init__(
147
+ self,
148
+ tokenizer,
149
+ image_converter,
150
+ sequence_length=32,
151
+ add_start_token=True,
152
+ add_end_token=True,
153
+ point_pad_value=-10,
154
+ **kwargs,
155
+ ):
156
+ super().__init__(**kwargs)
157
+ self.tokenizer = tokenizer
158
+ self.packer = None
159
+ self.image_converter = image_converter
160
+ self.sequence_length = sequence_length
161
+ self.add_start_token = add_start_token
162
+ self.add_end_token = add_end_token
163
+ self.point_pad_value = point_pad_value
164
+
165
+ def build(self, input_shape):
166
+ # Defer packer creation to `build()` so that we can be sure tokenizer
167
+ # assets have loaded when restoring a saved model.
168
+ self.packer = StartEndPacker(
169
+ start_value=self.tokenizer.start_token_id,
170
+ end_value=self.tokenizer.end_token_id,
171
+ pad_value=self.tokenizer.pad_token_id,
172
+ sequence_length=self.sequence_length,
173
+ return_padding_mask=True,
174
+ )
175
+ self.built = True
176
+
177
+ def _preprocess_boxes(self, boxes, box_labels, height, width):
178
+ if isinstance(boxes, tf.RaggedTensor):
179
+ max_num_boxes = tf.reduce_max(boxes.row_lengths(axis=1))
180
+ boxes = boxes.to_tensor(
181
+ shape=[None, max_num_boxes, 4],
182
+ default_value=self.point_pad_value,
183
+ )
184
+ box_labels = box_labels.to_tensor(
185
+ shape=[None, max_num_boxes],
186
+ default_value=self.point_pad_value,
187
+ )
188
+ box_dtype = keras.backend.standardize_dtype(boxes.dtype)
189
+ normalized_boxes = tf.stack(
190
+ [
191
+ boxes[..., 0] / tf.cast(width, box_dtype),
192
+ boxes[..., 1] / tf.cast(height, box_dtype),
193
+ boxes[..., 2] / tf.cast(width, box_dtype),
194
+ boxes[..., 3] / tf.cast(height, box_dtype),
195
+ ],
196
+ axis=-1,
197
+ )
198
+ boxes = tf.where(
199
+ tf.equal(tf.expand_dims(box_labels, axis=-1), self.point_pad_value),
200
+ tf.fill(
201
+ tf.shape(normalized_boxes),
202
+ tf.cast(self.point_pad_value, normalized_boxes.dtype),
203
+ ),
204
+ normalized_boxes,
205
+ )
206
+ # XYXY to CXCYWH.
207
+ boxes = tf.stack(
208
+ [
209
+ (boxes[..., 0] + boxes[..., 2]) / 2.0,
210
+ (boxes[..., 1] + boxes[..., 3]) / 2.0,
211
+ boxes[..., 2] - boxes[..., 0],
212
+ boxes[..., 3] - boxes[..., 1],
213
+ ],
214
+ axis=-1,
215
+ )
216
+ # Add batch indices.
217
+ batch_size = tf.shape(boxes)[0]
218
+ batch_indices = tf.range(batch_size, dtype=boxes.dtype)
219
+ batch_indices = tf.reshape(batch_indices, (batch_size, 1, 1))
220
+ batch_indices = tf.tile(batch_indices, (1, tf.shape(boxes)[1], 1))
221
+ boxes = tf.concat([batch_indices, boxes], axis=-1)
222
+ return boxes, box_labels
223
+
224
+ @preprocessing_function
225
+ def call(
226
+ self,
227
+ x,
228
+ y=None,
229
+ sample_weight=None,
230
+ sequence_length=None,
231
+ ):
232
+ sequence_length = sequence_length or self.sequence_length
233
+
234
+ images = x["images"]
235
+ prompts = x.get("prompts", None)
236
+ boxes, box_labels = x.get("boxes", None), x.get("box_labels", None)
237
+
238
+ # Convert to batched inputs.
239
+ if len(images.shape) == 3:
240
+ is_batched = False
241
+ images = tf.expand_dims(images, axis=0)
242
+ if prompts is not None and len(prompts.shape) == 0:
243
+ prompts = tf.expand_dims(prompts, axis=0)
244
+ if boxes is not None and len(boxes.shape) == 2:
245
+ boxes = tf.expand_dims(boxes, axis=0)
246
+ box_labels = tf.expand_dims(box_labels, axis=0)
247
+ else:
248
+ is_batched = True
249
+
250
+ batch_size = tf.shape(images)[0]
251
+ height = tf.shape(images)[1]
252
+ width = tf.shape(images)[2]
253
+
254
+ # Add placeholders if not provided.
255
+ if prompts is None:
256
+ prompts = tf.convert_to_tensor("visual")
257
+ prompts = tf.tile(prompts[None], [batch_size])
258
+ if boxes is None:
259
+ boxes = tf.zeros((batch_size, 0, 4), dtype="float32")
260
+ box_labels = tf.zeros((batch_size, 0), dtype="int32")
261
+
262
+ # Tokenise the prompts.
263
+ prompts = self.tokenizer(prompts)
264
+ token_ids, padding_mask = self.packer(
265
+ prompts,
266
+ sequence_length=sequence_length + 1,
267
+ add_start_value=self.add_start_token,
268
+ add_end_value=self.add_end_token,
269
+ )
270
+
271
+ # Resize and normalize the images.
272
+ pixel_values = self.image_converter(images)
273
+ if keras.config.backend() == "torch" and not isinstance(
274
+ images, tf.Tensor
275
+ ):
276
+ images = images.cpu()
277
+
278
+ # Normalize the boxes.
279
+ boxes, box_labels = self._preprocess_boxes(
280
+ boxes, box_labels, height, width
281
+ )
282
+
283
+ if not is_batched:
284
+ token_ids = tf.squeeze(token_ids, axis=0)
285
+ padding_mask = tf.squeeze(padding_mask, axis=0)
286
+ pixel_values = tf.squeeze(pixel_values, axis=0)
287
+ boxes = tf.squeeze(boxes, axis=0)
288
+ box_labels = tf.squeeze(box_labels, axis=0)
289
+
290
+ x = {
291
+ "pixel_values": pixel_values,
292
+ "token_ids": token_ids[..., :-1],
293
+ "padding_mask": padding_mask[..., :-1],
294
+ "boxes": boxes,
295
+ "box_labels": box_labels,
296
+ }
297
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
298
+
299
+ def get_config(self):
300
+ config = super().get_config()
301
+ config.update(
302
+ {
303
+ "sequence_length": self.sequence_length,
304
+ "add_start_token": self.add_start_token,
305
+ "add_end_token": self.add_end_token,
306
+ }
307
+ )
308
+ return config
309
+
310
+ @property
311
+ def sequence_length(self):
312
+ """The padded length of model input sequences."""
313
+ return self._sequence_length
314
+
315
+ @sequence_length.setter
316
+ def sequence_length(self, value):
317
+ self._sequence_length = value
318
+ if self.packer is not None:
319
+ self.packer.sequence_length = value
320
+
321
+ @property
322
+ def image_size(self):
323
+ """Settable tuple of `(height, width)` ints. The output image shape."""
324
+ if self.image_converter.resizing.height is None:
325
+ return None
326
+ return (
327
+ self.image_converter.resizing.height,
328
+ self.image_converter.resizing.width,
329
+ )
330
+
331
+ @image_size.setter
332
+ def image_size(self, value):
333
+ if value is None:
334
+ value = (None, None)
335
+ self.image_converter.resizing.height = value[0]
336
+ self.image_converter.resizing.width = value[1]
@@ -0,0 +1,16 @@
1
+ """SAM3 model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "sam3_pcs": {
6
+ "metadata": {
7
+ "description": (
8
+ "30 million parameter Promptable Concept Segmentation (PCS) "
9
+ "SAM model."
10
+ ),
11
+ "params": 30000000,
12
+ "path": "sam3",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/sam3/keras/sam3_pcs/1",
15
+ },
16
+ }