keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
1
+ """RWKV7 model preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "rwkv7_g1a_0.1b_en": {
5
+ "metadata": {
6
+ "description": (
7
+ "150 million parameter RWKV7 model. Optimized for edge "
8
+ "devices and mobile deployment."
9
+ ),
10
+ "params": 150000000,
11
+ "path": "rwkv7",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/rwkv7/keras/rwkv7_g1a_0.1b/1",
14
+ },
15
+ "rwkv7_g1a_0.3b_en": {
16
+ "metadata": {
17
+ "description": (
18
+ "400 million parameter RWKV7 model. Small variant balancing "
19
+ "speed and instruction following."
20
+ ),
21
+ "params": 400000000,
22
+ "path": "rwkv7",
23
+ },
24
+ "kaggle_handle": "kaggle://keras/rwkv7/keras/rwkv7_g1a_0.3b/1",
25
+ },
26
+ }
@@ -0,0 +1,495 @@
1
+ import os
2
+
3
+ import keras
4
+ import numpy as np
5
+
6
+ from keras_hub.src.api_export import keras_hub_export
7
+ from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
8
+ from keras_hub.src.tokenizers import tokenizer
9
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
10
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
11
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
12
+ from keras_hub.src.utils.tensor_utils import tensor_to_list
13
+ from keras_hub.src.utils.tensor_utils import tf
14
+
15
+ # Vocabulary file name constant
16
+ VOCAB_FILENAME = "vocabulary.txt"
17
+
18
+
19
+ class TRIE:
20
+ """Byte-level Trie structure for longest prefix matching.
21
+
22
+ This class implements a trie data structure that stores byte
23
+ sequences and allows efficient longest prefix matching.
24
+ """
25
+
26
+ __slots__ = tuple("ch,children,values,parent".split(","))
27
+
28
+ def __init__(self, parent=None, ch=None):
29
+ """Initialize a TRIE node.
30
+
31
+ Args:
32
+ parent: Parent node reference.
33
+ ch: Byte value for this node.
34
+ """
35
+ self.ch = ch
36
+ self.children = [None for _ in range(256)]
37
+ self.values = set()
38
+ self.parent = parent
39
+
40
+ def __repr__(self):
41
+ """String representation of the TRIE node."""
42
+ current_node = self
43
+ ret = []
44
+ while current_node is not None:
45
+ if current_node.ch is not None:
46
+ ret.append(current_node.ch)
47
+ current_node = current_node.parent
48
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
49
+
50
+ def add(self, key, idx=0, val=None):
51
+ """Add a key-value pair to the trie.
52
+
53
+ Args:
54
+ key: Byte sequence to add.
55
+ idx: Current index in key processing.
56
+ val: Value to store (defaults to key).
57
+
58
+ Returns:
59
+ Final node where key was inserted.
60
+ """
61
+ if idx == len(key):
62
+ if val is None:
63
+ val = key
64
+ self.values.add(val)
65
+ return self
66
+
67
+ ch = key[idx]
68
+ if self.children[ch] is None:
69
+ self.children[ch] = TRIE(parent=self, ch=ch)
70
+
71
+ return self.children[ch].add(key, idx + 1, val)
72
+
73
+ def find_longest(self, key, idx=0):
74
+ """Find longest match in trie for given key.
75
+
76
+ Args:
77
+ key: Byte sequence to search for.
78
+ idx: Starting index for search.
79
+
80
+ Returns:
81
+ Tuple of (end_index, node, values) for match.
82
+ """
83
+ current_node = self
84
+ ch = key[idx]
85
+ ret = None
86
+
87
+ while current_node.children[ch] is not None:
88
+ current_node = current_node.children[ch]
89
+ idx += 1
90
+ if current_node.values:
91
+ ret = idx, current_node, current_node.values
92
+ if idx == len(key):
93
+ break
94
+ ch = key[idx]
95
+ if ret is None:
96
+ raise ValueError(f"No valid token found in trie for key: {key}")
97
+ return ret
98
+
99
+
100
+ class RWKVTokenizerBase:
101
+ """RWKV tokenizer implementation using byte-level trie.
102
+
103
+ Implements tokenization using a fixed vocabulary and greedy
104
+ longest-match algorithm on byte sequences.
105
+ """
106
+
107
+ def __init__(self, vocabs):
108
+ """Initialize tokenizer with vocabulary.
109
+
110
+ Args:
111
+ vocabs: List of vocabulary entries in format
112
+ "<idx> <repr> <len>".
113
+ """
114
+ self.idx2token = {}
115
+ for line in vocabs:
116
+ idx = int(line[: line.index(" ")])
117
+ x = eval(line[line.index(" ") : line.rindex(" ")])
118
+ x = x.encode("utf-8") if isinstance(x, str) else x
119
+ assert isinstance(x, bytes)
120
+ assert len(x) == int(line[line.rindex(" ") :])
121
+ self.idx2token[idx] = x
122
+
123
+ self.token2idx = {}
124
+ for k, v in self.idx2token.items():
125
+ self.token2idx[v] = int(k)
126
+
127
+ self.root = TRIE()
128
+ for token, token_id in self.token2idx.items():
129
+ _ = self.root.add(token, val=(token, token_id))
130
+
131
+ def encodeBytes(self, src):
132
+ """Encode byte sequence to token IDs.
133
+
134
+ Args:
135
+ src: Byte sequence to encode.
136
+
137
+ Returns:
138
+ List of token IDs.
139
+ """
140
+ idx = 0
141
+ tokens = []
142
+ while idx < len(src):
143
+ prev_idx = idx
144
+ idx, _, values = self.root.find_longest(src, idx)
145
+ assert idx != prev_idx
146
+ _, token = next(iter(values))
147
+ tokens.append(token)
148
+ return tokens
149
+
150
+ def decodeBytes(self, tokens):
151
+ """Decode token IDs to byte sequence.
152
+
153
+ Args:
154
+ tokens: List of token IDs.
155
+
156
+ Returns:
157
+ Decoded byte sequence.
158
+ """
159
+ return b"".join(map(lambda i: self.idx2token[int(i)], tokens))
160
+
161
+ def encode(self, src):
162
+ """Encode text to token IDs.
163
+
164
+ Args:
165
+ src: Text string or list of strings.
166
+
167
+ Returns:
168
+ Token IDs or list of token ID lists.
169
+ """
170
+ if isinstance(src, str):
171
+ return self.encodeBytes(src.encode("utf-8"))
172
+ else:
173
+ return [self.encodeBytes(s.encode("utf-8")) for s in src]
174
+
175
+ def decode(self, tokens):
176
+ """Decode token IDs to text.
177
+
178
+ Args:
179
+ tokens: Token IDs or list of token ID lists.
180
+
181
+ Returns:
182
+ List of decoded text strings.
183
+ """
184
+ return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
185
+ # try:
186
+ # return self.decodeBytes(tokens).decode('utf-8')
187
+ # except:
188
+ # return '\ufffd' # bad utf-8
189
+
190
+ def printTokens(self, tokens):
191
+ """Print tokens with their string representations.
192
+
193
+ Args:
194
+ tokens: List of token IDs to print.
195
+ """
196
+ for token_id in tokens:
197
+ token = self.idx2token[token_id]
198
+ try:
199
+ token = token.decode("utf-8")
200
+ except BaseException:
201
+ pass
202
+ print(f"{repr(token)}{token_id}", end=" ")
203
+ print()
204
+
205
+
206
+ @keras_hub_export("keras_hub.tokenizers.RWKVTokenizer")
207
+ class RWKVTokenizer(tokenizer.Tokenizer):
208
+ """RWKV byte-level tokenizer with longest-match trie search.
209
+
210
+ This tokenizer maps raw text to a sequence of integer token ids
211
+ using a fixed vocabulary and a greedy longest-match algorithm.
212
+
213
+ Args:
214
+ vocabulary: list of strings, each line formatted as
215
+ "<idx> <repr> <len>".
216
+ dtype: output dtype for tensor operations. Must be integer
217
+ or string type.
218
+
219
+ Examples:
220
+ ```python
221
+ vocab = ["0 ' ' 1", "1 '\\n' 1", "2 'the' 3", "3 'hello' 5"]
222
+ tok = RWKVTokenizer(vocabulary=vocab)
223
+ tok("hello the")
224
+ ```
225
+
226
+ Output:
227
+ [3, 0, 2]
228
+ """
229
+
230
+ backbone_cls = RWKV7Backbone
231
+
232
+ def __init__(
233
+ self,
234
+ vocabulary=None,
235
+ dtype="int32",
236
+ pad_token_id=0,
237
+ start_token_id=None,
238
+ end_token_id=None,
239
+ **kwargs,
240
+ ):
241
+ """Initialize RWKV tokenizer.
242
+
243
+ Args:
244
+ vocabulary: Vocabulary list.
245
+ dtype: Output data type.
246
+ **kwargs: Additional keyword arguments.
247
+ """
248
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
249
+ raise ValueError(
250
+ "Output dtype must be an integer type or a string. "
251
+ f"Received: dtype={dtype}"
252
+ )
253
+
254
+ super().__init__(dtype=dtype, **kwargs)
255
+
256
+ self.vocabulary = None
257
+ self.pad_token_id = pad_token_id
258
+ self.start_token_id = start_token_id
259
+ self.end_token_id = end_token_id or self.pad_token_id
260
+ if vocabulary is not None:
261
+ self.set_vocabulary(vocabulary)
262
+ self.file_assets = [VOCAB_FILENAME]
263
+
264
+ def set_vocabulary(self, vocabulary):
265
+ """Set the tokenizer vocabulary.
266
+
267
+ Args:
268
+ vocabulary: Vocabulary list to set.
269
+ """
270
+ self.vocabulary = vocabulary
271
+ self._tokenizer = RWKVTokenizerBase(vocabulary)
272
+ if self.end_token_id is None or self.end_token_id == self.pad_token_id:
273
+ for line in vocabulary:
274
+ idx = int(line[: line.index(" ")])
275
+ repr_str = eval(line[line.index(" ") : line.rindex(" ")])
276
+ if repr_str == "\n\n":
277
+ self.end_token_id = idx
278
+ break
279
+
280
+ def save_assets(self, dir_path):
281
+ """Save vocabulary to directory.
282
+
283
+ Args:
284
+ dir_path: Directory path to save to.
285
+ """
286
+ path = os.path.join(dir_path, VOCAB_FILENAME)
287
+ with open(path, "w", encoding="utf-8") as file:
288
+ file.write("".join(self.vocabulary))
289
+
290
+ def load_assets(self, dir_path=""):
291
+ """Load vocabulary from directory.
292
+
293
+ Args:
294
+ dir_path: Directory path to load from.
295
+ """
296
+ path = os.path.join(dir_path, VOCAB_FILENAME)
297
+ with open(path, "r", encoding="utf-8") as f:
298
+ vocabulary = f.readlines()
299
+ self.set_vocabulary(vocabulary)
300
+
301
+ def _check_vocabulary(self):
302
+ """Check if vocabulary is set, raise error if not."""
303
+ if self.vocabulary is None:
304
+ raise ValueError(
305
+ "No vocabulary has been set for RWKVTokenizer. Make "
306
+ "sure to pass a `vocabulary` argument when creating the layer."
307
+ )
308
+
309
+ def vocabulary_size(self):
310
+ """Get the size of the vocabulary.
311
+
312
+ Returns:
313
+ Number of tokens in vocabulary.
314
+ """
315
+ self._check_vocabulary()
316
+ return int(len(self.vocabulary))
317
+
318
+ def get_vocabulary(self):
319
+ """Get the current vocabulary.
320
+
321
+ Returns:
322
+ Current vocabulary list.
323
+ """
324
+ self._check_vocabulary()
325
+ return tensor_to_list(self.vocabulary)
326
+
327
+ def id_to_token(self, id):
328
+ """Convert token ID to string representation.
329
+
330
+ Args:
331
+ id: Token ID to convert.
332
+
333
+ Returns:
334
+ String representation of token.
335
+ """
336
+ self._check_vocabulary()
337
+ if id >= self.vocabulary_size() or id < 0:
338
+ raise ValueError(
339
+ f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
340
+ f"Received: {id}"
341
+ )
342
+ return self._tokenizer.idx2token[id]
343
+
344
+ def token_to_id(self, token):
345
+ """Convert a string token to an integer id."""
346
+ self._check_vocabulary()
347
+ return int(self._tokenizer.token2idx[token])
348
+
349
+ def get_config(self):
350
+ """Get tokenizer configuration.
351
+
352
+ Returns:
353
+ Configuration dictionary.
354
+ """
355
+ config = super().get_config()
356
+ config.update(
357
+ {
358
+ "vocabulary": self.vocabulary,
359
+ "end_token_id": self.end_token_id,
360
+ "pad_token_id": self.pad_token_id,
361
+ "start_token_id": self.start_token_id,
362
+ }
363
+ )
364
+ return config
365
+
366
+ @preprocessing_function
367
+ def tokenize(self, inputs):
368
+ self._check_vocabulary()
369
+
370
+ if not tf.executing_eagerly() and tf.is_tensor(inputs):
371
+
372
+ def tokenize_wrapper(text_tensor):
373
+ text_list = (
374
+ text_tensor.numpy()
375
+ if hasattr(text_tensor, "numpy")
376
+ else text_tensor
377
+ )
378
+ if isinstance(text_list, bytes):
379
+ text_list = [text_list.decode("utf-8")]
380
+ elif isinstance(text_list, np.ndarray):
381
+ text_list = [x.decode("utf-8") for x in text_list.flatten()]
382
+
383
+ tokens = self._tokenizer.encode(text_list)
384
+
385
+ if is_string_dtype(self.dtype):
386
+ result = [
387
+ self.id_to_token(i).decode("utf-8", errors="replace")
388
+ for i in tokens[0]
389
+ ]
390
+ return tf.constant(result, dtype=tf.string)
391
+ else:
392
+ return tf.constant(tokens[0], dtype=self.compute_dtype)
393
+
394
+ if inputs.shape.rank == 0:
395
+ output = tf.py_function(
396
+ tokenize_wrapper,
397
+ [inputs],
398
+ Tout=tf.string
399
+ if is_string_dtype(self.dtype)
400
+ else self.compute_dtype,
401
+ )
402
+ output.set_shape([None])
403
+ return output
404
+ else:
405
+ output = tf.map_fn(
406
+ lambda x: tf.py_function(
407
+ tokenize_wrapper,
408
+ [x],
409
+ Tout=tf.string
410
+ if is_string_dtype(self.dtype)
411
+ else self.compute_dtype,
412
+ ),
413
+ inputs,
414
+ fn_output_signature=tf.TensorSpec(
415
+ [None],
416
+ dtype=tf.string
417
+ if is_string_dtype(self.dtype)
418
+ else self.compute_dtype,
419
+ ),
420
+ )
421
+ return output
422
+
423
+ if tf.is_tensor(inputs):
424
+ inputs = tensor_to_list(inputs)
425
+
426
+ tokens = self._tokenizer.encode(inputs)
427
+
428
+ if is_string_dtype(self.dtype):
429
+
430
+ def ids_to_str(ids):
431
+ return [
432
+ self.id_to_token(i).decode("utf-8", errors="replace")
433
+ for i in ids
434
+ ]
435
+
436
+ if isinstance(inputs, str):
437
+ return ids_to_str(tokens)
438
+ return [ids_to_str(ts) for ts in tokens]
439
+
440
+ if isinstance(inputs, str):
441
+ return tf.convert_to_tensor(tokens, dtype=self.compute_dtype)
442
+ else:
443
+ return tf.ragged.constant(tokens, dtype=self.compute_dtype)
444
+
445
+ @preprocessing_function
446
+ def detokenize(self, inputs):
447
+ """Convert tokens back to text.
448
+
449
+ Args:
450
+ inputs: Tokens to convert.
451
+
452
+ Returns:
453
+ Detokenized text.
454
+ """
455
+ self._check_vocabulary()
456
+
457
+ if tf.is_tensor(inputs):
458
+ inputs = tensor_to_list(inputs)
459
+
460
+ if len(inputs) > 0 and isinstance(inputs[0], (int, np.integer)):
461
+ inputs = [inputs]
462
+
463
+ strip_zero_inputs = []
464
+ for seq in inputs:
465
+ if tf.is_tensor(seq):
466
+ seq = tensor_to_list(seq)
467
+ strip_zero_inputs.append([x for x in seq if x != 0])
468
+
469
+ result = self._tokenizer.decode(strip_zero_inputs)
470
+
471
+ return tf.convert_to_tensor(result, dtype=tf.string)
472
+
473
+ def compute_output_spec(self, input_spec):
474
+ """Compute output specification.
475
+
476
+ Args:
477
+ input_spec: Input specification.
478
+
479
+ Returns:
480
+ Output tensor specification.
481
+ """
482
+ return keras.KerasTensor(
483
+ input_spec.shape + (None,), dtype=self.compute_dtype
484
+ )
485
+
486
+ def call(self, inputs):
487
+ """Call the tokenizer on inputs.
488
+
489
+ Args:
490
+ inputs: Input text.
491
+
492
+ Returns:
493
+ Tokenized output.
494
+ """
495
+ return self.tokenize(inputs)
@@ -16,7 +16,11 @@ class SAMBackbone(Backbone):
16
16
  mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
17
17
  generate segmentation masks given the embeddings generated by the
18
18
  backbone and the prompt encoder.
19
- dtype: The dtype of the layer weights.
19
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
20
+ for the models computations and weights. Note that some
21
+ computations, such as softmax and layer normalization will always
22
+ be done in float32 precision regardless of dtype. Defaults to
23
+ `bfloat16`.
20
24
 
21
25
  Example:
22
26
  ```python
@@ -292,7 +292,7 @@ class SAMPromptEncoder(keras.layers.Layer):
292
292
  )
293
293
 
294
294
  dense_embeddings = ops.cond(
295
- ops.equal(ops.size(masks), 0),
295
+ ops.equal(ops.shape(masks)[1], 0),
296
296
  _no_mask_embed,
297
297
  _maybe_input_mask_embed,
298
298
  )
@@ -0,0 +1,7 @@
1
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
2
+ SAM3PromptableConceptBackbone,
3
+ )
4
+ from keras_hub.src.models.sam3.sam3_presets import backbone_presets
5
+ from keras_hub.src.utils.preset_utils import register_presets
6
+
7
+ register_presets(backbone_presets, SAM3PromptableConceptBackbone)