keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
keras_hub/layers/__init__.py
CHANGED
|
@@ -138,6 +138,27 @@ from keras_hub.src.models.sam.sam_mask_decoder import (
|
|
|
138
138
|
from keras_hub.src.models.sam.sam_prompt_encoder import (
|
|
139
139
|
SAMPromptEncoder as SAMPromptEncoder,
|
|
140
140
|
)
|
|
141
|
+
from keras_hub.src.models.sam3.sam3_detr_decoder import (
|
|
142
|
+
SAM3DetrDecoder as SAM3DetrDecoder,
|
|
143
|
+
)
|
|
144
|
+
from keras_hub.src.models.sam3.sam3_detr_encoder import (
|
|
145
|
+
SAM3DetrEncoder as SAM3DetrEncoder,
|
|
146
|
+
)
|
|
147
|
+
from keras_hub.src.models.sam3.sam3_geometry_encoder import (
|
|
148
|
+
SAM3GeometryEncoder as SAM3GeometryEncoder,
|
|
149
|
+
)
|
|
150
|
+
from keras_hub.src.models.sam3.sam3_image_converter import (
|
|
151
|
+
SAM3ImageConverter as SAM3ImageConverter,
|
|
152
|
+
)
|
|
153
|
+
from keras_hub.src.models.sam3.sam3_mask_decoder import (
|
|
154
|
+
SAM3MaskDecoder as SAM3MaskDecoder,
|
|
155
|
+
)
|
|
156
|
+
from keras_hub.src.models.sam3.sam3_text_encoder import (
|
|
157
|
+
SAM3TextEncoder as SAM3TextEncoder,
|
|
158
|
+
)
|
|
159
|
+
from keras_hub.src.models.sam3.sam3_vision_encoder import (
|
|
160
|
+
SAM3VisionEncoder as SAM3VisionEncoder,
|
|
161
|
+
)
|
|
141
162
|
from keras_hub.src.models.segformer.segformer_image_converter import (
|
|
142
163
|
SegFormerImageConverter as SegFormerImageConverter,
|
|
143
164
|
)
|
keras_hub/models/__init__.py
CHANGED
|
@@ -211,6 +211,12 @@ from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor i
|
|
|
211
211
|
from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
|
|
212
212
|
DistilBertTokenizer as DistilBertTokenizer,
|
|
213
213
|
)
|
|
214
|
+
from keras_hub.src.models.edrec.edrec_backbone import (
|
|
215
|
+
EdRecBackbone as EdRecBackbone,
|
|
216
|
+
)
|
|
217
|
+
from keras_hub.src.models.edrec.edrec_seq2seq_lm import (
|
|
218
|
+
EdRecSeq2SeqLM as EdRecSeq2SeqLM,
|
|
219
|
+
)
|
|
214
220
|
from keras_hub.src.models.efficientnet.efficientnet_backbone import (
|
|
215
221
|
EfficientNetBackbone as EfficientNetBackbone,
|
|
216
222
|
)
|
|
@@ -629,6 +635,15 @@ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor i
|
|
|
629
635
|
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
|
|
630
636
|
RoformerV2Tokenizer as RoformerV2Tokenizer,
|
|
631
637
|
)
|
|
638
|
+
from keras_hub.src.models.rwkv7.rwkv7_backbone import (
|
|
639
|
+
RWKV7Backbone as RWKV7Backbone,
|
|
640
|
+
)
|
|
641
|
+
from keras_hub.src.models.rwkv7.rwkv7_causal_lm import (
|
|
642
|
+
RWKV7CausalLM as RWKV7CausalLM,
|
|
643
|
+
)
|
|
644
|
+
from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
|
|
645
|
+
RWKV7CausalLMPreprocessor as RWKV7CausalLMPreprocessor,
|
|
646
|
+
)
|
|
632
647
|
from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone
|
|
633
648
|
from keras_hub.src.models.sam.sam_image_segmenter import (
|
|
634
649
|
SAMImageSegmenter as SAMImageSegmenter,
|
|
@@ -636,6 +651,18 @@ from keras_hub.src.models.sam.sam_image_segmenter import (
|
|
|
636
651
|
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
|
|
637
652
|
SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor,
|
|
638
653
|
)
|
|
654
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
655
|
+
SAM3PromptableConceptBackbone as SAM3PromptableConceptBackbone,
|
|
656
|
+
)
|
|
657
|
+
from keras_hub.src.models.sam3.sam3_pc_image_segmenter import (
|
|
658
|
+
SAM3PromptableConceptImageSegmenter as SAM3PromptableConceptImageSegmenter,
|
|
659
|
+
)
|
|
660
|
+
from keras_hub.src.models.sam3.sam3_pc_image_segmenter_preprocessor import (
|
|
661
|
+
SAM3PromptableConceptImageSegmenterPreprocessor as SAM3PromptableConceptImageSegmenterPreprocessor,
|
|
662
|
+
)
|
|
663
|
+
from keras_hub.src.models.sam3.sam3_tokenizer import (
|
|
664
|
+
SAM3Tokenizer as SAM3Tokenizer,
|
|
665
|
+
)
|
|
639
666
|
from keras_hub.src.models.segformer.segformer_backbone import (
|
|
640
667
|
SegFormerBackbone as SegFormerBackbone,
|
|
641
668
|
)
|
|
@@ -290,16 +290,19 @@ def non_max_suppression(
|
|
|
290
290
|
"int32",
|
|
291
291
|
)
|
|
292
292
|
idx = ops.minimum(idx, num_boxes - 1)
|
|
293
|
+
idx = ops.cast(idx, "int32")
|
|
293
294
|
|
|
294
295
|
index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32")
|
|
295
296
|
take_along_axis_idx = ops.reshape(
|
|
296
297
|
idx + ops.expand_dims(index_offsets, 1), [-1]
|
|
297
298
|
)
|
|
299
|
+
take_along_axis_idx = ops.cast(take_along_axis_idx, "int32")
|
|
298
300
|
|
|
299
301
|
if keras.backend.backend() != "tensorflow":
|
|
300
|
-
|
|
301
|
-
ops.reshape(sorted_indices, [-1]),
|
|
302
|
+
sorted_indices_int = ops.cast(
|
|
303
|
+
ops.reshape(sorted_indices, [-1]), "int32"
|
|
302
304
|
)
|
|
305
|
+
idx = ops.take_along_axis(sorted_indices_int, take_along_axis_idx)
|
|
303
306
|
else:
|
|
304
307
|
import tensorflow as tf
|
|
305
308
|
|
|
@@ -1,281 +1,8 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
|
|
3
1
|
import keras
|
|
4
|
-
from keras import ops
|
|
5
2
|
|
|
6
3
|
from keras_hub.src.api_export import keras_hub_export
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
@keras_hub_export("keras_hub.layers.ReversibleEmbedding")
|
|
10
|
-
class ReversibleEmbedding(keras.layers.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
This layer is an extension of `keras.layers.Embedding` for language models.
|
|
14
|
-
This layer can be called "in reverse" with `reverse=True`, in which case the
|
|
15
|
-
layer will linearly project from `output_dim` back to `input_dim`.
|
|
16
|
-
|
|
17
|
-
By default, the reverse projection will use the transpose of the
|
|
18
|
-
`embeddings` weights to project to `input_dim` (weights are "tied"). If
|
|
19
|
-
`tie_weights=False`, the model will use a separate, trainable variable for
|
|
20
|
-
reverse projection.
|
|
21
|
-
|
|
22
|
-
This layer has no bias terms.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
input_dim: Integer. Size of the vocabulary,
|
|
26
|
-
i.e. maximum integer index + 1.
|
|
27
|
-
output_dim: Integer. Dimension of the dense embedding.
|
|
28
|
-
tie_weights: Boolean, whether or not the matrix for embedding and
|
|
29
|
-
the matrix for the `reverse` projection should share the same
|
|
30
|
-
weights.
|
|
31
|
-
embeddings_initializer: Initializer for the `embeddings`
|
|
32
|
-
matrix (see `keras.initializers`).
|
|
33
|
-
embeddings_regularizer: Regularizer function applied to
|
|
34
|
-
the `embeddings` matrix (see `keras.regularizers`).
|
|
35
|
-
embeddings_constraint: Constraint function applied to
|
|
36
|
-
the `embeddings` matrix (see `keras.constraints`).
|
|
37
|
-
mask_zero: Boolean, whether or not the input value 0 is a special
|
|
38
|
-
"padding" value that should be masked out.
|
|
39
|
-
reverse_dtype: The dtype for the reverse projection computation.
|
|
40
|
-
Defaults to the `compute_dtype` of the layer.
|
|
41
|
-
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
|
|
42
|
-
output logits will be scaled by
|
|
43
|
-
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
|
|
44
|
-
range of output logits and can improve training.
|
|
45
|
-
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
|
|
46
|
-
including `name`, `trainable`, `dtype` etc.
|
|
47
|
-
|
|
48
|
-
Call arguments:
|
|
49
|
-
inputs: The tensor inputs to the layer.
|
|
50
|
-
reverse: Boolean. If `True` the layer will perform a linear projection
|
|
51
|
-
from `output_dim` to `input_dim`, instead of a normal embedding
|
|
52
|
-
call. Default to `False`.
|
|
53
|
-
|
|
54
|
-
Example:
|
|
55
|
-
```python
|
|
56
|
-
batch_size = 16
|
|
57
|
-
vocab_size = 100
|
|
58
|
-
hidden_dim = 32
|
|
59
|
-
seq_length = 50
|
|
60
|
-
|
|
61
|
-
# Generate random inputs.
|
|
62
|
-
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
|
|
63
|
-
|
|
64
|
-
embedding = keras_hub.layers.ReversibleEmbedding(vocab_size, hidden_dim)
|
|
65
|
-
# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
|
|
66
|
-
hidden_states = embedding(token_ids)
|
|
67
|
-
# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
|
|
68
|
-
logits = embedding(hidden_states, reverse=True)
|
|
69
|
-
```
|
|
70
|
-
|
|
71
|
-
References:
|
|
72
|
-
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
|
73
|
-
- [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
def __init__(
|
|
77
|
-
self,
|
|
78
|
-
input_dim,
|
|
79
|
-
output_dim,
|
|
80
|
-
tie_weights=True,
|
|
81
|
-
embeddings_initializer="uniform",
|
|
82
|
-
embeddings_regularizer=None,
|
|
83
|
-
embeddings_constraint=None,
|
|
84
|
-
mask_zero=False,
|
|
85
|
-
reverse_dtype=None,
|
|
86
|
-
logit_soft_cap=None,
|
|
87
|
-
**kwargs,
|
|
88
|
-
):
|
|
89
|
-
super().__init__(
|
|
90
|
-
input_dim,
|
|
91
|
-
output_dim,
|
|
92
|
-
embeddings_initializer=embeddings_initializer,
|
|
93
|
-
embeddings_regularizer=embeddings_regularizer,
|
|
94
|
-
embeddings_constraint=embeddings_constraint,
|
|
95
|
-
mask_zero=mask_zero,
|
|
96
|
-
**kwargs,
|
|
97
|
-
)
|
|
98
|
-
self.tie_weights = tie_weights
|
|
99
|
-
self.reverse_dtype = reverse_dtype
|
|
100
|
-
self.logit_soft_cap = logit_soft_cap
|
|
101
|
-
|
|
102
|
-
def build(self, inputs_shape=None):
|
|
103
|
-
super().build(inputs_shape)
|
|
104
|
-
if (
|
|
105
|
-
not self.tie_weights
|
|
106
|
-
and getattr(self, "quantization_mode", None) != "int8"
|
|
107
|
-
):
|
|
108
|
-
self.reverse_embeddings = self.add_weight(
|
|
109
|
-
name="reverse_embeddings",
|
|
110
|
-
shape=(self.output_dim, self.input_dim),
|
|
111
|
-
initializer=self.embeddings_initializer,
|
|
112
|
-
dtype=self.dtype,
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
def call(self, inputs, reverse=False):
|
|
116
|
-
if reverse:
|
|
117
|
-
if self.tie_weights:
|
|
118
|
-
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
|
|
119
|
-
else:
|
|
120
|
-
kernel = self.reverse_embeddings
|
|
121
|
-
if self.reverse_dtype is not None:
|
|
122
|
-
inputs = ops.cast(inputs, self.reverse_dtype)
|
|
123
|
-
kernel = ops.cast(kernel, self.reverse_dtype)
|
|
124
|
-
logits = ops.matmul(inputs, kernel)
|
|
125
|
-
# Optionally soft-cap logits.
|
|
126
|
-
if self.logit_soft_cap is not None:
|
|
127
|
-
soft_cap = self.logit_soft_cap
|
|
128
|
-
logits = ops.tanh(logits / soft_cap) * soft_cap
|
|
129
|
-
return logits
|
|
130
|
-
|
|
131
|
-
return super().call(inputs)
|
|
132
|
-
|
|
133
|
-
def get_config(self):
|
|
134
|
-
config = super().get_config()
|
|
135
|
-
config.update(
|
|
136
|
-
{
|
|
137
|
-
"tie_weights": self.tie_weights,
|
|
138
|
-
"reverse_dtype": self.reverse_dtype,
|
|
139
|
-
"logit_soft_cap": self.logit_soft_cap,
|
|
140
|
-
}
|
|
141
|
-
)
|
|
142
|
-
return config
|
|
143
|
-
|
|
144
|
-
def save_own_variables(self, store):
|
|
145
|
-
if not self.built:
|
|
146
|
-
return
|
|
147
|
-
super().save_own_variables(store)
|
|
148
|
-
target_variables = []
|
|
149
|
-
if not self.tie_weights:
|
|
150
|
-
# Store the reverse embedding weights as the last weights.
|
|
151
|
-
target_variables.append(self.reverse_embeddings)
|
|
152
|
-
if getattr(self, "quantization_mode", None) == "int8":
|
|
153
|
-
target_variables.append(self.reverse_embeddings_scale)
|
|
154
|
-
for i, variable in enumerate(target_variables, start=len(store)):
|
|
155
|
-
store[str(i)] = variable
|
|
156
|
-
|
|
157
|
-
def load_own_variables(self, store):
|
|
158
|
-
if not self.built:
|
|
159
|
-
self.build()
|
|
160
|
-
super().load_own_variables(store)
|
|
161
|
-
if not self.tie_weights:
|
|
162
|
-
# Last weights in the stores are the reverse embedding weights.
|
|
163
|
-
target_variables = [self.reverse_embeddings]
|
|
164
|
-
if getattr(self, "quantization_mode", None) == "int8":
|
|
165
|
-
target_variables.append(self.reverse_embeddings_scale)
|
|
166
|
-
for i, variable in enumerate(
|
|
167
|
-
target_variables, start=len(store) - len(target_variables)
|
|
168
|
-
):
|
|
169
|
-
variable.assign(store[str(i)])
|
|
170
|
-
|
|
171
|
-
def compute_output_spec(self, inputs, reverse=False):
|
|
172
|
-
output_shape = list(inputs.shape)
|
|
173
|
-
if reverse:
|
|
174
|
-
output_shape[-1] = self.input_dim
|
|
175
|
-
else:
|
|
176
|
-
output_shape += [self.output_dim]
|
|
177
|
-
return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
|
|
178
|
-
|
|
179
|
-
# Quantization-related (int8) methods
|
|
180
|
-
|
|
181
|
-
def quantized_call(self, inputs, reverse=False):
|
|
182
|
-
# TODO (hongyu): This function could be removed once we add `*args` and
|
|
183
|
-
# `**kwargs` for `Embedding.quantized_call`
|
|
184
|
-
if self.quantization_mode == "int8":
|
|
185
|
-
return self._int8_call(inputs, reverse=reverse)
|
|
186
|
-
else:
|
|
187
|
-
self._quantization_mode_error(self.quantization_mode)
|
|
188
|
-
|
|
189
|
-
def _int8_build(self, embeddings_shape=None):
|
|
190
|
-
if (
|
|
191
|
-
"embeddings_shape"
|
|
192
|
-
in inspect.signature(super()._int8_build).parameters
|
|
193
|
-
):
|
|
194
|
-
if embeddings_shape is None:
|
|
195
|
-
embeddings_shape = (self.input_dim, self.output_dim)
|
|
196
|
-
super()._int8_build(embeddings_shape=embeddings_shape)
|
|
197
|
-
else:
|
|
198
|
-
# Backward compatibility for older versions of Keras.
|
|
199
|
-
super()._int8_build()
|
|
200
|
-
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
|
|
201
|
-
if not self.tie_weights:
|
|
202
|
-
self.reverse_embeddings = self.add_weight(
|
|
203
|
-
name="reverse_embeddings",
|
|
204
|
-
shape=(self.output_dim, self.input_dim),
|
|
205
|
-
initializer="zeros",
|
|
206
|
-
dtype="int8",
|
|
207
|
-
trainable=False,
|
|
208
|
-
)
|
|
209
|
-
self.reverse_embeddings_scale = self.add_weight(
|
|
210
|
-
name="reverse_embeddings_scale",
|
|
211
|
-
shape=(self.input_dim,),
|
|
212
|
-
initializer="ones",
|
|
213
|
-
trainable=False,
|
|
214
|
-
)
|
|
215
|
-
self._is_quantized = True
|
|
216
|
-
|
|
217
|
-
def _int8_call(self, inputs, reverse=False):
|
|
218
|
-
if reverse:
|
|
219
|
-
if self.tie_weights:
|
|
220
|
-
kernel = ops.transpose(self._embeddings)
|
|
221
|
-
scale = ops.transpose(self.embeddings_scale)
|
|
222
|
-
else:
|
|
223
|
-
kernel = self.reverse_embeddings
|
|
224
|
-
scale = self.reverse_embeddings_scale
|
|
225
|
-
inputs, inputs_scale = self.inputs_quantizer(inputs)
|
|
226
|
-
logits = ops.matmul(inputs, kernel)
|
|
227
|
-
# De-scale outputs
|
|
228
|
-
logits = ops.cast(logits, self.compute_dtype)
|
|
229
|
-
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
|
|
230
|
-
# Optionally soft-cap logits.
|
|
231
|
-
if self.logit_soft_cap is not None:
|
|
232
|
-
soft_cap = self.logit_soft_cap
|
|
233
|
-
logits = ops.tanh(logits / soft_cap) * soft_cap
|
|
234
|
-
return logits
|
|
235
|
-
|
|
236
|
-
return super()._int8_call(inputs)
|
|
237
|
-
|
|
238
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
239
|
-
del config
|
|
240
|
-
if type_check and type(self) is not ReversibleEmbedding:
|
|
241
|
-
raise self._not_implemented_error(self.quantize)
|
|
242
|
-
|
|
243
|
-
def abs_max_quantize(inputs, axis):
|
|
244
|
-
return keras.quantizers.abs_max_quantize(
|
|
245
|
-
inputs, axis=axis, to_numpy=True
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
if mode != "int8":
|
|
249
|
-
raise NotImplementedError(
|
|
250
|
-
"Invalid quantization mode. Expected 'int8'. "
|
|
251
|
-
f"Received: quantization_mode={mode}"
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
embeddings_shape = (self.input_dim, self.output_dim)
|
|
255
|
-
if mode == "int8":
|
|
256
|
-
embeddings, embeddings_scale = abs_max_quantize(
|
|
257
|
-
self._embeddings, axis=-1
|
|
258
|
-
)
|
|
259
|
-
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
260
|
-
del self._embeddings
|
|
261
|
-
if not self.tie_weights:
|
|
262
|
-
reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
|
|
263
|
-
self.reverse_embeddings, axis=0
|
|
264
|
-
)
|
|
265
|
-
reverse_embeddings_scale = ops.squeeze(
|
|
266
|
-
reverse_embeddings_scale, axis=0
|
|
267
|
-
)
|
|
268
|
-
del self.reverse_embeddings
|
|
269
|
-
self.quantized_build(embeddings_shape, mode)
|
|
270
|
-
if mode == "int8":
|
|
271
|
-
self._embeddings.assign(embeddings)
|
|
272
|
-
self.embeddings_scale.assign(embeddings_scale)
|
|
273
|
-
if not self.tie_weights:
|
|
274
|
-
self.reverse_embeddings.assign(reverse_embeddings)
|
|
275
|
-
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
276
|
-
|
|
277
|
-
if self.dtype_policy.quantization_mode is None:
|
|
278
|
-
policy = keras.dtype_policies.get(
|
|
279
|
-
f"{mode}_from_{self.dtype_policy.name}"
|
|
280
|
-
)
|
|
281
|
-
self.dtype_policy = policy
|
|
7
|
+
class ReversibleEmbedding(keras.layers.ReversibleEmbedding):
|
|
8
|
+
pass
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
3
|
+
from keras.src.backend import get_keras_mask
|
|
4
|
+
from keras.src.backend import set_keras_mask
|
|
2
5
|
|
|
3
6
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
7
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
8
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
9
9
|
|
|
10
10
|
|
|
@@ -128,10 +128,10 @@ class TokenAndPositionEmbedding(keras.layers.Layer):
|
|
|
128
128
|
positions=positions,
|
|
129
129
|
)
|
|
130
130
|
outputs = embedded_tokens + embedded_positions
|
|
131
|
+
mask = get_keras_mask(embedded_tokens)
|
|
132
|
+
if mask is not None:
|
|
133
|
+
set_keras_mask(outputs, mask)
|
|
131
134
|
return outputs
|
|
132
135
|
|
|
133
|
-
def compute_mask(self, inputs, mask=None):
|
|
134
|
-
return self.token_embedding.compute_mask(inputs, mask=mask)
|
|
135
|
-
|
|
136
136
|
def compute_output_shape(self, input_shape):
|
|
137
137
|
return tuple(input_shape) + (self.embedding_dim,)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from absl import logging
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.src.backend import get_keras_mask
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
def _check_masks_shapes(inputs, padding_mask, attention_mask):
|
|
6
7
|
mask = padding_mask
|
|
7
|
-
if
|
|
8
|
-
mask = inputs
|
|
8
|
+
if mask is None:
|
|
9
|
+
mask = get_keras_mask(inputs)
|
|
9
10
|
if mask is not None:
|
|
10
11
|
if len(mask.shape) != 2:
|
|
11
12
|
raise ValueError(
|
|
@@ -68,17 +69,16 @@ def merge_padding_and_attention_mask(
|
|
|
68
69
|
returned mask is padding_mask with one additional axis.
|
|
69
70
|
"""
|
|
70
71
|
_check_masks_shapes(inputs, padding_mask, attention_mask)
|
|
71
|
-
mask
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
mask = inputs._keras_mask
|
|
77
|
-
else:
|
|
72
|
+
# We look for a padding mask from the input data.
|
|
73
|
+
mask = get_keras_mask(inputs)
|
|
74
|
+
# But if padding mask is explicitly provided, we use it.
|
|
75
|
+
if padding_mask is not None:
|
|
76
|
+
if mask is not None:
|
|
78
77
|
logging.warning(
|
|
79
78
|
"You are explicitly setting `padding_mask` while the `inputs` "
|
|
80
79
|
"have built-in mask, so the built-in mask is ignored."
|
|
81
80
|
)
|
|
81
|
+
mask = padding_mask
|
|
82
82
|
if mask is not None:
|
|
83
83
|
# Add an axis for broadcasting, the attention mask should be 2D
|
|
84
84
|
# (not including the batch axis).
|
|
@@ -7,9 +7,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
|
7
7
|
|
|
8
8
|
try:
|
|
9
9
|
import tensorflow as tf
|
|
10
|
-
import tensorflow_text as tf_text
|
|
11
10
|
except ImportError:
|
|
12
11
|
tf = None
|
|
12
|
+
try:
|
|
13
|
+
import tensorflow_text as tf_text
|
|
14
|
+
except ImportError:
|
|
13
15
|
tf_text = None
|
|
14
16
|
|
|
15
17
|
|
|
@@ -8,9 +8,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
10
|
import tensorflow as tf
|
|
11
|
-
import tensorflow_text as tf_text
|
|
12
11
|
except ImportError:
|
|
13
12
|
tf = None
|
|
13
|
+
try:
|
|
14
|
+
import tensorflow_text as tf_text
|
|
15
|
+
except ImportError:
|
|
14
16
|
tf_text = None
|
|
15
17
|
|
|
16
18
|
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
keras_hub/src/models/backbone.py
CHANGED
|
@@ -107,6 +107,9 @@ class Backbone(keras.Model):
|
|
|
107
107
|
def from_config(cls, config):
|
|
108
108
|
# The default `from_config()` for functional models will return a
|
|
109
109
|
# vanilla `keras.Model`. We override it to get a subclass instance back.
|
|
110
|
+
config = config.copy()
|
|
111
|
+
if "dtype" in config and isinstance(config["dtype"], dict):
|
|
112
|
+
config["dtype"] = keras.dtype_policies.get(config["dtype"])
|
|
110
113
|
return cls(**config)
|
|
111
114
|
|
|
112
115
|
@classproperty
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
|
|
9
7
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
10
8
|
from keras_hub.src.models.backbone import Backbone
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -35,7 +33,7 @@ class BertBackbone(Backbone):
|
|
|
35
33
|
vocabulary_size: int. The size of the token vocabulary.
|
|
36
34
|
num_layers: int. The number of transformer layers.
|
|
37
35
|
num_heads: int. The number of attention heads for each transformer.
|
|
38
|
-
The
|
|
36
|
+
The hidden_dim must be divisible by the number of attention heads.
|
|
39
37
|
hidden_dim: int. The size of the transformer encoding and pooler layers.
|
|
40
38
|
intermediate_dim: int. The output dimension of the first Dense layer in
|
|
41
39
|
a two-layer feedforward network for each transformer.
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.bloom.bloom_decoder import BloomDecoder
|
|
9
7
|
|
|
@@ -196,7 +196,7 @@ class CausalLM(Task):
|
|
|
196
196
|
|
|
197
197
|
# Create an explicit tuple of all variable state.
|
|
198
198
|
state = (
|
|
199
|
-
self.sampler.variables,
|
|
199
|
+
[v.value for v in self.sampler.variables],
|
|
200
200
|
# Use the explicit variable.value to preserve the
|
|
201
201
|
# sharding spec of distribution.
|
|
202
202
|
[v.value for v in self.trainable_variables],
|
|
@@ -431,7 +431,7 @@ class CausalLM(Task):
|
|
|
431
431
|
self.generate_function = None
|
|
432
432
|
|
|
433
433
|
def get_quantization_layer_structure(self, mode):
|
|
434
|
-
if mode
|
|
434
|
+
if mode not in ["gptq", "awq"]:
|
|
435
435
|
return None
|
|
436
436
|
|
|
437
437
|
backbone = self.backbone
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.deberta_v3.disentangled_attention_encoder import (
|
|
9
7
|
DisentangledAttentionEncoder,
|