keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,212 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.api_export import keras_hub_export
6
+ from keras_hub.src.layers.modeling.token_and_position_embedding import (
7
+ TokenAndPositionEmbedding,
8
+ )
9
+ from keras_hub.src.models.sam3.sam3_utils import create_bidirectional_mask
10
+
11
+
12
+ class CLIPEncoderLayer(layers.Layer):
13
+ def __init__(
14
+ self,
15
+ hidden_dim,
16
+ num_heads,
17
+ intermediate_dim,
18
+ intermediate_activation="gelu",
19
+ layer_norm_epsilon=1e-5,
20
+ **kwargs,
21
+ ):
22
+ super().__init__(**kwargs)
23
+ self.hidden_dim = int(hidden_dim)
24
+ self.num_heads = int(num_heads)
25
+ self.intermediate_dim = int(intermediate_dim)
26
+ self.intermediate_activation = intermediate_activation
27
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
28
+
29
+ self.layer_norm_1 = layers.LayerNormalization(
30
+ epsilon=self.layer_norm_epsilon,
31
+ dtype=self.dtype_policy,
32
+ name="layer_norm_1",
33
+ )
34
+ self.attention = layers.MultiHeadAttention(
35
+ num_heads,
36
+ hidden_dim // num_heads,
37
+ dtype=self.dtype_policy,
38
+ name="attention",
39
+ )
40
+ self.layer_norm_2 = layers.LayerNormalization(
41
+ epsilon=self.layer_norm_epsilon,
42
+ dtype=self.dtype_policy,
43
+ name="layer_norm_2",
44
+ )
45
+ self.dense_1 = layers.Dense(
46
+ self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
47
+ )
48
+ self.activation = layers.Activation(
49
+ intermediate_activation, dtype=self.dtype_policy, name="activation"
50
+ )
51
+ self.dense_2 = layers.Dense(
52
+ self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
53
+ )
54
+
55
+ def build(self, inputs_shape, attention_mask_shape):
56
+ self.layer_norm_1.build(inputs_shape)
57
+ self.attention.build(inputs_shape, inputs_shape, inputs_shape)
58
+ self.layer_norm_2.build(inputs_shape)
59
+ self.dense_1.build(inputs_shape)
60
+ input_shape = self.dense_1.compute_output_shape(inputs_shape)
61
+ self.dense_2.build(input_shape)
62
+
63
+ def compute_output_shape(self, inputs_shape, attention_mask_shape):
64
+ outputs_shape = list(inputs_shape)
65
+ outputs_shape[-1] = self.hidden_dim
66
+ return outputs_shape
67
+
68
+ def call(self, inputs, attention_mask, training=None):
69
+ residual = inputs
70
+ x = self.layer_norm_1(inputs)
71
+ x = self.attention(
72
+ x,
73
+ x,
74
+ x,
75
+ attention_mask=ops.cast(attention_mask, dtype="bool"),
76
+ training=training,
77
+ )
78
+ x = ops.add(residual, x)
79
+
80
+ residual = x
81
+ x = self.dense_1(self.layer_norm_2(residual))
82
+ x = self.activation(x)
83
+ x = self.dense_2(x)
84
+ x = ops.add(residual, x)
85
+ return x
86
+
87
+ def get_config(self):
88
+ config = super().get_config()
89
+ config.update(
90
+ {
91
+ "hidden_dim": self.hidden_dim,
92
+ "num_heads": self.num_heads,
93
+ "intermediate_dim": self.intermediate_dim,
94
+ "intermediate_activation": self.intermediate_activation,
95
+ "layer_norm_epsilon": self.layer_norm_epsilon,
96
+ }
97
+ )
98
+ return config
99
+
100
+
101
+ @keras_hub_export("keras_hub.layers.SAM3TextEncoder")
102
+ class SAM3TextEncoder(layers.Layer):
103
+ """A text encoder for the Segment Anything Model 3 (SAM3).
104
+
105
+ This layer implements a CLIP-style text encoder. It processes token IDs and
106
+ padding masks to produce text embeddings that are used as prompts for
107
+ segmentation.
108
+
109
+ Args:
110
+ vocabulary_size: int. The size of the vocabulary.
111
+ embedding_dim: int. The dimension of the token embeddings.
112
+ hidden_dim: int. The hidden dimension of the transformer layers.
113
+ num_layers: int. The number of transformer layers.
114
+ num_heads: int. The number of attention heads.
115
+ intermediate_dim: int. The dimension of the intermediate layer in the
116
+ transformer's MLP.
117
+ intermediate_activation: str. The activation function for the
118
+ transformer layers. Defaults to `"gelu"`.
119
+ max_sequence_length: int. The maximum sequence length. Defaults to
120
+ `32`.
121
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
122
+ Defaults to `1e-6`.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ vocabulary_size,
128
+ embedding_dim,
129
+ hidden_dim,
130
+ num_layers,
131
+ num_heads,
132
+ intermediate_dim,
133
+ intermediate_activation="gelu",
134
+ max_sequence_length=32,
135
+ layer_norm_epsilon=1e-5,
136
+ **kwargs,
137
+ ):
138
+ super().__init__(**kwargs)
139
+ self.vocabulary_size = int(vocabulary_size)
140
+ self.embedding_dim = int(embedding_dim)
141
+ self.hidden_dim = int(hidden_dim)
142
+ self.num_layers = int(num_layers)
143
+ self.num_heads = int(num_heads)
144
+ self.intermediate_dim = int(intermediate_dim)
145
+ self.intermediate_activation = intermediate_activation
146
+ self.max_sequence_length = int(max_sequence_length)
147
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
148
+
149
+ self.embedding = TokenAndPositionEmbedding(
150
+ vocabulary_size=self.vocabulary_size,
151
+ sequence_length=self.max_sequence_length,
152
+ embedding_dim=self.embedding_dim,
153
+ dtype=self.dtype_policy,
154
+ name="embedding",
155
+ )
156
+ self.encoder_layers = [
157
+ CLIPEncoderLayer(
158
+ self.hidden_dim,
159
+ self.num_heads,
160
+ self.intermediate_dim,
161
+ self.intermediate_activation,
162
+ dtype=self.dtype_policy,
163
+ name=f"encoder_layer_{i}",
164
+ )
165
+ for i in range(self.num_layers)
166
+ ]
167
+ self.layer_norm = layers.LayerNormalization(
168
+ epsilon=self.layer_norm_epsilon,
169
+ dtype=self.dtype_policy,
170
+ name="layer_norm",
171
+ )
172
+
173
+ def build(self, token_ids_shape, padding_masks_shape):
174
+ self.embedding.build(token_ids_shape)
175
+ x_shape = self.embedding.compute_output_shape(token_ids_shape)
176
+ for layer in self.encoder_layers:
177
+ layer.build(x_shape, padding_masks_shape)
178
+ self.layer_norm.build(x_shape)
179
+
180
+ def call(self, token_ids, padding_masks, training=None):
181
+ x = self.embedding(token_ids, training=training)
182
+ padding_masks = create_bidirectional_mask(x, padding_masks)
183
+ for layer in self.encoder_layers:
184
+ x = layer(x, padding_masks, training=training)
185
+ x = self.layer_norm(x)
186
+ return x
187
+
188
+ def get_config(self):
189
+ config = super().get_config()
190
+ config.update(
191
+ {
192
+ "vocabulary_size": self.vocabulary_size,
193
+ "embedding_dim": self.embedding_dim,
194
+ "hidden_dim": self.hidden_dim,
195
+ "num_layers": self.num_layers,
196
+ "num_heads": self.num_heads,
197
+ "intermediate_dim": self.intermediate_dim,
198
+ "intermediate_activation": self.intermediate_activation,
199
+ "max_sequence_length": self.max_sequence_length,
200
+ "layer_norm_epsilon": self.layer_norm_epsilon,
201
+ }
202
+ )
203
+ return config
204
+
205
+ def compute_output_shape(self, token_ids_shape, padding_masks_shape):
206
+ return self.embedding.compute_output_shape(token_ids_shape)
207
+
208
+ def compute_output_spec(self, token_ids, padding_masks):
209
+ output_shape = self.compute_output_shape(
210
+ token_ids.shape, padding_masks.shape
211
+ )
212
+ return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
@@ -0,0 +1,65 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer
3
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
4
+ SAM3PromptableConceptBackbone,
5
+ )
6
+
7
+
8
+ @keras_hub_export(
9
+ [
10
+ "keras_hub.tokenizers.SAM3Tokenizer",
11
+ "keras_hub.models.SAM3Tokenizer",
12
+ ]
13
+ )
14
+ class SAM3Tokenizer(CLIPTokenizer):
15
+ """A SAM3 tokenizer using Byte-Pair Encoding subword segmentation.
16
+
17
+ This tokenizer class will tokenize raw strings into integer sequences and
18
+ is based on `keras_hub.tokenizers.BytePairTokenizer`. Unlike the
19
+ underlying tokenizer, it will check for all special tokens needed by SAM3
20
+ models and provides a `from_preset()` method to automatically download
21
+ a matching vocabulary for a SAM3 preset.
22
+
23
+ If input is a batch of strings (rank > 0), the layer will output a
24
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
25
+
26
+ If input is a scalar string (rank == 0), the layer will output a dense
27
+ `tf.Tensor` with static shape `[None]`.
28
+
29
+ Args:
30
+ vocabulary: string or dict, maps token to integer ids. If it is a
31
+ string, it should be the file path to a json file.
32
+ merges: string or list, contains the merge rule. If it is a string,
33
+ it should be the file path to merge rules. The merge rule file
34
+ should have one merge rule per line. Every merge rule contains
35
+ merge entities separated by a space.
36
+
37
+ Examples:
38
+
39
+ ```python
40
+ # Unbatched input.
41
+ tokenizer = keras_hub.models.SAM3Tokenizer.from_preset("sam3_pcs")
42
+ tokenizer("The quick brown fox jumped.")
43
+
44
+ # Batched input.
45
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
46
+
47
+ # Detokenization.
48
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
49
+ ```
50
+ """
51
+
52
+ backbone_cls = SAM3PromptableConceptBackbone
53
+
54
+ def __init__(self, vocabulary=None, merges=None, **kwargs):
55
+ super().__init__(
56
+ vocabulary=vocabulary,
57
+ merges=merges,
58
+ pad_with_end_token=True,
59
+ **kwargs,
60
+ )
61
+
62
+ def get_config(self):
63
+ config = super().get_config()
64
+ del config["pad_with_end_token"] # Always True for SAM3Tokenizer.
65
+ return config
@@ -0,0 +1,134 @@
1
+ from keras import ops
2
+
3
+
4
+ def window_partition(x, height, width, window_size, hidden_dim):
5
+ x = ops.reshape(
6
+ x,
7
+ (
8
+ -1,
9
+ height // window_size,
10
+ window_size,
11
+ width // window_size,
12
+ window_size,
13
+ hidden_dim,
14
+ ),
15
+ )
16
+ x = ops.transpose(x, axes=(0, 1, 3, 2, 4, 5))
17
+ x = ops.reshape(x, (-1, window_size, window_size, hidden_dim))
18
+ return x
19
+
20
+
21
+ def window_unpartition(x, height, width, window_size, hidden_dim):
22
+ x = ops.reshape(
23
+ x,
24
+ (
25
+ -1,
26
+ height // window_size,
27
+ width // window_size,
28
+ window_size,
29
+ window_size,
30
+ hidden_dim,
31
+ ),
32
+ )
33
+ x = ops.transpose(x, axes=(0, 1, 3, 2, 4, 5))
34
+ x = ops.reshape(x, (-1, height, width, hidden_dim))
35
+ return x
36
+
37
+
38
+ def box_cxcywh_to_xyxy(boxes):
39
+ x_c, y_c, w, h = ops.unstack(boxes, num=4, axis=-1)
40
+ return ops.stack(
41
+ [
42
+ ops.subtract(x_c, ops.multiply(0.5, w)),
43
+ ops.subtract(y_c, ops.multiply(0.5, h)),
44
+ ops.add(x_c, ops.multiply(0.5, w)),
45
+ ops.add(y_c, ops.multiply(0.5, h)),
46
+ ],
47
+ axis=-1,
48
+ )
49
+
50
+
51
+ def concatenate_padded_sequences(
52
+ sequence1, mask1, sequence_len1, sequence2, mask2, sequence_len2, hidden_dim
53
+ ):
54
+ """Concatenate two sequences with padding masks.
55
+
56
+ Args:
57
+ sequence1: A tensor of shape (batch_size, sequence_length1, hidden_dim).
58
+ mask1: A boolean tensor of shape (batch_size, sequence_length1).
59
+ sequence2: A tensor of shape (batch_size, sequence_length2, hidden_dim).
60
+ mask2: A boolean tensor of shape (batch_size, sequence_length2).
61
+ hidden_dim: An integer representing the hidden dimension.
62
+
63
+ Returns:
64
+ concatenated_sequence: A tensor of shape
65
+ (batch_size, sequence_length1 + sequence_length2, hidden_dim).
66
+ concatenated_mask: A boolean tensor of shape
67
+ (batch_size, sequence_length1 + sequence_length2).
68
+ """
69
+ batch_size = ops.shape(sequence1)[0]
70
+ max_length = sequence_len1 + sequence_len2
71
+
72
+ actual_sequence_1_lengths = ops.sum(ops.cast(mask1, dtype="int32"), axis=1)
73
+ actual_sequence_2_lengths = ops.sum(ops.cast(mask2, dtype="int32"), axis=1)
74
+ final_lengths = ops.add(
75
+ actual_sequence_1_lengths, actual_sequence_2_lengths
76
+ )
77
+
78
+ concatenated_mask = ops.less(
79
+ ops.tile(
80
+ ops.expand_dims(ops.arange(max_length, dtype="int32"), axis=0),
81
+ [batch_size, 1],
82
+ ),
83
+ ops.expand_dims(final_lengths, axis=1),
84
+ )
85
+ concatenated_sequence = ops.concatenate(
86
+ [
87
+ sequence1,
88
+ ops.zeros(
89
+ (batch_size, sequence_len2, hidden_dim), dtype=sequence1.dtype
90
+ ),
91
+ ],
92
+ axis=1,
93
+ )
94
+
95
+ # Create the indices.
96
+ indices = ops.tile(
97
+ ops.expand_dims(ops.arange(sequence_len2, dtype="int32"), axis=0),
98
+ [batch_size, 1],
99
+ )
100
+ indices = ops.add(
101
+ indices, ops.expand_dims(actual_sequence_1_lengths, axis=-1)
102
+ )
103
+ # Adjust the indices to account for batch dimension.
104
+ to_add = ops.multiply(ops.arange(batch_size, dtype="int32"), max_length)
105
+ indices = ops.add(
106
+ indices, ops.cast(ops.expand_dims(to_add, axis=-1), "int32")
107
+ )
108
+ # `ops.scatter_update` requires 2D indices. We flatten the inputs before
109
+ # scattering and reshape back after.
110
+ flat_concatenated_sequence = ops.scatter_update(
111
+ ops.reshape(concatenated_sequence, (-1, hidden_dim)),
112
+ ops.reshape(indices, (-1, 1)),
113
+ ops.reshape(
114
+ ops.cast(sequence2, concatenated_sequence.dtype), (-1, hidden_dim)
115
+ ),
116
+ )
117
+
118
+ concatenated_sequence = ops.reshape(
119
+ flat_concatenated_sequence, (batch_size, -1, hidden_dim)
120
+ )
121
+ return concatenated_sequence, concatenated_mask
122
+
123
+
124
+ def create_bidirectional_mask(input_embeds, attention_mask):
125
+ seq_len = ops.shape(input_embeds)[1]
126
+ attention_mask = attention_mask[:, None, None, :]
127
+ return ops.tile(attention_mask, (1, 1, seq_len, 1))
128
+
129
+
130
+ def inverse_sigmoid(x, eps=1e-3):
131
+ x = ops.clip(x, 0.0, 1.0)
132
+ x1 = ops.maximum(x, eps)
133
+ x2 = ops.maximum(ops.subtract(1.0, x), eps)
134
+ return ops.log(ops.divide(x1, x2))