keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import initializers
|
|
5
|
+
from keras import ops
|
|
6
|
+
from keras.layers import Layer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def transpose_head(x, head_first):
|
|
10
|
+
x = ops.cast(x, dtype="float32")
|
|
11
|
+
if head_first:
|
|
12
|
+
return ops.transpose(x, (0, 2, 1, 3))
|
|
13
|
+
else:
|
|
14
|
+
return x
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def rnn_generalized_delta_rule(
|
|
18
|
+
r,
|
|
19
|
+
w,
|
|
20
|
+
k,
|
|
21
|
+
v,
|
|
22
|
+
a,
|
|
23
|
+
b,
|
|
24
|
+
initial_state=None,
|
|
25
|
+
output_final_state=True,
|
|
26
|
+
head_first=False,
|
|
27
|
+
):
|
|
28
|
+
DTYPE = r.dtype
|
|
29
|
+
B, T, H, N = ops.shape(r)
|
|
30
|
+
r = transpose_head(r, head_first)
|
|
31
|
+
|
|
32
|
+
k = transpose_head(k, head_first)
|
|
33
|
+
|
|
34
|
+
v = transpose_head(v, head_first)
|
|
35
|
+
a = transpose_head(a, head_first)
|
|
36
|
+
b = transpose_head(b, head_first)
|
|
37
|
+
w = transpose_head(w, head_first)
|
|
38
|
+
w = ops.exp(-ops.exp(w))
|
|
39
|
+
|
|
40
|
+
if initial_state is not None:
|
|
41
|
+
state = initial_state
|
|
42
|
+
if ops.shape(state)[0] == 1:
|
|
43
|
+
state = ops.broadcast_to(state, (B, H, N, N))
|
|
44
|
+
else:
|
|
45
|
+
state = ops.zeros((B, H, N, N))
|
|
46
|
+
state = ops.cast(state, "float32")
|
|
47
|
+
|
|
48
|
+
keras_backend = keras.config.backend()
|
|
49
|
+
|
|
50
|
+
def step(t, inputs):
|
|
51
|
+
state, out = inputs
|
|
52
|
+
kk = ops.reshape(k[:, t, :], (B, H, 1, N))
|
|
53
|
+
rr = ops.reshape(r[:, t, :], (B, H, N, 1))
|
|
54
|
+
vv = ops.reshape(v[:, t, :], (B, H, N, 1))
|
|
55
|
+
aa = ops.reshape(a[:, t, :], (B, H, N, 1))
|
|
56
|
+
bb = ops.reshape(b[:, t, :], (B, H, 1, N))
|
|
57
|
+
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
|
|
58
|
+
o = ops.cast((state @ rr), out.dtype)
|
|
59
|
+
if keras_backend == "tensorflow":
|
|
60
|
+
out = out.write(t, ops.reshape(o, (B, H, N)))
|
|
61
|
+
elif keras_backend == "torch":
|
|
62
|
+
out[:, t : t + 1] = ops.reshape(o, (B, 1, H, N))
|
|
63
|
+
else:
|
|
64
|
+
out = ops.slice_update(
|
|
65
|
+
out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N))
|
|
66
|
+
)
|
|
67
|
+
return [state, out]
|
|
68
|
+
|
|
69
|
+
if keras_backend == "tensorflow":
|
|
70
|
+
# slice_update has no gradient in the TensorFlow backend
|
|
71
|
+
import tensorflow as tf
|
|
72
|
+
|
|
73
|
+
out = tf.TensorArray(DTYPE, size=T)
|
|
74
|
+
for t in range(T):
|
|
75
|
+
state, out = step(t, [state, out])
|
|
76
|
+
out = ops.transpose(out.stack(), [1, 0, 2, 3])
|
|
77
|
+
|
|
78
|
+
else:
|
|
79
|
+
out = ops.zeros((B, T, H, N), DTYPE)
|
|
80
|
+
state, out = ops.fori_loop(0, T, step, [state, out])
|
|
81
|
+
|
|
82
|
+
if output_final_state:
|
|
83
|
+
return ops.cast(out, DTYPE), state
|
|
84
|
+
return ops.cast(out, DTYPE)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class TimeShift(Layer):
|
|
88
|
+
"""Time shift layer that shifts input sequence by one step.
|
|
89
|
+
It also be called short conv
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def call(self, inputs, cache_x=None):
|
|
93
|
+
if cache_x is not None:
|
|
94
|
+
x = ops.concatenate([cache_x, inputs], axis=1)
|
|
95
|
+
else:
|
|
96
|
+
x = ops.pad(inputs, [[0, 0], [1, 0], [0, 0]], constant_values=0.0)
|
|
97
|
+
return x[:, :-1, :]
|
|
98
|
+
|
|
99
|
+
def compute_output_shape(self, input_shape):
|
|
100
|
+
return input_shape
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class RWKV7ChannelMix(Layer):
|
|
104
|
+
"""RWKV-7 channel mixing layer."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, dim_ffn, kernel_initializer="glorot_uniform", **kwargs):
|
|
107
|
+
"""Initialize RWKV7 channel mixer.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
dim_ffn: Feed-forward dimension.
|
|
111
|
+
kernel_initializer: Weight initializer.
|
|
112
|
+
**kwargs: Additional layer arguments.
|
|
113
|
+
"""
|
|
114
|
+
super().__init__(**kwargs)
|
|
115
|
+
self.dim_ffn = dim_ffn
|
|
116
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
117
|
+
|
|
118
|
+
def call(self, x, last_cache_x=None, not_generation_mode=True):
|
|
119
|
+
"""Process input through channel mixer.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
x: Input tensor.
|
|
123
|
+
last_cache_x: Cached previous values.
|
|
124
|
+
not_generation_mode: Whether in generate mode.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Mixed output tensor.
|
|
128
|
+
"""
|
|
129
|
+
xx = self.time_shift(x, last_cache_x) - x
|
|
130
|
+
if last_cache_x is not None or not not_generation_mode:
|
|
131
|
+
last_cache_x = x[:, -1:]
|
|
132
|
+
k = x + xx * self.x_k
|
|
133
|
+
k = ops.relu(self.key(k)) ** 2
|
|
134
|
+
output = self.value(k)
|
|
135
|
+
if not_generation_mode:
|
|
136
|
+
return output
|
|
137
|
+
return output, last_cache_x
|
|
138
|
+
|
|
139
|
+
def compute_output_shape(self, input_shape):
|
|
140
|
+
if isinstance(input_shape, list):
|
|
141
|
+
return input_shape[0]
|
|
142
|
+
return input_shape
|
|
143
|
+
|
|
144
|
+
def build(self, input_shape):
|
|
145
|
+
super().build(input_shape)
|
|
146
|
+
if isinstance(input_shape, list):
|
|
147
|
+
input_shape = input_shape[0]
|
|
148
|
+
|
|
149
|
+
self.x_k = self.add_weight(
|
|
150
|
+
shape=(input_shape[-1],),
|
|
151
|
+
name="time_mix_k",
|
|
152
|
+
initializer=self.kernel_initializer,
|
|
153
|
+
)
|
|
154
|
+
self.time_shift = TimeShift(
|
|
155
|
+
dtype=self.dtype_policy,
|
|
156
|
+
)
|
|
157
|
+
self.key = keras.layers.Dense(
|
|
158
|
+
self.dim_ffn,
|
|
159
|
+
use_bias=False,
|
|
160
|
+
name="dense_k",
|
|
161
|
+
kernel_initializer=self.kernel_initializer,
|
|
162
|
+
dtype=self.dtype_policy,
|
|
163
|
+
)
|
|
164
|
+
self.value = keras.layers.Dense(
|
|
165
|
+
input_shape[-1],
|
|
166
|
+
use_bias=False,
|
|
167
|
+
name="dense_v",
|
|
168
|
+
kernel_initializer=self.kernel_initializer,
|
|
169
|
+
dtype=self.dtype_policy,
|
|
170
|
+
)
|
|
171
|
+
self.key.build(input_shape)
|
|
172
|
+
self.value.build([None, None, self.dim_ffn])
|
|
173
|
+
|
|
174
|
+
def get_config(self):
|
|
175
|
+
config = super().get_config()
|
|
176
|
+
config.update(
|
|
177
|
+
{
|
|
178
|
+
"dim_ffn": self.dim_ffn,
|
|
179
|
+
"kernel_initializer": initializers.serialize(
|
|
180
|
+
self.kernel_initializer
|
|
181
|
+
),
|
|
182
|
+
}
|
|
183
|
+
)
|
|
184
|
+
return config
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class RWKV7TimeMix(Layer):
|
|
188
|
+
"""RWKV-7 time mixing layer."""
|
|
189
|
+
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
hidden_size,
|
|
193
|
+
head_size,
|
|
194
|
+
gate_lora=128,
|
|
195
|
+
mv_lora=32,
|
|
196
|
+
aaa_lora=64,
|
|
197
|
+
decay_lora=64,
|
|
198
|
+
kernel_initializer="glorot_uniform",
|
|
199
|
+
add_v_first=True,
|
|
200
|
+
**kwargs,
|
|
201
|
+
):
|
|
202
|
+
"""Initialize RWKV7 time mixer.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
hidden_size: Hidden dimension size.
|
|
206
|
+
head_size: Attention head size.
|
|
207
|
+
gate_lora: LoRA dimension for gating.
|
|
208
|
+
mv_lora: LoRA dimension for value mixing.
|
|
209
|
+
aaa_lora: LoRA dimension for alpha parameters.
|
|
210
|
+
decay_lora: LoRA dimension for decay parameters.
|
|
211
|
+
kernel_initializer: Weight initializer.
|
|
212
|
+
**kwargs: Additional layer arguments.
|
|
213
|
+
"""
|
|
214
|
+
super().__init__(**kwargs)
|
|
215
|
+
self.head_size = head_size
|
|
216
|
+
self.hidden_size = hidden_size
|
|
217
|
+
self.n_head = hidden_size // self.head_size
|
|
218
|
+
self.gate_lora = gate_lora
|
|
219
|
+
self.mv_lora = mv_lora
|
|
220
|
+
self.aaa_lora = aaa_lora
|
|
221
|
+
self.decay_lora = decay_lora
|
|
222
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
223
|
+
self.add_v_first = add_v_first
|
|
224
|
+
self.initial_state = None
|
|
225
|
+
|
|
226
|
+
self.RWKV7_OP = rnn_generalized_delta_rule
|
|
227
|
+
self.RWKV7_OP_RNN = rnn_generalized_delta_rule
|
|
228
|
+
self.RWKV7_OP_INFERENCE = rnn_generalized_delta_rule
|
|
229
|
+
if keras.config.backend() in ["torch", "jax"]:
|
|
230
|
+
# only torch and jax support cuda kernel speedup
|
|
231
|
+
try:
|
|
232
|
+
from rwkv_ops import rwkv7_op
|
|
233
|
+
from rwkv_ops import rwkv7_op_inference
|
|
234
|
+
from rwkv_ops import rwkv7_op_rnn
|
|
235
|
+
|
|
236
|
+
self.RWKV7_OP = rwkv7_op
|
|
237
|
+
# faster inference op
|
|
238
|
+
self.RWKV7_OP_INFERENCE = rwkv7_op_inference
|
|
239
|
+
self.RWKV7_OP_RNN = rwkv7_op_rnn
|
|
240
|
+
except ImportError:
|
|
241
|
+
warnings.warn(
|
|
242
|
+
"The 'rwkv_ops' package is not installed. "
|
|
243
|
+
"Falling back to the default (pure-Python) operators"
|
|
244
|
+
"pure-Python which will be very slow. "
|
|
245
|
+
"Please 'pip install rwkv_ops' to enable the cuda kernels",
|
|
246
|
+
UserWarning,
|
|
247
|
+
stacklevel=2,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
assert self.hidden_size % self.n_head == 0
|
|
251
|
+
|
|
252
|
+
def build(self, input_shape):
|
|
253
|
+
super().build(input_shape)
|
|
254
|
+
if isinstance(input_shape[0], list):
|
|
255
|
+
input_shape = input_shape[0]
|
|
256
|
+
H = self.n_head
|
|
257
|
+
N = self.head_size
|
|
258
|
+
B, T, C = input_shape
|
|
259
|
+
|
|
260
|
+
self.x_r = self.add_weight(
|
|
261
|
+
shape=(C,), name="x_r", initializer=self.kernel_initializer
|
|
262
|
+
)
|
|
263
|
+
self.x_w = self.add_weight(
|
|
264
|
+
shape=(C,), name="x_w", initializer=self.kernel_initializer
|
|
265
|
+
)
|
|
266
|
+
self.x_k = self.add_weight(
|
|
267
|
+
shape=(C,), name="x_k", initializer=self.kernel_initializer
|
|
268
|
+
)
|
|
269
|
+
self.x_v = self.add_weight(
|
|
270
|
+
shape=(C,), name="x_v", initializer=self.kernel_initializer
|
|
271
|
+
)
|
|
272
|
+
self.x_a = self.add_weight(
|
|
273
|
+
shape=(C,), name="x_a", initializer=self.kernel_initializer
|
|
274
|
+
)
|
|
275
|
+
self.x_g = self.add_weight(
|
|
276
|
+
shape=(C,), name="x_g", initializer=self.kernel_initializer
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
self.w0 = self.add_weight(
|
|
280
|
+
shape=(C,), name="w0", initializer=self.kernel_initializer
|
|
281
|
+
)
|
|
282
|
+
self.w1 = self.add_weight(
|
|
283
|
+
shape=(C, self.decay_lora),
|
|
284
|
+
name="w1",
|
|
285
|
+
initializer=self.kernel_initializer,
|
|
286
|
+
)
|
|
287
|
+
self.w2 = self.add_weight(
|
|
288
|
+
shape=(self.decay_lora, C),
|
|
289
|
+
name="w2",
|
|
290
|
+
initializer=self.kernel_initializer,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
self.a0 = self.add_weight(
|
|
294
|
+
shape=(C,), name="a0", initializer=self.kernel_initializer
|
|
295
|
+
)
|
|
296
|
+
self.a1 = self.add_weight(
|
|
297
|
+
shape=(C, self.aaa_lora),
|
|
298
|
+
name="a1",
|
|
299
|
+
initializer=self.kernel_initializer,
|
|
300
|
+
)
|
|
301
|
+
self.a2 = self.add_weight(
|
|
302
|
+
shape=(self.aaa_lora, C),
|
|
303
|
+
name="a2",
|
|
304
|
+
initializer=self.kernel_initializer,
|
|
305
|
+
)
|
|
306
|
+
if self.add_v_first:
|
|
307
|
+
self.v0 = self.add_weight(
|
|
308
|
+
shape=(C,), name="v0", initializer=self.kernel_initializer
|
|
309
|
+
)
|
|
310
|
+
self.v1 = self.add_weight(
|
|
311
|
+
shape=(C, self.mv_lora),
|
|
312
|
+
name="v1",
|
|
313
|
+
initializer=self.kernel_initializer,
|
|
314
|
+
)
|
|
315
|
+
self.v2 = self.add_weight(
|
|
316
|
+
shape=(self.mv_lora, C),
|
|
317
|
+
name="v2",
|
|
318
|
+
initializer=self.kernel_initializer,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self.g1 = self.add_weight(
|
|
322
|
+
shape=(C, self.gate_lora),
|
|
323
|
+
name="g1",
|
|
324
|
+
initializer=self.kernel_initializer,
|
|
325
|
+
)
|
|
326
|
+
self.g2 = self.add_weight(
|
|
327
|
+
shape=(self.gate_lora, C),
|
|
328
|
+
name="g2",
|
|
329
|
+
initializer=self.kernel_initializer,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
self.k_k = self.add_weight(
|
|
333
|
+
shape=(C,), name="k_k", initializer=self.kernel_initializer
|
|
334
|
+
)
|
|
335
|
+
self.k_a = self.add_weight(
|
|
336
|
+
shape=(C,), name="k_a", initializer=self.kernel_initializer
|
|
337
|
+
)
|
|
338
|
+
self.r_k = self.add_weight(
|
|
339
|
+
shape=(H, N), name="r_k", initializer=self.kernel_initializer
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
self.time_shift = TimeShift(
|
|
343
|
+
dtype=self.dtype_policy,
|
|
344
|
+
)
|
|
345
|
+
self.receptance = keras.layers.Dense(
|
|
346
|
+
C,
|
|
347
|
+
use_bias=False,
|
|
348
|
+
kernel_initializer=self.kernel_initializer,
|
|
349
|
+
name="receptance",
|
|
350
|
+
dtype=self.dtype_policy,
|
|
351
|
+
)
|
|
352
|
+
self.key = keras.layers.Dense(
|
|
353
|
+
C,
|
|
354
|
+
use_bias=False,
|
|
355
|
+
kernel_initializer=self.kernel_initializer,
|
|
356
|
+
name="key",
|
|
357
|
+
dtype=self.dtype_policy,
|
|
358
|
+
)
|
|
359
|
+
self.value = keras.layers.Dense(
|
|
360
|
+
C,
|
|
361
|
+
use_bias=False,
|
|
362
|
+
kernel_initializer=self.kernel_initializer,
|
|
363
|
+
name="value",
|
|
364
|
+
dtype=self.dtype_policy,
|
|
365
|
+
)
|
|
366
|
+
self.output_layer = keras.layers.Dense(
|
|
367
|
+
C,
|
|
368
|
+
use_bias=False,
|
|
369
|
+
kernel_initializer=self.kernel_initializer,
|
|
370
|
+
name="output_layer",
|
|
371
|
+
dtype=self.dtype_policy,
|
|
372
|
+
)
|
|
373
|
+
self.ln_x = keras.layers.GroupNormalization(
|
|
374
|
+
groups=H,
|
|
375
|
+
epsilon=64e-5,
|
|
376
|
+
dtype=self.dtype_policy,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
self.receptance.build(input_shape)
|
|
380
|
+
self.value.build(input_shape)
|
|
381
|
+
self.key.build(input_shape)
|
|
382
|
+
self.output_layer.build(input_shape)
|
|
383
|
+
self.ln_x.build((None, C))
|
|
384
|
+
|
|
385
|
+
def call(
|
|
386
|
+
self,
|
|
387
|
+
x,
|
|
388
|
+
v_first=None,
|
|
389
|
+
padding_mask=None,
|
|
390
|
+
last_cache_x=None,
|
|
391
|
+
cache_state=None,
|
|
392
|
+
rnn_mode=False,
|
|
393
|
+
not_generation_mode=True,
|
|
394
|
+
training=None,
|
|
395
|
+
):
|
|
396
|
+
"""Process input through time mixer.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
x: Input tensor.
|
|
400
|
+
v_first: First value for mixing.
|
|
401
|
+
padding_mask: Mask for padding tokens.
|
|
402
|
+
last_cache_x: Cached previous values.
|
|
403
|
+
cache_state: Cached recurrent state.
|
|
404
|
+
rnn_mode: Whether to use RNN mode.
|
|
405
|
+
not_generation_mode: Whether in generate mode.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Mixed output tensor and state information.
|
|
409
|
+
"""
|
|
410
|
+
if cache_state is None:
|
|
411
|
+
initial_state = self.initial_state
|
|
412
|
+
else:
|
|
413
|
+
initial_state = cache_state
|
|
414
|
+
if padding_mask is not None:
|
|
415
|
+
if ops.ndim(padding_mask) == 2:
|
|
416
|
+
padding_mask = padding_mask[..., None]
|
|
417
|
+
padding_mask = ops.cast(padding_mask, x.dtype)
|
|
418
|
+
x *= padding_mask
|
|
419
|
+
B, T, C = ops.shape(x)
|
|
420
|
+
H = self.n_head
|
|
421
|
+
xx = self.time_shift(x, last_cache_x) - x
|
|
422
|
+
if last_cache_x is not None or not not_generation_mode:
|
|
423
|
+
last_cache_x = x[:, -1:]
|
|
424
|
+
if padding_mask is not None:
|
|
425
|
+
xx *= padding_mask
|
|
426
|
+
|
|
427
|
+
xr = x + xx * self.x_r
|
|
428
|
+
xw = x + xx * self.x_w
|
|
429
|
+
xk = x + xx * self.x_k
|
|
430
|
+
xv = x + xx * self.x_v
|
|
431
|
+
xa = x + xx * self.x_a
|
|
432
|
+
xg = x + xx * self.x_g
|
|
433
|
+
|
|
434
|
+
r = self.receptance(xr)
|
|
435
|
+
w = (
|
|
436
|
+
-ops.softplus(
|
|
437
|
+
-(
|
|
438
|
+
self.w0
|
|
439
|
+
+ ops.matmul(ops.tanh(ops.matmul(xw, self.w1)), self.w2)
|
|
440
|
+
)
|
|
441
|
+
)
|
|
442
|
+
- 0.5
|
|
443
|
+
) # soft-clamp to (-inf, -0.5)
|
|
444
|
+
k = self.key(xk)
|
|
445
|
+
v = self.value(xv)
|
|
446
|
+
if v_first is None or not self.add_v_first:
|
|
447
|
+
v_first = v
|
|
448
|
+
else:
|
|
449
|
+
v = v + (v_first - v) * ops.sigmoid(
|
|
450
|
+
self.v0 + ops.matmul(ops.matmul(xv, self.v1), self.v2)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
a = ops.sigmoid(
|
|
454
|
+
self.a0 + ops.matmul(ops.matmul(xa, self.a1), self.a2)
|
|
455
|
+
) # a is "in-context learning rate"
|
|
456
|
+
g = ops.matmul(ops.sigmoid(ops.matmul(xg, self.g1)), self.g2)
|
|
457
|
+
|
|
458
|
+
kk = k * self.k_k
|
|
459
|
+
|
|
460
|
+
kk = self.normalize(ops.reshape(kk, (B, T, H, -1)))
|
|
461
|
+
kk = ops.reshape(kk, (B, T, C))
|
|
462
|
+
|
|
463
|
+
k = k * (1 + (a - 1) * self.k_a)
|
|
464
|
+
if padding_mask is not None:
|
|
465
|
+
w = ops.where(padding_mask, w, -1e9)
|
|
466
|
+
if training:
|
|
467
|
+
rwkv7_op = self.RWKV7_OP
|
|
468
|
+
elif rnn_mode:
|
|
469
|
+
# T=1,single step
|
|
470
|
+
rwkv7_op = self.RWKV7_OP_RNN
|
|
471
|
+
else:
|
|
472
|
+
rwkv7_op = self.RWKV7_OP_INFERENCE
|
|
473
|
+
|
|
474
|
+
x, final_state = rwkv7_op(
|
|
475
|
+
ops.reshape(r, (B, T, self.n_head, self.head_size)),
|
|
476
|
+
ops.reshape(w, (B, T, self.n_head, self.head_size)),
|
|
477
|
+
ops.reshape(k, (B, T, self.n_head, self.head_size)),
|
|
478
|
+
ops.reshape(v, (B, T, self.n_head, self.head_size)),
|
|
479
|
+
ops.reshape(-kk, (B, T, self.n_head, self.head_size)),
|
|
480
|
+
ops.reshape(kk * a, (B, T, self.n_head, self.head_size)),
|
|
481
|
+
initial_state=ops.cast(initial_state, "float32")
|
|
482
|
+
if initial_state is not None
|
|
483
|
+
else None,
|
|
484
|
+
)
|
|
485
|
+
x = ops.reshape(x, (B, T, C))
|
|
486
|
+
|
|
487
|
+
x = ops.reshape(self.ln_x(ops.reshape(x, (B * T, C))), ops.shape(x))
|
|
488
|
+
|
|
489
|
+
x = ops.reshape(x, (B, T, C))
|
|
490
|
+
r = ops.reshape(r, (B, T, H, -1))
|
|
491
|
+
k = ops.reshape(k, (B, T, H, -1))
|
|
492
|
+
v = ops.reshape(v, (B, T, C))
|
|
493
|
+
|
|
494
|
+
rwkv = ops.sum(r * k * self.r_k, axis=-1, keepdims=True) * ops.reshape(
|
|
495
|
+
v, (B, T, H, -1)
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
x = x + ops.reshape(rwkv, (B, T, C))
|
|
499
|
+
x = self.output_layer(x * g)
|
|
500
|
+
if not_generation_mode:
|
|
501
|
+
return x, v_first, final_state
|
|
502
|
+
return x, v_first, last_cache_x, final_state
|
|
503
|
+
|
|
504
|
+
def compute_output_shape(self, input_shape):
|
|
505
|
+
output_shapes = [
|
|
506
|
+
input_shape,
|
|
507
|
+
input_shape,
|
|
508
|
+
[input_shape[0], self.n_head, self.head_size, self.head_size],
|
|
509
|
+
]
|
|
510
|
+
return output_shapes
|
|
511
|
+
|
|
512
|
+
def normalize(
|
|
513
|
+
self,
|
|
514
|
+
x,
|
|
515
|
+
eps: float = 1e-12,
|
|
516
|
+
):
|
|
517
|
+
square_sum = ops.sum(ops.square(x), axis=-1, keepdims=True)
|
|
518
|
+
inv_norm = ops.rsqrt(square_sum + eps)
|
|
519
|
+
inv_norm = ops.maximum(inv_norm, eps)
|
|
520
|
+
return x * inv_norm
|
|
521
|
+
|
|
522
|
+
def get_config(self):
|
|
523
|
+
config = super().get_config()
|
|
524
|
+
config.update(
|
|
525
|
+
{
|
|
526
|
+
"hidden_size": self.hidden_size,
|
|
527
|
+
"head_size": self.head_size,
|
|
528
|
+
"gate_lora": self.gate_lora,
|
|
529
|
+
"mv_lora": self.mv_lora,
|
|
530
|
+
"aaa_lora": self.aaa_lora,
|
|
531
|
+
"decay_lora": self.decay_lora,
|
|
532
|
+
"add_v_first": self.add_v_first,
|
|
533
|
+
"kernel_initializer": initializers.serialize(
|
|
534
|
+
self.kernel_initializer
|
|
535
|
+
),
|
|
536
|
+
}
|
|
537
|
+
)
|
|
538
|
+
return config
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
class RWKV7_Block(Layer):
|
|
542
|
+
def __init__(
|
|
543
|
+
self,
|
|
544
|
+
hidden_size,
|
|
545
|
+
head_size,
|
|
546
|
+
intermediate_dim,
|
|
547
|
+
gate_lora=128,
|
|
548
|
+
mv_lora=32,
|
|
549
|
+
aaa_lora=64,
|
|
550
|
+
decay_lora=64,
|
|
551
|
+
use_initial_norm=False,
|
|
552
|
+
kernel_initializer="glorot_uniform",
|
|
553
|
+
**kwargs,
|
|
554
|
+
):
|
|
555
|
+
"""Initialize RWKV7 block.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
hidden_size: Hidden dimension size.
|
|
559
|
+
head_size: Attention head size.
|
|
560
|
+
intermediate_dim: Intermediate dimension for FFN.
|
|
561
|
+
gate_lora: LoRA dimension for gating.
|
|
562
|
+
mv_lora: LoRA dimension for value mixing.
|
|
563
|
+
aaa_lora: LoRA dimension for alpha parameters.
|
|
564
|
+
decay_lora: LoRA dimension for decay parameters.
|
|
565
|
+
use_initial_norm: Whether to use initial normalization.
|
|
566
|
+
kernel_initializer: Weight initializer.
|
|
567
|
+
**kwargs: Additional layer arguments.
|
|
568
|
+
"""
|
|
569
|
+
super().__init__(**kwargs)
|
|
570
|
+
self.head_size = head_size
|
|
571
|
+
self.hidden_size = hidden_size
|
|
572
|
+
self.gate_lora = gate_lora
|
|
573
|
+
self.mv_lora = mv_lora
|
|
574
|
+
self.aaa_lora = aaa_lora
|
|
575
|
+
self.decay_lora = decay_lora
|
|
576
|
+
self.intermediate_dim = intermediate_dim
|
|
577
|
+
self.use_initial_norm = use_initial_norm
|
|
578
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
579
|
+
|
|
580
|
+
def build(self, input_shape):
|
|
581
|
+
super().build(input_shape)
|
|
582
|
+
if self.use_initial_norm:
|
|
583
|
+
self.ln0 = keras.layers.LayerNormalization(
|
|
584
|
+
epsilon=1e-5, dtype=self.dtype_policy, name="init_norm"
|
|
585
|
+
)
|
|
586
|
+
self.ln0.build(input_shape)
|
|
587
|
+
|
|
588
|
+
self.ln1 = keras.layers.LayerNormalization(
|
|
589
|
+
epsilon=1e-5, dtype=self.dtype_policy, name="att_norm"
|
|
590
|
+
)
|
|
591
|
+
self.ln1.build(input_shape)
|
|
592
|
+
|
|
593
|
+
self.ln2 = keras.layers.LayerNormalization(
|
|
594
|
+
epsilon=1e-5, dtype=self.dtype_policy, name="ffn_norm"
|
|
595
|
+
)
|
|
596
|
+
self.ln2.build(input_shape)
|
|
597
|
+
|
|
598
|
+
self.att = RWKV7TimeMix(
|
|
599
|
+
self.hidden_size,
|
|
600
|
+
self.head_size,
|
|
601
|
+
self.gate_lora,
|
|
602
|
+
self.mv_lora,
|
|
603
|
+
self.aaa_lora,
|
|
604
|
+
self.decay_lora,
|
|
605
|
+
name="RWKV_TIME_MIX",
|
|
606
|
+
add_v_first=not self.use_initial_norm,
|
|
607
|
+
kernel_initializer=self.kernel_initializer,
|
|
608
|
+
dtype=self.dtype_policy,
|
|
609
|
+
)
|
|
610
|
+
self.att.build(input_shape)
|
|
611
|
+
|
|
612
|
+
self.ffn = RWKV7ChannelMix(
|
|
613
|
+
self.intermediate_dim,
|
|
614
|
+
name="RWKV_CMIX",
|
|
615
|
+
kernel_initializer=self.kernel_initializer,
|
|
616
|
+
dtype=self.dtype_policy,
|
|
617
|
+
)
|
|
618
|
+
self.ffn.build(input_shape)
|
|
619
|
+
|
|
620
|
+
# The generate call should be separated from the call method.
|
|
621
|
+
# Otherwise, keras.remat will report an error.
|
|
622
|
+
def generate_call(
|
|
623
|
+
self,
|
|
624
|
+
x,
|
|
625
|
+
v_first=None,
|
|
626
|
+
padding_mask=None,
|
|
627
|
+
cache_state=None,
|
|
628
|
+
cache_tmix_x=None,
|
|
629
|
+
cache_cmix_x=None,
|
|
630
|
+
rnn_mode=False,
|
|
631
|
+
):
|
|
632
|
+
"""Process input through RWKV block.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
x: Input tensor.
|
|
636
|
+
v_first: First value for mixing.
|
|
637
|
+
padding_mask: Mask for padding tokens.
|
|
638
|
+
cache_state: Cached recurrent state.
|
|
639
|
+
cache_tmix_x: Cached time mixer values.
|
|
640
|
+
cache_cmix_x: Cached channel mixer values.
|
|
641
|
+
rnn_mode: Whether to use RNN mode.
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
Processed output tensor and cache information.
|
|
645
|
+
"""
|
|
646
|
+
|
|
647
|
+
not_generation_mode = False
|
|
648
|
+
if padding_mask is not None:
|
|
649
|
+
padding_mask = ops.cast(padding_mask, x.dtype)
|
|
650
|
+
padding_mask = ops.expand_dims(padding_mask, axis=-1)
|
|
651
|
+
if self.use_initial_norm:
|
|
652
|
+
x = self.ln0(x)
|
|
653
|
+
xx, v_first, cache_tmix_x, cache_state = self.att.call(
|
|
654
|
+
self.ln1(x),
|
|
655
|
+
v_first=v_first,
|
|
656
|
+
padding_mask=padding_mask,
|
|
657
|
+
last_cache_x=cache_tmix_x,
|
|
658
|
+
cache_state=cache_state,
|
|
659
|
+
rnn_mode=rnn_mode,
|
|
660
|
+
not_generation_mode=not_generation_mode,
|
|
661
|
+
training=False,
|
|
662
|
+
)
|
|
663
|
+
x = ops.cast(x, xx.dtype) + xx
|
|
664
|
+
xx = self.ln2(x)
|
|
665
|
+
if padding_mask is not None:
|
|
666
|
+
padding_mask = ops.cast(padding_mask, x.dtype)
|
|
667
|
+
xx = xx * padding_mask
|
|
668
|
+
xx, cache_cmix_x = self.ffn(
|
|
669
|
+
xx, cache_cmix_x, not_generation_mode=not_generation_mode
|
|
670
|
+
)
|
|
671
|
+
x = x + xx
|
|
672
|
+
return x, v_first, cache_state, cache_tmix_x, cache_cmix_x
|
|
673
|
+
|
|
674
|
+
def call(
|
|
675
|
+
self,
|
|
676
|
+
x,
|
|
677
|
+
v_first=None,
|
|
678
|
+
padding_mask=None,
|
|
679
|
+
):
|
|
680
|
+
not_generation_mode = True
|
|
681
|
+
if padding_mask is not None:
|
|
682
|
+
padding_mask = ops.cast(padding_mask, x.dtype)
|
|
683
|
+
padding_mask = ops.expand_dims(padding_mask, axis=-1)
|
|
684
|
+
if self.use_initial_norm:
|
|
685
|
+
x = self.ln0(x)
|
|
686
|
+
xx = self.ln1(x)
|
|
687
|
+
# For our model, returning the state is not necessary.
|
|
688
|
+
# However, other researchers might need it when using it
|
|
689
|
+
# so we provide a return.
|
|
690
|
+
xx, v_first, state = self.att(
|
|
691
|
+
xx,
|
|
692
|
+
v_first=v_first,
|
|
693
|
+
padding_mask=padding_mask,
|
|
694
|
+
not_generation_mode=not_generation_mode,
|
|
695
|
+
)
|
|
696
|
+
x = ops.cast(x, xx.dtype) + xx
|
|
697
|
+
xx = self.ln2(x)
|
|
698
|
+
if padding_mask is not None:
|
|
699
|
+
padding_mask = ops.cast(padding_mask, x.dtype)
|
|
700
|
+
xx = xx * padding_mask
|
|
701
|
+
x = x + self.ffn(xx, not_generation_mode=not_generation_mode)
|
|
702
|
+
return x, v_first
|
|
703
|
+
|
|
704
|
+
def compute_output_shape(self, input_shape):
|
|
705
|
+
return [input_shape, input_shape]
|
|
706
|
+
|
|
707
|
+
def get_config(self):
|
|
708
|
+
config = super().get_config()
|
|
709
|
+
config.update(
|
|
710
|
+
{
|
|
711
|
+
"hidden_size": self.hidden_size,
|
|
712
|
+
"head_size": self.head_size,
|
|
713
|
+
"gate_lora": self.gate_lora,
|
|
714
|
+
"mv_lora": self.mv_lora,
|
|
715
|
+
"aaa_lora": self.aaa_lora,
|
|
716
|
+
"decay_lora": self.decay_lora,
|
|
717
|
+
"intermediate_dim": self.intermediate_dim,
|
|
718
|
+
"use_initial_norm": self.use_initial_norm,
|
|
719
|
+
"kernel_initializer": initializers.serialize(
|
|
720
|
+
self.kernel_initializer
|
|
721
|
+
),
|
|
722
|
+
}
|
|
723
|
+
)
|
|
724
|
+
return config
|