tf-models-nightly 2.19.0.dev20250107__py2.py3-none-any.whl → 2.19.0.dev20250109__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-models-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. official/projects/detr/__init__.py +14 -0
  2. official/projects/detr/configs/__init__.py +14 -0
  3. official/projects/detr/configs/detr.py +277 -0
  4. official/projects/detr/configs/detr_test.py +51 -0
  5. official/projects/detr/dataloaders/__init__.py +14 -0
  6. official/projects/detr/dataloaders/coco.py +157 -0
  7. official/projects/detr/dataloaders/coco_test.py +111 -0
  8. official/projects/detr/dataloaders/detr_input.py +175 -0
  9. official/projects/detr/experiments/__init__.py +14 -0
  10. official/projects/detr/modeling/__init__.py +14 -0
  11. official/projects/detr/modeling/detr.py +345 -0
  12. official/projects/detr/modeling/detr_test.py +70 -0
  13. official/projects/detr/modeling/transformer.py +849 -0
  14. official/projects/detr/modeling/transformer_test.py +263 -0
  15. official/projects/detr/ops/__init__.py +14 -0
  16. official/projects/detr/ops/matchers.py +489 -0
  17. official/projects/detr/ops/matchers_test.py +95 -0
  18. official/projects/detr/optimization.py +151 -0
  19. official/projects/detr/serving/__init__.py +14 -0
  20. official/projects/detr/serving/export_module.py +103 -0
  21. official/projects/detr/serving/export_module_test.py +98 -0
  22. official/projects/detr/serving/export_saved_model.py +109 -0
  23. official/projects/detr/tasks/__init__.py +14 -0
  24. official/projects/detr/tasks/detection.py +421 -0
  25. official/projects/detr/tasks/detection_test.py +203 -0
  26. official/projects/detr/train.py +70 -0
  27. official/vision/ops/augment.py +1 -13
  28. {tf_models_nightly-2.19.0.dev20250107.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/METADATA +1 -1
  29. {tf_models_nightly-2.19.0.dev20250107.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/RECORD +33 -7
  30. {tf_models_nightly-2.19.0.dev20250107.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/AUTHORS +0 -0
  31. {tf_models_nightly-2.19.0.dev20250107.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/LICENSE +0 -0
  32. {tf_models_nightly-2.19.0.dev20250107.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/WHEEL +0 -0
  33. {tf_models_nightly-2.19.0.dev20250107.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,277 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """DETR configurations."""
16
+
17
+ import dataclasses
18
+ import os
19
+ from typing import List, Optional, Union
20
+
21
+ from official.core import config_definitions as cfg
22
+ from official.core import exp_factory
23
+ from official.modeling import hyperparams
24
+ from official.projects.detr import optimization
25
+ from official.projects.detr.dataloaders import coco
26
+ from official.vision.configs import backbones
27
+ from official.vision.configs import common
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class DataConfig(cfg.DataConfig):
32
+ """Input config for training."""
33
+ input_path: str = ''
34
+ tfds_name: str = ''
35
+ tfds_split: str = 'train'
36
+ global_batch_size: int = 0
37
+ is_training: bool = False
38
+ dtype: str = 'bfloat16'
39
+ decoder: common.DataDecoder = dataclasses.field(default_factory=common.DataDecoder)
40
+ shuffle_buffer_size: int = 10000
41
+ file_type: str = 'tfrecord'
42
+ drop_remainder: bool = True
43
+
44
+
45
+ @dataclasses.dataclass
46
+ class Losses(hyperparams.Config):
47
+ class_offset: int = 0
48
+ lambda_cls: float = 1.0
49
+ lambda_box: float = 5.0
50
+ lambda_giou: float = 2.0
51
+ background_cls_weight: float = 0.1
52
+ l2_weight_decay: float = 1e-4
53
+
54
+
55
+ @dataclasses.dataclass
56
+ class Detr(hyperparams.Config):
57
+ """Detr model definations."""
58
+ num_queries: int = 100
59
+ hidden_size: int = 256
60
+ num_classes: int = 91 # 0: background
61
+ num_encoder_layers: int = 6
62
+ num_decoder_layers: int = 6
63
+ input_size: List[int] = dataclasses.field(default_factory=list)
64
+ backbone: backbones.Backbone = dataclasses.field(default_factory=lambda:backbones.Backbone(
65
+ type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)))
66
+ norm_activation: common.NormActivation = dataclasses.field(default_factory=common.NormActivation)
67
+ backbone_endpoint_name: str = '5'
68
+
69
+
70
+ @dataclasses.dataclass
71
+ class DetrTask(cfg.TaskConfig):
72
+ model: Detr = dataclasses.field(default_factory=Detr)
73
+ train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
74
+ validation_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
75
+ losses: Losses = dataclasses.field(default_factory=Losses)
76
+ init_checkpoint: Optional[str] = None
77
+ init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
78
+ annotation_file: Optional[str] = None
79
+ per_category_metrics: bool = False
80
+
81
+
82
+ COCO_INPUT_PATH_BASE = 'coco'
83
+ COCO_TRAIN_EXAMPLES = 118287
84
+ COCO_VAL_EXAMPLES = 5000
85
+
86
+
87
+ @exp_factory.register_config_factory('detr_coco')
88
+ def detr_coco() -> cfg.ExperimentConfig:
89
+ """Config to get results that matches the paper."""
90
+ train_batch_size = 64
91
+ eval_batch_size = 64
92
+ num_train_data = COCO_TRAIN_EXAMPLES
93
+ num_steps_per_epoch = num_train_data // train_batch_size
94
+ train_steps = 500 * num_steps_per_epoch # 500 epochs
95
+ decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
96
+ config = cfg.ExperimentConfig(
97
+ task=DetrTask(
98
+ init_checkpoint='',
99
+ init_checkpoint_modules='backbone',
100
+ model=Detr(
101
+ num_classes=81,
102
+ input_size=[1333, 1333, 3],
103
+ norm_activation=common.NormActivation()),
104
+ losses=Losses(),
105
+ train_data=coco.COCODataConfig(
106
+ tfds_name='coco/2017',
107
+ tfds_split='train',
108
+ is_training=True,
109
+ global_batch_size=train_batch_size,
110
+ shuffle_buffer_size=1000,
111
+ ),
112
+ validation_data=coco.COCODataConfig(
113
+ tfds_name='coco/2017',
114
+ tfds_split='validation',
115
+ is_training=False,
116
+ global_batch_size=eval_batch_size,
117
+ drop_remainder=False)),
118
+ trainer=cfg.TrainerConfig(
119
+ train_steps=train_steps,
120
+ validation_steps=-1,
121
+ steps_per_loop=10000,
122
+ summary_interval=10000,
123
+ checkpoint_interval=10000,
124
+ validation_interval=10000,
125
+ max_to_keep=1,
126
+ best_checkpoint_export_subdir='best_ckpt',
127
+ best_checkpoint_eval_metric='AP',
128
+ optimizer_config=optimization.OptimizationConfig({
129
+ 'optimizer': {
130
+ 'type': 'detr_adamw',
131
+ 'detr_adamw': {
132
+ 'weight_decay_rate': 1e-4,
133
+ 'global_clipnorm': 0.1,
134
+ # Avoid AdamW legacy behavior.
135
+ 'gradient_clip_norm': 0.0
136
+ }
137
+ },
138
+ 'learning_rate': {
139
+ 'type': 'stepwise',
140
+ 'stepwise': {
141
+ 'boundaries': [decay_at],
142
+ 'values': [0.0001, 1.0e-05]
143
+ }
144
+ },
145
+ })),
146
+ restrictions=[
147
+ 'task.train_data.is_training != None',
148
+ ])
149
+ return config
150
+
151
+
152
+ @exp_factory.register_config_factory('detr_coco_tfrecord')
153
+ def detr_coco_tfrecord() -> cfg.ExperimentConfig:
154
+ """Config to get results that matches the paper."""
155
+ train_batch_size = 64
156
+ eval_batch_size = 64
157
+ steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
158
+ train_steps = 300 * steps_per_epoch # 300 epochs
159
+ decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
160
+ config = cfg.ExperimentConfig(
161
+ task=DetrTask(
162
+ init_checkpoint='',
163
+ init_checkpoint_modules='backbone',
164
+ annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
165
+ 'instances_val2017.json'),
166
+ model=Detr(
167
+ input_size=[1333, 1333, 3],
168
+ norm_activation=common.NormActivation()),
169
+ losses=Losses(),
170
+ train_data=DataConfig(
171
+ input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
172
+ is_training=True,
173
+ global_batch_size=train_batch_size,
174
+ shuffle_buffer_size=1000,
175
+ ),
176
+ validation_data=DataConfig(
177
+ input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
178
+ is_training=False,
179
+ global_batch_size=eval_batch_size,
180
+ drop_remainder=False,
181
+ )),
182
+ trainer=cfg.TrainerConfig(
183
+ train_steps=train_steps,
184
+ validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
185
+ steps_per_loop=steps_per_epoch,
186
+ summary_interval=steps_per_epoch,
187
+ checkpoint_interval=steps_per_epoch,
188
+ validation_interval=5 * steps_per_epoch,
189
+ max_to_keep=1,
190
+ best_checkpoint_export_subdir='best_ckpt',
191
+ best_checkpoint_eval_metric='AP',
192
+ optimizer_config=optimization.OptimizationConfig({
193
+ 'optimizer': {
194
+ 'type': 'detr_adamw',
195
+ 'detr_adamw': {
196
+ 'weight_decay_rate': 1e-4,
197
+ 'global_clipnorm': 0.1,
198
+ # Avoid AdamW legacy behavior.
199
+ 'gradient_clip_norm': 0.0
200
+ }
201
+ },
202
+ 'learning_rate': {
203
+ 'type': 'stepwise',
204
+ 'stepwise': {
205
+ 'boundaries': [decay_at],
206
+ 'values': [0.0001, 1.0e-05]
207
+ }
208
+ },
209
+ })),
210
+ restrictions=[
211
+ 'task.train_data.is_training != None',
212
+ ])
213
+ return config
214
+
215
+
216
+ @exp_factory.register_config_factory('detr_coco_tfds')
217
+ def detr_coco_tfds() -> cfg.ExperimentConfig:
218
+ """Config to get results that matches the paper."""
219
+ train_batch_size = 64
220
+ eval_batch_size = 64
221
+ steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
222
+ train_steps = 300 * steps_per_epoch # 300 epochs
223
+ decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
224
+ config = cfg.ExperimentConfig(
225
+ task=DetrTask(
226
+ init_checkpoint='',
227
+ init_checkpoint_modules='backbone',
228
+ model=Detr(
229
+ num_classes=81,
230
+ input_size=[1333, 1333, 3],
231
+ norm_activation=common.NormActivation()),
232
+ losses=Losses(class_offset=1),
233
+ train_data=DataConfig(
234
+ tfds_name='coco/2017',
235
+ tfds_split='train',
236
+ is_training=True,
237
+ global_batch_size=train_batch_size,
238
+ shuffle_buffer_size=1000,
239
+ ),
240
+ validation_data=DataConfig(
241
+ tfds_name='coco/2017',
242
+ tfds_split='validation',
243
+ is_training=False,
244
+ global_batch_size=eval_batch_size,
245
+ drop_remainder=False)),
246
+ trainer=cfg.TrainerConfig(
247
+ train_steps=train_steps,
248
+ validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
249
+ steps_per_loop=steps_per_epoch,
250
+ summary_interval=steps_per_epoch,
251
+ checkpoint_interval=steps_per_epoch,
252
+ validation_interval=5 * steps_per_epoch,
253
+ max_to_keep=1,
254
+ best_checkpoint_export_subdir='best_ckpt',
255
+ best_checkpoint_eval_metric='AP',
256
+ optimizer_config=optimization.OptimizationConfig({
257
+ 'optimizer': {
258
+ 'type': 'detr_adamw',
259
+ 'detr_adamw': {
260
+ 'weight_decay_rate': 1e-4,
261
+ 'global_clipnorm': 0.1,
262
+ # Avoid AdamW legacy behavior.
263
+ 'gradient_clip_norm': 0.0
264
+ }
265
+ },
266
+ 'learning_rate': {
267
+ 'type': 'stepwise',
268
+ 'stepwise': {
269
+ 'boundaries': [decay_at],
270
+ 'values': [0.0001, 1.0e-05]
271
+ }
272
+ },
273
+ })),
274
+ restrictions=[
275
+ 'task.train_data.is_training != None',
276
+ ])
277
+ return config
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for detr."""
16
+
17
+ # pylint: disable=unused-import
18
+ from absl.testing import parameterized
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.core import config_definitions as cfg
22
+ from official.core import exp_factory
23
+ from official.projects.detr.configs import detr as exp_cfg
24
+ from official.projects.detr.dataloaders import coco
25
+
26
+
27
+ class DetrTest(tf.test.TestCase, parameterized.TestCase):
28
+
29
+ @parameterized.parameters(('detr_coco',))
30
+ def test_detr_configs_tfds(self, config_name):
31
+ config = exp_factory.get_exp_config(config_name)
32
+ self.assertIsInstance(config, cfg.ExperimentConfig)
33
+ self.assertIsInstance(config.task, exp_cfg.DetrTask)
34
+ self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
35
+ config.task.train_data.is_training = None
36
+ with self.assertRaises(KeyError):
37
+ config.validate()
38
+
39
+ @parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
40
+ def test_detr_configs(self, config_name):
41
+ config = exp_factory.get_exp_config(config_name)
42
+ self.assertIsInstance(config, cfg.ExperimentConfig)
43
+ self.assertIsInstance(config.task, exp_cfg.DetrTask)
44
+ self.assertIsInstance(config.task.train_data, cfg.DataConfig)
45
+ config.task.train_data.is_training = None
46
+ with self.assertRaises(KeyError):
47
+ config.validate()
48
+
49
+
50
+ if __name__ == '__main__':
51
+ tf.test.main()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,157 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """COCO data loader for DETR."""
16
+
17
+ import dataclasses
18
+ from typing import Optional, Tuple
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.core import config_definitions as cfg
22
+ from official.core import input_reader
23
+ from official.vision.ops import box_ops
24
+ from official.vision.ops import preprocess_ops
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class COCODataConfig(cfg.DataConfig):
29
+ """Data config for COCO."""
30
+ output_size: Tuple[int, int] = (1333, 1333)
31
+ max_num_boxes: int = 100
32
+ resize_scales: Tuple[int, ...] = (
33
+ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
34
+
35
+
36
+ class COCODataLoader():
37
+ """A class to load dataset for COCO detection task."""
38
+
39
+ def __init__(self, params: COCODataConfig):
40
+ self._params = params
41
+
42
+ def preprocess(self, inputs):
43
+ """Preprocess COCO for DETR."""
44
+ image = inputs['image']
45
+ boxes = inputs['objects']['bbox']
46
+ classes = inputs['objects']['label'] + 1
47
+ is_crowd = inputs['objects']['is_crowd']
48
+
49
+ image = preprocess_ops.normalize_image(image)
50
+ if self._params.is_training:
51
+ image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
52
+
53
+ do_crop = tf.greater(tf.random.uniform([]), 0.5)
54
+ if do_crop:
55
+ # Rescale
56
+ boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
57
+ index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
58
+ scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
59
+ short_side = scales[0]
60
+ image, image_info = preprocess_ops.resize_image(image, short_side)
61
+ boxes = preprocess_ops.resize_and_crop_boxes(boxes,
62
+ image_info[2, :],
63
+ image_info[1, :],
64
+ image_info[3, :])
65
+ boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
66
+
67
+ # Do croping
68
+ shape = tf.cast(image_info[1], dtype=tf.int32)
69
+ h = tf.random.uniform(
70
+ [], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
71
+ w = tf.random.uniform(
72
+ [], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
73
+ i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
74
+ j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
75
+ image = tf.image.crop_to_bounding_box(image, i, j, h, w)
76
+ boxes = tf.clip_by_value(
77
+ (boxes[..., :] * tf.cast(
78
+ tf.stack([shape[0], shape[1], shape[0], shape[1]]),
79
+ dtype=tf.float32) -
80
+ tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
81
+ tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
82
+ scales = tf.constant(
83
+ self._params.resize_scales,
84
+ dtype=tf.float32)
85
+ index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
86
+ scales = tf.gather(scales, index, axis=0)
87
+ else:
88
+ scales = tf.constant([self._params.resize_scales[-1]], tf.float32)
89
+
90
+ image_shape = tf.shape(image)[:2]
91
+ boxes = box_ops.denormalize_boxes(boxes, image_shape)
92
+ gt_boxes = boxes
93
+ short_side = scales[0]
94
+ image, image_info = preprocess_ops.resize_image(
95
+ image,
96
+ short_side,
97
+ max(self._params.output_size))
98
+ boxes = preprocess_ops.resize_and_crop_boxes(boxes,
99
+ image_info[2, :],
100
+ image_info[1, :],
101
+ image_info[3, :])
102
+ boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
103
+
104
+ # Filters out ground truth boxes that are all zeros.
105
+ indices = box_ops.get_non_empty_box_indices(boxes)
106
+ boxes = tf.gather(boxes, indices)
107
+ classes = tf.gather(classes, indices)
108
+ is_crowd = tf.gather(is_crowd, indices)
109
+ boxes = box_ops.yxyx_to_cycxhw(boxes)
110
+
111
+ image = tf.image.pad_to_bounding_box(
112
+ image, 0, 0, self._params.output_size[0], self._params.output_size[1])
113
+ labels = {
114
+ 'classes':
115
+ preprocess_ops.clip_or_pad_to_fixed_size(
116
+ classes, self._params.max_num_boxes),
117
+ 'boxes':
118
+ preprocess_ops.clip_or_pad_to_fixed_size(
119
+ boxes, self._params.max_num_boxes)
120
+ }
121
+ if not self._params.is_training:
122
+ labels.update({
123
+ 'id':
124
+ inputs['image/id'],
125
+ 'image_info':
126
+ image_info,
127
+ 'is_crowd':
128
+ preprocess_ops.clip_or_pad_to_fixed_size(
129
+ is_crowd, self._params.max_num_boxes),
130
+ 'gt_boxes':
131
+ preprocess_ops.clip_or_pad_to_fixed_size(
132
+ gt_boxes, self._params.max_num_boxes),
133
+ })
134
+
135
+ return image, labels
136
+
137
+ def _transform_and_batch_fn(
138
+ self,
139
+ dataset,
140
+ input_context: Optional[tf.distribute.InputContext] = None):
141
+ """Preprocess and batch."""
142
+ dataset = dataset.map(
143
+ self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
144
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
145
+ self._params.global_batch_size
146
+ ) if input_context else self._params.global_batch_size
147
+ dataset = dataset.batch(
148
+ per_replica_batch_size, drop_remainder=self._params.drop_remainder)
149
+ return dataset
150
+
151
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
152
+ """Returns a tf.dataset.Dataset."""
153
+ reader = input_reader.InputReader(
154
+ params=self._params,
155
+ decoder_fn=None,
156
+ transform_and_batch_fn=self._transform_and_batch_fn)
157
+ return reader.read(input_context)
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for tensorflow_models.official.projects.detr.dataloaders.coco."""
16
+
17
+ from absl.testing import parameterized
18
+ import numpy as np
19
+ import tensorflow as tf, tf_keras
20
+ import tensorflow_datasets as tfds
21
+
22
+ from official.projects.detr.dataloaders import coco
23
+
24
+
25
+ def _gen_fn():
26
+ h = np.random.randint(0, 300)
27
+ w = np.random.randint(0, 300)
28
+ num_boxes = np.random.randint(0, 50)
29
+ return {
30
+ 'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
31
+ 'image/id': np.random.randint(0, 100),
32
+ 'image/filename': 'test',
33
+ 'objects': {
34
+ 'is_crowd': np.ones(shape=(num_boxes), dtype=bool),
35
+ 'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
36
+ 'label': np.ones(shape=(num_boxes), dtype=np.int64),
37
+ 'id': np.ones(shape=(num_boxes), dtype=np.int64),
38
+ 'area': np.ones(shape=(num_boxes), dtype=np.int64),
39
+ }
40
+ }
41
+
42
+
43
+ class CocoDataloaderTest(tf.test.TestCase, parameterized.TestCase):
44
+
45
+ def test_load_dataset(self):
46
+ output_size = 1280
47
+ max_num_boxes = 100
48
+ batch_size = 2
49
+ data_config = coco.COCODataConfig(
50
+ tfds_name='coco/2017',
51
+ tfds_split='validation',
52
+ is_training=False,
53
+ global_batch_size=batch_size,
54
+ output_size=(output_size, output_size),
55
+ max_num_boxes=max_num_boxes,
56
+ )
57
+
58
+ num_examples = 10
59
+ def as_dataset(self, *args, **kwargs):
60
+ del args
61
+ del kwargs
62
+ return tf.data.Dataset.from_generator(
63
+ lambda: (_gen_fn() for i in range(num_examples)),
64
+ output_types=self.info.features.dtype,
65
+ output_shapes=self.info.features.shape,
66
+ )
67
+
68
+ with tfds.testing.mock_data(num_examples=num_examples,
69
+ as_dataset_fn=as_dataset):
70
+ dataset = coco.COCODataLoader(data_config).load()
71
+ dataset_iter = iter(dataset)
72
+ images, labels = next(dataset_iter)
73
+ self.assertEqual(images.shape, (batch_size, output_size, output_size, 3))
74
+ self.assertEqual(labels['classes'].shape, (batch_size, max_num_boxes))
75
+ self.assertEqual(labels['boxes'].shape, (batch_size, max_num_boxes, 4))
76
+ self.assertEqual(labels['id'].shape, (batch_size,))
77
+ self.assertEqual(
78
+ labels['image_info'].shape, (batch_size, 4, 2))
79
+ self.assertEqual(labels['is_crowd'].shape, (batch_size, max_num_boxes))
80
+
81
+ @parameterized.named_parameters(
82
+ ('training', True),
83
+ ('validation', False))
84
+ def test_preprocess(self, is_training):
85
+ output_size = 1280
86
+ max_num_boxes = 100
87
+ batch_size = 2
88
+ data_config = coco.COCODataConfig(
89
+ tfds_name='coco/2017',
90
+ tfds_split='validation',
91
+ is_training=is_training,
92
+ global_batch_size=batch_size,
93
+ output_size=(output_size, output_size),
94
+ max_num_boxes=max_num_boxes,
95
+ )
96
+
97
+ dl = coco.COCODataLoader(data_config)
98
+ inputs = _gen_fn()
99
+ image, label = dl.preprocess(inputs)
100
+ self.assertEqual(image.shape, (output_size, output_size, 3))
101
+ self.assertEqual(label['classes'].shape, (max_num_boxes))
102
+ self.assertEqual(label['boxes'].shape, (max_num_boxes, 4))
103
+ if not is_training:
104
+ self.assertDTypeEqual(label['id'], int)
105
+ self.assertEqual(
106
+ label['image_info'].shape, (4, 2))
107
+ self.assertEqual(label['is_crowd'].shape, (max_num_boxes))
108
+
109
+
110
+ if __name__ == '__main__':
111
+ tf.test.main()