keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import activations
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
10
8
  from keras_hub.src.models.roformer_v2.roformer_v2_encoder import (
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.rqvae.rqvae_backbone import RQVAEBackbone
2
+ from keras_hub.src.models.rqvae.rqvae_layers import Decoder
3
+ from keras_hub.src.models.rqvae.rqvae_layers import Encoder
4
+ from keras_hub.src.models.rqvae.rqvae_layers import ResidualVectorQuantizer
5
+ from keras_hub.src.models.rqvae.rqvae_layers import VectorQuantizerEMA
@@ -0,0 +1,167 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.rqvae import rqvae_layers
6
+
7
+
8
+ class RQVAEBackbone(Backbone):
9
+ """Residual Quantized Variational Autoencoder (RQVAE) backbone.
10
+
11
+ This class implements the RQ-VAE backbone, which consists of an encoder,
12
+ a residual vector quantizer, and a decoder. It is used for learning discrete
13
+ representations of data.
14
+
15
+ Args:
16
+ input_dim: Integer. The dimensionality of the input data.
17
+ encoder_layer_dims: A list of integers specifying the size of each
18
+ hidden Dense layer in the encoder.
19
+ output_dim: Integer. The dimensionality of the latent space (embedding
20
+ dimension).
21
+ decoder_layer_dims: A list of integers specifying the size of each
22
+ hidden
23
+ Dense layer in the decoder.
24
+ num_embeddings: Integer. The number of embeddings in the codebook.
25
+ num_quantizers: Integer. The number of sequential quantizers in the
26
+ residual vector quantizer.
27
+ decay: Float. The decay rate for the EMA updates in the quantizers.
28
+ Defaults to `0.99`.
29
+ data_variance: Float. The variance of the data, used to scale the
30
+ reconstruction loss. Defaults to `1.0`.
31
+ commitment_cost: Float. The weight of the commitment loss (quantization
32
+ loss) in the total loss. Defaults to `0.25`.
33
+ dtype: Optional dtype of the layer's computations and weights.
34
+ Alias of `variable_type`. Default to `None`.
35
+ **kwargs: Base backbone keyword arguments.
36
+
37
+ References:
38
+ - [SoundStream: An End-to-End Neural Audio Codec](https://arxiv.org/abs/2107.03312)
39
+
40
+ Example:
41
+ >>> model = RQVAEBackbone(
42
+ ... input_dim=10,
43
+ ... encoder_layer_dims=[32, 16],
44
+ ... output_dim=8,
45
+ ... decoder_layer_dims=[16, 32],
46
+ ... num_embeddings=64,
47
+ ... num_quantizers=4,
48
+ ... )
49
+ >>> x = keras.random.uniform(shape=(1, 10))
50
+ >>> outputs = model(x)
51
+ >>> tuple(outputs["reconstructions"].shape)
52
+ (1, 10)
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ input_dim,
58
+ encoder_layer_dims,
59
+ output_dim,
60
+ decoder_layer_dims,
61
+ num_embeddings,
62
+ num_quantizers,
63
+ decay=0.99,
64
+ data_variance=1.0,
65
+ commitment_cost=0.25,
66
+ dtype=None,
67
+ **kwargs,
68
+ ):
69
+ # inputs
70
+ input_dtype = dtype
71
+ if dtype is not None:
72
+ if isinstance(dtype, keras.dtype_policies.DTypePolicyMap):
73
+ input_dtype = dtype.default_policy.compute_dtype
74
+ elif getattr(dtype, "compute_dtype", None):
75
+ input_dtype = dtype.compute_dtype
76
+
77
+ inputs = keras.Input(shape=(input_dim,), dtype=input_dtype)
78
+
79
+ # Layers
80
+ encoder = rqvae_layers.Encoder(
81
+ layer_dims=encoder_layer_dims,
82
+ output_dim=output_dim,
83
+ dtype=dtype,
84
+ name="encoder",
85
+ )
86
+
87
+ quantizers = []
88
+ for i in range(num_quantizers):
89
+ quantizers.append(
90
+ rqvae_layers.VectorQuantizerEMA(
91
+ num_embeddings=num_embeddings,
92
+ embedding_dim=output_dim,
93
+ decay=decay,
94
+ dtype=dtype,
95
+ name=f"quantizer_{i}",
96
+ )
97
+ )
98
+
99
+ residual_quantizer = rqvae_layers.ResidualVectorQuantizer(
100
+ quantizers=quantizers, dtype=dtype, name="residual_quantizer"
101
+ )
102
+
103
+ decoder = rqvae_layers.Decoder(
104
+ layer_dims=decoder_layer_dims,
105
+ output_dim=input_dim,
106
+ dtype=dtype,
107
+ name="decoder",
108
+ )
109
+
110
+ # Functional Build
111
+ x = encoder(inputs)
112
+ quantized, encodings, usage_ratios, quantization_loss = (
113
+ residual_quantizer(x)
114
+ )
115
+ reconstructions = decoder(quantized)
116
+
117
+ outputs = {
118
+ "reconstructions": reconstructions,
119
+ "encodings": encodings,
120
+ "usage_ratios": usage_ratios,
121
+ "quantization_loss": quantization_loss,
122
+ }
123
+
124
+ super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)
125
+
126
+ self.encoder = encoder
127
+ self.residual_quantizer = residual_quantizer
128
+ self.decoder = decoder
129
+ self.data_variance = data_variance
130
+ self.commitment_cost = commitment_cost
131
+
132
+ # Save config
133
+ self.input_dim = input_dim
134
+ self.encoder_layer_dims = encoder_layer_dims
135
+ self.output_dim = output_dim
136
+ self.decoder_layer_dims = decoder_layer_dims
137
+ self.num_embeddings = num_embeddings
138
+ self.num_quantizers = num_quantizers
139
+ self.decay = decay
140
+
141
+ def compute_loss(self, x, y, y_pred, sample_weight=None):
142
+ reconstructions = y_pred["reconstructions"]
143
+ quantization_loss = y_pred["quantization_loss"]
144
+ target = y if y is not None else x
145
+ reconstruction_loss = (
146
+ ops.mean((reconstructions - target) ** 2) / self.data_variance
147
+ )
148
+
149
+ loss = reconstruction_loss + self.commitment_cost * quantization_loss
150
+ return loss
151
+
152
+ def get_config(self):
153
+ config = super().get_config()
154
+ config.update(
155
+ {
156
+ "input_dim": self.input_dim,
157
+ "encoder_layer_dims": self.encoder_layer_dims,
158
+ "output_dim": self.output_dim,
159
+ "decoder_layer_dims": self.decoder_layer_dims,
160
+ "num_embeddings": self.num_embeddings,
161
+ "num_quantizers": self.num_quantizers,
162
+ "decay": self.decay,
163
+ "data_variance": self.data_variance,
164
+ "commitment_cost": self.commitment_cost,
165
+ }
166
+ )
167
+ return config
@@ -0,0 +1,335 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+
6
+ class Encoder(layers.Layer):
7
+ """A simple feed-forward encoder with ReLU activations.
8
+
9
+ This layer consists of a sequence of Dense layers with ReLU activation,
10
+ followed by a final Dense layer with no activation.
11
+
12
+ Args:
13
+ layer_dims: A list of integers specifying the size of each hidden Dense
14
+ layer.
15
+ output_dim: Integer. The size of the output Dense layer.
16
+ **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
17
+
18
+ Example:
19
+ >>> encoder = Encoder(layer_dims=[64, 32], output_dim=16)
20
+ >>> x = keras.random.uniform(shape=(1, 10))
21
+ >>> output = encoder(x)
22
+ >>> tuple(output.shape)
23
+ (1, 16)
24
+ """
25
+
26
+ def __init__(self, layer_dims, output_dim, **kwargs):
27
+ super().__init__(**kwargs)
28
+ self.layer_dims = layer_dims
29
+ self.output_dim = output_dim
30
+ self.dense_layers = []
31
+ for dim in layer_dims:
32
+ self.dense_layers.append(layers.Dense(dim, activation="relu"))
33
+ self.output_layer = layers.Dense(output_dim)
34
+
35
+ def call(self, inputs):
36
+ x = inputs
37
+ for layer in self.dense_layers:
38
+ x = layer(x)
39
+ return self.output_layer(x)
40
+
41
+ def get_config(self):
42
+ config = super().get_config()
43
+ config.update(
44
+ {
45
+ "layer_dims": self.layer_dims,
46
+ "output_dim": self.output_dim,
47
+ }
48
+ )
49
+ return config
50
+
51
+
52
+ class Decoder(layers.Layer):
53
+ """A simple feed-forward decoder with ReLU activations.
54
+
55
+ This layer consists of a sequence of Dense layers with ReLU activation,
56
+ followed by a final Dense layer with no activation.
57
+
58
+ Args:
59
+ layer_dims: A list of integers specifying the size of each hidden Dense
60
+ layer.
61
+ output_dim: Integer. The size of the output Dense layer.
62
+ **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
63
+
64
+ Example:
65
+ >>> decoder = Decoder(layer_dims=[32, 64], output_dim=10)
66
+ >>> x = keras.random.uniform(shape=(1, 16))
67
+ >>> output = decoder(x)
68
+ >>> tuple(output.shape)
69
+ (1, 10)
70
+ """
71
+
72
+ def __init__(self, layer_dims, output_dim, **kwargs):
73
+ super().__init__(**kwargs)
74
+ self.layer_dims = layer_dims
75
+ self.output_dim = output_dim
76
+ self.dense_layers = []
77
+ for dim in layer_dims:
78
+ self.dense_layers.append(layers.Dense(dim, activation="relu"))
79
+ self.output_layer = layers.Dense(output_dim)
80
+
81
+ def call(self, inputs):
82
+ x = inputs
83
+ for layer in self.dense_layers:
84
+ x = layer(x)
85
+ return self.output_layer(x)
86
+
87
+ def get_config(self):
88
+ config = super().get_config()
89
+ config.update(
90
+ {
91
+ "layer_dims": self.layer_dims,
92
+ "output_dim": self.output_dim,
93
+ }
94
+ )
95
+ return config
96
+
97
+
98
+ class VectorQuantizerEMA(layers.Layer):
99
+ """Vector Quantizer with Exponential Moving Average (EMA) updates.
100
+
101
+ This layer implements a vector quantization module using EMA to update
102
+ states, which stabilizes the training process compared to codebook collapse.
103
+ It takes an input tensor, flattens it, and maps each vector to the nearest
104
+ element in a codebook (embeddings).
105
+
106
+ Args:
107
+ num_embeddings: Integer. The number of embeddings in the codebook.
108
+ embedding_dim: Integer. The dimensionality of each embedding vector.
109
+ decay: Float. The decay rate for the EMA updates. Defaults to `0.99`.
110
+ eps: Float. A small epsilon value for numerical stability to avoid
111
+ division by zero. Defaults to `1e-5`.
112
+ **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
113
+
114
+ References:
115
+ - [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)
116
+
117
+ Example:
118
+ >>> vq = VectorQuantizerEMA(num_embeddings=10, embedding_dim=16)
119
+ >>> x = keras.random.uniform(shape=(1, 5, 16))
120
+ >>> quantized, encodings, usage_ratio, loss = vq(x)
121
+ >>> tuple(quantized.shape)
122
+ (1, 5, 16)
123
+ """
124
+
125
+ def __init__(
126
+ self, num_embeddings, embedding_dim, decay=0.99, eps=1e-5, **kwargs
127
+ ):
128
+ super().__init__(**kwargs)
129
+ self.num_embeddings = num_embeddings
130
+ self.embedding_dim = embedding_dim
131
+ self.decay = decay
132
+ self.eps = eps
133
+
134
+ def build(self, input_shape):
135
+ self.embeddings = self.add_weight(
136
+ shape=(self.num_embeddings, self.embedding_dim),
137
+ initializer="random_normal",
138
+ trainable=False,
139
+ name="embeddings",
140
+ )
141
+ self.ema_cluster_size = self.add_weight(
142
+ shape=(self.num_embeddings,),
143
+ initializer="zeros",
144
+ trainable=False,
145
+ name="ema_cluster_size",
146
+ )
147
+ self.ema_w = self.add_weight(
148
+ shape=(self.num_embeddings, self.embedding_dim),
149
+ initializer="random_normal",
150
+ trainable=False,
151
+ name="ema_w",
152
+ )
153
+
154
+ def _codebook_usage(self, encodings):
155
+ usage_counts = ops.sum(encodings, axis=0) # (num_embeddings,)
156
+ num_used = ops.sum(ops.cast(usage_counts > 0, "float32"))
157
+ return num_used / self.num_embeddings
158
+
159
+ def call(self, inputs, training=False):
160
+ input_shape = ops.shape(inputs)
161
+ # Flatten inputs to (N, D)
162
+ flattened_inputs = ops.reshape(inputs, (-1, self.embedding_dim))
163
+
164
+ # Distances: x^2 + c^2 - 2xc
165
+ # inputs: (N, D), codebook: (E, D)
166
+ input_sq = ops.sum(flattened_inputs**2, axis=1, keepdims=True)
167
+ codebook_sq = ops.sum(self.embeddings**2, axis=1) # (E,)
168
+ dot_product = ops.matmul(
169
+ flattened_inputs, ops.transpose(self.embeddings)
170
+ ) # (N, E)
171
+
172
+ distances = input_sq + codebook_sq - 2 * dot_product
173
+
174
+ # Encoding
175
+ encoding_indices = ops.argmin(distances, axis=-1)
176
+ encodings = ops.one_hot(encoding_indices, self.num_embeddings) # (N, E)
177
+
178
+ # Quantize
179
+ quantized_flat = ops.take(self.embeddings, encoding_indices, axis=0)
180
+
181
+ if training:
182
+ # EMA Update
183
+ current_counts = ops.sum(encodings, axis=0)
184
+ updated_ema_cluster_size = (
185
+ self.ema_cluster_size * self.decay
186
+ + (1.0 - self.decay) * current_counts
187
+ )
188
+
189
+ # Laplace smoothing
190
+ n = ops.sum(updated_ema_cluster_size)
191
+ updated_ema_cluster_size = (
192
+ (updated_ema_cluster_size + self.eps)
193
+ / (n + self.num_embeddings * self.eps)
194
+ * n
195
+ )
196
+
197
+ self.ema_cluster_size.assign(updated_ema_cluster_size)
198
+
199
+ # total_assignment_sums = encoding.T @ inputs -> (E, D)
200
+ total_assignment_sums = ops.matmul(
201
+ ops.transpose(encodings), flattened_inputs
202
+ )
203
+
204
+ updated_ema_w = (
205
+ self.ema_w * self.decay
206
+ + (1.0 - self.decay) * total_assignment_sums
207
+ )
208
+ self.ema_w.assign(updated_ema_w)
209
+
210
+ updated_embeddings = self.ema_w / (
211
+ ops.expand_dims(self.ema_cluster_size, axis=1) + self.eps
212
+ )
213
+ self.embeddings.assign(updated_embeddings)
214
+
215
+ # Quantization loss
216
+ quantization_loss = ops.mean(
217
+ (flattened_inputs - quantized_flat) ** 2
218
+ )
219
+ quantization_loss = ops.reshape(quantization_loss, (1,))
220
+
221
+ # STE
222
+ quantized = ops.reshape(quantized_flat, input_shape)
223
+ quantized_flow = inputs + ops.stop_gradient(quantized - inputs)
224
+
225
+ usage_ratio = self._codebook_usage(encodings)
226
+
227
+ return quantized_flow, encodings, usage_ratio, quantization_loss
228
+ else:
229
+ quantized = ops.reshape(quantized_flat, input_shape)
230
+ usage_ratio = self._codebook_usage(encodings)
231
+ quantization_loss = ops.convert_to_tensor(0.0, dtype=inputs.dtype)
232
+ quantization_loss = ops.reshape(quantization_loss, (1,))
233
+ return quantized, encodings, usage_ratio, quantization_loss
234
+
235
+ def get_config(self):
236
+ config = super().get_config()
237
+ config.update(
238
+ {
239
+ "num_embeddings": self.num_embeddings,
240
+ "embedding_dim": self.embedding_dim,
241
+ "decay": self.decay,
242
+ "eps": self.eps,
243
+ }
244
+ )
245
+ return config
246
+
247
+
248
+ class ResidualVectorQuantizer(layers.Layer):
249
+ """A Residual Vector Quantizer.
250
+
251
+ This layer applies a sequence of vector quantizers to the residual of the
252
+ input. The first quantizer quantizes the input, the second quantizer
253
+ quantizes the error (residual) from the first, and so on.
254
+
255
+ Args:
256
+ quantizers: A list of `VectorQuantizerEMA` instances (or compatible
257
+ layers) to be applied sequentially.
258
+ **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
259
+
260
+ References:
261
+ - [SoundStream: An End-to-End Neural Audio Codec](https://arxiv.org/abs/2107.03312)
262
+
263
+ Example:
264
+ >>> vq1 = VectorQuantizerEMA(num_embeddings=10, embedding_dim=16)
265
+ >>> vq2 = VectorQuantizerEMA(num_embeddings=10, embedding_dim=16)
266
+ >>> rvq = ResidualVectorQuantizer(quantizers=[vq1, vq2])
267
+ >>> x = keras.random.uniform(shape=(1, 5, 16))
268
+ >>> quantized_sum, encodings, usage_ratios, loss = rvq(x)
269
+ >>> tuple(quantized_sum.shape)
270
+ (1, 5, 16)
271
+ """
272
+
273
+ def __init__(self, quantizers, **kwargs):
274
+ super().__init__(**kwargs)
275
+ # quantizers should be a list of layer instances or configs?
276
+ # Typically in Keras we pass instances.
277
+ self.quantizers = quantizers
278
+
279
+ def call(self, inputs, training=False):
280
+ quantized_list = []
281
+ encodings_list = []
282
+ usage_ratios_list = []
283
+ residual = inputs
284
+ total_quantization_loss = ops.convert_to_tensor(0.0, dtype=inputs.dtype)
285
+
286
+ for quantizer in self.quantizers:
287
+ # Always returns 4 values now
288
+ (
289
+ current_quantized,
290
+ current_encoding,
291
+ usage_ratio,
292
+ quantization_loss,
293
+ ) = quantizer(residual, training=training)
294
+ total_quantization_loss = (
295
+ total_quantization_loss + quantization_loss
296
+ )
297
+
298
+ quantized_list.append(current_quantized)
299
+ residual = residual - current_quantized
300
+ encodings_list.append(current_encoding)
301
+ usage_ratios_list.append(usage_ratio)
302
+
303
+ # Stack results
304
+ # quantized: sum of all quantized
305
+ # encodings: stack
306
+ # usage_ratios: stack
307
+
308
+ quantized_sum = sum(quantized_list) # Element-wise sum
309
+ # ops.stack needs a list of tensors
310
+ encodings = ops.stack(encodings_list, axis=0)
311
+ # usage_ratios is list of scalars (tensors of rank 0) or simple tensors.
312
+ # ops.stack works.
313
+ usage_ratios = ops.stack(usage_ratios_list, axis=0)
314
+
315
+ return quantized_sum, encodings, usage_ratios, total_quantization_loss
316
+
317
+ def get_config(self):
318
+ config = super().get_config()
319
+ quantizers_config = []
320
+ for q in self.quantizers:
321
+ quantizers_config.append(keras.utils.serialize_keras_object(q))
322
+ config.update(
323
+ {
324
+ "quantizers": quantizers_config,
325
+ }
326
+ )
327
+ return config
328
+
329
+ @classmethod
330
+ def from_config(cls, config):
331
+ quantizers_config = config.pop("quantizers")
332
+ quantizers = [
333
+ keras.utils.deserialize_keras_object(q) for q in quantizers_config
334
+ ]
335
+ return cls(quantizers=quantizers, **config)
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
2
+ from keras_hub.src.models.rwkv7.rwkv7_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, RWKV7Backbone)