keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
|
5
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
6
|
+
SAM3PromptableConceptBackbone,
|
|
7
|
+
)
|
|
8
|
+
from keras_hub.src.models.sam3.sam3_pc_image_segmenter_preprocessor import (
|
|
9
|
+
SAM3PromptableConceptImageSegmenterPreprocessor,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras_hub_export("keras_hub.models.SAM3PromptableConceptImageSegmenter")
|
|
14
|
+
class SAM3PromptableConceptImageSegmenter(ImageSegmenter):
|
|
15
|
+
"""The Segment Anything 3 (SAM3) promptable concept image segmenter Model.
|
|
16
|
+
|
|
17
|
+
SAM3 promptable concept segmentation (PCS) segments objects in images based
|
|
18
|
+
on concept prompts, which could be short noun phrases
|
|
19
|
+
(e.g., “yellow school bus”), image exemplars, or a combination of both.
|
|
20
|
+
SAM3 PCS takes such prompts and returns segmentation masks and unique
|
|
21
|
+
identities for all matching object instances.
|
|
22
|
+
|
|
23
|
+
There are two ways to prompt:
|
|
24
|
+
1. Text prompt: A short noun phrase describing the concept to segment.
|
|
25
|
+
2. Box prompt: A box tells the model which part/crop of the image to
|
|
26
|
+
segment.
|
|
27
|
+
|
|
28
|
+
These prompts can be used individually or together, but at least one of the
|
|
29
|
+
prompts must be present. To turn off a particular prompt, simply exclude it
|
|
30
|
+
from the inputs to the model.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
backbone: A `keras_hub.models.SAM3PromptableConceptBackbone` instance.
|
|
34
|
+
preprocessor: Optional. An instance of
|
|
35
|
+
`SAM3PromptableConceptImageSegmenterPreprocessor` for input data
|
|
36
|
+
preprocessing.
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
|
|
40
|
+
Load pretrained model using `from_preset`.
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
image_size = 128
|
|
44
|
+
batch_size = 2
|
|
45
|
+
input_data = {
|
|
46
|
+
"images": np.ones(
|
|
47
|
+
(batch_size, image_size, image_size, 3), dtype="float32",
|
|
48
|
+
),
|
|
49
|
+
"prompts": ["ear", "head"],
|
|
50
|
+
"boxes": np.ones((batch_size, 1, 4), dtype="float32"), # XYXY format.
|
|
51
|
+
"box_labels": np.ones((batch_size, 1), dtype="float32"),
|
|
52
|
+
}
|
|
53
|
+
sam3_pcs = keras_hub.models.SAM3PromptableConceptImageSegmenter.from_preset(
|
|
54
|
+
"sam3_pcs"
|
|
55
|
+
)
|
|
56
|
+
outputs = sam3_pcs.predict(input_data)
|
|
57
|
+
scores = outputs["scores"] # [B, num_queries]
|
|
58
|
+
boxes = outputs["boxes"] # [B, num_queries, 4]
|
|
59
|
+
masks = outputs["masks"] # [B, num_queries, H, W]
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Load pretrained model with custom image shape.
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
input_image_size = 128
|
|
66
|
+
batch_size = 1
|
|
67
|
+
model_image_size = 336
|
|
68
|
+
input_data = {
|
|
69
|
+
"images": np.ones(
|
|
70
|
+
(batch_size, input_image_size, input_image_size, 3),
|
|
71
|
+
dtype="float32",
|
|
72
|
+
),
|
|
73
|
+
"prompts": ["ear", "head"],
|
|
74
|
+
"boxes": np.ones((batch_size, 1, 4), dtype="float32"), # XYXY format.
|
|
75
|
+
"box_labels": np.ones((batch_size, 1), dtype="float32"),
|
|
76
|
+
}
|
|
77
|
+
sam3_backbone = keras_hub.models.SAM3PromptableConceptBackbone.from_preset(
|
|
78
|
+
"sam3_pcs", image_shape=(model_image_size, model_image_size, 3)
|
|
79
|
+
)
|
|
80
|
+
sam3_preprocessor = keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor.from_preset(
|
|
81
|
+
"sam3_pcs"
|
|
82
|
+
)
|
|
83
|
+
sam3_preprocessor.image_size = (model_image_size, model_image_size)
|
|
84
|
+
sam3_pcs = keras_hub.models.SAM3PromptableConceptImageSegmenter(
|
|
85
|
+
backbone=sam3_backbone, preprocessor=sam3_preprocessor
|
|
86
|
+
)
|
|
87
|
+
outputs = sam3_pcs.predict(input_data)
|
|
88
|
+
scores = outputs["scores"] # [B, num_queries]
|
|
89
|
+
boxes = outputs["boxes"] # [B, num_queries, 4]
|
|
90
|
+
masks = outputs["masks"] # [B, num_queries, H, W]
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Load SAM3PromptableConceptImageSegmenter with custom backbone
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
vision_encoder = keras_hub.layers.SAM3VisionEncoder(
|
|
97
|
+
image_shape=(224, 224, 3),
|
|
98
|
+
patch_size=14,
|
|
99
|
+
num_layers=2,
|
|
100
|
+
hidden_dim=32,
|
|
101
|
+
intermediate_dim=128,
|
|
102
|
+
num_heads=2,
|
|
103
|
+
fpn_hidden_dim=32,
|
|
104
|
+
fpn_scale_factors=[4.0, 2.0, 1.0, 0.5],
|
|
105
|
+
pretrain_image_shape=(112, 112, 3),
|
|
106
|
+
window_size=2,
|
|
107
|
+
global_attn_indexes=[1, 2],
|
|
108
|
+
)
|
|
109
|
+
text_encoder = keras_hub.layers.SAM3TextEncoder(
|
|
110
|
+
vocabulary_size=1024,
|
|
111
|
+
embedding_dim=32,
|
|
112
|
+
hidden_dim=32,
|
|
113
|
+
num_layers=2,
|
|
114
|
+
num_heads=2,
|
|
115
|
+
intermediate_dim=128,
|
|
116
|
+
)
|
|
117
|
+
geometry_encoder = keras_hub.layers.SAM3GeometryEncoder(
|
|
118
|
+
num_layers=3,
|
|
119
|
+
hidden_dim=32,
|
|
120
|
+
intermediate_dim=128,
|
|
121
|
+
num_heads=2,
|
|
122
|
+
roi_size=7,
|
|
123
|
+
)
|
|
124
|
+
detr_encoder = keras_hub.layers.SAM3DetrEncoder(
|
|
125
|
+
num_layers=3,
|
|
126
|
+
hidden_dim=32,
|
|
127
|
+
intermediate_dim=128,
|
|
128
|
+
num_heads=2,
|
|
129
|
+
)
|
|
130
|
+
detr_decoder = keras_hub.layers.SAM3DetrDecoder(
|
|
131
|
+
image_shape=(224, 224, 3),
|
|
132
|
+
patch_size=14,
|
|
133
|
+
num_layers=2,
|
|
134
|
+
hidden_dim=32,
|
|
135
|
+
intermediate_dim=128,
|
|
136
|
+
num_heads=2,
|
|
137
|
+
num_queries=100,
|
|
138
|
+
)
|
|
139
|
+
mask_decoder = keras_hub.layers.SAM3MaskDecoder(
|
|
140
|
+
num_upsampling_stages=3,
|
|
141
|
+
hidden_dim=32,
|
|
142
|
+
num_heads=2,
|
|
143
|
+
)
|
|
144
|
+
backbone = keras_hub.models.SAM3PromptableConceptBackbone(
|
|
145
|
+
vision_encoder=vision_encoder,
|
|
146
|
+
text_encoder=text_encoder,
|
|
147
|
+
geometry_encoder=geometry_encoder,
|
|
148
|
+
detr_encoder=detr_encoder,
|
|
149
|
+
detr_decoder=detr_decoder,
|
|
150
|
+
mask_decoder=mask_decoder,
|
|
151
|
+
)
|
|
152
|
+
preprocessor = keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor.from_preset(
|
|
153
|
+
"sam3_pcs"
|
|
154
|
+
)
|
|
155
|
+
sam3_pcs = keras_hub.models.SAM3PromptableConceptImageSegmenter(
|
|
156
|
+
backbone=backbone, preprocessor=preprocessor
|
|
157
|
+
)
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
For example, to pass in all the prompts, do:
|
|
161
|
+
|
|
162
|
+
```python
|
|
163
|
+
image_size = 128
|
|
164
|
+
batch_size = 2
|
|
165
|
+
images = np.ones(
|
|
166
|
+
(batch_size, image_size, image_size, 3), dtype="float32",
|
|
167
|
+
)
|
|
168
|
+
prompts = ["ear", "head"]
|
|
169
|
+
# Box prompt in XYXY format
|
|
170
|
+
boxes = np.array(
|
|
171
|
+
[[[100.0, 100.0, 150.0, 150.0]], [[50.0, 50.0, 80.0, 80.0]]],
|
|
172
|
+
dtype="float32",
|
|
173
|
+
)
|
|
174
|
+
# Box labels: 1 means positive box, 0 means negative box, -10 is for
|
|
175
|
+
# padding boxes.
|
|
176
|
+
box_labels = np.array([[1], [1]], dtype="int32")
|
|
177
|
+
# Prepare an input dictionary:
|
|
178
|
+
inputs = {
|
|
179
|
+
"images": images,
|
|
180
|
+
"prompts": prompts,
|
|
181
|
+
"boxes": boxes,
|
|
182
|
+
"box_labels": box_labels,
|
|
183
|
+
}
|
|
184
|
+
outputs = sam3_pcs.predict(inputs)
|
|
185
|
+
scores = outputs["scores"] # [B, num_queries]
|
|
186
|
+
boxes = outputs["boxes"] # [B, num_queries, 4]
|
|
187
|
+
masks = outputs["masks"] # [B, num_queries, H, W]
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
Now, in case of only text prompts, simply exclude the box prompts:
|
|
191
|
+
|
|
192
|
+
```python
|
|
193
|
+
inputs = {
|
|
194
|
+
"images": images,
|
|
195
|
+
"prompts": prompts,
|
|
196
|
+
}
|
|
197
|
+
outputs = sam3_pcs.predict(inputs)
|
|
198
|
+
scores = outputs["scores"] # [B, num_queries]
|
|
199
|
+
boxes = outputs["boxes"] # [B, num_queries, 4]
|
|
200
|
+
masks = outputs["masks"] # [B, num_queries, H, W]
|
|
201
|
+
```
|
|
202
|
+
""" # noqa: E501
|
|
203
|
+
|
|
204
|
+
backbone_cls = SAM3PromptableConceptBackbone
|
|
205
|
+
preprocessor_cls = SAM3PromptableConceptImageSegmenterPreprocessor
|
|
206
|
+
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
backbone,
|
|
210
|
+
preprocessor=None,
|
|
211
|
+
**kwargs,
|
|
212
|
+
):
|
|
213
|
+
# === Layers ===
|
|
214
|
+
self.backbone = backbone
|
|
215
|
+
self.preprocessor = preprocessor
|
|
216
|
+
|
|
217
|
+
# === Functional Model ===
|
|
218
|
+
inputs = self.backbone.input
|
|
219
|
+
outputs = self.backbone(inputs)
|
|
220
|
+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
|
|
221
|
+
|
|
222
|
+
def fit(self, *args, **kwargs):
|
|
223
|
+
raise NotImplementedError(
|
|
224
|
+
"SAM3PromptableConceptImageSegmenter only supports inference for "
|
|
225
|
+
"now. Training the model isn't supported yet."
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def post_process_prediction(self, predictions):
|
|
229
|
+
"""Post-processes the raw model predictions.
|
|
230
|
+
|
|
231
|
+
This method converts the raw model preditions into the scores, boxes and
|
|
232
|
+
masks.
|
|
233
|
+
|
|
234
|
+
The output format is as follows:
|
|
235
|
+
- scores: A float tensor of shape `[batch_size, num_queries]`
|
|
236
|
+
representing the confidence score of each object instance. The score
|
|
237
|
+
is in the range [0, 1].
|
|
238
|
+
- boxes: A float tensor of shape `[batch_size, num_queries, 4]`
|
|
239
|
+
representing the bounding boxes of each object instance in
|
|
240
|
+
`[x_min, y_min, x_max, y_max]` format. The box coordinates are
|
|
241
|
+
normalized to the range [0, 1].
|
|
242
|
+
- masks: A boolean tensor of shape
|
|
243
|
+
`[batch_size, num_queries, height, width]` representing the binary
|
|
244
|
+
masks of each object instance.
|
|
245
|
+
"""
|
|
246
|
+
pred_logits = predictions["pred_logits"]
|
|
247
|
+
pred_boxes = predictions["pred_boxes"]
|
|
248
|
+
pred_masks = predictions["pred_masks"]
|
|
249
|
+
presence_logits = predictions["presence_logits"]
|
|
250
|
+
|
|
251
|
+
pred_scores = keras.ops.sigmoid(pred_logits)
|
|
252
|
+
presence_scores = keras.ops.sigmoid(presence_logits)
|
|
253
|
+
scores = keras.ops.multiply(pred_scores, presence_scores)
|
|
254
|
+
|
|
255
|
+
masks = keras.ops.sigmoid(pred_masks)
|
|
256
|
+
masks = keras.ops.transpose(masks, [0, 3, 1, 2])
|
|
257
|
+
return {
|
|
258
|
+
"scores": scores,
|
|
259
|
+
"boxes": pred_boxes,
|
|
260
|
+
"masks": masks,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
def predict_step(self, *args):
|
|
264
|
+
predictions = super().predict_step(*args)
|
|
265
|
+
if isinstance(predictions, tuple):
|
|
266
|
+
return self.post_process_prediction(predictions[0]), predictions[1]
|
|
267
|
+
return self.post_process_prediction(predictions)
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
def from_config(cls, config):
|
|
271
|
+
config = config.copy()
|
|
272
|
+
if "backbone" in config and isinstance(config["backbone"], dict):
|
|
273
|
+
config["backbone"] = keras.saving.deserialize_keras_object(
|
|
274
|
+
config["backbone"]
|
|
275
|
+
)
|
|
276
|
+
if "preprocessor" in config and isinstance(
|
|
277
|
+
config["preprocessor"], dict
|
|
278
|
+
):
|
|
279
|
+
config["preprocessor"] = keras.saving.deserialize_keras_object(
|
|
280
|
+
config["preprocessor"]
|
|
281
|
+
)
|
|
282
|
+
return cls(**config)
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
|
5
|
+
from keras_hub.src.models.preprocessor import Preprocessor
|
|
6
|
+
from keras_hub.src.models.sam3.sam3_image_converter import SAM3ImageConverter
|
|
7
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
8
|
+
SAM3PromptableConceptBackbone,
|
|
9
|
+
)
|
|
10
|
+
from keras_hub.src.models.sam3.sam3_tokenizer import SAM3Tokenizer
|
|
11
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import tensorflow as tf
|
|
15
|
+
except ImportError:
|
|
16
|
+
tf = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@keras_hub_export(
|
|
20
|
+
"keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor"
|
|
21
|
+
)
|
|
22
|
+
class SAM3PromptableConceptImageSegmenterPreprocessor(Preprocessor):
|
|
23
|
+
"""SAM3 Promptable Concept Image Segmenter preprocessor.
|
|
24
|
+
|
|
25
|
+
This preprocessing layer is meant for use with
|
|
26
|
+
`keras_hub.models.SAM3PromptableConceptImageSegmenter`.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
tokenizer: A `keras_hub.models.SAM3Tokenizer` instance.
|
|
30
|
+
image_converter: A `keras_hub.layers.SAM3ImageConverter` instance.
|
|
31
|
+
sequence_length: The length of the packed token_ids. Defaults to `32`.
|
|
32
|
+
add_start_token: If `True`, the preprocessor will prepend the tokenizer
|
|
33
|
+
start token to each input sequence. Defaults to `True`.
|
|
34
|
+
add_end_token: If `True`, the preprocessor will append the tokenizer
|
|
35
|
+
end token to each input sequence. Defaults to `True`.
|
|
36
|
+
point_pad_value: int. The padding value for box prompts. Defaults to
|
|
37
|
+
`-10`.
|
|
38
|
+
|
|
39
|
+
Call arguments:
|
|
40
|
+
x: A dictionary with the following keys:
|
|
41
|
+
- images: A single image or a batch of images, of shape
|
|
42
|
+
`(height, width, 3)` or `(batch_size, height, width, 3)`.
|
|
43
|
+
- prompts: (optional) A string or a batch of strings containing the
|
|
44
|
+
text prompts. If not provided, a default prompt will be used.
|
|
45
|
+
- boxes: (optional) A tensor of shape `(num_boxes, 4)` or
|
|
46
|
+
`(batch_size, num_boxes, 4)` containing box coordinates in
|
|
47
|
+
`(x_min, y_min, x_max, y_max)` format. Coordinates should be in
|
|
48
|
+
absolute pixel values. If not provided, no box prompts will be
|
|
49
|
+
used. `-10` is used as the padding value.
|
|
50
|
+
- box_labels: (optional) A tensor of shape `(num_boxes,)` or
|
|
51
|
+
`(batch_size, num_boxes)` containing box labels. If not provided,
|
|
52
|
+
no box labels will be used. `-10` is used as the padding value.
|
|
53
|
+
|
|
54
|
+
Examples:
|
|
55
|
+
|
|
56
|
+
```python
|
|
57
|
+
# Load the preprocessor from a preset.
|
|
58
|
+
preprocessor = keras_hub.models.SAM3PromptableConceptImageSegmenterPreprocessor.from_preset(
|
|
59
|
+
"sam3_pcs"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Unbatched inputs, with one image and one text prompt.
|
|
63
|
+
preprocessor(
|
|
64
|
+
{
|
|
65
|
+
"prompts": "ear",
|
|
66
|
+
"images": np.ones((896, 896, 3), dtype="float32")
|
|
67
|
+
}
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Unbatched inputs, with one image and one box prompt.
|
|
71
|
+
preprocessor(
|
|
72
|
+
{
|
|
73
|
+
"boxes": [[0, 0, 300, 300]],
|
|
74
|
+
"box_labels": [1],
|
|
75
|
+
"images": np.ones((896, 896, 3), dtype="float32")
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Batched inputs, one image per text prompt.
|
|
80
|
+
preprocessor(
|
|
81
|
+
{
|
|
82
|
+
"prompts": [
|
|
83
|
+
"ear",
|
|
84
|
+
"head"
|
|
85
|
+
],
|
|
86
|
+
"images": [
|
|
87
|
+
np.ones((896, 896, 3), dtype="float32"),
|
|
88
|
+
np.ones((896, 896, 3), dtype="float32")
|
|
89
|
+
]
|
|
90
|
+
}
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Batched inputs, one image per box prompt.
|
|
94
|
+
preprocessor(
|
|
95
|
+
{
|
|
96
|
+
"boxes": [
|
|
97
|
+
[[0, 0, 300, 300]],
|
|
98
|
+
[[50, 50, 100, 100]]
|
|
99
|
+
],
|
|
100
|
+
"box_labels": [
|
|
101
|
+
[1],
|
|
102
|
+
[1]
|
|
103
|
+
],
|
|
104
|
+
"images": [
|
|
105
|
+
np.ones((896, 896, 3), dtype="float32"),
|
|
106
|
+
np.ones((896, 896, 3), dtype="float32")
|
|
107
|
+
]
|
|
108
|
+
}
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Different number of box prompts in every sample.
|
|
112
|
+
preprocessor(
|
|
113
|
+
{
|
|
114
|
+
"boxes": [
|
|
115
|
+
[[0, 0, 300, 300]],
|
|
116
|
+
[[50, 50, 100, 100], [150, 150, 200, 200]]
|
|
117
|
+
],
|
|
118
|
+
"box_labels": [
|
|
119
|
+
[1],
|
|
120
|
+
[1, 1]
|
|
121
|
+
],
|
|
122
|
+
"images": [
|
|
123
|
+
np.ones((896, 896, 3), dtype="float32"),
|
|
124
|
+
np.ones((896, 896, 3), dtype="float32")
|
|
125
|
+
]
|
|
126
|
+
}
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Apply preprocessing to a `tf.data.Dataset`.
|
|
130
|
+
inputs = {
|
|
131
|
+
"prompts": [
|
|
132
|
+
"ear",
|
|
133
|
+
"head",
|
|
134
|
+
],
|
|
135
|
+
"images": np.ones((2, 896, 896, 3), dtype="float32")
|
|
136
|
+
}
|
|
137
|
+
ds = tf.data.Dataset.from_tensor_slices(inputs)
|
|
138
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
|
139
|
+
```
|
|
140
|
+
""" # noqa: E501
|
|
141
|
+
|
|
142
|
+
backbone_cls = SAM3PromptableConceptBackbone
|
|
143
|
+
tokenizer_cls = SAM3Tokenizer
|
|
144
|
+
image_converter_cls = SAM3ImageConverter
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
tokenizer,
|
|
149
|
+
image_converter,
|
|
150
|
+
sequence_length=32,
|
|
151
|
+
add_start_token=True,
|
|
152
|
+
add_end_token=True,
|
|
153
|
+
point_pad_value=-10,
|
|
154
|
+
**kwargs,
|
|
155
|
+
):
|
|
156
|
+
super().__init__(**kwargs)
|
|
157
|
+
self.tokenizer = tokenizer
|
|
158
|
+
self.packer = None
|
|
159
|
+
self.image_converter = image_converter
|
|
160
|
+
self.sequence_length = sequence_length
|
|
161
|
+
self.add_start_token = add_start_token
|
|
162
|
+
self.add_end_token = add_end_token
|
|
163
|
+
self.point_pad_value = point_pad_value
|
|
164
|
+
|
|
165
|
+
def build(self, input_shape):
|
|
166
|
+
# Defer packer creation to `build()` so that we can be sure tokenizer
|
|
167
|
+
# assets have loaded when restoring a saved model.
|
|
168
|
+
self.packer = StartEndPacker(
|
|
169
|
+
start_value=self.tokenizer.start_token_id,
|
|
170
|
+
end_value=self.tokenizer.end_token_id,
|
|
171
|
+
pad_value=self.tokenizer.pad_token_id,
|
|
172
|
+
sequence_length=self.sequence_length,
|
|
173
|
+
return_padding_mask=True,
|
|
174
|
+
)
|
|
175
|
+
self.built = True
|
|
176
|
+
|
|
177
|
+
def _preprocess_boxes(self, boxes, box_labels, height, width):
|
|
178
|
+
if isinstance(boxes, tf.RaggedTensor):
|
|
179
|
+
max_num_boxes = tf.reduce_max(boxes.row_lengths(axis=1))
|
|
180
|
+
boxes = boxes.to_tensor(
|
|
181
|
+
shape=[None, max_num_boxes, 4],
|
|
182
|
+
default_value=self.point_pad_value,
|
|
183
|
+
)
|
|
184
|
+
box_labels = box_labels.to_tensor(
|
|
185
|
+
shape=[None, max_num_boxes],
|
|
186
|
+
default_value=self.point_pad_value,
|
|
187
|
+
)
|
|
188
|
+
box_dtype = keras.backend.standardize_dtype(boxes.dtype)
|
|
189
|
+
normalized_boxes = tf.stack(
|
|
190
|
+
[
|
|
191
|
+
boxes[..., 0] / tf.cast(width, box_dtype),
|
|
192
|
+
boxes[..., 1] / tf.cast(height, box_dtype),
|
|
193
|
+
boxes[..., 2] / tf.cast(width, box_dtype),
|
|
194
|
+
boxes[..., 3] / tf.cast(height, box_dtype),
|
|
195
|
+
],
|
|
196
|
+
axis=-1,
|
|
197
|
+
)
|
|
198
|
+
boxes = tf.where(
|
|
199
|
+
tf.equal(tf.expand_dims(box_labels, axis=-1), self.point_pad_value),
|
|
200
|
+
tf.fill(
|
|
201
|
+
tf.shape(normalized_boxes),
|
|
202
|
+
tf.cast(self.point_pad_value, normalized_boxes.dtype),
|
|
203
|
+
),
|
|
204
|
+
normalized_boxes,
|
|
205
|
+
)
|
|
206
|
+
# XYXY to CXCYWH.
|
|
207
|
+
boxes = tf.stack(
|
|
208
|
+
[
|
|
209
|
+
(boxes[..., 0] + boxes[..., 2]) / 2.0,
|
|
210
|
+
(boxes[..., 1] + boxes[..., 3]) / 2.0,
|
|
211
|
+
boxes[..., 2] - boxes[..., 0],
|
|
212
|
+
boxes[..., 3] - boxes[..., 1],
|
|
213
|
+
],
|
|
214
|
+
axis=-1,
|
|
215
|
+
)
|
|
216
|
+
# Add batch indices.
|
|
217
|
+
batch_size = tf.shape(boxes)[0]
|
|
218
|
+
batch_indices = tf.range(batch_size, dtype=boxes.dtype)
|
|
219
|
+
batch_indices = tf.reshape(batch_indices, (batch_size, 1, 1))
|
|
220
|
+
batch_indices = tf.tile(batch_indices, (1, tf.shape(boxes)[1], 1))
|
|
221
|
+
boxes = tf.concat([batch_indices, boxes], axis=-1)
|
|
222
|
+
return boxes, box_labels
|
|
223
|
+
|
|
224
|
+
@preprocessing_function
|
|
225
|
+
def call(
|
|
226
|
+
self,
|
|
227
|
+
x,
|
|
228
|
+
y=None,
|
|
229
|
+
sample_weight=None,
|
|
230
|
+
sequence_length=None,
|
|
231
|
+
):
|
|
232
|
+
sequence_length = sequence_length or self.sequence_length
|
|
233
|
+
|
|
234
|
+
images = x["images"]
|
|
235
|
+
prompts = x.get("prompts", None)
|
|
236
|
+
boxes, box_labels = x.get("boxes", None), x.get("box_labels", None)
|
|
237
|
+
|
|
238
|
+
# Convert to batched inputs.
|
|
239
|
+
if len(images.shape) == 3:
|
|
240
|
+
is_batched = False
|
|
241
|
+
images = tf.expand_dims(images, axis=0)
|
|
242
|
+
if prompts is not None and len(prompts.shape) == 0:
|
|
243
|
+
prompts = tf.expand_dims(prompts, axis=0)
|
|
244
|
+
if boxes is not None and len(boxes.shape) == 2:
|
|
245
|
+
boxes = tf.expand_dims(boxes, axis=0)
|
|
246
|
+
box_labels = tf.expand_dims(box_labels, axis=0)
|
|
247
|
+
else:
|
|
248
|
+
is_batched = True
|
|
249
|
+
|
|
250
|
+
batch_size = tf.shape(images)[0]
|
|
251
|
+
height = tf.shape(images)[1]
|
|
252
|
+
width = tf.shape(images)[2]
|
|
253
|
+
|
|
254
|
+
# Add placeholders if not provided.
|
|
255
|
+
if prompts is None:
|
|
256
|
+
prompts = tf.convert_to_tensor("visual")
|
|
257
|
+
prompts = tf.tile(prompts[None], [batch_size])
|
|
258
|
+
if boxes is None:
|
|
259
|
+
boxes = tf.zeros((batch_size, 0, 4), dtype="float32")
|
|
260
|
+
box_labels = tf.zeros((batch_size, 0), dtype="int32")
|
|
261
|
+
|
|
262
|
+
# Tokenise the prompts.
|
|
263
|
+
prompts = self.tokenizer(prompts)
|
|
264
|
+
token_ids, padding_mask = self.packer(
|
|
265
|
+
prompts,
|
|
266
|
+
sequence_length=sequence_length + 1,
|
|
267
|
+
add_start_value=self.add_start_token,
|
|
268
|
+
add_end_value=self.add_end_token,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Resize and normalize the images.
|
|
272
|
+
pixel_values = self.image_converter(images)
|
|
273
|
+
if keras.config.backend() == "torch" and not isinstance(
|
|
274
|
+
images, tf.Tensor
|
|
275
|
+
):
|
|
276
|
+
images = images.cpu()
|
|
277
|
+
|
|
278
|
+
# Normalize the boxes.
|
|
279
|
+
boxes, box_labels = self._preprocess_boxes(
|
|
280
|
+
boxes, box_labels, height, width
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if not is_batched:
|
|
284
|
+
token_ids = tf.squeeze(token_ids, axis=0)
|
|
285
|
+
padding_mask = tf.squeeze(padding_mask, axis=0)
|
|
286
|
+
pixel_values = tf.squeeze(pixel_values, axis=0)
|
|
287
|
+
boxes = tf.squeeze(boxes, axis=0)
|
|
288
|
+
box_labels = tf.squeeze(box_labels, axis=0)
|
|
289
|
+
|
|
290
|
+
x = {
|
|
291
|
+
"pixel_values": pixel_values,
|
|
292
|
+
"token_ids": token_ids[..., :-1],
|
|
293
|
+
"padding_mask": padding_mask[..., :-1],
|
|
294
|
+
"boxes": boxes,
|
|
295
|
+
"box_labels": box_labels,
|
|
296
|
+
}
|
|
297
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
|
298
|
+
|
|
299
|
+
def get_config(self):
|
|
300
|
+
config = super().get_config()
|
|
301
|
+
config.update(
|
|
302
|
+
{
|
|
303
|
+
"sequence_length": self.sequence_length,
|
|
304
|
+
"add_start_token": self.add_start_token,
|
|
305
|
+
"add_end_token": self.add_end_token,
|
|
306
|
+
}
|
|
307
|
+
)
|
|
308
|
+
return config
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def sequence_length(self):
|
|
312
|
+
"""The padded length of model input sequences."""
|
|
313
|
+
return self._sequence_length
|
|
314
|
+
|
|
315
|
+
@sequence_length.setter
|
|
316
|
+
def sequence_length(self, value):
|
|
317
|
+
self._sequence_length = value
|
|
318
|
+
if self.packer is not None:
|
|
319
|
+
self.packer.sequence_length = value
|
|
320
|
+
|
|
321
|
+
@property
|
|
322
|
+
def image_size(self):
|
|
323
|
+
"""Settable tuple of `(height, width)` ints. The output image shape."""
|
|
324
|
+
if self.image_converter.resizing.height is None:
|
|
325
|
+
return None
|
|
326
|
+
return (
|
|
327
|
+
self.image_converter.resizing.height,
|
|
328
|
+
self.image_converter.resizing.width,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
@image_size.setter
|
|
332
|
+
def image_size(self, value):
|
|
333
|
+
if value is None:
|
|
334
|
+
value = (None, None)
|
|
335
|
+
self.image_converter.resizing.height = value[0]
|
|
336
|
+
self.image_converter.resizing.width = value[1]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""SAM3 model preset configurations."""
|
|
2
|
+
|
|
3
|
+
# Metadata for loading pretrained model weights.
|
|
4
|
+
backbone_presets = {
|
|
5
|
+
"sam3_pcs": {
|
|
6
|
+
"metadata": {
|
|
7
|
+
"description": (
|
|
8
|
+
"30 million parameter Promptable Concept Segmentation (PCS) "
|
|
9
|
+
"SAM model."
|
|
10
|
+
),
|
|
11
|
+
"params": 30000000,
|
|
12
|
+
"path": "sam3",
|
|
13
|
+
},
|
|
14
|
+
"kaggle_handle": "kaggle://keras/sam3/keras/sam3_pcs/1",
|
|
15
|
+
},
|
|
16
|
+
}
|