keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import layers
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
6
|
+
from keras_hub.src.layers.modeling.token_and_position_embedding import (
|
|
7
|
+
TokenAndPositionEmbedding,
|
|
8
|
+
)
|
|
9
|
+
from keras_hub.src.models.sam3.sam3_utils import create_bidirectional_mask
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CLIPEncoderLayer(layers.Layer):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
hidden_dim,
|
|
16
|
+
num_heads,
|
|
17
|
+
intermediate_dim,
|
|
18
|
+
intermediate_activation="gelu",
|
|
19
|
+
layer_norm_epsilon=1e-5,
|
|
20
|
+
**kwargs,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(**kwargs)
|
|
23
|
+
self.hidden_dim = int(hidden_dim)
|
|
24
|
+
self.num_heads = int(num_heads)
|
|
25
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
26
|
+
self.intermediate_activation = intermediate_activation
|
|
27
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
28
|
+
|
|
29
|
+
self.layer_norm_1 = layers.LayerNormalization(
|
|
30
|
+
epsilon=self.layer_norm_epsilon,
|
|
31
|
+
dtype=self.dtype_policy,
|
|
32
|
+
name="layer_norm_1",
|
|
33
|
+
)
|
|
34
|
+
self.attention = layers.MultiHeadAttention(
|
|
35
|
+
num_heads,
|
|
36
|
+
hidden_dim // num_heads,
|
|
37
|
+
dtype=self.dtype_policy,
|
|
38
|
+
name="attention",
|
|
39
|
+
)
|
|
40
|
+
self.layer_norm_2 = layers.LayerNormalization(
|
|
41
|
+
epsilon=self.layer_norm_epsilon,
|
|
42
|
+
dtype=self.dtype_policy,
|
|
43
|
+
name="layer_norm_2",
|
|
44
|
+
)
|
|
45
|
+
self.dense_1 = layers.Dense(
|
|
46
|
+
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
|
|
47
|
+
)
|
|
48
|
+
self.activation = layers.Activation(
|
|
49
|
+
intermediate_activation, dtype=self.dtype_policy, name="activation"
|
|
50
|
+
)
|
|
51
|
+
self.dense_2 = layers.Dense(
|
|
52
|
+
self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def build(self, inputs_shape, attention_mask_shape):
|
|
56
|
+
self.layer_norm_1.build(inputs_shape)
|
|
57
|
+
self.attention.build(inputs_shape, inputs_shape, inputs_shape)
|
|
58
|
+
self.layer_norm_2.build(inputs_shape)
|
|
59
|
+
self.dense_1.build(inputs_shape)
|
|
60
|
+
input_shape = self.dense_1.compute_output_shape(inputs_shape)
|
|
61
|
+
self.dense_2.build(input_shape)
|
|
62
|
+
|
|
63
|
+
def compute_output_shape(self, inputs_shape, attention_mask_shape):
|
|
64
|
+
outputs_shape = list(inputs_shape)
|
|
65
|
+
outputs_shape[-1] = self.hidden_dim
|
|
66
|
+
return outputs_shape
|
|
67
|
+
|
|
68
|
+
def call(self, inputs, attention_mask, training=None):
|
|
69
|
+
residual = inputs
|
|
70
|
+
x = self.layer_norm_1(inputs)
|
|
71
|
+
x = self.attention(
|
|
72
|
+
x,
|
|
73
|
+
x,
|
|
74
|
+
x,
|
|
75
|
+
attention_mask=ops.cast(attention_mask, dtype="bool"),
|
|
76
|
+
training=training,
|
|
77
|
+
)
|
|
78
|
+
x = ops.add(residual, x)
|
|
79
|
+
|
|
80
|
+
residual = x
|
|
81
|
+
x = self.dense_1(self.layer_norm_2(residual))
|
|
82
|
+
x = self.activation(x)
|
|
83
|
+
x = self.dense_2(x)
|
|
84
|
+
x = ops.add(residual, x)
|
|
85
|
+
return x
|
|
86
|
+
|
|
87
|
+
def get_config(self):
|
|
88
|
+
config = super().get_config()
|
|
89
|
+
config.update(
|
|
90
|
+
{
|
|
91
|
+
"hidden_dim": self.hidden_dim,
|
|
92
|
+
"num_heads": self.num_heads,
|
|
93
|
+
"intermediate_dim": self.intermediate_dim,
|
|
94
|
+
"intermediate_activation": self.intermediate_activation,
|
|
95
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
return config
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@keras_hub_export("keras_hub.layers.SAM3TextEncoder")
|
|
102
|
+
class SAM3TextEncoder(layers.Layer):
|
|
103
|
+
"""A text encoder for the Segment Anything Model 3 (SAM3).
|
|
104
|
+
|
|
105
|
+
This layer implements a CLIP-style text encoder. It processes token IDs and
|
|
106
|
+
padding masks to produce text embeddings that are used as prompts for
|
|
107
|
+
segmentation.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
vocabulary_size: int. The size of the vocabulary.
|
|
111
|
+
embedding_dim: int. The dimension of the token embeddings.
|
|
112
|
+
hidden_dim: int. The hidden dimension of the transformer layers.
|
|
113
|
+
num_layers: int. The number of transformer layers.
|
|
114
|
+
num_heads: int. The number of attention heads.
|
|
115
|
+
intermediate_dim: int. The dimension of the intermediate layer in the
|
|
116
|
+
transformer's MLP.
|
|
117
|
+
intermediate_activation: str. The activation function for the
|
|
118
|
+
transformer layers. Defaults to `"gelu"`.
|
|
119
|
+
max_sequence_length: int. The maximum sequence length. Defaults to
|
|
120
|
+
`32`.
|
|
121
|
+
layer_norm_epsilon: float. The epsilon value for layer normalization.
|
|
122
|
+
Defaults to `1e-6`.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
vocabulary_size,
|
|
128
|
+
embedding_dim,
|
|
129
|
+
hidden_dim,
|
|
130
|
+
num_layers,
|
|
131
|
+
num_heads,
|
|
132
|
+
intermediate_dim,
|
|
133
|
+
intermediate_activation="gelu",
|
|
134
|
+
max_sequence_length=32,
|
|
135
|
+
layer_norm_epsilon=1e-5,
|
|
136
|
+
**kwargs,
|
|
137
|
+
):
|
|
138
|
+
super().__init__(**kwargs)
|
|
139
|
+
self.vocabulary_size = int(vocabulary_size)
|
|
140
|
+
self.embedding_dim = int(embedding_dim)
|
|
141
|
+
self.hidden_dim = int(hidden_dim)
|
|
142
|
+
self.num_layers = int(num_layers)
|
|
143
|
+
self.num_heads = int(num_heads)
|
|
144
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
145
|
+
self.intermediate_activation = intermediate_activation
|
|
146
|
+
self.max_sequence_length = int(max_sequence_length)
|
|
147
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
148
|
+
|
|
149
|
+
self.embedding = TokenAndPositionEmbedding(
|
|
150
|
+
vocabulary_size=self.vocabulary_size,
|
|
151
|
+
sequence_length=self.max_sequence_length,
|
|
152
|
+
embedding_dim=self.embedding_dim,
|
|
153
|
+
dtype=self.dtype_policy,
|
|
154
|
+
name="embedding",
|
|
155
|
+
)
|
|
156
|
+
self.encoder_layers = [
|
|
157
|
+
CLIPEncoderLayer(
|
|
158
|
+
self.hidden_dim,
|
|
159
|
+
self.num_heads,
|
|
160
|
+
self.intermediate_dim,
|
|
161
|
+
self.intermediate_activation,
|
|
162
|
+
dtype=self.dtype_policy,
|
|
163
|
+
name=f"encoder_layer_{i}",
|
|
164
|
+
)
|
|
165
|
+
for i in range(self.num_layers)
|
|
166
|
+
]
|
|
167
|
+
self.layer_norm = layers.LayerNormalization(
|
|
168
|
+
epsilon=self.layer_norm_epsilon,
|
|
169
|
+
dtype=self.dtype_policy,
|
|
170
|
+
name="layer_norm",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def build(self, token_ids_shape, padding_masks_shape):
|
|
174
|
+
self.embedding.build(token_ids_shape)
|
|
175
|
+
x_shape = self.embedding.compute_output_shape(token_ids_shape)
|
|
176
|
+
for layer in self.encoder_layers:
|
|
177
|
+
layer.build(x_shape, padding_masks_shape)
|
|
178
|
+
self.layer_norm.build(x_shape)
|
|
179
|
+
|
|
180
|
+
def call(self, token_ids, padding_masks, training=None):
|
|
181
|
+
x = self.embedding(token_ids, training=training)
|
|
182
|
+
padding_masks = create_bidirectional_mask(x, padding_masks)
|
|
183
|
+
for layer in self.encoder_layers:
|
|
184
|
+
x = layer(x, padding_masks, training=training)
|
|
185
|
+
x = self.layer_norm(x)
|
|
186
|
+
return x
|
|
187
|
+
|
|
188
|
+
def get_config(self):
|
|
189
|
+
config = super().get_config()
|
|
190
|
+
config.update(
|
|
191
|
+
{
|
|
192
|
+
"vocabulary_size": self.vocabulary_size,
|
|
193
|
+
"embedding_dim": self.embedding_dim,
|
|
194
|
+
"hidden_dim": self.hidden_dim,
|
|
195
|
+
"num_layers": self.num_layers,
|
|
196
|
+
"num_heads": self.num_heads,
|
|
197
|
+
"intermediate_dim": self.intermediate_dim,
|
|
198
|
+
"intermediate_activation": self.intermediate_activation,
|
|
199
|
+
"max_sequence_length": self.max_sequence_length,
|
|
200
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
return config
|
|
204
|
+
|
|
205
|
+
def compute_output_shape(self, token_ids_shape, padding_masks_shape):
|
|
206
|
+
return self.embedding.compute_output_shape(token_ids_shape)
|
|
207
|
+
|
|
208
|
+
def compute_output_spec(self, token_ids, padding_masks):
|
|
209
|
+
output_shape = self.compute_output_shape(
|
|
210
|
+
token_ids.shape, padding_masks.shape
|
|
211
|
+
)
|
|
212
|
+
return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer
|
|
3
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
4
|
+
SAM3PromptableConceptBackbone,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@keras_hub_export(
|
|
9
|
+
[
|
|
10
|
+
"keras_hub.tokenizers.SAM3Tokenizer",
|
|
11
|
+
"keras_hub.models.SAM3Tokenizer",
|
|
12
|
+
]
|
|
13
|
+
)
|
|
14
|
+
class SAM3Tokenizer(CLIPTokenizer):
|
|
15
|
+
"""A SAM3 tokenizer using Byte-Pair Encoding subword segmentation.
|
|
16
|
+
|
|
17
|
+
This tokenizer class will tokenize raw strings into integer sequences and
|
|
18
|
+
is based on `keras_hub.tokenizers.BytePairTokenizer`. Unlike the
|
|
19
|
+
underlying tokenizer, it will check for all special tokens needed by SAM3
|
|
20
|
+
models and provides a `from_preset()` method to automatically download
|
|
21
|
+
a matching vocabulary for a SAM3 preset.
|
|
22
|
+
|
|
23
|
+
If input is a batch of strings (rank > 0), the layer will output a
|
|
24
|
+
`tf.RaggedTensor` where the last dimension of the output is ragged.
|
|
25
|
+
|
|
26
|
+
If input is a scalar string (rank == 0), the layer will output a dense
|
|
27
|
+
`tf.Tensor` with static shape `[None]`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
vocabulary: string or dict, maps token to integer ids. If it is a
|
|
31
|
+
string, it should be the file path to a json file.
|
|
32
|
+
merges: string or list, contains the merge rule. If it is a string,
|
|
33
|
+
it should be the file path to merge rules. The merge rule file
|
|
34
|
+
should have one merge rule per line. Every merge rule contains
|
|
35
|
+
merge entities separated by a space.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
# Unbatched input.
|
|
41
|
+
tokenizer = keras_hub.models.SAM3Tokenizer.from_preset("sam3_pcs")
|
|
42
|
+
tokenizer("The quick brown fox jumped.")
|
|
43
|
+
|
|
44
|
+
# Batched input.
|
|
45
|
+
tokenizer(["The quick brown fox jumped.", "The fox slept."])
|
|
46
|
+
|
|
47
|
+
# Detokenization.
|
|
48
|
+
tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
|
|
49
|
+
```
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
backbone_cls = SAM3PromptableConceptBackbone
|
|
53
|
+
|
|
54
|
+
def __init__(self, vocabulary=None, merges=None, **kwargs):
|
|
55
|
+
super().__init__(
|
|
56
|
+
vocabulary=vocabulary,
|
|
57
|
+
merges=merges,
|
|
58
|
+
pad_with_end_token=True,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def get_config(self):
|
|
63
|
+
config = super().get_config()
|
|
64
|
+
del config["pad_with_end_token"] # Always True for SAM3Tokenizer.
|
|
65
|
+
return config
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def window_partition(x, height, width, window_size, hidden_dim):
|
|
5
|
+
x = ops.reshape(
|
|
6
|
+
x,
|
|
7
|
+
(
|
|
8
|
+
-1,
|
|
9
|
+
height // window_size,
|
|
10
|
+
window_size,
|
|
11
|
+
width // window_size,
|
|
12
|
+
window_size,
|
|
13
|
+
hidden_dim,
|
|
14
|
+
),
|
|
15
|
+
)
|
|
16
|
+
x = ops.transpose(x, axes=(0, 1, 3, 2, 4, 5))
|
|
17
|
+
x = ops.reshape(x, (-1, window_size, window_size, hidden_dim))
|
|
18
|
+
return x
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def window_unpartition(x, height, width, window_size, hidden_dim):
|
|
22
|
+
x = ops.reshape(
|
|
23
|
+
x,
|
|
24
|
+
(
|
|
25
|
+
-1,
|
|
26
|
+
height // window_size,
|
|
27
|
+
width // window_size,
|
|
28
|
+
window_size,
|
|
29
|
+
window_size,
|
|
30
|
+
hidden_dim,
|
|
31
|
+
),
|
|
32
|
+
)
|
|
33
|
+
x = ops.transpose(x, axes=(0, 1, 3, 2, 4, 5))
|
|
34
|
+
x = ops.reshape(x, (-1, height, width, hidden_dim))
|
|
35
|
+
return x
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def box_cxcywh_to_xyxy(boxes):
|
|
39
|
+
x_c, y_c, w, h = ops.unstack(boxes, num=4, axis=-1)
|
|
40
|
+
return ops.stack(
|
|
41
|
+
[
|
|
42
|
+
ops.subtract(x_c, ops.multiply(0.5, w)),
|
|
43
|
+
ops.subtract(y_c, ops.multiply(0.5, h)),
|
|
44
|
+
ops.add(x_c, ops.multiply(0.5, w)),
|
|
45
|
+
ops.add(y_c, ops.multiply(0.5, h)),
|
|
46
|
+
],
|
|
47
|
+
axis=-1,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def concatenate_padded_sequences(
|
|
52
|
+
sequence1, mask1, sequence_len1, sequence2, mask2, sequence_len2, hidden_dim
|
|
53
|
+
):
|
|
54
|
+
"""Concatenate two sequences with padding masks.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
sequence1: A tensor of shape (batch_size, sequence_length1, hidden_dim).
|
|
58
|
+
mask1: A boolean tensor of shape (batch_size, sequence_length1).
|
|
59
|
+
sequence2: A tensor of shape (batch_size, sequence_length2, hidden_dim).
|
|
60
|
+
mask2: A boolean tensor of shape (batch_size, sequence_length2).
|
|
61
|
+
hidden_dim: An integer representing the hidden dimension.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
concatenated_sequence: A tensor of shape
|
|
65
|
+
(batch_size, sequence_length1 + sequence_length2, hidden_dim).
|
|
66
|
+
concatenated_mask: A boolean tensor of shape
|
|
67
|
+
(batch_size, sequence_length1 + sequence_length2).
|
|
68
|
+
"""
|
|
69
|
+
batch_size = ops.shape(sequence1)[0]
|
|
70
|
+
max_length = sequence_len1 + sequence_len2
|
|
71
|
+
|
|
72
|
+
actual_sequence_1_lengths = ops.sum(ops.cast(mask1, dtype="int32"), axis=1)
|
|
73
|
+
actual_sequence_2_lengths = ops.sum(ops.cast(mask2, dtype="int32"), axis=1)
|
|
74
|
+
final_lengths = ops.add(
|
|
75
|
+
actual_sequence_1_lengths, actual_sequence_2_lengths
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
concatenated_mask = ops.less(
|
|
79
|
+
ops.tile(
|
|
80
|
+
ops.expand_dims(ops.arange(max_length, dtype="int32"), axis=0),
|
|
81
|
+
[batch_size, 1],
|
|
82
|
+
),
|
|
83
|
+
ops.expand_dims(final_lengths, axis=1),
|
|
84
|
+
)
|
|
85
|
+
concatenated_sequence = ops.concatenate(
|
|
86
|
+
[
|
|
87
|
+
sequence1,
|
|
88
|
+
ops.zeros(
|
|
89
|
+
(batch_size, sequence_len2, hidden_dim), dtype=sequence1.dtype
|
|
90
|
+
),
|
|
91
|
+
],
|
|
92
|
+
axis=1,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Create the indices.
|
|
96
|
+
indices = ops.tile(
|
|
97
|
+
ops.expand_dims(ops.arange(sequence_len2, dtype="int32"), axis=0),
|
|
98
|
+
[batch_size, 1],
|
|
99
|
+
)
|
|
100
|
+
indices = ops.add(
|
|
101
|
+
indices, ops.expand_dims(actual_sequence_1_lengths, axis=-1)
|
|
102
|
+
)
|
|
103
|
+
# Adjust the indices to account for batch dimension.
|
|
104
|
+
to_add = ops.multiply(ops.arange(batch_size, dtype="int32"), max_length)
|
|
105
|
+
indices = ops.add(
|
|
106
|
+
indices, ops.cast(ops.expand_dims(to_add, axis=-1), "int32")
|
|
107
|
+
)
|
|
108
|
+
# `ops.scatter_update` requires 2D indices. We flatten the inputs before
|
|
109
|
+
# scattering and reshape back after.
|
|
110
|
+
flat_concatenated_sequence = ops.scatter_update(
|
|
111
|
+
ops.reshape(concatenated_sequence, (-1, hidden_dim)),
|
|
112
|
+
ops.reshape(indices, (-1, 1)),
|
|
113
|
+
ops.reshape(
|
|
114
|
+
ops.cast(sequence2, concatenated_sequence.dtype), (-1, hidden_dim)
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
concatenated_sequence = ops.reshape(
|
|
119
|
+
flat_concatenated_sequence, (batch_size, -1, hidden_dim)
|
|
120
|
+
)
|
|
121
|
+
return concatenated_sequence, concatenated_mask
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def create_bidirectional_mask(input_embeds, attention_mask):
|
|
125
|
+
seq_len = ops.shape(input_embeds)[1]
|
|
126
|
+
attention_mask = attention_mask[:, None, None, :]
|
|
127
|
+
return ops.tile(attention_mask, (1, 1, seq_len, 1))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def inverse_sigmoid(x, eps=1e-3):
|
|
131
|
+
x = ops.clip(x, 0.0, 1.0)
|
|
132
|
+
x1 = ops.maximum(x, eps)
|
|
133
|
+
x2 = ops.maximum(ops.subtract(1.0, x), eps)
|
|
134
|
+
return ops.log(ops.divide(x1, x2))
|