keras-hub 0.24.0.dev0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
- keras_hub/src/models/esm/esm_attention.py +11 -4
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
- keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
- keras_hub/src/models/gpt_oss/__init__.py +5 -0
- keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +221 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
- keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
- keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
- keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
- keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
- keras_hub/src/models/llama3/llama3_presets.py +1 -1
- keras_hub/src/models/parseq/parseq_decoder.py +21 -9
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
- keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/METADATA +1 -1
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/RECORD +27 -16
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/WHEEL +0 -0
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
7
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GptOssAttention(keras.layers.Layer):
|
|
11
|
+
"""A cached attention layer with sliding window and sink tokens.
|
|
12
|
+
|
|
13
|
+
This layer implements the attention mechanism described in the GPT-OSS
|
|
14
|
+
paper. It includes grouped-query attention, rotary position embeddings,
|
|
15
|
+
sliding window attention, and sink tokens for improved performance on
|
|
16
|
+
long sequences.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
num_query_heads: int. The number of query attention heads.
|
|
20
|
+
num_key_value_heads: int. The number of key and value attention
|
|
21
|
+
heads.
|
|
22
|
+
rope_max_wavelength: int. The maximum wavelength for the
|
|
23
|
+
rotary position embedding. Defaults to 10000.
|
|
24
|
+
rope_scaling_factor: float. The scaling factor for the
|
|
25
|
+
rotary position embedding. Defaults to 1.0.
|
|
26
|
+
kernel_initializer: str. The initializer for the kernel
|
|
27
|
+
weights. Defaults to "glorot_uniform".
|
|
28
|
+
sliding_window: int. The size of the sliding window.
|
|
29
|
+
Defaults to 4096.
|
|
30
|
+
dropout: float. The dropout rate. Defaults to 0.
|
|
31
|
+
head_dim: int. Head dimension for attention. If None,
|
|
32
|
+
calculated as hidden_dim // num_query_heads. Defaults to None.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
num_query_heads,
|
|
38
|
+
num_key_value_heads,
|
|
39
|
+
rope_max_wavelength=10000,
|
|
40
|
+
rope_scaling_factor=1.0,
|
|
41
|
+
kernel_initializer="glorot_uniform",
|
|
42
|
+
sliding_window=4096,
|
|
43
|
+
dropout=0,
|
|
44
|
+
head_dim=None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(**kwargs)
|
|
48
|
+
self.num_query_heads = num_query_heads
|
|
49
|
+
self.num_key_value_heads = num_key_value_heads
|
|
50
|
+
self.sliding_window = sliding_window
|
|
51
|
+
self.dropout = dropout
|
|
52
|
+
self.head_dim = head_dim
|
|
53
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
54
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
55
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
|
56
|
+
self._kernel_initializer = keras.initializers.get(
|
|
57
|
+
clone_initializer(kernel_initializer)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def build(self, inputs_shape):
|
|
61
|
+
# Einsum variables:
|
|
62
|
+
# b = batch size
|
|
63
|
+
# q = query length
|
|
64
|
+
# k = key/value length
|
|
65
|
+
# m = the model's hidden_dim
|
|
66
|
+
# u = num query heads
|
|
67
|
+
# v = num key/value heads
|
|
68
|
+
# h = head dim
|
|
69
|
+
self._hidden_dim = inputs_shape[-1]
|
|
70
|
+
|
|
71
|
+
if self.head_dim is not None:
|
|
72
|
+
self._head_dim = self.head_dim
|
|
73
|
+
else:
|
|
74
|
+
self._head_dim = self._hidden_dim // self.num_query_heads
|
|
75
|
+
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
|
|
76
|
+
|
|
77
|
+
self._rotary_dim = (self._head_dim // 2) * 2
|
|
78
|
+
|
|
79
|
+
self.query_dense = keras.layers.EinsumDense(
|
|
80
|
+
equation="bqm,muh->bquh",
|
|
81
|
+
output_shape=(None, self.num_query_heads, self._head_dim),
|
|
82
|
+
bias_axes="uh",
|
|
83
|
+
kernel_initializer=self._kernel_initializer,
|
|
84
|
+
bias_initializer="zeros",
|
|
85
|
+
dtype=self.dtype_policy,
|
|
86
|
+
name="query",
|
|
87
|
+
)
|
|
88
|
+
self.query_dense.build(inputs_shape)
|
|
89
|
+
|
|
90
|
+
self.key_dense = keras.layers.EinsumDense(
|
|
91
|
+
equation="bkm,mvh->bkvh",
|
|
92
|
+
output_shape=(
|
|
93
|
+
None,
|
|
94
|
+
self.num_key_value_heads,
|
|
95
|
+
self._head_dim,
|
|
96
|
+
),
|
|
97
|
+
bias_axes="vh",
|
|
98
|
+
kernel_initializer=self._kernel_initializer,
|
|
99
|
+
bias_initializer="zeros",
|
|
100
|
+
dtype=self.dtype_policy,
|
|
101
|
+
name="key",
|
|
102
|
+
)
|
|
103
|
+
self.key_dense.build(inputs_shape)
|
|
104
|
+
|
|
105
|
+
self.value_dense = keras.layers.EinsumDense(
|
|
106
|
+
equation="bkm,mvh->bkvh",
|
|
107
|
+
output_shape=(
|
|
108
|
+
None,
|
|
109
|
+
self.num_key_value_heads,
|
|
110
|
+
self._head_dim,
|
|
111
|
+
),
|
|
112
|
+
bias_axes="vh",
|
|
113
|
+
kernel_initializer=self._kernel_initializer,
|
|
114
|
+
bias_initializer="zeros",
|
|
115
|
+
dtype=self.dtype_policy,
|
|
116
|
+
name="value",
|
|
117
|
+
)
|
|
118
|
+
self.value_dense.build(inputs_shape)
|
|
119
|
+
|
|
120
|
+
self.dropout_layer = keras.layers.Dropout(
|
|
121
|
+
rate=self.dropout,
|
|
122
|
+
dtype=self.dtype_policy,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.output_dense = keras.layers.EinsumDense(
|
|
126
|
+
equation="bquh,uhm->bqm",
|
|
127
|
+
output_shape=(None, self._hidden_dim),
|
|
128
|
+
bias_axes="m",
|
|
129
|
+
kernel_initializer=self._kernel_initializer,
|
|
130
|
+
bias_initializer="zeros",
|
|
131
|
+
dtype=self.dtype_policy,
|
|
132
|
+
name="attention_output",
|
|
133
|
+
)
|
|
134
|
+
self.output_dense.build(
|
|
135
|
+
(None, None, self.num_query_heads, self._head_dim)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.rotary_embedding_layer = RotaryEmbedding(
|
|
139
|
+
max_wavelength=self.rope_max_wavelength,
|
|
140
|
+
scaling_factor=self.rope_scaling_factor, # YaRN scaling factor
|
|
141
|
+
rope_type="yarn",
|
|
142
|
+
beta_fast=32.0,
|
|
143
|
+
beta_slow=1.0,
|
|
144
|
+
original_max_position_embeddings=4096,
|
|
145
|
+
dtype=self.dtype_policy,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.sinks = self.add_weight(
|
|
149
|
+
shape=(self.num_query_heads,),
|
|
150
|
+
initializer="random_normal",
|
|
151
|
+
dtype=self.dtype,
|
|
152
|
+
name="sinks",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
self._dot_product_equation = "bquh,bkuh->buqk"
|
|
156
|
+
self._combine_equation = "buqk,bkuh->bquh"
|
|
157
|
+
|
|
158
|
+
self.built = True
|
|
159
|
+
|
|
160
|
+
def call(
|
|
161
|
+
self,
|
|
162
|
+
hidden_states,
|
|
163
|
+
attention_mask=None,
|
|
164
|
+
cache=None,
|
|
165
|
+
cache_update_index=None,
|
|
166
|
+
training=None,
|
|
167
|
+
):
|
|
168
|
+
start_index = (
|
|
169
|
+
cache_update_index if cache_update_index is not None else 0
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
query = self.query_dense(hidden_states)
|
|
173
|
+
|
|
174
|
+
# Compute RoPE for queries (only
|
|
175
|
+
# to first _rotary_dim dimensions)
|
|
176
|
+
if self._rotary_dim < self._head_dim:
|
|
177
|
+
query_rot = query[..., : self._rotary_dim]
|
|
178
|
+
query_rot = self.rotary_embedding_layer(
|
|
179
|
+
query_rot, start_index=start_index
|
|
180
|
+
)
|
|
181
|
+
query = ops.concatenate(
|
|
182
|
+
[query_rot, query[..., self._rotary_dim :]], axis=-1
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
query = self.rotary_embedding_layer(query, start_index=start_index)
|
|
186
|
+
|
|
187
|
+
def _compute_key_value(x):
|
|
188
|
+
key, value = self.key_dense(x), self.value_dense(x)
|
|
189
|
+
# Compute RoPE for keys (only apply to first _rotary_dim dimensions)
|
|
190
|
+
if self._rotary_dim < self._head_dim:
|
|
191
|
+
key_rot = key[..., : self._rotary_dim]
|
|
192
|
+
key_rot = self.rotary_embedding_layer(
|
|
193
|
+
key_rot, start_index=start_index
|
|
194
|
+
)
|
|
195
|
+
key = ops.concatenate(
|
|
196
|
+
[key_rot, key[..., self._rotary_dim :]], axis=-1
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
key = self.rotary_embedding_layer(key, start_index=start_index)
|
|
200
|
+
return key, value
|
|
201
|
+
|
|
202
|
+
if cache is not None:
|
|
203
|
+
key_cache = cache[:, 0, ...]
|
|
204
|
+
value_cache = cache[:, 1, ...]
|
|
205
|
+
if cache_update_index is None:
|
|
206
|
+
key = key_cache
|
|
207
|
+
value = value_cache
|
|
208
|
+
else:
|
|
209
|
+
key_update, value_update = _compute_key_value(hidden_states)
|
|
210
|
+
start = [0, cache_update_index, 0, 0]
|
|
211
|
+
key = ops.slice_update(key_cache, start, key_update)
|
|
212
|
+
value = ops.slice_update(value_cache, start, value_update)
|
|
213
|
+
cache = ops.stack((key, value), axis=1)
|
|
214
|
+
else:
|
|
215
|
+
if cache_update_index is not None:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"`cache_update_index` should not be set if `cache` is "
|
|
218
|
+
f"`None`. Received: cache={cache}, "
|
|
219
|
+
f"cache_update_index={cache_update_index}"
|
|
220
|
+
)
|
|
221
|
+
key, value = _compute_key_value(hidden_states)
|
|
222
|
+
|
|
223
|
+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
|
|
224
|
+
# -> [batch_shape, seq_len, num_heads, head_dim]
|
|
225
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
|
226
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
|
227
|
+
|
|
228
|
+
attention_output = self._compute_attention(
|
|
229
|
+
query, key, value, attention_mask, start_index
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
attention_output = self.dropout_layer(
|
|
233
|
+
attention_output, training=training
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
attention_output = self.output_dense(attention_output)
|
|
237
|
+
|
|
238
|
+
if cache is not None:
|
|
239
|
+
return attention_output, cache
|
|
240
|
+
return attention_output
|
|
241
|
+
|
|
242
|
+
def _compute_attention(
|
|
243
|
+
self, query, key, value, attention_mask=None, start_index=0
|
|
244
|
+
):
|
|
245
|
+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
|
|
246
|
+
attention_scores = ops.multiply(
|
|
247
|
+
attention_scores,
|
|
248
|
+
ops.cast(self._inv_norm_factor, self.compute_dtype),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Apply sliding window mask if specified
|
|
252
|
+
if self.sliding_window is not None and self.sliding_window > 0:
|
|
253
|
+
q_len = ops.shape(attention_scores)[-2]
|
|
254
|
+
kv_len = ops.shape(attention_scores)[-1]
|
|
255
|
+
|
|
256
|
+
# Query positions are offset by start_index during generation
|
|
257
|
+
q_positions = ops.arange(q_len) + start_index
|
|
258
|
+
kv_positions = ops.arange(kv_len)
|
|
259
|
+
|
|
260
|
+
# Mask true for positions outside sliding window
|
|
261
|
+
# For causal attention: mask if kv_pos < q_pos - sliding_window
|
|
262
|
+
mask = (
|
|
263
|
+
kv_positions[None, :]
|
|
264
|
+
>= q_positions[:, None] - self.sliding_window
|
|
265
|
+
)
|
|
266
|
+
if self.compute_dtype == "float32":
|
|
267
|
+
sliding_adder = ops.cast(-1e9, self.compute_dtype)
|
|
268
|
+
else:
|
|
269
|
+
sliding_adder = ops.cast(-1e4, self.compute_dtype)
|
|
270
|
+
attention_scores = ops.where(
|
|
271
|
+
mask[None, None, :, :], attention_scores, sliding_adder
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if attention_mask is not None:
|
|
275
|
+
# The mask is a boolean tensor, True for positions to be masked.
|
|
276
|
+
# We add a large negative number to the masked positions.
|
|
277
|
+
# Use a large negative value for masking
|
|
278
|
+
if self.compute_dtype == "float32":
|
|
279
|
+
adder = ops.cast(-1e9, self.compute_dtype)
|
|
280
|
+
else:
|
|
281
|
+
adder = ops.cast(-1e4, self.compute_dtype)
|
|
282
|
+
attention_scores = ops.where(
|
|
283
|
+
attention_mask[:, None, :, :], attention_scores, adder
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Handle sink tokens by concatenating them to the logits.
|
|
287
|
+
b = ops.shape(attention_scores)[0]
|
|
288
|
+
q = ops.shape(attention_scores)[2]
|
|
289
|
+
|
|
290
|
+
sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
|
|
291
|
+
sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
|
|
292
|
+
# attention_scores shape: [b, num_heads, q, k]
|
|
293
|
+
# sinks shape: [b, num_heads, q, 1]
|
|
294
|
+
# We need to concatenate along the last dimension
|
|
295
|
+
combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)
|
|
296
|
+
|
|
297
|
+
# Stabilize logits before softmax for numerical stability.
|
|
298
|
+
max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
|
|
299
|
+
max_logits = ops.stop_gradient(max_logits)
|
|
300
|
+
combined_logits = combined_logits - max_logits
|
|
301
|
+
|
|
302
|
+
probs = ops.softmax(combined_logits, axis=-1)
|
|
303
|
+
|
|
304
|
+
# Remove the sink probabilities before computing the output.
|
|
305
|
+
attention_scores = probs[..., :-1]
|
|
306
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
|
307
|
+
|
|
308
|
+
attention_output = ops.einsum(
|
|
309
|
+
self._combine_equation, attention_scores, value
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return attention_output
|
|
313
|
+
|
|
314
|
+
def get_config(self):
|
|
315
|
+
config = super().get_config()
|
|
316
|
+
config.update(
|
|
317
|
+
{
|
|
318
|
+
"num_query_heads": self.num_query_heads,
|
|
319
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
320
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
321
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
322
|
+
"kernel_initializer": keras.initializers.serialize(
|
|
323
|
+
self._kernel_initializer
|
|
324
|
+
),
|
|
325
|
+
"sliding_window": self.sliding_window,
|
|
326
|
+
"dropout": self.dropout,
|
|
327
|
+
"head_dim": self.head_dim,
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
return config
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
+
ReversibleEmbedding,
|
|
6
|
+
)
|
|
7
|
+
from keras_hub.src.models.backbone import Backbone
|
|
8
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_decoder import (
|
|
9
|
+
GptOssTransformerDecoder,
|
|
10
|
+
)
|
|
11
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_layer_norm import (
|
|
12
|
+
GptOssLayerNormalization,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _gpt_oss_kernel_initializer(stddev=0.02):
|
|
17
|
+
return keras.initializers.RandomNormal(stddev=stddev)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@keras_hub_export("keras_hub.models.GptOssBackbone")
|
|
21
|
+
class GptOssBackbone(Backbone):
|
|
22
|
+
"""A GPT-style Transformer with a Mixture of Experts.
|
|
23
|
+
|
|
24
|
+
This network implements a GPT-style decoder network with Mixture of Expert
|
|
25
|
+
(MoE) layers, similar to the architecture described in
|
|
26
|
+
["Mixtral of Experts"](https://arxiv.org/pdf/2401.04088) but with
|
|
27
|
+
customizations found in some open-source GPT models. It includes the
|
|
28
|
+
embedding lookups and transformer layers.
|
|
29
|
+
|
|
30
|
+
The default constructor gives a fully customizable, randomly initialized
|
|
31
|
+
GptOss model with any number of layers, heads, and embedding
|
|
32
|
+
dimensions. To load preset architectures and weights, use the `from_preset`
|
|
33
|
+
constructor.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
vocabulary_size: int. The size of the token vocabulary.
|
|
37
|
+
num_layers: int. The number of transformer layers.
|
|
38
|
+
num_query_heads: int. The number of query attention heads for
|
|
39
|
+
each transformer.
|
|
40
|
+
hidden_dim: int. The size of the transformer encoding and pooling
|
|
41
|
+
layers.
|
|
42
|
+
intermediate_dim: int. The output dimension of the first Dense layer
|
|
43
|
+
in a three-layer feedforward network for each transformer.
|
|
44
|
+
num_key_value_heads: int. The number of key and value attention heads
|
|
45
|
+
for each transformer.
|
|
46
|
+
num_experts: int. The number of experts for the MoE layers.
|
|
47
|
+
top_k: int. The number of experts to use for each token.
|
|
48
|
+
Defaults to `2`.
|
|
49
|
+
rope_max_wavelength: int. The maximum angular wavelength of
|
|
50
|
+
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
|
51
|
+
rope_scaling_factor: float. The scaling factor for
|
|
52
|
+
calculation of roatary embedding. Defaults to `1.0`.
|
|
53
|
+
layer_norm_epsilon: float. Epsilon for the layer
|
|
54
|
+
normalization layers in the transformer decoder. Defaults to `1e-6`.
|
|
55
|
+
sliding_window: int. The sliding window for the attention
|
|
56
|
+
layers. This controls the maximum cache size for the attention
|
|
57
|
+
layers in each transformer decoder. Only `sliding_window` number
|
|
58
|
+
of tokens are saved in the cache and used to generate the next
|
|
59
|
+
token. Defaults to `4096`.
|
|
60
|
+
head_dim: int. Head dimension for attention layers. This
|
|
61
|
+
parameter is accepted for HuggingFace compatibility but ignored.
|
|
62
|
+
The head dimension is calculated dynamically as hidden_dim //
|
|
63
|
+
num_query_heads. Defaults to `None`.
|
|
64
|
+
dropout: float. Attention dropout probability.
|
|
65
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
66
|
+
for model computations and weights. Note that some computations,
|
|
67
|
+
such as softmax and layer normalization, will always be done at
|
|
68
|
+
`float32` precision regardless of dtype.
|
|
69
|
+
|
|
70
|
+
Examples:
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
import numpy as np
|
|
74
|
+
import keras_hub
|
|
75
|
+
|
|
76
|
+
# Load a pretrained GptOss backbone from a preset.
|
|
77
|
+
model = keras_hub.models.GptOssBackbone.from_preset("gpt_oss_20b_en")
|
|
78
|
+
|
|
79
|
+
input_data = {
|
|
80
|
+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
|
81
|
+
"padding_mask": np.array(
|
|
82
|
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32"
|
|
83
|
+
),
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
model(input_data)
|
|
87
|
+
|
|
88
|
+
# Randomly initialized GptOss decoder with custom config.
|
|
89
|
+
model = keras_hub.models.GptOssBackbone(
|
|
90
|
+
vocabulary_size=10,
|
|
91
|
+
hidden_dim=512,
|
|
92
|
+
num_layers=2,
|
|
93
|
+
num_query_heads=32,
|
|
94
|
+
num_key_value_heads=8,
|
|
95
|
+
intermediate_dim=1024,
|
|
96
|
+
num_experts=4,
|
|
97
|
+
top_k=2,
|
|
98
|
+
sliding_window=256,
|
|
99
|
+
layer_norm_epsilon=1e-6,
|
|
100
|
+
dtype="float32"
|
|
101
|
+
)
|
|
102
|
+
model(input_data)
|
|
103
|
+
```
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
vocabulary_size,
|
|
109
|
+
num_layers,
|
|
110
|
+
num_query_heads,
|
|
111
|
+
hidden_dim,
|
|
112
|
+
intermediate_dim,
|
|
113
|
+
num_key_value_heads,
|
|
114
|
+
num_experts,
|
|
115
|
+
top_k=2,
|
|
116
|
+
rope_max_wavelength=10000,
|
|
117
|
+
rope_scaling_factor=1.0,
|
|
118
|
+
layer_norm_epsilon=1e-6,
|
|
119
|
+
sliding_window=4096,
|
|
120
|
+
head_dim=None,
|
|
121
|
+
dropout=0,
|
|
122
|
+
output_router_logits=False,
|
|
123
|
+
dtype=None,
|
|
124
|
+
**kwargs,
|
|
125
|
+
):
|
|
126
|
+
# === Layers ===
|
|
127
|
+
self.token_embedding = ReversibleEmbedding(
|
|
128
|
+
input_dim=vocabulary_size,
|
|
129
|
+
output_dim=hidden_dim,
|
|
130
|
+
tie_weights=False,
|
|
131
|
+
embeddings_initializer=_gpt_oss_kernel_initializer(stddev=0.01),
|
|
132
|
+
dtype=dtype,
|
|
133
|
+
name="token_embedding",
|
|
134
|
+
)
|
|
135
|
+
self.transformer_layers = []
|
|
136
|
+
for i in range(num_layers):
|
|
137
|
+
layer = GptOssTransformerDecoder(
|
|
138
|
+
intermediate_dim=intermediate_dim,
|
|
139
|
+
num_query_heads=num_query_heads,
|
|
140
|
+
num_key_value_heads=num_key_value_heads,
|
|
141
|
+
num_experts=num_experts,
|
|
142
|
+
top_k=top_k,
|
|
143
|
+
output_router_logits=output_router_logits,
|
|
144
|
+
rope_max_wavelength=rope_max_wavelength,
|
|
145
|
+
rope_scaling_factor=rope_scaling_factor,
|
|
146
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
|
147
|
+
kernel_initializer=_gpt_oss_kernel_initializer(stddev=0.02),
|
|
148
|
+
# GPT-OSS uses SW attention in every other layer
|
|
149
|
+
sliding_window=sliding_window if i % 2 == 1 else None,
|
|
150
|
+
dropout=dropout,
|
|
151
|
+
head_dim=head_dim,
|
|
152
|
+
dtype=dtype,
|
|
153
|
+
name=f"transformer_layer_{i}",
|
|
154
|
+
)
|
|
155
|
+
self.transformer_layers.append(layer)
|
|
156
|
+
self.layer_norm = GptOssLayerNormalization(
|
|
157
|
+
epsilon=layer_norm_epsilon,
|
|
158
|
+
dtype=dtype,
|
|
159
|
+
name="sequence_output_layernorm",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# === Functional Model ===
|
|
163
|
+
token_id_input = keras.Input(
|
|
164
|
+
shape=(None,), dtype="int32", name="token_ids"
|
|
165
|
+
)
|
|
166
|
+
padding_mask_input = keras.Input(
|
|
167
|
+
shape=(None,), dtype="int32", name="padding_mask"
|
|
168
|
+
)
|
|
169
|
+
x = self.token_embedding(token_id_input)
|
|
170
|
+
for transformer_layer in self.transformer_layers:
|
|
171
|
+
x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
|
|
172
|
+
sequence_output = self.layer_norm(x)
|
|
173
|
+
super().__init__(
|
|
174
|
+
inputs={
|
|
175
|
+
"token_ids": token_id_input,
|
|
176
|
+
"padding_mask": padding_mask_input,
|
|
177
|
+
},
|
|
178
|
+
outputs=sequence_output,
|
|
179
|
+
dtype=dtype,
|
|
180
|
+
**kwargs,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# === Config ===
|
|
184
|
+
self.vocabulary_size = vocabulary_size
|
|
185
|
+
self.num_layers = num_layers
|
|
186
|
+
self.num_query_heads = num_query_heads
|
|
187
|
+
self.hidden_dim = hidden_dim
|
|
188
|
+
self.intermediate_dim = intermediate_dim
|
|
189
|
+
self.num_key_value_heads = num_key_value_heads
|
|
190
|
+
self.num_experts = num_experts
|
|
191
|
+
self.top_k = top_k
|
|
192
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
193
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
194
|
+
self.sliding_window = sliding_window
|
|
195
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
196
|
+
self.dropout = dropout
|
|
197
|
+
self.output_router_logits = output_router_logits
|
|
198
|
+
self.head_dim = head_dim
|
|
199
|
+
|
|
200
|
+
def get_config(self):
|
|
201
|
+
config = super().get_config()
|
|
202
|
+
config.update(
|
|
203
|
+
{
|
|
204
|
+
"vocabulary_size": self.vocabulary_size,
|
|
205
|
+
"num_layers": self.num_layers,
|
|
206
|
+
"num_query_heads": self.num_query_heads,
|
|
207
|
+
"hidden_dim": self.hidden_dim,
|
|
208
|
+
"intermediate_dim": self.intermediate_dim,
|
|
209
|
+
"num_experts": self.num_experts,
|
|
210
|
+
"top_k": self.top_k,
|
|
211
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
212
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
213
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
214
|
+
"sliding_window": self.sliding_window,
|
|
215
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
216
|
+
"dropout": self.dropout,
|
|
217
|
+
"output_router_logits": self.output_router_logits,
|
|
218
|
+
"head_dim": self.head_dim,
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
return config
|