keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,180 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.backbone import Backbone
6
+ from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block
7
+
8
+
9
+ def rwkv7_kernel_initializer(stddev=0.02):
10
+ return keras.initializers.TruncatedNormal(stddev=stddev)
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.RWKV7Backbone")
14
+ class RWKV7Backbone(Backbone):
15
+ """The RWKV7 Transformer core architecture with hyperparameters.
16
+
17
+ This network implements a RNN-based decoder network,
18
+ Goose, as described in
19
+ [RWKV-7](https://arxiv.org/abs/2503.14456).
20
+
21
+ This network implements a Modern RNN architecture based on linear
22
+ attention mechanisms with recurrent processing, as described in the
23
+ RWKV papers. It includes the embedding lookups and RWKV-7 blocks.
24
+
25
+ The default constructor gives a fully customizable, randomly initialized
26
+ RWKV-7 model with any number of layers, heads, and embedding dimensions.
27
+ To load preset architectures and weights, use the `from_preset`
28
+ constructor.
29
+
30
+ Args:
31
+ hidden_size: int. The size of the transformer encoding and pooling
32
+ layers.
33
+ head_size: int. The size of each attention head.
34
+ num_layers: int. The number of transformer layers.
35
+ vocabulary_size: int. The size of the token vocabulary.
36
+ intermediate_dim: int. The output dimension of the first Dense layer in
37
+ a two-layer feedforward network for each transformer.
38
+ gate_lora: int. LoRA dimension for gating. Defaults to `128` .
39
+ mv_lora: int. LoRA dimension for value mixing. Defaults to `32` .
40
+ aaa_lora: int. LoRA dimension for alpha parameters.Defaults to `64` .
41
+ decay_lora: int. LoRA dimension for decay parameters.Defaults to `64` .
42
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
43
+ for model computations and weights. Note that some computations,
44
+ such as softmax and layer normalization, will always be done at
45
+ float32 precision regardless of dtype.
46
+ dropout_rate: float. Dropout rate for the dropout layer.
47
+
48
+ Examples:
49
+
50
+ ```python
51
+ input_data = np.ones(shape=(1, 12), dtype="int32")
52
+
53
+
54
+ # Randomly initialized RWKV-7 decoder with custom config.
55
+ model = keras_hub.models.RWKV7Backbone(
56
+ vocabulary_size=10,
57
+ hidden_size=512,
58
+ num_layers=2,
59
+ head_size=64,
60
+ intermediate_dim=1024,
61
+ dtype="float32"
62
+ )
63
+ model(input_data)
64
+ ```
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ hidden_size,
70
+ head_size,
71
+ num_layers,
72
+ vocabulary_size,
73
+ intermediate_dim,
74
+ gate_lora=128,
75
+ mv_lora=32,
76
+ aaa_lora=64,
77
+ decay_lora=64,
78
+ dtype=None,
79
+ dropout_rate=0,
80
+ **kwargs,
81
+ ):
82
+ # === Layers ===
83
+ self.token_embedding = keras.layers.Embedding(
84
+ input_dim=vocabulary_size,
85
+ output_dim=hidden_size,
86
+ embeddings_initializer=rwkv7_kernel_initializer(),
87
+ dtype=dtype,
88
+ name="token_embedding",
89
+ )
90
+
91
+ self.output_layer_norm = keras.layers.LayerNormalization(
92
+ epsilon=1e-5,
93
+ name="output_norm",
94
+ dtype=dtype,
95
+ )
96
+ self.dropout = keras.layers.Dropout(
97
+ dropout_rate,
98
+ dtype=dtype,
99
+ name="dropout",
100
+ )
101
+ self.rwkv_layers = []
102
+ for i in range(num_layers):
103
+ layer = RWKV7_Block(
104
+ hidden_size,
105
+ head_size,
106
+ intermediate_dim,
107
+ gate_lora,
108
+ mv_lora,
109
+ aaa_lora,
110
+ decay_lora,
111
+ use_initial_norm=i == 0,
112
+ kernel_initializer=rwkv7_kernel_initializer(),
113
+ dtype=dtype,
114
+ name=f"rwkv_layer_{i}",
115
+ )
116
+
117
+ self.rwkv_layers.append(layer)
118
+ self.head = keras.layers.Dense(
119
+ units=vocabulary_size,
120
+ kernel_initializer=rwkv7_kernel_initializer(),
121
+ use_bias=False,
122
+ name="head",
123
+ dtype=dtype,
124
+ )
125
+ # === Functional Model ===
126
+ token_id_input = keras.Input(
127
+ shape=(None,), dtype="int32", name="token_ids"
128
+ )
129
+
130
+ padding_mask_input = keras.Input(
131
+ shape=(None,), dtype="int32", name="padding_mask"
132
+ )
133
+
134
+ x = self.token_embedding(token_id_input)
135
+ padding_mask = ops.cast(padding_mask_input, dtype=x.dtype)
136
+ v_first = None
137
+ for rwkv_layer in self.rwkv_layers:
138
+ x, v_first = rwkv_layer(x, v_first, padding_mask)
139
+ x = self.dropout(x)
140
+ sequence_output = self.output_layer_norm(x)
141
+ sequence_output = self.head(sequence_output)
142
+
143
+ super().__init__(
144
+ inputs={
145
+ "token_ids": token_id_input,
146
+ "padding_mask": padding_mask_input,
147
+ },
148
+ outputs=sequence_output,
149
+ dtype=dtype,
150
+ **kwargs,
151
+ )
152
+
153
+ self.num_layers = num_layers
154
+ self.head_size = head_size
155
+ self.hidden_size = hidden_size
156
+ self.gate_lora = gate_lora
157
+ self.mv_lora = mv_lora
158
+ self.aaa_lora = aaa_lora
159
+ self.decay_lora = decay_lora
160
+ self.vocabulary_size = vocabulary_size
161
+ self.dropout_rate = dropout_rate
162
+ self.intermediate_dim = intermediate_dim
163
+
164
+ def get_config(self):
165
+ config = super().get_config()
166
+ config.update(
167
+ {
168
+ "hidden_size": self.hidden_size,
169
+ "head_size": self.head_size,
170
+ "gate_lora": self.gate_lora,
171
+ "mv_lora": self.mv_lora,
172
+ "aaa_lora": self.aaa_lora,
173
+ "decay_lora": self.decay_lora,
174
+ "vocabulary_size": self.vocabulary_size,
175
+ "dropout_rate": self.dropout_rate,
176
+ "intermediate_dim": self.intermediate_dim,
177
+ "num_layers": self.num_layers,
178
+ }
179
+ )
180
+ return config
@@ -0,0 +1,259 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.causal_lm import CausalLM
5
+ from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
6
+ from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
7
+ RWKV7CausalLMPreprocessor,
8
+ )
9
+ from keras_hub.src.utils.tensor_utils import any_equal
10
+
11
+
12
+ @keras_hub_export("keras_hub.models.RWKV7CausalLM")
13
+ class RWKV7CausalLM(CausalLM):
14
+ """An end-to-end RWKV-7 model for causal language modeling.
15
+
16
+ A causal language model (LM) predicts the next token based on previous
17
+ tokens. This task setup can be used to train the model unsupervised on
18
+ plain text input, or to autoregressively generate plain text similar to
19
+ the data used for training. This task can be used for pre-training or
20
+ fine-tuning a RWKV-7 model, simply by calling `fit()`.
21
+
22
+ This model has a generate() method, which generates text based on a
23
+ prompt. The generation strategy used is controlled by an additional
24
+ sampler argument on `compile()`. You can recompile the model with
25
+ different `keras_hub.samplers` objects to control the generation. By
26
+ default, `"greedy"` sampling will be used.
27
+
28
+ Args:
29
+ backbone: A `keras_hub.models.RWKV7Backbone` instance.
30
+ preprocessor: A `keras_hub.models.RWKV7CausalLMPreprocessor` or `None`.
31
+ If `None`, this model will not apply preprocessing, and inputs
32
+ should be preprocessed before calling the model.
33
+
34
+ Examples:
35
+ ```python
36
+ # Initialize the tokenizer and load assets from a local path.
37
+ tokenizer = RWKVTokenizer()
38
+ tokenizer.load_assets(rwkv_path)
39
+
40
+ # Create a preprocessor with a sequence length of 8.
41
+ preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8)
42
+
43
+ # Initialize the model with a backbone and preprocessor.
44
+ causal_lm = RWKV7CausalLM(backbone, preprocessor)
45
+
46
+ # you also can load model by from_preset
47
+ rwkv_path = "RWKV7_G1a_0.1B"
48
+ tokenizer = RWKVTokenizer.from_preset(rwkv_path)
49
+ causal_lm = RWKV7CausalLM.from_preset(rwkv_path)
50
+
51
+ prompts = ["Bubble sort\n```python", "Hello World\n```python\n"]
52
+
53
+ causal_lm.compile(sampler="greedy")
54
+
55
+ outputs = causal_lm.generate(prompts, max_length=128)
56
+ for out in outputs:
57
+ print(out)
58
+ print("-" * 100)
59
+ ```
60
+ """
61
+
62
+ backbone_cls = RWKV7Backbone
63
+ preprocessor_cls = RWKV7CausalLMPreprocessor
64
+
65
+ def __init__(self, backbone, preprocessor=None, **kwargs):
66
+ """Initialize the RWKV-7 causal language model.
67
+
68
+ Args:
69
+ backbone: The backbone model.
70
+ preprocessor: The preprocessor for tokenization.
71
+ **kwargs: Additional keyword arguments.
72
+ """
73
+ # === Layers ===
74
+ self.backbone = backbone
75
+ self.preprocessor = preprocessor
76
+ super().__init__(
77
+ inputs=backbone.input,
78
+ outputs=backbone.output,
79
+ **kwargs,
80
+ )
81
+
82
+ def call_with_cache(
83
+ self,
84
+ token_ids,
85
+ cache,
86
+ compute_head=True,
87
+ padding_mask=None,
88
+ rnn_mode=True,
89
+ ):
90
+ """Forward pass of `RWKV7CausalLM` with cache.
91
+
92
+ `call_with_cache` adds an additional forward pass for the model for
93
+ autoregressive inference. Unlike calling the model directly, this method
94
+ allows caching previous state Tensors in RWKV layers, and avoids
95
+ recomputing the outputs of seen tokens.
96
+
97
+ Args:
98
+ token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
99
+ cache: a dense float Tensor, the cache of state and token values.
100
+ compute_head: bool, whether to compute the output head.
101
+ padding_mask: a dense bool Tensor, the padding mask.
102
+ rnn_mode: bool, whether to use RNN mode.
103
+
104
+ Returns:
105
+ A (logits, hidden_states, cache) tuple. Where `logits` is the
106
+ language model logits for the input token_ids, `hidden_states` is
107
+ the final hidden representation of the input tokens, and `cache` is
108
+ the decoding cache.
109
+ """
110
+ state_cache, last_token_cache = cache
111
+ x = self.backbone.token_embedding(token_ids)
112
+ if padding_mask is None:
113
+ padding_mask = ops.not_equal(token_ids, 0)
114
+ padding_mask = ops.cast(padding_mask, x.dtype)
115
+ v_first = None
116
+ updated_state_cache = []
117
+ updated_last_token_cache = []
118
+
119
+ for i in range(self.backbone.num_layers):
120
+ current_state_cache = state_cache[:, i, ...]
121
+ current_token_cache = last_token_cache[:, i, ...]
122
+ x, v_first, new_cache_state, cache_tmix_x, cache_cmix_x = (
123
+ self.backbone.rwkv_layers[i].generate_call(
124
+ x,
125
+ v_first=v_first,
126
+ padding_mask=padding_mask,
127
+ cache_state=current_state_cache,
128
+ cache_tmix_x=current_token_cache[:, 0],
129
+ cache_cmix_x=current_token_cache[:, 1],
130
+ rnn_mode=rnn_mode,
131
+ )
132
+ )
133
+ new_token_cache = ops.stack([cache_tmix_x, cache_cmix_x], axis=1)
134
+ updated_state_cache.append(new_cache_state)
135
+ updated_last_token_cache.append(new_token_cache)
136
+ cache = [
137
+ ops.stack(updated_state_cache, axis=1),
138
+ ops.stack(updated_last_token_cache, axis=1),
139
+ ]
140
+ hidden_states = x = self.backbone.output_layer_norm(x)
141
+ if compute_head:
142
+ logits = self.backbone.head(x)
143
+ else:
144
+ logits = None
145
+ return logits, hidden_states, cache
146
+
147
+ def _build_cache(self, token_ids, padding_mask):
148
+ """Build an empty cache for use with `call_with_cache()`."""
149
+ batch_size = ops.shape(token_ids)[0]
150
+ num_layers = self.backbone.num_layers
151
+ head_dim = self.backbone.head_size
152
+ hidden_size = self.backbone.hidden_size
153
+ num_heads = hidden_size // head_dim
154
+
155
+ state_cache = ops.zeros(
156
+ [batch_size, num_layers, num_heads, head_dim, head_dim],
157
+ dtype=self.compute_dtype,
158
+ )
159
+ last_token_cache = ops.zeros(
160
+ [batch_size, num_layers, 2, 1, hidden_size],
161
+ dtype=self.compute_dtype,
162
+ )
163
+ cache = [state_cache, last_token_cache]
164
+
165
+ # Seed the cache.
166
+ # Prefill stage can use kernel for better performance
167
+ _, hidden_states, cache = self.call_with_cache(
168
+ token_ids,
169
+ cache,
170
+ rnn_mode=False,
171
+ compute_head=False,
172
+ padding_mask=padding_mask,
173
+ )
174
+
175
+ return hidden_states, cache
176
+
177
+ def generate_step(
178
+ self,
179
+ inputs,
180
+ stop_token_ids=None,
181
+ ):
182
+ """A compilable generation function for a single batch of inputs.
183
+
184
+ This function represents the inner, XLA-compilable, generation function
185
+ for a single batch of inputs. Inputs should have the same structure as
186
+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
187
+
188
+ Args:
189
+ inputs: A dictionary with keys `"token_ids"`, `"padding_mask"`, and
190
+ `"predict_token_ids"` with batched tensor values.
191
+ stop_token_ids: Tuple of id's of the end token to stop on. If all
192
+ sequences have produced a new stop token, generation
193
+ will stop.
194
+ """
195
+ token_ids, padding_mask, predict_token_ids = (
196
+ inputs["token_ids"],
197
+ inputs["padding_mask"],
198
+ inputs["predict_token_ids"],
199
+ )
200
+ # Create and seed cache with a single forward pass.
201
+
202
+ hidden_states, cache = self._build_cache(
203
+ token_ids, inputs["input_padding_mask"]
204
+ )
205
+
206
+ def next(prompt, cache, index):
207
+ # The cache index is the index of our previous token.
208
+ cache_update_index = index - 1
209
+ batch_size = ops.shape(prompt)[0]
210
+ prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
211
+ logits, hidden_states, cache = self.call_with_cache(
212
+ prompt,
213
+ cache,
214
+ )
215
+ return (
216
+ ops.squeeze(logits, axis=1),
217
+ ops.squeeze(hidden_states, axis=1),
218
+ cache,
219
+ )
220
+
221
+ output_ids = self.sampler(
222
+ next=next,
223
+ prompt=predict_token_ids,
224
+ cache=cache,
225
+ index=1,
226
+ mask=padding_mask,
227
+ stop_token_ids=stop_token_ids,
228
+ hidden_states=hidden_states,
229
+ model=self,
230
+ )
231
+ padding_mask = ops.concatenate(
232
+ [
233
+ ops.cast(ops.not_equal(token_ids, 0), padding_mask.dtype),
234
+ padding_mask,
235
+ ],
236
+ axis=1,
237
+ )
238
+ token_ids = ops.concatenate([token_ids, output_ids], axis=1)
239
+
240
+ # Compute an output padding mask with the token ids we updated.
241
+ if stop_token_ids is not None:
242
+ # Build a mask of stop token locations not in the original
243
+ # prompt (not in locations where `padding_mask` is True).
244
+ end_locations = any_equal(
245
+ token_ids, stop_token_ids, ops.logical_not(padding_mask)
246
+ )
247
+ end_locations = ops.cast(end_locations, "int32")
248
+ # Use cumsum to get ones in all locations after end_locations.
249
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
250
+ overflow = cumsum - end_locations
251
+ # Our padding mask is the inverse of these overflow locations.
252
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
253
+ else:
254
+ # Without early stopping, all locations will have been updated.
255
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
256
+ return {
257
+ "token_ids": token_ids,
258
+ "padding_mask": ops.cast(padding_mask, token_ids.dtype),
259
+ }
@@ -0,0 +1,214 @@
1
+ import keras
2
+ import numpy as np
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
6
+ from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
7
+ from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
8
+ from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer
9
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
10
+ from keras_hub.src.utils.tensor_utils import tf
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor")
14
+ class RWKV7CausalLMPreprocessor(CausalLMPreprocessor):
15
+ """RWKV-7 Causal LM preprocessor.
16
+
17
+ This preprocessing layer is meant for use with
18
+ `keras_hub.models.RWKV7CausalLM`. By default, it will take in batches of
19
+ strings, and return outputs in a `(x, y, sample_weight)` format, where the
20
+ `y` label is the next token id in the `x` sequence.
21
+
22
+ For use with generation, the layer also exposes two methods
23
+ `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
24
+ is attached to a `keras_hub.models.RWKV7CausalLM` instance, these methods
25
+ will be called implicitly in generate(). They can also be called
26
+ standalone (e.g. to precompute preprocessing inputs for generation in a
27
+ separate process).
28
+
29
+ Args:
30
+ tokenizer: A `keras_hub.models.RWKVTokenizer` instance.
31
+ sequence_length: The length of the packed inputs.
32
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
33
+ start token to each input sequence. Default is `False`.
34
+
35
+ Call arguments:
36
+ x: A string, `tf.Tensor` or list of python strings.
37
+ y: Label data. Should always be `None` as the layer generates labels.
38
+ sample_weight: Label weights. Should always be `None` as the layer
39
+ generates label weights.
40
+ sequence_length: Pass to override the configured sequence_length of
41
+ the layer.
42
+
43
+
44
+ Examples:
45
+ ```python
46
+ # Initialize the tokenizer and load assets from a local path.
47
+ tokenizer = RWKVTokenizer()
48
+ tokenizer.load_assets(rwkv_path)
49
+
50
+ # Create a preprocessor with a sequence length of 8.
51
+ preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8)
52
+
53
+ # Tokenize and pack a batch of sentences.
54
+ preprocessor(["Bubble sort\n```python", "Hello World\n```python\n"])
55
+
56
+ # Preprocess inputs for generation with a maximum generation length of 16.
57
+ preprocessor.generate_preprocess(
58
+ ["Bubble sort\n```python", "Hello World\n```python\n"], 16
59
+ )
60
+ ```
61
+ """
62
+
63
+ backbone_cls = RWKV7Backbone
64
+ tokenizer_cls = RWKVTokenizer
65
+
66
+ def __init__(
67
+ self,
68
+ tokenizer,
69
+ add_start_token=False,
70
+ **kwargs,
71
+ ):
72
+ """Initialize the preprocessor.
73
+
74
+ Args:
75
+ tokenizer: The tokenizer to use.
76
+ add_start_token: Whether to add start token.
77
+ **kwargs: Additional arguments.
78
+ """
79
+ super().__init__(
80
+ tokenizer=tokenizer, add_start_token=add_start_token, **kwargs
81
+ )
82
+
83
+ @preprocessing_function
84
+ def call(
85
+ self,
86
+ x,
87
+ y=None,
88
+ sample_weight=None,
89
+ sequence_length=None,
90
+ ):
91
+ """Preprocess the input for training.
92
+
93
+ Args:
94
+ x: Input text data.
95
+ y: Target data (optional).
96
+ sample_weight: Sample weights (optional).
97
+ sequence_length: Desired sequence length.
98
+
99
+ Returns:
100
+ Preprocessed data tuple (x, y, sample_weight).
101
+ """
102
+ if isinstance(x, str):
103
+ x = [x]
104
+ sequence_length = sequence_length or self.sequence_length
105
+ # Pad length to multiples of 16 to meet kernel requirements
106
+ if sequence_length is None:
107
+ raise ValueError("sequence_length must be specified.")
108
+ if keras.config.backend() in ["torch", "jax"]:
109
+ # When using rwkv_ops, ensure sequence_length is divisible by 16.
110
+ try:
111
+ import rwkv_ops # noqa: F401
112
+
113
+ if sequence_length % 16 != 0:
114
+ sequence_length += (16 - sequence_length % 16) % 16
115
+ except ImportError:
116
+ pass
117
+ x = self.tokenizer(x)
118
+
119
+ token_ids, padding_mask = self.packer(
120
+ x, sequence_length=sequence_length + 1, add_end_value=False
121
+ )
122
+
123
+ # The last token does not have a next token, so we truncate it out.
124
+ x = {
125
+ "token_ids": token_ids[..., :-1],
126
+ "padding_mask": padding_mask[..., :-1],
127
+ }
128
+ # Target `y` will be the next token.
129
+ y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
130
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
131
+
132
+ def build(self, input_shape):
133
+ self.packer = StartEndPacker(
134
+ start_value=None,
135
+ end_value=None,
136
+ pad_value=self.tokenizer.pad_token_id,
137
+ sequence_length=self.sequence_length,
138
+ return_padding_mask=True,
139
+ padding_side="left", # RWKV uses left-padding exclusively
140
+ )
141
+ self.built = True
142
+
143
+ @preprocessing_function
144
+ def generate_preprocess(
145
+ self,
146
+ x,
147
+ sequence_length=None,
148
+ ):
149
+ """Preprocess input for generation.
150
+
151
+ Args:
152
+ x: Input text data.
153
+ sequence_length: Maximum generation length.
154
+
155
+ Returns:
156
+ Dictionary with preprocessed inputs for generation.
157
+ """
158
+ if isinstance(x, str):
159
+ x = [x]
160
+
161
+ if not self.built:
162
+ self.build(None)
163
+ # Align with Keras API
164
+ # Input sequence_length is the maximum generation length
165
+ # While self.sequence_length corresponds to the prefill max length
166
+ generate_length = sequence_length
167
+ if sequence_length is None:
168
+ raise ValueError("`sequence_length` must be specified.")
169
+ sequence_length = self.sequence_length
170
+
171
+ x = [t[-sequence_length:] for t in self.tokenizer(x)]
172
+ y = tf.zeros((len(x), generate_length), "int32")
173
+ # Utilize RNN characteristics where prefill and decode are two sequences
174
+ # But the first token of decode should be the last token of prefill
175
+ start_token = [[t[-1]] for t in x]
176
+ x = [np.array(t[:-1]) if len(t) > 1 else [0] for t in x]
177
+ x = tf.ragged.constant(x)
178
+ token_ids, input_padding_mask = self.packer(
179
+ x, sequence_length=sequence_length, add_end_value=False
180
+ )
181
+ start_token = tf.convert_to_tensor(start_token, "int32")
182
+
183
+ y = tf.concat([start_token, y], axis=1)
184
+ padding_mask = tf.not_equal(y, 0)
185
+
186
+ return {
187
+ "token_ids": token_ids,
188
+ "input_padding_mask": input_padding_mask,
189
+ "padding_mask": padding_mask,
190
+ "predict_token_ids": y,
191
+ }
192
+
193
+ @preprocessing_function
194
+ def generate_postprocess(
195
+ self,
196
+ x,
197
+ ):
198
+ """Convert integer token output to strings for generation.
199
+
200
+ This method reverses `generate_preprocess()`, by first removing all
201
+ padding and start/end tokens, and then converting the integer sequence
202
+ back to a string.
203
+
204
+ Args:
205
+ x: Dictionary containing token_ids and padding_mask.
206
+
207
+ Returns:
208
+ Detokenized string output.
209
+ """
210
+ if not self.built:
211
+ self.build(None)
212
+
213
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
214
+ return self.tokenizer.detokenize(token_ids * padding_mask)