tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250109__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-models-nightly might be problematic. Click here for more details.
- official/projects/detr/__init__.py +14 -0
- official/projects/detr/configs/__init__.py +14 -0
- official/projects/detr/configs/detr.py +277 -0
- official/projects/detr/configs/detr_test.py +51 -0
- official/projects/detr/dataloaders/__init__.py +14 -0
- official/projects/detr/dataloaders/coco.py +157 -0
- official/projects/detr/dataloaders/coco_test.py +111 -0
- official/projects/detr/dataloaders/detr_input.py +175 -0
- official/projects/detr/experiments/__init__.py +14 -0
- official/projects/detr/modeling/__init__.py +14 -0
- official/projects/detr/modeling/detr.py +345 -0
- official/projects/detr/modeling/detr_test.py +70 -0
- official/projects/detr/modeling/transformer.py +849 -0
- official/projects/detr/modeling/transformer_test.py +263 -0
- official/projects/detr/ops/__init__.py +14 -0
- official/projects/detr/ops/matchers.py +489 -0
- official/projects/detr/ops/matchers_test.py +95 -0
- official/projects/detr/optimization.py +151 -0
- official/projects/detr/serving/__init__.py +14 -0
- official/projects/detr/serving/export_module.py +103 -0
- official/projects/detr/serving/export_module_test.py +98 -0
- official/projects/detr/serving/export_saved_model.py +109 -0
- official/projects/detr/tasks/__init__.py +14 -0
- official/projects/detr/tasks/detection.py +421 -0
- official/projects/detr/tasks/detection_test.py +203 -0
- official/projects/detr/train.py +70 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/RECORD +32 -6
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,277 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""DETR configurations."""
|
16
|
+
|
17
|
+
import dataclasses
|
18
|
+
import os
|
19
|
+
from typing import List, Optional, Union
|
20
|
+
|
21
|
+
from official.core import config_definitions as cfg
|
22
|
+
from official.core import exp_factory
|
23
|
+
from official.modeling import hyperparams
|
24
|
+
from official.projects.detr import optimization
|
25
|
+
from official.projects.detr.dataloaders import coco
|
26
|
+
from official.vision.configs import backbones
|
27
|
+
from official.vision.configs import common
|
28
|
+
|
29
|
+
|
30
|
+
@dataclasses.dataclass
|
31
|
+
class DataConfig(cfg.DataConfig):
|
32
|
+
"""Input config for training."""
|
33
|
+
input_path: str = ''
|
34
|
+
tfds_name: str = ''
|
35
|
+
tfds_split: str = 'train'
|
36
|
+
global_batch_size: int = 0
|
37
|
+
is_training: bool = False
|
38
|
+
dtype: str = 'bfloat16'
|
39
|
+
decoder: common.DataDecoder = dataclasses.field(default_factory=common.DataDecoder)
|
40
|
+
shuffle_buffer_size: int = 10000
|
41
|
+
file_type: str = 'tfrecord'
|
42
|
+
drop_remainder: bool = True
|
43
|
+
|
44
|
+
|
45
|
+
@dataclasses.dataclass
|
46
|
+
class Losses(hyperparams.Config):
|
47
|
+
class_offset: int = 0
|
48
|
+
lambda_cls: float = 1.0
|
49
|
+
lambda_box: float = 5.0
|
50
|
+
lambda_giou: float = 2.0
|
51
|
+
background_cls_weight: float = 0.1
|
52
|
+
l2_weight_decay: float = 1e-4
|
53
|
+
|
54
|
+
|
55
|
+
@dataclasses.dataclass
|
56
|
+
class Detr(hyperparams.Config):
|
57
|
+
"""Detr model definations."""
|
58
|
+
num_queries: int = 100
|
59
|
+
hidden_size: int = 256
|
60
|
+
num_classes: int = 91 # 0: background
|
61
|
+
num_encoder_layers: int = 6
|
62
|
+
num_decoder_layers: int = 6
|
63
|
+
input_size: List[int] = dataclasses.field(default_factory=list)
|
64
|
+
backbone: backbones.Backbone = dataclasses.field(default_factory=lambda:backbones.Backbone(
|
65
|
+
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)))
|
66
|
+
norm_activation: common.NormActivation = dataclasses.field(default_factory=common.NormActivation)
|
67
|
+
backbone_endpoint_name: str = '5'
|
68
|
+
|
69
|
+
|
70
|
+
@dataclasses.dataclass
|
71
|
+
class DetrTask(cfg.TaskConfig):
|
72
|
+
model: Detr = dataclasses.field(default_factory=Detr)
|
73
|
+
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
|
74
|
+
validation_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
|
75
|
+
losses: Losses = dataclasses.field(default_factory=Losses)
|
76
|
+
init_checkpoint: Optional[str] = None
|
77
|
+
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
|
78
|
+
annotation_file: Optional[str] = None
|
79
|
+
per_category_metrics: bool = False
|
80
|
+
|
81
|
+
|
82
|
+
COCO_INPUT_PATH_BASE = 'coco'
|
83
|
+
COCO_TRAIN_EXAMPLES = 118287
|
84
|
+
COCO_VAL_EXAMPLES = 5000
|
85
|
+
|
86
|
+
|
87
|
+
@exp_factory.register_config_factory('detr_coco')
|
88
|
+
def detr_coco() -> cfg.ExperimentConfig:
|
89
|
+
"""Config to get results that matches the paper."""
|
90
|
+
train_batch_size = 64
|
91
|
+
eval_batch_size = 64
|
92
|
+
num_train_data = COCO_TRAIN_EXAMPLES
|
93
|
+
num_steps_per_epoch = num_train_data // train_batch_size
|
94
|
+
train_steps = 500 * num_steps_per_epoch # 500 epochs
|
95
|
+
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
|
96
|
+
config = cfg.ExperimentConfig(
|
97
|
+
task=DetrTask(
|
98
|
+
init_checkpoint='',
|
99
|
+
init_checkpoint_modules='backbone',
|
100
|
+
model=Detr(
|
101
|
+
num_classes=81,
|
102
|
+
input_size=[1333, 1333, 3],
|
103
|
+
norm_activation=common.NormActivation()),
|
104
|
+
losses=Losses(),
|
105
|
+
train_data=coco.COCODataConfig(
|
106
|
+
tfds_name='coco/2017',
|
107
|
+
tfds_split='train',
|
108
|
+
is_training=True,
|
109
|
+
global_batch_size=train_batch_size,
|
110
|
+
shuffle_buffer_size=1000,
|
111
|
+
),
|
112
|
+
validation_data=coco.COCODataConfig(
|
113
|
+
tfds_name='coco/2017',
|
114
|
+
tfds_split='validation',
|
115
|
+
is_training=False,
|
116
|
+
global_batch_size=eval_batch_size,
|
117
|
+
drop_remainder=False)),
|
118
|
+
trainer=cfg.TrainerConfig(
|
119
|
+
train_steps=train_steps,
|
120
|
+
validation_steps=-1,
|
121
|
+
steps_per_loop=10000,
|
122
|
+
summary_interval=10000,
|
123
|
+
checkpoint_interval=10000,
|
124
|
+
validation_interval=10000,
|
125
|
+
max_to_keep=1,
|
126
|
+
best_checkpoint_export_subdir='best_ckpt',
|
127
|
+
best_checkpoint_eval_metric='AP',
|
128
|
+
optimizer_config=optimization.OptimizationConfig({
|
129
|
+
'optimizer': {
|
130
|
+
'type': 'detr_adamw',
|
131
|
+
'detr_adamw': {
|
132
|
+
'weight_decay_rate': 1e-4,
|
133
|
+
'global_clipnorm': 0.1,
|
134
|
+
# Avoid AdamW legacy behavior.
|
135
|
+
'gradient_clip_norm': 0.0
|
136
|
+
}
|
137
|
+
},
|
138
|
+
'learning_rate': {
|
139
|
+
'type': 'stepwise',
|
140
|
+
'stepwise': {
|
141
|
+
'boundaries': [decay_at],
|
142
|
+
'values': [0.0001, 1.0e-05]
|
143
|
+
}
|
144
|
+
},
|
145
|
+
})),
|
146
|
+
restrictions=[
|
147
|
+
'task.train_data.is_training != None',
|
148
|
+
])
|
149
|
+
return config
|
150
|
+
|
151
|
+
|
152
|
+
@exp_factory.register_config_factory('detr_coco_tfrecord')
|
153
|
+
def detr_coco_tfrecord() -> cfg.ExperimentConfig:
|
154
|
+
"""Config to get results that matches the paper."""
|
155
|
+
train_batch_size = 64
|
156
|
+
eval_batch_size = 64
|
157
|
+
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
|
158
|
+
train_steps = 300 * steps_per_epoch # 300 epochs
|
159
|
+
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
|
160
|
+
config = cfg.ExperimentConfig(
|
161
|
+
task=DetrTask(
|
162
|
+
init_checkpoint='',
|
163
|
+
init_checkpoint_modules='backbone',
|
164
|
+
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
|
165
|
+
'instances_val2017.json'),
|
166
|
+
model=Detr(
|
167
|
+
input_size=[1333, 1333, 3],
|
168
|
+
norm_activation=common.NormActivation()),
|
169
|
+
losses=Losses(),
|
170
|
+
train_data=DataConfig(
|
171
|
+
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
|
172
|
+
is_training=True,
|
173
|
+
global_batch_size=train_batch_size,
|
174
|
+
shuffle_buffer_size=1000,
|
175
|
+
),
|
176
|
+
validation_data=DataConfig(
|
177
|
+
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
|
178
|
+
is_training=False,
|
179
|
+
global_batch_size=eval_batch_size,
|
180
|
+
drop_remainder=False,
|
181
|
+
)),
|
182
|
+
trainer=cfg.TrainerConfig(
|
183
|
+
train_steps=train_steps,
|
184
|
+
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
|
185
|
+
steps_per_loop=steps_per_epoch,
|
186
|
+
summary_interval=steps_per_epoch,
|
187
|
+
checkpoint_interval=steps_per_epoch,
|
188
|
+
validation_interval=5 * steps_per_epoch,
|
189
|
+
max_to_keep=1,
|
190
|
+
best_checkpoint_export_subdir='best_ckpt',
|
191
|
+
best_checkpoint_eval_metric='AP',
|
192
|
+
optimizer_config=optimization.OptimizationConfig({
|
193
|
+
'optimizer': {
|
194
|
+
'type': 'detr_adamw',
|
195
|
+
'detr_adamw': {
|
196
|
+
'weight_decay_rate': 1e-4,
|
197
|
+
'global_clipnorm': 0.1,
|
198
|
+
# Avoid AdamW legacy behavior.
|
199
|
+
'gradient_clip_norm': 0.0
|
200
|
+
}
|
201
|
+
},
|
202
|
+
'learning_rate': {
|
203
|
+
'type': 'stepwise',
|
204
|
+
'stepwise': {
|
205
|
+
'boundaries': [decay_at],
|
206
|
+
'values': [0.0001, 1.0e-05]
|
207
|
+
}
|
208
|
+
},
|
209
|
+
})),
|
210
|
+
restrictions=[
|
211
|
+
'task.train_data.is_training != None',
|
212
|
+
])
|
213
|
+
return config
|
214
|
+
|
215
|
+
|
216
|
+
@exp_factory.register_config_factory('detr_coco_tfds')
|
217
|
+
def detr_coco_tfds() -> cfg.ExperimentConfig:
|
218
|
+
"""Config to get results that matches the paper."""
|
219
|
+
train_batch_size = 64
|
220
|
+
eval_batch_size = 64
|
221
|
+
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
|
222
|
+
train_steps = 300 * steps_per_epoch # 300 epochs
|
223
|
+
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
|
224
|
+
config = cfg.ExperimentConfig(
|
225
|
+
task=DetrTask(
|
226
|
+
init_checkpoint='',
|
227
|
+
init_checkpoint_modules='backbone',
|
228
|
+
model=Detr(
|
229
|
+
num_classes=81,
|
230
|
+
input_size=[1333, 1333, 3],
|
231
|
+
norm_activation=common.NormActivation()),
|
232
|
+
losses=Losses(class_offset=1),
|
233
|
+
train_data=DataConfig(
|
234
|
+
tfds_name='coco/2017',
|
235
|
+
tfds_split='train',
|
236
|
+
is_training=True,
|
237
|
+
global_batch_size=train_batch_size,
|
238
|
+
shuffle_buffer_size=1000,
|
239
|
+
),
|
240
|
+
validation_data=DataConfig(
|
241
|
+
tfds_name='coco/2017',
|
242
|
+
tfds_split='validation',
|
243
|
+
is_training=False,
|
244
|
+
global_batch_size=eval_batch_size,
|
245
|
+
drop_remainder=False)),
|
246
|
+
trainer=cfg.TrainerConfig(
|
247
|
+
train_steps=train_steps,
|
248
|
+
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
|
249
|
+
steps_per_loop=steps_per_epoch,
|
250
|
+
summary_interval=steps_per_epoch,
|
251
|
+
checkpoint_interval=steps_per_epoch,
|
252
|
+
validation_interval=5 * steps_per_epoch,
|
253
|
+
max_to_keep=1,
|
254
|
+
best_checkpoint_export_subdir='best_ckpt',
|
255
|
+
best_checkpoint_eval_metric='AP',
|
256
|
+
optimizer_config=optimization.OptimizationConfig({
|
257
|
+
'optimizer': {
|
258
|
+
'type': 'detr_adamw',
|
259
|
+
'detr_adamw': {
|
260
|
+
'weight_decay_rate': 1e-4,
|
261
|
+
'global_clipnorm': 0.1,
|
262
|
+
# Avoid AdamW legacy behavior.
|
263
|
+
'gradient_clip_norm': 0.0
|
264
|
+
}
|
265
|
+
},
|
266
|
+
'learning_rate': {
|
267
|
+
'type': 'stepwise',
|
268
|
+
'stepwise': {
|
269
|
+
'boundaries': [decay_at],
|
270
|
+
'values': [0.0001, 1.0e-05]
|
271
|
+
}
|
272
|
+
},
|
273
|
+
})),
|
274
|
+
restrictions=[
|
275
|
+
'task.train_data.is_training != None',
|
276
|
+
])
|
277
|
+
return config
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Tests for detr."""
|
16
|
+
|
17
|
+
# pylint: disable=unused-import
|
18
|
+
from absl.testing import parameterized
|
19
|
+
import tensorflow as tf, tf_keras
|
20
|
+
|
21
|
+
from official.core import config_definitions as cfg
|
22
|
+
from official.core import exp_factory
|
23
|
+
from official.projects.detr.configs import detr as exp_cfg
|
24
|
+
from official.projects.detr.dataloaders import coco
|
25
|
+
|
26
|
+
|
27
|
+
class DetrTest(tf.test.TestCase, parameterized.TestCase):
|
28
|
+
|
29
|
+
@parameterized.parameters(('detr_coco',))
|
30
|
+
def test_detr_configs_tfds(self, config_name):
|
31
|
+
config = exp_factory.get_exp_config(config_name)
|
32
|
+
self.assertIsInstance(config, cfg.ExperimentConfig)
|
33
|
+
self.assertIsInstance(config.task, exp_cfg.DetrTask)
|
34
|
+
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
|
35
|
+
config.task.train_data.is_training = None
|
36
|
+
with self.assertRaises(KeyError):
|
37
|
+
config.validate()
|
38
|
+
|
39
|
+
@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
|
40
|
+
def test_detr_configs(self, config_name):
|
41
|
+
config = exp_factory.get_exp_config(config_name)
|
42
|
+
self.assertIsInstance(config, cfg.ExperimentConfig)
|
43
|
+
self.assertIsInstance(config.task, exp_cfg.DetrTask)
|
44
|
+
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
|
45
|
+
config.task.train_data.is_training = None
|
46
|
+
with self.assertRaises(KeyError):
|
47
|
+
config.validate()
|
48
|
+
|
49
|
+
|
50
|
+
if __name__ == '__main__':
|
51
|
+
tf.test.main()
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,157 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""COCO data loader for DETR."""
|
16
|
+
|
17
|
+
import dataclasses
|
18
|
+
from typing import Optional, Tuple
|
19
|
+
import tensorflow as tf, tf_keras
|
20
|
+
|
21
|
+
from official.core import config_definitions as cfg
|
22
|
+
from official.core import input_reader
|
23
|
+
from official.vision.ops import box_ops
|
24
|
+
from official.vision.ops import preprocess_ops
|
25
|
+
|
26
|
+
|
27
|
+
@dataclasses.dataclass
|
28
|
+
class COCODataConfig(cfg.DataConfig):
|
29
|
+
"""Data config for COCO."""
|
30
|
+
output_size: Tuple[int, int] = (1333, 1333)
|
31
|
+
max_num_boxes: int = 100
|
32
|
+
resize_scales: Tuple[int, ...] = (
|
33
|
+
480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
34
|
+
|
35
|
+
|
36
|
+
class COCODataLoader():
|
37
|
+
"""A class to load dataset for COCO detection task."""
|
38
|
+
|
39
|
+
def __init__(self, params: COCODataConfig):
|
40
|
+
self._params = params
|
41
|
+
|
42
|
+
def preprocess(self, inputs):
|
43
|
+
"""Preprocess COCO for DETR."""
|
44
|
+
image = inputs['image']
|
45
|
+
boxes = inputs['objects']['bbox']
|
46
|
+
classes = inputs['objects']['label'] + 1
|
47
|
+
is_crowd = inputs['objects']['is_crowd']
|
48
|
+
|
49
|
+
image = preprocess_ops.normalize_image(image)
|
50
|
+
if self._params.is_training:
|
51
|
+
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
|
52
|
+
|
53
|
+
do_crop = tf.greater(tf.random.uniform([]), 0.5)
|
54
|
+
if do_crop:
|
55
|
+
# Rescale
|
56
|
+
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
|
57
|
+
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
|
58
|
+
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
|
59
|
+
short_side = scales[0]
|
60
|
+
image, image_info = preprocess_ops.resize_image(image, short_side)
|
61
|
+
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
|
62
|
+
image_info[2, :],
|
63
|
+
image_info[1, :],
|
64
|
+
image_info[3, :])
|
65
|
+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
|
66
|
+
|
67
|
+
# Do croping
|
68
|
+
shape = tf.cast(image_info[1], dtype=tf.int32)
|
69
|
+
h = tf.random.uniform(
|
70
|
+
[], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
|
71
|
+
w = tf.random.uniform(
|
72
|
+
[], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
|
73
|
+
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
|
74
|
+
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
|
75
|
+
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
|
76
|
+
boxes = tf.clip_by_value(
|
77
|
+
(boxes[..., :] * tf.cast(
|
78
|
+
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
|
79
|
+
dtype=tf.float32) -
|
80
|
+
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
|
81
|
+
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
|
82
|
+
scales = tf.constant(
|
83
|
+
self._params.resize_scales,
|
84
|
+
dtype=tf.float32)
|
85
|
+
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
|
86
|
+
scales = tf.gather(scales, index, axis=0)
|
87
|
+
else:
|
88
|
+
scales = tf.constant([self._params.resize_scales[-1]], tf.float32)
|
89
|
+
|
90
|
+
image_shape = tf.shape(image)[:2]
|
91
|
+
boxes = box_ops.denormalize_boxes(boxes, image_shape)
|
92
|
+
gt_boxes = boxes
|
93
|
+
short_side = scales[0]
|
94
|
+
image, image_info = preprocess_ops.resize_image(
|
95
|
+
image,
|
96
|
+
short_side,
|
97
|
+
max(self._params.output_size))
|
98
|
+
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
|
99
|
+
image_info[2, :],
|
100
|
+
image_info[1, :],
|
101
|
+
image_info[3, :])
|
102
|
+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
|
103
|
+
|
104
|
+
# Filters out ground truth boxes that are all zeros.
|
105
|
+
indices = box_ops.get_non_empty_box_indices(boxes)
|
106
|
+
boxes = tf.gather(boxes, indices)
|
107
|
+
classes = tf.gather(classes, indices)
|
108
|
+
is_crowd = tf.gather(is_crowd, indices)
|
109
|
+
boxes = box_ops.yxyx_to_cycxhw(boxes)
|
110
|
+
|
111
|
+
image = tf.image.pad_to_bounding_box(
|
112
|
+
image, 0, 0, self._params.output_size[0], self._params.output_size[1])
|
113
|
+
labels = {
|
114
|
+
'classes':
|
115
|
+
preprocess_ops.clip_or_pad_to_fixed_size(
|
116
|
+
classes, self._params.max_num_boxes),
|
117
|
+
'boxes':
|
118
|
+
preprocess_ops.clip_or_pad_to_fixed_size(
|
119
|
+
boxes, self._params.max_num_boxes)
|
120
|
+
}
|
121
|
+
if not self._params.is_training:
|
122
|
+
labels.update({
|
123
|
+
'id':
|
124
|
+
inputs['image/id'],
|
125
|
+
'image_info':
|
126
|
+
image_info,
|
127
|
+
'is_crowd':
|
128
|
+
preprocess_ops.clip_or_pad_to_fixed_size(
|
129
|
+
is_crowd, self._params.max_num_boxes),
|
130
|
+
'gt_boxes':
|
131
|
+
preprocess_ops.clip_or_pad_to_fixed_size(
|
132
|
+
gt_boxes, self._params.max_num_boxes),
|
133
|
+
})
|
134
|
+
|
135
|
+
return image, labels
|
136
|
+
|
137
|
+
def _transform_and_batch_fn(
|
138
|
+
self,
|
139
|
+
dataset,
|
140
|
+
input_context: Optional[tf.distribute.InputContext] = None):
|
141
|
+
"""Preprocess and batch."""
|
142
|
+
dataset = dataset.map(
|
143
|
+
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
144
|
+
per_replica_batch_size = input_context.get_per_replica_batch_size(
|
145
|
+
self._params.global_batch_size
|
146
|
+
) if input_context else self._params.global_batch_size
|
147
|
+
dataset = dataset.batch(
|
148
|
+
per_replica_batch_size, drop_remainder=self._params.drop_remainder)
|
149
|
+
return dataset
|
150
|
+
|
151
|
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
152
|
+
"""Returns a tf.dataset.Dataset."""
|
153
|
+
reader = input_reader.InputReader(
|
154
|
+
params=self._params,
|
155
|
+
decoder_fn=None,
|
156
|
+
transform_and_batch_fn=self._transform_and_batch_fn)
|
157
|
+
return reader.read(input_context)
|
@@ -0,0 +1,111 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Tests for tensorflow_models.official.projects.detr.dataloaders.coco."""
|
16
|
+
|
17
|
+
from absl.testing import parameterized
|
18
|
+
import numpy as np
|
19
|
+
import tensorflow as tf, tf_keras
|
20
|
+
import tensorflow_datasets as tfds
|
21
|
+
|
22
|
+
from official.projects.detr.dataloaders import coco
|
23
|
+
|
24
|
+
|
25
|
+
def _gen_fn():
|
26
|
+
h = np.random.randint(0, 300)
|
27
|
+
w = np.random.randint(0, 300)
|
28
|
+
num_boxes = np.random.randint(0, 50)
|
29
|
+
return {
|
30
|
+
'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
|
31
|
+
'image/id': np.random.randint(0, 100),
|
32
|
+
'image/filename': 'test',
|
33
|
+
'objects': {
|
34
|
+
'is_crowd': np.ones(shape=(num_boxes), dtype=bool),
|
35
|
+
'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
|
36
|
+
'label': np.ones(shape=(num_boxes), dtype=np.int64),
|
37
|
+
'id': np.ones(shape=(num_boxes), dtype=np.int64),
|
38
|
+
'area': np.ones(shape=(num_boxes), dtype=np.int64),
|
39
|
+
}
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
class CocoDataloaderTest(tf.test.TestCase, parameterized.TestCase):
|
44
|
+
|
45
|
+
def test_load_dataset(self):
|
46
|
+
output_size = 1280
|
47
|
+
max_num_boxes = 100
|
48
|
+
batch_size = 2
|
49
|
+
data_config = coco.COCODataConfig(
|
50
|
+
tfds_name='coco/2017',
|
51
|
+
tfds_split='validation',
|
52
|
+
is_training=False,
|
53
|
+
global_batch_size=batch_size,
|
54
|
+
output_size=(output_size, output_size),
|
55
|
+
max_num_boxes=max_num_boxes,
|
56
|
+
)
|
57
|
+
|
58
|
+
num_examples = 10
|
59
|
+
def as_dataset(self, *args, **kwargs):
|
60
|
+
del args
|
61
|
+
del kwargs
|
62
|
+
return tf.data.Dataset.from_generator(
|
63
|
+
lambda: (_gen_fn() for i in range(num_examples)),
|
64
|
+
output_types=self.info.features.dtype,
|
65
|
+
output_shapes=self.info.features.shape,
|
66
|
+
)
|
67
|
+
|
68
|
+
with tfds.testing.mock_data(num_examples=num_examples,
|
69
|
+
as_dataset_fn=as_dataset):
|
70
|
+
dataset = coco.COCODataLoader(data_config).load()
|
71
|
+
dataset_iter = iter(dataset)
|
72
|
+
images, labels = next(dataset_iter)
|
73
|
+
self.assertEqual(images.shape, (batch_size, output_size, output_size, 3))
|
74
|
+
self.assertEqual(labels['classes'].shape, (batch_size, max_num_boxes))
|
75
|
+
self.assertEqual(labels['boxes'].shape, (batch_size, max_num_boxes, 4))
|
76
|
+
self.assertEqual(labels['id'].shape, (batch_size,))
|
77
|
+
self.assertEqual(
|
78
|
+
labels['image_info'].shape, (batch_size, 4, 2))
|
79
|
+
self.assertEqual(labels['is_crowd'].shape, (batch_size, max_num_boxes))
|
80
|
+
|
81
|
+
@parameterized.named_parameters(
|
82
|
+
('training', True),
|
83
|
+
('validation', False))
|
84
|
+
def test_preprocess(self, is_training):
|
85
|
+
output_size = 1280
|
86
|
+
max_num_boxes = 100
|
87
|
+
batch_size = 2
|
88
|
+
data_config = coco.COCODataConfig(
|
89
|
+
tfds_name='coco/2017',
|
90
|
+
tfds_split='validation',
|
91
|
+
is_training=is_training,
|
92
|
+
global_batch_size=batch_size,
|
93
|
+
output_size=(output_size, output_size),
|
94
|
+
max_num_boxes=max_num_boxes,
|
95
|
+
)
|
96
|
+
|
97
|
+
dl = coco.COCODataLoader(data_config)
|
98
|
+
inputs = _gen_fn()
|
99
|
+
image, label = dl.preprocess(inputs)
|
100
|
+
self.assertEqual(image.shape, (output_size, output_size, 3))
|
101
|
+
self.assertEqual(label['classes'].shape, (max_num_boxes))
|
102
|
+
self.assertEqual(label['boxes'].shape, (max_num_boxes, 4))
|
103
|
+
if not is_training:
|
104
|
+
self.assertDTypeEqual(label['id'], int)
|
105
|
+
self.assertEqual(
|
106
|
+
label['image_info'].shape, (4, 2))
|
107
|
+
self.assertEqual(label['is_crowd'].shape, (max_num_boxes))
|
108
|
+
|
109
|
+
|
110
|
+
if __name__ == '__main__':
|
111
|
+
tf.test.main()
|