keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,724 @@
1
+ import warnings
2
+
3
+ import keras
4
+ from keras import initializers
5
+ from keras import ops
6
+ from keras.layers import Layer
7
+
8
+
9
+ def transpose_head(x, head_first):
10
+ x = ops.cast(x, dtype="float32")
11
+ if head_first:
12
+ return ops.transpose(x, (0, 2, 1, 3))
13
+ else:
14
+ return x
15
+
16
+
17
+ def rnn_generalized_delta_rule(
18
+ r,
19
+ w,
20
+ k,
21
+ v,
22
+ a,
23
+ b,
24
+ initial_state=None,
25
+ output_final_state=True,
26
+ head_first=False,
27
+ ):
28
+ DTYPE = r.dtype
29
+ B, T, H, N = ops.shape(r)
30
+ r = transpose_head(r, head_first)
31
+
32
+ k = transpose_head(k, head_first)
33
+
34
+ v = transpose_head(v, head_first)
35
+ a = transpose_head(a, head_first)
36
+ b = transpose_head(b, head_first)
37
+ w = transpose_head(w, head_first)
38
+ w = ops.exp(-ops.exp(w))
39
+
40
+ if initial_state is not None:
41
+ state = initial_state
42
+ if ops.shape(state)[0] == 1:
43
+ state = ops.broadcast_to(state, (B, H, N, N))
44
+ else:
45
+ state = ops.zeros((B, H, N, N))
46
+ state = ops.cast(state, "float32")
47
+
48
+ keras_backend = keras.config.backend()
49
+
50
+ def step(t, inputs):
51
+ state, out = inputs
52
+ kk = ops.reshape(k[:, t, :], (B, H, 1, N))
53
+ rr = ops.reshape(r[:, t, :], (B, H, N, 1))
54
+ vv = ops.reshape(v[:, t, :], (B, H, N, 1))
55
+ aa = ops.reshape(a[:, t, :], (B, H, N, 1))
56
+ bb = ops.reshape(b[:, t, :], (B, H, 1, N))
57
+ state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
58
+ o = ops.cast((state @ rr), out.dtype)
59
+ if keras_backend == "tensorflow":
60
+ out = out.write(t, ops.reshape(o, (B, H, N)))
61
+ elif keras_backend == "torch":
62
+ out[:, t : t + 1] = ops.reshape(o, (B, 1, H, N))
63
+ else:
64
+ out = ops.slice_update(
65
+ out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N))
66
+ )
67
+ return [state, out]
68
+
69
+ if keras_backend == "tensorflow":
70
+ # slice_update has no gradient in the TensorFlow backend
71
+ import tensorflow as tf
72
+
73
+ out = tf.TensorArray(DTYPE, size=T)
74
+ for t in range(T):
75
+ state, out = step(t, [state, out])
76
+ out = ops.transpose(out.stack(), [1, 0, 2, 3])
77
+
78
+ else:
79
+ out = ops.zeros((B, T, H, N), DTYPE)
80
+ state, out = ops.fori_loop(0, T, step, [state, out])
81
+
82
+ if output_final_state:
83
+ return ops.cast(out, DTYPE), state
84
+ return ops.cast(out, DTYPE)
85
+
86
+
87
+ class TimeShift(Layer):
88
+ """Time shift layer that shifts input sequence by one step.
89
+ It also be called short conv
90
+ """
91
+
92
+ def call(self, inputs, cache_x=None):
93
+ if cache_x is not None:
94
+ x = ops.concatenate([cache_x, inputs], axis=1)
95
+ else:
96
+ x = ops.pad(inputs, [[0, 0], [1, 0], [0, 0]], constant_values=0.0)
97
+ return x[:, :-1, :]
98
+
99
+ def compute_output_shape(self, input_shape):
100
+ return input_shape
101
+
102
+
103
+ class RWKV7ChannelMix(Layer):
104
+ """RWKV-7 channel mixing layer."""
105
+
106
+ def __init__(self, dim_ffn, kernel_initializer="glorot_uniform", **kwargs):
107
+ """Initialize RWKV7 channel mixer.
108
+
109
+ Args:
110
+ dim_ffn: Feed-forward dimension.
111
+ kernel_initializer: Weight initializer.
112
+ **kwargs: Additional layer arguments.
113
+ """
114
+ super().__init__(**kwargs)
115
+ self.dim_ffn = dim_ffn
116
+ self.kernel_initializer = initializers.get(kernel_initializer)
117
+
118
+ def call(self, x, last_cache_x=None, not_generation_mode=True):
119
+ """Process input through channel mixer.
120
+
121
+ Args:
122
+ x: Input tensor.
123
+ last_cache_x: Cached previous values.
124
+ not_generation_mode: Whether in generate mode.
125
+
126
+ Returns:
127
+ Mixed output tensor.
128
+ """
129
+ xx = self.time_shift(x, last_cache_x) - x
130
+ if last_cache_x is not None or not not_generation_mode:
131
+ last_cache_x = x[:, -1:]
132
+ k = x + xx * self.x_k
133
+ k = ops.relu(self.key(k)) ** 2
134
+ output = self.value(k)
135
+ if not_generation_mode:
136
+ return output
137
+ return output, last_cache_x
138
+
139
+ def compute_output_shape(self, input_shape):
140
+ if isinstance(input_shape, list):
141
+ return input_shape[0]
142
+ return input_shape
143
+
144
+ def build(self, input_shape):
145
+ super().build(input_shape)
146
+ if isinstance(input_shape, list):
147
+ input_shape = input_shape[0]
148
+
149
+ self.x_k = self.add_weight(
150
+ shape=(input_shape[-1],),
151
+ name="time_mix_k",
152
+ initializer=self.kernel_initializer,
153
+ )
154
+ self.time_shift = TimeShift(
155
+ dtype=self.dtype_policy,
156
+ )
157
+ self.key = keras.layers.Dense(
158
+ self.dim_ffn,
159
+ use_bias=False,
160
+ name="dense_k",
161
+ kernel_initializer=self.kernel_initializer,
162
+ dtype=self.dtype_policy,
163
+ )
164
+ self.value = keras.layers.Dense(
165
+ input_shape[-1],
166
+ use_bias=False,
167
+ name="dense_v",
168
+ kernel_initializer=self.kernel_initializer,
169
+ dtype=self.dtype_policy,
170
+ )
171
+ self.key.build(input_shape)
172
+ self.value.build([None, None, self.dim_ffn])
173
+
174
+ def get_config(self):
175
+ config = super().get_config()
176
+ config.update(
177
+ {
178
+ "dim_ffn": self.dim_ffn,
179
+ "kernel_initializer": initializers.serialize(
180
+ self.kernel_initializer
181
+ ),
182
+ }
183
+ )
184
+ return config
185
+
186
+
187
+ class RWKV7TimeMix(Layer):
188
+ """RWKV-7 time mixing layer."""
189
+
190
+ def __init__(
191
+ self,
192
+ hidden_size,
193
+ head_size,
194
+ gate_lora=128,
195
+ mv_lora=32,
196
+ aaa_lora=64,
197
+ decay_lora=64,
198
+ kernel_initializer="glorot_uniform",
199
+ add_v_first=True,
200
+ **kwargs,
201
+ ):
202
+ """Initialize RWKV7 time mixer.
203
+
204
+ Args:
205
+ hidden_size: Hidden dimension size.
206
+ head_size: Attention head size.
207
+ gate_lora: LoRA dimension for gating.
208
+ mv_lora: LoRA dimension for value mixing.
209
+ aaa_lora: LoRA dimension for alpha parameters.
210
+ decay_lora: LoRA dimension for decay parameters.
211
+ kernel_initializer: Weight initializer.
212
+ **kwargs: Additional layer arguments.
213
+ """
214
+ super().__init__(**kwargs)
215
+ self.head_size = head_size
216
+ self.hidden_size = hidden_size
217
+ self.n_head = hidden_size // self.head_size
218
+ self.gate_lora = gate_lora
219
+ self.mv_lora = mv_lora
220
+ self.aaa_lora = aaa_lora
221
+ self.decay_lora = decay_lora
222
+ self.kernel_initializer = initializers.get(kernel_initializer)
223
+ self.add_v_first = add_v_first
224
+ self.initial_state = None
225
+
226
+ self.RWKV7_OP = rnn_generalized_delta_rule
227
+ self.RWKV7_OP_RNN = rnn_generalized_delta_rule
228
+ self.RWKV7_OP_INFERENCE = rnn_generalized_delta_rule
229
+ if keras.config.backend() in ["torch", "jax"]:
230
+ # only torch and jax support cuda kernel speedup
231
+ try:
232
+ from rwkv_ops import rwkv7_op
233
+ from rwkv_ops import rwkv7_op_inference
234
+ from rwkv_ops import rwkv7_op_rnn
235
+
236
+ self.RWKV7_OP = rwkv7_op
237
+ # faster inference op
238
+ self.RWKV7_OP_INFERENCE = rwkv7_op_inference
239
+ self.RWKV7_OP_RNN = rwkv7_op_rnn
240
+ except ImportError:
241
+ warnings.warn(
242
+ "The 'rwkv_ops' package is not installed. "
243
+ "Falling back to the default (pure-Python) operators"
244
+ "pure-Python which will be very slow. "
245
+ "Please 'pip install rwkv_ops' to enable the cuda kernels",
246
+ UserWarning,
247
+ stacklevel=2,
248
+ )
249
+
250
+ assert self.hidden_size % self.n_head == 0
251
+
252
+ def build(self, input_shape):
253
+ super().build(input_shape)
254
+ if isinstance(input_shape[0], list):
255
+ input_shape = input_shape[0]
256
+ H = self.n_head
257
+ N = self.head_size
258
+ B, T, C = input_shape
259
+
260
+ self.x_r = self.add_weight(
261
+ shape=(C,), name="x_r", initializer=self.kernel_initializer
262
+ )
263
+ self.x_w = self.add_weight(
264
+ shape=(C,), name="x_w", initializer=self.kernel_initializer
265
+ )
266
+ self.x_k = self.add_weight(
267
+ shape=(C,), name="x_k", initializer=self.kernel_initializer
268
+ )
269
+ self.x_v = self.add_weight(
270
+ shape=(C,), name="x_v", initializer=self.kernel_initializer
271
+ )
272
+ self.x_a = self.add_weight(
273
+ shape=(C,), name="x_a", initializer=self.kernel_initializer
274
+ )
275
+ self.x_g = self.add_weight(
276
+ shape=(C,), name="x_g", initializer=self.kernel_initializer
277
+ )
278
+
279
+ self.w0 = self.add_weight(
280
+ shape=(C,), name="w0", initializer=self.kernel_initializer
281
+ )
282
+ self.w1 = self.add_weight(
283
+ shape=(C, self.decay_lora),
284
+ name="w1",
285
+ initializer=self.kernel_initializer,
286
+ )
287
+ self.w2 = self.add_weight(
288
+ shape=(self.decay_lora, C),
289
+ name="w2",
290
+ initializer=self.kernel_initializer,
291
+ )
292
+
293
+ self.a0 = self.add_weight(
294
+ shape=(C,), name="a0", initializer=self.kernel_initializer
295
+ )
296
+ self.a1 = self.add_weight(
297
+ shape=(C, self.aaa_lora),
298
+ name="a1",
299
+ initializer=self.kernel_initializer,
300
+ )
301
+ self.a2 = self.add_weight(
302
+ shape=(self.aaa_lora, C),
303
+ name="a2",
304
+ initializer=self.kernel_initializer,
305
+ )
306
+ if self.add_v_first:
307
+ self.v0 = self.add_weight(
308
+ shape=(C,), name="v0", initializer=self.kernel_initializer
309
+ )
310
+ self.v1 = self.add_weight(
311
+ shape=(C, self.mv_lora),
312
+ name="v1",
313
+ initializer=self.kernel_initializer,
314
+ )
315
+ self.v2 = self.add_weight(
316
+ shape=(self.mv_lora, C),
317
+ name="v2",
318
+ initializer=self.kernel_initializer,
319
+ )
320
+
321
+ self.g1 = self.add_weight(
322
+ shape=(C, self.gate_lora),
323
+ name="g1",
324
+ initializer=self.kernel_initializer,
325
+ )
326
+ self.g2 = self.add_weight(
327
+ shape=(self.gate_lora, C),
328
+ name="g2",
329
+ initializer=self.kernel_initializer,
330
+ )
331
+
332
+ self.k_k = self.add_weight(
333
+ shape=(C,), name="k_k", initializer=self.kernel_initializer
334
+ )
335
+ self.k_a = self.add_weight(
336
+ shape=(C,), name="k_a", initializer=self.kernel_initializer
337
+ )
338
+ self.r_k = self.add_weight(
339
+ shape=(H, N), name="r_k", initializer=self.kernel_initializer
340
+ )
341
+
342
+ self.time_shift = TimeShift(
343
+ dtype=self.dtype_policy,
344
+ )
345
+ self.receptance = keras.layers.Dense(
346
+ C,
347
+ use_bias=False,
348
+ kernel_initializer=self.kernel_initializer,
349
+ name="receptance",
350
+ dtype=self.dtype_policy,
351
+ )
352
+ self.key = keras.layers.Dense(
353
+ C,
354
+ use_bias=False,
355
+ kernel_initializer=self.kernel_initializer,
356
+ name="key",
357
+ dtype=self.dtype_policy,
358
+ )
359
+ self.value = keras.layers.Dense(
360
+ C,
361
+ use_bias=False,
362
+ kernel_initializer=self.kernel_initializer,
363
+ name="value",
364
+ dtype=self.dtype_policy,
365
+ )
366
+ self.output_layer = keras.layers.Dense(
367
+ C,
368
+ use_bias=False,
369
+ kernel_initializer=self.kernel_initializer,
370
+ name="output_layer",
371
+ dtype=self.dtype_policy,
372
+ )
373
+ self.ln_x = keras.layers.GroupNormalization(
374
+ groups=H,
375
+ epsilon=64e-5,
376
+ dtype=self.dtype_policy,
377
+ )
378
+
379
+ self.receptance.build(input_shape)
380
+ self.value.build(input_shape)
381
+ self.key.build(input_shape)
382
+ self.output_layer.build(input_shape)
383
+ self.ln_x.build((None, C))
384
+
385
+ def call(
386
+ self,
387
+ x,
388
+ v_first=None,
389
+ padding_mask=None,
390
+ last_cache_x=None,
391
+ cache_state=None,
392
+ rnn_mode=False,
393
+ not_generation_mode=True,
394
+ training=None,
395
+ ):
396
+ """Process input through time mixer.
397
+
398
+ Args:
399
+ x: Input tensor.
400
+ v_first: First value for mixing.
401
+ padding_mask: Mask for padding tokens.
402
+ last_cache_x: Cached previous values.
403
+ cache_state: Cached recurrent state.
404
+ rnn_mode: Whether to use RNN mode.
405
+ not_generation_mode: Whether in generate mode.
406
+
407
+ Returns:
408
+ Mixed output tensor and state information.
409
+ """
410
+ if cache_state is None:
411
+ initial_state = self.initial_state
412
+ else:
413
+ initial_state = cache_state
414
+ if padding_mask is not None:
415
+ if ops.ndim(padding_mask) == 2:
416
+ padding_mask = padding_mask[..., None]
417
+ padding_mask = ops.cast(padding_mask, x.dtype)
418
+ x *= padding_mask
419
+ B, T, C = ops.shape(x)
420
+ H = self.n_head
421
+ xx = self.time_shift(x, last_cache_x) - x
422
+ if last_cache_x is not None or not not_generation_mode:
423
+ last_cache_x = x[:, -1:]
424
+ if padding_mask is not None:
425
+ xx *= padding_mask
426
+
427
+ xr = x + xx * self.x_r
428
+ xw = x + xx * self.x_w
429
+ xk = x + xx * self.x_k
430
+ xv = x + xx * self.x_v
431
+ xa = x + xx * self.x_a
432
+ xg = x + xx * self.x_g
433
+
434
+ r = self.receptance(xr)
435
+ w = (
436
+ -ops.softplus(
437
+ -(
438
+ self.w0
439
+ + ops.matmul(ops.tanh(ops.matmul(xw, self.w1)), self.w2)
440
+ )
441
+ )
442
+ - 0.5
443
+ ) # soft-clamp to (-inf, -0.5)
444
+ k = self.key(xk)
445
+ v = self.value(xv)
446
+ if v_first is None or not self.add_v_first:
447
+ v_first = v
448
+ else:
449
+ v = v + (v_first - v) * ops.sigmoid(
450
+ self.v0 + ops.matmul(ops.matmul(xv, self.v1), self.v2)
451
+ )
452
+
453
+ a = ops.sigmoid(
454
+ self.a0 + ops.matmul(ops.matmul(xa, self.a1), self.a2)
455
+ ) # a is "in-context learning rate"
456
+ g = ops.matmul(ops.sigmoid(ops.matmul(xg, self.g1)), self.g2)
457
+
458
+ kk = k * self.k_k
459
+
460
+ kk = self.normalize(ops.reshape(kk, (B, T, H, -1)))
461
+ kk = ops.reshape(kk, (B, T, C))
462
+
463
+ k = k * (1 + (a - 1) * self.k_a)
464
+ if padding_mask is not None:
465
+ w = ops.where(padding_mask, w, -1e9)
466
+ if training:
467
+ rwkv7_op = self.RWKV7_OP
468
+ elif rnn_mode:
469
+ # T=1,single step
470
+ rwkv7_op = self.RWKV7_OP_RNN
471
+ else:
472
+ rwkv7_op = self.RWKV7_OP_INFERENCE
473
+
474
+ x, final_state = rwkv7_op(
475
+ ops.reshape(r, (B, T, self.n_head, self.head_size)),
476
+ ops.reshape(w, (B, T, self.n_head, self.head_size)),
477
+ ops.reshape(k, (B, T, self.n_head, self.head_size)),
478
+ ops.reshape(v, (B, T, self.n_head, self.head_size)),
479
+ ops.reshape(-kk, (B, T, self.n_head, self.head_size)),
480
+ ops.reshape(kk * a, (B, T, self.n_head, self.head_size)),
481
+ initial_state=ops.cast(initial_state, "float32")
482
+ if initial_state is not None
483
+ else None,
484
+ )
485
+ x = ops.reshape(x, (B, T, C))
486
+
487
+ x = ops.reshape(self.ln_x(ops.reshape(x, (B * T, C))), ops.shape(x))
488
+
489
+ x = ops.reshape(x, (B, T, C))
490
+ r = ops.reshape(r, (B, T, H, -1))
491
+ k = ops.reshape(k, (B, T, H, -1))
492
+ v = ops.reshape(v, (B, T, C))
493
+
494
+ rwkv = ops.sum(r * k * self.r_k, axis=-1, keepdims=True) * ops.reshape(
495
+ v, (B, T, H, -1)
496
+ )
497
+
498
+ x = x + ops.reshape(rwkv, (B, T, C))
499
+ x = self.output_layer(x * g)
500
+ if not_generation_mode:
501
+ return x, v_first, final_state
502
+ return x, v_first, last_cache_x, final_state
503
+
504
+ def compute_output_shape(self, input_shape):
505
+ output_shapes = [
506
+ input_shape,
507
+ input_shape,
508
+ [input_shape[0], self.n_head, self.head_size, self.head_size],
509
+ ]
510
+ return output_shapes
511
+
512
+ def normalize(
513
+ self,
514
+ x,
515
+ eps: float = 1e-12,
516
+ ):
517
+ square_sum = ops.sum(ops.square(x), axis=-1, keepdims=True)
518
+ inv_norm = ops.rsqrt(square_sum + eps)
519
+ inv_norm = ops.maximum(inv_norm, eps)
520
+ return x * inv_norm
521
+
522
+ def get_config(self):
523
+ config = super().get_config()
524
+ config.update(
525
+ {
526
+ "hidden_size": self.hidden_size,
527
+ "head_size": self.head_size,
528
+ "gate_lora": self.gate_lora,
529
+ "mv_lora": self.mv_lora,
530
+ "aaa_lora": self.aaa_lora,
531
+ "decay_lora": self.decay_lora,
532
+ "add_v_first": self.add_v_first,
533
+ "kernel_initializer": initializers.serialize(
534
+ self.kernel_initializer
535
+ ),
536
+ }
537
+ )
538
+ return config
539
+
540
+
541
+ class RWKV7_Block(Layer):
542
+ def __init__(
543
+ self,
544
+ hidden_size,
545
+ head_size,
546
+ intermediate_dim,
547
+ gate_lora=128,
548
+ mv_lora=32,
549
+ aaa_lora=64,
550
+ decay_lora=64,
551
+ use_initial_norm=False,
552
+ kernel_initializer="glorot_uniform",
553
+ **kwargs,
554
+ ):
555
+ """Initialize RWKV7 block.
556
+
557
+ Args:
558
+ hidden_size: Hidden dimension size.
559
+ head_size: Attention head size.
560
+ intermediate_dim: Intermediate dimension for FFN.
561
+ gate_lora: LoRA dimension for gating.
562
+ mv_lora: LoRA dimension for value mixing.
563
+ aaa_lora: LoRA dimension for alpha parameters.
564
+ decay_lora: LoRA dimension for decay parameters.
565
+ use_initial_norm: Whether to use initial normalization.
566
+ kernel_initializer: Weight initializer.
567
+ **kwargs: Additional layer arguments.
568
+ """
569
+ super().__init__(**kwargs)
570
+ self.head_size = head_size
571
+ self.hidden_size = hidden_size
572
+ self.gate_lora = gate_lora
573
+ self.mv_lora = mv_lora
574
+ self.aaa_lora = aaa_lora
575
+ self.decay_lora = decay_lora
576
+ self.intermediate_dim = intermediate_dim
577
+ self.use_initial_norm = use_initial_norm
578
+ self.kernel_initializer = initializers.get(kernel_initializer)
579
+
580
+ def build(self, input_shape):
581
+ super().build(input_shape)
582
+ if self.use_initial_norm:
583
+ self.ln0 = keras.layers.LayerNormalization(
584
+ epsilon=1e-5, dtype=self.dtype_policy, name="init_norm"
585
+ )
586
+ self.ln0.build(input_shape)
587
+
588
+ self.ln1 = keras.layers.LayerNormalization(
589
+ epsilon=1e-5, dtype=self.dtype_policy, name="att_norm"
590
+ )
591
+ self.ln1.build(input_shape)
592
+
593
+ self.ln2 = keras.layers.LayerNormalization(
594
+ epsilon=1e-5, dtype=self.dtype_policy, name="ffn_norm"
595
+ )
596
+ self.ln2.build(input_shape)
597
+
598
+ self.att = RWKV7TimeMix(
599
+ self.hidden_size,
600
+ self.head_size,
601
+ self.gate_lora,
602
+ self.mv_lora,
603
+ self.aaa_lora,
604
+ self.decay_lora,
605
+ name="RWKV_TIME_MIX",
606
+ add_v_first=not self.use_initial_norm,
607
+ kernel_initializer=self.kernel_initializer,
608
+ dtype=self.dtype_policy,
609
+ )
610
+ self.att.build(input_shape)
611
+
612
+ self.ffn = RWKV7ChannelMix(
613
+ self.intermediate_dim,
614
+ name="RWKV_CMIX",
615
+ kernel_initializer=self.kernel_initializer,
616
+ dtype=self.dtype_policy,
617
+ )
618
+ self.ffn.build(input_shape)
619
+
620
+ # The generate call should be separated from the call method.
621
+ # Otherwise, keras.remat will report an error.
622
+ def generate_call(
623
+ self,
624
+ x,
625
+ v_first=None,
626
+ padding_mask=None,
627
+ cache_state=None,
628
+ cache_tmix_x=None,
629
+ cache_cmix_x=None,
630
+ rnn_mode=False,
631
+ ):
632
+ """Process input through RWKV block.
633
+
634
+ Args:
635
+ x: Input tensor.
636
+ v_first: First value for mixing.
637
+ padding_mask: Mask for padding tokens.
638
+ cache_state: Cached recurrent state.
639
+ cache_tmix_x: Cached time mixer values.
640
+ cache_cmix_x: Cached channel mixer values.
641
+ rnn_mode: Whether to use RNN mode.
642
+
643
+ Returns:
644
+ Processed output tensor and cache information.
645
+ """
646
+
647
+ not_generation_mode = False
648
+ if padding_mask is not None:
649
+ padding_mask = ops.cast(padding_mask, x.dtype)
650
+ padding_mask = ops.expand_dims(padding_mask, axis=-1)
651
+ if self.use_initial_norm:
652
+ x = self.ln0(x)
653
+ xx, v_first, cache_tmix_x, cache_state = self.att.call(
654
+ self.ln1(x),
655
+ v_first=v_first,
656
+ padding_mask=padding_mask,
657
+ last_cache_x=cache_tmix_x,
658
+ cache_state=cache_state,
659
+ rnn_mode=rnn_mode,
660
+ not_generation_mode=not_generation_mode,
661
+ training=False,
662
+ )
663
+ x = ops.cast(x, xx.dtype) + xx
664
+ xx = self.ln2(x)
665
+ if padding_mask is not None:
666
+ padding_mask = ops.cast(padding_mask, x.dtype)
667
+ xx = xx * padding_mask
668
+ xx, cache_cmix_x = self.ffn(
669
+ xx, cache_cmix_x, not_generation_mode=not_generation_mode
670
+ )
671
+ x = x + xx
672
+ return x, v_first, cache_state, cache_tmix_x, cache_cmix_x
673
+
674
+ def call(
675
+ self,
676
+ x,
677
+ v_first=None,
678
+ padding_mask=None,
679
+ ):
680
+ not_generation_mode = True
681
+ if padding_mask is not None:
682
+ padding_mask = ops.cast(padding_mask, x.dtype)
683
+ padding_mask = ops.expand_dims(padding_mask, axis=-1)
684
+ if self.use_initial_norm:
685
+ x = self.ln0(x)
686
+ xx = self.ln1(x)
687
+ # For our model, returning the state is not necessary.
688
+ # However, other researchers might need it when using it
689
+ # so we provide a return.
690
+ xx, v_first, state = self.att(
691
+ xx,
692
+ v_first=v_first,
693
+ padding_mask=padding_mask,
694
+ not_generation_mode=not_generation_mode,
695
+ )
696
+ x = ops.cast(x, xx.dtype) + xx
697
+ xx = self.ln2(x)
698
+ if padding_mask is not None:
699
+ padding_mask = ops.cast(padding_mask, x.dtype)
700
+ xx = xx * padding_mask
701
+ x = x + self.ffn(xx, not_generation_mode=not_generation_mode)
702
+ return x, v_first
703
+
704
+ def compute_output_shape(self, input_shape):
705
+ return [input_shape, input_shape]
706
+
707
+ def get_config(self):
708
+ config = super().get_config()
709
+ config.update(
710
+ {
711
+ "hidden_size": self.hidden_size,
712
+ "head_size": self.head_size,
713
+ "gate_lora": self.gate_lora,
714
+ "mv_lora": self.mv_lora,
715
+ "aaa_lora": self.aaa_lora,
716
+ "decay_lora": self.decay_lora,
717
+ "intermediate_dim": self.intermediate_dim,
718
+ "use_initial_norm": self.use_initial_norm,
719
+ "kernel_initializer": initializers.serialize(
720
+ self.kernel_initializer
721
+ ),
722
+ }
723
+ )
724
+ return config