keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +21 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +10 -15
- keras_hub/src/models/d_fine/__init__.py +0 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/parseq/__init__.py +0 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/tests/test_case.py +37 -1
- keras_hub/src/utils/preset_utils.py +49 -0
- keras_hub/src/utils/tensor_utils.py +23 -1
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,461 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import keras
|
4
|
+
|
5
|
+
from keras_hub.src.models.d_fine.d_fine_utils import (
|
6
|
+
multi_scale_deformable_attention_v2,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class DFineMultiscaleDeformableAttention(keras.layers.Layer):
|
11
|
+
"""Multi-scale deformable attention layer for D-FINE models.
|
12
|
+
|
13
|
+
This layer implements the multi-scale deformable attention mechanism, which
|
14
|
+
is the core of the cross-attention in each `DFineDecoderLayer`. It allows
|
15
|
+
the model to attend to a small set of key sampling points around a reference
|
16
|
+
point across multiple feature levels from the encoder.
|
17
|
+
|
18
|
+
The layer computes sampling locations and attention weights based on the
|
19
|
+
input queries, enabling the model to focus on relevant features across
|
20
|
+
multiple feature levels and spatial positions.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
hidden_dim: int, Hidden dimension size for the attention mechanism.
|
24
|
+
decoder_attention_heads: int, Number of attention heads.
|
25
|
+
num_feature_levels: int, Number of feature levels to attend to.
|
26
|
+
decoder_offset_scale: float, Scaling factor for sampling offsets.
|
27
|
+
decoder_method: str, Method used for deformable attention computation.
|
28
|
+
decoder_n_points: int or list, Number of sampling points per level.
|
29
|
+
If int, the same number of points is used for all levels.
|
30
|
+
If list, specifies points for each level individually.
|
31
|
+
num_queries: int, Number of queries in the attention mechanism.
|
32
|
+
spatial_shapes: list, List of spatial shapes for different
|
33
|
+
feature levels.
|
34
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
hidden_dim,
|
40
|
+
decoder_attention_heads,
|
41
|
+
num_feature_levels,
|
42
|
+
decoder_offset_scale,
|
43
|
+
decoder_method,
|
44
|
+
decoder_n_points,
|
45
|
+
num_queries,
|
46
|
+
spatial_shapes,
|
47
|
+
dtype=None,
|
48
|
+
**kwargs,
|
49
|
+
):
|
50
|
+
super().__init__(dtype=dtype, **kwargs)
|
51
|
+
self.hidden_dim = hidden_dim
|
52
|
+
self.num_queries = num_queries
|
53
|
+
self.n_heads = decoder_attention_heads
|
54
|
+
self.n_levels = num_feature_levels
|
55
|
+
self.offset_scale = decoder_offset_scale
|
56
|
+
self.decoder_method = decoder_method
|
57
|
+
self.decoder_n_points = decoder_n_points
|
58
|
+
self.spatial_shapes = spatial_shapes
|
59
|
+
if isinstance(self.decoder_n_points, list):
|
60
|
+
self.num_points = self.decoder_n_points
|
61
|
+
else:
|
62
|
+
self.num_points = [
|
63
|
+
self.decoder_n_points for _ in range(self.n_levels)
|
64
|
+
]
|
65
|
+
self._num_points_scale = [
|
66
|
+
1.0 / n_points_at_level
|
67
|
+
for n_points_at_level in self.num_points
|
68
|
+
for _ in range(n_points_at_level)
|
69
|
+
]
|
70
|
+
self.total_points = self.n_heads * sum(self.num_points)
|
71
|
+
self.ms_deformable_attn_core = multi_scale_deformable_attention_v2
|
72
|
+
|
73
|
+
def build(self, input_shape):
|
74
|
+
sampling_offsets_output_shape = (
|
75
|
+
input_shape[1],
|
76
|
+
self.n_heads,
|
77
|
+
sum(self.num_points),
|
78
|
+
2,
|
79
|
+
)
|
80
|
+
self.sampling_offsets = keras.layers.EinsumDense(
|
81
|
+
"abc,cdef->abdef",
|
82
|
+
output_shape=sampling_offsets_output_shape,
|
83
|
+
bias_axes="def",
|
84
|
+
kernel_initializer="zeros",
|
85
|
+
bias_initializer="zeros",
|
86
|
+
name="sampling_offsets",
|
87
|
+
dtype=self.dtype_policy,
|
88
|
+
)
|
89
|
+
self.sampling_offsets.build(input_shape)
|
90
|
+
attention_weights_output_shape = (
|
91
|
+
input_shape[1],
|
92
|
+
self.n_heads,
|
93
|
+
sum(self.num_points),
|
94
|
+
)
|
95
|
+
self.attention_weights = keras.layers.EinsumDense(
|
96
|
+
"abc,cde->abde",
|
97
|
+
output_shape=attention_weights_output_shape,
|
98
|
+
bias_axes="de",
|
99
|
+
kernel_initializer="zeros",
|
100
|
+
bias_initializer="zeros",
|
101
|
+
name="attention_weights",
|
102
|
+
dtype=self.dtype_policy,
|
103
|
+
)
|
104
|
+
self.attention_weights.build(input_shape)
|
105
|
+
if self.sampling_offsets.bias is not None:
|
106
|
+
thetas = keras.ops.arange(
|
107
|
+
self.n_heads, dtype=self.variable_dtype
|
108
|
+
) * (2.0 * math.pi / self.n_heads)
|
109
|
+
grid_init = keras.ops.stack(
|
110
|
+
[keras.ops.cos(thetas), keras.ops.sin(thetas)], axis=-1
|
111
|
+
)
|
112
|
+
grid_init = grid_init / keras.ops.max(
|
113
|
+
keras.ops.abs(grid_init), axis=-1, keepdims=True
|
114
|
+
)
|
115
|
+
grid_init = keras.ops.reshape(grid_init, (self.n_heads, 1, 2))
|
116
|
+
grid_init = keras.ops.tile(grid_init, [1, sum(self.num_points), 1])
|
117
|
+
scaling = []
|
118
|
+
for n in self.num_points:
|
119
|
+
scaling.append(
|
120
|
+
keras.ops.arange(1, n + 1, dtype=self.variable_dtype)
|
121
|
+
)
|
122
|
+
scaling = keras.ops.concatenate(scaling, axis=0)
|
123
|
+
scaling = keras.ops.reshape(scaling, (1, -1, 1))
|
124
|
+
grid_init *= scaling
|
125
|
+
self.sampling_offsets.bias.assign(grid_init)
|
126
|
+
self.num_points_scale = self.add_weight(
|
127
|
+
name="num_points_scale",
|
128
|
+
shape=(len(self._num_points_scale),),
|
129
|
+
initializer=keras.initializers.Constant(self._num_points_scale),
|
130
|
+
trainable=False,
|
131
|
+
)
|
132
|
+
super().build(input_shape)
|
133
|
+
|
134
|
+
def compute_attention(
|
135
|
+
self, hidden_states, reference_points, spatial_shapes
|
136
|
+
):
|
137
|
+
batch_size = keras.ops.shape(hidden_states)[0]
|
138
|
+
num_queries = keras.ops.shape(hidden_states)[1]
|
139
|
+
sampling_offsets = self.sampling_offsets(hidden_states)
|
140
|
+
attention_weights = self.attention_weights(hidden_states)
|
141
|
+
attention_weights = keras.ops.softmax(attention_weights, axis=-1)
|
142
|
+
|
143
|
+
if keras.ops.shape(reference_points)[-1] == 2:
|
144
|
+
offset_normalizer = keras.ops.cast(
|
145
|
+
spatial_shapes, dtype=hidden_states.dtype
|
146
|
+
)
|
147
|
+
offset_normalizer = keras.ops.flip(offset_normalizer, axis=1)
|
148
|
+
offset_normalizer = keras.ops.reshape(
|
149
|
+
offset_normalizer, (1, 1, 1, self.n_levels, 1, 2)
|
150
|
+
)
|
151
|
+
sampling_locations = (
|
152
|
+
keras.ops.reshape(
|
153
|
+
reference_points,
|
154
|
+
(batch_size, num_queries, 1, self.n_levels, 1, 2),
|
155
|
+
)
|
156
|
+
+ sampling_offsets / offset_normalizer
|
157
|
+
)
|
158
|
+
elif keras.ops.shape(reference_points)[-1] == 4:
|
159
|
+
num_points_scale_t = keras.ops.cast(
|
160
|
+
self.num_points_scale, dtype=hidden_states.dtype
|
161
|
+
)
|
162
|
+
num_points_scale_t = keras.ops.expand_dims(
|
163
|
+
num_points_scale_t, axis=-1
|
164
|
+
)
|
165
|
+
offset = (
|
166
|
+
sampling_offsets
|
167
|
+
* num_points_scale_t
|
168
|
+
* keras.ops.expand_dims(reference_points[..., 2:], axis=-2)
|
169
|
+
* self.offset_scale
|
170
|
+
)
|
171
|
+
sampling_locations = (
|
172
|
+
keras.ops.expand_dims(reference_points[..., :2], axis=-2)
|
173
|
+
+ offset
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
raise ValueError(
|
177
|
+
f"Last dim of reference_points must be 2 or 4, but get "
|
178
|
+
f"{keras.ops.shape(reference_points)[-1]} instead."
|
179
|
+
)
|
180
|
+
return sampling_locations, attention_weights
|
181
|
+
|
182
|
+
def call(
|
183
|
+
self,
|
184
|
+
hidden_states,
|
185
|
+
encoder_hidden_states,
|
186
|
+
reference_points,
|
187
|
+
spatial_shapes,
|
188
|
+
):
|
189
|
+
batch_size = keras.ops.shape(hidden_states)[0]
|
190
|
+
num_queries = keras.ops.shape(hidden_states)[1]
|
191
|
+
sequence_length = keras.ops.shape(encoder_hidden_states)[1]
|
192
|
+
value = keras.ops.reshape(
|
193
|
+
encoder_hidden_states,
|
194
|
+
(
|
195
|
+
batch_size,
|
196
|
+
sequence_length,
|
197
|
+
self.n_heads,
|
198
|
+
self.hidden_dim // self.n_heads,
|
199
|
+
),
|
200
|
+
)
|
201
|
+
sampling_locations, attention_weights = self.compute_attention(
|
202
|
+
hidden_states, reference_points, spatial_shapes
|
203
|
+
)
|
204
|
+
|
205
|
+
# NOTE: slice_sizes_values passed down to ms_deformable_attn_core
|
206
|
+
# since JAX tracing doesn't support dynamic shapes.
|
207
|
+
slice_sizes = [h * w for h, w in self.spatial_shapes]
|
208
|
+
output = self.ms_deformable_attn_core(
|
209
|
+
value,
|
210
|
+
spatial_shapes,
|
211
|
+
sampling_locations,
|
212
|
+
attention_weights,
|
213
|
+
self.num_points,
|
214
|
+
slice_sizes,
|
215
|
+
self.spatial_shapes,
|
216
|
+
self.n_levels,
|
217
|
+
num_queries,
|
218
|
+
self.decoder_method,
|
219
|
+
)
|
220
|
+
return output, attention_weights
|
221
|
+
|
222
|
+
def compute_output_spec(
|
223
|
+
self,
|
224
|
+
hidden_states,
|
225
|
+
encoder_hidden_states,
|
226
|
+
reference_points,
|
227
|
+
spatial_shapes,
|
228
|
+
):
|
229
|
+
input_shape = hidden_states.shape
|
230
|
+
batch_size = input_shape[0] if len(input_shape) > 0 else None
|
231
|
+
num_queries = input_shape[1] if len(input_shape) > 1 else None
|
232
|
+
output_shape = (batch_size, num_queries, self.hidden_dim)
|
233
|
+
output_spec = keras.KerasTensor(output_shape, dtype=self.compute_dtype)
|
234
|
+
attention_weights_shape = (
|
235
|
+
batch_size,
|
236
|
+
num_queries,
|
237
|
+
self.n_heads,
|
238
|
+
sum(self.num_points),
|
239
|
+
)
|
240
|
+
attention_weights_spec = keras.KerasTensor(
|
241
|
+
attention_weights_shape, dtype=self.compute_dtype
|
242
|
+
)
|
243
|
+
return output_spec, attention_weights_spec
|
244
|
+
|
245
|
+
def get_config(self):
|
246
|
+
config = super().get_config()
|
247
|
+
config.update(
|
248
|
+
{
|
249
|
+
"hidden_dim": self.hidden_dim,
|
250
|
+
"decoder_attention_heads": self.n_heads,
|
251
|
+
"num_feature_levels": self.n_levels,
|
252
|
+
"decoder_offset_scale": self.offset_scale,
|
253
|
+
"decoder_method": self.decoder_method,
|
254
|
+
"decoder_n_points": self.decoder_n_points,
|
255
|
+
"num_queries": self.num_queries,
|
256
|
+
"spatial_shapes": self.spatial_shapes,
|
257
|
+
}
|
258
|
+
)
|
259
|
+
return config
|
260
|
+
|
261
|
+
|
262
|
+
class DFineMultiheadAttention(keras.layers.Layer):
|
263
|
+
"""Multi-head attention layer for D-FINE models.
|
264
|
+
|
265
|
+
This layer implements a standard multi-head attention mechanism. It is used
|
266
|
+
in two key places within the D-FINE architecture:
|
267
|
+
1. In `DFineEncoderLayer` as the self-attention mechanism to process the
|
268
|
+
sequence of image features from the `HGNetV2Backbone` class.
|
269
|
+
2. In `DFineDecoderLayer` as the self-attention mechanism to allow object
|
270
|
+
queries to interact with each other.
|
271
|
+
|
272
|
+
It supports position embeddings to incorporate positional information and
|
273
|
+
attention masking to prevent attending to certain positions.
|
274
|
+
|
275
|
+
Args:
|
276
|
+
embedding_dim: int, Embedding dimension size.
|
277
|
+
num_heads: int, Number of attention heads.
|
278
|
+
dropout: float, optional, Dropout probability for attention weights.
|
279
|
+
Defaults to `0.0`.
|
280
|
+
bias: bool, optional, Whether to include bias in projection layers.
|
281
|
+
Defaults to `True`.
|
282
|
+
kernel_initializer: str or initializer, optional, Initializer for
|
283
|
+
kernel weights. Defaults to `"glorot_uniform"`.
|
284
|
+
bias_initializer: str or initializer, optional, Initializer for
|
285
|
+
bias weights. Defaults to `"zeros"`.
|
286
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
287
|
+
"""
|
288
|
+
|
289
|
+
def __init__(
|
290
|
+
self,
|
291
|
+
embedding_dim,
|
292
|
+
num_heads,
|
293
|
+
dropout=0.0,
|
294
|
+
bias=True,
|
295
|
+
kernel_initializer="glorot_uniform",
|
296
|
+
bias_initializer="zeros",
|
297
|
+
dtype=None,
|
298
|
+
**kwargs,
|
299
|
+
):
|
300
|
+
super().__init__(dtype=dtype, **kwargs)
|
301
|
+
self.embedding_dim = embedding_dim
|
302
|
+
self.num_heads = num_heads
|
303
|
+
self.dropout_rate = dropout
|
304
|
+
self.head_dim = embedding_dim // num_heads
|
305
|
+
if self.head_dim * self.num_heads != self.embedding_dim:
|
306
|
+
raise ValueError(
|
307
|
+
f"embedding_dim must be divisible by num_heads (got "
|
308
|
+
f"`embedding_dim`: {self.embedding_dim} and `num_heads`: "
|
309
|
+
f"{self.num_heads})."
|
310
|
+
)
|
311
|
+
self.scaling = self.head_dim**-0.5
|
312
|
+
self.bias = bias
|
313
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
314
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
315
|
+
self.dropout = keras.layers.Dropout(
|
316
|
+
self.dropout_rate, dtype=self.dtype_policy
|
317
|
+
)
|
318
|
+
|
319
|
+
def build(self, input_shape):
|
320
|
+
embedding_dim = self.embedding_dim
|
321
|
+
proj_equation = "abc,cde->abde"
|
322
|
+
proj_bias_axes = "de"
|
323
|
+
proj_output_shape = (None, self.num_heads, self.head_dim)
|
324
|
+
proj_input_shape = (None, None, embedding_dim)
|
325
|
+
self.q_proj = keras.layers.EinsumDense(
|
326
|
+
proj_equation,
|
327
|
+
output_shape=proj_output_shape,
|
328
|
+
bias_axes=proj_bias_axes if self.bias else None,
|
329
|
+
kernel_initializer=self.kernel_initializer,
|
330
|
+
bias_initializer=self.bias_initializer if self.bias else None,
|
331
|
+
dtype=self.dtype_policy,
|
332
|
+
name="q_proj",
|
333
|
+
)
|
334
|
+
self.q_proj.build(proj_input_shape)
|
335
|
+
self.k_proj = keras.layers.EinsumDense(
|
336
|
+
proj_equation,
|
337
|
+
output_shape=proj_output_shape,
|
338
|
+
bias_axes=proj_bias_axes if self.bias else None,
|
339
|
+
kernel_initializer=self.kernel_initializer,
|
340
|
+
bias_initializer=self.bias_initializer if self.bias else None,
|
341
|
+
dtype=self.dtype_policy,
|
342
|
+
name="k_proj",
|
343
|
+
)
|
344
|
+
self.k_proj.build(proj_input_shape)
|
345
|
+
self.v_proj = keras.layers.EinsumDense(
|
346
|
+
proj_equation,
|
347
|
+
output_shape=proj_output_shape,
|
348
|
+
bias_axes=proj_bias_axes if self.bias else None,
|
349
|
+
kernel_initializer=self.kernel_initializer,
|
350
|
+
bias_initializer=self.bias_initializer if self.bias else None,
|
351
|
+
dtype=self.dtype_policy,
|
352
|
+
name="v_proj",
|
353
|
+
)
|
354
|
+
self.v_proj.build(proj_input_shape)
|
355
|
+
out_proj_input_shape = (None, None, self.num_heads * self.head_dim)
|
356
|
+
out_proj_output_shape = (None, self.embedding_dim)
|
357
|
+
self.out_proj = keras.layers.EinsumDense(
|
358
|
+
"abc,cd->abd",
|
359
|
+
output_shape=out_proj_output_shape,
|
360
|
+
bias_axes="d" if self.bias else None,
|
361
|
+
kernel_initializer=self.kernel_initializer,
|
362
|
+
bias_initializer=self.bias_initializer if self.bias else None,
|
363
|
+
dtype=self.dtype_policy,
|
364
|
+
name="out_proj",
|
365
|
+
)
|
366
|
+
self.out_proj.build(out_proj_input_shape)
|
367
|
+
super().build(input_shape)
|
368
|
+
|
369
|
+
def call(
|
370
|
+
self,
|
371
|
+
hidden_states,
|
372
|
+
position_embeddings=None,
|
373
|
+
attention_mask=None,
|
374
|
+
output_attentions=False,
|
375
|
+
training=None,
|
376
|
+
):
|
377
|
+
batch_size = keras.ops.shape(hidden_states)[0]
|
378
|
+
target_len = keras.ops.shape(hidden_states)[1]
|
379
|
+
|
380
|
+
def with_pos_embed(tensor, position_embeddings_k):
|
381
|
+
return (
|
382
|
+
tensor
|
383
|
+
if position_embeddings_k is None
|
384
|
+
else tensor + position_embeddings_k
|
385
|
+
)
|
386
|
+
|
387
|
+
hidden_states_with_pos = with_pos_embed(
|
388
|
+
hidden_states, position_embeddings
|
389
|
+
)
|
390
|
+
query_states = self.q_proj(hidden_states_with_pos)
|
391
|
+
key_states = self.k_proj(hidden_states_with_pos)
|
392
|
+
value_states = self.v_proj(hidden_states)
|
393
|
+
attn_weights = keras.ops.einsum(
|
394
|
+
"bthd,bshd->bhts", query_states * self.scaling, key_states
|
395
|
+
)
|
396
|
+
if attention_mask is not None:
|
397
|
+
if keras.ops.ndim(attention_mask) == 2:
|
398
|
+
attention_mask = keras.ops.expand_dims(attention_mask, axis=0)
|
399
|
+
attention_mask = keras.ops.expand_dims(attention_mask, axis=1)
|
400
|
+
attn_weights = attn_weights + attention_mask
|
401
|
+
attn_weights = keras.ops.softmax(attn_weights, axis=-1)
|
402
|
+
attn_weights_for_output = attn_weights if output_attentions else None
|
403
|
+
attn_probs = self.dropout(attn_weights, training=training)
|
404
|
+
attn_output = keras.ops.einsum(
|
405
|
+
"bhts,bshd->bthd", attn_probs, value_states
|
406
|
+
)
|
407
|
+
attn_output = keras.ops.reshape(
|
408
|
+
attn_output, (batch_size, target_len, self.embedding_dim)
|
409
|
+
)
|
410
|
+
attn_output = self.out_proj(attn_output)
|
411
|
+
if output_attentions:
|
412
|
+
return attn_output, attn_weights_for_output
|
413
|
+
else:
|
414
|
+
return attn_output
|
415
|
+
|
416
|
+
def compute_output_spec(
|
417
|
+
self,
|
418
|
+
hidden_states,
|
419
|
+
position_embeddings=None,
|
420
|
+
attention_mask=None,
|
421
|
+
output_attentions=False,
|
422
|
+
training=None,
|
423
|
+
):
|
424
|
+
input_shape = hidden_states.shape
|
425
|
+
batch_size = input_shape[0] if len(input_shape) > 0 else None
|
426
|
+
target_len = input_shape[1] if len(input_shape) > 1 else None
|
427
|
+
source_len = target_len
|
428
|
+
attn_output_shape = (batch_size, target_len, self.embedding_dim)
|
429
|
+
attn_output_spec = keras.KerasTensor(
|
430
|
+
attn_output_shape, dtype=self.compute_dtype
|
431
|
+
)
|
432
|
+
if output_attentions:
|
433
|
+
attn_weights_shape = (
|
434
|
+
batch_size,
|
435
|
+
self.num_heads,
|
436
|
+
target_len,
|
437
|
+
source_len,
|
438
|
+
)
|
439
|
+
attn_weights_spec = keras.KerasTensor(
|
440
|
+
attn_weights_shape, dtype=self.compute_dtype
|
441
|
+
)
|
442
|
+
return attn_output_spec, attn_weights_spec
|
443
|
+
return attn_output_spec
|
444
|
+
|
445
|
+
def get_config(self):
|
446
|
+
config = super().get_config()
|
447
|
+
config.update(
|
448
|
+
{
|
449
|
+
"embedding_dim": self.embedding_dim,
|
450
|
+
"num_heads": self.num_heads,
|
451
|
+
"dropout": self.dropout_rate,
|
452
|
+
"bias": self.bias,
|
453
|
+
"kernel_initializer": keras.initializers.serialize(
|
454
|
+
self.kernel_initializer
|
455
|
+
),
|
456
|
+
"bias_initializer": keras.initializers.serialize(
|
457
|
+
self.bias_initializer
|
458
|
+
),
|
459
|
+
}
|
460
|
+
)
|
461
|
+
return config
|