keras-hub-nightly 0.20.0.dev202503250356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,6 +54,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import (
54
54
  from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
55
55
  EfficientNetImageConverter,
56
56
  )
57
+ from keras_hub.src.models.gemma3.gemma3_image_converter import (
58
+ Gemma3ImageConverter,
59
+ )
57
60
  from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
58
61
  from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
59
62
  MobileNetImageConverter,
@@ -177,6 +177,12 @@ from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
177
177
  GemmaCausalLMPreprocessor,
178
178
  )
179
179
  from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
180
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
181
+ from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM
182
+ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
183
+ Gemma3CausalLMPreprocessor,
184
+ )
185
+ from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
180
186
  from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
181
187
  from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
182
188
  from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
@@ -19,6 +19,7 @@ from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer
19
19
  from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer
20
20
  from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
21
21
  from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
22
+ from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
22
23
  from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
23
24
  from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
24
25
  from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
@@ -194,15 +194,15 @@ class Backbone(keras.Model):
194
194
  """
195
195
  return ["query_dense", "value_dense", "query", "value"]
196
196
 
197
- def enable_lora(self, rank):
197
+ def enable_lora(self, rank, target_names=None):
198
198
  """Enable Lora on the backbone.
199
199
 
200
200
  Calling this method will freeze all weights on the backbone,
201
201
  while enabling Lora on the query & value `EinsumDense` layers
202
202
  of the attention layers.
203
203
  """
204
- target_names = self.get_lora_target_names()
205
-
204
+ if target_names is None:
205
+ target_names = self.get_lora_target_names()
206
206
  self.trainable = True
207
207
  self._lora_enabled_layers = []
208
208
  self._lora_rank = rank
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
2
+ from keras_hub.src.models.gemma3.gemma3_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, Gemma3Backbone)
@@ -0,0 +1,315 @@
1
+ import inspect
2
+
3
+ import keras
4
+ import numpy as np
5
+ from keras import ops
6
+
7
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
+ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9
+ from keras_hub.src.utils.keras_utils import clone_initializer
10
+ from keras_hub.src.utils.keras_utils import has_flash_attention_support
11
+ from keras_hub.src.utils.keras_utils import running_on_tpu
12
+
13
+
14
+ class CachedGemma3Attention(keras.layers.Layer):
15
+ """A cached grouped query attention layer for Gemma3.
16
+
17
+ This is different from Gemma and Gemma2 in several ways:
18
+
19
+ - `use_query_key_norm`: Applies RMS Norm on query, key.
20
+ - `rope_wavelength`: RoPE wavelength differs from local to global attention
21
+ layers.
22
+ - `rope_scaling_factor`: RoPE scaling factor differs from local to global
23
+ attention layers.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ head_dim,
29
+ num_query_heads,
30
+ num_key_value_heads,
31
+ kernel_initializer="glorot_uniform",
32
+ logit_soft_cap=None,
33
+ use_sliding_window_attention=False,
34
+ sliding_window_size=4096,
35
+ query_head_dim_normalize=True,
36
+ use_query_key_norm=False,
37
+ layer_norm_epsilon=1e-6,
38
+ rope_wavelength=10_000.0,
39
+ rope_scaling_factor=1.0,
40
+ dropout=0,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.num_query_heads = num_query_heads
45
+ self.num_key_value_heads = num_key_value_heads
46
+ self.head_dim = head_dim
47
+ self.logit_soft_cap = logit_soft_cap
48
+ self.use_sliding_window_attention = use_sliding_window_attention
49
+ self.sliding_window_size = sliding_window_size
50
+ self.query_head_dim_normalize = query_head_dim_normalize
51
+ self.use_query_key_norm = use_query_key_norm
52
+ self.layer_norm_epsilon = layer_norm_epsilon
53
+ self.rope_wavelength = rope_wavelength
54
+ self.rope_scaling_factor = rope_scaling_factor
55
+ self.dropout = dropout
56
+
57
+ self._kernel_initializer = keras.initializers.get(
58
+ clone_initializer(kernel_initializer)
59
+ )
60
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
61
+ self.query_head_dim_normalize = query_head_dim_normalize
62
+
63
+ def build(self, inputs_shape):
64
+ self.hidden_dim = inputs_shape[-1]
65
+
66
+ self.query_dense = keras.layers.EinsumDense(
67
+ "btd,ndh->btnh",
68
+ output_shape=(None, self.num_query_heads, self.head_dim),
69
+ kernel_initializer=self._kernel_initializer,
70
+ dtype=self.dtype_policy,
71
+ name="query",
72
+ )
73
+ self.query_dense.build(inputs_shape)
74
+
75
+ self.key_dense = keras.layers.EinsumDense(
76
+ "bsd,kdh->bskh",
77
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
78
+ kernel_initializer=self._kernel_initializer,
79
+ dtype=self.dtype_policy,
80
+ name="key",
81
+ )
82
+ self.key_dense.build(inputs_shape)
83
+
84
+ self.value_dense = keras.layers.EinsumDense(
85
+ "bsd,kdh->bskh",
86
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
87
+ kernel_initializer=self._kernel_initializer,
88
+ dtype=self.dtype_policy,
89
+ name="value",
90
+ )
91
+ self.value_dense.build(inputs_shape)
92
+
93
+ if self.use_query_key_norm:
94
+ self.query_norm = RMSNormalization(
95
+ epsilon=self.layer_norm_epsilon,
96
+ dtype=self.dtype_policy,
97
+ name="query_norm",
98
+ )
99
+ self.query_norm.build(
100
+ self.query_dense.compute_output_shape(inputs_shape)
101
+ )
102
+
103
+ self.key_norm = RMSNormalization(
104
+ epsilon=self.layer_norm_epsilon,
105
+ dtype=self.dtype_policy,
106
+ name="key_norm",
107
+ )
108
+ self.key_norm.build(
109
+ self.key_dense.compute_output_shape(inputs_shape)
110
+ )
111
+
112
+ self.dropout_layer = keras.layers.Dropout(
113
+ rate=self.dropout,
114
+ dtype=self.dtype_policy,
115
+ )
116
+
117
+ self.output_dense = keras.layers.EinsumDense(
118
+ equation="btnh,nhd->btd",
119
+ output_shape=(None, self.hidden_dim),
120
+ kernel_initializer=self._kernel_initializer,
121
+ dtype=self.dtype_policy,
122
+ name="attention_output",
123
+ )
124
+ self.output_dense.build(
125
+ (None, None, self.num_query_heads, self.head_dim)
126
+ )
127
+ self.softmax = keras.layers.Softmax(dtype="float32")
128
+
129
+ self.rope_layer = RotaryEmbedding(
130
+ max_wavelength=self.rope_wavelength,
131
+ scaling_factor=self.rope_scaling_factor,
132
+ dtype=self.dtype_policy,
133
+ )
134
+
135
+ self.built = True
136
+
137
+ def _apply_rope(self, x, start_index):
138
+ """Rope rotate q or k."""
139
+ x = self.rope_layer(x, start_index=start_index)
140
+ return x
141
+
142
+ def _can_use_flash_attention(self):
143
+ if not has_flash_attention_support():
144
+ return False
145
+ if self.dropout > 0.0:
146
+ return False
147
+ if self.logit_soft_cap is None:
148
+ return True
149
+ sig = inspect.signature(ops.dot_product_attention)
150
+ # We can currently only run soft capped attention for keras >= 3.10
151
+ # and only on TPU.
152
+ return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
153
+
154
+ def _compute_attention(
155
+ self,
156
+ q,
157
+ k,
158
+ v,
159
+ attention_mask,
160
+ training=False,
161
+ cache_update_index=0,
162
+ ):
163
+ if self.query_head_dim_normalize:
164
+ query_normalization = 1 / np.sqrt(self.head_dim)
165
+ else:
166
+ query_normalization = 1 / np.sqrt(
167
+ self.hidden_dim // self.num_query_heads
168
+ )
169
+ if self._can_use_flash_attention():
170
+ if attention_mask is not None:
171
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
172
+ attention_mask = ops.cast(attention_mask, dtype="bool")
173
+ # Only pass soft cap if needed as not all keras versions support.
174
+ if self.logit_soft_cap:
175
+ kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
176
+ else:
177
+ kwargs = {}
178
+ return ops.dot_product_attention(
179
+ query=q,
180
+ key=k,
181
+ value=v,
182
+ mask=attention_mask,
183
+ scale=query_normalization,
184
+ **kwargs,
185
+ )
186
+
187
+ q *= ops.cast(query_normalization, dtype=q.dtype)
188
+ q_shape = ops.shape(q)
189
+ q = ops.reshape(
190
+ q,
191
+ (
192
+ *q_shape[:-2],
193
+ self.num_key_value_heads,
194
+ self.num_query_heads // self.num_key_value_heads,
195
+ q_shape[-1],
196
+ ),
197
+ )
198
+ b, q_len, _, _, h = ops.shape(q)
199
+
200
+ # Fallback to standard attention if flash attention is disabled
201
+ attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
202
+ if self.logit_soft_cap is not None:
203
+ attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
204
+ attention_logits = ops.multiply(
205
+ ops.tanh(attention_logits), self.logit_soft_cap
206
+ )
207
+
208
+ if self.use_sliding_window_attention:
209
+ attention_mask = self._mask_sliding_window(
210
+ attention_mask,
211
+ cache_update_index=cache_update_index,
212
+ )
213
+
214
+ attention_mask = attention_mask[:, None, None, :, :]
215
+ orig_dtype = attention_logits.dtype
216
+ attention_softmax = self.softmax(attention_logits, mask=attention_mask)
217
+ attention_softmax = ops.cast(attention_softmax, orig_dtype)
218
+
219
+ if self.dropout:
220
+ attention_softmax = self.dropout_layer(
221
+ attention_softmax, training=training
222
+ )
223
+
224
+ results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
225
+ return ops.reshape(results, (b, q_len, self.num_query_heads, h))
226
+
227
+ def _mask_sliding_window(
228
+ self,
229
+ attention_mask,
230
+ cache_update_index=0,
231
+ ):
232
+ batch_size, query_len, key_len = ops.shape(attention_mask)
233
+ # Compute the sliding window for square attention.
234
+ all_ones = ops.ones((key_len, key_len), "bool")
235
+ if keras.config.backend() == "tensorflow":
236
+ # TODO: trui/tril has issues with dynamic shape on the tensorflow
237
+ # backend. We should fix, but use `band_part` for now.
238
+ import tensorflow as tf
239
+
240
+ band_size = ops.minimum(key_len, self.sliding_window_size - 1)
241
+ band_size = ops.cast(band_size, "int32")
242
+ sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
243
+ else:
244
+ sliding_mask = ops.triu(
245
+ all_ones, -1 * self.sliding_window_size + 1
246
+ ) * ops.tril(all_ones, self.sliding_window_size - 1)
247
+ # Slice the window for short queries during generation.
248
+ start = (cache_update_index, 0)
249
+ sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
250
+ sliding_mask = ops.expand_dims(sliding_mask, 0)
251
+ return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
252
+
253
+ def call(
254
+ self,
255
+ x,
256
+ attention_mask=None,
257
+ cache=None,
258
+ cache_update_index=0,
259
+ training=False,
260
+ ):
261
+ query = self.query_dense(x)
262
+
263
+ if self.use_query_key_norm:
264
+ query = self.query_norm(query)
265
+
266
+ query = self._apply_rope(query, cache_update_index)
267
+
268
+ if cache is not None:
269
+ key_cache = cache[:, 0, ...]
270
+ value_cache = cache[:, 1, ...]
271
+ key_update = self.key_dense(x)
272
+
273
+ if self.use_query_key_norm:
274
+ key_update = self.key_norm(key_update)
275
+
276
+ key_update = self._apply_rope(key_update, cache_update_index)
277
+ value_update = self.value_dense(x)
278
+ start = [0, cache_update_index, 0, 0]
279
+ key = ops.slice_update(key_cache, start, key_update)
280
+ value = ops.slice_update(value_cache, start, value_update)
281
+ cache = ops.stack((key, value), axis=1)
282
+ else:
283
+ key = self.key_dense(x)
284
+
285
+ if self.use_query_key_norm:
286
+ key = self.key_norm(key)
287
+
288
+ key = self._apply_rope(key, cache_update_index)
289
+ value = self.value_dense(x)
290
+
291
+ attention_vec = self._compute_attention(
292
+ query,
293
+ key,
294
+ value,
295
+ attention_mask,
296
+ training=training,
297
+ cache_update_index=cache_update_index,
298
+ )
299
+
300
+ # Wipe attn vec if there are no attended tokens.
301
+ no_attended_tokens = ops.all(
302
+ ops.equal(attention_mask, 0), axis=-1, keepdims=True
303
+ )[..., None]
304
+ attention_vec = ops.where(
305
+ no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
306
+ )
307
+
308
+ attention_output = self.output_dense(attention_vec)
309
+
310
+ if cache is not None:
311
+ return attention_output, cache
312
+ return attention_output
313
+
314
+ def compute_output_shape(self, input_shape):
315
+ return input_shape