keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +6 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/gemma3/__init__.py +5 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +315 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +352 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +306 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +691 -0
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +305 -0
- keras_hub/src/models/gemma3/gemma3_image_converter.py +8 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +79 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +93 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +87 -0
- keras_hub/src/models/gemma3/gemma3_vit.py +608 -0
- keras_hub/src/models/gemma3/rms_normalization.py +26 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/RECORD +20 -8
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/top_level.txt +0 -0
keras_hub/api/layers/__init__.py
CHANGED
@@ -54,6 +54,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import (
|
|
54
54
|
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
|
55
55
|
EfficientNetImageConverter,
|
56
56
|
)
|
57
|
+
from keras_hub.src.models.gemma3.gemma3_image_converter import (
|
58
|
+
Gemma3ImageConverter,
|
59
|
+
)
|
57
60
|
from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
|
58
61
|
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
|
59
62
|
MobileNetImageConverter,
|
keras_hub/api/models/__init__.py
CHANGED
@@ -177,6 +177,12 @@ from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
|
|
177
177
|
GemmaCausalLMPreprocessor,
|
178
178
|
)
|
179
179
|
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
|
180
|
+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
|
181
|
+
from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM
|
182
|
+
from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
|
183
|
+
Gemma3CausalLMPreprocessor,
|
184
|
+
)
|
185
|
+
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
|
180
186
|
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
|
181
187
|
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
|
182
188
|
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
|
@@ -19,6 +19,7 @@ from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer
|
|
19
19
|
from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer
|
20
20
|
from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
|
21
21
|
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
|
22
|
+
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
|
22
23
|
from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
|
23
24
|
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
|
24
25
|
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
|
@@ -0,0 +1,315 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
import keras
|
4
|
+
import numpy as np
|
5
|
+
from keras import ops
|
6
|
+
|
7
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
8
|
+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
9
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
10
|
+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
|
11
|
+
from keras_hub.src.utils.keras_utils import running_on_tpu
|
12
|
+
|
13
|
+
|
14
|
+
class CachedGemma3Attention(keras.layers.Layer):
|
15
|
+
"""A cached grouped query attention layer for Gemma3.
|
16
|
+
|
17
|
+
This is different from Gemma and Gemma2 in several ways:
|
18
|
+
|
19
|
+
- `use_query_key_norm`: Applies RMS Norm on query, key.
|
20
|
+
- `rope_wavelength`: RoPE wavelength differs from local to global attention
|
21
|
+
layers.
|
22
|
+
- `rope_scaling_factor`: RoPE scaling factor differs from local to global
|
23
|
+
attention layers.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
head_dim,
|
29
|
+
num_query_heads,
|
30
|
+
num_key_value_heads,
|
31
|
+
kernel_initializer="glorot_uniform",
|
32
|
+
logit_soft_cap=None,
|
33
|
+
use_sliding_window_attention=False,
|
34
|
+
sliding_window_size=4096,
|
35
|
+
query_head_dim_normalize=True,
|
36
|
+
use_query_key_norm=False,
|
37
|
+
layer_norm_epsilon=1e-6,
|
38
|
+
rope_wavelength=10_000.0,
|
39
|
+
rope_scaling_factor=1.0,
|
40
|
+
dropout=0,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
super().__init__(**kwargs)
|
44
|
+
self.num_query_heads = num_query_heads
|
45
|
+
self.num_key_value_heads = num_key_value_heads
|
46
|
+
self.head_dim = head_dim
|
47
|
+
self.logit_soft_cap = logit_soft_cap
|
48
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
49
|
+
self.sliding_window_size = sliding_window_size
|
50
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
51
|
+
self.use_query_key_norm = use_query_key_norm
|
52
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
53
|
+
self.rope_wavelength = rope_wavelength
|
54
|
+
self.rope_scaling_factor = rope_scaling_factor
|
55
|
+
self.dropout = dropout
|
56
|
+
|
57
|
+
self._kernel_initializer = keras.initializers.get(
|
58
|
+
clone_initializer(kernel_initializer)
|
59
|
+
)
|
60
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
61
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
62
|
+
|
63
|
+
def build(self, inputs_shape):
|
64
|
+
self.hidden_dim = inputs_shape[-1]
|
65
|
+
|
66
|
+
self.query_dense = keras.layers.EinsumDense(
|
67
|
+
"btd,ndh->btnh",
|
68
|
+
output_shape=(None, self.num_query_heads, self.head_dim),
|
69
|
+
kernel_initializer=self._kernel_initializer,
|
70
|
+
dtype=self.dtype_policy,
|
71
|
+
name="query",
|
72
|
+
)
|
73
|
+
self.query_dense.build(inputs_shape)
|
74
|
+
|
75
|
+
self.key_dense = keras.layers.EinsumDense(
|
76
|
+
"bsd,kdh->bskh",
|
77
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
78
|
+
kernel_initializer=self._kernel_initializer,
|
79
|
+
dtype=self.dtype_policy,
|
80
|
+
name="key",
|
81
|
+
)
|
82
|
+
self.key_dense.build(inputs_shape)
|
83
|
+
|
84
|
+
self.value_dense = keras.layers.EinsumDense(
|
85
|
+
"bsd,kdh->bskh",
|
86
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
87
|
+
kernel_initializer=self._kernel_initializer,
|
88
|
+
dtype=self.dtype_policy,
|
89
|
+
name="value",
|
90
|
+
)
|
91
|
+
self.value_dense.build(inputs_shape)
|
92
|
+
|
93
|
+
if self.use_query_key_norm:
|
94
|
+
self.query_norm = RMSNormalization(
|
95
|
+
epsilon=self.layer_norm_epsilon,
|
96
|
+
dtype=self.dtype_policy,
|
97
|
+
name="query_norm",
|
98
|
+
)
|
99
|
+
self.query_norm.build(
|
100
|
+
self.query_dense.compute_output_shape(inputs_shape)
|
101
|
+
)
|
102
|
+
|
103
|
+
self.key_norm = RMSNormalization(
|
104
|
+
epsilon=self.layer_norm_epsilon,
|
105
|
+
dtype=self.dtype_policy,
|
106
|
+
name="key_norm",
|
107
|
+
)
|
108
|
+
self.key_norm.build(
|
109
|
+
self.key_dense.compute_output_shape(inputs_shape)
|
110
|
+
)
|
111
|
+
|
112
|
+
self.dropout_layer = keras.layers.Dropout(
|
113
|
+
rate=self.dropout,
|
114
|
+
dtype=self.dtype_policy,
|
115
|
+
)
|
116
|
+
|
117
|
+
self.output_dense = keras.layers.EinsumDense(
|
118
|
+
equation="btnh,nhd->btd",
|
119
|
+
output_shape=(None, self.hidden_dim),
|
120
|
+
kernel_initializer=self._kernel_initializer,
|
121
|
+
dtype=self.dtype_policy,
|
122
|
+
name="attention_output",
|
123
|
+
)
|
124
|
+
self.output_dense.build(
|
125
|
+
(None, None, self.num_query_heads, self.head_dim)
|
126
|
+
)
|
127
|
+
self.softmax = keras.layers.Softmax(dtype="float32")
|
128
|
+
|
129
|
+
self.rope_layer = RotaryEmbedding(
|
130
|
+
max_wavelength=self.rope_wavelength,
|
131
|
+
scaling_factor=self.rope_scaling_factor,
|
132
|
+
dtype=self.dtype_policy,
|
133
|
+
)
|
134
|
+
|
135
|
+
self.built = True
|
136
|
+
|
137
|
+
def _apply_rope(self, x, start_index):
|
138
|
+
"""Rope rotate q or k."""
|
139
|
+
x = self.rope_layer(x, start_index=start_index)
|
140
|
+
return x
|
141
|
+
|
142
|
+
def _can_use_flash_attention(self):
|
143
|
+
if not has_flash_attention_support():
|
144
|
+
return False
|
145
|
+
if self.dropout > 0.0:
|
146
|
+
return False
|
147
|
+
if self.logit_soft_cap is None:
|
148
|
+
return True
|
149
|
+
sig = inspect.signature(ops.dot_product_attention)
|
150
|
+
# We can currently only run soft capped attention for keras >= 3.10
|
151
|
+
# and only on TPU.
|
152
|
+
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
|
153
|
+
|
154
|
+
def _compute_attention(
|
155
|
+
self,
|
156
|
+
q,
|
157
|
+
k,
|
158
|
+
v,
|
159
|
+
attention_mask,
|
160
|
+
training=False,
|
161
|
+
cache_update_index=0,
|
162
|
+
):
|
163
|
+
if self.query_head_dim_normalize:
|
164
|
+
query_normalization = 1 / np.sqrt(self.head_dim)
|
165
|
+
else:
|
166
|
+
query_normalization = 1 / np.sqrt(
|
167
|
+
self.hidden_dim // self.num_query_heads
|
168
|
+
)
|
169
|
+
if self._can_use_flash_attention():
|
170
|
+
if attention_mask is not None:
|
171
|
+
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
172
|
+
attention_mask = ops.cast(attention_mask, dtype="bool")
|
173
|
+
# Only pass soft cap if needed as not all keras versions support.
|
174
|
+
if self.logit_soft_cap:
|
175
|
+
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
|
176
|
+
else:
|
177
|
+
kwargs = {}
|
178
|
+
return ops.dot_product_attention(
|
179
|
+
query=q,
|
180
|
+
key=k,
|
181
|
+
value=v,
|
182
|
+
mask=attention_mask,
|
183
|
+
scale=query_normalization,
|
184
|
+
**kwargs,
|
185
|
+
)
|
186
|
+
|
187
|
+
q *= ops.cast(query_normalization, dtype=q.dtype)
|
188
|
+
q_shape = ops.shape(q)
|
189
|
+
q = ops.reshape(
|
190
|
+
q,
|
191
|
+
(
|
192
|
+
*q_shape[:-2],
|
193
|
+
self.num_key_value_heads,
|
194
|
+
self.num_query_heads // self.num_key_value_heads,
|
195
|
+
q_shape[-1],
|
196
|
+
),
|
197
|
+
)
|
198
|
+
b, q_len, _, _, h = ops.shape(q)
|
199
|
+
|
200
|
+
# Fallback to standard attention if flash attention is disabled
|
201
|
+
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
|
202
|
+
if self.logit_soft_cap is not None:
|
203
|
+
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
|
204
|
+
attention_logits = ops.multiply(
|
205
|
+
ops.tanh(attention_logits), self.logit_soft_cap
|
206
|
+
)
|
207
|
+
|
208
|
+
if self.use_sliding_window_attention:
|
209
|
+
attention_mask = self._mask_sliding_window(
|
210
|
+
attention_mask,
|
211
|
+
cache_update_index=cache_update_index,
|
212
|
+
)
|
213
|
+
|
214
|
+
attention_mask = attention_mask[:, None, None, :, :]
|
215
|
+
orig_dtype = attention_logits.dtype
|
216
|
+
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
|
217
|
+
attention_softmax = ops.cast(attention_softmax, orig_dtype)
|
218
|
+
|
219
|
+
if self.dropout:
|
220
|
+
attention_softmax = self.dropout_layer(
|
221
|
+
attention_softmax, training=training
|
222
|
+
)
|
223
|
+
|
224
|
+
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
|
225
|
+
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
|
226
|
+
|
227
|
+
def _mask_sliding_window(
|
228
|
+
self,
|
229
|
+
attention_mask,
|
230
|
+
cache_update_index=0,
|
231
|
+
):
|
232
|
+
batch_size, query_len, key_len = ops.shape(attention_mask)
|
233
|
+
# Compute the sliding window for square attention.
|
234
|
+
all_ones = ops.ones((key_len, key_len), "bool")
|
235
|
+
if keras.config.backend() == "tensorflow":
|
236
|
+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
|
237
|
+
# backend. We should fix, but use `band_part` for now.
|
238
|
+
import tensorflow as tf
|
239
|
+
|
240
|
+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
|
241
|
+
band_size = ops.cast(band_size, "int32")
|
242
|
+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
|
243
|
+
else:
|
244
|
+
sliding_mask = ops.triu(
|
245
|
+
all_ones, -1 * self.sliding_window_size + 1
|
246
|
+
) * ops.tril(all_ones, self.sliding_window_size - 1)
|
247
|
+
# Slice the window for short queries during generation.
|
248
|
+
start = (cache_update_index, 0)
|
249
|
+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
|
250
|
+
sliding_mask = ops.expand_dims(sliding_mask, 0)
|
251
|
+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
|
252
|
+
|
253
|
+
def call(
|
254
|
+
self,
|
255
|
+
x,
|
256
|
+
attention_mask=None,
|
257
|
+
cache=None,
|
258
|
+
cache_update_index=0,
|
259
|
+
training=False,
|
260
|
+
):
|
261
|
+
query = self.query_dense(x)
|
262
|
+
|
263
|
+
if self.use_query_key_norm:
|
264
|
+
query = self.query_norm(query)
|
265
|
+
|
266
|
+
query = self._apply_rope(query, cache_update_index)
|
267
|
+
|
268
|
+
if cache is not None:
|
269
|
+
key_cache = cache[:, 0, ...]
|
270
|
+
value_cache = cache[:, 1, ...]
|
271
|
+
key_update = self.key_dense(x)
|
272
|
+
|
273
|
+
if self.use_query_key_norm:
|
274
|
+
key_update = self.key_norm(key_update)
|
275
|
+
|
276
|
+
key_update = self._apply_rope(key_update, cache_update_index)
|
277
|
+
value_update = self.value_dense(x)
|
278
|
+
start = [0, cache_update_index, 0, 0]
|
279
|
+
key = ops.slice_update(key_cache, start, key_update)
|
280
|
+
value = ops.slice_update(value_cache, start, value_update)
|
281
|
+
cache = ops.stack((key, value), axis=1)
|
282
|
+
else:
|
283
|
+
key = self.key_dense(x)
|
284
|
+
|
285
|
+
if self.use_query_key_norm:
|
286
|
+
key = self.key_norm(key)
|
287
|
+
|
288
|
+
key = self._apply_rope(key, cache_update_index)
|
289
|
+
value = self.value_dense(x)
|
290
|
+
|
291
|
+
attention_vec = self._compute_attention(
|
292
|
+
query,
|
293
|
+
key,
|
294
|
+
value,
|
295
|
+
attention_mask,
|
296
|
+
training=training,
|
297
|
+
cache_update_index=cache_update_index,
|
298
|
+
)
|
299
|
+
|
300
|
+
# Wipe attn vec if there are no attended tokens.
|
301
|
+
no_attended_tokens = ops.all(
|
302
|
+
ops.equal(attention_mask, 0), axis=-1, keepdims=True
|
303
|
+
)[..., None]
|
304
|
+
attention_vec = ops.where(
|
305
|
+
no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
|
306
|
+
)
|
307
|
+
|
308
|
+
attention_output = self.output_dense(attention_vec)
|
309
|
+
|
310
|
+
if cache is not None:
|
311
|
+
return attention_output, cache
|
312
|
+
return attention_output
|
313
|
+
|
314
|
+
def compute_output_shape(self, input_shape):
|
315
|
+
return input_shape
|