tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250110__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-models-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. official/projects/detr/__init__.py +14 -0
  2. official/projects/detr/configs/__init__.py +14 -0
  3. official/projects/detr/configs/detr.py +277 -0
  4. official/projects/detr/configs/detr_test.py +51 -0
  5. official/projects/detr/dataloaders/__init__.py +14 -0
  6. official/projects/detr/dataloaders/coco.py +157 -0
  7. official/projects/detr/dataloaders/coco_test.py +111 -0
  8. official/projects/detr/dataloaders/detr_input.py +175 -0
  9. official/projects/detr/experiments/__init__.py +14 -0
  10. official/projects/detr/modeling/__init__.py +14 -0
  11. official/projects/detr/modeling/detr.py +345 -0
  12. official/projects/detr/modeling/detr_test.py +70 -0
  13. official/projects/detr/modeling/transformer.py +849 -0
  14. official/projects/detr/modeling/transformer_test.py +263 -0
  15. official/projects/detr/ops/__init__.py +14 -0
  16. official/projects/detr/ops/matchers.py +489 -0
  17. official/projects/detr/ops/matchers_test.py +95 -0
  18. official/projects/detr/optimization.py +151 -0
  19. official/projects/detr/serving/__init__.py +14 -0
  20. official/projects/detr/serving/export_module.py +103 -0
  21. official/projects/detr/serving/export_module_test.py +98 -0
  22. official/projects/detr/serving/export_saved_model.py +109 -0
  23. official/projects/detr/tasks/__init__.py +14 -0
  24. official/projects/detr/tasks/detection.py +421 -0
  25. official/projects/detr/tasks/detection_test.py +203 -0
  26. official/projects/detr/train.py +70 -0
  27. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/METADATA +1 -1
  28. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/RECORD +32 -6
  29. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/AUTHORS +0 -0
  30. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/LICENSE +0 -0
  31. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/WHEEL +0 -0
  32. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """COCO data loader for DETR."""
16
+
17
+ from typing import Tuple
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.vision.dataloaders import parser
21
+
22
+ from official.vision.ops import box_ops
23
+ from official.vision.ops import preprocess_ops
24
+
25
+ RESIZE_SCALES = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
26
+
27
+
28
+ class Parser(parser.Parser):
29
+ """Parse an image and its annotations into a dictionary of tensors."""
30
+
31
+ def __init__(self,
32
+ class_offset: int = 0,
33
+ output_size: Tuple[int, int] = (1333, 1333),
34
+ max_num_boxes: int = 100,
35
+ resize_scales: Tuple[int, ...] = RESIZE_SCALES,
36
+ aug_rand_hflip=True):
37
+ self._class_offset = class_offset
38
+ self._output_size = output_size
39
+ self._max_num_boxes = max_num_boxes
40
+ self._resize_scales = resize_scales
41
+ self._aug_rand_hflip = aug_rand_hflip
42
+
43
+ def _parse_train_data(self, data):
44
+ """Parses data for training and evaluation."""
45
+ classes = data['groundtruth_classes'] + self._class_offset
46
+ boxes = data['groundtruth_boxes']
47
+ is_crowd = data['groundtruth_is_crowd']
48
+
49
+ # Gets original image.
50
+ image = data['image']
51
+
52
+ # Normalizes image with mean and std pixel values.
53
+ image = preprocess_ops.normalize_image(image)
54
+ image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
55
+
56
+ do_crop = tf.greater(tf.random.uniform([]), 0.5)
57
+ if do_crop:
58
+ # Rescale
59
+ boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
60
+ index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
61
+ scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
62
+ short_side = scales[0]
63
+ image, image_info = preprocess_ops.resize_image(image, short_side)
64
+ boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
65
+ image_info[1, :],
66
+ image_info[3, :])
67
+ boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
68
+
69
+ # Do croping
70
+ shape = tf.cast(image_info[1], dtype=tf.int32)
71
+ h = tf.random.uniform([],
72
+ 384,
73
+ tf.math.minimum(shape[0], 600),
74
+ dtype=tf.int32)
75
+ w = tf.random.uniform([],
76
+ 384,
77
+ tf.math.minimum(shape[1], 600),
78
+ dtype=tf.int32)
79
+ i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
80
+ j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
81
+ image = tf.image.crop_to_bounding_box(image, i, j, h, w)
82
+ boxes = tf.clip_by_value(
83
+ (boxes[..., :] * tf.cast(
84
+ tf.stack([shape[0], shape[1], shape[0], shape[1]]),
85
+ dtype=tf.float32) -
86
+ tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
87
+ tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
88
+ scales = tf.constant(self._resize_scales, dtype=tf.float32)
89
+ index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
90
+ scales = tf.gather(scales, index, axis=0)
91
+
92
+ image_shape = tf.shape(image)[:2]
93
+ boxes = box_ops.denormalize_boxes(boxes, image_shape)
94
+ short_side = scales[0]
95
+ image, image_info = preprocess_ops.resize_image(image, short_side,
96
+ max(self._output_size))
97
+ boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
98
+ image_info[1, :],
99
+ image_info[3, :])
100
+ boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
101
+
102
+ # Filters out ground truth boxes that are all zeros.
103
+ indices = box_ops.get_non_empty_box_indices(boxes)
104
+ boxes = tf.gather(boxes, indices)
105
+ classes = tf.gather(classes, indices)
106
+ is_crowd = tf.gather(is_crowd, indices)
107
+ boxes = box_ops.yxyx_to_cycxhw(boxes)
108
+
109
+ image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
110
+ self._output_size[1])
111
+ labels = {
112
+ 'classes':
113
+ preprocess_ops.clip_or_pad_to_fixed_size(classes,
114
+ self._max_num_boxes),
115
+ 'boxes':
116
+ preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
117
+ }
118
+
119
+ return image, labels
120
+
121
+ def _parse_eval_data(self, data):
122
+ """Parses data for training and evaluation."""
123
+ classes = data['groundtruth_classes']
124
+ boxes = data['groundtruth_boxes']
125
+ is_crowd = data['groundtruth_is_crowd']
126
+
127
+ # Gets original image and its size.
128
+ image = data['image']
129
+
130
+ # Normalizes image with mean and std pixel values.
131
+ image = preprocess_ops.normalize_image(image)
132
+
133
+ scales = tf.constant([self._resize_scales[-1]], tf.float32)
134
+
135
+ image_shape = tf.shape(image)[:2]
136
+ boxes = box_ops.denormalize_boxes(boxes, image_shape)
137
+ gt_boxes = boxes
138
+ short_side = scales[0]
139
+ image, image_info = preprocess_ops.resize_image(image, short_side,
140
+ max(self._output_size))
141
+ boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
142
+ image_info[1, :],
143
+ image_info[3, :])
144
+ boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
145
+
146
+ # Filters out ground truth boxes that are all zeros.
147
+ indices = box_ops.get_non_empty_box_indices(boxes)
148
+ boxes = tf.gather(boxes, indices)
149
+ classes = tf.gather(classes, indices)
150
+ is_crowd = tf.gather(is_crowd, indices)
151
+ boxes = box_ops.yxyx_to_cycxhw(boxes)
152
+
153
+ image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
154
+ self._output_size[1])
155
+ labels = {
156
+ 'classes':
157
+ preprocess_ops.clip_or_pad_to_fixed_size(classes,
158
+ self._max_num_boxes),
159
+ 'boxes':
160
+ preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
161
+ }
162
+ labels.update({
163
+ 'id':
164
+ int(data['source_id']),
165
+ 'image_info':
166
+ image_info,
167
+ 'is_crowd':
168
+ preprocess_ops.clip_or_pad_to_fixed_size(is_crowd,
169
+ self._max_num_boxes),
170
+ 'gt_boxes':
171
+ preprocess_ops.clip_or_pad_to_fixed_size(gt_boxes,
172
+ self._max_num_boxes),
173
+ })
174
+
175
+ return image, labels
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,345 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Implements End-to-End Object Detection with Transformers.
16
+
17
+ Model paper: https://arxiv.org/abs/2005.12872
18
+ This module does not support Keras de/serialization. Please use
19
+ tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
20
+ for graph serializaiton.
21
+ """
22
+ import math
23
+ from typing import Any, List
24
+
25
+ import tensorflow as tf, tf_keras
26
+
27
+ from official.modeling import tf_utils
28
+ from official.projects.detr.modeling import transformer
29
+ from official.vision.ops import box_ops
30
+
31
+
32
+ def position_embedding_sine(attention_mask,
33
+ num_pos_features=256,
34
+ temperature=10000.,
35
+ normalize=True,
36
+ scale=2 * math.pi):
37
+ """Sine-based positional embeddings for 2D images.
38
+
39
+ Args:
40
+ attention_mask: a `bool` Tensor specifying the size of the input image to
41
+ the Transformer and which elements are padded, of size [batch_size,
42
+ height, width]
43
+ num_pos_features: a `int` specifying the number of positional features,
44
+ should be equal to the hidden size of the Transformer network
45
+ temperature: a `float` specifying the temperature of the positional
46
+ embedding. Any type that is converted to a `float` can also be accepted.
47
+ normalize: a `bool` determining whether the positional embeddings should be
48
+ normalized between [0, scale] before application of the sine and cos
49
+ functions.
50
+ scale: a `float` if normalize is True specifying the scale embeddings before
51
+ application of the embedding function.
52
+
53
+ Returns:
54
+ embeddings: a `float` tensor of the same shape as input_tensor specifying
55
+ the positional embeddings based on sine features.
56
+ """
57
+ if num_pos_features % 2 != 0:
58
+ raise ValueError(
59
+ "Number of embedding features (num_pos_features) must be even when "
60
+ "column and row embeddings are concatenated.")
61
+ num_pos_features = num_pos_features // 2
62
+
63
+ # Produce row and column embeddings based on total size of the image
64
+ # <tf.float>[batch_size, height, width]
65
+ attention_mask = tf.cast(attention_mask, tf.float32)
66
+ row_embedding = tf.cumsum(attention_mask, 1)
67
+ col_embedding = tf.cumsum(attention_mask, 2)
68
+
69
+ if normalize:
70
+ eps = 1e-6
71
+ row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale
72
+ col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale
73
+
74
+ dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype)
75
+ dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features)
76
+
77
+ # Creates positional embeddings for each row and column position
78
+ # <tf.float>[batch_size, height, width, num_pos_features]
79
+ pos_row = tf.expand_dims(row_embedding, -1) / dim_t
80
+ pos_col = tf.expand_dims(col_embedding, -1) / dim_t
81
+ pos_row = tf.stack(
82
+ [tf.sin(pos_row[:, :, :, 0::2]),
83
+ tf.cos(pos_row[:, :, :, 1::2])], axis=4)
84
+ pos_col = tf.stack(
85
+ [tf.sin(pos_col[:, :, :, 0::2]),
86
+ tf.cos(pos_col[:, :, :, 1::2])], axis=4)
87
+
88
+ # final_shape = pos_row.shape.as_list()[:3] + [-1]
89
+ final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1]
90
+ pos_row = tf.reshape(pos_row, final_shape)
91
+ pos_col = tf.reshape(pos_col, final_shape)
92
+ output = tf.concat([pos_row, pos_col], -1)
93
+
94
+ embeddings = tf.cast(output, tf.float32)
95
+ return embeddings
96
+
97
+
98
+ def postprocess(outputs: dict[str, tf.Tensor]) -> dict[str, tf.Tensor]:
99
+ """Performs post-processing on model output.
100
+
101
+ Args:
102
+ outputs: The raw model output.
103
+
104
+ Returns:
105
+ Postprocessed model output.
106
+ """
107
+ predictions = {
108
+ "detection_boxes": # Box coordinates are relative values here.
109
+ box_ops.cycxhw_to_yxyx(outputs["box_outputs"]),
110
+ "detection_scores":
111
+ tf.math.reduce_max(
112
+ tf.nn.softmax(outputs["cls_outputs"])[:, :, 1:], axis=-1),
113
+ "detection_classes":
114
+ tf.math.argmax(outputs["cls_outputs"][:, :, 1:], axis=-1) + 1,
115
+ # Fix this. It's not being used at the moment.
116
+ "num_detections":
117
+ tf.reduce_sum(
118
+ tf.cast(
119
+ tf.math.greater(
120
+ tf.math.reduce_max(outputs["cls_outputs"], axis=-1), 0),
121
+ tf.int32),
122
+ axis=-1)
123
+ }
124
+ return predictions
125
+
126
+
127
+ class DETR(tf_keras.Model):
128
+ """DETR model with Keras.
129
+
130
+ DETR consists of backbone, query embedding, DETRTransformer,
131
+ class and box heads.
132
+ """
133
+
134
+ def __init__(self,
135
+ backbone,
136
+ backbone_endpoint_name,
137
+ num_queries,
138
+ hidden_size,
139
+ num_classes,
140
+ num_encoder_layers=6,
141
+ num_decoder_layers=6,
142
+ dropout_rate=0.1,
143
+ **kwargs):
144
+ super().__init__(**kwargs)
145
+ self._num_queries = num_queries
146
+ self._hidden_size = hidden_size
147
+ self._num_classes = num_classes
148
+ self._num_encoder_layers = num_encoder_layers
149
+ self._num_decoder_layers = num_decoder_layers
150
+ self._dropout_rate = dropout_rate
151
+ if hidden_size % 2 != 0:
152
+ raise ValueError("hidden_size must be a multiple of 2.")
153
+ self._backbone = backbone
154
+ self._backbone_endpoint_name = backbone_endpoint_name
155
+
156
+ def build(self, input_shape=None):
157
+ self._input_proj = tf_keras.layers.Conv2D(
158
+ self._hidden_size, 1, name="detr/conv2d")
159
+ self._build_detection_decoder()
160
+ super().build(input_shape)
161
+
162
+ def _build_detection_decoder(self):
163
+ """Builds detection decoder."""
164
+ self._transformer = DETRTransformer(
165
+ num_encoder_layers=self._num_encoder_layers,
166
+ num_decoder_layers=self._num_decoder_layers,
167
+ dropout_rate=self._dropout_rate)
168
+ self._query_embeddings = self.add_weight(
169
+ "detr/query_embeddings",
170
+ shape=[self._num_queries, self._hidden_size],
171
+ initializer=tf_keras.initializers.RandomNormal(mean=0., stddev=1.),
172
+ dtype=tf.float32)
173
+ sqrt_k = math.sqrt(1.0 / self._hidden_size)
174
+ self._class_embed = tf_keras.layers.Dense(
175
+ self._num_classes,
176
+ kernel_initializer=tf_keras.initializers.RandomUniform(-sqrt_k, sqrt_k),
177
+ name="detr/cls_dense")
178
+ self._bbox_embed = [
179
+ tf_keras.layers.Dense(
180
+ self._hidden_size, activation="relu",
181
+ kernel_initializer=tf_keras.initializers.RandomUniform(
182
+ -sqrt_k, sqrt_k),
183
+ name="detr/box_dense_0"),
184
+ tf_keras.layers.Dense(
185
+ self._hidden_size, activation="relu",
186
+ kernel_initializer=tf_keras.initializers.RandomUniform(
187
+ -sqrt_k, sqrt_k),
188
+ name="detr/box_dense_1"),
189
+ tf_keras.layers.Dense(
190
+ 4, kernel_initializer=tf_keras.initializers.RandomUniform(
191
+ -sqrt_k, sqrt_k),
192
+ name="detr/box_dense_2")]
193
+ self._sigmoid = tf_keras.layers.Activation("sigmoid")
194
+
195
+ @property
196
+ def backbone(self) -> tf_keras.Model:
197
+ return self._backbone
198
+
199
+ def get_config(self):
200
+ return {
201
+ "backbone": self._backbone,
202
+ "backbone_endpoint_name": self._backbone_endpoint_name,
203
+ "num_queries": self._num_queries,
204
+ "hidden_size": self._hidden_size,
205
+ "num_classes": self._num_classes,
206
+ "num_encoder_layers": self._num_encoder_layers,
207
+ "num_decoder_layers": self._num_decoder_layers,
208
+ "dropout_rate": self._dropout_rate,
209
+ }
210
+
211
+ @classmethod
212
+ def from_config(cls, config):
213
+ return cls(**config)
214
+
215
+ def _generate_image_mask(self, inputs: tf.Tensor,
216
+ target_shape: tf.Tensor) -> tf.Tensor:
217
+ """Generates image mask from input image."""
218
+ mask = tf.expand_dims(
219
+ tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype),
220
+ axis=-1)
221
+ mask = tf.image.resize(
222
+ mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
223
+ return mask
224
+
225
+ def call(self, inputs: tf.Tensor, training: bool = None) -> List[Any]: # pytype: disable=annotation-type-mismatch,signature-mismatch
226
+ batch_size = tf.shape(inputs)[0]
227
+ features = self._backbone(inputs)[self._backbone_endpoint_name]
228
+ shape = tf.shape(features)
229
+ mask = self._generate_image_mask(inputs, shape[1: 3])
230
+
231
+ pos_embed = position_embedding_sine(
232
+ mask[:, :, :, 0], num_pos_features=self._hidden_size)
233
+ pos_embed = tf.reshape(pos_embed, [batch_size, -1, self._hidden_size])
234
+
235
+ features = tf.reshape(
236
+ self._input_proj(features), [batch_size, -1, self._hidden_size])
237
+ mask = tf.reshape(mask, [batch_size, -1])
238
+
239
+ decoded_list = self._transformer({
240
+ "inputs":
241
+ features,
242
+ "targets":
243
+ tf.tile(
244
+ tf.expand_dims(self._query_embeddings, axis=0),
245
+ (batch_size, 1, 1)),
246
+ "pos_embed": pos_embed,
247
+ "mask": mask,
248
+ })
249
+ out_list = []
250
+ for decoded in decoded_list:
251
+ decoded = tf.stack(decoded)
252
+ output_class = self._class_embed(decoded)
253
+ box_out = decoded
254
+ for layer in self._bbox_embed:
255
+ box_out = layer(box_out)
256
+ output_coord = self._sigmoid(box_out)
257
+ out = {"cls_outputs": output_class, "box_outputs": output_coord}
258
+ if not training:
259
+ out.update(postprocess(out))
260
+ out_list.append(out)
261
+
262
+ return out_list
263
+
264
+
265
+ class DETRTransformer(tf_keras.layers.Layer):
266
+ """Encoder and Decoder of DETR."""
267
+
268
+ def __init__(
269
+ self,
270
+ num_encoder_layers=6,
271
+ num_decoder_layers=6,
272
+ num_attention_heads=8,
273
+ intermediate_size=2048,
274
+ dropout_rate=0.1,
275
+ **kwargs
276
+ ):
277
+ super().__init__(**kwargs)
278
+ self._dropout_rate = dropout_rate
279
+ self._num_encoder_layers = num_encoder_layers
280
+ self._num_decoder_layers = num_decoder_layers
281
+ self._num_attention_heads = num_attention_heads
282
+ self._intermediate_size = intermediate_size
283
+
284
+ def build(self, input_shape=None):
285
+ if self._num_encoder_layers > 0:
286
+ self._encoder = transformer.TransformerEncoder(
287
+ attention_dropout_rate=self._dropout_rate,
288
+ dropout_rate=self._dropout_rate,
289
+ intermediate_dropout=self._dropout_rate,
290
+ norm_first=False,
291
+ num_layers=self._num_encoder_layers,
292
+ num_attention_heads=self._num_attention_heads,
293
+ intermediate_size=self._intermediate_size,
294
+ )
295
+ else:
296
+ self._encoder = None
297
+
298
+ self._decoder = transformer.TransformerDecoder(
299
+ attention_dropout_rate=self._dropout_rate,
300
+ dropout_rate=self._dropout_rate,
301
+ intermediate_dropout=self._dropout_rate,
302
+ norm_first=False,
303
+ num_layers=self._num_decoder_layers,
304
+ num_attention_heads=self._num_attention_heads,
305
+ intermediate_size=self._intermediate_size,
306
+ )
307
+ super().build(input_shape)
308
+
309
+ def get_config(self):
310
+ return {
311
+ "num_encoder_layers": self._num_encoder_layers,
312
+ "num_decoder_layers": self._num_decoder_layers,
313
+ "dropout_rate": self._dropout_rate,
314
+ }
315
+
316
+ def call(self, inputs):
317
+ sources = inputs["inputs"]
318
+ targets = inputs["targets"]
319
+ pos_embed = inputs["pos_embed"]
320
+ mask = inputs["mask"]
321
+ input_shape = tf_utils.get_shape_list(sources)
322
+ source_attention_mask = tf.tile(
323
+ tf.expand_dims(mask, axis=1), [1, input_shape[1], 1])
324
+ if self._encoder is not None:
325
+ memory = self._encoder(
326
+ sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
327
+ else:
328
+ memory = sources
329
+
330
+ target_shape = tf_utils.get_shape_list(targets)
331
+ cross_attention_mask = tf.tile(
332
+ tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
333
+ target_shape = tf.shape(targets)
334
+ decoded = self._decoder(
335
+ tf.zeros_like(targets),
336
+ memory,
337
+ # TODO(b/199545430): self_attention_mask could be set to None when this
338
+ # bug is resolved. Passing ones for now.
339
+ self_attention_mask=tf.ones(
340
+ (target_shape[0], target_shape[1], target_shape[1])),
341
+ cross_attention_mask=cross_attention_mask,
342
+ return_all_decoder_outputs=True,
343
+ input_pos_embed=targets,
344
+ memory_pos_embed=pos_embed)
345
+ return decoded
@@ -0,0 +1,70 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for tensorflow_models.official.projects.detr.detr."""
16
+ import tensorflow as tf, tf_keras
17
+ from official.projects.detr.modeling import detr
18
+ from official.vision.modeling.backbones import resnet
19
+
20
+
21
+ class DetrTest(tf.test.TestCase):
22
+
23
+ def test_forward(self):
24
+ num_queries = 10
25
+ hidden_size = 128
26
+ num_classes = 10
27
+ image_size = 640
28
+ batch_size = 2
29
+ backbone = resnet.ResNet(50, bn_trainable=False)
30
+ backbone_endpoint_name = '5'
31
+ model = detr.DETR(backbone, backbone_endpoint_name, num_queries,
32
+ hidden_size, num_classes)
33
+ outs = model(tf.ones((batch_size, image_size, image_size, 3)))
34
+ self.assertLen(outs, 6) # intermediate decoded outputs.
35
+ for out in outs:
36
+ self.assertAllEqual(
37
+ tf.shape(out['cls_outputs']), (batch_size, num_queries, num_classes))
38
+ self.assertAllEqual(
39
+ tf.shape(out['box_outputs']), (batch_size, num_queries, 4))
40
+
41
+ def test_get_from_config_detr_transformer(self):
42
+ config = {
43
+ 'num_encoder_layers': 1,
44
+ 'num_decoder_layers': 2,
45
+ 'dropout_rate': 0.5,
46
+ }
47
+ detr_model = detr.DETRTransformer.from_config(config)
48
+ retrieved_config = detr_model.get_config()
49
+
50
+ self.assertEqual(config, retrieved_config)
51
+
52
+ def test_get_from_config_detr(self):
53
+ config = {
54
+ 'backbone': resnet.ResNet(50, bn_trainable=False),
55
+ 'backbone_endpoint_name': '5',
56
+ 'num_queries': 2,
57
+ 'hidden_size': 4,
58
+ 'num_classes': 10,
59
+ 'num_encoder_layers': 4,
60
+ 'num_decoder_layers': 5,
61
+ 'dropout_rate': 0.5,
62
+ }
63
+ detr_model = detr.DETR.from_config(config)
64
+ retrieved_config = detr_model.get_config()
65
+
66
+ self.assertEqual(config, retrieved_config)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ tf.test.main()