keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,301 @@
1
+ import keras
2
+
3
+ try:
4
+ import tensorflow as tf
5
+ except ImportError:
6
+ tf = None
7
+
8
+ from keras_hub.src.api_export import keras_hub_export
9
+ from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
10
+ from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
11
+
12
+
13
+ @keras_hub_export("keras_hub.layers.MoonshineAudioConverter")
14
+ class MoonshineAudioConverter(AudioConverter):
15
+ """Moonshine audio preprocessing layer.
16
+
17
+ This layer processes raw audio waveforms for the Moonshine ASR model. Audio
18
+ is formatted as a batched tensor at a 16kHz sample rate and validated for
19
+ length (0.1 to 64 seconds). The layer handles padding and optional
20
+ normalization. It does not contain trainable weights.
21
+
22
+ Args:
23
+ sampling_rate: int, optional. The audio sampling rate in Hz. Defaults to
24
+ 16,000.
25
+ padding_value: float, optional. The value for padding. Defaults to 0.0.
26
+ do_normalize: bool, optional. Whether to normalize inputs. Defaults to
27
+ False.
28
+ **kwargs: Additional keyword arguments passed to the base AudioConverter
29
+ class for customizing the underlying preprocessing behavior.
30
+
31
+ Call arguments:
32
+ - `inputs`: The raw audio data to be processed. It should be a tensor of
33
+ shape `(batch_size, time_steps, 1)` for mono audio. If the input has
34
+ shape `(batch_size, time_steps)`, the layer will add the channel
35
+ dimension.
36
+ - `sampling_rate`: The sampling rate of the audio in Hz. If
37
+ provided, it must match the expected sampling rate set during
38
+ initialization (default is 16,000 Hz). If not provided, the expected
39
+ sampling rate is taken from the initialization arguments.
40
+ - `padding`: The padding strategy to apply. If provided, can be one of:
41
+ - `"longest"`: If `pad_to_multiple_of` is set, pads the audio to
42
+ make the time_steps dimension a multiple of `pad_to_multiple_of`.
43
+ - `"max_length"`: Pads or truncates the audio to `max_length` time
44
+ steps. If `pad_to_multiple_of` is set, the target length will be
45
+ the smallest multiple of `pad_to_multiple_of` that is greater than
46
+ or equal to `max_length`.
47
+ - If not specified or `None`, no padding is applied.
48
+ - `max_length`: The target number of time steps when `padding` is
49
+ `"max_length"`. If not provided and `padding` is `"max_length"`, no
50
+ padding or truncation is applied.
51
+ - `pad_to_multiple_of`: If set, the padded time_steps will be a
52
+ multiple of this value for the chosen padding strategy.
53
+
54
+ Examples:
55
+ ```python
56
+ import keras
57
+ from keras_hub.layers import MoonshineAudioConverter
58
+
59
+ # Create a dummy audio input (1 second at 16kHz).
60
+ dummy_audio = keras.ops.convert_to_tensor(
61
+ [[0.1] * 16000],
62
+ dtype="float32"
63
+ )
64
+ dummy_audio = keras.ops.expand_dims(dummy_audio, axis=-1)
65
+
66
+ # Initialize the preprocessor.
67
+ preprocessor = MoonshineAudioConverter(do_normalize=True)
68
+
69
+ # Process the audio.
70
+ processed_audio = preprocessor(dummy_audio)
71
+
72
+ # Output shape.
73
+ print(processed_audio.shape) # Expected: (1, 16000, 1) or padded length
74
+ ```
75
+ """
76
+
77
+ # References:
78
+ # Defined and formulated based on the UsefulSensors implementation of audio
79
+ # preprocessing logic (https://github.com/usefulsensors/moonshine/blob/main/moonshine/transcribe.py).
80
+
81
+ backbone_cls = MoonshineBackbone
82
+
83
+ def __init__(
84
+ self,
85
+ sampling_rate=16000,
86
+ padding_value=0.0,
87
+ do_normalize=False,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(**kwargs)
91
+ self._convert_input_args = False
92
+ self._allow_non_tensor_positional_args = True
93
+ self.sampling_rate = sampling_rate
94
+ self.padding_value = padding_value
95
+ self.do_normalize = do_normalize
96
+
97
+ def call(
98
+ self,
99
+ inputs,
100
+ sampling_rate=None,
101
+ padding=None,
102
+ max_length=None,
103
+ pad_to_multiple_of=None,
104
+ ):
105
+ # Validate sampling rate.
106
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
107
+ raise ValueError(
108
+ f"Expected sampling_rate {self.sampling_rate}, got "
109
+ f"{sampling_rate}"
110
+ )
111
+
112
+ # Ensure inputs are (batch_size, time_steps, 1).
113
+ input_shape = keras.ops.shape(inputs)
114
+ input_rank = len(input_shape)
115
+ if input_rank == 2:
116
+ processed_inputs = keras.ops.expand_dims(inputs, axis=-1)
117
+ elif input_rank == 3:
118
+ processed_inputs = inputs
119
+ else:
120
+ raise ValueError(
121
+ "Inputs must be mono audio: (batch_size, time_steps, 1)"
122
+ )
123
+
124
+ # Get original length and validate duration.
125
+ current_shape = keras.ops.shape(processed_inputs)
126
+ original_length = current_shape[1]
127
+ duration = (
128
+ keras.ops.cast(original_length, keras.backend.floatx())
129
+ / self.sampling_rate
130
+ )
131
+ # Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/transcribe.py#L20
132
+ is_invalid_duration = keras.ops.logical_or(
133
+ keras.ops.less(duration, 0.1), keras.ops.greater(duration, 64.0)
134
+ )
135
+
136
+ def print_warning_fn():
137
+ import warnings
138
+
139
+ warnings.warn(
140
+ "Audio duration must be between 0.1 and 64 seconds. For "
141
+ "transcribing longer segments, pre-segment your audio and "
142
+ "provide shorter segments."
143
+ )
144
+ return keras.ops.convert_to_tensor(True, dtype="bool")
145
+
146
+ is_tf_symbolic = (
147
+ tf is not None
148
+ and hasattr(processed_inputs, "graph")
149
+ and hasattr(processed_inputs.graph, "as_graph_def")
150
+ )
151
+ use_tf_graph_ops = tf is not None and is_tf_symbolic
152
+ if use_tf_graph_ops and keras.config.backend() != "torch":
153
+ _ = tf.cond(
154
+ is_invalid_duration,
155
+ print_warning_fn,
156
+ lambda: keras.ops.convert_to_tensor(False, dtype="bool"),
157
+ )
158
+ else:
159
+ if keras.ops.convert_to_numpy(is_invalid_duration):
160
+ print_warning_fn()
161
+
162
+ # Handle padding.
163
+ if padding == "longest":
164
+ target_length = original_length
165
+ if pad_to_multiple_of:
166
+ target_length = (
167
+ (target_length + pad_to_multiple_of - 1)
168
+ // pad_to_multiple_of
169
+ ) * pad_to_multiple_of
170
+
171
+ needs_padding = keras.ops.greater(target_length, original_length)
172
+
173
+ def pad_fn():
174
+ padding_amount = target_length - original_length
175
+ paddings = [[0, 0], [0, padding_amount], [0, 0]]
176
+ if use_tf_graph_ops and keras.config.backend() != "tensorflow":
177
+ return tf.pad(
178
+ processed_inputs,
179
+ paddings,
180
+ mode="CONSTANT",
181
+ constant_values=float(self.padding_value),
182
+ )
183
+ else:
184
+ return keras.ops.pad(
185
+ processed_inputs,
186
+ paddings,
187
+ mode="constant",
188
+ constant_values=self.padding_value,
189
+ )
190
+
191
+ if use_tf_graph_ops and keras.config.backend() != "torch":
192
+ processed_inputs = tf.cond(
193
+ needs_padding, pad_fn, lambda: processed_inputs
194
+ )
195
+ else:
196
+ processed_inputs = keras.ops.cond(
197
+ needs_padding, pad_fn, lambda: processed_inputs
198
+ )
199
+
200
+ elif padding == "max_length" and max_length is not None:
201
+ target_length_const = max_length
202
+ if pad_to_multiple_of:
203
+ target_length_const = (
204
+ (target_length_const + pad_to_multiple_of - 1)
205
+ // pad_to_multiple_of
206
+ ) * pad_to_multiple_of
207
+
208
+ needs_padding = keras.ops.less(original_length, target_length_const)
209
+ needs_truncating = keras.ops.greater(
210
+ original_length, target_length_const
211
+ )
212
+
213
+ def pad_fn():
214
+ padding_amount = target_length_const - original_length
215
+ paddings = [[0, 0], [0, padding_amount], [0, 0]]
216
+ if use_tf_graph_ops and keras.config.backend() != "tensorflow":
217
+ return tf.pad(
218
+ processed_inputs,
219
+ paddings,
220
+ mode="CONSTANT",
221
+ constant_values=float(self.padding_value),
222
+ )
223
+ else:
224
+ return keras.ops.pad(
225
+ processed_inputs,
226
+ paddings,
227
+ mode="constant",
228
+ constant_values=self.padding_value,
229
+ )
230
+
231
+ def trunc_fn():
232
+ if use_tf_graph_ops and keras.config.backend() != "tensorflow":
233
+ return processed_inputs[:, :target_length_const, :]
234
+ else:
235
+ return keras.ops.slice(
236
+ processed_inputs,
237
+ [0, 0, 0],
238
+ [-1, target_length_const, -1],
239
+ )
240
+
241
+ if use_tf_graph_ops and keras.config.backend() != "torch":
242
+ processed_inputs = tf.cond(
243
+ needs_padding,
244
+ pad_fn,
245
+ lambda: tf.cond(
246
+ needs_truncating, trunc_fn, lambda: processed_inputs
247
+ ),
248
+ )
249
+ else:
250
+ needs_padding = keras.ops.less(
251
+ original_length, target_length_const
252
+ )
253
+ needs_truncating = keras.ops.greater(
254
+ original_length, target_length_const
255
+ )
256
+ needs_padding_bool = keras.ops.convert_to_numpy(needs_padding)
257
+ needs_truncating_bool = keras.ops.convert_to_numpy(
258
+ needs_truncating
259
+ )
260
+
261
+ if needs_padding_bool:
262
+ padding_amount = target_length_const - original_length
263
+ paddings = [[0, 0], [0, padding_amount], [0, 0]]
264
+ processed_inputs = keras.ops.pad(
265
+ processed_inputs,
266
+ paddings,
267
+ mode="constant",
268
+ constant_values=self.padding_value,
269
+ )
270
+ elif needs_truncating_bool:
271
+ processed_inputs = processed_inputs[
272
+ :, :target_length_const, :
273
+ ]
274
+
275
+ # Normalize if enabled.
276
+ if self.do_normalize:
277
+ mean = keras.ops.mean(processed_inputs, axis=1, keepdims=True)
278
+ var = keras.ops.var(processed_inputs, axis=1, keepdims=True)
279
+ processed_inputs = (processed_inputs - mean) / keras.ops.sqrt(
280
+ var + 1e-7
281
+ )
282
+
283
+ return processed_inputs
284
+
285
+ def compute_output_shape(self, input_shape):
286
+ # [batch_size, time_steps] → [batch_size, time_steps, 1].
287
+ if len(input_shape) == 2 or len(input_shape) == 3:
288
+ return (input_shape[0], None, 1)
289
+ else:
290
+ raise ValueError("Input shape must be rank 2 or 3.")
291
+
292
+ def get_config(self):
293
+ config = super().get_config()
294
+ config.update(
295
+ {
296
+ "sampling_rate": self.sampling_rate,
297
+ "padding_value": self.padding_value,
298
+ "do_normalize": self.do_normalize,
299
+ }
300
+ )
301
+ return config
@@ -0,0 +1,383 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.audio_to_text import AudioToText
5
+ from keras_hub.src.models.moonshine.moonshine_audio_to_text_preprocessor import ( # noqa: E501
6
+ MoonshineAudioToTextPreprocessor,
7
+ )
8
+ from keras_hub.src.models.moonshine.moonshine_backbone import Arange
9
+ from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
10
+ from keras_hub.src.models.moonshine.moonshine_backbone import (
11
+ compute_output_lengths,
12
+ )
13
+ from keras_hub.src.utils.tensor_utils import any_equal
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.MoonshineAudioToText")
17
+ class MoonshineAudioToText(AudioToText):
18
+ """An end-to-end Moonshine model for audio-to-text tasks.
19
+
20
+ A Seq2Seq LM designed for audio-to-text tasks, such as speech recognition.
21
+ The encoder processes audio features, and the decoder generates text
22
+ transcriptions. You can finetune `MoonshineAudioToText` for any
23
+ audio-to-text task (e.g., live transcription or voice commands).
24
+
25
+ This model includes a `generate()` method for text generation based on audio
26
+ inputs and an optional text prompt for the decoder. The generation strategy
27
+ is controlled by a `sampler` argument passed to `compile()`. By default,
28
+ `"top_k"` sampling is used.
29
+
30
+ Args:
31
+ backbone: A `keras_hub.models.MoonshineBackbone` instance.
32
+ preprocessor: A `keras_hub.models.MoonshineAudioToTextPreprocessor` or
33
+ `None`. If `None`, inputs must be preprocessed before calling the
34
+ model.
35
+
36
+ Examples:
37
+ ```python
38
+ # Initialize model from preset.
39
+ moonshine_lm = keras_hub.models.MoonshineAudioToText.from_preset(
40
+ "moonshine_base"
41
+ )
42
+
43
+ # Generate with single audio input.
44
+ audio_tensor = keras.random.normal((1, 16000, 1))
45
+ moonshine_lm.generate({"audio": audio_tensor})
46
+
47
+ # Generate with text prompt.
48
+ moonshine_lm.generate({"audio": audio_tensor, "text": "quick"})
49
+
50
+ # Use different sampling strategy.
51
+ moonshine_lm.compile(sampler="greedy")
52
+ moonshine_lm.generate({"audio": audio_tensor})
53
+ ```
54
+ """
55
+
56
+ # References:
57
+ # Defined and formulated based on the Hugging Face implementation of the
58
+ # MoonshineForConditionalGeneration class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1509-L1626).
59
+
60
+ backbone_cls = MoonshineBackbone
61
+ preprocessor_cls = MoonshineAudioToTextPreprocessor
62
+
63
+ def __init__(self, backbone, preprocessor=None, **kwargs):
64
+ # === Layers ===
65
+ self.backbone = backbone
66
+ self.preprocessor = preprocessor
67
+
68
+ # === Functional Model ===
69
+ inputs = backbone.input
70
+ hidden_states = backbone(inputs)["decoder_sequence_output"]
71
+ outputs = backbone.token_embedding(hidden_states, reverse=True)
72
+ super().__init__(
73
+ inputs=inputs,
74
+ outputs=outputs,
75
+ **kwargs,
76
+ )
77
+
78
+ def call_decoder_with_cache(
79
+ self,
80
+ encoder_hidden_states,
81
+ encoder_padding_mask,
82
+ decoder_token_ids,
83
+ self_attention_cache=None,
84
+ self_attention_cache_update_index=None,
85
+ cross_attention_cache=None,
86
+ ):
87
+ """Process decoder inputs with attention caching for efficient
88
+ generation.
89
+
90
+ Args:
91
+ encoder_hidden_states: Tensor. Encoder outputs.
92
+ encoder_padding_mask: Tensor. Padding mask for encoder outputs.
93
+ decoder_token_ids: Tensor. Decoder input token IDs.
94
+ self_attention_cache: Tensor, optional. Cache for self-attention
95
+ layers.
96
+ self_attention_cache_update_index: int, optional. Index for cache
97
+ updates.
98
+ cross_attention_cache: Tensor, optional. Cache for cross-attention
99
+ layers. This cache is computed once and reused.
100
+
101
+ Returns:
102
+ Tuple: Tuple of (logits, hidden_states, new_self_attention_cache,
103
+ cross_attention_cache).
104
+ """
105
+ tokens = self.backbone.token_embedding(decoder_token_ids)
106
+ x = tokens
107
+
108
+ # Cache management for audio-to-text generation.
109
+ self_attention_caches = []
110
+ position = keras.ops.array(
111
+ [self_attention_cache_update_index], dtype="int32"
112
+ )
113
+ rotary_embedding = self.backbone.decoder_rotary_embedding(position)
114
+
115
+ for i, layer in enumerate(self.backbone.decoder_blocks):
116
+ current_self_cache = self_attention_cache[:, i, ...]
117
+ current_cross_cache = cross_attention_cache[:, i, ...]
118
+ x, new_self_cache = layer(
119
+ decoder_sequence=x,
120
+ encoder_sequence=encoder_hidden_states,
121
+ rotary_embedding=rotary_embedding,
122
+ encoder_padding_mask=encoder_padding_mask,
123
+ self_attention_cache=current_self_cache,
124
+ self_attention_cache_update_index=self_attention_cache_update_index,
125
+ cross_attention_cache=current_cross_cache,
126
+ training=False,
127
+ )
128
+ # Update self-attention cache.
129
+ self_attention_caches.append(new_self_cache)
130
+
131
+ # [batch_size, num_layers, 2, seq_len, num_heads, head_dim].
132
+ new_self_attention_cache = keras.ops.stack(
133
+ self_attention_caches, axis=1
134
+ )
135
+ hidden_states = self.backbone.decoder_post_norm(x)
136
+ logits = self.backbone.token_embedding(hidden_states, reverse=True)
137
+ return (
138
+ logits,
139
+ hidden_states,
140
+ new_self_attention_cache,
141
+ cross_attention_cache,
142
+ )
143
+
144
+ def _build_cache(
145
+ self,
146
+ encoder_input_values,
147
+ encoder_padding_mask,
148
+ decoder_token_ids,
149
+ decoder_padding_mask,
150
+ ):
151
+ """Build initial cache states from inputs."""
152
+ encoder_hidden_states, encoder_attention_mask_for_decoder = (
153
+ self.call_encoder(
154
+ encoder_input_values=encoder_input_values,
155
+ padding_mask=encoder_padding_mask,
156
+ )
157
+ )
158
+ precomputed_cross_caches = []
159
+ for layer in self.backbone.decoder_blocks:
160
+ cross_k = layer.cross_attention._key_dense(encoder_hidden_states)
161
+ cross_v = layer.cross_attention._value_dense(encoder_hidden_states)
162
+ layer_cross_cache = keras.ops.stack([cross_k, cross_v], axis=1)
163
+ precomputed_cross_caches.append(layer_cross_cache)
164
+ precomputed_cross_cache = keras.ops.stack(
165
+ precomputed_cross_caches, axis=1
166
+ )
167
+ batch_size = keras.ops.shape(encoder_input_values)[0]
168
+ num_layers = self.backbone.decoder_num_layers
169
+ num_heads = self.backbone.decoder_num_heads
170
+ head_dim = self.backbone.hidden_dim // self.backbone.decoder_num_heads
171
+ if self.backbone.pad_head_dim_to_multiple_of is not None:
172
+ head_dim = (
173
+ (head_dim + self.backbone.pad_head_dim_to_multiple_of - 1)
174
+ // self.backbone.pad_head_dim_to_multiple_of
175
+ ) * self.backbone.pad_head_dim_to_multiple_of
176
+ # Use the full sequence length for the cache dimension.
177
+ cache_length = keras.ops.shape(decoder_token_ids)[1]
178
+ initial_self_cache_shape = (
179
+ batch_size,
180
+ num_layers,
181
+ 2,
182
+ cache_length,
183
+ num_heads,
184
+ head_dim,
185
+ )
186
+ initial_self_cache = keras.ops.zeros(
187
+ initial_self_cache_shape, dtype=self.compute_dtype
188
+ )
189
+ tokens = self.backbone.token_embedding(decoder_token_ids)
190
+ x = tokens
191
+ positions = keras.ops.arange(0, cache_length, dtype="int32")
192
+ rotary_embedding = self.backbone.decoder_rotary_embedding(positions)
193
+ seeded_self_caches = []
194
+ for i, layer in enumerate(self.backbone.decoder_blocks):
195
+ current_initial_self_cache = initial_self_cache[:, i, ...]
196
+ current_precomputed_cross_cache = precomputed_cross_cache[:, i, ...]
197
+ x, seeded_self_cache_layer = layer(
198
+ decoder_sequence=x,
199
+ encoder_sequence=encoder_hidden_states,
200
+ rotary_embedding=rotary_embedding,
201
+ decoder_padding_mask=decoder_padding_mask,
202
+ encoder_padding_mask=encoder_attention_mask_for_decoder,
203
+ self_attention_cache=current_initial_self_cache,
204
+ self_attention_cache_update_index=0,
205
+ cross_attention_cache=current_precomputed_cross_cache,
206
+ training=False,
207
+ )
208
+ seeded_self_caches.append(seeded_self_cache_layer)
209
+ hidden_states = self.backbone.decoder_post_norm(x)
210
+ self_attn_cache = keras.ops.stack(seeded_self_caches, axis=1)
211
+ return (
212
+ hidden_states,
213
+ self_attn_cache,
214
+ precomputed_cross_cache,
215
+ encoder_hidden_states,
216
+ encoder_attention_mask_for_decoder,
217
+ )
218
+
219
+ def call_encoder(self, encoder_input_values, padding_mask):
220
+ """Process audio input through the encoder stack."""
221
+ x = self.backbone.conv1(encoder_input_values)
222
+ x = self.backbone.tanh_after_conv1(x)
223
+ x = self.backbone.group_norm(x)
224
+ x = self.backbone.conv2(x)
225
+ x = self.backbone.gelu_after_conv2(x)
226
+ x = self.backbone.conv3(x)
227
+ x = self.backbone.gelu_after_conv3(x)
228
+ original_lengths = keras.ops.sum(
229
+ keras.ops.cast(padding_mask, "int32"), axis=1
230
+ )
231
+ output_lengths = compute_output_lengths(original_lengths)
232
+ padding_mask = self.backbone._compute_mask_layer(x, output_lengths)
233
+ positions = Arange(name="encoder_positions")(x)
234
+ rotary_embedding = self.backbone.encoder_rotary_embedding(positions)
235
+ x = self.backbone.encoder_dropout(x, training=False)
236
+ for transformer_layer in self.backbone.encoder_blocks:
237
+ x = transformer_layer(
238
+ inputs=x,
239
+ rotary_embedding=rotary_embedding,
240
+ attention_mask=padding_mask,
241
+ training=False,
242
+ )
243
+ x = self.backbone.encoder_final_layer_norm(x)
244
+ return x, padding_mask
245
+
246
+ # Source: https://github.com/huggingface/transformers/blob/9e94801146ceeb3b215bbdb9492be74d7d7b7210/src/transformers/generation/utils.py#L1970-L2463
247
+ def generate_step(self, inputs, stop_token_ids=None):
248
+ """A compilable generation function for a batch of inputs.
249
+
250
+ This function represents the inner, XLA-compilable, generation function
251
+ for a single batch of inputs. Inputs should have the same structure as
252
+ model inputs, a dictionary with keys `"encoder_input_values"`,
253
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
254
+ `"decoder_padding_mask"`.
255
+
256
+ Args:
257
+ inputs: A dictionary with four keys - `"encoder_input_values"`,
258
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
259
+ `"decoder_padding_mask"`, with batched tensor values.
260
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
261
+ sequences have produced a new stop token, generation
262
+ will stop.
263
+
264
+ Returns:
265
+ Dictionary: A dictionary with two keys - `"decoder_token_ids"`
266
+ containing the updated token sequence with newly generated
267
+ tokens, and `"decoder_padding_mask"` containing the updated
268
+ padding mask for the generated sequence.
269
+ """
270
+ encoder_input_values = inputs["encoder_input_values"]
271
+ encoder_padding_mask = inputs["encoder_padding_mask"]
272
+ decoder_token_ids = inputs["decoder_token_ids"]
273
+ decoder_padding_mask = inputs["decoder_padding_mask"]
274
+
275
+ if (
276
+ encoder_input_values is None
277
+ or encoder_padding_mask is None
278
+ or decoder_token_ids is None
279
+ ):
280
+ raise ValueError("Input tensors cannot be None")
281
+
282
+ (
283
+ hidden_states,
284
+ self_attention_cache,
285
+ cross_attention_cache,
286
+ encoder_hidden_states,
287
+ encoder_attention_mask_for_decoder,
288
+ ) = self._build_cache(
289
+ encoder_input_values,
290
+ encoder_padding_mask,
291
+ decoder_token_ids,
292
+ decoder_padding_mask,
293
+ )
294
+ row_lengths = keras.ops.sum(
295
+ keras.ops.cast(decoder_padding_mask, "int32"),
296
+ axis=-1,
297
+ )
298
+ index = keras.ops.min(row_lengths)
299
+
300
+ def next(prompt, cache, index):
301
+ if isinstance(cache, tuple) and len(cache) == 2:
302
+ current_self_attention_cache = cache[0]
303
+ current_cross_attention_cache = cache[1]
304
+ elif cache is not None and not isinstance(cache, tuple):
305
+ current_self_attention_cache = cache
306
+ current_cross_attention_cache = cross_attention_cache
307
+ else:
308
+ cache = None
309
+ cache_index = index - 1
310
+ num_samples = keras.ops.shape(prompt)[0]
311
+ next_token_input = keras.ops.slice(
312
+ prompt, [0, cache_index], [num_samples, 1]
313
+ )
314
+
315
+ batch_size = keras.ops.shape(encoder_input_values)[0]
316
+
317
+ def repeat_tensor(x):
318
+ if keras.ops.shape(x)[0] == num_samples:
319
+ return x
320
+ return keras.ops.repeat(
321
+ x, repeats=num_samples // batch_size, axis=0
322
+ )
323
+
324
+ cross_attention_cache_repeated = repeat_tensor(
325
+ current_cross_attention_cache
326
+ )
327
+ logits, hidden_states, new_self_attention_cache, _ = (
328
+ self.call_decoder_with_cache(
329
+ encoder_hidden_states=repeat_tensor(encoder_hidden_states),
330
+ encoder_padding_mask=repeat_tensor(
331
+ encoder_attention_mask_for_decoder
332
+ ),
333
+ decoder_token_ids=next_token_input,
334
+ self_attention_cache=current_self_attention_cache,
335
+ self_attention_cache_update_index=cache_index,
336
+ cross_attention_cache=cross_attention_cache_repeated,
337
+ )
338
+ )
339
+ return (
340
+ logits[:, 0, :],
341
+ hidden_states[:, 0, :],
342
+ (new_self_attention_cache, current_cross_attention_cache),
343
+ )
344
+
345
+ decoder_token_ids = self.sampler(
346
+ next=next,
347
+ prompt=decoder_token_ids,
348
+ cache=(self_attention_cache, cross_attention_cache),
349
+ index=index,
350
+ mask=keras.ops.cast(
351
+ decoder_token_ids != self.preprocessor.tokenizer.pad_token_id
352
+ if self.preprocessor is not None
353
+ else decoder_padding_mask,
354
+ dtype="bool",
355
+ ),
356
+ stop_token_ids=stop_token_ids,
357
+ hidden_states=hidden_states,
358
+ model=self,
359
+ )
360
+
361
+ if stop_token_ids is not None:
362
+ end_locations = any_equal(
363
+ decoder_token_ids,
364
+ stop_token_ids,
365
+ decoder_token_ids == self.preprocessor.tokenizer.pad_token_id
366
+ if self.preprocessor is not None
367
+ else False,
368
+ )
369
+ end_locations = keras.ops.cast(end_locations, "int32")
370
+ cumsum = keras.ops.cumsum(end_locations, axis=-1)
371
+ overflow = cumsum - end_locations
372
+ decoder_padding_mask = keras.ops.logical_not(
373
+ keras.ops.cast(overflow, "bool")
374
+ )
375
+ else:
376
+ decoder_padding_mask = keras.ops.ones_like(
377
+ decoder_token_ids, dtype="bool"
378
+ )
379
+
380
+ return {
381
+ "decoder_token_ids": decoder_token_ids,
382
+ "decoder_padding_mask": decoder_padding_mask,
383
+ }