keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import activations
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
|
|
10
8
|
from keras_hub.src.models.roformer_v2.roformer_v2_encoder import (
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from keras_hub.src.models.rqvae.rqvae_backbone import RQVAEBackbone
|
|
2
|
+
from keras_hub.src.models.rqvae.rqvae_layers import Decoder
|
|
3
|
+
from keras_hub.src.models.rqvae.rqvae_layers import Encoder
|
|
4
|
+
from keras_hub.src.models.rqvae.rqvae_layers import ResidualVectorQuantizer
|
|
5
|
+
from keras_hub.src.models.rqvae.rqvae_layers import VectorQuantizerEMA
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
|
5
|
+
from keras_hub.src.models.rqvae import rqvae_layers
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RQVAEBackbone(Backbone):
|
|
9
|
+
"""Residual Quantized Variational Autoencoder (RQVAE) backbone.
|
|
10
|
+
|
|
11
|
+
This class implements the RQ-VAE backbone, which consists of an encoder,
|
|
12
|
+
a residual vector quantizer, and a decoder. It is used for learning discrete
|
|
13
|
+
representations of data.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
input_dim: Integer. The dimensionality of the input data.
|
|
17
|
+
encoder_layer_dims: A list of integers specifying the size of each
|
|
18
|
+
hidden Dense layer in the encoder.
|
|
19
|
+
output_dim: Integer. The dimensionality of the latent space (embedding
|
|
20
|
+
dimension).
|
|
21
|
+
decoder_layer_dims: A list of integers specifying the size of each
|
|
22
|
+
hidden
|
|
23
|
+
Dense layer in the decoder.
|
|
24
|
+
num_embeddings: Integer. The number of embeddings in the codebook.
|
|
25
|
+
num_quantizers: Integer. The number of sequential quantizers in the
|
|
26
|
+
residual vector quantizer.
|
|
27
|
+
decay: Float. The decay rate for the EMA updates in the quantizers.
|
|
28
|
+
Defaults to `0.99`.
|
|
29
|
+
data_variance: Float. The variance of the data, used to scale the
|
|
30
|
+
reconstruction loss. Defaults to `1.0`.
|
|
31
|
+
commitment_cost: Float. The weight of the commitment loss (quantization
|
|
32
|
+
loss) in the total loss. Defaults to `0.25`.
|
|
33
|
+
dtype: Optional dtype of the layer's computations and weights.
|
|
34
|
+
Alias of `variable_type`. Default to `None`.
|
|
35
|
+
**kwargs: Base backbone keyword arguments.
|
|
36
|
+
|
|
37
|
+
References:
|
|
38
|
+
- [SoundStream: An End-to-End Neural Audio Codec](https://arxiv.org/abs/2107.03312)
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> model = RQVAEBackbone(
|
|
42
|
+
... input_dim=10,
|
|
43
|
+
... encoder_layer_dims=[32, 16],
|
|
44
|
+
... output_dim=8,
|
|
45
|
+
... decoder_layer_dims=[16, 32],
|
|
46
|
+
... num_embeddings=64,
|
|
47
|
+
... num_quantizers=4,
|
|
48
|
+
... )
|
|
49
|
+
>>> x = keras.random.uniform(shape=(1, 10))
|
|
50
|
+
>>> outputs = model(x)
|
|
51
|
+
>>> tuple(outputs["reconstructions"].shape)
|
|
52
|
+
(1, 10)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
input_dim,
|
|
58
|
+
encoder_layer_dims,
|
|
59
|
+
output_dim,
|
|
60
|
+
decoder_layer_dims,
|
|
61
|
+
num_embeddings,
|
|
62
|
+
num_quantizers,
|
|
63
|
+
decay=0.99,
|
|
64
|
+
data_variance=1.0,
|
|
65
|
+
commitment_cost=0.25,
|
|
66
|
+
dtype=None,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
# inputs
|
|
70
|
+
input_dtype = dtype
|
|
71
|
+
if dtype is not None:
|
|
72
|
+
if isinstance(dtype, keras.dtype_policies.DTypePolicyMap):
|
|
73
|
+
input_dtype = dtype.default_policy.compute_dtype
|
|
74
|
+
elif getattr(dtype, "compute_dtype", None):
|
|
75
|
+
input_dtype = dtype.compute_dtype
|
|
76
|
+
|
|
77
|
+
inputs = keras.Input(shape=(input_dim,), dtype=input_dtype)
|
|
78
|
+
|
|
79
|
+
# Layers
|
|
80
|
+
encoder = rqvae_layers.Encoder(
|
|
81
|
+
layer_dims=encoder_layer_dims,
|
|
82
|
+
output_dim=output_dim,
|
|
83
|
+
dtype=dtype,
|
|
84
|
+
name="encoder",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
quantizers = []
|
|
88
|
+
for i in range(num_quantizers):
|
|
89
|
+
quantizers.append(
|
|
90
|
+
rqvae_layers.VectorQuantizerEMA(
|
|
91
|
+
num_embeddings=num_embeddings,
|
|
92
|
+
embedding_dim=output_dim,
|
|
93
|
+
decay=decay,
|
|
94
|
+
dtype=dtype,
|
|
95
|
+
name=f"quantizer_{i}",
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
residual_quantizer = rqvae_layers.ResidualVectorQuantizer(
|
|
100
|
+
quantizers=quantizers, dtype=dtype, name="residual_quantizer"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
decoder = rqvae_layers.Decoder(
|
|
104
|
+
layer_dims=decoder_layer_dims,
|
|
105
|
+
output_dim=input_dim,
|
|
106
|
+
dtype=dtype,
|
|
107
|
+
name="decoder",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Functional Build
|
|
111
|
+
x = encoder(inputs)
|
|
112
|
+
quantized, encodings, usage_ratios, quantization_loss = (
|
|
113
|
+
residual_quantizer(x)
|
|
114
|
+
)
|
|
115
|
+
reconstructions = decoder(quantized)
|
|
116
|
+
|
|
117
|
+
outputs = {
|
|
118
|
+
"reconstructions": reconstructions,
|
|
119
|
+
"encodings": encodings,
|
|
120
|
+
"usage_ratios": usage_ratios,
|
|
121
|
+
"quantization_loss": quantization_loss,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)
|
|
125
|
+
|
|
126
|
+
self.encoder = encoder
|
|
127
|
+
self.residual_quantizer = residual_quantizer
|
|
128
|
+
self.decoder = decoder
|
|
129
|
+
self.data_variance = data_variance
|
|
130
|
+
self.commitment_cost = commitment_cost
|
|
131
|
+
|
|
132
|
+
# Save config
|
|
133
|
+
self.input_dim = input_dim
|
|
134
|
+
self.encoder_layer_dims = encoder_layer_dims
|
|
135
|
+
self.output_dim = output_dim
|
|
136
|
+
self.decoder_layer_dims = decoder_layer_dims
|
|
137
|
+
self.num_embeddings = num_embeddings
|
|
138
|
+
self.num_quantizers = num_quantizers
|
|
139
|
+
self.decay = decay
|
|
140
|
+
|
|
141
|
+
def compute_loss(self, x, y, y_pred, sample_weight=None):
|
|
142
|
+
reconstructions = y_pred["reconstructions"]
|
|
143
|
+
quantization_loss = y_pred["quantization_loss"]
|
|
144
|
+
target = y if y is not None else x
|
|
145
|
+
reconstruction_loss = (
|
|
146
|
+
ops.mean((reconstructions - target) ** 2) / self.data_variance
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
loss = reconstruction_loss + self.commitment_cost * quantization_loss
|
|
150
|
+
return loss
|
|
151
|
+
|
|
152
|
+
def get_config(self):
|
|
153
|
+
config = super().get_config()
|
|
154
|
+
config.update(
|
|
155
|
+
{
|
|
156
|
+
"input_dim": self.input_dim,
|
|
157
|
+
"encoder_layer_dims": self.encoder_layer_dims,
|
|
158
|
+
"output_dim": self.output_dim,
|
|
159
|
+
"decoder_layer_dims": self.decoder_layer_dims,
|
|
160
|
+
"num_embeddings": self.num_embeddings,
|
|
161
|
+
"num_quantizers": self.num_quantizers,
|
|
162
|
+
"decay": self.decay,
|
|
163
|
+
"data_variance": self.data_variance,
|
|
164
|
+
"commitment_cost": self.commitment_cost,
|
|
165
|
+
}
|
|
166
|
+
)
|
|
167
|
+
return config
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import layers
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Encoder(layers.Layer):
|
|
7
|
+
"""A simple feed-forward encoder with ReLU activations.
|
|
8
|
+
|
|
9
|
+
This layer consists of a sequence of Dense layers with ReLU activation,
|
|
10
|
+
followed by a final Dense layer with no activation.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
layer_dims: A list of integers specifying the size of each hidden Dense
|
|
14
|
+
layer.
|
|
15
|
+
output_dim: Integer. The size of the output Dense layer.
|
|
16
|
+
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> encoder = Encoder(layer_dims=[64, 32], output_dim=16)
|
|
20
|
+
>>> x = keras.random.uniform(shape=(1, 10))
|
|
21
|
+
>>> output = encoder(x)
|
|
22
|
+
>>> tuple(output.shape)
|
|
23
|
+
(1, 16)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, layer_dims, output_dim, **kwargs):
|
|
27
|
+
super().__init__(**kwargs)
|
|
28
|
+
self.layer_dims = layer_dims
|
|
29
|
+
self.output_dim = output_dim
|
|
30
|
+
self.dense_layers = []
|
|
31
|
+
for dim in layer_dims:
|
|
32
|
+
self.dense_layers.append(layers.Dense(dim, activation="relu"))
|
|
33
|
+
self.output_layer = layers.Dense(output_dim)
|
|
34
|
+
|
|
35
|
+
def call(self, inputs):
|
|
36
|
+
x = inputs
|
|
37
|
+
for layer in self.dense_layers:
|
|
38
|
+
x = layer(x)
|
|
39
|
+
return self.output_layer(x)
|
|
40
|
+
|
|
41
|
+
def get_config(self):
|
|
42
|
+
config = super().get_config()
|
|
43
|
+
config.update(
|
|
44
|
+
{
|
|
45
|
+
"layer_dims": self.layer_dims,
|
|
46
|
+
"output_dim": self.output_dim,
|
|
47
|
+
}
|
|
48
|
+
)
|
|
49
|
+
return config
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Decoder(layers.Layer):
|
|
53
|
+
"""A simple feed-forward decoder with ReLU activations.
|
|
54
|
+
|
|
55
|
+
This layer consists of a sequence of Dense layers with ReLU activation,
|
|
56
|
+
followed by a final Dense layer with no activation.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
layer_dims: A list of integers specifying the size of each hidden Dense
|
|
60
|
+
layer.
|
|
61
|
+
output_dim: Integer. The size of the output Dense layer.
|
|
62
|
+
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
>>> decoder = Decoder(layer_dims=[32, 64], output_dim=10)
|
|
66
|
+
>>> x = keras.random.uniform(shape=(1, 16))
|
|
67
|
+
>>> output = decoder(x)
|
|
68
|
+
>>> tuple(output.shape)
|
|
69
|
+
(1, 10)
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, layer_dims, output_dim, **kwargs):
|
|
73
|
+
super().__init__(**kwargs)
|
|
74
|
+
self.layer_dims = layer_dims
|
|
75
|
+
self.output_dim = output_dim
|
|
76
|
+
self.dense_layers = []
|
|
77
|
+
for dim in layer_dims:
|
|
78
|
+
self.dense_layers.append(layers.Dense(dim, activation="relu"))
|
|
79
|
+
self.output_layer = layers.Dense(output_dim)
|
|
80
|
+
|
|
81
|
+
def call(self, inputs):
|
|
82
|
+
x = inputs
|
|
83
|
+
for layer in self.dense_layers:
|
|
84
|
+
x = layer(x)
|
|
85
|
+
return self.output_layer(x)
|
|
86
|
+
|
|
87
|
+
def get_config(self):
|
|
88
|
+
config = super().get_config()
|
|
89
|
+
config.update(
|
|
90
|
+
{
|
|
91
|
+
"layer_dims": self.layer_dims,
|
|
92
|
+
"output_dim": self.output_dim,
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
return config
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class VectorQuantizerEMA(layers.Layer):
|
|
99
|
+
"""Vector Quantizer with Exponential Moving Average (EMA) updates.
|
|
100
|
+
|
|
101
|
+
This layer implements a vector quantization module using EMA to update
|
|
102
|
+
states, which stabilizes the training process compared to codebook collapse.
|
|
103
|
+
It takes an input tensor, flattens it, and maps each vector to the nearest
|
|
104
|
+
element in a codebook (embeddings).
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
num_embeddings: Integer. The number of embeddings in the codebook.
|
|
108
|
+
embedding_dim: Integer. The dimensionality of each embedding vector.
|
|
109
|
+
decay: Float. The decay rate for the EMA updates. Defaults to `0.99`.
|
|
110
|
+
eps: Float. A small epsilon value for numerical stability to avoid
|
|
111
|
+
division by zero. Defaults to `1e-5`.
|
|
112
|
+
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
|
113
|
+
|
|
114
|
+
References:
|
|
115
|
+
- [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> vq = VectorQuantizerEMA(num_embeddings=10, embedding_dim=16)
|
|
119
|
+
>>> x = keras.random.uniform(shape=(1, 5, 16))
|
|
120
|
+
>>> quantized, encodings, usage_ratio, loss = vq(x)
|
|
121
|
+
>>> tuple(quantized.shape)
|
|
122
|
+
(1, 5, 16)
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self, num_embeddings, embedding_dim, decay=0.99, eps=1e-5, **kwargs
|
|
127
|
+
):
|
|
128
|
+
super().__init__(**kwargs)
|
|
129
|
+
self.num_embeddings = num_embeddings
|
|
130
|
+
self.embedding_dim = embedding_dim
|
|
131
|
+
self.decay = decay
|
|
132
|
+
self.eps = eps
|
|
133
|
+
|
|
134
|
+
def build(self, input_shape):
|
|
135
|
+
self.embeddings = self.add_weight(
|
|
136
|
+
shape=(self.num_embeddings, self.embedding_dim),
|
|
137
|
+
initializer="random_normal",
|
|
138
|
+
trainable=False,
|
|
139
|
+
name="embeddings",
|
|
140
|
+
)
|
|
141
|
+
self.ema_cluster_size = self.add_weight(
|
|
142
|
+
shape=(self.num_embeddings,),
|
|
143
|
+
initializer="zeros",
|
|
144
|
+
trainable=False,
|
|
145
|
+
name="ema_cluster_size",
|
|
146
|
+
)
|
|
147
|
+
self.ema_w = self.add_weight(
|
|
148
|
+
shape=(self.num_embeddings, self.embedding_dim),
|
|
149
|
+
initializer="random_normal",
|
|
150
|
+
trainable=False,
|
|
151
|
+
name="ema_w",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _codebook_usage(self, encodings):
|
|
155
|
+
usage_counts = ops.sum(encodings, axis=0) # (num_embeddings,)
|
|
156
|
+
num_used = ops.sum(ops.cast(usage_counts > 0, "float32"))
|
|
157
|
+
return num_used / self.num_embeddings
|
|
158
|
+
|
|
159
|
+
def call(self, inputs, training=False):
|
|
160
|
+
input_shape = ops.shape(inputs)
|
|
161
|
+
# Flatten inputs to (N, D)
|
|
162
|
+
flattened_inputs = ops.reshape(inputs, (-1, self.embedding_dim))
|
|
163
|
+
|
|
164
|
+
# Distances: x^2 + c^2 - 2xc
|
|
165
|
+
# inputs: (N, D), codebook: (E, D)
|
|
166
|
+
input_sq = ops.sum(flattened_inputs**2, axis=1, keepdims=True)
|
|
167
|
+
codebook_sq = ops.sum(self.embeddings**2, axis=1) # (E,)
|
|
168
|
+
dot_product = ops.matmul(
|
|
169
|
+
flattened_inputs, ops.transpose(self.embeddings)
|
|
170
|
+
) # (N, E)
|
|
171
|
+
|
|
172
|
+
distances = input_sq + codebook_sq - 2 * dot_product
|
|
173
|
+
|
|
174
|
+
# Encoding
|
|
175
|
+
encoding_indices = ops.argmin(distances, axis=-1)
|
|
176
|
+
encodings = ops.one_hot(encoding_indices, self.num_embeddings) # (N, E)
|
|
177
|
+
|
|
178
|
+
# Quantize
|
|
179
|
+
quantized_flat = ops.take(self.embeddings, encoding_indices, axis=0)
|
|
180
|
+
|
|
181
|
+
if training:
|
|
182
|
+
# EMA Update
|
|
183
|
+
current_counts = ops.sum(encodings, axis=0)
|
|
184
|
+
updated_ema_cluster_size = (
|
|
185
|
+
self.ema_cluster_size * self.decay
|
|
186
|
+
+ (1.0 - self.decay) * current_counts
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Laplace smoothing
|
|
190
|
+
n = ops.sum(updated_ema_cluster_size)
|
|
191
|
+
updated_ema_cluster_size = (
|
|
192
|
+
(updated_ema_cluster_size + self.eps)
|
|
193
|
+
/ (n + self.num_embeddings * self.eps)
|
|
194
|
+
* n
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
self.ema_cluster_size.assign(updated_ema_cluster_size)
|
|
198
|
+
|
|
199
|
+
# total_assignment_sums = encoding.T @ inputs -> (E, D)
|
|
200
|
+
total_assignment_sums = ops.matmul(
|
|
201
|
+
ops.transpose(encodings), flattened_inputs
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
updated_ema_w = (
|
|
205
|
+
self.ema_w * self.decay
|
|
206
|
+
+ (1.0 - self.decay) * total_assignment_sums
|
|
207
|
+
)
|
|
208
|
+
self.ema_w.assign(updated_ema_w)
|
|
209
|
+
|
|
210
|
+
updated_embeddings = self.ema_w / (
|
|
211
|
+
ops.expand_dims(self.ema_cluster_size, axis=1) + self.eps
|
|
212
|
+
)
|
|
213
|
+
self.embeddings.assign(updated_embeddings)
|
|
214
|
+
|
|
215
|
+
# Quantization loss
|
|
216
|
+
quantization_loss = ops.mean(
|
|
217
|
+
(flattened_inputs - quantized_flat) ** 2
|
|
218
|
+
)
|
|
219
|
+
quantization_loss = ops.reshape(quantization_loss, (1,))
|
|
220
|
+
|
|
221
|
+
# STE
|
|
222
|
+
quantized = ops.reshape(quantized_flat, input_shape)
|
|
223
|
+
quantized_flow = inputs + ops.stop_gradient(quantized - inputs)
|
|
224
|
+
|
|
225
|
+
usage_ratio = self._codebook_usage(encodings)
|
|
226
|
+
|
|
227
|
+
return quantized_flow, encodings, usage_ratio, quantization_loss
|
|
228
|
+
else:
|
|
229
|
+
quantized = ops.reshape(quantized_flat, input_shape)
|
|
230
|
+
usage_ratio = self._codebook_usage(encodings)
|
|
231
|
+
quantization_loss = ops.convert_to_tensor(0.0, dtype=inputs.dtype)
|
|
232
|
+
quantization_loss = ops.reshape(quantization_loss, (1,))
|
|
233
|
+
return quantized, encodings, usage_ratio, quantization_loss
|
|
234
|
+
|
|
235
|
+
def get_config(self):
|
|
236
|
+
config = super().get_config()
|
|
237
|
+
config.update(
|
|
238
|
+
{
|
|
239
|
+
"num_embeddings": self.num_embeddings,
|
|
240
|
+
"embedding_dim": self.embedding_dim,
|
|
241
|
+
"decay": self.decay,
|
|
242
|
+
"eps": self.eps,
|
|
243
|
+
}
|
|
244
|
+
)
|
|
245
|
+
return config
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class ResidualVectorQuantizer(layers.Layer):
|
|
249
|
+
"""A Residual Vector Quantizer.
|
|
250
|
+
|
|
251
|
+
This layer applies a sequence of vector quantizers to the residual of the
|
|
252
|
+
input. The first quantizer quantizes the input, the second quantizer
|
|
253
|
+
quantizes the error (residual) from the first, and so on.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
quantizers: A list of `VectorQuantizerEMA` instances (or compatible
|
|
257
|
+
layers) to be applied sequentially.
|
|
258
|
+
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
|
259
|
+
|
|
260
|
+
References:
|
|
261
|
+
- [SoundStream: An End-to-End Neural Audio Codec](https://arxiv.org/abs/2107.03312)
|
|
262
|
+
|
|
263
|
+
Example:
|
|
264
|
+
>>> vq1 = VectorQuantizerEMA(num_embeddings=10, embedding_dim=16)
|
|
265
|
+
>>> vq2 = VectorQuantizerEMA(num_embeddings=10, embedding_dim=16)
|
|
266
|
+
>>> rvq = ResidualVectorQuantizer(quantizers=[vq1, vq2])
|
|
267
|
+
>>> x = keras.random.uniform(shape=(1, 5, 16))
|
|
268
|
+
>>> quantized_sum, encodings, usage_ratios, loss = rvq(x)
|
|
269
|
+
>>> tuple(quantized_sum.shape)
|
|
270
|
+
(1, 5, 16)
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def __init__(self, quantizers, **kwargs):
|
|
274
|
+
super().__init__(**kwargs)
|
|
275
|
+
# quantizers should be a list of layer instances or configs?
|
|
276
|
+
# Typically in Keras we pass instances.
|
|
277
|
+
self.quantizers = quantizers
|
|
278
|
+
|
|
279
|
+
def call(self, inputs, training=False):
|
|
280
|
+
quantized_list = []
|
|
281
|
+
encodings_list = []
|
|
282
|
+
usage_ratios_list = []
|
|
283
|
+
residual = inputs
|
|
284
|
+
total_quantization_loss = ops.convert_to_tensor(0.0, dtype=inputs.dtype)
|
|
285
|
+
|
|
286
|
+
for quantizer in self.quantizers:
|
|
287
|
+
# Always returns 4 values now
|
|
288
|
+
(
|
|
289
|
+
current_quantized,
|
|
290
|
+
current_encoding,
|
|
291
|
+
usage_ratio,
|
|
292
|
+
quantization_loss,
|
|
293
|
+
) = quantizer(residual, training=training)
|
|
294
|
+
total_quantization_loss = (
|
|
295
|
+
total_quantization_loss + quantization_loss
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
quantized_list.append(current_quantized)
|
|
299
|
+
residual = residual - current_quantized
|
|
300
|
+
encodings_list.append(current_encoding)
|
|
301
|
+
usage_ratios_list.append(usage_ratio)
|
|
302
|
+
|
|
303
|
+
# Stack results
|
|
304
|
+
# quantized: sum of all quantized
|
|
305
|
+
# encodings: stack
|
|
306
|
+
# usage_ratios: stack
|
|
307
|
+
|
|
308
|
+
quantized_sum = sum(quantized_list) # Element-wise sum
|
|
309
|
+
# ops.stack needs a list of tensors
|
|
310
|
+
encodings = ops.stack(encodings_list, axis=0)
|
|
311
|
+
# usage_ratios is list of scalars (tensors of rank 0) or simple tensors.
|
|
312
|
+
# ops.stack works.
|
|
313
|
+
usage_ratios = ops.stack(usage_ratios_list, axis=0)
|
|
314
|
+
|
|
315
|
+
return quantized_sum, encodings, usage_ratios, total_quantization_loss
|
|
316
|
+
|
|
317
|
+
def get_config(self):
|
|
318
|
+
config = super().get_config()
|
|
319
|
+
quantizers_config = []
|
|
320
|
+
for q in self.quantizers:
|
|
321
|
+
quantizers_config.append(keras.utils.serialize_keras_object(q))
|
|
322
|
+
config.update(
|
|
323
|
+
{
|
|
324
|
+
"quantizers": quantizers_config,
|
|
325
|
+
}
|
|
326
|
+
)
|
|
327
|
+
return config
|
|
328
|
+
|
|
329
|
+
@classmethod
|
|
330
|
+
def from_config(cls, config):
|
|
331
|
+
quantizers_config = config.pop("quantizers")
|
|
332
|
+
quantizers = [
|
|
333
|
+
keras.utils.deserialize_keras_object(q) for q in quantizers_config
|
|
334
|
+
]
|
|
335
|
+
return cls(quantizers=quantizers, **config)
|