keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +15 -33
- keras_hub/layers/__init__.py +134 -0
- keras_hub/metrics/__init__.py +11 -0
- keras_hub/models/__init__.py +642 -0
- keras_hub/samplers/__init__.py +18 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
- keras_hub/src/layers/preprocessing/image_converter.py +1 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
- keras_hub/src/layers/preprocessing/random_swap.py +1 -1
- keras_hub/src/models/audio_to_text.py +66 -0
- keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
- keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma/gemma_presets.py +10 -10
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
- keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_presets.py +3 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +84 -2
- keras_hub/src/models/mistral/mistral_presets.py +3 -3
- keras_hub/src/models/mixtral/__init__.py +5 -0
- keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/moonshine/__init__.py +5 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
- keras_hub/src/models/qwen/__init__.py +4 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_backbone.py +8 -1
- keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
- keras_hub/src/models/qwen_moe/__init__.py +5 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
- keras_hub/src/models/segformer/segformer_presets.py +12 -12
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/models/xception/__init__.py +5 -0
- keras_hub/src/models/xception/xception_backbone.py +188 -0
- keras_hub/src/models/xception/xception_image_classifier.py +12 -0
- keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/xception/xception_image_converter.py +8 -0
- keras_hub/src/models/xception/xception_presets.py +14 -0
- keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
- keras_hub/src/utils/coco/__init__.py +0 -0
- keras_hub/src/utils/coco/coco_utils.py +133 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +70 -10
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/timm/convert_cspnet.py +94 -23
- keras_hub/src/utils/timm/preset_loader.py +6 -6
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/{version_utils.py → version.py} +1 -1
- keras_hub/tokenizers/__init__.py +117 -0
- keras_hub/utils/__init__.py +21 -0
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
- keras_hub/api/__init__.py +0 -15
- keras_hub/api/layers/__init__.py +0 -86
- keras_hub/api/metrics/__init__.py +0 -11
- keras_hub/api/models/__init__.py +0 -416
- keras_hub/api/samplers/__init__.py +0 -16
- keras_hub/api/tokenizers/__init__.py +0 -58
- keras_hub/api/utils/__init__.py +0 -9
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
|
@@ -55,7 +55,7 @@ backbone_presets = {
|
|
|
55
55
|
"params": 11765788416,
|
|
56
56
|
"path": "gemma3",
|
|
57
57
|
},
|
|
58
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/
|
|
58
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/3",
|
|
59
59
|
},
|
|
60
60
|
"gemma3_instruct_12b_text": {
|
|
61
61
|
"metadata": {
|
|
@@ -66,7 +66,7 @@ backbone_presets = {
|
|
|
66
66
|
"params": 11765788416,
|
|
67
67
|
"path": "gemma3",
|
|
68
68
|
},
|
|
69
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/
|
|
69
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/3",
|
|
70
70
|
},
|
|
71
71
|
"gemma3_27b_text": {
|
|
72
72
|
"metadata": {
|
|
@@ -77,7 +77,7 @@ backbone_presets = {
|
|
|
77
77
|
"params": 27009002240,
|
|
78
78
|
"path": "gemma3",
|
|
79
79
|
},
|
|
80
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/
|
|
80
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/4",
|
|
81
81
|
},
|
|
82
82
|
"gemma3_instruct_27b_text": {
|
|
83
83
|
"metadata": {
|
|
@@ -88,7 +88,7 @@ backbone_presets = {
|
|
|
88
88
|
"params": 27009002240,
|
|
89
89
|
"path": "gemma3",
|
|
90
90
|
},
|
|
91
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/
|
|
91
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/3",
|
|
92
92
|
},
|
|
93
93
|
"gemma3_4b": {
|
|
94
94
|
"metadata": {
|
|
@@ -121,7 +121,7 @@ backbone_presets = {
|
|
|
121
121
|
"params": 12187079280,
|
|
122
122
|
"path": "gemma3",
|
|
123
123
|
},
|
|
124
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/
|
|
124
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/2",
|
|
125
125
|
},
|
|
126
126
|
"gemma3_instruct_12b": {
|
|
127
127
|
"metadata": {
|
|
@@ -132,7 +132,7 @@ backbone_presets = {
|
|
|
132
132
|
"params": 12187079280,
|
|
133
133
|
"path": "gemma3",
|
|
134
134
|
},
|
|
135
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/
|
|
135
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/2",
|
|
136
136
|
},
|
|
137
137
|
"gemma3_27b": {
|
|
138
138
|
"metadata": {
|
|
@@ -143,7 +143,7 @@ backbone_presets = {
|
|
|
143
143
|
"params": 27432062576,
|
|
144
144
|
"path": "gemma3",
|
|
145
145
|
},
|
|
146
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/
|
|
146
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/2",
|
|
147
147
|
},
|
|
148
148
|
"gemma3_instruct_27b": {
|
|
149
149
|
"metadata": {
|
|
@@ -154,6 +154,6 @@ backbone_presets = {
|
|
|
154
154
|
"params": 27432062576,
|
|
155
155
|
"path": "gemma3",
|
|
156
156
|
},
|
|
157
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/
|
|
157
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/2",
|
|
158
158
|
},
|
|
159
159
|
}
|
|
@@ -488,7 +488,7 @@ class Gemma3VisionEncoderBlock(keras.layers.Layer):
|
|
|
488
488
|
# Fix the compatibility issue with Keras 3.1 where
|
|
489
489
|
# `compute_output_spec` fails to propagate `inputs_shape`
|
|
490
490
|
# correctly, causing it to be `None`.
|
|
491
|
-
|
|
491
|
+
return [None, None, self.hidden_dim]
|
|
492
492
|
return [
|
|
493
493
|
None,
|
|
494
494
|
(inputs_shape[2] // self.patch_size) ** 2,
|
|
@@ -3,7 +3,9 @@ import math
|
|
|
3
3
|
import keras
|
|
4
4
|
from keras import ops
|
|
5
5
|
|
|
6
|
-
from keras_hub.src.
|
|
6
|
+
from keras_hub.src.models.llama.llama_rotary_embedding import (
|
|
7
|
+
LlamaRotaryEmbedding,
|
|
8
|
+
)
|
|
7
9
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
8
10
|
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
|
9
11
|
|
|
@@ -16,7 +18,11 @@ class LlamaAttention(keras.layers.Layer):
|
|
|
16
18
|
num_query_heads,
|
|
17
19
|
num_key_value_heads,
|
|
18
20
|
rope_max_wavelength=10000,
|
|
19
|
-
|
|
21
|
+
rope_position_scaling_factor=1.0,
|
|
22
|
+
rope_frequency_adjustment_factor=None,
|
|
23
|
+
rope_low_freq_factor=None,
|
|
24
|
+
rope_high_freq_factor=None,
|
|
25
|
+
rope_pretraining_sequence_length=None,
|
|
20
26
|
kernel_initializer="glorot_uniform",
|
|
21
27
|
dropout=0,
|
|
22
28
|
**kwargs,
|
|
@@ -28,13 +34,16 @@ class LlamaAttention(keras.layers.Layer):
|
|
|
28
34
|
|
|
29
35
|
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
|
30
36
|
self.rope_max_wavelength = rope_max_wavelength
|
|
37
|
+
self.rope_position_scaling_factor = rope_position_scaling_factor
|
|
38
|
+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
|
|
39
|
+
self.rope_low_freq_factor = rope_low_freq_factor
|
|
40
|
+
self.rope_high_freq_factor = rope_high_freq_factor
|
|
41
|
+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
|
|
31
42
|
|
|
32
43
|
self.kernel_initializer = keras.initializers.get(
|
|
33
44
|
clone_initializer(kernel_initializer)
|
|
34
45
|
)
|
|
35
46
|
|
|
36
|
-
self.rope_scaling_factor = rope_scaling_factor
|
|
37
|
-
|
|
38
47
|
def build(self, inputs_shape):
|
|
39
48
|
# Einsum variables:
|
|
40
49
|
# b = batch size
|
|
@@ -103,9 +112,13 @@ class LlamaAttention(keras.layers.Layer):
|
|
|
103
112
|
)
|
|
104
113
|
self._output_dense.build((None, None, self.num_query_heads, head_dim))
|
|
105
114
|
|
|
106
|
-
self.rotary_embedding_layer =
|
|
115
|
+
self.rotary_embedding_layer = LlamaRotaryEmbedding(
|
|
107
116
|
max_wavelength=self.rope_max_wavelength,
|
|
108
|
-
|
|
117
|
+
position_scaling_factor=self.rope_position_scaling_factor,
|
|
118
|
+
frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
|
|
119
|
+
low_freq_factor=self.rope_low_freq_factor,
|
|
120
|
+
high_freq_factor=self.rope_high_freq_factor,
|
|
121
|
+
pretraining_sequence_length=self.rope_pretraining_sequence_length,
|
|
109
122
|
dtype=self.dtype_policy,
|
|
110
123
|
)
|
|
111
124
|
|
|
@@ -224,6 +237,11 @@ class LlamaAttention(keras.layers.Layer):
|
|
|
224
237
|
"num_key_value_heads": self.num_key_value_heads,
|
|
225
238
|
"rope_max_wavelength": self.rope_max_wavelength,
|
|
226
239
|
"rope_scaling_factor": self.rope_scaling_factor,
|
|
240
|
+
"rope_low_freq_factor": self.rope_low_freq_factor,
|
|
241
|
+
"rope_high_freq_factor": self.rope_high_freq_factor,
|
|
242
|
+
"rope_pretraining_sequence_length": (
|
|
243
|
+
self.rope_pretraining_sequence_length
|
|
244
|
+
),
|
|
227
245
|
"kernel_initializer": keras.initializers.serialize(
|
|
228
246
|
self.kernel_initializer
|
|
229
247
|
),
|
|
@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
|
|
|
30
30
|
constructor.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
|
-
vocabulary_size
|
|
34
|
-
num_layers
|
|
35
|
-
num_query_heads
|
|
33
|
+
vocabulary_size: int. The size of the token vocabulary.
|
|
34
|
+
num_layers: int. The number of transformer layers.
|
|
35
|
+
num_query_heads : int. The number of query attention heads for
|
|
36
36
|
each transformer.
|
|
37
|
-
hidden_dim
|
|
37
|
+
hidden_dim : int. The size of the transformer encoding and pooling
|
|
38
38
|
layers.
|
|
39
|
-
intermediate_dim
|
|
39
|
+
intermediate_dim : int. The output dimension of the first Dense layer in
|
|
40
40
|
a three-layer feedforward network for each transformer.
|
|
41
|
-
num_key_value_heads
|
|
41
|
+
num_key_value_heads : int. The number of key and value attention heads
|
|
42
42
|
for each transformer.
|
|
43
|
-
rope_max_wavelength
|
|
43
|
+
rope_max_wavelength : int. The maximum angular wavelength of
|
|
44
44
|
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
|
45
|
-
|
|
46
|
-
calculation of
|
|
47
|
-
|
|
48
|
-
|
|
45
|
+
rope_position_scaling_factor: float. The scaling factor for
|
|
46
|
+
calculation of rotary embedding. Defaults to `1.0`
|
|
47
|
+
rope_frequency_adjustment_factor: float. The scaling factor
|
|
48
|
+
used to scale the inverse frequencies. Defaults to `None`.
|
|
49
|
+
rope_low_freq_factor: float. The low frequency scaling
|
|
50
|
+
factor. Defaults to `None`.
|
|
51
|
+
rope_high_freq_factor: float. Used for Llama3.1+. The high
|
|
52
|
+
frequency scaling factor. Defaults to `None`.
|
|
53
|
+
rope_pretraining_sequence_length: int. Used for Llama3.1+.
|
|
54
|
+
Defaults to `None`.
|
|
55
|
+
layer_norm_epsilon : float. Epsilon for the layer normalization layers
|
|
56
|
+
in the transformer decoder. Defaults to `1e-6`.
|
|
49
57
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
50
58
|
for model computations and weights. Note that some computations,
|
|
51
59
|
such as softmax and layer normalization, will always be done at
|
|
@@ -87,7 +95,11 @@ class LlamaBackbone(Backbone):
|
|
|
87
95
|
intermediate_dim,
|
|
88
96
|
num_key_value_heads,
|
|
89
97
|
rope_max_wavelength=10000,
|
|
90
|
-
|
|
98
|
+
rope_position_scaling_factor=1.0,
|
|
99
|
+
rope_frequency_adjustment_factor=None,
|
|
100
|
+
rope_low_freq_factor=None,
|
|
101
|
+
rope_high_freq_factor=None,
|
|
102
|
+
rope_pretraining_sequence_length=None,
|
|
91
103
|
layer_norm_epsilon=1e-6,
|
|
92
104
|
dropout=0,
|
|
93
105
|
dtype=None,
|
|
@@ -110,7 +122,15 @@ class LlamaBackbone(Backbone):
|
|
|
110
122
|
num_query_heads=num_query_heads,
|
|
111
123
|
num_key_value_heads=num_key_value_heads,
|
|
112
124
|
rope_max_wavelength=rope_max_wavelength,
|
|
113
|
-
|
|
125
|
+
rope_position_scaling_factor=rope_position_scaling_factor,
|
|
126
|
+
rope_frequency_adjustment_factor=(
|
|
127
|
+
rope_frequency_adjustment_factor
|
|
128
|
+
),
|
|
129
|
+
rope_low_freq_factor=rope_low_freq_factor,
|
|
130
|
+
rope_high_freq_factor=rope_high_freq_factor,
|
|
131
|
+
rope_pretraining_sequence_length=(
|
|
132
|
+
rope_pretraining_sequence_length
|
|
133
|
+
),
|
|
114
134
|
layer_norm_epsilon=layer_norm_epsilon,
|
|
115
135
|
activation=ops.silu,
|
|
116
136
|
kernel_initializer=_llama_kernel_initializer(stddev=0.02),
|
|
@@ -152,9 +172,13 @@ class LlamaBackbone(Backbone):
|
|
|
152
172
|
self.num_query_heads = num_query_heads
|
|
153
173
|
self.hidden_dim = hidden_dim
|
|
154
174
|
self.intermediate_dim = intermediate_dim
|
|
155
|
-
self.rope_max_wavelength = rope_max_wavelength
|
|
156
175
|
self.num_key_value_heads = num_key_value_heads
|
|
157
|
-
self.
|
|
176
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
177
|
+
self.rope_position_scaling_factor = rope_position_scaling_factor
|
|
178
|
+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
|
|
179
|
+
self.rope_low_freq_factor = rope_low_freq_factor
|
|
180
|
+
self.rope_high_freq_factor = rope_high_freq_factor
|
|
181
|
+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
|
|
158
182
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
159
183
|
self.dropout = dropout
|
|
160
184
|
self.tie_word_embeddings = tie_word_embeddings
|
|
@@ -169,7 +193,17 @@ class LlamaBackbone(Backbone):
|
|
|
169
193
|
"hidden_dim": self.hidden_dim,
|
|
170
194
|
"intermediate_dim": self.intermediate_dim,
|
|
171
195
|
"rope_max_wavelength": self.rope_max_wavelength,
|
|
172
|
-
"
|
|
196
|
+
"rope_position_scaling_factor": (
|
|
197
|
+
self.rope_position_scaling_factor
|
|
198
|
+
),
|
|
199
|
+
"rope_frequency_adjustment_factor": (
|
|
200
|
+
self.rope_frequency_adjustment_factor
|
|
201
|
+
),
|
|
202
|
+
"rope_low_freq_factor": self.rope_low_freq_factor,
|
|
203
|
+
"rope_high_freq_factor": self.rope_high_freq_factor,
|
|
204
|
+
"rope_pretraining_sequence_length": (
|
|
205
|
+
self.rope_pretraining_sequence_length
|
|
206
|
+
),
|
|
173
207
|
"num_key_value_heads": self.num_key_value_heads,
|
|
174
208
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
175
209
|
"dropout": self.dropout,
|
|
@@ -21,7 +21,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
|
21
21
|
num_query_heads,
|
|
22
22
|
num_key_value_heads,
|
|
23
23
|
rope_max_wavelength=10000,
|
|
24
|
-
|
|
24
|
+
rope_position_scaling_factor=1.0,
|
|
25
|
+
rope_frequency_adjustment_factor=None,
|
|
26
|
+
rope_low_freq_factor=None,
|
|
27
|
+
rope_high_freq_factor=None,
|
|
28
|
+
rope_pretraining_sequence_length=None,
|
|
25
29
|
activation="silu",
|
|
26
30
|
layer_norm_epsilon=1e-5,
|
|
27
31
|
kernel_initializer="glorot_uniform",
|
|
@@ -34,7 +38,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
|
34
38
|
self.num_key_value_heads = num_key_value_heads
|
|
35
39
|
|
|
36
40
|
self.rope_max_wavelength = rope_max_wavelength
|
|
37
|
-
self.
|
|
41
|
+
self.rope_position_scaling_factor = rope_position_scaling_factor
|
|
42
|
+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
|
|
43
|
+
self.rope_low_freq_factor = rope_low_freq_factor
|
|
44
|
+
self.rope_high_freq_factor = rope_high_freq_factor
|
|
45
|
+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
|
|
38
46
|
|
|
39
47
|
self.dropout = dropout
|
|
40
48
|
|
|
@@ -53,7 +61,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
|
53
61
|
num_query_heads=self.num_query_heads,
|
|
54
62
|
num_key_value_heads=self.num_key_value_heads,
|
|
55
63
|
rope_max_wavelength=self.rope_max_wavelength,
|
|
56
|
-
|
|
64
|
+
rope_position_scaling_factor=self.rope_position_scaling_factor,
|
|
65
|
+
rope_frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
|
|
66
|
+
rope_low_freq_factor=self.rope_low_freq_factor,
|
|
67
|
+
rope_high_freq_factor=self.rope_high_freq_factor,
|
|
68
|
+
rope_pretraining_sequence_length=self.rope_pretraining_sequence_length,
|
|
57
69
|
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
58
70
|
dropout=self.dropout,
|
|
59
71
|
dtype=self.dtype_policy,
|
|
@@ -221,6 +233,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
|
|
|
221
233
|
"num_query_heads": self.num_query_heads,
|
|
222
234
|
"rope_max_wavelength": self.rope_max_wavelength,
|
|
223
235
|
"rope_scaling_factor": self.rope_scaling_factor,
|
|
236
|
+
"rope_low_freq_factor": self.rope_low_freq_factor,
|
|
237
|
+
"rope_high_freq_factor": self.rope_high_freq_factor,
|
|
238
|
+
"rope_pretraining_sequence_length": (
|
|
239
|
+
self.rope_pretraining_sequence_length
|
|
240
|
+
),
|
|
224
241
|
"num_key_value_heads": self.num_key_value_heads,
|
|
225
242
|
"activation": keras.activations.serialize(self.activation),
|
|
226
243
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
@@ -8,7 +8,7 @@ backbone_presets = {
|
|
|
8
8
|
"params": 6738415616,
|
|
9
9
|
"path": "llama",
|
|
10
10
|
},
|
|
11
|
-
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/
|
|
11
|
+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/3",
|
|
12
12
|
},
|
|
13
13
|
"llama2_7b_en_int8": {
|
|
14
14
|
"metadata": {
|
|
@@ -30,7 +30,7 @@ backbone_presets = {
|
|
|
30
30
|
"params": 6738415616,
|
|
31
31
|
"path": "llama",
|
|
32
32
|
},
|
|
33
|
-
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/
|
|
33
|
+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/3",
|
|
34
34
|
},
|
|
35
35
|
"llama2_instruct_7b_en_int8": {
|
|
36
36
|
"metadata": {
|
|
@@ -52,6 +52,6 @@ backbone_presets = {
|
|
|
52
52
|
"params": 6738415616,
|
|
53
53
|
"path": "llama",
|
|
54
54
|
},
|
|
55
|
-
"kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/
|
|
55
|
+
"kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/3",
|
|
56
56
|
},
|
|
57
57
|
}
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LlamaRotaryEmbedding(RotaryEmbedding):
|
|
9
|
+
"""Rotary positional encoding layer.
|
|
10
|
+
|
|
11
|
+
This layer encodes absolute positional information with a rotation
|
|
12
|
+
matrix. It calculates the rotary encoding with a mix of sine and
|
|
13
|
+
cosine functions with geometrically increasing wavelengths.
|
|
14
|
+
Defined and formulated in
|
|
15
|
+
[RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
|
|
16
|
+
The input must be a tensor with shape a sequence dimension and a feature
|
|
17
|
+
dimension. Typically, this will either an input with shape
|
|
18
|
+
`(batch_size, sequence_length, feature_length)` or
|
|
19
|
+
`(batch_size, sequence_length, num_heads, feature_length)`.
|
|
20
|
+
This layer will return a new tensor with the rotary embedding applied to
|
|
21
|
+
the input tensor.
|
|
22
|
+
It is extended from `RotaryEmbedding` layer in `keras_hub.layers`.
|
|
23
|
+
It has additional smoothening and interpolation for some frequency ranges.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
max_wavelength: int. The maximum angular wavelength of the sine/cosine
|
|
27
|
+
curves. Defaults to `10000`.
|
|
28
|
+
position_scaling_factor: float. The scaling factor used to scale
|
|
29
|
+
positions of the tokens. Defaults to `1.0`.
|
|
30
|
+
frequency_adjustment_factor: float. The scaling factor used to scale the
|
|
31
|
+
inverse frequencies. Defaults to `None`.
|
|
32
|
+
low_freq_factor: float. The low frequency scaling factor.
|
|
33
|
+
Defaults to `None`.
|
|
34
|
+
high_freq_factor: float. The high frequency scaling factor.
|
|
35
|
+
Defaults to `None`.
|
|
36
|
+
pretraining_sequence_length: int. Used for Llama3.1+, the original
|
|
37
|
+
context length at time of pretraining. Defaults to `None`.
|
|
38
|
+
sequence_axis: int. Sequence axis in the input tensor.
|
|
39
|
+
feature_axis: int. Feature axis in the input tensor.
|
|
40
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
41
|
+
including `name`, `trainable`, `dtype` etc.
|
|
42
|
+
|
|
43
|
+
Call arguments:
|
|
44
|
+
inputs: The tensor inputs to apply the embedding to. This can have
|
|
45
|
+
any shape, but must contain both a sequence and feature axis. The
|
|
46
|
+
rotary embedding will be applied to `inputs` and returned.
|
|
47
|
+
start_index: An integer or integer tensor. The starting position to
|
|
48
|
+
compute the rotary embedding from. This is useful during cached
|
|
49
|
+
decoding, where each position is predicted separately in a loop.
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
batch_size = 16
|
|
55
|
+
feature_length = 18
|
|
56
|
+
sequence_length = 256
|
|
57
|
+
num_heads = 8
|
|
58
|
+
|
|
59
|
+
# No multi-head dimension.
|
|
60
|
+
tensor = np.ones((batch_size, sequence_length, feature_length))
|
|
61
|
+
rot_emb_layer = RotaryEmbedding()
|
|
62
|
+
tensor_rot = rot_emb_layer(tensor)
|
|
63
|
+
|
|
64
|
+
# With multi-head dimension.
|
|
65
|
+
tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
|
|
66
|
+
tensor_rot = rot_emb_layer(tensor)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
References:
|
|
70
|
+
- [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
max_wavelength=10000,
|
|
76
|
+
position_scaling_factor=1.0,
|
|
77
|
+
sequence_axis=1,
|
|
78
|
+
feature_axis=-1,
|
|
79
|
+
frequency_adjustment_factor=None,
|
|
80
|
+
low_freq_factor=None,
|
|
81
|
+
high_freq_factor=None,
|
|
82
|
+
pretraining_sequence_length=None,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
super().__init__(
|
|
86
|
+
max_wavelength=max_wavelength,
|
|
87
|
+
scaling_factor=position_scaling_factor,
|
|
88
|
+
sequence_axis=sequence_axis,
|
|
89
|
+
feature_axis=feature_axis,
|
|
90
|
+
**kwargs,
|
|
91
|
+
)
|
|
92
|
+
self.max_wavelength = max_wavelength
|
|
93
|
+
self.sequence_axis = sequence_axis
|
|
94
|
+
self.feature_axis = feature_axis
|
|
95
|
+
self.position_scaling_factor = position_scaling_factor
|
|
96
|
+
self.frequency_adjustment_factor = frequency_adjustment_factor
|
|
97
|
+
self.low_freq_factor = low_freq_factor
|
|
98
|
+
self.high_freq_factor = high_freq_factor
|
|
99
|
+
self.pretraining_sequence_length = pretraining_sequence_length
|
|
100
|
+
|
|
101
|
+
grouped_args = [
|
|
102
|
+
low_freq_factor,
|
|
103
|
+
high_freq_factor,
|
|
104
|
+
frequency_adjustment_factor,
|
|
105
|
+
pretraining_sequence_length,
|
|
106
|
+
]
|
|
107
|
+
args_none = [x is None for x in grouped_args]
|
|
108
|
+
if any(args_none) and not all(args_none):
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Either all of `low_freq_factor`,`high_freq_factor`, "
|
|
111
|
+
"`frequency_adjustment_factor` and "
|
|
112
|
+
"`pretraining_sequence_length` should be set, or all of should"
|
|
113
|
+
" be set `None`."
|
|
114
|
+
)
|
|
115
|
+
self.built = True
|
|
116
|
+
|
|
117
|
+
def _get_inverse_freq(self, rotary_dim):
|
|
118
|
+
freq_range = ops.divide(
|
|
119
|
+
ops.arange(0, rotary_dim, 2, dtype="float32"),
|
|
120
|
+
ops.cast(rotary_dim, "float32"),
|
|
121
|
+
)
|
|
122
|
+
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
|
|
123
|
+
|
|
124
|
+
# From llama3.1+ we have additional smoothening and interpolation.
|
|
125
|
+
# low_freq_factor, high_freq_factor, pretraining_sequence_length,
|
|
126
|
+
# frequency_adjustment_factor are all set at once so it is fine.
|
|
127
|
+
if self.low_freq_factor is not None:
|
|
128
|
+
low_freq_wavelen = (
|
|
129
|
+
self.pretraining_sequence_length / self.low_freq_factor
|
|
130
|
+
)
|
|
131
|
+
high_freq_wavelen = (
|
|
132
|
+
self.pretraining_sequence_length / self.high_freq_factor
|
|
133
|
+
)
|
|
134
|
+
wavelen = 2 * math.pi / inverse_freq
|
|
135
|
+
|
|
136
|
+
# wavelen < high_freq_wavelen: do nothing
|
|
137
|
+
# wavelen > low_freq_wavelen: divide by factor
|
|
138
|
+
inverse_freq = ops.where(
|
|
139
|
+
ops.greater(wavelen, low_freq_wavelen),
|
|
140
|
+
(inverse_freq / self.frequency_adjustment_factor),
|
|
141
|
+
inverse_freq,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# otherwise: interpolate between the two, using a smooth factor
|
|
145
|
+
smooth_factor = (
|
|
146
|
+
(self.pretraining_sequence_length / wavelen)
|
|
147
|
+
- self.low_freq_factor
|
|
148
|
+
) / (self.high_freq_factor - self.low_freq_factor)
|
|
149
|
+
smoothed_inv_freq = (1 - smooth_factor) * (
|
|
150
|
+
inverse_freq / self.frequency_adjustment_factor
|
|
151
|
+
) + (smooth_factor * inverse_freq)
|
|
152
|
+
is_medium_freq = ops.logical_and(
|
|
153
|
+
ops.greater_equal(wavelen, high_freq_wavelen),
|
|
154
|
+
ops.less_equal(wavelen, low_freq_wavelen),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
inverse_freq = ops.where(
|
|
158
|
+
is_medium_freq, smoothed_inv_freq, inverse_freq
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return inverse_freq
|
|
162
|
+
|
|
163
|
+
def get_config(self):
|
|
164
|
+
config = super().get_config()
|
|
165
|
+
config.update(
|
|
166
|
+
{
|
|
167
|
+
"max_wavelength": self.max_wavelength,
|
|
168
|
+
"sequence_axis": self.sequence_axis,
|
|
169
|
+
"feature_axis": self.feature_axis,
|
|
170
|
+
"position_scaling_factor": self.position_scaling_factor,
|
|
171
|
+
"frequency_adjustment_factor": self.frequency_adjustment_factor,
|
|
172
|
+
"low_freq_factor": self.low_freq_factor,
|
|
173
|
+
"high_freq_factor": self.high_freq_factor,
|
|
174
|
+
"original_max_embeddings": self.pretraining_sequence_length,
|
|
175
|
+
}
|
|
176
|
+
)
|
|
177
|
+
return config
|
|
178
|
+
|
|
179
|
+
def compute_output_shape(self, input_shape):
|
|
180
|
+
return input_shape
|
|
@@ -32,8 +32,16 @@ class Llama3Backbone(LlamaBackbone):
|
|
|
32
32
|
fo each transformer.
|
|
33
33
|
rope_max_wavelength (int, optional): The maximum angular wavelength of
|
|
34
34
|
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
|
35
|
-
|
|
36
|
-
calculation of roatary embedding. Defaults to `1.0
|
|
35
|
+
rope_position_scaling_factor (float, optional): The scaling factor for
|
|
36
|
+
calculation of roatary embedding. Defaults to `1.0`
|
|
37
|
+
rope_requency_adjustment_factor (float, optional): The scaling factor
|
|
38
|
+
used to scale the inverse frequencies.
|
|
39
|
+
rope_low_freq_factor (float, optional): The low frequency factor.
|
|
40
|
+
Defaults to None.
|
|
41
|
+
rope_high_freq_factor: (float, optional) Used for Llama3.1+. The high
|
|
42
|
+
frequency factor. Defaults to None.
|
|
43
|
+
rope_pretraining_sequence_length: (int, optional) Sequence length during
|
|
44
|
+
original pretraining. Defaults to None.
|
|
37
45
|
layer_norm_epsilon (float, optional): Epsilon for the layer
|
|
38
46
|
normalization layers in the transformer decoder. Defaults to `1e-6`.
|
|
39
47
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
@@ -8,7 +8,7 @@ backbone_presets = {
|
|
|
8
8
|
"params": 8030261248,
|
|
9
9
|
"path": "llama3",
|
|
10
10
|
},
|
|
11
|
-
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/
|
|
11
|
+
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/5",
|
|
12
12
|
},
|
|
13
13
|
"llama3_8b_en_int8": {
|
|
14
14
|
"metadata": {
|
|
@@ -30,7 +30,7 @@ backbone_presets = {
|
|
|
30
30
|
"params": 8030261248,
|
|
31
31
|
"path": "llama3",
|
|
32
32
|
},
|
|
33
|
-
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/
|
|
33
|
+
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/5",
|
|
34
34
|
},
|
|
35
35
|
"llama3_instruct_8b_en_int8": {
|
|
36
36
|
"metadata": {
|
|
@@ -45,4 +45,86 @@ backbone_presets = {
|
|
|
45
45
|
"kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/2"
|
|
46
46
|
),
|
|
47
47
|
},
|
|
48
|
+
"llama3.1_8b": {
|
|
49
|
+
"metadata": {
|
|
50
|
+
"description": (
|
|
51
|
+
"8 billion parameter, 32-layer, based LLaMA 3.1 model. "
|
|
52
|
+
),
|
|
53
|
+
"params": 8030261248,
|
|
54
|
+
"path": "llama3",
|
|
55
|
+
},
|
|
56
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_8b/1"),
|
|
57
|
+
},
|
|
58
|
+
"llama3.1_instruct_8b": {
|
|
59
|
+
"metadata": {
|
|
60
|
+
"description": (
|
|
61
|
+
"8 billion parameter, 32-layer, instruction tuned LLaMA 3.1. "
|
|
62
|
+
),
|
|
63
|
+
"params": 8030261248,
|
|
64
|
+
"path": "llama3",
|
|
65
|
+
},
|
|
66
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/lama3.1_instruct_8b/1"),
|
|
67
|
+
},
|
|
68
|
+
"llama3.1_guard_8b": {
|
|
69
|
+
"metadata": {
|
|
70
|
+
"description": (
|
|
71
|
+
"8 billion parameter, 32-layer, LLaMA 3.1 fine-tuned for "
|
|
72
|
+
"consent safety classification. "
|
|
73
|
+
),
|
|
74
|
+
"params": 8030261248,
|
|
75
|
+
"path": "llama3",
|
|
76
|
+
},
|
|
77
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_guard_8b/1"),
|
|
78
|
+
},
|
|
79
|
+
"llama3.2_1b": {
|
|
80
|
+
"metadata": {
|
|
81
|
+
"description": (
|
|
82
|
+
"1 billion parameter, 16-layer, based LLaMA 3.2 model. "
|
|
83
|
+
),
|
|
84
|
+
"params": 1498482688,
|
|
85
|
+
"path": "llama3",
|
|
86
|
+
},
|
|
87
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_1b/1"),
|
|
88
|
+
},
|
|
89
|
+
"llama3.2_instruct_1b": {
|
|
90
|
+
"metadata": {
|
|
91
|
+
"description": (
|
|
92
|
+
"1 billion parameter, 16-layer, instruction tuned LLaMA 3.2. "
|
|
93
|
+
),
|
|
94
|
+
"params": 1498482688,
|
|
95
|
+
"path": "llama3",
|
|
96
|
+
},
|
|
97
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_1b/1"),
|
|
98
|
+
},
|
|
99
|
+
"llama3.2_3b": {
|
|
100
|
+
"metadata": {
|
|
101
|
+
"description": (
|
|
102
|
+
"3 billion parameter, 26-layer, based LLaMA 3.2 model. "
|
|
103
|
+
),
|
|
104
|
+
"params": 3606752256,
|
|
105
|
+
"path": "llama3",
|
|
106
|
+
},
|
|
107
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_3b/1"),
|
|
108
|
+
},
|
|
109
|
+
"llama3.2_instruct_3b": {
|
|
110
|
+
"metadata": {
|
|
111
|
+
"description": (
|
|
112
|
+
"3 billion parameter, 28-layer, instruction tuned LLaMA 3.2. "
|
|
113
|
+
),
|
|
114
|
+
"params": 3606752256,
|
|
115
|
+
"path": "llama3",
|
|
116
|
+
},
|
|
117
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_3b/1"),
|
|
118
|
+
},
|
|
119
|
+
"llama3.2_guard_1b": {
|
|
120
|
+
"metadata": {
|
|
121
|
+
"description": (
|
|
122
|
+
"1 billion parameter, 16-layer, based LLaMA 3.2 model "
|
|
123
|
+
"fine-tuned for consent safety classification. "
|
|
124
|
+
),
|
|
125
|
+
"params": 1498482688,
|
|
126
|
+
"path": "llama3",
|
|
127
|
+
},
|
|
128
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_guard_1b/1"),
|
|
129
|
+
},
|
|
48
130
|
}
|
|
@@ -8,7 +8,7 @@ backbone_presets = {
|
|
|
8
8
|
"params": 7241732096,
|
|
9
9
|
"path": "mistral",
|
|
10
10
|
},
|
|
11
|
-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/
|
|
11
|
+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/8",
|
|
12
12
|
},
|
|
13
13
|
"mistral_instruct_7b_en": {
|
|
14
14
|
"metadata": {
|
|
@@ -16,7 +16,7 @@ backbone_presets = {
|
|
|
16
16
|
"params": 7241732096,
|
|
17
17
|
"path": "mistral",
|
|
18
18
|
},
|
|
19
|
-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/
|
|
19
|
+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/8",
|
|
20
20
|
},
|
|
21
21
|
"mistral_0.2_instruct_7b_en": {
|
|
22
22
|
"metadata": {
|
|
@@ -24,6 +24,6 @@ backbone_presets = {
|
|
|
24
24
|
"params": 7241732096,
|
|
25
25
|
"path": "mistral",
|
|
26
26
|
},
|
|
27
|
-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/
|
|
27
|
+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/3",
|
|
28
28
|
},
|
|
29
29
|
}
|