keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
+
from keras_hub.src.models.backbone import Backbone
|
|
6
|
+
from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rwkv7_kernel_initializer(stddev=0.02):
|
|
10
|
+
return keras.initializers.TruncatedNormal(stddev=stddev)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras_hub_export("keras_hub.models.RWKV7Backbone")
|
|
14
|
+
class RWKV7Backbone(Backbone):
|
|
15
|
+
"""The RWKV7 Transformer core architecture with hyperparameters.
|
|
16
|
+
|
|
17
|
+
This network implements a RNN-based decoder network,
|
|
18
|
+
Goose, as described in
|
|
19
|
+
[RWKV-7](https://arxiv.org/abs/2503.14456).
|
|
20
|
+
|
|
21
|
+
This network implements a Modern RNN architecture based on linear
|
|
22
|
+
attention mechanisms with recurrent processing, as described in the
|
|
23
|
+
RWKV papers. It includes the embedding lookups and RWKV-7 blocks.
|
|
24
|
+
|
|
25
|
+
The default constructor gives a fully customizable, randomly initialized
|
|
26
|
+
RWKV-7 model with any number of layers, heads, and embedding dimensions.
|
|
27
|
+
To load preset architectures and weights, use the `from_preset`
|
|
28
|
+
constructor.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
hidden_size: int. The size of the transformer encoding and pooling
|
|
32
|
+
layers.
|
|
33
|
+
head_size: int. The size of each attention head.
|
|
34
|
+
num_layers: int. The number of transformer layers.
|
|
35
|
+
vocabulary_size: int. The size of the token vocabulary.
|
|
36
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
|
37
|
+
a two-layer feedforward network for each transformer.
|
|
38
|
+
gate_lora: int. LoRA dimension for gating. Defaults to `128` .
|
|
39
|
+
mv_lora: int. LoRA dimension for value mixing. Defaults to `32` .
|
|
40
|
+
aaa_lora: int. LoRA dimension for alpha parameters.Defaults to `64` .
|
|
41
|
+
decay_lora: int. LoRA dimension for decay parameters.Defaults to `64` .
|
|
42
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
43
|
+
for model computations and weights. Note that some computations,
|
|
44
|
+
such as softmax and layer normalization, will always be done at
|
|
45
|
+
float32 precision regardless of dtype.
|
|
46
|
+
dropout_rate: float. Dropout rate for the dropout layer.
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
input_data = np.ones(shape=(1, 12), dtype="int32")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Randomly initialized RWKV-7 decoder with custom config.
|
|
55
|
+
model = keras_hub.models.RWKV7Backbone(
|
|
56
|
+
vocabulary_size=10,
|
|
57
|
+
hidden_size=512,
|
|
58
|
+
num_layers=2,
|
|
59
|
+
head_size=64,
|
|
60
|
+
intermediate_dim=1024,
|
|
61
|
+
dtype="float32"
|
|
62
|
+
)
|
|
63
|
+
model(input_data)
|
|
64
|
+
```
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
hidden_size,
|
|
70
|
+
head_size,
|
|
71
|
+
num_layers,
|
|
72
|
+
vocabulary_size,
|
|
73
|
+
intermediate_dim,
|
|
74
|
+
gate_lora=128,
|
|
75
|
+
mv_lora=32,
|
|
76
|
+
aaa_lora=64,
|
|
77
|
+
decay_lora=64,
|
|
78
|
+
dtype=None,
|
|
79
|
+
dropout_rate=0,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
# === Layers ===
|
|
83
|
+
self.token_embedding = keras.layers.Embedding(
|
|
84
|
+
input_dim=vocabulary_size,
|
|
85
|
+
output_dim=hidden_size,
|
|
86
|
+
embeddings_initializer=rwkv7_kernel_initializer(),
|
|
87
|
+
dtype=dtype,
|
|
88
|
+
name="token_embedding",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self.output_layer_norm = keras.layers.LayerNormalization(
|
|
92
|
+
epsilon=1e-5,
|
|
93
|
+
name="output_norm",
|
|
94
|
+
dtype=dtype,
|
|
95
|
+
)
|
|
96
|
+
self.dropout = keras.layers.Dropout(
|
|
97
|
+
dropout_rate,
|
|
98
|
+
dtype=dtype,
|
|
99
|
+
name="dropout",
|
|
100
|
+
)
|
|
101
|
+
self.rwkv_layers = []
|
|
102
|
+
for i in range(num_layers):
|
|
103
|
+
layer = RWKV7_Block(
|
|
104
|
+
hidden_size,
|
|
105
|
+
head_size,
|
|
106
|
+
intermediate_dim,
|
|
107
|
+
gate_lora,
|
|
108
|
+
mv_lora,
|
|
109
|
+
aaa_lora,
|
|
110
|
+
decay_lora,
|
|
111
|
+
use_initial_norm=i == 0,
|
|
112
|
+
kernel_initializer=rwkv7_kernel_initializer(),
|
|
113
|
+
dtype=dtype,
|
|
114
|
+
name=f"rwkv_layer_{i}",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.rwkv_layers.append(layer)
|
|
118
|
+
self.head = keras.layers.Dense(
|
|
119
|
+
units=vocabulary_size,
|
|
120
|
+
kernel_initializer=rwkv7_kernel_initializer(),
|
|
121
|
+
use_bias=False,
|
|
122
|
+
name="head",
|
|
123
|
+
dtype=dtype,
|
|
124
|
+
)
|
|
125
|
+
# === Functional Model ===
|
|
126
|
+
token_id_input = keras.Input(
|
|
127
|
+
shape=(None,), dtype="int32", name="token_ids"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
padding_mask_input = keras.Input(
|
|
131
|
+
shape=(None,), dtype="int32", name="padding_mask"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
x = self.token_embedding(token_id_input)
|
|
135
|
+
padding_mask = ops.cast(padding_mask_input, dtype=x.dtype)
|
|
136
|
+
v_first = None
|
|
137
|
+
for rwkv_layer in self.rwkv_layers:
|
|
138
|
+
x, v_first = rwkv_layer(x, v_first, padding_mask)
|
|
139
|
+
x = self.dropout(x)
|
|
140
|
+
sequence_output = self.output_layer_norm(x)
|
|
141
|
+
sequence_output = self.head(sequence_output)
|
|
142
|
+
|
|
143
|
+
super().__init__(
|
|
144
|
+
inputs={
|
|
145
|
+
"token_ids": token_id_input,
|
|
146
|
+
"padding_mask": padding_mask_input,
|
|
147
|
+
},
|
|
148
|
+
outputs=sequence_output,
|
|
149
|
+
dtype=dtype,
|
|
150
|
+
**kwargs,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self.num_layers = num_layers
|
|
154
|
+
self.head_size = head_size
|
|
155
|
+
self.hidden_size = hidden_size
|
|
156
|
+
self.gate_lora = gate_lora
|
|
157
|
+
self.mv_lora = mv_lora
|
|
158
|
+
self.aaa_lora = aaa_lora
|
|
159
|
+
self.decay_lora = decay_lora
|
|
160
|
+
self.vocabulary_size = vocabulary_size
|
|
161
|
+
self.dropout_rate = dropout_rate
|
|
162
|
+
self.intermediate_dim = intermediate_dim
|
|
163
|
+
|
|
164
|
+
def get_config(self):
|
|
165
|
+
config = super().get_config()
|
|
166
|
+
config.update(
|
|
167
|
+
{
|
|
168
|
+
"hidden_size": self.hidden_size,
|
|
169
|
+
"head_size": self.head_size,
|
|
170
|
+
"gate_lora": self.gate_lora,
|
|
171
|
+
"mv_lora": self.mv_lora,
|
|
172
|
+
"aaa_lora": self.aaa_lora,
|
|
173
|
+
"decay_lora": self.decay_lora,
|
|
174
|
+
"vocabulary_size": self.vocabulary_size,
|
|
175
|
+
"dropout_rate": self.dropout_rate,
|
|
176
|
+
"intermediate_dim": self.intermediate_dim,
|
|
177
|
+
"num_layers": self.num_layers,
|
|
178
|
+
}
|
|
179
|
+
)
|
|
180
|
+
return config
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.models.causal_lm import CausalLM
|
|
5
|
+
from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
|
|
6
|
+
from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
|
|
7
|
+
RWKV7CausalLMPreprocessor,
|
|
8
|
+
)
|
|
9
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_hub_export("keras_hub.models.RWKV7CausalLM")
|
|
13
|
+
class RWKV7CausalLM(CausalLM):
|
|
14
|
+
"""An end-to-end RWKV-7 model for causal language modeling.
|
|
15
|
+
|
|
16
|
+
A causal language model (LM) predicts the next token based on previous
|
|
17
|
+
tokens. This task setup can be used to train the model unsupervised on
|
|
18
|
+
plain text input, or to autoregressively generate plain text similar to
|
|
19
|
+
the data used for training. This task can be used for pre-training or
|
|
20
|
+
fine-tuning a RWKV-7 model, simply by calling `fit()`.
|
|
21
|
+
|
|
22
|
+
This model has a generate() method, which generates text based on a
|
|
23
|
+
prompt. The generation strategy used is controlled by an additional
|
|
24
|
+
sampler argument on `compile()`. You can recompile the model with
|
|
25
|
+
different `keras_hub.samplers` objects to control the generation. By
|
|
26
|
+
default, `"greedy"` sampling will be used.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
backbone: A `keras_hub.models.RWKV7Backbone` instance.
|
|
30
|
+
preprocessor: A `keras_hub.models.RWKV7CausalLMPreprocessor` or `None`.
|
|
31
|
+
If `None`, this model will not apply preprocessing, and inputs
|
|
32
|
+
should be preprocessed before calling the model.
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
```python
|
|
36
|
+
# Initialize the tokenizer and load assets from a local path.
|
|
37
|
+
tokenizer = RWKVTokenizer()
|
|
38
|
+
tokenizer.load_assets(rwkv_path)
|
|
39
|
+
|
|
40
|
+
# Create a preprocessor with a sequence length of 8.
|
|
41
|
+
preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8)
|
|
42
|
+
|
|
43
|
+
# Initialize the model with a backbone and preprocessor.
|
|
44
|
+
causal_lm = RWKV7CausalLM(backbone, preprocessor)
|
|
45
|
+
|
|
46
|
+
# you also can load model by from_preset
|
|
47
|
+
rwkv_path = "RWKV7_G1a_0.1B"
|
|
48
|
+
tokenizer = RWKVTokenizer.from_preset(rwkv_path)
|
|
49
|
+
causal_lm = RWKV7CausalLM.from_preset(rwkv_path)
|
|
50
|
+
|
|
51
|
+
prompts = ["Bubble sort\n```python", "Hello World\n```python\n"]
|
|
52
|
+
|
|
53
|
+
causal_lm.compile(sampler="greedy")
|
|
54
|
+
|
|
55
|
+
outputs = causal_lm.generate(prompts, max_length=128)
|
|
56
|
+
for out in outputs:
|
|
57
|
+
print(out)
|
|
58
|
+
print("-" * 100)
|
|
59
|
+
```
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
backbone_cls = RWKV7Backbone
|
|
63
|
+
preprocessor_cls = RWKV7CausalLMPreprocessor
|
|
64
|
+
|
|
65
|
+
def __init__(self, backbone, preprocessor=None, **kwargs):
|
|
66
|
+
"""Initialize the RWKV-7 causal language model.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
backbone: The backbone model.
|
|
70
|
+
preprocessor: The preprocessor for tokenization.
|
|
71
|
+
**kwargs: Additional keyword arguments.
|
|
72
|
+
"""
|
|
73
|
+
# === Layers ===
|
|
74
|
+
self.backbone = backbone
|
|
75
|
+
self.preprocessor = preprocessor
|
|
76
|
+
super().__init__(
|
|
77
|
+
inputs=backbone.input,
|
|
78
|
+
outputs=backbone.output,
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def call_with_cache(
|
|
83
|
+
self,
|
|
84
|
+
token_ids,
|
|
85
|
+
cache,
|
|
86
|
+
compute_head=True,
|
|
87
|
+
padding_mask=None,
|
|
88
|
+
rnn_mode=True,
|
|
89
|
+
):
|
|
90
|
+
"""Forward pass of `RWKV7CausalLM` with cache.
|
|
91
|
+
|
|
92
|
+
`call_with_cache` adds an additional forward pass for the model for
|
|
93
|
+
autoregressive inference. Unlike calling the model directly, this method
|
|
94
|
+
allows caching previous state Tensors in RWKV layers, and avoids
|
|
95
|
+
recomputing the outputs of seen tokens.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
|
99
|
+
cache: a dense float Tensor, the cache of state and token values.
|
|
100
|
+
compute_head: bool, whether to compute the output head.
|
|
101
|
+
padding_mask: a dense bool Tensor, the padding mask.
|
|
102
|
+
rnn_mode: bool, whether to use RNN mode.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
|
106
|
+
language model logits for the input token_ids, `hidden_states` is
|
|
107
|
+
the final hidden representation of the input tokens, and `cache` is
|
|
108
|
+
the decoding cache.
|
|
109
|
+
"""
|
|
110
|
+
state_cache, last_token_cache = cache
|
|
111
|
+
x = self.backbone.token_embedding(token_ids)
|
|
112
|
+
if padding_mask is None:
|
|
113
|
+
padding_mask = ops.not_equal(token_ids, 0)
|
|
114
|
+
padding_mask = ops.cast(padding_mask, x.dtype)
|
|
115
|
+
v_first = None
|
|
116
|
+
updated_state_cache = []
|
|
117
|
+
updated_last_token_cache = []
|
|
118
|
+
|
|
119
|
+
for i in range(self.backbone.num_layers):
|
|
120
|
+
current_state_cache = state_cache[:, i, ...]
|
|
121
|
+
current_token_cache = last_token_cache[:, i, ...]
|
|
122
|
+
x, v_first, new_cache_state, cache_tmix_x, cache_cmix_x = (
|
|
123
|
+
self.backbone.rwkv_layers[i].generate_call(
|
|
124
|
+
x,
|
|
125
|
+
v_first=v_first,
|
|
126
|
+
padding_mask=padding_mask,
|
|
127
|
+
cache_state=current_state_cache,
|
|
128
|
+
cache_tmix_x=current_token_cache[:, 0],
|
|
129
|
+
cache_cmix_x=current_token_cache[:, 1],
|
|
130
|
+
rnn_mode=rnn_mode,
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
new_token_cache = ops.stack([cache_tmix_x, cache_cmix_x], axis=1)
|
|
134
|
+
updated_state_cache.append(new_cache_state)
|
|
135
|
+
updated_last_token_cache.append(new_token_cache)
|
|
136
|
+
cache = [
|
|
137
|
+
ops.stack(updated_state_cache, axis=1),
|
|
138
|
+
ops.stack(updated_last_token_cache, axis=1),
|
|
139
|
+
]
|
|
140
|
+
hidden_states = x = self.backbone.output_layer_norm(x)
|
|
141
|
+
if compute_head:
|
|
142
|
+
logits = self.backbone.head(x)
|
|
143
|
+
else:
|
|
144
|
+
logits = None
|
|
145
|
+
return logits, hidden_states, cache
|
|
146
|
+
|
|
147
|
+
def _build_cache(self, token_ids, padding_mask):
|
|
148
|
+
"""Build an empty cache for use with `call_with_cache()`."""
|
|
149
|
+
batch_size = ops.shape(token_ids)[0]
|
|
150
|
+
num_layers = self.backbone.num_layers
|
|
151
|
+
head_dim = self.backbone.head_size
|
|
152
|
+
hidden_size = self.backbone.hidden_size
|
|
153
|
+
num_heads = hidden_size // head_dim
|
|
154
|
+
|
|
155
|
+
state_cache = ops.zeros(
|
|
156
|
+
[batch_size, num_layers, num_heads, head_dim, head_dim],
|
|
157
|
+
dtype=self.compute_dtype,
|
|
158
|
+
)
|
|
159
|
+
last_token_cache = ops.zeros(
|
|
160
|
+
[batch_size, num_layers, 2, 1, hidden_size],
|
|
161
|
+
dtype=self.compute_dtype,
|
|
162
|
+
)
|
|
163
|
+
cache = [state_cache, last_token_cache]
|
|
164
|
+
|
|
165
|
+
# Seed the cache.
|
|
166
|
+
# Prefill stage can use kernel for better performance
|
|
167
|
+
_, hidden_states, cache = self.call_with_cache(
|
|
168
|
+
token_ids,
|
|
169
|
+
cache,
|
|
170
|
+
rnn_mode=False,
|
|
171
|
+
compute_head=False,
|
|
172
|
+
padding_mask=padding_mask,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return hidden_states, cache
|
|
176
|
+
|
|
177
|
+
def generate_step(
|
|
178
|
+
self,
|
|
179
|
+
inputs,
|
|
180
|
+
stop_token_ids=None,
|
|
181
|
+
):
|
|
182
|
+
"""A compilable generation function for a single batch of inputs.
|
|
183
|
+
|
|
184
|
+
This function represents the inner, XLA-compilable, generation function
|
|
185
|
+
for a single batch of inputs. Inputs should have the same structure as
|
|
186
|
+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
inputs: A dictionary with keys `"token_ids"`, `"padding_mask"`, and
|
|
190
|
+
`"predict_token_ids"` with batched tensor values.
|
|
191
|
+
stop_token_ids: Tuple of id's of the end token to stop on. If all
|
|
192
|
+
sequences have produced a new stop token, generation
|
|
193
|
+
will stop.
|
|
194
|
+
"""
|
|
195
|
+
token_ids, padding_mask, predict_token_ids = (
|
|
196
|
+
inputs["token_ids"],
|
|
197
|
+
inputs["padding_mask"],
|
|
198
|
+
inputs["predict_token_ids"],
|
|
199
|
+
)
|
|
200
|
+
# Create and seed cache with a single forward pass.
|
|
201
|
+
|
|
202
|
+
hidden_states, cache = self._build_cache(
|
|
203
|
+
token_ids, inputs["input_padding_mask"]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def next(prompt, cache, index):
|
|
207
|
+
# The cache index is the index of our previous token.
|
|
208
|
+
cache_update_index = index - 1
|
|
209
|
+
batch_size = ops.shape(prompt)[0]
|
|
210
|
+
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
|
|
211
|
+
logits, hidden_states, cache = self.call_with_cache(
|
|
212
|
+
prompt,
|
|
213
|
+
cache,
|
|
214
|
+
)
|
|
215
|
+
return (
|
|
216
|
+
ops.squeeze(logits, axis=1),
|
|
217
|
+
ops.squeeze(hidden_states, axis=1),
|
|
218
|
+
cache,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
output_ids = self.sampler(
|
|
222
|
+
next=next,
|
|
223
|
+
prompt=predict_token_ids,
|
|
224
|
+
cache=cache,
|
|
225
|
+
index=1,
|
|
226
|
+
mask=padding_mask,
|
|
227
|
+
stop_token_ids=stop_token_ids,
|
|
228
|
+
hidden_states=hidden_states,
|
|
229
|
+
model=self,
|
|
230
|
+
)
|
|
231
|
+
padding_mask = ops.concatenate(
|
|
232
|
+
[
|
|
233
|
+
ops.cast(ops.not_equal(token_ids, 0), padding_mask.dtype),
|
|
234
|
+
padding_mask,
|
|
235
|
+
],
|
|
236
|
+
axis=1,
|
|
237
|
+
)
|
|
238
|
+
token_ids = ops.concatenate([token_ids, output_ids], axis=1)
|
|
239
|
+
|
|
240
|
+
# Compute an output padding mask with the token ids we updated.
|
|
241
|
+
if stop_token_ids is not None:
|
|
242
|
+
# Build a mask of stop token locations not in the original
|
|
243
|
+
# prompt (not in locations where `padding_mask` is True).
|
|
244
|
+
end_locations = any_equal(
|
|
245
|
+
token_ids, stop_token_ids, ops.logical_not(padding_mask)
|
|
246
|
+
)
|
|
247
|
+
end_locations = ops.cast(end_locations, "int32")
|
|
248
|
+
# Use cumsum to get ones in all locations after end_locations.
|
|
249
|
+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
|
|
250
|
+
overflow = cumsum - end_locations
|
|
251
|
+
# Our padding mask is the inverse of these overflow locations.
|
|
252
|
+
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
|
|
253
|
+
else:
|
|
254
|
+
# Without early stopping, all locations will have been updated.
|
|
255
|
+
padding_mask = ops.ones_like(token_ids, dtype="bool")
|
|
256
|
+
return {
|
|
257
|
+
"token_ids": token_ids,
|
|
258
|
+
"padding_mask": ops.cast(padding_mask, token_ids.dtype),
|
|
259
|
+
}
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
|
6
|
+
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
|
|
7
|
+
from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
|
|
8
|
+
from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer
|
|
9
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
10
|
+
from keras_hub.src.utils.tensor_utils import tf
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor")
|
|
14
|
+
class RWKV7CausalLMPreprocessor(CausalLMPreprocessor):
|
|
15
|
+
"""RWKV-7 Causal LM preprocessor.
|
|
16
|
+
|
|
17
|
+
This preprocessing layer is meant for use with
|
|
18
|
+
`keras_hub.models.RWKV7CausalLM`. By default, it will take in batches of
|
|
19
|
+
strings, and return outputs in a `(x, y, sample_weight)` format, where the
|
|
20
|
+
`y` label is the next token id in the `x` sequence.
|
|
21
|
+
|
|
22
|
+
For use with generation, the layer also exposes two methods
|
|
23
|
+
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
|
|
24
|
+
is attached to a `keras_hub.models.RWKV7CausalLM` instance, these methods
|
|
25
|
+
will be called implicitly in generate(). They can also be called
|
|
26
|
+
standalone (e.g. to precompute preprocessing inputs for generation in a
|
|
27
|
+
separate process).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
tokenizer: A `keras_hub.models.RWKVTokenizer` instance.
|
|
31
|
+
sequence_length: The length of the packed inputs.
|
|
32
|
+
add_start_token: If `True`, the preprocessor will prepend the tokenizer
|
|
33
|
+
start token to each input sequence. Default is `False`.
|
|
34
|
+
|
|
35
|
+
Call arguments:
|
|
36
|
+
x: A string, `tf.Tensor` or list of python strings.
|
|
37
|
+
y: Label data. Should always be `None` as the layer generates labels.
|
|
38
|
+
sample_weight: Label weights. Should always be `None` as the layer
|
|
39
|
+
generates label weights.
|
|
40
|
+
sequence_length: Pass to override the configured sequence_length of
|
|
41
|
+
the layer.
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
```python
|
|
46
|
+
# Initialize the tokenizer and load assets from a local path.
|
|
47
|
+
tokenizer = RWKVTokenizer()
|
|
48
|
+
tokenizer.load_assets(rwkv_path)
|
|
49
|
+
|
|
50
|
+
# Create a preprocessor with a sequence length of 8.
|
|
51
|
+
preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8)
|
|
52
|
+
|
|
53
|
+
# Tokenize and pack a batch of sentences.
|
|
54
|
+
preprocessor(["Bubble sort\n```python", "Hello World\n```python\n"])
|
|
55
|
+
|
|
56
|
+
# Preprocess inputs for generation with a maximum generation length of 16.
|
|
57
|
+
preprocessor.generate_preprocess(
|
|
58
|
+
["Bubble sort\n```python", "Hello World\n```python\n"], 16
|
|
59
|
+
)
|
|
60
|
+
```
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
backbone_cls = RWKV7Backbone
|
|
64
|
+
tokenizer_cls = RWKVTokenizer
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
tokenizer,
|
|
69
|
+
add_start_token=False,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
"""Initialize the preprocessor.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
tokenizer: The tokenizer to use.
|
|
76
|
+
add_start_token: Whether to add start token.
|
|
77
|
+
**kwargs: Additional arguments.
|
|
78
|
+
"""
|
|
79
|
+
super().__init__(
|
|
80
|
+
tokenizer=tokenizer, add_start_token=add_start_token, **kwargs
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@preprocessing_function
|
|
84
|
+
def call(
|
|
85
|
+
self,
|
|
86
|
+
x,
|
|
87
|
+
y=None,
|
|
88
|
+
sample_weight=None,
|
|
89
|
+
sequence_length=None,
|
|
90
|
+
):
|
|
91
|
+
"""Preprocess the input for training.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
x: Input text data.
|
|
95
|
+
y: Target data (optional).
|
|
96
|
+
sample_weight: Sample weights (optional).
|
|
97
|
+
sequence_length: Desired sequence length.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Preprocessed data tuple (x, y, sample_weight).
|
|
101
|
+
"""
|
|
102
|
+
if isinstance(x, str):
|
|
103
|
+
x = [x]
|
|
104
|
+
sequence_length = sequence_length or self.sequence_length
|
|
105
|
+
# Pad length to multiples of 16 to meet kernel requirements
|
|
106
|
+
if sequence_length is None:
|
|
107
|
+
raise ValueError("sequence_length must be specified.")
|
|
108
|
+
if keras.config.backend() in ["torch", "jax"]:
|
|
109
|
+
# When using rwkv_ops, ensure sequence_length is divisible by 16.
|
|
110
|
+
try:
|
|
111
|
+
import rwkv_ops # noqa: F401
|
|
112
|
+
|
|
113
|
+
if sequence_length % 16 != 0:
|
|
114
|
+
sequence_length += (16 - sequence_length % 16) % 16
|
|
115
|
+
except ImportError:
|
|
116
|
+
pass
|
|
117
|
+
x = self.tokenizer(x)
|
|
118
|
+
|
|
119
|
+
token_ids, padding_mask = self.packer(
|
|
120
|
+
x, sequence_length=sequence_length + 1, add_end_value=False
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# The last token does not have a next token, so we truncate it out.
|
|
124
|
+
x = {
|
|
125
|
+
"token_ids": token_ids[..., :-1],
|
|
126
|
+
"padding_mask": padding_mask[..., :-1],
|
|
127
|
+
}
|
|
128
|
+
# Target `y` will be the next token.
|
|
129
|
+
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
|
|
130
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
|
131
|
+
|
|
132
|
+
def build(self, input_shape):
|
|
133
|
+
self.packer = StartEndPacker(
|
|
134
|
+
start_value=None,
|
|
135
|
+
end_value=None,
|
|
136
|
+
pad_value=self.tokenizer.pad_token_id,
|
|
137
|
+
sequence_length=self.sequence_length,
|
|
138
|
+
return_padding_mask=True,
|
|
139
|
+
padding_side="left", # RWKV uses left-padding exclusively
|
|
140
|
+
)
|
|
141
|
+
self.built = True
|
|
142
|
+
|
|
143
|
+
@preprocessing_function
|
|
144
|
+
def generate_preprocess(
|
|
145
|
+
self,
|
|
146
|
+
x,
|
|
147
|
+
sequence_length=None,
|
|
148
|
+
):
|
|
149
|
+
"""Preprocess input for generation.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
x: Input text data.
|
|
153
|
+
sequence_length: Maximum generation length.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Dictionary with preprocessed inputs for generation.
|
|
157
|
+
"""
|
|
158
|
+
if isinstance(x, str):
|
|
159
|
+
x = [x]
|
|
160
|
+
|
|
161
|
+
if not self.built:
|
|
162
|
+
self.build(None)
|
|
163
|
+
# Align with Keras API
|
|
164
|
+
# Input sequence_length is the maximum generation length
|
|
165
|
+
# While self.sequence_length corresponds to the prefill max length
|
|
166
|
+
generate_length = sequence_length
|
|
167
|
+
if sequence_length is None:
|
|
168
|
+
raise ValueError("`sequence_length` must be specified.")
|
|
169
|
+
sequence_length = self.sequence_length
|
|
170
|
+
|
|
171
|
+
x = [t[-sequence_length:] for t in self.tokenizer(x)]
|
|
172
|
+
y = tf.zeros((len(x), generate_length), "int32")
|
|
173
|
+
# Utilize RNN characteristics where prefill and decode are two sequences
|
|
174
|
+
# But the first token of decode should be the last token of prefill
|
|
175
|
+
start_token = [[t[-1]] for t in x]
|
|
176
|
+
x = [np.array(t[:-1]) if len(t) > 1 else [0] for t in x]
|
|
177
|
+
x = tf.ragged.constant(x)
|
|
178
|
+
token_ids, input_padding_mask = self.packer(
|
|
179
|
+
x, sequence_length=sequence_length, add_end_value=False
|
|
180
|
+
)
|
|
181
|
+
start_token = tf.convert_to_tensor(start_token, "int32")
|
|
182
|
+
|
|
183
|
+
y = tf.concat([start_token, y], axis=1)
|
|
184
|
+
padding_mask = tf.not_equal(y, 0)
|
|
185
|
+
|
|
186
|
+
return {
|
|
187
|
+
"token_ids": token_ids,
|
|
188
|
+
"input_padding_mask": input_padding_mask,
|
|
189
|
+
"padding_mask": padding_mask,
|
|
190
|
+
"predict_token_ids": y,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
@preprocessing_function
|
|
194
|
+
def generate_postprocess(
|
|
195
|
+
self,
|
|
196
|
+
x,
|
|
197
|
+
):
|
|
198
|
+
"""Convert integer token output to strings for generation.
|
|
199
|
+
|
|
200
|
+
This method reverses `generate_preprocess()`, by first removing all
|
|
201
|
+
padding and start/end tokens, and then converting the integer sequence
|
|
202
|
+
back to a string.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
x: Dictionary containing token_ids and padding_mask.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Detokenized string output.
|
|
209
|
+
"""
|
|
210
|
+
if not self.built:
|
|
211
|
+
self.build(None)
|
|
212
|
+
|
|
213
|
+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
|
|
214
|
+
return self.tokenizer.detokenize(token_ids * padding_mask)
|