keras-hub-nightly 0.21.0.dev202504170402__py3-none-any.whl → 0.21.0.dev202504180401__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +1 -0
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +82 -0
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.21.0.dev202504170402.dist-info → keras_hub_nightly-0.21.0.dev202504180401.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202504170402.dist-info → keras_hub_nightly-0.21.0.dev202504180401.dist-info}/RECORD +14 -13
- {keras_hub_nightly-0.21.0.dev202504170402.dist-info → keras_hub_nightly-0.21.0.dev202504180401.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.21.0.dev202504170402.dist-info → keras_hub_nightly-0.21.0.dev202504180401.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ class FalconBackbone(Backbone):
|
|
29
29
|
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
|
30
30
|
the transformer decoder.
|
31
31
|
attention_dropout_rate: float. Dropout probability for the attention.
|
32
|
-
feedforward_dropout_rate:
|
32
|
+
feedforward_dropout_rate: float. Dropout probability for the
|
33
33
|
feedforward.
|
34
34
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
35
35
|
for model computations and weights. Note that some computations,
|
@@ -512,6 +512,7 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
512
512
|
|
513
513
|
# Extract text part of the input.
|
514
514
|
prompts, responses = x["prompts"], x["responses"]
|
515
|
+
tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
|
515
516
|
|
516
517
|
# Find out if the input is batched/not batched. Uprank if not batched.
|
517
518
|
# In other preprocessors, we don't have to do this, but here, all
|
@@ -3,7 +3,9 @@ import math
|
|
3
3
|
import keras
|
4
4
|
from keras import ops
|
5
5
|
|
6
|
-
from keras_hub.src.
|
6
|
+
from keras_hub.src.models.llama.llama_rotary_embedding import (
|
7
|
+
LlamaRotaryEmbedding,
|
8
|
+
)
|
7
9
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
10
|
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
11
|
|
@@ -16,7 +18,11 @@ class LlamaAttention(keras.layers.Layer):
|
|
16
18
|
num_query_heads,
|
17
19
|
num_key_value_heads,
|
18
20
|
rope_max_wavelength=10000,
|
19
|
-
|
21
|
+
rope_position_scaling_factor=1.0,
|
22
|
+
rope_frequency_adjustment_factor=None,
|
23
|
+
rope_low_freq_factor=None,
|
24
|
+
rope_high_freq_factor=None,
|
25
|
+
rope_pretraining_sequence_length=None,
|
20
26
|
kernel_initializer="glorot_uniform",
|
21
27
|
dropout=0,
|
22
28
|
**kwargs,
|
@@ -28,13 +34,16 @@ class LlamaAttention(keras.layers.Layer):
|
|
28
34
|
|
29
35
|
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
30
36
|
self.rope_max_wavelength = rope_max_wavelength
|
37
|
+
self.rope_position_scaling_factor = rope_position_scaling_factor
|
38
|
+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
|
39
|
+
self.rope_low_freq_factor = rope_low_freq_factor
|
40
|
+
self.rope_high_freq_factor = rope_high_freq_factor
|
41
|
+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
|
31
42
|
|
32
43
|
self.kernel_initializer = keras.initializers.get(
|
33
44
|
clone_initializer(kernel_initializer)
|
34
45
|
)
|
35
46
|
|
36
|
-
self.rope_scaling_factor = rope_scaling_factor
|
37
|
-
|
38
47
|
def build(self, inputs_shape):
|
39
48
|
# Einsum variables:
|
40
49
|
# b = batch size
|
@@ -103,9 +112,13 @@ class LlamaAttention(keras.layers.Layer):
|
|
103
112
|
)
|
104
113
|
self._output_dense.build((None, None, self.num_query_heads, head_dim))
|
105
114
|
|
106
|
-
self.rotary_embedding_layer =
|
115
|
+
self.rotary_embedding_layer = LlamaRotaryEmbedding(
|
107
116
|
max_wavelength=self.rope_max_wavelength,
|
108
|
-
|
117
|
+
position_scaling_factor=self.rope_position_scaling_factor,
|
118
|
+
frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
|
119
|
+
low_freq_factor=self.rope_low_freq_factor,
|
120
|
+
high_freq_factor=self.rope_high_freq_factor,
|
121
|
+
pretraining_sequence_length=self.rope_pretraining_sequence_length,
|
109
122
|
dtype=self.dtype_policy,
|
110
123
|
)
|
111
124
|
|
@@ -224,6 +237,11 @@ class LlamaAttention(keras.layers.Layer):
|
|
224
237
|
"num_key_value_heads": self.num_key_value_heads,
|
225
238
|
"rope_max_wavelength": self.rope_max_wavelength,
|
226
239
|
"rope_scaling_factor": self.rope_scaling_factor,
|
240
|
+
"rope_low_freq_factor": self.rope_low_freq_factor,
|
241
|
+
"rope_high_freq_factor": self.rope_high_freq_factor,
|
242
|
+
"rope_pretraining_sequence_length": (
|
243
|
+
self.rope_pretraining_sequence_length
|
244
|
+
),
|
227
245
|
"kernel_initializer": keras.initializers.serialize(
|
228
246
|
self.kernel_initializer
|
229
247
|
),
|
@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
|
|
30
30
|
constructor.
|
31
31
|
|
32
32
|
Args:
|
33
|
-
vocabulary_size
|
34
|
-
num_layers
|
35
|
-
num_query_heads
|
33
|
+
vocabulary_size: int. The size of the token vocabulary.
|
34
|
+
num_layers: int. The number of transformer layers.
|
35
|
+
num_query_heads : int. The number of query attention heads for
|
36
36
|
each transformer.
|
37
|
-
hidden_dim
|
37
|
+
hidden_dim : int. The size of the transformer encoding and pooling
|
38
38
|
layers.
|
39
|
-
intermediate_dim
|
39
|
+
intermediate_dim : int. The output dimension of the first Dense layer in
|
40
40
|
a three-layer feedforward network for each transformer.
|
41
|
-
num_key_value_heads
|
41
|
+
num_key_value_heads : int. The number of key and value attention heads
|
42
42
|
for each transformer.
|
43
|
-
rope_max_wavelength
|
43
|
+
rope_max_wavelength : int. The maximum angular wavelength of
|
44
44
|
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
45
|
-
|
46
|
-
calculation of
|
47
|
-
|
48
|
-
|
45
|
+
rope_position_scaling_factor: float. The scaling factor for
|
46
|
+
calculation of rotary embedding. Defaults to `1.0`
|
47
|
+
rope_frequency_adjustment_factor: float. The scaling factor
|
48
|
+
used to scale the inverse frequencies. Defaults to `None`.
|
49
|
+
rope_low_freq_factor: float. The low frequency scaling
|
50
|
+
factor. Defaults to `None`.
|
51
|
+
rope_high_freq_factor: float. Used for Llama3.1+. The high
|
52
|
+
frequency scaling factor. Defaults to `None`.
|
53
|
+
rope_pretraining_sequence_length: int. Used for Llama3.1+.
|
54
|
+
Defaults to `None`.
|
55
|
+
layer_norm_epsilon : float. Epsilon for the layer normalization layers
|
56
|
+
in the transformer decoder. Defaults to `1e-6`.
|
49
57
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
50
58
|
for model computations and weights. Note that some computations,
|
51
59
|
such as softmax and layer normalization, will always be done at
|
@@ -87,7 +95,11 @@ class LlamaBackbone(Backbone):
|
|
87
95
|
intermediate_dim,
|
88
96
|
num_key_value_heads,
|
89
97
|
rope_max_wavelength=10000,
|
90
|
-
|
98
|
+
rope_position_scaling_factor=1.0,
|
99
|
+
rope_frequency_adjustment_factor=None,
|
100
|
+
rope_low_freq_factor=None,
|
101
|
+
rope_high_freq_factor=None,
|
102
|
+
rope_pretraining_sequence_length=None,
|
91
103
|
layer_norm_epsilon=1e-6,
|
92
104
|
dropout=0,
|
93
105
|
dtype=None,
|
@@ -110,7 +122,15 @@ class LlamaBackbone(Backbone):
|
|
110
122
|
num_query_heads=num_query_heads,
|
111
123
|
num_key_value_heads=num_key_value_heads,
|
112
124
|
rope_max_wavelength=rope_max_wavelength,
|
113
|
-
|
125
|
+
rope_position_scaling_factor=rope_position_scaling_factor,
|
126
|
+
rope_frequency_adjustment_factor=(
|
127
|
+
rope_frequency_adjustment_factor
|
128
|
+
),
|
129
|
+
rope_low_freq_factor=rope_low_freq_factor,
|
130
|
+
rope_high_freq_factor=rope_high_freq_factor,
|
131
|
+
rope_pretraining_sequence_length=(
|
132
|
+
rope_pretraining_sequence_length
|
133
|
+
),
|
114
134
|
layer_norm_epsilon=layer_norm_epsilon,
|
115
135
|
activation=ops.silu,
|
116
136
|
kernel_initializer=_llama_kernel_initializer(stddev=0.02),
|
@@ -152,9 +172,13 @@ class LlamaBackbone(Backbone):
|
|
152
172
|
self.num_query_heads = num_query_heads
|
153
173
|
self.hidden_dim = hidden_dim
|
154
174
|
self.intermediate_dim = intermediate_dim
|
155
|
-
self.rope_max_wavelength = rope_max_wavelength
|
156
175
|
self.num_key_value_heads = num_key_value_heads
|
157
|
-
self.
|
176
|
+
self.rope_max_wavelength = rope_max_wavelength
|
177
|
+
self.rope_position_scaling_factor = rope_position_scaling_factor
|
178
|
+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
|
179
|
+
self.rope_low_freq_factor = rope_low_freq_factor
|
180
|
+
self.rope_high_freq_factor = rope_high_freq_factor
|
181
|
+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
|
158
182
|
self.layer_norm_epsilon = layer_norm_epsilon
|
159
183
|
self.dropout = dropout
|
160
184
|
self.tie_word_embeddings = tie_word_embeddings
|
@@ -169,7 +193,17 @@ class LlamaBackbone(Backbone):
|
|
169
193
|
"hidden_dim": self.hidden_dim,
|
170
194
|
"intermediate_dim": self.intermediate_dim,
|
171
195
|
"rope_max_wavelength": self.rope_max_wavelength,
|
172
|
-
"
|
196
|
+
"rope_position_scaling_factor": (
|
197
|
+
self.rope_position_scaling_factor
|
198
|
+
),
|
199
|
+
"rope_frequency_adjustment_factor": (
|
200
|
+
self.rope_frequency_adjustment_factor
|
201
|
+
),
|
202
|
+
"rope_low_freq_factor": self.rope_low_freq_factor,
|
203
|
+
"rope_high_freq_factor": self.rope_high_freq_factor,
|
204
|
+
"rope_pretraining_sequence_length": (
|
205
|
+
self.rope_pretraining_sequence_length
|
206
|
+
),
|
173
207
|
"num_key_value_heads": self.num_key_value_heads,
|
174
208
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
175
209
|
"dropout": self.dropout,
|
@@ -21,7 +21,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
21
21
|
num_query_heads,
|
22
22
|
num_key_value_heads,
|
23
23
|
rope_max_wavelength=10000,
|
24
|
-
|
24
|
+
rope_position_scaling_factor=1.0,
|
25
|
+
rope_frequency_adjustment_factor=None,
|
26
|
+
rope_low_freq_factor=None,
|
27
|
+
rope_high_freq_factor=None,
|
28
|
+
rope_pretraining_sequence_length=None,
|
25
29
|
activation="silu",
|
26
30
|
layer_norm_epsilon=1e-5,
|
27
31
|
kernel_initializer="glorot_uniform",
|
@@ -34,7 +38,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
34
38
|
self.num_key_value_heads = num_key_value_heads
|
35
39
|
|
36
40
|
self.rope_max_wavelength = rope_max_wavelength
|
37
|
-
self.
|
41
|
+
self.rope_position_scaling_factor = rope_position_scaling_factor
|
42
|
+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
|
43
|
+
self.rope_low_freq_factor = rope_low_freq_factor
|
44
|
+
self.rope_high_freq_factor = rope_high_freq_factor
|
45
|
+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
|
38
46
|
|
39
47
|
self.dropout = dropout
|
40
48
|
|
@@ -53,7 +61,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
53
61
|
num_query_heads=self.num_query_heads,
|
54
62
|
num_key_value_heads=self.num_key_value_heads,
|
55
63
|
rope_max_wavelength=self.rope_max_wavelength,
|
56
|
-
|
64
|
+
rope_position_scaling_factor=self.rope_position_scaling_factor,
|
65
|
+
rope_frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
|
66
|
+
rope_low_freq_factor=self.rope_low_freq_factor,
|
67
|
+
rope_high_freq_factor=self.rope_high_freq_factor,
|
68
|
+
rope_pretraining_sequence_length=self.rope_pretraining_sequence_length,
|
57
69
|
kernel_initializer=clone_initializer(self.kernel_initializer),
|
58
70
|
dropout=self.dropout,
|
59
71
|
dtype=self.dtype_policy,
|
@@ -221,6 +233,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
221
233
|
"num_query_heads": self.num_query_heads,
|
222
234
|
"rope_max_wavelength": self.rope_max_wavelength,
|
223
235
|
"rope_scaling_factor": self.rope_scaling_factor,
|
236
|
+
"rope_low_freq_factor": self.rope_low_freq_factor,
|
237
|
+
"rope_high_freq_factor": self.rope_high_freq_factor,
|
238
|
+
"rope_pretraining_sequence_length": (
|
239
|
+
self.rope_pretraining_sequence_length
|
240
|
+
),
|
224
241
|
"num_key_value_heads": self.num_key_value_heads,
|
225
242
|
"activation": keras.activations.serialize(self.activation),
|
226
243
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
@@ -0,0 +1,180 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from keras import ops
|
4
|
+
|
5
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
6
|
+
|
7
|
+
|
8
|
+
class LlamaRotaryEmbedding(RotaryEmbedding):
|
9
|
+
"""Rotary positional encoding layer.
|
10
|
+
|
11
|
+
This layer encodes absolute positional information with a rotation
|
12
|
+
matrix. It calculates the rotary encoding with a mix of sine and
|
13
|
+
cosine functions with geometrically increasing wavelengths.
|
14
|
+
Defined and formulated in
|
15
|
+
[RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
|
16
|
+
The input must be a tensor with shape a sequence dimension and a feature
|
17
|
+
dimension. Typically, this will either an input with shape
|
18
|
+
`(batch_size, sequence_length, feature_length)` or
|
19
|
+
`(batch_size, sequence_length, num_heads, feature_length)`.
|
20
|
+
This layer will return a new tensor with the rotary embedding applied to
|
21
|
+
the input tensor.
|
22
|
+
It is extended from `RotaryEmbedding` layer in `keras_hub.layers`.
|
23
|
+
It has additional smoothening and interpolation for some frequency ranges.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
max_wavelength: int. The maximum angular wavelength of the sine/cosine
|
27
|
+
curves. Defaults to `10000`.
|
28
|
+
position_scaling_factor: float. The scaling factor used to scale
|
29
|
+
positions of the tokens. Defaults to `1.0`.
|
30
|
+
frequency_adjustment_factor: float. The scaling factor used to scale the
|
31
|
+
inverse frequencies. Defaults to `None`.
|
32
|
+
low_freq_factor: float. The low frequency scaling factor.
|
33
|
+
Defaults to `None`.
|
34
|
+
high_freq_factor: float. The high frequency scaling factor.
|
35
|
+
Defaults to `None`.
|
36
|
+
pretraining_sequence_length: int. Used for Llama3.1+, the original
|
37
|
+
context length at time of pretraining. Defaults to `None`.
|
38
|
+
sequence_axis: int. Sequence axis in the input tensor.
|
39
|
+
feature_axis: int. Feature axis in the input tensor.
|
40
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
41
|
+
including `name`, `trainable`, `dtype` etc.
|
42
|
+
|
43
|
+
Call arguments:
|
44
|
+
inputs: The tensor inputs to apply the embedding to. This can have
|
45
|
+
any shape, but must contain both a sequence and feature axis. The
|
46
|
+
rotary embedding will be applied to `inputs` and returned.
|
47
|
+
start_index: An integer or integer tensor. The starting position to
|
48
|
+
compute the rotary embedding from. This is useful during cached
|
49
|
+
decoding, where each position is predicted separately in a loop.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
|
53
|
+
```python
|
54
|
+
batch_size = 16
|
55
|
+
feature_length = 18
|
56
|
+
sequence_length = 256
|
57
|
+
num_heads = 8
|
58
|
+
|
59
|
+
# No multi-head dimension.
|
60
|
+
tensor = np.ones((batch_size, sequence_length, feature_length))
|
61
|
+
rot_emb_layer = RotaryEmbedding()
|
62
|
+
tensor_rot = rot_emb_layer(tensor)
|
63
|
+
|
64
|
+
# With multi-head dimension.
|
65
|
+
tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
|
66
|
+
tensor_rot = rot_emb_layer(tensor)
|
67
|
+
```
|
68
|
+
|
69
|
+
References:
|
70
|
+
- [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
max_wavelength=10000,
|
76
|
+
position_scaling_factor=1.0,
|
77
|
+
sequence_axis=1,
|
78
|
+
feature_axis=-1,
|
79
|
+
frequency_adjustment_factor=None,
|
80
|
+
low_freq_factor=None,
|
81
|
+
high_freq_factor=None,
|
82
|
+
pretraining_sequence_length=None,
|
83
|
+
**kwargs,
|
84
|
+
):
|
85
|
+
super().__init__(
|
86
|
+
max_wavelength=max_wavelength,
|
87
|
+
scaling_factor=position_scaling_factor,
|
88
|
+
sequence_axis=sequence_axis,
|
89
|
+
feature_axis=feature_axis,
|
90
|
+
**kwargs,
|
91
|
+
)
|
92
|
+
self.max_wavelength = max_wavelength
|
93
|
+
self.sequence_axis = sequence_axis
|
94
|
+
self.feature_axis = feature_axis
|
95
|
+
self.position_scaling_factor = position_scaling_factor
|
96
|
+
self.frequency_adjustment_factor = frequency_adjustment_factor
|
97
|
+
self.low_freq_factor = low_freq_factor
|
98
|
+
self.high_freq_factor = high_freq_factor
|
99
|
+
self.pretraining_sequence_length = pretraining_sequence_length
|
100
|
+
|
101
|
+
grouped_args = [
|
102
|
+
low_freq_factor,
|
103
|
+
high_freq_factor,
|
104
|
+
frequency_adjustment_factor,
|
105
|
+
pretraining_sequence_length,
|
106
|
+
]
|
107
|
+
args_none = [x is None for x in grouped_args]
|
108
|
+
if any(args_none) and not all(args_none):
|
109
|
+
raise ValueError(
|
110
|
+
"Either all of `low_freq_factor`,`high_freq_factor`, "
|
111
|
+
"`frequency_adjustment_factor` and "
|
112
|
+
"`pretraining_sequence_length` should be set, or all of should"
|
113
|
+
" be set `None`."
|
114
|
+
)
|
115
|
+
self.built = True
|
116
|
+
|
117
|
+
def _get_inverse_freq(self, rotary_dim):
|
118
|
+
freq_range = ops.divide(
|
119
|
+
ops.arange(0, rotary_dim, 2, dtype="float32"),
|
120
|
+
ops.cast(rotary_dim, "float32"),
|
121
|
+
)
|
122
|
+
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
|
123
|
+
|
124
|
+
# From llama3.1+ we have additional smoothening and interpolation.
|
125
|
+
# low_freq_factor, high_freq_factor, pretraining_sequence_length,
|
126
|
+
# frequency_adjustment_factor are all set at once so it is fine.
|
127
|
+
if self.low_freq_factor is not None:
|
128
|
+
low_freq_wavelen = (
|
129
|
+
self.pretraining_sequence_length / self.low_freq_factor
|
130
|
+
)
|
131
|
+
high_freq_wavelen = (
|
132
|
+
self.pretraining_sequence_length / self.high_freq_factor
|
133
|
+
)
|
134
|
+
wavelen = 2 * math.pi / inverse_freq
|
135
|
+
|
136
|
+
# wavelen < high_freq_wavelen: do nothing
|
137
|
+
# wavelen > low_freq_wavelen: divide by factor
|
138
|
+
inverse_freq = ops.where(
|
139
|
+
ops.greater(wavelen, low_freq_wavelen),
|
140
|
+
(inverse_freq / self.frequency_adjustment_factor),
|
141
|
+
inverse_freq,
|
142
|
+
)
|
143
|
+
|
144
|
+
# otherwise: interpolate between the two, using a smooth factor
|
145
|
+
smooth_factor = (
|
146
|
+
(self.pretraining_sequence_length / wavelen)
|
147
|
+
- self.low_freq_factor
|
148
|
+
) / (self.high_freq_factor - self.low_freq_factor)
|
149
|
+
smoothed_inv_freq = (1 - smooth_factor) * (
|
150
|
+
inverse_freq / self.frequency_adjustment_factor
|
151
|
+
) + (smooth_factor * inverse_freq)
|
152
|
+
is_medium_freq = ops.logical_and(
|
153
|
+
ops.greater_equal(wavelen, high_freq_wavelen),
|
154
|
+
ops.less_equal(wavelen, low_freq_wavelen),
|
155
|
+
)
|
156
|
+
|
157
|
+
inverse_freq = ops.where(
|
158
|
+
is_medium_freq, smoothed_inv_freq, inverse_freq
|
159
|
+
)
|
160
|
+
|
161
|
+
return inverse_freq
|
162
|
+
|
163
|
+
def get_config(self):
|
164
|
+
config = super().get_config()
|
165
|
+
config.update(
|
166
|
+
{
|
167
|
+
"max_wavelength": self.max_wavelength,
|
168
|
+
"sequence_axis": self.sequence_axis,
|
169
|
+
"feature_axis": self.feature_axis,
|
170
|
+
"position_scaling_factor": self.position_scaling_factor,
|
171
|
+
"frequency_adjustment_factor": self.frequency_adjustment_factor,
|
172
|
+
"low_freq_factor": self.low_freq_factor,
|
173
|
+
"high_freq_factor": self.high_freq_factor,
|
174
|
+
"original_max_embeddings": self.pretraining_sequence_length,
|
175
|
+
}
|
176
|
+
)
|
177
|
+
return config
|
178
|
+
|
179
|
+
def compute_output_shape(self, input_shape):
|
180
|
+
return input_shape
|
@@ -32,8 +32,16 @@ class Llama3Backbone(LlamaBackbone):
|
|
32
32
|
fo each transformer.
|
33
33
|
rope_max_wavelength (int, optional): The maximum angular wavelength of
|
34
34
|
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
35
|
-
|
36
|
-
calculation of roatary embedding. Defaults to `1.0
|
35
|
+
rope_position_scaling_factor (float, optional): The scaling factor for
|
36
|
+
calculation of roatary embedding. Defaults to `1.0`
|
37
|
+
rope_requency_adjustment_factor (float, optional): The scaling factor
|
38
|
+
used to scale the inverse frequencies.
|
39
|
+
rope_low_freq_factor (float, optional): The low frequency factor.
|
40
|
+
Defaults to None.
|
41
|
+
rope_high_freq_factor: (float, optional) Used for Llama3.1+. The high
|
42
|
+
frequency factor. Defaults to None.
|
43
|
+
rope_pretraining_sequence_length: (int, optional) Sequence length during
|
44
|
+
original pretraining. Defaults to None.
|
37
45
|
layer_norm_epsilon (float, optional): Epsilon for the layer
|
38
46
|
normalization layers in the transformer decoder. Defaults to `1e-6`.
|
39
47
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
@@ -45,4 +45,86 @@ backbone_presets = {
|
|
45
45
|
"kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/2"
|
46
46
|
),
|
47
47
|
},
|
48
|
+
"llama3.1_8b": {
|
49
|
+
"metadata": {
|
50
|
+
"description": (
|
51
|
+
"8 billion parameter, 32-layer, based LLaMA 3.1 model. "
|
52
|
+
),
|
53
|
+
"params": 8030261248,
|
54
|
+
"path": "llama3",
|
55
|
+
},
|
56
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_8b/1"),
|
57
|
+
},
|
58
|
+
"llama3.1_instruct_8b": {
|
59
|
+
"metadata": {
|
60
|
+
"description": (
|
61
|
+
"8 billion parameter, 32-layer, instruction tuned LLaMA 3.1. "
|
62
|
+
),
|
63
|
+
"params": 8030261248,
|
64
|
+
"path": "llama3",
|
65
|
+
},
|
66
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/lama3.1_instruct_8b/1"),
|
67
|
+
},
|
68
|
+
"llama3.1_guard_8b": {
|
69
|
+
"metadata": {
|
70
|
+
"description": (
|
71
|
+
"8 billion parameter, 32-layer, LLaMA 3.1 fine-tuned for "
|
72
|
+
"consent safety classification. "
|
73
|
+
),
|
74
|
+
"params": 8030261248,
|
75
|
+
"path": "llama3",
|
76
|
+
},
|
77
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_guard_8b/1"),
|
78
|
+
},
|
79
|
+
"llama3.2_1b": {
|
80
|
+
"metadata": {
|
81
|
+
"description": (
|
82
|
+
"1 billion parameter, 16-layer, based LLaMA 3.2 model. "
|
83
|
+
),
|
84
|
+
"params": 1498482688,
|
85
|
+
"path": "llama3",
|
86
|
+
},
|
87
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_1b/1"),
|
88
|
+
},
|
89
|
+
"llama3.2_instruct_1b": {
|
90
|
+
"metadata": {
|
91
|
+
"description": (
|
92
|
+
"1 billion parameter, 16-layer, instruction tuned LLaMA 3.2. "
|
93
|
+
),
|
94
|
+
"params": 1498482688,
|
95
|
+
"path": "llama3",
|
96
|
+
},
|
97
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_1b/1"),
|
98
|
+
},
|
99
|
+
"llama3.2_3b": {
|
100
|
+
"metadata": {
|
101
|
+
"description": (
|
102
|
+
"3 billion parameter, 26-layer, based LLaMA 3.2 model. "
|
103
|
+
),
|
104
|
+
"params": 3606752256,
|
105
|
+
"path": "llama3",
|
106
|
+
},
|
107
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_3b/1"),
|
108
|
+
},
|
109
|
+
"llama3.2_instruct_3b": {
|
110
|
+
"metadata": {
|
111
|
+
"description": (
|
112
|
+
"3 billion parameter, 28-layer, instruction tuned LLaMA 3.2. "
|
113
|
+
),
|
114
|
+
"params": 3606752256,
|
115
|
+
"path": "llama3",
|
116
|
+
},
|
117
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_3b/1"),
|
118
|
+
},
|
119
|
+
"llama3.2_guard_1b": {
|
120
|
+
"metadata": {
|
121
|
+
"description": (
|
122
|
+
"1 billion parameter, 16-layer, based LLaMA 3.2 model "
|
123
|
+
"fine-tuned for consent safety classification. "
|
124
|
+
),
|
125
|
+
"params": 1498482688,
|
126
|
+
"path": "llama3",
|
127
|
+
},
|
128
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_guard_1b/1"),
|
129
|
+
},
|
48
130
|
}
|
@@ -7,7 +7,7 @@ backbone_cls = Llama3Backbone
|
|
7
7
|
|
8
8
|
|
9
9
|
def convert_backbone_config(transformers_config):
|
10
|
-
|
10
|
+
backbone_config = {
|
11
11
|
"vocabulary_size": transformers_config["vocab_size"],
|
12
12
|
"num_layers": transformers_config["num_hidden_layers"],
|
13
13
|
"num_query_heads": transformers_config["num_attention_heads"],
|
@@ -15,8 +15,28 @@ def convert_backbone_config(transformers_config):
|
|
15
15
|
"intermediate_dim": transformers_config["intermediate_size"],
|
16
16
|
"num_key_value_heads": transformers_config["num_key_value_heads"],
|
17
17
|
"tie_word_embeddings": transformers_config["tie_word_embeddings"],
|
18
|
+
"rope_max_wavelength": transformers_config["rope_theta"],
|
18
19
|
}
|
19
20
|
|
21
|
+
if transformers_config.get("rope_scaling", None) is not None:
|
22
|
+
if transformers_config["rope_scaling"]["rope_type"] != "llama3":
|
23
|
+
raise ValueError("The config should be a valid llama3 config.")
|
24
|
+
backbone_config["rope_frequency_adjustment_factor"] = (
|
25
|
+
transformers_config["rope_scaling"]["factor"]
|
26
|
+
)
|
27
|
+
backbone_config["rope_low_freq_factor"] = transformers_config[
|
28
|
+
"rope_scaling"
|
29
|
+
]["low_freq_factor"]
|
30
|
+
backbone_config["rope_high_freq_factor"] = transformers_config[
|
31
|
+
"rope_scaling"
|
32
|
+
]["high_freq_factor"]
|
33
|
+
backbone_config["rope_pretraining_sequence_length"] = (
|
34
|
+
transformers_config["rope_scaling"][
|
35
|
+
"original_max_position_embeddings"
|
36
|
+
]
|
37
|
+
)
|
38
|
+
return backbone_config
|
39
|
+
|
20
40
|
|
21
41
|
def convert_weights(backbone, loader, transformers_config):
|
22
42
|
loader.port_weight(
|
keras_hub/src/version_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: keras-hub-nightly
|
3
|
-
Version: 0.21.0.
|
3
|
+
Version: 0.21.0.dev202504180401
|
4
4
|
Summary: Industry-strength Natural Language Processing extensions for Keras.
|
5
5
|
Home-page: https://github.com/keras-team/keras-hub
|
6
6
|
Author: Keras team
|
@@ -8,7 +8,7 @@ keras_hub/api/tokenizers/__init__.py,sha256=NCQSOg3vf3KlM2YBsxApcJUVu9MH2jV0NQrM
|
|
8
8
|
keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
|
9
9
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
11
|
-
keras_hub/src/version_utils.py,sha256=
|
11
|
+
keras_hub/src/version_utils.py,sha256=jjtNdFgTpwdTKpH773EBDjVAe0GNxJuSTxnmUiyM280,222
|
12
12
|
keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
13
|
keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
|
@@ -173,7 +173,7 @@ keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py,sha256=UUa7RKyl
|
|
173
173
|
keras_hub/src/models/f_net/f_net_tokenizer.py,sha256=ZRTaSfgZnYLTVXgM51303LpryRsSL5GaC2Cl_D7g27A,2285
|
174
174
|
keras_hub/src/models/falcon/__init__.py,sha256=IVwPgPbw0l8XJRPQETmeQNvpdn_SneXhe_3oRMOvdx8,257
|
175
175
|
keras_hub/src/models/falcon/falcon_attention.py,sha256=fRHuK_y_w64hrqq0XYfcsycs3KD1_3RmeKP7j8LEjGU,4559
|
176
|
-
keras_hub/src/models/falcon/falcon_backbone.py,sha256=
|
176
|
+
keras_hub/src/models/falcon/falcon_backbone.py,sha256=hRwomKH_GIKJ0KMfccpHVU43HVN0WQy1n9PldvlUaTM,5451
|
177
177
|
keras_hub/src/models/falcon/falcon_causal_lm.py,sha256=2UEIeju5Tg-FstVuusejJ-MbHZ6vsNfsSJzzBM89fnU,10908
|
178
178
|
keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py,sha256=nI9E8N9enx5DppDHpLwGslb65rqGorL2sEz1jzet4gA,3033
|
179
179
|
keras_hub/src/models/falcon/falcon_presets.py,sha256=PDghkND0-7le4W-atm4BitzA127z-5ZyQguCnCChSBo,463
|
@@ -199,7 +199,7 @@ keras_hub/src/models/gemma3/__init__.py,sha256=oPFadkdK5DRLD6sYx83iTetY5daWuSzmJ
|
|
199
199
|
keras_hub/src/models/gemma3/gemma3_attention.py,sha256=VstFCTVsplcDNSgnyBcSpLgKn-pktJ39D5Ri-Bb7BQA,13628
|
200
200
|
keras_hub/src/models/gemma3/gemma3_backbone.py,sha256=xw6gbFZWZuREcN1iyPj-1Hm-3EmRglgFD5fQSzDp3zA,16439
|
201
201
|
keras_hub/src/models/gemma3/gemma3_causal_lm.py,sha256=U3C9TWlIz8VefAxQ0wJ6bDz18wqHBie8B26Ub_nFZs4,13843
|
202
|
-
keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py,sha256=
|
202
|
+
keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py,sha256=vjt4N-zr0Eb5kvkOR-WUgskDTNe64L_6tYnhyNb6xaE,29601
|
203
203
|
keras_hub/src/models/gemma3/gemma3_decoder_block.py,sha256=6PLlpDxxF67stDv74fw9nNgUHBWmTLx6qGygJwyu5FY,10819
|
204
204
|
keras_hub/src/models/gemma3/gemma3_image_converter.py,sha256=czi5JrTyKiK0nFzvonviBIX8jjvLHqvGNA9RyheB31k,536
|
205
205
|
keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py,sha256=_Q5hvhA93HAJe-A2IBRKVu0_RDVht61lFQiYse_9Rm4,4597
|
@@ -222,19 +222,20 @@ keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py,sha256=YiVz9q
|
|
222
222
|
keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py,sha256=hmB81V0SuI6bEsxEuFkYgq58wbcrv1YLvmXGin5T3E0,9732
|
223
223
|
keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py,sha256=aKso-8yGrynn3tZ5xm2egcXIBQo3__sWZDBtjmS3ZgU,1991
|
224
224
|
keras_hub/src/models/llama/__init__.py,sha256=svVZjGi71R3lVbq0AdbqlXj909mr3Rp9EPXdiO0w0G0,251
|
225
|
-
keras_hub/src/models/llama/llama_attention.py,sha256=
|
226
|
-
keras_hub/src/models/llama/llama_backbone.py,sha256=
|
225
|
+
keras_hub/src/models/llama/llama_attention.py,sha256=UFHOWr69vTkOxLdgSUckGaSuUUyqlJ_xYoswWHVnTOU,8977
|
226
|
+
keras_hub/src/models/llama/llama_backbone.py,sha256=AT8kUPHEn6DT-aGY838_sZkBhByIdh82DWW8y-Sp3mE,13614
|
227
227
|
keras_hub/src/models/llama/llama_causal_lm.py,sha256=9bP4-XDCMgsZuH1ILIMzmwq2Fyy6vkk1Vsht-lMGCNo,13258
|
228
228
|
keras_hub/src/models/llama/llama_causal_lm_preprocessor.py,sha256=VTboOMiRBoxHrwP343upLUTsv3AG65r2H8h_PNPVphE,3047
|
229
|
-
keras_hub/src/models/llama/llama_decoder.py,sha256=
|
229
|
+
keras_hub/src/models/llama/llama_decoder.py,sha256=CfWI8ru1-uWjDs0sL6H7g8ElYXWu6h7c5XIx-2Y8lX8,9668
|
230
230
|
keras_hub/src/models/llama/llama_layernorm.py,sha256=LfRbePHUJs00Ptf7dvNaw3Aj9n1xBMBpE_rS5zzsYMo,1050
|
231
231
|
keras_hub/src/models/llama/llama_presets.py,sha256=k0JPQggSQ0XUkhiPlfM0gTqHXGOt39InVLglPUi4AJU,1902
|
232
|
+
keras_hub/src/models/llama/llama_rotary_embedding.py,sha256=nqQGl7lFXJq7xGBfoONx2-wuuvKdoydnzUjy6FGQjwo,7300
|
232
233
|
keras_hub/src/models/llama/llama_tokenizer.py,sha256=NKWhxTutQ2jd6sd3NSTy9plQyKGCmuNG7U6kVxhZU4Y,1981
|
233
234
|
keras_hub/src/models/llama3/__init__.py,sha256=Vqvr2E10cnANkrRQGNBJtVLNAu-Bg9Lx6sqKOZWFy_8,257
|
234
|
-
keras_hub/src/models/llama3/llama3_backbone.py,sha256=
|
235
|
+
keras_hub/src/models/llama3/llama3_backbone.py,sha256=TEocD8X7GihQFGJAz3jPwLCqDb86nyeZ1DqBF7RgQLE,3366
|
235
236
|
keras_hub/src/models/llama3/llama3_causal_lm.py,sha256=qk_onuf7S6d7rxAntilq2Q2orggMbPEJbNHJNVe2G0U,1541
|
236
237
|
keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py,sha256=twbXel9hsQgGxDAoQhEQuVm2udnEybI4fAQTJzXAuBs,3064
|
237
|
-
keras_hub/src/models/llama3/llama3_presets.py,sha256
|
238
|
+
keras_hub/src/models/llama3/llama3_presets.py,sha256=--_6Uao-fK4xD4ShgsqzKmlyQPyO9tRkF0VDYKjGpNw,4302
|
238
239
|
keras_hub/src/models/llama3/llama3_tokenizer.py,sha256=J-KxRc08vGs4olFw_4mtJs0W_dTeUyj_XxMycazBmxI,1934
|
239
240
|
keras_hub/src/models/mistral/__init__.py,sha256=vjBlzcrIsFSwJKnfwfTNMKstIEKGFTE3kVcdAdfwlnE,263
|
240
241
|
keras_hub/src/models/mistral/mistral_attention.py,sha256=nGDlD4NcIwIGlfbt3ArxdT5QAvamY7yiNEGDlTgWirU,8609
|
@@ -458,14 +459,14 @@ keras_hub/src/utils/transformers/convert_bert.py,sha256=4gQqXCJzC9QWdLPDUAq741K8
|
|
458
459
|
keras_hub/src/utils/transformers/convert_distilbert.py,sha256=SlfIRhSRk5c1ir2HGiDPiXa5XdOId_DbcnZO9lbwyZ8,6498
|
459
460
|
keras_hub/src/utils/transformers/convert_gemma.py,sha256=ElCgwBpSN5Q7rV5PJawTsoytPzs5ZjuwoY60YAe8y_A,6533
|
460
461
|
keras_hub/src/utils/transformers/convert_gpt2.py,sha256=HCeHN_-GiQJRxLCM9OCJJ1watPVpIBF8ujS8pGbBOWc,5703
|
461
|
-
keras_hub/src/utils/transformers/convert_llama3.py,sha256=
|
462
|
+
keras_hub/src/utils/transformers/convert_llama3.py,sha256=c5phNl-QayQ_BS0s-lenbu6oHxqfwDShKJoh9DluxUU,6146
|
462
463
|
keras_hub/src/utils/transformers/convert_mistral.py,sha256=kVhN9h1ZFVhwkNW8p3wnS7eANJUXIsNy1RxWXy20Gqw,4760
|
463
464
|
keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYumf66hIid07k5NLqoeWAJgPnaLs,10649
|
464
465
|
keras_hub/src/utils/transformers/convert_qwen.py,sha256=WUxMAEFVqRs7TRw7QU5TH3_ev4yf02R1xFVliMvTQqg,5886
|
465
466
|
keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
|
466
467
|
keras_hub/src/utils/transformers/preset_loader.py,sha256=0Hi7R8HnATcwFVLsJwMMIMWTCXHNfep4IPiRpQXqM-w,3933
|
467
468
|
keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
|
468
|
-
keras_hub_nightly-0.21.0.
|
469
|
-
keras_hub_nightly-0.21.0.
|
470
|
-
keras_hub_nightly-0.21.0.
|
471
|
-
keras_hub_nightly-0.21.0.
|
469
|
+
keras_hub_nightly-0.21.0.dev202504180401.dist-info/METADATA,sha256=IPS1Mx1IcGzE10Z-je3R99kEyVnTYVXg0DQ-lFDqTLE,7715
|
470
|
+
keras_hub_nightly-0.21.0.dev202504180401.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
471
|
+
keras_hub_nightly-0.21.0.dev202504180401.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
472
|
+
keras_hub_nightly-0.21.0.dev202504180401.dist-info/RECORD,,
|
File without changes
|