keras-hub 0.24.0.dev0__py3-none-any.whl → 0.25.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. keras_hub/models/__init__.py +12 -0
  2. keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
  3. keras_hub/src/models/esm/esm_attention.py +11 -4
  4. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
  5. keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
  6. keras_hub/src/models/gpt_oss/__init__.py +5 -0
  7. keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
  8. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +221 -0
  9. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
  10. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
  11. keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
  12. keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
  13. keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
  14. keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
  15. keras_hub/src/models/llama3/llama3_presets.py +1 -1
  16. keras_hub/src/models/parseq/parseq_decoder.py +21 -9
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  18. keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
  19. keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
  20. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  21. keras_hub/src/version.py +1 -1
  22. keras_hub/tokenizers/__init__.py +3 -0
  23. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dev0.dist-info}/METADATA +1 -1
  24. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dev0.dist-info}/RECORD +26 -15
  25. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dev0.dist-info}/WHEEL +0 -0
  26. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dev0.dist-info}/top_level.txt +0 -0
@@ -340,6 +340,18 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import (
340
340
  from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
341
341
  GPTNeoXTokenizer as GPTNeoXTokenizer,
342
342
  )
343
+ from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
344
+ GptOssBackbone as GptOssBackbone,
345
+ )
346
+ from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
347
+ GptOssCausalLM as GptOssCausalLM,
348
+ )
349
+ from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
350
+ GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
351
+ )
352
+ from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
353
+ GptOssTokenizer as GptOssTokenizer,
354
+ )
343
355
  from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
344
356
  HGNetV2Backbone as HGNetV2Backbone,
345
357
  )
@@ -1,4 +1,5 @@
1
1
  import keras
2
+ import numpy as np
2
3
  from keras import ops
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
@@ -25,6 +26,17 @@ class RotaryEmbedding(keras.layers.Layer):
25
26
  curves.
26
27
  scaling_factor: float. The scaling factor used to scale positions of
27
28
  the tokens.
29
+ rope_type: str. The type of RoPE scaling to apply. Supported types:
30
+ "linear", "dynamic", "yarn". Defaults to "linear".
31
+ beta_fast: float. Beta fast parameter for YaRN scaling. Only used
32
+ when rope_type="yarn". Defaults to 32.0.
33
+ beta_slow: float. Beta slow parameter for YaRN scaling. Only used
34
+ when rope_type="yarn". Defaults to 1.0.
35
+ original_max_position_embeddings: int. Original maximum position
36
+ embeddings for YaRN scaling. Only used when rope_type="yarn".
37
+ Defaults to 4096.
38
+ truncate: bool. Whether to apply truncation for YaRN scaling. Only used
39
+ when rope_type="yarn". Defaults to False.
28
40
  sequence_axis: int. Sequence axis in the input tensor.
29
41
  feature_axis: int. Feature axis in the input tensor.
30
42
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
@@ -69,6 +81,11 @@ class RotaryEmbedding(keras.layers.Layer):
69
81
  self,
70
82
  max_wavelength=10000,
71
83
  scaling_factor=1.0,
84
+ rope_type="linear",
85
+ beta_fast=32.0,
86
+ beta_slow=1.0,
87
+ original_max_position_embeddings=4096,
88
+ truncate=False,
72
89
  sequence_axis=1,
73
90
  feature_axis=-1,
74
91
  **kwargs,
@@ -78,24 +95,70 @@ class RotaryEmbedding(keras.layers.Layer):
78
95
  self.sequence_axis = sequence_axis
79
96
  self.feature_axis = feature_axis
80
97
  self.scaling_factor = scaling_factor
98
+ self.rope_type = rope_type
99
+
100
+ # YaRN-specific parameters (only used when rope_type="yarn")
101
+ self.beta_fast = beta_fast
102
+ self.beta_slow = beta_slow
103
+ self.original_max_position_embeddings = original_max_position_embeddings
104
+ self.truncate = truncate
81
105
  self.built = True
82
106
 
107
+ def _normalize_axes(self, input_shape):
108
+ """Normalize and validate axis indices for the given input shape."""
109
+ rank = len(input_shape)
110
+
111
+ # Normalize negative indices
112
+ sequence_axis = self.sequence_axis
113
+ feature_axis = self.feature_axis
114
+
115
+ if sequence_axis < 0:
116
+ sequence_axis += rank
117
+ if feature_axis < 0:
118
+ feature_axis += rank
119
+
120
+ if sequence_axis < 0 or sequence_axis >= rank:
121
+ raise ValueError(
122
+ f"sequence_axis {self.sequence_axis} "
123
+ f"is out of range for input with rank {rank}"
124
+ )
125
+ if feature_axis < 0 or feature_axis >= rank:
126
+ raise ValueError(
127
+ f"feature_axis {self.feature_axis} "
128
+ f"is out of range for input with rank {rank}"
129
+ )
130
+ if sequence_axis == feature_axis:
131
+ raise ValueError("sequence_axis and feature_axis must be different")
132
+
133
+ return sequence_axis, feature_axis
134
+
135
+ def _validate_rotary_dimension(self, rotary_dim):
136
+ if rotary_dim % 2 != 0:
137
+ raise ValueError(
138
+ f"Rotary dimension must be even, got {rotary_dim}."
139
+ "The rotary embedding splits the feature dimension "
140
+ "into two halves. Consider using a different feature "
141
+ "dimension or padding."
142
+ )
143
+
83
144
  def call(self, inputs, start_index=0, positions=None):
145
+ input_shape = ops.shape(inputs)
146
+ sequence_axis, feature_axis = self._normalize_axes(input_shape)
147
+
148
+ rotary_dim = input_shape[feature_axis]
149
+ self._validate_rotary_dimension(rotary_dim)
150
+
84
151
  # Take care of unbatched `positions`.
85
152
  if positions is not None:
86
153
  if len(ops.shape(positions)) == 1:
87
154
  positions = ops.expand_dims(positions, axis=0)
88
155
 
89
- inputs = ops.moveaxis(
90
- inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
91
- )
156
+ inputs = ops.moveaxis(inputs, (feature_axis, sequence_axis), (-1, 1))
92
157
  cos_emb, sin_emb = self._compute_cos_sin_embedding(
93
158
  inputs, start_index, positions
94
159
  )
95
160
  output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
96
- return ops.moveaxis(
97
- output, (-1, 1), (self.feature_axis, self.sequence_axis)
98
- )
161
+ return ops.moveaxis(output, (-1, 1), (feature_axis, sequence_axis))
99
162
 
100
163
  def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
101
164
  x1, x2 = ops.split(tensor, 2, axis=-1)
@@ -113,19 +176,35 @@ class RotaryEmbedding(keras.layers.Layer):
113
176
  return positions + ops.cast(start_index, dtype="float32")
114
177
 
115
178
  def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
179
+ """Compute cos & sin RoPE embeddings with optional YaRN scaling.
180
+ Uses tensor ops only to remain JIT/backends friendly.
181
+ """
116
182
  batch_axis = 0
117
- feature_axis = len(inputs.shape) - 1
118
183
  sequence_axis = 1
184
+ feature_axis = len(inputs.shape) - 1
119
185
 
120
186
  rotary_dim = ops.shape(inputs)[feature_axis]
121
187
  inverse_freq = self._get_inverse_freq(rotary_dim)
122
188
 
123
189
  if positions is None:
124
190
  positions = self._compute_positions(inputs, start_index)
125
- positions = ops.expand_dims(positions, axis=batch_axis)
191
+ positions = ops.expand_dims(
192
+ positions, axis=batch_axis
193
+ ) # shape (1, seq_len)
126
194
  else:
127
195
  positions = ops.cast(positions, "float32")
128
- positions = positions / ops.cast(self.scaling_factor, "float32")
196
+ if len(ops.shape(positions)) == 1:
197
+ positions = ops.expand_dims(positions, axis=batch_axis)
198
+
199
+ if (
200
+ self.rope_type == "yarn"
201
+ and self.truncate
202
+ and self.original_max_position_embeddings is not None
203
+ ):
204
+ positions = ops.minimum(
205
+ positions,
206
+ ops.cast(self.original_max_position_embeddings, "float32"),
207
+ )
129
208
 
130
209
  freq = ops.einsum("bi,j->bij", positions, inverse_freq)
131
210
 
@@ -140,15 +219,103 @@ class RotaryEmbedding(keras.layers.Layer):
140
219
 
141
220
  cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
142
221
  sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
222
+
223
+ if self.rope_type == "yarn":
224
+ # YaRN temperature scaling
225
+ factor = ops.add(
226
+ ops.multiply(
227
+ ops.cast(0.1, self.compute_dtype),
228
+ ops.log(ops.cast(self.scaling_factor, self.compute_dtype)),
229
+ ),
230
+ ops.cast(1.0, self.compute_dtype),
231
+ )
232
+ cos_emb = cos_emb * factor
233
+ sin_emb = sin_emb * factor
143
234
  return cos_emb, sin_emb
144
235
 
145
236
  def _get_inverse_freq(self, rotary_dim):
146
- freq_range = ops.divide(
147
- ops.arange(0, rotary_dim, 2, dtype="float32"),
148
- ops.cast(rotary_dim, "float32"),
237
+ """Return inverse frequencies."""
238
+ idx = ops.arange(0, rotary_dim, 2, dtype="float32")
239
+ denom = ops.cast(rotary_dim, "float32")
240
+ freq_range = idx / denom
241
+ inv = ops.power(ops.cast(self.max_wavelength, "float32"), -freq_range)
242
+
243
+ if self.rope_type == "linear":
244
+ return inv / ops.cast(self.scaling_factor, "float32")
245
+ elif self.rope_type == "dynamic":
246
+ exponent = ops.cast(rotary_dim, "float32") / ops.cast(
247
+ max(1, rotary_dim - 2), "float32"
248
+ )
249
+ return inv / ops.power(
250
+ ops.cast(self.scaling_factor, "float32"), exponent
251
+ )
252
+ elif self.rope_type == "yarn":
253
+ return self._get_yarn_inverse_freq(rotary_dim)
254
+ else:
255
+ return inv
256
+
257
+ def _get_yarn_inverse_freq(self, rotary_dim):
258
+ # Get the base (rope_theta equivalent) from max_wavelength
259
+ base = ops.cast(self.max_wavelength, "float32")
260
+
261
+ # Compute base frequencies: base ** (idx / dim)
262
+ idx = ops.arange(0, rotary_dim, 2, dtype="float32")
263
+ pos_freqs = ops.power(base, idx / ops.cast(rotary_dim, "float32"))
264
+
265
+ # Compute interpolation and extrapolation frequencies
266
+ inv_freq_extrapolation = 1.0 / pos_freqs
267
+ inv_freq_interpolation = 1.0 / (
268
+ ops.cast(self.scaling_factor, "float32") * pos_freqs
149
269
  )
150
- inverse_freq = 1.0 / (self.max_wavelength**freq_range)
151
- return inverse_freq
270
+
271
+ # Find correction range
272
+ beta_fast = ops.cast(self.beta_fast, "float32")
273
+ beta_slow = ops.cast(self.beta_slow, "float32")
274
+
275
+ # Find correction dimensions for beta_fast and beta_slow
276
+ def find_correction_dim_tensor(num_rotations, dim):
277
+ max_pos = ops.cast(self.original_max_position_embeddings, "float32")
278
+ return (dim * ops.log(max_pos / (num_rotations * 2 * np.pi))) / (
279
+ 2 * ops.log(base)
280
+ )
281
+
282
+ low = find_correction_dim_tensor(
283
+ beta_fast, ops.cast(rotary_dim, "float32")
284
+ )
285
+ high = find_correction_dim_tensor(
286
+ beta_slow, ops.cast(rotary_dim, "float32")
287
+ )
288
+
289
+ # Apply truncation if specified
290
+ if self.truncate:
291
+ low = ops.floor(low)
292
+ high = ops.ceil(high)
293
+
294
+ # Clamp to valid range
295
+ low = ops.maximum(low, ops.cast(0, "float32"))
296
+ high = ops.minimum(high, ops.cast(rotary_dim // 2 - 1, "float32"))
297
+
298
+ # Linear ramp function
299
+ dim_half = rotary_dim // 2
300
+ idx_half = ops.arange(0, dim_half, dtype="float32")
301
+
302
+ # Prevent singularity
303
+ diff = high - low
304
+ diff = ops.maximum(diff, ops.cast(0.001, "float32"))
305
+
306
+ linear_func = (idx_half - low) / diff
307
+ ramp_func = ops.clip(linear_func, 0, 1)
308
+
309
+ # Apply the ramp to get extrapolation factor
310
+ inv_freq_extrapolation_factor = 1 - ramp_func
311
+
312
+ # Combine interpolation and extrapolation
313
+ scaled_inverse_freq = (
314
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
315
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
316
+ )
317
+
318
+ return scaled_inverse_freq
152
319
 
153
320
  def get_config(self):
154
321
  config = super().get_config()
@@ -156,6 +323,13 @@ class RotaryEmbedding(keras.layers.Layer):
156
323
  {
157
324
  "max_wavelength": self.max_wavelength,
158
325
  "scaling_factor": self.scaling_factor,
326
+ "rope_type": self.rope_type,
327
+ "beta_fast": self.beta_fast,
328
+ "beta_slow": self.beta_slow,
329
+ "original_max_position_embeddings": (
330
+ self.original_max_position_embeddings
331
+ ),
332
+ "truncate": self.truncate,
159
333
  "sequence_axis": self.sequence_axis,
160
334
  "feature_axis": self.feature_axis,
161
335
  }
@@ -14,7 +14,8 @@ class ESMRotaryEmbedding(RotaryEmbedding):
14
14
  inv_freq = self.scaling_factor / (
15
15
  self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
16
16
  )
17
- t = ops.arange(x.shape[position], dtype=x.dtype)
17
+ # Use ops.shape for dynamic shape compatibility with TFLite
18
+ t = ops.arange(ops.shape(x)[position], dtype=x.dtype)
18
19
  freqs = ops.outer(t, inv_freq)
19
20
  emb = ops.concatenate((freqs, freqs), axis=-1)
20
21
 
@@ -32,11 +33,17 @@ class ESMRotaryEmbedding(RotaryEmbedding):
32
33
 
33
34
  def rotate_half(self, x):
34
35
  x1, x2 = ops.split(x, 2, -1)
35
- return ops.concatenate((-x2, x1), axis=-1)
36
+ # Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
37
+ # backend. Use stack + reshape approach from base RotaryEmbedding.
38
+ half_rot_x = ops.stack((-x2, x1), axis=-2)
39
+ half_rot_x = ops.reshape(half_rot_x, ops.shape(x))
40
+ return half_rot_x
36
41
 
37
42
  def apply_rotary_pos_emb(self, x, cos, sin):
38
- cos = cos[:, : x.shape[1], :, :]
39
- sin = sin[:, : x.shape[1], :, :]
43
+ # Use ops.shape for dynamic shape compatibility with TFLite
44
+ seq_len = ops.shape(x)[1]
45
+ cos = cos[:, :seq_len, :, :]
46
+ sin = sin[:, :seq_len, :, :]
40
47
 
41
48
  return (x * cos) + (self.rotate_half(x) * sin)
42
49
 
@@ -283,9 +283,14 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
283
283
  # is `None`.
284
284
  self.text_only_model = self.image_converter is None
285
285
 
286
- self.image_placeholder = self.tokenizer.image_placeholder
287
- self.start_of_image_token = self.tokenizer.start_of_image_token
288
- self.end_of_image_token = self.tokenizer.end_of_image_token
286
+ if self.text_only_model:
287
+ self.image_placeholder = None
288
+ self.start_of_image_token = None
289
+ self.end_of_image_token = None
290
+ else:
291
+ self.image_placeholder = self.tokenizer.image_placeholder
292
+ self.start_of_image_token = self.tokenizer.start_of_image_token
293
+ self.end_of_image_token = self.tokenizer.end_of_image_token
289
294
 
290
295
  def build(self, input_shape):
291
296
  # Defer packer creation to `build()` so that we can be sure tokenizer
@@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
77
77
 
78
78
  backbone_cls = Gemma3Backbone
79
79
 
80
- def __init__(self, proto, **kwargs):
80
+ def __init__(self, proto, has_vision_tokens=True, **kwargs):
81
81
  # Add special tokens.
82
82
 
83
+ self.has_vision_tokens = has_vision_tokens
83
84
  # The usual tokens.
84
85
  self._add_special_token("<bos>", "start_token")
85
86
  self._add_special_token("<eos>", "end_token")
86
87
  self._add_special_token("<pad>", "pad_token")
87
88
 
88
- # Image placeholder token.
89
- self._add_special_token("<img>", "image_placeholder")
90
-
91
- # Some tokens which are used in the preprocessor. We need to keep them
92
- # here so that the preprocessor works with `tf.data`.
93
- self._add_special_token("<start_of_image>", "start_of_image_token")
94
- self._add_special_token("<end_of_image>", "end_of_image_token")
89
+ if has_vision_tokens:
90
+ # Image placeholder token.
91
+ self._add_special_token("<img>", "image_placeholder")
92
+ # Some tokens which are used in the preprocessor.
93
+ # We need to keep them
94
+ # here so that the preprocessor works with tf.data.
95
+ self._add_special_token("<start_of_image>", "start_of_image_token")
96
+ self._add_special_token("<end_of_image>", "end_of_image_token")
97
+ else:
98
+ # For text-only, skip assigning token IDs or set to -1
99
+ self.start_of_image_token_id = -1
100
+ self.image_placeholder_token_id = -1
101
+ self.end_of_image_token_id = -1
95
102
 
96
103
  super().__init__(proto=proto, **kwargs)
104
+
105
+ def get_config(self):
106
+ config = super().get_config()
107
+ config.update({"has_vision_tokens": self.has_vision_tokens})
108
+ return config
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone
2
+ from keras_hub.src.models.gpt_oss.gpt_oss_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, GptOssBackbone)