keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -55,7 +55,7 @@ backbone_presets = {
55
55
  "params": 11765788416,
56
56
  "path": "gemma3",
57
57
  },
58
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/2",
58
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/3",
59
59
  },
60
60
  "gemma3_instruct_12b_text": {
61
61
  "metadata": {
@@ -66,7 +66,7 @@ backbone_presets = {
66
66
  "params": 11765788416,
67
67
  "path": "gemma3",
68
68
  },
69
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/2",
69
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/3",
70
70
  },
71
71
  "gemma3_27b_text": {
72
72
  "metadata": {
@@ -77,7 +77,7 @@ backbone_presets = {
77
77
  "params": 27009002240,
78
78
  "path": "gemma3",
79
79
  },
80
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/3",
80
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/4",
81
81
  },
82
82
  "gemma3_instruct_27b_text": {
83
83
  "metadata": {
@@ -88,7 +88,7 @@ backbone_presets = {
88
88
  "params": 27009002240,
89
89
  "path": "gemma3",
90
90
  },
91
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/2",
91
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/3",
92
92
  },
93
93
  "gemma3_4b": {
94
94
  "metadata": {
@@ -121,7 +121,7 @@ backbone_presets = {
121
121
  "params": 12187079280,
122
122
  "path": "gemma3",
123
123
  },
124
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/1",
124
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/2",
125
125
  },
126
126
  "gemma3_instruct_12b": {
127
127
  "metadata": {
@@ -132,7 +132,7 @@ backbone_presets = {
132
132
  "params": 12187079280,
133
133
  "path": "gemma3",
134
134
  },
135
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/1",
135
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/2",
136
136
  },
137
137
  "gemma3_27b": {
138
138
  "metadata": {
@@ -143,7 +143,7 @@ backbone_presets = {
143
143
  "params": 27432062576,
144
144
  "path": "gemma3",
145
145
  },
146
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/1",
146
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/2",
147
147
  },
148
148
  "gemma3_instruct_27b": {
149
149
  "metadata": {
@@ -154,6 +154,6 @@ backbone_presets = {
154
154
  "params": 27432062576,
155
155
  "path": "gemma3",
156
156
  },
157
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/1",
157
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/2",
158
158
  },
159
159
  }
@@ -488,7 +488,7 @@ class Gemma3VisionEncoderBlock(keras.layers.Layer):
488
488
  # Fix the compatibility issue with Keras 3.1 where
489
489
  # `compute_output_spec` fails to propagate `inputs_shape`
490
490
  # correctly, causing it to be `None`.
491
- inputs_shape = [None, None, None]
491
+ return [None, None, self.hidden_dim]
492
492
  return [
493
493
  None,
494
494
  (inputs_shape[2] // self.patch_size) ** 2,
@@ -3,7 +3,9 @@ import math
3
3
  import keras
4
4
  from keras import ops
5
5
 
6
- from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+ from keras_hub.src.models.llama.llama_rotary_embedding import (
7
+ LlamaRotaryEmbedding,
8
+ )
7
9
  from keras_hub.src.utils.keras_utils import clone_initializer
8
10
  from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
11
 
@@ -16,7 +18,11 @@ class LlamaAttention(keras.layers.Layer):
16
18
  num_query_heads,
17
19
  num_key_value_heads,
18
20
  rope_max_wavelength=10000,
19
- rope_scaling_factor=1.0,
21
+ rope_position_scaling_factor=1.0,
22
+ rope_frequency_adjustment_factor=None,
23
+ rope_low_freq_factor=None,
24
+ rope_high_freq_factor=None,
25
+ rope_pretraining_sequence_length=None,
20
26
  kernel_initializer="glorot_uniform",
21
27
  dropout=0,
22
28
  **kwargs,
@@ -28,13 +34,16 @@ class LlamaAttention(keras.layers.Layer):
28
34
 
29
35
  self.num_key_value_groups = num_query_heads // num_key_value_heads
30
36
  self.rope_max_wavelength = rope_max_wavelength
37
+ self.rope_position_scaling_factor = rope_position_scaling_factor
38
+ self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
39
+ self.rope_low_freq_factor = rope_low_freq_factor
40
+ self.rope_high_freq_factor = rope_high_freq_factor
41
+ self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
31
42
 
32
43
  self.kernel_initializer = keras.initializers.get(
33
44
  clone_initializer(kernel_initializer)
34
45
  )
35
46
 
36
- self.rope_scaling_factor = rope_scaling_factor
37
-
38
47
  def build(self, inputs_shape):
39
48
  # Einsum variables:
40
49
  # b = batch size
@@ -103,9 +112,13 @@ class LlamaAttention(keras.layers.Layer):
103
112
  )
104
113
  self._output_dense.build((None, None, self.num_query_heads, head_dim))
105
114
 
106
- self.rotary_embedding_layer = RotaryEmbedding(
115
+ self.rotary_embedding_layer = LlamaRotaryEmbedding(
107
116
  max_wavelength=self.rope_max_wavelength,
108
- scaling_factor=self.rope_scaling_factor,
117
+ position_scaling_factor=self.rope_position_scaling_factor,
118
+ frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
119
+ low_freq_factor=self.rope_low_freq_factor,
120
+ high_freq_factor=self.rope_high_freq_factor,
121
+ pretraining_sequence_length=self.rope_pretraining_sequence_length,
109
122
  dtype=self.dtype_policy,
110
123
  )
111
124
 
@@ -224,6 +237,11 @@ class LlamaAttention(keras.layers.Layer):
224
237
  "num_key_value_heads": self.num_key_value_heads,
225
238
  "rope_max_wavelength": self.rope_max_wavelength,
226
239
  "rope_scaling_factor": self.rope_scaling_factor,
240
+ "rope_low_freq_factor": self.rope_low_freq_factor,
241
+ "rope_high_freq_factor": self.rope_high_freq_factor,
242
+ "rope_pretraining_sequence_length": (
243
+ self.rope_pretraining_sequence_length
244
+ ),
227
245
  "kernel_initializer": keras.initializers.serialize(
228
246
  self.kernel_initializer
229
247
  ),
@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
30
30
  constructor.
31
31
 
32
32
  Args:
33
- vocabulary_size (int): The size of the token vocabulary.
34
- num_layers (int): The number of transformer layers.
35
- num_query_heads (int): The number of query attention heads for
33
+ vocabulary_size: int. The size of the token vocabulary.
34
+ num_layers: int. The number of transformer layers.
35
+ num_query_heads : int. The number of query attention heads for
36
36
  each transformer.
37
- hidden_dim (int): The size of the transformer encoding and pooling
37
+ hidden_dim : int. The size of the transformer encoding and pooling
38
38
  layers.
39
- intermediate_dim (int): The output dimension of the first Dense layer in
39
+ intermediate_dim : int. The output dimension of the first Dense layer in
40
40
  a three-layer feedforward network for each transformer.
41
- num_key_value_heads (int): The number of key and value attention heads
41
+ num_key_value_heads : int. The number of key and value attention heads
42
42
  for each transformer.
43
- rope_max_wavelength (int, optional): The maximum angular wavelength of
43
+ rope_max_wavelength : int. The maximum angular wavelength of
44
44
  the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
45
- rope_scaling_factor (float, optional): The scaling factor for
46
- calculation of roatary embedding. Defaults to `1.0`.
47
- layer_norm_epsilon (float, optional): Epsilon for the layer
48
- normalization layers in the transformer decoder. Defaults to `1e-6`.
45
+ rope_position_scaling_factor: float. The scaling factor for
46
+ calculation of rotary embedding. Defaults to `1.0`
47
+ rope_frequency_adjustment_factor: float. The scaling factor
48
+ used to scale the inverse frequencies. Defaults to `None`.
49
+ rope_low_freq_factor: float. The low frequency scaling
50
+ factor. Defaults to `None`.
51
+ rope_high_freq_factor: float. Used for Llama3.1+. The high
52
+ frequency scaling factor. Defaults to `None`.
53
+ rope_pretraining_sequence_length: int. Used for Llama3.1+.
54
+ Defaults to `None`.
55
+ layer_norm_epsilon : float. Epsilon for the layer normalization layers
56
+ in the transformer decoder. Defaults to `1e-6`.
49
57
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
50
58
  for model computations and weights. Note that some computations,
51
59
  such as softmax and layer normalization, will always be done at
@@ -87,7 +95,11 @@ class LlamaBackbone(Backbone):
87
95
  intermediate_dim,
88
96
  num_key_value_heads,
89
97
  rope_max_wavelength=10000,
90
- rope_scaling_factor=1.0,
98
+ rope_position_scaling_factor=1.0,
99
+ rope_frequency_adjustment_factor=None,
100
+ rope_low_freq_factor=None,
101
+ rope_high_freq_factor=None,
102
+ rope_pretraining_sequence_length=None,
91
103
  layer_norm_epsilon=1e-6,
92
104
  dropout=0,
93
105
  dtype=None,
@@ -110,7 +122,15 @@ class LlamaBackbone(Backbone):
110
122
  num_query_heads=num_query_heads,
111
123
  num_key_value_heads=num_key_value_heads,
112
124
  rope_max_wavelength=rope_max_wavelength,
113
- rope_scaling_factor=rope_scaling_factor,
125
+ rope_position_scaling_factor=rope_position_scaling_factor,
126
+ rope_frequency_adjustment_factor=(
127
+ rope_frequency_adjustment_factor
128
+ ),
129
+ rope_low_freq_factor=rope_low_freq_factor,
130
+ rope_high_freq_factor=rope_high_freq_factor,
131
+ rope_pretraining_sequence_length=(
132
+ rope_pretraining_sequence_length
133
+ ),
114
134
  layer_norm_epsilon=layer_norm_epsilon,
115
135
  activation=ops.silu,
116
136
  kernel_initializer=_llama_kernel_initializer(stddev=0.02),
@@ -152,9 +172,13 @@ class LlamaBackbone(Backbone):
152
172
  self.num_query_heads = num_query_heads
153
173
  self.hidden_dim = hidden_dim
154
174
  self.intermediate_dim = intermediate_dim
155
- self.rope_max_wavelength = rope_max_wavelength
156
175
  self.num_key_value_heads = num_key_value_heads
157
- self.rope_scaling_factor = rope_scaling_factor
176
+ self.rope_max_wavelength = rope_max_wavelength
177
+ self.rope_position_scaling_factor = rope_position_scaling_factor
178
+ self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
179
+ self.rope_low_freq_factor = rope_low_freq_factor
180
+ self.rope_high_freq_factor = rope_high_freq_factor
181
+ self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
158
182
  self.layer_norm_epsilon = layer_norm_epsilon
159
183
  self.dropout = dropout
160
184
  self.tie_word_embeddings = tie_word_embeddings
@@ -169,7 +193,17 @@ class LlamaBackbone(Backbone):
169
193
  "hidden_dim": self.hidden_dim,
170
194
  "intermediate_dim": self.intermediate_dim,
171
195
  "rope_max_wavelength": self.rope_max_wavelength,
172
- "rope_scaling_factor": self.rope_scaling_factor,
196
+ "rope_position_scaling_factor": (
197
+ self.rope_position_scaling_factor
198
+ ),
199
+ "rope_frequency_adjustment_factor": (
200
+ self.rope_frequency_adjustment_factor
201
+ ),
202
+ "rope_low_freq_factor": self.rope_low_freq_factor,
203
+ "rope_high_freq_factor": self.rope_high_freq_factor,
204
+ "rope_pretraining_sequence_length": (
205
+ self.rope_pretraining_sequence_length
206
+ ),
173
207
  "num_key_value_heads": self.num_key_value_heads,
174
208
  "layer_norm_epsilon": self.layer_norm_epsilon,
175
209
  "dropout": self.dropout,
@@ -21,7 +21,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
21
21
  num_query_heads,
22
22
  num_key_value_heads,
23
23
  rope_max_wavelength=10000,
24
- rope_scaling_factor=1.0,
24
+ rope_position_scaling_factor=1.0,
25
+ rope_frequency_adjustment_factor=None,
26
+ rope_low_freq_factor=None,
27
+ rope_high_freq_factor=None,
28
+ rope_pretraining_sequence_length=None,
25
29
  activation="silu",
26
30
  layer_norm_epsilon=1e-5,
27
31
  kernel_initializer="glorot_uniform",
@@ -34,7 +38,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
34
38
  self.num_key_value_heads = num_key_value_heads
35
39
 
36
40
  self.rope_max_wavelength = rope_max_wavelength
37
- self.rope_scaling_factor = rope_scaling_factor
41
+ self.rope_position_scaling_factor = rope_position_scaling_factor
42
+ self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
43
+ self.rope_low_freq_factor = rope_low_freq_factor
44
+ self.rope_high_freq_factor = rope_high_freq_factor
45
+ self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
38
46
 
39
47
  self.dropout = dropout
40
48
 
@@ -53,7 +61,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
53
61
  num_query_heads=self.num_query_heads,
54
62
  num_key_value_heads=self.num_key_value_heads,
55
63
  rope_max_wavelength=self.rope_max_wavelength,
56
- rope_scaling_factor=self.rope_scaling_factor,
64
+ rope_position_scaling_factor=self.rope_position_scaling_factor,
65
+ rope_frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
66
+ rope_low_freq_factor=self.rope_low_freq_factor,
67
+ rope_high_freq_factor=self.rope_high_freq_factor,
68
+ rope_pretraining_sequence_length=self.rope_pretraining_sequence_length,
57
69
  kernel_initializer=clone_initializer(self.kernel_initializer),
58
70
  dropout=self.dropout,
59
71
  dtype=self.dtype_policy,
@@ -221,6 +233,11 @@ class LlamaTransformerDecoder(keras.layers.Layer):
221
233
  "num_query_heads": self.num_query_heads,
222
234
  "rope_max_wavelength": self.rope_max_wavelength,
223
235
  "rope_scaling_factor": self.rope_scaling_factor,
236
+ "rope_low_freq_factor": self.rope_low_freq_factor,
237
+ "rope_high_freq_factor": self.rope_high_freq_factor,
238
+ "rope_pretraining_sequence_length": (
239
+ self.rope_pretraining_sequence_length
240
+ ),
224
241
  "num_key_value_heads": self.num_key_value_heads,
225
242
  "activation": keras.activations.serialize(self.activation),
226
243
  "layer_norm_epsilon": self.layer_norm_epsilon,
@@ -8,7 +8,7 @@ backbone_presets = {
8
8
  "params": 6738415616,
9
9
  "path": "llama",
10
10
  },
11
- "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/2",
11
+ "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/3",
12
12
  },
13
13
  "llama2_7b_en_int8": {
14
14
  "metadata": {
@@ -30,7 +30,7 @@ backbone_presets = {
30
30
  "params": 6738415616,
31
31
  "path": "llama",
32
32
  },
33
- "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/2",
33
+ "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/3",
34
34
  },
35
35
  "llama2_instruct_7b_en_int8": {
36
36
  "metadata": {
@@ -52,6 +52,6 @@ backbone_presets = {
52
52
  "params": 6738415616,
53
53
  "path": "llama",
54
54
  },
55
- "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/2",
55
+ "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/3",
56
56
  },
57
57
  }
@@ -0,0 +1,180 @@
1
+ import math
2
+
3
+ from keras import ops
4
+
5
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+
7
+
8
+ class LlamaRotaryEmbedding(RotaryEmbedding):
9
+ """Rotary positional encoding layer.
10
+
11
+ This layer encodes absolute positional information with a rotation
12
+ matrix. It calculates the rotary encoding with a mix of sine and
13
+ cosine functions with geometrically increasing wavelengths.
14
+ Defined and formulated in
15
+ [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
16
+ The input must be a tensor with shape a sequence dimension and a feature
17
+ dimension. Typically, this will either an input with shape
18
+ `(batch_size, sequence_length, feature_length)` or
19
+ `(batch_size, sequence_length, num_heads, feature_length)`.
20
+ This layer will return a new tensor with the rotary embedding applied to
21
+ the input tensor.
22
+ It is extended from `RotaryEmbedding` layer in `keras_hub.layers`.
23
+ It has additional smoothening and interpolation for some frequency ranges.
24
+
25
+ Args:
26
+ max_wavelength: int. The maximum angular wavelength of the sine/cosine
27
+ curves. Defaults to `10000`.
28
+ position_scaling_factor: float. The scaling factor used to scale
29
+ positions of the tokens. Defaults to `1.0`.
30
+ frequency_adjustment_factor: float. The scaling factor used to scale the
31
+ inverse frequencies. Defaults to `None`.
32
+ low_freq_factor: float. The low frequency scaling factor.
33
+ Defaults to `None`.
34
+ high_freq_factor: float. The high frequency scaling factor.
35
+ Defaults to `None`.
36
+ pretraining_sequence_length: int. Used for Llama3.1+, the original
37
+ context length at time of pretraining. Defaults to `None`.
38
+ sequence_axis: int. Sequence axis in the input tensor.
39
+ feature_axis: int. Feature axis in the input tensor.
40
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
41
+ including `name`, `trainable`, `dtype` etc.
42
+
43
+ Call arguments:
44
+ inputs: The tensor inputs to apply the embedding to. This can have
45
+ any shape, but must contain both a sequence and feature axis. The
46
+ rotary embedding will be applied to `inputs` and returned.
47
+ start_index: An integer or integer tensor. The starting position to
48
+ compute the rotary embedding from. This is useful during cached
49
+ decoding, where each position is predicted separately in a loop.
50
+
51
+ Examples:
52
+
53
+ ```python
54
+ batch_size = 16
55
+ feature_length = 18
56
+ sequence_length = 256
57
+ num_heads = 8
58
+
59
+ # No multi-head dimension.
60
+ tensor = np.ones((batch_size, sequence_length, feature_length))
61
+ rot_emb_layer = RotaryEmbedding()
62
+ tensor_rot = rot_emb_layer(tensor)
63
+
64
+ # With multi-head dimension.
65
+ tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
66
+ tensor_rot = rot_emb_layer(tensor)
67
+ ```
68
+
69
+ References:
70
+ - [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ max_wavelength=10000,
76
+ position_scaling_factor=1.0,
77
+ sequence_axis=1,
78
+ feature_axis=-1,
79
+ frequency_adjustment_factor=None,
80
+ low_freq_factor=None,
81
+ high_freq_factor=None,
82
+ pretraining_sequence_length=None,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(
86
+ max_wavelength=max_wavelength,
87
+ scaling_factor=position_scaling_factor,
88
+ sequence_axis=sequence_axis,
89
+ feature_axis=feature_axis,
90
+ **kwargs,
91
+ )
92
+ self.max_wavelength = max_wavelength
93
+ self.sequence_axis = sequence_axis
94
+ self.feature_axis = feature_axis
95
+ self.position_scaling_factor = position_scaling_factor
96
+ self.frequency_adjustment_factor = frequency_adjustment_factor
97
+ self.low_freq_factor = low_freq_factor
98
+ self.high_freq_factor = high_freq_factor
99
+ self.pretraining_sequence_length = pretraining_sequence_length
100
+
101
+ grouped_args = [
102
+ low_freq_factor,
103
+ high_freq_factor,
104
+ frequency_adjustment_factor,
105
+ pretraining_sequence_length,
106
+ ]
107
+ args_none = [x is None for x in grouped_args]
108
+ if any(args_none) and not all(args_none):
109
+ raise ValueError(
110
+ "Either all of `low_freq_factor`,`high_freq_factor`, "
111
+ "`frequency_adjustment_factor` and "
112
+ "`pretraining_sequence_length` should be set, or all of should"
113
+ " be set `None`."
114
+ )
115
+ self.built = True
116
+
117
+ def _get_inverse_freq(self, rotary_dim):
118
+ freq_range = ops.divide(
119
+ ops.arange(0, rotary_dim, 2, dtype="float32"),
120
+ ops.cast(rotary_dim, "float32"),
121
+ )
122
+ inverse_freq = 1.0 / (self.max_wavelength**freq_range)
123
+
124
+ # From llama3.1+ we have additional smoothening and interpolation.
125
+ # low_freq_factor, high_freq_factor, pretraining_sequence_length,
126
+ # frequency_adjustment_factor are all set at once so it is fine.
127
+ if self.low_freq_factor is not None:
128
+ low_freq_wavelen = (
129
+ self.pretraining_sequence_length / self.low_freq_factor
130
+ )
131
+ high_freq_wavelen = (
132
+ self.pretraining_sequence_length / self.high_freq_factor
133
+ )
134
+ wavelen = 2 * math.pi / inverse_freq
135
+
136
+ # wavelen < high_freq_wavelen: do nothing
137
+ # wavelen > low_freq_wavelen: divide by factor
138
+ inverse_freq = ops.where(
139
+ ops.greater(wavelen, low_freq_wavelen),
140
+ (inverse_freq / self.frequency_adjustment_factor),
141
+ inverse_freq,
142
+ )
143
+
144
+ # otherwise: interpolate between the two, using a smooth factor
145
+ smooth_factor = (
146
+ (self.pretraining_sequence_length / wavelen)
147
+ - self.low_freq_factor
148
+ ) / (self.high_freq_factor - self.low_freq_factor)
149
+ smoothed_inv_freq = (1 - smooth_factor) * (
150
+ inverse_freq / self.frequency_adjustment_factor
151
+ ) + (smooth_factor * inverse_freq)
152
+ is_medium_freq = ops.logical_and(
153
+ ops.greater_equal(wavelen, high_freq_wavelen),
154
+ ops.less_equal(wavelen, low_freq_wavelen),
155
+ )
156
+
157
+ inverse_freq = ops.where(
158
+ is_medium_freq, smoothed_inv_freq, inverse_freq
159
+ )
160
+
161
+ return inverse_freq
162
+
163
+ def get_config(self):
164
+ config = super().get_config()
165
+ config.update(
166
+ {
167
+ "max_wavelength": self.max_wavelength,
168
+ "sequence_axis": self.sequence_axis,
169
+ "feature_axis": self.feature_axis,
170
+ "position_scaling_factor": self.position_scaling_factor,
171
+ "frequency_adjustment_factor": self.frequency_adjustment_factor,
172
+ "low_freq_factor": self.low_freq_factor,
173
+ "high_freq_factor": self.high_freq_factor,
174
+ "original_max_embeddings": self.pretraining_sequence_length,
175
+ }
176
+ )
177
+ return config
178
+
179
+ def compute_output_shape(self, input_shape):
180
+ return input_shape
@@ -32,8 +32,16 @@ class Llama3Backbone(LlamaBackbone):
32
32
  fo each transformer.
33
33
  rope_max_wavelength (int, optional): The maximum angular wavelength of
34
34
  the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
35
- rope_scaling_factor (float, optional): The scaling factor for
36
- calculation of roatary embedding. Defaults to `1.0`.
35
+ rope_position_scaling_factor (float, optional): The scaling factor for
36
+ calculation of roatary embedding. Defaults to `1.0`
37
+ rope_requency_adjustment_factor (float, optional): The scaling factor
38
+ used to scale the inverse frequencies.
39
+ rope_low_freq_factor (float, optional): The low frequency factor.
40
+ Defaults to None.
41
+ rope_high_freq_factor: (float, optional) Used for Llama3.1+. The high
42
+ frequency factor. Defaults to None.
43
+ rope_pretraining_sequence_length: (int, optional) Sequence length during
44
+ original pretraining. Defaults to None.
37
45
  layer_norm_epsilon (float, optional): Epsilon for the layer
38
46
  normalization layers in the transformer decoder. Defaults to `1e-6`.
39
47
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
@@ -8,7 +8,7 @@ backbone_presets = {
8
8
  "params": 8030261248,
9
9
  "path": "llama3",
10
10
  },
11
- "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/4",
11
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/5",
12
12
  },
13
13
  "llama3_8b_en_int8": {
14
14
  "metadata": {
@@ -30,7 +30,7 @@ backbone_presets = {
30
30
  "params": 8030261248,
31
31
  "path": "llama3",
32
32
  },
33
- "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/4",
33
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/5",
34
34
  },
35
35
  "llama3_instruct_8b_en_int8": {
36
36
  "metadata": {
@@ -45,4 +45,86 @@ backbone_presets = {
45
45
  "kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/2"
46
46
  ),
47
47
  },
48
+ "llama3.1_8b": {
49
+ "metadata": {
50
+ "description": (
51
+ "8 billion parameter, 32-layer, based LLaMA 3.1 model. "
52
+ ),
53
+ "params": 8030261248,
54
+ "path": "llama3",
55
+ },
56
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_8b/1"),
57
+ },
58
+ "llama3.1_instruct_8b": {
59
+ "metadata": {
60
+ "description": (
61
+ "8 billion parameter, 32-layer, instruction tuned LLaMA 3.1. "
62
+ ),
63
+ "params": 8030261248,
64
+ "path": "llama3",
65
+ },
66
+ "kaggle_handle": ("kaggle://keras/llama3/keras/lama3.1_instruct_8b/1"),
67
+ },
68
+ "llama3.1_guard_8b": {
69
+ "metadata": {
70
+ "description": (
71
+ "8 billion parameter, 32-layer, LLaMA 3.1 fine-tuned for "
72
+ "consent safety classification. "
73
+ ),
74
+ "params": 8030261248,
75
+ "path": "llama3",
76
+ },
77
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_guard_8b/1"),
78
+ },
79
+ "llama3.2_1b": {
80
+ "metadata": {
81
+ "description": (
82
+ "1 billion parameter, 16-layer, based LLaMA 3.2 model. "
83
+ ),
84
+ "params": 1498482688,
85
+ "path": "llama3",
86
+ },
87
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_1b/1"),
88
+ },
89
+ "llama3.2_instruct_1b": {
90
+ "metadata": {
91
+ "description": (
92
+ "1 billion parameter, 16-layer, instruction tuned LLaMA 3.2. "
93
+ ),
94
+ "params": 1498482688,
95
+ "path": "llama3",
96
+ },
97
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_1b/1"),
98
+ },
99
+ "llama3.2_3b": {
100
+ "metadata": {
101
+ "description": (
102
+ "3 billion parameter, 26-layer, based LLaMA 3.2 model. "
103
+ ),
104
+ "params": 3606752256,
105
+ "path": "llama3",
106
+ },
107
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_3b/1"),
108
+ },
109
+ "llama3.2_instruct_3b": {
110
+ "metadata": {
111
+ "description": (
112
+ "3 billion parameter, 28-layer, instruction tuned LLaMA 3.2. "
113
+ ),
114
+ "params": 3606752256,
115
+ "path": "llama3",
116
+ },
117
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_instruct_3b/1"),
118
+ },
119
+ "llama3.2_guard_1b": {
120
+ "metadata": {
121
+ "description": (
122
+ "1 billion parameter, 16-layer, based LLaMA 3.2 model "
123
+ "fine-tuned for consent safety classification. "
124
+ ),
125
+ "params": 1498482688,
126
+ "path": "llama3",
127
+ },
128
+ "kaggle_handle": ("kaggle://keras/llama3/keras/llama3.2_guard_1b/1"),
129
+ },
48
130
  }
@@ -8,7 +8,7 @@ backbone_presets = {
8
8
  "params": 7241732096,
9
9
  "path": "mistral",
10
10
  },
11
- "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/7",
11
+ "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/8",
12
12
  },
13
13
  "mistral_instruct_7b_en": {
14
14
  "metadata": {
@@ -16,7 +16,7 @@ backbone_presets = {
16
16
  "params": 7241732096,
17
17
  "path": "mistral",
18
18
  },
19
- "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/7",
19
+ "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/8",
20
20
  },
21
21
  "mistral_0.2_instruct_7b_en": {
22
22
  "metadata": {
@@ -24,6 +24,6 @@ backbone_presets = {
24
24
  "params": 7241732096,
25
25
  "path": "mistral",
26
26
  },
27
- "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/2",
27
+ "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/3",
28
28
  },
29
29
  }
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone
2
+ from keras_hub.src.models.mixtral.mixtral_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, MixtralBackbone)