keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505150407__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,313 @@
1
+ import keras
2
+
3
+ from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
11
+ from keras_hub.src.models.moonshine.moonshine_layers import (
12
+ moonshine_kernel_initializer,
13
+ )
14
+ from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
15
+ MoonshineMultiHeadAttention,
16
+ )
17
+ from keras_hub.src.utils.keras_utils import clone_initializer
18
+
19
+
20
+ @keras.saving.register_keras_serializable(package="keras_hub")
21
+ class MoonshineDecoderBlock(TransformerDecoder):
22
+ """Moonshine decoder block for sequence processing.
23
+
24
+ This layer implements a decoder block that includes self-attention with
25
+ causal masking, cross-attention with precomputed key/value pairs, and a
26
+ feedforward network.
27
+
28
+ Args:
29
+ hidden_dim: int. The dimensionality of the model's hidden
30
+ representations.
31
+ intermediate_dim: int. The dimensionality of the intermediate
32
+ representations in the feedforward network.
33
+ num_heads: int. The number of attention heads for multi-head attention
34
+ mechanisms.
35
+ feedforward_expansion_factor: int, optional. A multiplicative factor for
36
+ scaling the feedforward network dimension. Defaults to 4.
37
+ use_swiglu_activation: bool, optional. Whether to use the SwiGLU
38
+ activation in the feedforward network for improved performance.
39
+ Defaults to True.
40
+ pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
41
+ dimension to be a multiple of this value for performance
42
+ optimization. Defaults to None.
43
+ initializer_range: float, optional. The standard deviation of the
44
+ truncated normal distribution used to initialize model weights.
45
+ Defaults to 0.02.
46
+ attention_bias: bool, optional. Whether to add a bias term to the
47
+ attention computations. Defaults to False.
48
+ attention_dropout: float, optional. The dropout rate applied to
49
+ attention weights during training. Defaults to 0.0.
50
+ dtype: str, optional. The data type to use for model computations and
51
+ weights. Defaults to None.
52
+ **kwargs: Additional keyword arguments passed to the base layer.
53
+ """
54
+
55
+ # References:
56
+ # Defined and formulated based on the UsefulSensors implementation of the
57
+ # DecoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L348-L466).
58
+
59
+ def __init__(
60
+ self,
61
+ hidden_dim,
62
+ intermediate_dim,
63
+ num_heads,
64
+ feedforward_expansion_factor=4,
65
+ use_swiglu_activation=True,
66
+ pad_head_dim_to_multiple_of=None,
67
+ initializer_range=0.02,
68
+ attention_bias=False,
69
+ attention_dropout=0.0,
70
+ dtype=None,
71
+ **kwargs,
72
+ ):
73
+ kwargs.pop("dropout", None)
74
+ kwargs.pop("activation", None)
75
+ kwargs.pop("kernel_initializer", None)
76
+ self.kernel_initializer = moonshine_kernel_initializer(
77
+ initializer_range=initializer_range
78
+ )
79
+ super().__init__(
80
+ intermediate_dim=intermediate_dim,
81
+ num_heads=num_heads,
82
+ dropout=attention_dropout,
83
+ activation="gelu" if use_swiglu_activation else "silu",
84
+ kernel_initializer=clone_initializer(self.kernel_initializer),
85
+ dtype=dtype,
86
+ **kwargs,
87
+ )
88
+ self.initializer_range = initializer_range
89
+ self.hidden_dim = hidden_dim
90
+ self.intermediate_dim = intermediate_dim
91
+ self.num_heads = num_heads
92
+ self.feedforward_expansion_factor = feedforward_expansion_factor
93
+ self.use_swiglu_activation = use_swiglu_activation
94
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
95
+ self.attention_dropout = attention_dropout
96
+ self.attention_bias = attention_bias
97
+
98
+ self.head_dim = hidden_dim // num_heads
99
+ if pad_head_dim_to_multiple_of is not None:
100
+ self.head_dim = (
101
+ (self.head_dim + pad_head_dim_to_multiple_of - 1)
102
+ // pad_head_dim_to_multiple_of
103
+ ) * pad_head_dim_to_multiple_of
104
+
105
+ self.norm1 = keras.layers.LayerNormalization(
106
+ axis=-1,
107
+ epsilon=1e-5,
108
+ center=False,
109
+ scale=True,
110
+ dtype=self.dtype,
111
+ )
112
+ self.self_attention = MoonshineMultiHeadAttention(
113
+ num_heads=num_heads,
114
+ key_dim=self.head_dim,
115
+ use_bias=False,
116
+ kernel_initializer=clone_initializer(self.kernel_initializer),
117
+ attention_bias=attention_bias,
118
+ attention_dropout=attention_dropout,
119
+ use_causal_mask=True,
120
+ apply_rotary_embedding=True,
121
+ dtype=self.dtype,
122
+ )
123
+ self.norm2 = keras.layers.LayerNormalization(
124
+ axis=-1,
125
+ epsilon=1e-5,
126
+ center=False,
127
+ scale=True,
128
+ dtype=self.dtype,
129
+ )
130
+ self.cross_attention = MoonshineMultiHeadAttention(
131
+ num_heads=num_heads,
132
+ key_dim=self.head_dim,
133
+ use_bias=False,
134
+ kernel_initializer=clone_initializer(self.kernel_initializer),
135
+ attention_bias=attention_bias,
136
+ attention_dropout=attention_dropout,
137
+ use_causal_mask=False,
138
+ apply_rotary_embedding=False,
139
+ dtype=self.dtype,
140
+ )
141
+ self.norm3 = keras.layers.LayerNormalization(
142
+ axis=-1,
143
+ epsilon=1e-5,
144
+ center=False,
145
+ scale=True,
146
+ dtype=self.dtype,
147
+ )
148
+ self.ff = MoonshineMLP(
149
+ hidden_dim=hidden_dim,
150
+ feedforward_expansion_factor=feedforward_expansion_factor,
151
+ use_swiglu_activation=use_swiglu_activation,
152
+ initializer_range=initializer_range,
153
+ dtype=self.dtype,
154
+ )
155
+
156
+ def build(self, decoder_sequence_shape, encoder_sequence_shape=None):
157
+ if encoder_sequence_shape is None:
158
+ raise ValueError(
159
+ "Encoder sequence shape must be provided for "
160
+ "MoonshineDecoderBlock."
161
+ )
162
+ context_shape = encoder_sequence_shape # Shape of context
163
+
164
+ # Build sublayers.
165
+ self.norm1.build(decoder_sequence_shape)
166
+ self.norm2.build(decoder_sequence_shape)
167
+ self.norm3.build(decoder_sequence_shape)
168
+
169
+ self.self_attention.build(
170
+ query_shape=decoder_sequence_shape,
171
+ key_shape=decoder_sequence_shape,
172
+ value_shape=decoder_sequence_shape,
173
+ )
174
+
175
+ self.cross_attention.build(
176
+ query_shape=decoder_sequence_shape,
177
+ key_shape=context_shape,
178
+ value_shape=context_shape,
179
+ )
180
+
181
+ self.ff.build(decoder_sequence_shape)
182
+ self.built = True
183
+
184
+ def _compute_self_attention_mask(
185
+ self,
186
+ decoder_sequence,
187
+ decoder_padding_mask,
188
+ self_attention_cache=None,
189
+ self_attention_cache_update_index=None,
190
+ ):
191
+ decoder_mask = merge_padding_and_attention_mask(
192
+ inputs=decoder_sequence,
193
+ padding_mask=decoder_padding_mask,
194
+ attention_mask=None,
195
+ )
196
+ if self.self_attention.use_causal_mask:
197
+ batch_size = keras.ops.shape(decoder_sequence)[0]
198
+ output_length = keras.ops.shape(decoder_sequence)[1]
199
+ current_cache_update_index = (
200
+ 0
201
+ if self_attention_cache_update_index is None
202
+ else self_attention_cache_update_index
203
+ )
204
+ if self_attention_cache is not None:
205
+ input_length = keras.ops.shape(self_attention_cache)[2]
206
+ else:
207
+ input_length = output_length
208
+ causal_mask = compute_causal_mask(
209
+ batch_size,
210
+ input_length,
211
+ output_length,
212
+ current_cache_update_index,
213
+ )
214
+ return (
215
+ keras.ops.minimum(decoder_mask, causal_mask)
216
+ if decoder_mask is not None
217
+ else causal_mask
218
+ )
219
+ return decoder_mask
220
+
221
+ def call(
222
+ self,
223
+ decoder_sequence,
224
+ encoder_sequence,
225
+ rotary_embedding,
226
+ encoder_attention_mask=None,
227
+ decoder_padding_mask=None,
228
+ encoder_padding_mask=None,
229
+ self_attention_cache=None,
230
+ self_attention_cache_update_index=None,
231
+ cross_attention_cache=None,
232
+ cross_attention_cache_update_index=None,
233
+ training=None,
234
+ ):
235
+ x = decoder_sequence
236
+ context = encoder_sequence
237
+ has_self_attention_cache = self_attention_cache is not None
238
+ has_cross_attention_cache = cross_attention_cache is not None
239
+
240
+ self_attention_mask = self._compute_self_attention_mask(
241
+ decoder_sequence=x,
242
+ decoder_padding_mask=decoder_padding_mask,
243
+ self_attention_cache=self_attention_cache,
244
+ self_attention_cache_update_index=self_attention_cache_update_index,
245
+ )
246
+
247
+ # Self attention block.
248
+ residual = x
249
+ x_norm1 = self.norm1(x)
250
+ x_self_attn = self.self_attention(
251
+ query=x_norm1,
252
+ key=x_norm1,
253
+ value=x_norm1,
254
+ rotary_embedding=rotary_embedding,
255
+ cache=self_attention_cache,
256
+ cache_update_index=self_attention_cache_update_index,
257
+ attention_mask=self_attention_mask,
258
+ training=training,
259
+ )
260
+ if has_self_attention_cache:
261
+ x_self_attn, self_attention_cache = x_self_attn
262
+ x = x_self_attn + residual
263
+ # Cross attention block.
264
+ residual = x
265
+ x_norm2 = self.norm2(x)
266
+ cross_attention_mask = merge_padding_and_attention_mask(
267
+ inputs=encoder_sequence,
268
+ padding_mask=encoder_padding_mask,
269
+ attention_mask=encoder_attention_mask,
270
+ )
271
+ x_cross_attn = self.cross_attention(
272
+ query=x_norm2,
273
+ key=context,
274
+ value=context,
275
+ cache=cross_attention_cache,
276
+ cache_update_index=cross_attention_cache_update_index,
277
+ attention_mask=cross_attention_mask,
278
+ training=training,
279
+ )
280
+ if has_cross_attention_cache:
281
+ x_cross_attn, cross_attention_cache = x_cross_attn
282
+ x = x_cross_attn + residual
283
+ residual = x
284
+ x_norm3 = self.norm3(x)
285
+ x_ff = self.ff(x_norm3)
286
+ x = x_ff + residual
287
+
288
+ if has_self_attention_cache:
289
+ return x, self_attention_cache
290
+ return x
291
+
292
+ def compute_output_shape(
293
+ self, decoder_sequence_shape, encoder_sequence_shape=None
294
+ ):
295
+ return decoder_sequence_shape
296
+
297
+ def get_config(self):
298
+ config = super().get_config()
299
+ config.update(
300
+ {
301
+ "hidden_dim": self.hidden_dim,
302
+ "intermediate_dim": self.intermediate_dim,
303
+ "num_heads": self.num_heads,
304
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
305
+ "use_swiglu_activation": self.use_swiglu_activation,
306
+ "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of, # noqa: E501
307
+ "initializer_range": self.initializer_range,
308
+ "attention_bias": self.attention_bias,
309
+ "attention_dropout": self.attention_dropout,
310
+ "dtype": self.dtype,
311
+ }
312
+ )
313
+ return config
@@ -0,0 +1,212 @@
1
+ import keras
2
+
3
+ from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
4
+ from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
5
+ from keras_hub.src.models.moonshine.moonshine_layers import (
6
+ moonshine_kernel_initializer,
7
+ )
8
+ from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
9
+ MoonshineMultiHeadAttention,
10
+ )
11
+ from keras_hub.src.utils.keras_utils import clone_initializer
12
+
13
+
14
+ @keras.saving.register_keras_serializable(package="keras_hub")
15
+ class MoonshineEncoderBlock(TransformerEncoder):
16
+ """
17
+ Moonshine encoder block for sequence processing.
18
+
19
+ Implements a standard encoder block with self-attention and feedforward
20
+ sublayers, including residual connections and layer normalization. The
21
+ implementation utilizes Moonshine-specific attention and feedforward
22
+ mechanisms.
23
+
24
+ Args:
25
+ hidden_dim: int. The dimensionality of the model's hidden
26
+ representations throughout the block.
27
+ intermediate_dim: int. The dimensionality used in projections before
28
+ applying non-linearities.
29
+ num_heads: int. The number of attention heads for multi-head attention
30
+ computation.
31
+ feedforward_expansion_factor: int, optional. A multiplier for expanding
32
+ the dimension in the feedforward network. Defaults to 4.
33
+ use_swiglu_activation: bool, optional. Whether to use SwiGLU activation
34
+ (True) or LinearGeLU (False) in the feedforward sublayer. Defaults
35
+ to False.
36
+ pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
37
+ dimension to be a multiple of this value for hardware optimization.
38
+ Defaults to None.
39
+ initializer_range: float, optional. The standard deviation of the
40
+ truncated normal distribution used for weight initialization.
41
+ Defaults to 0.02.
42
+ attention_bias: bool, optional. Whether to use a bias term in the
43
+ attention mechanism. Defaults to False.
44
+ attention_dropout: float, optional. The dropout rate applied to the
45
+ attention weights. Defaults to 0.0.
46
+ dtype: str, optional. The data type to use for model computations and
47
+ weights. Defaults to None.
48
+ **kwargs: Additional keyword arguments passed to the base layer.
49
+ """
50
+
51
+ # References:
52
+ # Defined and formulated based on the UsefulSensors implementation of the
53
+ # EncoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L124-L161).
54
+
55
+ def __init__(
56
+ self,
57
+ hidden_dim,
58
+ intermediate_dim,
59
+ num_heads,
60
+ feedforward_expansion_factor=4,
61
+ use_swiglu_activation=False,
62
+ pad_head_dim_to_multiple_of=None,
63
+ dtype=None,
64
+ initializer_range=0.02,
65
+ attention_bias=False,
66
+ attention_dropout=0.0,
67
+ **kwargs,
68
+ ):
69
+ kwargs.pop("dropout", None)
70
+ kwargs.pop("activation", None)
71
+ kwargs.pop("kernel_initializer", None)
72
+ self.kernel_initializer = moonshine_kernel_initializer(
73
+ initializer_range=initializer_range
74
+ )
75
+ super().__init__(
76
+ intermediate_dim=intermediate_dim,
77
+ num_heads=num_heads,
78
+ dropout=attention_dropout,
79
+ activation="gelu" if use_swiglu_activation else "silu",
80
+ kernel_initializer=clone_initializer(self.kernel_initializer),
81
+ dtype=dtype,
82
+ **kwargs,
83
+ )
84
+ self.attention_bias = attention_bias
85
+ self.attention_dropout = attention_dropout
86
+ self.initializer_range = initializer_range
87
+ self.hidden_dim = hidden_dim
88
+ self.intermediate_dim = intermediate_dim
89
+ self.num_heads = num_heads
90
+ self.feedforward_expansion_factor = feedforward_expansion_factor
91
+ self.use_swiglu_activation = use_swiglu_activation
92
+
93
+ # Self-attention sublayers.
94
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
95
+
96
+ self.head_dim = hidden_dim // num_heads
97
+ if pad_head_dim_to_multiple_of is not None:
98
+ self.head_dim = (
99
+ (self.head_dim + pad_head_dim_to_multiple_of - 1)
100
+ // pad_head_dim_to_multiple_of
101
+ ) * pad_head_dim_to_multiple_of
102
+
103
+ self.self_attention_layer = MoonshineMultiHeadAttention(
104
+ num_heads=num_heads,
105
+ key_dim=self.head_dim,
106
+ use_bias=False,
107
+ kernel_initializer=clone_initializer(self.kernel_initializer),
108
+ attention_bias=attention_bias,
109
+ attention_dropout=attention_dropout,
110
+ use_causal_mask=False,
111
+ apply_rotary_embedding=True,
112
+ name="self_attention_layer",
113
+ dtype=self.dtype,
114
+ )
115
+ self.self_attention_layer_norm = keras.layers.LayerNormalization(
116
+ axis=-1,
117
+ epsilon=1e-5,
118
+ center=False,
119
+ scale=True,
120
+ name="self_attention_layer_norm",
121
+ dtype=self.dtype,
122
+ )
123
+
124
+ # Feedforward sublayers.
125
+ self.feedforward_layer_norm = keras.layers.LayerNormalization(
126
+ axis=-1,
127
+ epsilon=1e-5,
128
+ center=False,
129
+ scale=True,
130
+ name="feedforward_layer_norm",
131
+ dtype=self.dtype,
132
+ )
133
+ self.feedforward = MoonshineMLP(
134
+ hidden_dim=hidden_dim,
135
+ feedforward_expansion_factor=feedforward_expansion_factor,
136
+ use_swiglu_activation=use_swiglu_activation,
137
+ initializer_range=initializer_range,
138
+ name="feedforward",
139
+ dtype=self.dtype,
140
+ )
141
+
142
+ def build(self, input_shape):
143
+ if isinstance(input_shape, dict):
144
+ encoder_input_shape = input_shape["input_values"]
145
+ else:
146
+ encoder_input_shape = input_shape
147
+ # Build self-attention branch.
148
+ self.self_attention_layer_norm.build(encoder_input_shape)
149
+ self.self_attention_layer.build(
150
+ encoder_input_shape, encoder_input_shape, encoder_input_shape
151
+ )
152
+ # Build feedforward branch.
153
+ self.feedforward_layer_norm.build(encoder_input_shape)
154
+ # The feedforward layer expects the last dimension to be hidden_dim.
155
+ feed_forward_input_shape = list(encoder_input_shape)
156
+ feed_forward_input_shape[-1] = self.hidden_dim
157
+ self.feedforward.build(tuple(feed_forward_input_shape))
158
+ self.built = True
159
+
160
+ def call(
161
+ self,
162
+ inputs,
163
+ rotary_embedding,
164
+ attention_mask=None,
165
+ training=None,
166
+ **kwargs,
167
+ ):
168
+ x = inputs
169
+
170
+ # Self-attention block with residual connection.
171
+ attention_residual = x
172
+ x = self.self_attention_layer_norm(x)
173
+ x = self.self_attention_layer(
174
+ query=x,
175
+ value=x,
176
+ key=x,
177
+ rotary_embedding=rotary_embedding,
178
+ attention_mask=attention_mask,
179
+ training=training,
180
+ **kwargs,
181
+ )
182
+ x = x + attention_residual
183
+
184
+ # Feedforward block with residual connection.
185
+ ff_residual = x
186
+ x = self.feedforward_layer_norm(x)
187
+ x = self.feedforward(x)
188
+ x = x + ff_residual
189
+
190
+ return x
191
+
192
+ def compute_output_shape(self, input_shape):
193
+ return input_shape
194
+
195
+ def get_config(self):
196
+ # ==== Config ====
197
+ config = super().get_config()
198
+ config.update(
199
+ {
200
+ "hidden_dim": self.hidden_dim,
201
+ "intermediate_dim": self.intermediate_dim,
202
+ "num_heads": self.num_heads,
203
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
204
+ "use_swiglu_activation": self.use_swiglu_activation,
205
+ "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of,
206
+ "initializer_range": self.initializer_range,
207
+ "attention_bias": self.attention_bias,
208
+ "attention_dropout": self.attention_dropout,
209
+ "dtype": self.dtype,
210
+ }
211
+ )
212
+ return config