keras-hub-nightly 0.21.0.dev202504170402__py3-none-any.whl → 0.21.0.dev202504190357__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,7 +29,7 @@ class FalconBackbone(Backbone):
29
29
  layer_norm_epsilon: float. Epsilon for the layer normalization layers in
30
30
  the transformer decoder.
31
31
  attention_dropout_rate: float. Dropout probability for the attention.
32
- feedforward_dropout_rate: flaot. Dropout probability for the
32
+ feedforward_dropout_rate: float. Dropout probability for the
33
33
  feedforward.
34
34
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
35
35
  for model computations and weights. Note that some computations,
@@ -512,6 +512,7 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
512
512
 
513
513
  # Extract text part of the input.
514
514
  prompts, responses = x["prompts"], x["responses"]
515
+ tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
515
516
 
516
517
  # Find out if the input is batched/not batched. Uprank if not batched.
517
518
  # In other preprocessors, we don't have to do this, but here, all
@@ -3,7 +3,9 @@ import math
3
3
  import keras
4
4
  from keras import ops
5
5
 
6
- from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+ from keras_hub.src.models.llama.llama_rotary_embedding import (
7
+ LlamaRotaryEmbedding,
8
+ )
7
9
  from keras_hub.src.utils.keras_utils import clone_initializer
8
10
  from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
11
 
@@ -16,7 +18,11 @@ class LlamaAttention(keras.layers.Layer):
16
18
  num_query_heads,
17
19
  num_key_value_heads,
18
20
  rope_max_wavelength=10000,
19
- rope_scaling_factor=1.0,
21
+ rope_position_scaling_factor=1.0,
22
+ rope_frequency_adjustment_factor=None,
23
+ rope_low_freq_factor=None,
24
+ rope_high_freq_factor=None,
25
+ rope_pretraining_sequence_length=None,
20
26
  kernel_initializer="glorot_uniform",
21
27
  dropout=0,
22
28
  **kwargs,
@@ -28,13 +34,16 @@ class LlamaAttention(keras.layers.Layer):
28
34
 
29
35
  self.num_key_value_groups = num_query_heads // num_key_value_heads
30
36
  self.rope_max_wavelength = rope_max_wavelength
37
+ self.rope_position_scaling_factor = rope_position_scaling_factor
38
+ self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
39
+ self.rope_low_freq_factor = rope_low_freq_factor
40
+ self.rope_high_freq_factor = rope_high_freq_factor
41
+ self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
31
42
 
32
43
  self.kernel_initializer = keras.initializers.get(
33
44
  clone_initializer(kernel_initializer)
34
45
  )
35
46
 
36
- self.rope_scaling_factor = rope_scaling_factor
37
-
38
47
  def build(self, inputs_shape):
39
48
  # Einsum variables:
40
49
  # b = batch size
@@ -103,9 +112,13 @@ class LlamaAttention(keras.layers.Layer):
103
112
  )
104
113
  self._output_dense.build((None, None, self.num_query_heads, head_dim))
105
114
 
106
- self.rotary_embedding_layer = RotaryEmbedding(
115
+ self.rotary_embedding_layer = LlamaRotaryEmbedding(
107
116
  max_wavelength=self.rope_max_wavelength,
108
- scaling_factor=self.rope_scaling_factor,
117
+ position_scaling_factor=self.rope_position_scaling_factor,
118
+ frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
119
+ low_freq_factor=self.rope_low_freq_factor,
120
+ high_freq_factor=self.rope_high_freq_factor,
121
+ pretraining_sequence_length=self.rope_pretraining_sequence_length,
109
122
  dtype=self.dtype_policy,
110
123
  )
111
124
 
@@ -224,6 +237,11 @@ class LlamaAttention(keras.layers.Layer):
224
237
  "num_key_value_heads": self.num_key_value_heads,
225
238
  "rope_max_wavelength": self.rope_max_wavelength,
226
239
  "rope_scaling_factor": self.rope_scaling_factor,
240
+ "rope_low_freq_factor": self.rope_low_freq_factor,
241
+ "rope_high_freq_factor": self.rope_high_freq_factor,
242
+ "rope_pretraining_sequence_length": (
243
+ self.rope_pretraining_sequence_length
244
+ ),
227
245
  "kernel_initializer": keras.initializers.serialize(
228
246
  self.kernel_initializer
229
247
  ),
@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
30
30
  constructor.
31
31
 
32
32
  Args:
33
- vocabulary_size (int): The size of the token vocabulary.
34
- num_layers (int): The number of transformer layers.
35
- num_query_heads (int): The number of query attention heads for
33
+ vocabulary_size: int. The size of the token vocabulary.
34
+ num_layers: int. The number of transformer layers.
35
+ num_query_heads : int. The number of query attention heads for
36
36
  each transformer.
37
- hidden_dim (int): The size of the transformer encoding and pooling
37
+ hidden_dim : int. The size of the transformer encoding and pooling
38
38
  layers.
39
- intermediate_dim (int): The output dimension of the first Dense layer in
39
+ intermediate_dim : int. The output dimension of the first Dense layer in
40
40
  a three-layer feedforward network for each transformer.
41
- num_key_value_heads (int): The number of key and value attention heads
41
+ num_key_value_heads : int. The number of key and value attention heads
42
42
  for each transformer.
43
- rope_max_wavelength (int, optional): The maximum angular wavelength of
43
+ rope_max_wavelength : int. The maximum angular wavelength of
44
44
  the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
45
- rope_scaling_factor (float, optional): The scaling factor for
46
- calculation of roatary embedding. Defaults to `1.0`.
47
- layer_norm_epsilon (float, optional): Epsilon for the layer
48
- normalization layers in the transformer decoder. Defaults to `1e-6`.
45
+ rope_position_scaling_factor: float. The scaling factor for
46
+ calculation of rotary embedding. Defaults to `1.0`
47
+ rope_frequency_adjustment_factor: float. The scaling factor
48
+ used to scale the inverse frequencies. Defaults to `None`.
49
+ rope_low_freq_factor: float. The low frequency scaling
50
+ factor. Defaults to `None`.
51
+ rope_high_freq_factor: float. Used for Llama3.1+. The high
52
+ frequency scaling factor. Defaults to `None`.
53
+ rope_pretraining_sequence_length: int. Used for Llama3.1+.
54
+ Defaults to `None`.
55
+ layer_norm_epsilon : float. Epsilon for the layer normalization layers
56
+ in the transformer decoder. Defaults to `1e-6`.
49
57
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
50
58
  for model computations and weights. Note that some computations,
51
59
  such as softmax and layer normalization, will always be done at
@@ -87,7 +95,11 @@ class LlamaBackbone(Backbone):
87
95
  intermediate_dim,
88
96
  num_key_value_heads,
89
97
  rope_max_wavelength=10000,
90
- rope_scaling_factor=1.0,
98
+ rope_position_scaling_factor=1.0,
99
+ rope_frequency_adjustment_factor=None,
100
+ rope_low_freq_factor=None,
101
+ rope_high_freq_factor=None,
102
+ rope_pretraining_sequence_length=None,
91
103
  layer_norm_epsilon=1e-6,
92
104
  dropout=0,
93
105
  dtype=None,
@@ -110,7 +122,15 @@ class LlamaBackbone(Backbone):
110
122
  num_query_heads=num_query_heads,
111
123
  num_key_value_heads=num_key_value_heads,
112
124
  rope_max_wavelength=rope_max_wavelength,
113
- rope_scaling_factor=rope_scaling_factor,
125
+ rope_position_scaling_factor=rope_position_scaling_factor,
126
+ rope_frequency_adjustment_factor=(
127
+ rope_frequency_adjustment_factor
128
+ ),
129
+ rope_low_freq_factor=rope_low_freq_factor,
130
+ rope_high_freq_factor=rope_high_freq_factor,
131
+ rope_pretraining_sequence_length=(
132
+ rope_pretraining_sequence_length
133
+ ),
114
134
  layer_norm_epsilon=layer_norm_epsilon,
115
135
  activation=ops.silu,
116
136
  kernel_initializer=_llama_kernel_initializer(stddev=0.02),
@@ -152,9 +172,13 @@ class LlamaBackbone(Backbone):
152
172
  self.num_query_heads = num_query_heads
153
173
  self.hidden_dim = hidden_dim
154
174
  self.intermediate_dim = intermediate_dim
155
- self.rope_max_wavelength = rope_max_wavelength
156
175
  self.num_key_value_heads = num_key_value_heads
157
- self.rope_scaling_factor = rope_scaling_factor
176
+ self.rope_max_wavelength = rope_max_wavelength
177
+ self.rope_position_scaling_factor = rope_position_scaling_factor
178
+ self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
179
+ self.rope_low_freq_factor = rope_low_freq_factor
180
+ self.rope_high_freq_factor = rope_high_freq_factor
181
+ self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
158
182
  self.layer_norm_epsilon = layer_norm_epsilon
159
183
  self.dropout = dropout
160
184
  self.tie_word_embeddings = tie_word_embeddings
@@ -169,7 +193,17 @@ class LlamaBackbone(Backbone):
169
193
  "hidden_dim": self.hidden_dim,
170
194
  "intermediate_dim": self.intermediate_dim,
171
195
  "rope_max_wavelength": self.rope_max_wavelength,
172
- "rope_scaling_factor": self.rope_scaling_factor,
196
+ "rope_position_scaling_factor": (
197
+ self.rope_position_scaling_factor
198
+ ),
199
+ "rope_frequency_adjustment_factor": (
200
+ self.rope_frequency_adjustment_factor
201
+ ),
202
+ "rope_low_freq_factor": self.rope_low_freq_factor,
203
+ "rope_high_freq_factor": self.rope_high_freq_factor,
204
+ "rope_pretraining_sequence_length": (
205
+ self.rope_pretraining_sequence_length
206
+ ),
173
207
  "num_key_value_heads": self.num_key_value_heads,
174
208
  "layer_norm_epsilon": self.layer_norm_epsilon,
175
209
  "dropout": self.dropout,
@@ -21,7 +21,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
21
21
  num_query_heads,
22
22
  num_key_value_heads,
23
23
  rope_max_wavelength=10000,
24
- rope_scaling_factor=1.0,
24
+ rope_position_scaling_factor=1.0,
25
+ rope_frequency_adjustment_factor=None,
26
+ rope_low_freq_factor=None,
27
+ rope_high_freq_factor=None,
28
+ rope_pretraining_sequence_length=None,
25
29
  activation="silu",
26
30
  layer_norm_epsilon=1e-5,
27
31
  kernel_initializer="glorot_uniform",
@@ -34,7 +38,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
34
38
  self.num_key_value_heads = num_key_value_heads
35
39
 
36
40
  self.rope_max_wavelength = rope_max_wavelength
37
- self.rope_scaling_factor = rope_scaling_factor
41
+ self.rope_position_scaling_factor = rope_position_scaling_factor
42
+ self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
43
+ self.rope_low_freq_factor = rope_low_freq_factor
44
+ self.rope_high_freq_factor = rope_high_freq_factor
45
+ self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
38
46
 
39
47
  self.dropout = dropout
40
48
 
@@ -53,7 +61,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
53
61
  num_query_heads=self.num_query_heads,
54
62
  num_key_value_heads=self.num_key_value_heads,
55
63
  rope_max_wavelength=self.rope_max_wavelength,
56
- rope_scaling_factor=self.rope_scaling_factor,
64
+ rope_position_scaling_factor=self.rope_position_scaling_factor,
65
+ rope_frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
66
+ rope_low_freq_factor=self.rope_low_freq_factor,
67
+ rope_high_freq_factor=self.rope_high_freq_factor,
68
+ rope_pretraining_sequence_length=self.rope_pretraining_sequence_length,
57
69
  kernel_initializer=clone_initializer(self.kernel_initializer),
58
70
  dropout=self.dropout,
59
71
  dtype=self.dtype_policy,
@@ -221,6 +233,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
221
233
  "num_query_heads": self.num_query_heads,
222
234
  "rope_max_wavelength": self.rope_max_wavelength,
223
235
  "rope_scaling_factor": self.rope_scaling_factor,
236
+ "rope_low_freq_factor": self.rope_low_freq_factor,
237
+ "rope_high_freq_factor": self.rope_high_freq_factor,
238
+ "rope_pretraining_sequence_length": (
239
+ self.rope_pretraining_sequence_length
240
+ ),
224
241
  "num_key_value_heads": self.num_key_value_heads,
225
242
  "activation": keras.activations.serialize(self.activation),
226
243
  "layer_norm_epsilon": self.layer_norm_epsilon,
@@ -0,0 +1,180 @@
1
+ import math
2
+
3
+ from keras import ops
4
+
5
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+
7
+
8
+ class LlamaRotaryEmbedding(RotaryEmbedding):
9
+ """Rotary positional encoding layer.
10
+
11
+ This layer encodes absolute positional information with a rotation
12
+ matrix. It calculates the rotary encoding with a mix of sine and
13
+ cosine functions with geometrically increasing wavelengths.
14
+ Defined and formulated in
15
+ [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
16
+ The input must be a tensor with shape a sequence dimension and a feature
17
+ dimension. Typically, this will either an input with shape
18
+ `(batch_size, sequence_length, feature_length)` or
19
+ `(batch_size, sequence_length, num_heads, feature_length)`.
20
+ This layer will return a new tensor with the rotary embedding applied to
21
+ the input tensor.
22
+ It is extended from `RotaryEmbedding` layer in `keras_hub.layers`.
23
+ It has additional smoothening and interpolation for some frequency ranges.
24
+
25
+ Args:
26
+ max_wavelength: int. The maximum angular wavelength of the sine/cosine
27
+ curves. Defaults to `10000`.
28
+ position_scaling_factor: float. The scaling factor used to scale
29
+ positions of the tokens. Defaults to `1.0`.
30
+ frequency_adjustment_factor: float. The scaling factor used to scale the
31
+ inverse frequencies. Defaults to `None`.
32
+ low_freq_factor: float. The low frequency scaling factor.
33
+ Defaults to `None`.
34
+ high_freq_factor: float. The high frequency scaling factor.
35
+ Defaults to `None`.
36
+ pretraining_sequence_length: int. Used for Llama3.1+, the original
37
+ context length at time of pretraining. Defaults to `None`.
38
+ sequence_axis: int. Sequence axis in the input tensor.
39
+ feature_axis: int. Feature axis in the input tensor.
40
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
41
+ including `name`, `trainable`, `dtype` etc.
42
+
43
+ Call arguments:
44
+ inputs: The tensor inputs to apply the embedding to. This can have
45
+ any shape, but must contain both a sequence and feature axis. The
46
+ rotary embedding will be applied to `inputs` and returned.
47
+ start_index: An integer or integer tensor. The starting position to
48
+ compute the rotary embedding from. This is useful during cached
49
+ decoding, where each position is predicted separately in a loop.
50
+
51
+ Examples:
52
+
53
+ ```python
54
+ batch_size = 16
55
+ feature_length = 18
56
+ sequence_length = 256
57
+ num_heads = 8
58
+
59
+ # No multi-head dimension.
60
+ tensor = np.ones((batch_size, sequence_length, feature_length))
61
+ rot_emb_layer = RotaryEmbedding()
62
+ tensor_rot = rot_emb_layer(tensor)
63
+
64
+ # With multi-head dimension.
65
+ tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
66
+ tensor_rot = rot_emb_layer(tensor)
67
+ ```
68
+
69
+ References:
70
+ - [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ max_wavelength=10000,
76
+ position_scaling_factor=1.0,
77
+ sequence_axis=1,
78
+ feature_axis=-1,
79
+ frequency_adjustment_factor=None,
80
+ low_freq_factor=None,
81
+ high_freq_factor=None,
82
+ pretraining_sequence_length=None,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(
86
+ max_wavelength=max_wavelength,
87
+ scaling_factor=position_scaling_factor,
88
+ sequence_axis=sequence_axis,
89
+ feature_axis=feature_axis,
90
+ **kwargs,
91
+ )
92
+ self.max_wavelength = max_wavelength
93
+ self.sequence_axis = sequence_axis
94
+ self.feature_axis = feature_axis
95
+ self.position_scaling_factor = position_scaling_factor
96
+ self.frequency_adjustment_factor = frequency_adjustment_factor
97
+ self.low_freq_factor = low_freq_factor
98
+ self.high_freq_factor = high_freq_factor
99
+ self.pretraining_sequence_length = pretraining_sequence_length
100
+
101
+ grouped_args = [
102
+ low_freq_factor,
103
+ high_freq_factor,
104
+ frequency_adjustment_factor,
105
+ pretraining_sequence_length,
106
+ ]
107
+ args_none = [x is None for x in grouped_args]
108
+ if any(args_none) and not all(args_none):
109
+ raise ValueError(
110
+ "Either all of `low_freq_factor`,`high_freq_factor`, "
111
+ "`frequency_adjustment_factor` and "
112
+ "`pretraining_sequence_length` should be set, or all of should"
113
+ " be set `None`."
114
+ )
115
+ self.built = True
116
+
117
+ def _get_inverse_freq(self, rotary_dim):
118
+ freq_range = ops.divide(
119
+ ops.arange(0, rotary_dim, 2, dtype="float32"),
120
+ ops.cast(rotary_dim, "float32"),
121
+ )
122
+ inverse_freq = 1.0 / (self.max_wavelength**freq_range)
123
+
124
+ # From llama3.1+ we have additional smoothening and interpolation.
125
+ # low_freq_factor, high_freq_factor, pretraining_sequence_length,
126
+ # frequency_adjustment_factor are all set at once so it is fine.
127
+ if self.low_freq_factor is not None:
128
+ low_freq_wavelen = (
129
+ self.pretraining_sequence_length / self.low_freq_factor
130
+ )
131
+ high_freq_wavelen = (
132
+ self.pretraining_sequence_length / self.high_freq_factor
133
+ )
134
+ wavelen = 2 * math.pi / inverse_freq
135
+
136
+ # wavelen < high_freq_wavelen: do nothing
137
+ # wavelen > low_freq_wavelen: divide by factor
138
+ inverse_freq = ops.where(
139
+ ops.greater(wavelen, low_freq_wavelen),
140
+ (inverse_freq / self.frequency_adjustment_factor),
141
+ inverse_freq,
142
+ )
143
+
144
+ # otherwise: interpolate between the two, using a smooth factor
145
+ smooth_factor = (
146
+ (self.pretraining_sequence_length / wavelen)
147
+ - self.low_freq_factor
148
+ ) / (self.high_freq_factor - self.low_freq_factor)
149
+ smoothed_inv_freq = (1 - smooth_factor) * (
150
+ inverse_freq / self.frequency_adjustment_factor
151
+ ) + (smooth_factor * inverse_freq)
152
+ is_medium_freq = ops.logical_and(
153
+ ops.greater_equal(wavelen, high_freq_wavelen),
154
+ ops.less_equal(wavelen, low_freq_wavelen),
155
+ )
156
+
157
+ inverse_freq = ops.where(
158
+ is_medium_freq, smoothed_inv_freq, inverse_freq
159
+ )
160
+
161
+ return inverse_freq
162
+
163
+ def get_config(self):
164
+ config = super().get_config()
165
+ config.update(
166
+ {
167
+ "max_wavelength": self.max_wavelength,
168
+ "sequence_axis": self.sequence_axis,
169
+ "feature_axis": self.feature_axis,
170
+ "position_scaling_factor": self.position_scaling_factor,
171
+ "frequency_adjustment_factor": self.frequency_adjustment_factor,
172
+ "low_freq_factor": self.low_freq_factor,
173
+ "high_freq_factor": self.high_freq_factor,
174
+ "original_max_embeddings": self.pretraining_sequence_length,
175
+ }
176
+ )
177
+ return config
178
+
179
+ def compute_output_shape(self, input_shape):
180
+ return input_shape
@@ -32,8 +32,16 @@ class Llama3Backbone(LlamaBackbone):
32
32
  fo each transformer.
33
33
  rope_max_wavelength (int, optional): The maximum angular wavelength of
34
34
  the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
35
- rope_scaling_factor (float, optional): The scaling factor for
36
- calculation of roatary embedding. Defaults to `1.0`.
35
+ rope_position_scaling_factor (float, optional): The scaling factor for
36
+ calculation of roatary embedding. Defaults to `1.0`
37
+ rope_requency_adjustment_factor (float, optional): The scaling factor
38
+ used to scale the inverse frequencies.
39
+ rope_low_freq_factor (float, optional): The low frequency factor.
40
+ Defaults to None.
41
+ rope_high_freq_factor: (float, optional) Used for Llama3.1+. The high
42
+ frequency factor. Defaults to None.
43
+ rope_pretraining_sequence_length: (int, optional) Sequence length during
44
+ original pretraining. Defaults to None.
37
45
  layer_norm_epsilon (float, optional): Epsilon for the layer
38
46
  normalization layers in the transformer decoder. Defaults to `1e-6`.
39
47
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
@@ -45,4 +45,86 @@ backbone_presets = {
45
45
  "kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/2"
46
46
  ),
47
47
  },
48
+ "llama3.1_8b": {
49
+ "metadata": {
50
+ "description": (
51
+ "8 billion parameter, 32-layer, based LLaMA 3.1 model. "
52
+ ),
53
+ "params": 8030261248,
54
+ "path": "llama3",
55
+ },
56
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_8b/1"),
57
+ },
58
+ "llama3.1_instruct_8b": {
59
+ "metadata": {
60
+ "description": (
61
+ "8 billion parameter, 32-layer, instruction tuned LLaMA 3.1. "
62
+ ),
63
+ "params": 8030261248,
64
+ "path": "llama3",
65
+ },
66
+ "kaggle_handle": ("kaggle://keras/llama3/keras/lama3.1_instruct_8b/1"),
67
+ },
68
+ "llama3.1_guard_8b": {
69
+ "metadata": {
70
+ "description": (
71
+ "8 billion parameter, 32-layer, LLaMA 3.1 fine-tuned for "
72
+ "consent safety classification. "
73
+ ),
74
+ "params": 8030261248,
75
+ "path": "llama3",
76
+ },
77
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_guard_8b/1"),
78
+ },
79
+ "llama3.2_1b": {
80
+ "metadata": {
81
+ "description": (
82
+ "1 billion parameter, 16-layer, based LLaMA 3.2 model. "
83
+ ),
84
+ "params": 1498482688,
85
+ "path": "llama3",
86
+ },
87
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_1b/1"),
88
+ },
89
+ "llama3.2_instruct_1b": {
90
+ "metadata": {
91
+ "description": (
92
+ "1 billion parameter, 16-layer, instruction tuned LLaMA 3.2. "
93
+ ),
94
+ "params": 1498482688,
95
+ "path": "llama3",
96
+ },
97
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_1b/1"),
98
+ },
99
+ "llama3.2_3b": {
100
+ "metadata": {
101
+ "description": (
102
+ "3 billion parameter, 26-layer, based LLaMA 3.2 model. "
103
+ ),
104
+ "params": 3606752256,
105
+ "path": "llama3",
106
+ },
107
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_3b/1"),
108
+ },
109
+ "llama3.2_instruct_3b": {
110
+ "metadata": {
111
+ "description": (
112
+ "3 billion parameter, 28-layer, instruction tuned LLaMA 3.2. "
113
+ ),
114
+ "params": 3606752256,
115
+ "path": "llama3",
116
+ },
117
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_3b/1"),
118
+ },
119
+ "llama3.2_guard_1b": {
120
+ "metadata": {
121
+ "description": (
122
+ "1 billion parameter, 16-layer, based LLaMA 3.2 model "
123
+ "fine-tuned for consent safety classification. "
124
+ ),
125
+ "params": 1498482688,
126
+ "path": "llama3",
127
+ },
128
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_guard_1b/1"),
129
+ },
48
130
  }
@@ -7,7 +7,7 @@ backbone_cls = Llama3Backbone
7
7
 
8
8
 
9
9
  def convert_backbone_config(transformers_config):
10
- return {
10
+ backbone_config = {
11
11
  "vocabulary_size": transformers_config["vocab_size"],
12
12
  "num_layers": transformers_config["num_hidden_layers"],
13
13
  "num_query_heads": transformers_config["num_attention_heads"],
@@ -15,8 +15,28 @@ def convert_backbone_config(transformers_config):
15
15
  "intermediate_dim": transformers_config["intermediate_size"],
16
16
  "num_key_value_heads": transformers_config["num_key_value_heads"],
17
17
  "tie_word_embeddings": transformers_config["tie_word_embeddings"],
18
+ "rope_max_wavelength": transformers_config["rope_theta"],
18
19
  }
19
20
 
21
+ if transformers_config.get("rope_scaling", None) is not None:
22
+ if transformers_config["rope_scaling"]["rope_type"] != "llama3":
23
+ raise ValueError("The config should be a valid llama3 config.")
24
+ backbone_config["rope_frequency_adjustment_factor"] = (
25
+ transformers_config["rope_scaling"]["factor"]
26
+ )
27
+ backbone_config["rope_low_freq_factor"] = transformers_config[
28
+ "rope_scaling"
29
+ ]["low_freq_factor"]
30
+ backbone_config["rope_high_freq_factor"] = transformers_config[
31
+ "rope_scaling"
32
+ ]["high_freq_factor"]
33
+ backbone_config["rope_pretraining_sequence_length"] = (
34
+ transformers_config["rope_scaling"][
35
+ "original_max_position_embeddings"
36
+ ]
37
+ )
38
+ return backbone_config
39
+
20
40
 
21
41
  def convert_weights(backbone, loader, transformers_config):
22
42
  loader.port_weight(
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.21.0.dev202504170402"
4
+ __version__ = "0.21.0.dev202504190357"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-hub-nightly
3
- Version: 0.21.0.dev202504170402
3
+ Version: 0.21.0.dev202504190357
4
4
  Summary: Industry-strength Natural Language Processing extensions for Keras.
5
5
  Home-page: https://github.com/keras-team/keras-hub
6
6
  Author: Keras team
@@ -8,7 +8,7 @@ keras_hub/api/tokenizers/__init__.py,sha256=NCQSOg3vf3KlM2YBsxApcJUVu9MH2jV0NQrM
8
8
  keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
9
9
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
11
- keras_hub/src/version_utils.py,sha256=um5nImV3kQfkhp9f7hoNHS8pkeGqPkhA4xKqbhBdupQ,222
11
+ keras_hub/src/version_utils.py,sha256=JHxB700m8f2SDoTYZCedVnki2N4OQO1V_32-jRHk4tU,222
12
12
  keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
@@ -173,7 +173,7 @@ keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py,sha256=UUa7RKyl
173
173
  keras_hub/src/models/f_net/f_net_tokenizer.py,sha256=ZRTaSfgZnYLTVXgM51303LpryRsSL5GaC2Cl_D7g27A,2285
174
174
  keras_hub/src/models/falcon/__init__.py,sha256=IVwPgPbw0l8XJRPQETmeQNvpdn_SneXhe_3oRMOvdx8,257
175
175
  keras_hub/src/models/falcon/falcon_attention.py,sha256=fRHuK_y_w64hrqq0XYfcsycs3KD1_3RmeKP7j8LEjGU,4559
176
- keras_hub/src/models/falcon/falcon_backbone.py,sha256=nGJcHnbqncZRTPERRi4ZuYGcODpkH2Mu0-Db59vH5io,5451
176
+ keras_hub/src/models/falcon/falcon_backbone.py,sha256=hRwomKH_GIKJ0KMfccpHVU43HVN0WQy1n9PldvlUaTM,5451
177
177
  keras_hub/src/models/falcon/falcon_causal_lm.py,sha256=2UEIeju5Tg-FstVuusejJ-MbHZ6vsNfsSJzzBM89fnU,10908
178
178
  keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py,sha256=nI9E8N9enx5DppDHpLwGslb65rqGorL2sEz1jzet4gA,3033
179
179
  keras_hub/src/models/falcon/falcon_presets.py,sha256=PDghkND0-7le4W-atm4BitzA127z-5ZyQguCnCChSBo,463
@@ -199,7 +199,7 @@ keras_hub/src/models/gemma3/__init__.py,sha256=oPFadkdK5DRLD6sYx83iTetY5daWuSzmJ
199
199
  keras_hub/src/models/gemma3/gemma3_attention.py,sha256=VstFCTVsplcDNSgnyBcSpLgKn-pktJ39D5Ri-Bb7BQA,13628
200
200
  keras_hub/src/models/gemma3/gemma3_backbone.py,sha256=xw6gbFZWZuREcN1iyPj-1Hm-3EmRglgFD5fQSzDp3zA,16439
201
201
  keras_hub/src/models/gemma3/gemma3_causal_lm.py,sha256=U3C9TWlIz8VefAxQ0wJ6bDz18wqHBie8B26Ub_nFZs4,13843
202
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py,sha256=HPIkSRAevePLEWx-t6oqtaOdxF0FjeBQKAg2Ey9axLA,29524
202
+ keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py,sha256=vjt4N-zr0Eb5kvkOR-WUgskDTNe64L_6tYnhyNb6xaE,29601
203
203
  keras_hub/src/models/gemma3/gemma3_decoder_block.py,sha256=6PLlpDxxF67stDv74fw9nNgUHBWmTLx6qGygJwyu5FY,10819
204
204
  keras_hub/src/models/gemma3/gemma3_image_converter.py,sha256=czi5JrTyKiK0nFzvonviBIX8jjvLHqvGNA9RyheB31k,536
205
205
  keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py,sha256=_Q5hvhA93HAJe-A2IBRKVu0_RDVht61lFQiYse_9Rm4,4597
@@ -222,19 +222,20 @@ keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py,sha256=YiVz9q
222
222
  keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py,sha256=hmB81V0SuI6bEsxEuFkYgq58wbcrv1YLvmXGin5T3E0,9732
223
223
  keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py,sha256=aKso-8yGrynn3tZ5xm2egcXIBQo3__sWZDBtjmS3ZgU,1991
224
224
  keras_hub/src/models/llama/__init__.py,sha256=svVZjGi71R3lVbq0AdbqlXj909mr3Rp9EPXdiO0w0G0,251
225
- keras_hub/src/models/llama/llama_attention.py,sha256=Q5N37sAESAjdFg9GNlanvNbD-dHS3mNNtt3vMXAFKMs,7931
226
- keras_hub/src/models/llama/llama_backbone.py,sha256=tjNEIKIL9ncoEL5KNFE5i0oTUkysjmJmh3mHmCz4RCw,11861
225
+ keras_hub/src/models/llama/llama_attention.py,sha256=UFHOWr69vTkOxLdgSUckGaSuUUyqlJ_xYoswWHVnTOU,8977
226
+ keras_hub/src/models/llama/llama_backbone.py,sha256=AT8kUPHEn6DT-aGY838_sZkBhByIdh82DWW8y-Sp3mE,13614
227
227
  keras_hub/src/models/llama/llama_causal_lm.py,sha256=9bP4-XDCMgsZuH1ILIMzmwq2Fyy6vkk1Vsht-lMGCNo,13258
228
228
  keras_hub/src/models/llama/llama_causal_lm_preprocessor.py,sha256=VTboOMiRBoxHrwP343upLUTsv3AG65r2H8h_PNPVphE,3047
229
- keras_hub/src/models/llama/llama_decoder.py,sha256=6iERIblED0ZB5w_EUlHks4UvMnsrWONdO_Xdz2OzhWM,8623
229
+ keras_hub/src/models/llama/llama_decoder.py,sha256=CfWI8ru1-uWjDs0sL6H7g8ElYXWu6h7c5XIx-2Y8lX8,9668
230
230
  keras_hub/src/models/llama/llama_layernorm.py,sha256=LfRbePHUJs00Ptf7dvNaw3Aj9n1xBMBpE_rS5zzsYMo,1050
231
231
  keras_hub/src/models/llama/llama_presets.py,sha256=k0JPQggSQ0XUkhiPlfM0gTqHXGOt39InVLglPUi4AJU,1902
232
+ keras_hub/src/models/llama/llama_rotary_embedding.py,sha256=nqQGl7lFXJq7xGBfoONx2-wuuvKdoydnzUjy6FGQjwo,7300
232
233
  keras_hub/src/models/llama/llama_tokenizer.py,sha256=NKWhxTutQ2jd6sd3NSTy9plQyKGCmuNG7U6kVxhZU4Y,1981
233
234
  keras_hub/src/models/llama3/__init__.py,sha256=Vqvr2E10cnANkrRQGNBJtVLNAu-Bg9Lx6sqKOZWFy_8,257
234
- keras_hub/src/models/llama3/llama3_backbone.py,sha256=g_IkHys5cr0gBXhDiqgIICO93RdGAm6WS5NK2SPhFvM,2866
235
+ keras_hub/src/models/llama3/llama3_backbone.py,sha256=TEocD8X7GihQFGJAz3jPwLCqDb86nyeZ1DqBF7RgQLE,3366
235
236
  keras_hub/src/models/llama3/llama3_causal_lm.py,sha256=qk_onuf7S6d7rxAntilq2Q2orggMbPEJbNHJNVe2G0U,1541
236
237
  keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py,sha256=twbXel9hsQgGxDAoQhEQuVm2udnEybI4fAQTJzXAuBs,3064
237
- keras_hub/src/models/llama3/llama3_presets.py,sha256=PWEW_hLMCD9SIYm3QLhRVIcwjrPuqv-KDebXACXRNbM,1579
238
+ keras_hub/src/models/llama3/llama3_presets.py,sha256=--_6Uao-fK4xD4ShgsqzKmlyQPyO9tRkF0VDYKjGpNw,4302
238
239
  keras_hub/src/models/llama3/llama3_tokenizer.py,sha256=J-KxRc08vGs4olFw_4mtJs0W_dTeUyj_XxMycazBmxI,1934
239
240
  keras_hub/src/models/mistral/__init__.py,sha256=vjBlzcrIsFSwJKnfwfTNMKstIEKGFTE3kVcdAdfwlnE,263
240
241
  keras_hub/src/models/mistral/mistral_attention.py,sha256=nGDlD4NcIwIGlfbt3ArxdT5QAvamY7yiNEGDlTgWirU,8609
@@ -458,14 +459,14 @@ keras_hub/src/utils/transformers/convert_bert.py,sha256=4gQqXCJzC9QWdLPDUAq741K8
458
459
  keras_hub/src/utils/transformers/convert_distilbert.py,sha256=SlfIRhSRk5c1ir2HGiDPiXa5XdOId_DbcnZO9lbwyZ8,6498
459
460
  keras_hub/src/utils/transformers/convert_gemma.py,sha256=ElCgwBpSN5Q7rV5PJawTsoytPzs5ZjuwoY60YAe8y_A,6533
460
461
  keras_hub/src/utils/transformers/convert_gpt2.py,sha256=HCeHN_-GiQJRxLCM9OCJJ1watPVpIBF8ujS8pGbBOWc,5703
461
- keras_hub/src/utils/transformers/convert_llama3.py,sha256=zlg0yFscjytyOFymDwqnbuXkmYvb88qqYzAROKcpaPU,5250
462
+ keras_hub/src/utils/transformers/convert_llama3.py,sha256=c5phNl-QayQ_BS0s-lenbu6oHxqfwDShKJoh9DluxUU,6146
462
463
  keras_hub/src/utils/transformers/convert_mistral.py,sha256=kVhN9h1ZFVhwkNW8p3wnS7eANJUXIsNy1RxWXy20Gqw,4760
463
464
  keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYumf66hIid07k5NLqoeWAJgPnaLs,10649
464
465
  keras_hub/src/utils/transformers/convert_qwen.py,sha256=WUxMAEFVqRs7TRw7QU5TH3_ev4yf02R1xFVliMvTQqg,5886
465
466
  keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
466
467
  keras_hub/src/utils/transformers/preset_loader.py,sha256=0Hi7R8HnATcwFVLsJwMMIMWTCXHNfep4IPiRpQXqM-w,3933
467
468
  keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
468
- keras_hub_nightly-0.21.0.dev202504170402.dist-info/METADATA,sha256=ADjI3HBVo2_uW8UYGR2aA6Uuetr05sEOpr2BcAR9SY0,7715
469
- keras_hub_nightly-0.21.0.dev202504170402.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
470
- keras_hub_nightly-0.21.0.dev202504170402.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
471
- keras_hub_nightly-0.21.0.dev202504170402.dist-info/RECORD,,
469
+ keras_hub_nightly-0.21.0.dev202504190357.dist-info/METADATA,sha256=upX5yCkjjUPOo3UnzvGqlIQVsM8hYfO5VlOKWL7aOI4,7715
470
+ keras_hub_nightly-0.21.0.dev202504190357.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
471
+ keras_hub_nightly-0.21.0.dev202504190357.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
472
+ keras_hub_nightly-0.21.0.dev202504190357.dist-info/RECORD,,