tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250110__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-models-nightly might be problematic. Click here for more details.
- official/projects/detr/__init__.py +14 -0
- official/projects/detr/configs/__init__.py +14 -0
- official/projects/detr/configs/detr.py +277 -0
- official/projects/detr/configs/detr_test.py +51 -0
- official/projects/detr/dataloaders/__init__.py +14 -0
- official/projects/detr/dataloaders/coco.py +157 -0
- official/projects/detr/dataloaders/coco_test.py +111 -0
- official/projects/detr/dataloaders/detr_input.py +175 -0
- official/projects/detr/experiments/__init__.py +14 -0
- official/projects/detr/modeling/__init__.py +14 -0
- official/projects/detr/modeling/detr.py +345 -0
- official/projects/detr/modeling/detr_test.py +70 -0
- official/projects/detr/modeling/transformer.py +849 -0
- official/projects/detr/modeling/transformer_test.py +263 -0
- official/projects/detr/ops/__init__.py +14 -0
- official/projects/detr/ops/matchers.py +489 -0
- official/projects/detr/ops/matchers_test.py +95 -0
- official/projects/detr/optimization.py +151 -0
- official/projects/detr/serving/__init__.py +14 -0
- official/projects/detr/serving/export_module.py +103 -0
- official/projects/detr/serving/export_module_test.py +98 -0
- official/projects/detr/serving/export_saved_model.py +109 -0
- official/projects/detr/tasks/__init__.py +14 -0
- official/projects/detr/tasks/detection.py +421 -0
- official/projects/detr/tasks/detection_test.py +203 -0
- official/projects/detr/train.py +70 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/RECORD +32 -6
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""COCO data loader for DETR."""
|
16
|
+
|
17
|
+
from typing import Tuple
|
18
|
+
import tensorflow as tf, tf_keras
|
19
|
+
|
20
|
+
from official.vision.dataloaders import parser
|
21
|
+
|
22
|
+
from official.vision.ops import box_ops
|
23
|
+
from official.vision.ops import preprocess_ops
|
24
|
+
|
25
|
+
RESIZE_SCALES = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
26
|
+
|
27
|
+
|
28
|
+
class Parser(parser.Parser):
|
29
|
+
"""Parse an image and its annotations into a dictionary of tensors."""
|
30
|
+
|
31
|
+
def __init__(self,
|
32
|
+
class_offset: int = 0,
|
33
|
+
output_size: Tuple[int, int] = (1333, 1333),
|
34
|
+
max_num_boxes: int = 100,
|
35
|
+
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
|
36
|
+
aug_rand_hflip=True):
|
37
|
+
self._class_offset = class_offset
|
38
|
+
self._output_size = output_size
|
39
|
+
self._max_num_boxes = max_num_boxes
|
40
|
+
self._resize_scales = resize_scales
|
41
|
+
self._aug_rand_hflip = aug_rand_hflip
|
42
|
+
|
43
|
+
def _parse_train_data(self, data):
|
44
|
+
"""Parses data for training and evaluation."""
|
45
|
+
classes = data['groundtruth_classes'] + self._class_offset
|
46
|
+
boxes = data['groundtruth_boxes']
|
47
|
+
is_crowd = data['groundtruth_is_crowd']
|
48
|
+
|
49
|
+
# Gets original image.
|
50
|
+
image = data['image']
|
51
|
+
|
52
|
+
# Normalizes image with mean and std pixel values.
|
53
|
+
image = preprocess_ops.normalize_image(image)
|
54
|
+
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
|
55
|
+
|
56
|
+
do_crop = tf.greater(tf.random.uniform([]), 0.5)
|
57
|
+
if do_crop:
|
58
|
+
# Rescale
|
59
|
+
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
|
60
|
+
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
|
61
|
+
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
|
62
|
+
short_side = scales[0]
|
63
|
+
image, image_info = preprocess_ops.resize_image(image, short_side)
|
64
|
+
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
|
65
|
+
image_info[1, :],
|
66
|
+
image_info[3, :])
|
67
|
+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
|
68
|
+
|
69
|
+
# Do croping
|
70
|
+
shape = tf.cast(image_info[1], dtype=tf.int32)
|
71
|
+
h = tf.random.uniform([],
|
72
|
+
384,
|
73
|
+
tf.math.minimum(shape[0], 600),
|
74
|
+
dtype=tf.int32)
|
75
|
+
w = tf.random.uniform([],
|
76
|
+
384,
|
77
|
+
tf.math.minimum(shape[1], 600),
|
78
|
+
dtype=tf.int32)
|
79
|
+
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
|
80
|
+
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
|
81
|
+
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
|
82
|
+
boxes = tf.clip_by_value(
|
83
|
+
(boxes[..., :] * tf.cast(
|
84
|
+
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
|
85
|
+
dtype=tf.float32) -
|
86
|
+
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
|
87
|
+
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
|
88
|
+
scales = tf.constant(self._resize_scales, dtype=tf.float32)
|
89
|
+
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
|
90
|
+
scales = tf.gather(scales, index, axis=0)
|
91
|
+
|
92
|
+
image_shape = tf.shape(image)[:2]
|
93
|
+
boxes = box_ops.denormalize_boxes(boxes, image_shape)
|
94
|
+
short_side = scales[0]
|
95
|
+
image, image_info = preprocess_ops.resize_image(image, short_side,
|
96
|
+
max(self._output_size))
|
97
|
+
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
|
98
|
+
image_info[1, :],
|
99
|
+
image_info[3, :])
|
100
|
+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
|
101
|
+
|
102
|
+
# Filters out ground truth boxes that are all zeros.
|
103
|
+
indices = box_ops.get_non_empty_box_indices(boxes)
|
104
|
+
boxes = tf.gather(boxes, indices)
|
105
|
+
classes = tf.gather(classes, indices)
|
106
|
+
is_crowd = tf.gather(is_crowd, indices)
|
107
|
+
boxes = box_ops.yxyx_to_cycxhw(boxes)
|
108
|
+
|
109
|
+
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
|
110
|
+
self._output_size[1])
|
111
|
+
labels = {
|
112
|
+
'classes':
|
113
|
+
preprocess_ops.clip_or_pad_to_fixed_size(classes,
|
114
|
+
self._max_num_boxes),
|
115
|
+
'boxes':
|
116
|
+
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
|
117
|
+
}
|
118
|
+
|
119
|
+
return image, labels
|
120
|
+
|
121
|
+
def _parse_eval_data(self, data):
|
122
|
+
"""Parses data for training and evaluation."""
|
123
|
+
classes = data['groundtruth_classes']
|
124
|
+
boxes = data['groundtruth_boxes']
|
125
|
+
is_crowd = data['groundtruth_is_crowd']
|
126
|
+
|
127
|
+
# Gets original image and its size.
|
128
|
+
image = data['image']
|
129
|
+
|
130
|
+
# Normalizes image with mean and std pixel values.
|
131
|
+
image = preprocess_ops.normalize_image(image)
|
132
|
+
|
133
|
+
scales = tf.constant([self._resize_scales[-1]], tf.float32)
|
134
|
+
|
135
|
+
image_shape = tf.shape(image)[:2]
|
136
|
+
boxes = box_ops.denormalize_boxes(boxes, image_shape)
|
137
|
+
gt_boxes = boxes
|
138
|
+
short_side = scales[0]
|
139
|
+
image, image_info = preprocess_ops.resize_image(image, short_side,
|
140
|
+
max(self._output_size))
|
141
|
+
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
|
142
|
+
image_info[1, :],
|
143
|
+
image_info[3, :])
|
144
|
+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
|
145
|
+
|
146
|
+
# Filters out ground truth boxes that are all zeros.
|
147
|
+
indices = box_ops.get_non_empty_box_indices(boxes)
|
148
|
+
boxes = tf.gather(boxes, indices)
|
149
|
+
classes = tf.gather(classes, indices)
|
150
|
+
is_crowd = tf.gather(is_crowd, indices)
|
151
|
+
boxes = box_ops.yxyx_to_cycxhw(boxes)
|
152
|
+
|
153
|
+
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
|
154
|
+
self._output_size[1])
|
155
|
+
labels = {
|
156
|
+
'classes':
|
157
|
+
preprocess_ops.clip_or_pad_to_fixed_size(classes,
|
158
|
+
self._max_num_boxes),
|
159
|
+
'boxes':
|
160
|
+
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
|
161
|
+
}
|
162
|
+
labels.update({
|
163
|
+
'id':
|
164
|
+
int(data['source_id']),
|
165
|
+
'image_info':
|
166
|
+
image_info,
|
167
|
+
'is_crowd':
|
168
|
+
preprocess_ops.clip_or_pad_to_fixed_size(is_crowd,
|
169
|
+
self._max_num_boxes),
|
170
|
+
'gt_boxes':
|
171
|
+
preprocess_ops.clip_or_pad_to_fixed_size(gt_boxes,
|
172
|
+
self._max_num_boxes),
|
173
|
+
})
|
174
|
+
|
175
|
+
return image, labels
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,345 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Implements End-to-End Object Detection with Transformers.
|
16
|
+
|
17
|
+
Model paper: https://arxiv.org/abs/2005.12872
|
18
|
+
This module does not support Keras de/serialization. Please use
|
19
|
+
tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
|
20
|
+
for graph serializaiton.
|
21
|
+
"""
|
22
|
+
import math
|
23
|
+
from typing import Any, List
|
24
|
+
|
25
|
+
import tensorflow as tf, tf_keras
|
26
|
+
|
27
|
+
from official.modeling import tf_utils
|
28
|
+
from official.projects.detr.modeling import transformer
|
29
|
+
from official.vision.ops import box_ops
|
30
|
+
|
31
|
+
|
32
|
+
def position_embedding_sine(attention_mask,
|
33
|
+
num_pos_features=256,
|
34
|
+
temperature=10000.,
|
35
|
+
normalize=True,
|
36
|
+
scale=2 * math.pi):
|
37
|
+
"""Sine-based positional embeddings for 2D images.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
attention_mask: a `bool` Tensor specifying the size of the input image to
|
41
|
+
the Transformer and which elements are padded, of size [batch_size,
|
42
|
+
height, width]
|
43
|
+
num_pos_features: a `int` specifying the number of positional features,
|
44
|
+
should be equal to the hidden size of the Transformer network
|
45
|
+
temperature: a `float` specifying the temperature of the positional
|
46
|
+
embedding. Any type that is converted to a `float` can also be accepted.
|
47
|
+
normalize: a `bool` determining whether the positional embeddings should be
|
48
|
+
normalized between [0, scale] before application of the sine and cos
|
49
|
+
functions.
|
50
|
+
scale: a `float` if normalize is True specifying the scale embeddings before
|
51
|
+
application of the embedding function.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
embeddings: a `float` tensor of the same shape as input_tensor specifying
|
55
|
+
the positional embeddings based on sine features.
|
56
|
+
"""
|
57
|
+
if num_pos_features % 2 != 0:
|
58
|
+
raise ValueError(
|
59
|
+
"Number of embedding features (num_pos_features) must be even when "
|
60
|
+
"column and row embeddings are concatenated.")
|
61
|
+
num_pos_features = num_pos_features // 2
|
62
|
+
|
63
|
+
# Produce row and column embeddings based on total size of the image
|
64
|
+
# <tf.float>[batch_size, height, width]
|
65
|
+
attention_mask = tf.cast(attention_mask, tf.float32)
|
66
|
+
row_embedding = tf.cumsum(attention_mask, 1)
|
67
|
+
col_embedding = tf.cumsum(attention_mask, 2)
|
68
|
+
|
69
|
+
if normalize:
|
70
|
+
eps = 1e-6
|
71
|
+
row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale
|
72
|
+
col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale
|
73
|
+
|
74
|
+
dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype)
|
75
|
+
dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features)
|
76
|
+
|
77
|
+
# Creates positional embeddings for each row and column position
|
78
|
+
# <tf.float>[batch_size, height, width, num_pos_features]
|
79
|
+
pos_row = tf.expand_dims(row_embedding, -1) / dim_t
|
80
|
+
pos_col = tf.expand_dims(col_embedding, -1) / dim_t
|
81
|
+
pos_row = tf.stack(
|
82
|
+
[tf.sin(pos_row[:, :, :, 0::2]),
|
83
|
+
tf.cos(pos_row[:, :, :, 1::2])], axis=4)
|
84
|
+
pos_col = tf.stack(
|
85
|
+
[tf.sin(pos_col[:, :, :, 0::2]),
|
86
|
+
tf.cos(pos_col[:, :, :, 1::2])], axis=4)
|
87
|
+
|
88
|
+
# final_shape = pos_row.shape.as_list()[:3] + [-1]
|
89
|
+
final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1]
|
90
|
+
pos_row = tf.reshape(pos_row, final_shape)
|
91
|
+
pos_col = tf.reshape(pos_col, final_shape)
|
92
|
+
output = tf.concat([pos_row, pos_col], -1)
|
93
|
+
|
94
|
+
embeddings = tf.cast(output, tf.float32)
|
95
|
+
return embeddings
|
96
|
+
|
97
|
+
|
98
|
+
def postprocess(outputs: dict[str, tf.Tensor]) -> dict[str, tf.Tensor]:
|
99
|
+
"""Performs post-processing on model output.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
outputs: The raw model output.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Postprocessed model output.
|
106
|
+
"""
|
107
|
+
predictions = {
|
108
|
+
"detection_boxes": # Box coordinates are relative values here.
|
109
|
+
box_ops.cycxhw_to_yxyx(outputs["box_outputs"]),
|
110
|
+
"detection_scores":
|
111
|
+
tf.math.reduce_max(
|
112
|
+
tf.nn.softmax(outputs["cls_outputs"])[:, :, 1:], axis=-1),
|
113
|
+
"detection_classes":
|
114
|
+
tf.math.argmax(outputs["cls_outputs"][:, :, 1:], axis=-1) + 1,
|
115
|
+
# Fix this. It's not being used at the moment.
|
116
|
+
"num_detections":
|
117
|
+
tf.reduce_sum(
|
118
|
+
tf.cast(
|
119
|
+
tf.math.greater(
|
120
|
+
tf.math.reduce_max(outputs["cls_outputs"], axis=-1), 0),
|
121
|
+
tf.int32),
|
122
|
+
axis=-1)
|
123
|
+
}
|
124
|
+
return predictions
|
125
|
+
|
126
|
+
|
127
|
+
class DETR(tf_keras.Model):
|
128
|
+
"""DETR model with Keras.
|
129
|
+
|
130
|
+
DETR consists of backbone, query embedding, DETRTransformer,
|
131
|
+
class and box heads.
|
132
|
+
"""
|
133
|
+
|
134
|
+
def __init__(self,
|
135
|
+
backbone,
|
136
|
+
backbone_endpoint_name,
|
137
|
+
num_queries,
|
138
|
+
hidden_size,
|
139
|
+
num_classes,
|
140
|
+
num_encoder_layers=6,
|
141
|
+
num_decoder_layers=6,
|
142
|
+
dropout_rate=0.1,
|
143
|
+
**kwargs):
|
144
|
+
super().__init__(**kwargs)
|
145
|
+
self._num_queries = num_queries
|
146
|
+
self._hidden_size = hidden_size
|
147
|
+
self._num_classes = num_classes
|
148
|
+
self._num_encoder_layers = num_encoder_layers
|
149
|
+
self._num_decoder_layers = num_decoder_layers
|
150
|
+
self._dropout_rate = dropout_rate
|
151
|
+
if hidden_size % 2 != 0:
|
152
|
+
raise ValueError("hidden_size must be a multiple of 2.")
|
153
|
+
self._backbone = backbone
|
154
|
+
self._backbone_endpoint_name = backbone_endpoint_name
|
155
|
+
|
156
|
+
def build(self, input_shape=None):
|
157
|
+
self._input_proj = tf_keras.layers.Conv2D(
|
158
|
+
self._hidden_size, 1, name="detr/conv2d")
|
159
|
+
self._build_detection_decoder()
|
160
|
+
super().build(input_shape)
|
161
|
+
|
162
|
+
def _build_detection_decoder(self):
|
163
|
+
"""Builds detection decoder."""
|
164
|
+
self._transformer = DETRTransformer(
|
165
|
+
num_encoder_layers=self._num_encoder_layers,
|
166
|
+
num_decoder_layers=self._num_decoder_layers,
|
167
|
+
dropout_rate=self._dropout_rate)
|
168
|
+
self._query_embeddings = self.add_weight(
|
169
|
+
"detr/query_embeddings",
|
170
|
+
shape=[self._num_queries, self._hidden_size],
|
171
|
+
initializer=tf_keras.initializers.RandomNormal(mean=0., stddev=1.),
|
172
|
+
dtype=tf.float32)
|
173
|
+
sqrt_k = math.sqrt(1.0 / self._hidden_size)
|
174
|
+
self._class_embed = tf_keras.layers.Dense(
|
175
|
+
self._num_classes,
|
176
|
+
kernel_initializer=tf_keras.initializers.RandomUniform(-sqrt_k, sqrt_k),
|
177
|
+
name="detr/cls_dense")
|
178
|
+
self._bbox_embed = [
|
179
|
+
tf_keras.layers.Dense(
|
180
|
+
self._hidden_size, activation="relu",
|
181
|
+
kernel_initializer=tf_keras.initializers.RandomUniform(
|
182
|
+
-sqrt_k, sqrt_k),
|
183
|
+
name="detr/box_dense_0"),
|
184
|
+
tf_keras.layers.Dense(
|
185
|
+
self._hidden_size, activation="relu",
|
186
|
+
kernel_initializer=tf_keras.initializers.RandomUniform(
|
187
|
+
-sqrt_k, sqrt_k),
|
188
|
+
name="detr/box_dense_1"),
|
189
|
+
tf_keras.layers.Dense(
|
190
|
+
4, kernel_initializer=tf_keras.initializers.RandomUniform(
|
191
|
+
-sqrt_k, sqrt_k),
|
192
|
+
name="detr/box_dense_2")]
|
193
|
+
self._sigmoid = tf_keras.layers.Activation("sigmoid")
|
194
|
+
|
195
|
+
@property
|
196
|
+
def backbone(self) -> tf_keras.Model:
|
197
|
+
return self._backbone
|
198
|
+
|
199
|
+
def get_config(self):
|
200
|
+
return {
|
201
|
+
"backbone": self._backbone,
|
202
|
+
"backbone_endpoint_name": self._backbone_endpoint_name,
|
203
|
+
"num_queries": self._num_queries,
|
204
|
+
"hidden_size": self._hidden_size,
|
205
|
+
"num_classes": self._num_classes,
|
206
|
+
"num_encoder_layers": self._num_encoder_layers,
|
207
|
+
"num_decoder_layers": self._num_decoder_layers,
|
208
|
+
"dropout_rate": self._dropout_rate,
|
209
|
+
}
|
210
|
+
|
211
|
+
@classmethod
|
212
|
+
def from_config(cls, config):
|
213
|
+
return cls(**config)
|
214
|
+
|
215
|
+
def _generate_image_mask(self, inputs: tf.Tensor,
|
216
|
+
target_shape: tf.Tensor) -> tf.Tensor:
|
217
|
+
"""Generates image mask from input image."""
|
218
|
+
mask = tf.expand_dims(
|
219
|
+
tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype),
|
220
|
+
axis=-1)
|
221
|
+
mask = tf.image.resize(
|
222
|
+
mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
223
|
+
return mask
|
224
|
+
|
225
|
+
def call(self, inputs: tf.Tensor, training: bool = None) -> List[Any]: # pytype: disable=annotation-type-mismatch,signature-mismatch
|
226
|
+
batch_size = tf.shape(inputs)[0]
|
227
|
+
features = self._backbone(inputs)[self._backbone_endpoint_name]
|
228
|
+
shape = tf.shape(features)
|
229
|
+
mask = self._generate_image_mask(inputs, shape[1: 3])
|
230
|
+
|
231
|
+
pos_embed = position_embedding_sine(
|
232
|
+
mask[:, :, :, 0], num_pos_features=self._hidden_size)
|
233
|
+
pos_embed = tf.reshape(pos_embed, [batch_size, -1, self._hidden_size])
|
234
|
+
|
235
|
+
features = tf.reshape(
|
236
|
+
self._input_proj(features), [batch_size, -1, self._hidden_size])
|
237
|
+
mask = tf.reshape(mask, [batch_size, -1])
|
238
|
+
|
239
|
+
decoded_list = self._transformer({
|
240
|
+
"inputs":
|
241
|
+
features,
|
242
|
+
"targets":
|
243
|
+
tf.tile(
|
244
|
+
tf.expand_dims(self._query_embeddings, axis=0),
|
245
|
+
(batch_size, 1, 1)),
|
246
|
+
"pos_embed": pos_embed,
|
247
|
+
"mask": mask,
|
248
|
+
})
|
249
|
+
out_list = []
|
250
|
+
for decoded in decoded_list:
|
251
|
+
decoded = tf.stack(decoded)
|
252
|
+
output_class = self._class_embed(decoded)
|
253
|
+
box_out = decoded
|
254
|
+
for layer in self._bbox_embed:
|
255
|
+
box_out = layer(box_out)
|
256
|
+
output_coord = self._sigmoid(box_out)
|
257
|
+
out = {"cls_outputs": output_class, "box_outputs": output_coord}
|
258
|
+
if not training:
|
259
|
+
out.update(postprocess(out))
|
260
|
+
out_list.append(out)
|
261
|
+
|
262
|
+
return out_list
|
263
|
+
|
264
|
+
|
265
|
+
class DETRTransformer(tf_keras.layers.Layer):
|
266
|
+
"""Encoder and Decoder of DETR."""
|
267
|
+
|
268
|
+
def __init__(
|
269
|
+
self,
|
270
|
+
num_encoder_layers=6,
|
271
|
+
num_decoder_layers=6,
|
272
|
+
num_attention_heads=8,
|
273
|
+
intermediate_size=2048,
|
274
|
+
dropout_rate=0.1,
|
275
|
+
**kwargs
|
276
|
+
):
|
277
|
+
super().__init__(**kwargs)
|
278
|
+
self._dropout_rate = dropout_rate
|
279
|
+
self._num_encoder_layers = num_encoder_layers
|
280
|
+
self._num_decoder_layers = num_decoder_layers
|
281
|
+
self._num_attention_heads = num_attention_heads
|
282
|
+
self._intermediate_size = intermediate_size
|
283
|
+
|
284
|
+
def build(self, input_shape=None):
|
285
|
+
if self._num_encoder_layers > 0:
|
286
|
+
self._encoder = transformer.TransformerEncoder(
|
287
|
+
attention_dropout_rate=self._dropout_rate,
|
288
|
+
dropout_rate=self._dropout_rate,
|
289
|
+
intermediate_dropout=self._dropout_rate,
|
290
|
+
norm_first=False,
|
291
|
+
num_layers=self._num_encoder_layers,
|
292
|
+
num_attention_heads=self._num_attention_heads,
|
293
|
+
intermediate_size=self._intermediate_size,
|
294
|
+
)
|
295
|
+
else:
|
296
|
+
self._encoder = None
|
297
|
+
|
298
|
+
self._decoder = transformer.TransformerDecoder(
|
299
|
+
attention_dropout_rate=self._dropout_rate,
|
300
|
+
dropout_rate=self._dropout_rate,
|
301
|
+
intermediate_dropout=self._dropout_rate,
|
302
|
+
norm_first=False,
|
303
|
+
num_layers=self._num_decoder_layers,
|
304
|
+
num_attention_heads=self._num_attention_heads,
|
305
|
+
intermediate_size=self._intermediate_size,
|
306
|
+
)
|
307
|
+
super().build(input_shape)
|
308
|
+
|
309
|
+
def get_config(self):
|
310
|
+
return {
|
311
|
+
"num_encoder_layers": self._num_encoder_layers,
|
312
|
+
"num_decoder_layers": self._num_decoder_layers,
|
313
|
+
"dropout_rate": self._dropout_rate,
|
314
|
+
}
|
315
|
+
|
316
|
+
def call(self, inputs):
|
317
|
+
sources = inputs["inputs"]
|
318
|
+
targets = inputs["targets"]
|
319
|
+
pos_embed = inputs["pos_embed"]
|
320
|
+
mask = inputs["mask"]
|
321
|
+
input_shape = tf_utils.get_shape_list(sources)
|
322
|
+
source_attention_mask = tf.tile(
|
323
|
+
tf.expand_dims(mask, axis=1), [1, input_shape[1], 1])
|
324
|
+
if self._encoder is not None:
|
325
|
+
memory = self._encoder(
|
326
|
+
sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
|
327
|
+
else:
|
328
|
+
memory = sources
|
329
|
+
|
330
|
+
target_shape = tf_utils.get_shape_list(targets)
|
331
|
+
cross_attention_mask = tf.tile(
|
332
|
+
tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
|
333
|
+
target_shape = tf.shape(targets)
|
334
|
+
decoded = self._decoder(
|
335
|
+
tf.zeros_like(targets),
|
336
|
+
memory,
|
337
|
+
# TODO(b/199545430): self_attention_mask could be set to None when this
|
338
|
+
# bug is resolved. Passing ones for now.
|
339
|
+
self_attention_mask=tf.ones(
|
340
|
+
(target_shape[0], target_shape[1], target_shape[1])),
|
341
|
+
cross_attention_mask=cross_attention_mask,
|
342
|
+
return_all_decoder_outputs=True,
|
343
|
+
input_pos_embed=targets,
|
344
|
+
memory_pos_embed=pos_embed)
|
345
|
+
return decoded
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Tests for tensorflow_models.official.projects.detr.detr."""
|
16
|
+
import tensorflow as tf, tf_keras
|
17
|
+
from official.projects.detr.modeling import detr
|
18
|
+
from official.vision.modeling.backbones import resnet
|
19
|
+
|
20
|
+
|
21
|
+
class DetrTest(tf.test.TestCase):
|
22
|
+
|
23
|
+
def test_forward(self):
|
24
|
+
num_queries = 10
|
25
|
+
hidden_size = 128
|
26
|
+
num_classes = 10
|
27
|
+
image_size = 640
|
28
|
+
batch_size = 2
|
29
|
+
backbone = resnet.ResNet(50, bn_trainable=False)
|
30
|
+
backbone_endpoint_name = '5'
|
31
|
+
model = detr.DETR(backbone, backbone_endpoint_name, num_queries,
|
32
|
+
hidden_size, num_classes)
|
33
|
+
outs = model(tf.ones((batch_size, image_size, image_size, 3)))
|
34
|
+
self.assertLen(outs, 6) # intermediate decoded outputs.
|
35
|
+
for out in outs:
|
36
|
+
self.assertAllEqual(
|
37
|
+
tf.shape(out['cls_outputs']), (batch_size, num_queries, num_classes))
|
38
|
+
self.assertAllEqual(
|
39
|
+
tf.shape(out['box_outputs']), (batch_size, num_queries, 4))
|
40
|
+
|
41
|
+
def test_get_from_config_detr_transformer(self):
|
42
|
+
config = {
|
43
|
+
'num_encoder_layers': 1,
|
44
|
+
'num_decoder_layers': 2,
|
45
|
+
'dropout_rate': 0.5,
|
46
|
+
}
|
47
|
+
detr_model = detr.DETRTransformer.from_config(config)
|
48
|
+
retrieved_config = detr_model.get_config()
|
49
|
+
|
50
|
+
self.assertEqual(config, retrieved_config)
|
51
|
+
|
52
|
+
def test_get_from_config_detr(self):
|
53
|
+
config = {
|
54
|
+
'backbone': resnet.ResNet(50, bn_trainable=False),
|
55
|
+
'backbone_endpoint_name': '5',
|
56
|
+
'num_queries': 2,
|
57
|
+
'hidden_size': 4,
|
58
|
+
'num_classes': 10,
|
59
|
+
'num_encoder_layers': 4,
|
60
|
+
'num_decoder_layers': 5,
|
61
|
+
'dropout_rate': 0.5,
|
62
|
+
}
|
63
|
+
detr_model = detr.DETR.from_config(config)
|
64
|
+
retrieved_config = detr_model.get_config()
|
65
|
+
|
66
|
+
self.assertEqual(config, retrieved_config)
|
67
|
+
|
68
|
+
|
69
|
+
if __name__ == '__main__':
|
70
|
+
tf.test.main()
|