keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""RWKV7 model preset configurations."""
|
|
2
|
+
|
|
3
|
+
backbone_presets = {
|
|
4
|
+
"rwkv7_g1a_0.1b_en": {
|
|
5
|
+
"metadata": {
|
|
6
|
+
"description": (
|
|
7
|
+
"150 million parameter RWKV7 model. Optimized for edge "
|
|
8
|
+
"devices and mobile deployment."
|
|
9
|
+
),
|
|
10
|
+
"params": 150000000,
|
|
11
|
+
"path": "rwkv7",
|
|
12
|
+
},
|
|
13
|
+
"kaggle_handle": "kaggle://keras/rwkv7/keras/rwkv7_g1a_0.1b/1",
|
|
14
|
+
},
|
|
15
|
+
"rwkv7_g1a_0.3b_en": {
|
|
16
|
+
"metadata": {
|
|
17
|
+
"description": (
|
|
18
|
+
"400 million parameter RWKV7 model. Small variant balancing "
|
|
19
|
+
"speed and instruction following."
|
|
20
|
+
),
|
|
21
|
+
"params": 400000000,
|
|
22
|
+
"path": "rwkv7",
|
|
23
|
+
},
|
|
24
|
+
"kaggle_handle": "kaggle://keras/rwkv7/keras/rwkv7_g1a_0.3b/1",
|
|
25
|
+
},
|
|
26
|
+
}
|
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
7
|
+
from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
|
|
8
|
+
from keras_hub.src.tokenizers import tokenizer
|
|
9
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
|
10
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
|
11
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
12
|
+
from keras_hub.src.utils.tensor_utils import tensor_to_list
|
|
13
|
+
from keras_hub.src.utils.tensor_utils import tf
|
|
14
|
+
|
|
15
|
+
# Vocabulary file name constant
|
|
16
|
+
VOCAB_FILENAME = "vocabulary.txt"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TRIE:
|
|
20
|
+
"""Byte-level Trie structure for longest prefix matching.
|
|
21
|
+
|
|
22
|
+
This class implements a trie data structure that stores byte
|
|
23
|
+
sequences and allows efficient longest prefix matching.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
__slots__ = tuple("ch,children,values,parent".split(","))
|
|
27
|
+
|
|
28
|
+
def __init__(self, parent=None, ch=None):
|
|
29
|
+
"""Initialize a TRIE node.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
parent: Parent node reference.
|
|
33
|
+
ch: Byte value for this node.
|
|
34
|
+
"""
|
|
35
|
+
self.ch = ch
|
|
36
|
+
self.children = [None for _ in range(256)]
|
|
37
|
+
self.values = set()
|
|
38
|
+
self.parent = parent
|
|
39
|
+
|
|
40
|
+
def __repr__(self):
|
|
41
|
+
"""String representation of the TRIE node."""
|
|
42
|
+
current_node = self
|
|
43
|
+
ret = []
|
|
44
|
+
while current_node is not None:
|
|
45
|
+
if current_node.ch is not None:
|
|
46
|
+
ret.append(current_node.ch)
|
|
47
|
+
current_node = current_node.parent
|
|
48
|
+
return "<TRIE %s %s>" % (ret[::-1], self.values)
|
|
49
|
+
|
|
50
|
+
def add(self, key, idx=0, val=None):
|
|
51
|
+
"""Add a key-value pair to the trie.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
key: Byte sequence to add.
|
|
55
|
+
idx: Current index in key processing.
|
|
56
|
+
val: Value to store (defaults to key).
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Final node where key was inserted.
|
|
60
|
+
"""
|
|
61
|
+
if idx == len(key):
|
|
62
|
+
if val is None:
|
|
63
|
+
val = key
|
|
64
|
+
self.values.add(val)
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
ch = key[idx]
|
|
68
|
+
if self.children[ch] is None:
|
|
69
|
+
self.children[ch] = TRIE(parent=self, ch=ch)
|
|
70
|
+
|
|
71
|
+
return self.children[ch].add(key, idx + 1, val)
|
|
72
|
+
|
|
73
|
+
def find_longest(self, key, idx=0):
|
|
74
|
+
"""Find longest match in trie for given key.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
key: Byte sequence to search for.
|
|
78
|
+
idx: Starting index for search.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Tuple of (end_index, node, values) for match.
|
|
82
|
+
"""
|
|
83
|
+
current_node = self
|
|
84
|
+
ch = key[idx]
|
|
85
|
+
ret = None
|
|
86
|
+
|
|
87
|
+
while current_node.children[ch] is not None:
|
|
88
|
+
current_node = current_node.children[ch]
|
|
89
|
+
idx += 1
|
|
90
|
+
if current_node.values:
|
|
91
|
+
ret = idx, current_node, current_node.values
|
|
92
|
+
if idx == len(key):
|
|
93
|
+
break
|
|
94
|
+
ch = key[idx]
|
|
95
|
+
if ret is None:
|
|
96
|
+
raise ValueError(f"No valid token found in trie for key: {key}")
|
|
97
|
+
return ret
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class RWKVTokenizerBase:
|
|
101
|
+
"""RWKV tokenizer implementation using byte-level trie.
|
|
102
|
+
|
|
103
|
+
Implements tokenization using a fixed vocabulary and greedy
|
|
104
|
+
longest-match algorithm on byte sequences.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, vocabs):
|
|
108
|
+
"""Initialize tokenizer with vocabulary.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
vocabs: List of vocabulary entries in format
|
|
112
|
+
"<idx> <repr> <len>".
|
|
113
|
+
"""
|
|
114
|
+
self.idx2token = {}
|
|
115
|
+
for line in vocabs:
|
|
116
|
+
idx = int(line[: line.index(" ")])
|
|
117
|
+
x = eval(line[line.index(" ") : line.rindex(" ")])
|
|
118
|
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
|
119
|
+
assert isinstance(x, bytes)
|
|
120
|
+
assert len(x) == int(line[line.rindex(" ") :])
|
|
121
|
+
self.idx2token[idx] = x
|
|
122
|
+
|
|
123
|
+
self.token2idx = {}
|
|
124
|
+
for k, v in self.idx2token.items():
|
|
125
|
+
self.token2idx[v] = int(k)
|
|
126
|
+
|
|
127
|
+
self.root = TRIE()
|
|
128
|
+
for token, token_id in self.token2idx.items():
|
|
129
|
+
_ = self.root.add(token, val=(token, token_id))
|
|
130
|
+
|
|
131
|
+
def encodeBytes(self, src):
|
|
132
|
+
"""Encode byte sequence to token IDs.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
src: Byte sequence to encode.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of token IDs.
|
|
139
|
+
"""
|
|
140
|
+
idx = 0
|
|
141
|
+
tokens = []
|
|
142
|
+
while idx < len(src):
|
|
143
|
+
prev_idx = idx
|
|
144
|
+
idx, _, values = self.root.find_longest(src, idx)
|
|
145
|
+
assert idx != prev_idx
|
|
146
|
+
_, token = next(iter(values))
|
|
147
|
+
tokens.append(token)
|
|
148
|
+
return tokens
|
|
149
|
+
|
|
150
|
+
def decodeBytes(self, tokens):
|
|
151
|
+
"""Decode token IDs to byte sequence.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
tokens: List of token IDs.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Decoded byte sequence.
|
|
158
|
+
"""
|
|
159
|
+
return b"".join(map(lambda i: self.idx2token[int(i)], tokens))
|
|
160
|
+
|
|
161
|
+
def encode(self, src):
|
|
162
|
+
"""Encode text to token IDs.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
src: Text string or list of strings.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Token IDs or list of token ID lists.
|
|
169
|
+
"""
|
|
170
|
+
if isinstance(src, str):
|
|
171
|
+
return self.encodeBytes(src.encode("utf-8"))
|
|
172
|
+
else:
|
|
173
|
+
return [self.encodeBytes(s.encode("utf-8")) for s in src]
|
|
174
|
+
|
|
175
|
+
def decode(self, tokens):
|
|
176
|
+
"""Decode token IDs to text.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
tokens: Token IDs or list of token ID lists.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
List of decoded text strings.
|
|
183
|
+
"""
|
|
184
|
+
return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
|
|
185
|
+
# try:
|
|
186
|
+
# return self.decodeBytes(tokens).decode('utf-8')
|
|
187
|
+
# except:
|
|
188
|
+
# return '\ufffd' # bad utf-8
|
|
189
|
+
|
|
190
|
+
def printTokens(self, tokens):
|
|
191
|
+
"""Print tokens with their string representations.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
tokens: List of token IDs to print.
|
|
195
|
+
"""
|
|
196
|
+
for token_id in tokens:
|
|
197
|
+
token = self.idx2token[token_id]
|
|
198
|
+
try:
|
|
199
|
+
token = token.decode("utf-8")
|
|
200
|
+
except BaseException:
|
|
201
|
+
pass
|
|
202
|
+
print(f"{repr(token)}{token_id}", end=" ")
|
|
203
|
+
print()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@keras_hub_export("keras_hub.tokenizers.RWKVTokenizer")
|
|
207
|
+
class RWKVTokenizer(tokenizer.Tokenizer):
|
|
208
|
+
"""RWKV byte-level tokenizer with longest-match trie search.
|
|
209
|
+
|
|
210
|
+
This tokenizer maps raw text to a sequence of integer token ids
|
|
211
|
+
using a fixed vocabulary and a greedy longest-match algorithm.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
vocabulary: list of strings, each line formatted as
|
|
215
|
+
"<idx> <repr> <len>".
|
|
216
|
+
dtype: output dtype for tensor operations. Must be integer
|
|
217
|
+
or string type.
|
|
218
|
+
|
|
219
|
+
Examples:
|
|
220
|
+
```python
|
|
221
|
+
vocab = ["0 ' ' 1", "1 '\\n' 1", "2 'the' 3", "3 'hello' 5"]
|
|
222
|
+
tok = RWKVTokenizer(vocabulary=vocab)
|
|
223
|
+
tok("hello the")
|
|
224
|
+
```
|
|
225
|
+
|
|
226
|
+
Output:
|
|
227
|
+
[3, 0, 2]
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
backbone_cls = RWKV7Backbone
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
vocabulary=None,
|
|
235
|
+
dtype="int32",
|
|
236
|
+
pad_token_id=0,
|
|
237
|
+
start_token_id=None,
|
|
238
|
+
end_token_id=None,
|
|
239
|
+
**kwargs,
|
|
240
|
+
):
|
|
241
|
+
"""Initialize RWKV tokenizer.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
vocabulary: Vocabulary list.
|
|
245
|
+
dtype: Output data type.
|
|
246
|
+
**kwargs: Additional keyword arguments.
|
|
247
|
+
"""
|
|
248
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"Output dtype must be an integer type or a string. "
|
|
251
|
+
f"Received: dtype={dtype}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
255
|
+
|
|
256
|
+
self.vocabulary = None
|
|
257
|
+
self.pad_token_id = pad_token_id
|
|
258
|
+
self.start_token_id = start_token_id
|
|
259
|
+
self.end_token_id = end_token_id or self.pad_token_id
|
|
260
|
+
if vocabulary is not None:
|
|
261
|
+
self.set_vocabulary(vocabulary)
|
|
262
|
+
self.file_assets = [VOCAB_FILENAME]
|
|
263
|
+
|
|
264
|
+
def set_vocabulary(self, vocabulary):
|
|
265
|
+
"""Set the tokenizer vocabulary.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
vocabulary: Vocabulary list to set.
|
|
269
|
+
"""
|
|
270
|
+
self.vocabulary = vocabulary
|
|
271
|
+
self._tokenizer = RWKVTokenizerBase(vocabulary)
|
|
272
|
+
if self.end_token_id is None or self.end_token_id == self.pad_token_id:
|
|
273
|
+
for line in vocabulary:
|
|
274
|
+
idx = int(line[: line.index(" ")])
|
|
275
|
+
repr_str = eval(line[line.index(" ") : line.rindex(" ")])
|
|
276
|
+
if repr_str == "\n\n":
|
|
277
|
+
self.end_token_id = idx
|
|
278
|
+
break
|
|
279
|
+
|
|
280
|
+
def save_assets(self, dir_path):
|
|
281
|
+
"""Save vocabulary to directory.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
dir_path: Directory path to save to.
|
|
285
|
+
"""
|
|
286
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
|
287
|
+
with open(path, "w", encoding="utf-8") as file:
|
|
288
|
+
file.write("".join(self.vocabulary))
|
|
289
|
+
|
|
290
|
+
def load_assets(self, dir_path=""):
|
|
291
|
+
"""Load vocabulary from directory.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
dir_path: Directory path to load from.
|
|
295
|
+
"""
|
|
296
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
|
297
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
298
|
+
vocabulary = f.readlines()
|
|
299
|
+
self.set_vocabulary(vocabulary)
|
|
300
|
+
|
|
301
|
+
def _check_vocabulary(self):
|
|
302
|
+
"""Check if vocabulary is set, raise error if not."""
|
|
303
|
+
if self.vocabulary is None:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
"No vocabulary has been set for RWKVTokenizer. Make "
|
|
306
|
+
"sure to pass a `vocabulary` argument when creating the layer."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def vocabulary_size(self):
|
|
310
|
+
"""Get the size of the vocabulary.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Number of tokens in vocabulary.
|
|
314
|
+
"""
|
|
315
|
+
self._check_vocabulary()
|
|
316
|
+
return int(len(self.vocabulary))
|
|
317
|
+
|
|
318
|
+
def get_vocabulary(self):
|
|
319
|
+
"""Get the current vocabulary.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Current vocabulary list.
|
|
323
|
+
"""
|
|
324
|
+
self._check_vocabulary()
|
|
325
|
+
return tensor_to_list(self.vocabulary)
|
|
326
|
+
|
|
327
|
+
def id_to_token(self, id):
|
|
328
|
+
"""Convert token ID to string representation.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
id: Token ID to convert.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
String representation of token.
|
|
335
|
+
"""
|
|
336
|
+
self._check_vocabulary()
|
|
337
|
+
if id >= self.vocabulary_size() or id < 0:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
|
|
340
|
+
f"Received: {id}"
|
|
341
|
+
)
|
|
342
|
+
return self._tokenizer.idx2token[id]
|
|
343
|
+
|
|
344
|
+
def token_to_id(self, token):
|
|
345
|
+
"""Convert a string token to an integer id."""
|
|
346
|
+
self._check_vocabulary()
|
|
347
|
+
return int(self._tokenizer.token2idx[token])
|
|
348
|
+
|
|
349
|
+
def get_config(self):
|
|
350
|
+
"""Get tokenizer configuration.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Configuration dictionary.
|
|
354
|
+
"""
|
|
355
|
+
config = super().get_config()
|
|
356
|
+
config.update(
|
|
357
|
+
{
|
|
358
|
+
"vocabulary": self.vocabulary,
|
|
359
|
+
"end_token_id": self.end_token_id,
|
|
360
|
+
"pad_token_id": self.pad_token_id,
|
|
361
|
+
"start_token_id": self.start_token_id,
|
|
362
|
+
}
|
|
363
|
+
)
|
|
364
|
+
return config
|
|
365
|
+
|
|
366
|
+
@preprocessing_function
|
|
367
|
+
def tokenize(self, inputs):
|
|
368
|
+
self._check_vocabulary()
|
|
369
|
+
|
|
370
|
+
if not tf.executing_eagerly() and tf.is_tensor(inputs):
|
|
371
|
+
|
|
372
|
+
def tokenize_wrapper(text_tensor):
|
|
373
|
+
text_list = (
|
|
374
|
+
text_tensor.numpy()
|
|
375
|
+
if hasattr(text_tensor, "numpy")
|
|
376
|
+
else text_tensor
|
|
377
|
+
)
|
|
378
|
+
if isinstance(text_list, bytes):
|
|
379
|
+
text_list = [text_list.decode("utf-8")]
|
|
380
|
+
elif isinstance(text_list, np.ndarray):
|
|
381
|
+
text_list = [x.decode("utf-8") for x in text_list.flatten()]
|
|
382
|
+
|
|
383
|
+
tokens = self._tokenizer.encode(text_list)
|
|
384
|
+
|
|
385
|
+
if is_string_dtype(self.dtype):
|
|
386
|
+
result = [
|
|
387
|
+
self.id_to_token(i).decode("utf-8", errors="replace")
|
|
388
|
+
for i in tokens[0]
|
|
389
|
+
]
|
|
390
|
+
return tf.constant(result, dtype=tf.string)
|
|
391
|
+
else:
|
|
392
|
+
return tf.constant(tokens[0], dtype=self.compute_dtype)
|
|
393
|
+
|
|
394
|
+
if inputs.shape.rank == 0:
|
|
395
|
+
output = tf.py_function(
|
|
396
|
+
tokenize_wrapper,
|
|
397
|
+
[inputs],
|
|
398
|
+
Tout=tf.string
|
|
399
|
+
if is_string_dtype(self.dtype)
|
|
400
|
+
else self.compute_dtype,
|
|
401
|
+
)
|
|
402
|
+
output.set_shape([None])
|
|
403
|
+
return output
|
|
404
|
+
else:
|
|
405
|
+
output = tf.map_fn(
|
|
406
|
+
lambda x: tf.py_function(
|
|
407
|
+
tokenize_wrapper,
|
|
408
|
+
[x],
|
|
409
|
+
Tout=tf.string
|
|
410
|
+
if is_string_dtype(self.dtype)
|
|
411
|
+
else self.compute_dtype,
|
|
412
|
+
),
|
|
413
|
+
inputs,
|
|
414
|
+
fn_output_signature=tf.TensorSpec(
|
|
415
|
+
[None],
|
|
416
|
+
dtype=tf.string
|
|
417
|
+
if is_string_dtype(self.dtype)
|
|
418
|
+
else self.compute_dtype,
|
|
419
|
+
),
|
|
420
|
+
)
|
|
421
|
+
return output
|
|
422
|
+
|
|
423
|
+
if tf.is_tensor(inputs):
|
|
424
|
+
inputs = tensor_to_list(inputs)
|
|
425
|
+
|
|
426
|
+
tokens = self._tokenizer.encode(inputs)
|
|
427
|
+
|
|
428
|
+
if is_string_dtype(self.dtype):
|
|
429
|
+
|
|
430
|
+
def ids_to_str(ids):
|
|
431
|
+
return [
|
|
432
|
+
self.id_to_token(i).decode("utf-8", errors="replace")
|
|
433
|
+
for i in ids
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
if isinstance(inputs, str):
|
|
437
|
+
return ids_to_str(tokens)
|
|
438
|
+
return [ids_to_str(ts) for ts in tokens]
|
|
439
|
+
|
|
440
|
+
if isinstance(inputs, str):
|
|
441
|
+
return tf.convert_to_tensor(tokens, dtype=self.compute_dtype)
|
|
442
|
+
else:
|
|
443
|
+
return tf.ragged.constant(tokens, dtype=self.compute_dtype)
|
|
444
|
+
|
|
445
|
+
@preprocessing_function
|
|
446
|
+
def detokenize(self, inputs):
|
|
447
|
+
"""Convert tokens back to text.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
inputs: Tokens to convert.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
Detokenized text.
|
|
454
|
+
"""
|
|
455
|
+
self._check_vocabulary()
|
|
456
|
+
|
|
457
|
+
if tf.is_tensor(inputs):
|
|
458
|
+
inputs = tensor_to_list(inputs)
|
|
459
|
+
|
|
460
|
+
if len(inputs) > 0 and isinstance(inputs[0], (int, np.integer)):
|
|
461
|
+
inputs = [inputs]
|
|
462
|
+
|
|
463
|
+
strip_zero_inputs = []
|
|
464
|
+
for seq in inputs:
|
|
465
|
+
if tf.is_tensor(seq):
|
|
466
|
+
seq = tensor_to_list(seq)
|
|
467
|
+
strip_zero_inputs.append([x for x in seq if x != 0])
|
|
468
|
+
|
|
469
|
+
result = self._tokenizer.decode(strip_zero_inputs)
|
|
470
|
+
|
|
471
|
+
return tf.convert_to_tensor(result, dtype=tf.string)
|
|
472
|
+
|
|
473
|
+
def compute_output_spec(self, input_spec):
|
|
474
|
+
"""Compute output specification.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
input_spec: Input specification.
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
Output tensor specification.
|
|
481
|
+
"""
|
|
482
|
+
return keras.KerasTensor(
|
|
483
|
+
input_spec.shape + (None,), dtype=self.compute_dtype
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
def call(self, inputs):
|
|
487
|
+
"""Call the tokenizer on inputs.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
inputs: Input text.
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
Tokenized output.
|
|
494
|
+
"""
|
|
495
|
+
return self.tokenize(inputs)
|
|
@@ -16,7 +16,11 @@ class SAMBackbone(Backbone):
|
|
|
16
16
|
mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
|
|
17
17
|
generate segmentation masks given the embeddings generated by the
|
|
18
18
|
backbone and the prompt encoder.
|
|
19
|
-
dtype: The dtype
|
|
19
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
20
|
+
for the models computations and weights. Note that some
|
|
21
|
+
computations, such as softmax and layer normalization will always
|
|
22
|
+
be done in float32 precision regardless of dtype. Defaults to
|
|
23
|
+
`bfloat16`.
|
|
20
24
|
|
|
21
25
|
Example:
|
|
22
26
|
```python
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
2
|
+
SAM3PromptableConceptBackbone,
|
|
3
|
+
)
|
|
4
|
+
from keras_hub.src.models.sam3.sam3_presets import backbone_presets
|
|
5
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
|
6
|
+
|
|
7
|
+
register_presets(backbone_presets, SAM3PromptableConceptBackbone)
|