tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250110__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-models-nightly might be problematic. Click here for more details.
- official/projects/detr/__init__.py +14 -0
- official/projects/detr/configs/__init__.py +14 -0
- official/projects/detr/configs/detr.py +277 -0
- official/projects/detr/configs/detr_test.py +51 -0
- official/projects/detr/dataloaders/__init__.py +14 -0
- official/projects/detr/dataloaders/coco.py +157 -0
- official/projects/detr/dataloaders/coco_test.py +111 -0
- official/projects/detr/dataloaders/detr_input.py +175 -0
- official/projects/detr/experiments/__init__.py +14 -0
- official/projects/detr/modeling/__init__.py +14 -0
- official/projects/detr/modeling/detr.py +345 -0
- official/projects/detr/modeling/detr_test.py +70 -0
- official/projects/detr/modeling/transformer.py +849 -0
- official/projects/detr/modeling/transformer_test.py +263 -0
- official/projects/detr/ops/__init__.py +14 -0
- official/projects/detr/ops/matchers.py +489 -0
- official/projects/detr/ops/matchers_test.py +95 -0
- official/projects/detr/optimization.py +151 -0
- official/projects/detr/serving/__init__.py +14 -0
- official/projects/detr/serving/export_module.py +103 -0
- official/projects/detr/serving/export_module_test.py +98 -0
- official/projects/detr/serving/export_saved_model.py +109 -0
- official/projects/detr/tasks/__init__.py +14 -0
- official/projects/detr/tasks/detection.py +421 -0
- official/projects/detr/tasks/detection_test.py +203 -0
- official/projects/detr/train.py +70 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/RECORD +32 -6
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250110.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,421 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""DETR detection task definition."""
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from absl import logging
|
19
|
+
import tensorflow as tf, tf_keras
|
20
|
+
|
21
|
+
from official.common import dataset_fn
|
22
|
+
from official.core import base_task
|
23
|
+
from official.core import task_factory
|
24
|
+
from official.projects.detr.configs import detr as detr_cfg
|
25
|
+
from official.projects.detr.dataloaders import coco
|
26
|
+
from official.projects.detr.dataloaders import detr_input
|
27
|
+
from official.projects.detr.modeling import detr
|
28
|
+
from official.projects.detr.ops import matchers
|
29
|
+
from official.vision.dataloaders import input_reader_factory
|
30
|
+
from official.vision.dataloaders import tf_example_decoder
|
31
|
+
from official.vision.dataloaders import tfds_factory
|
32
|
+
from official.vision.dataloaders import tf_example_label_map_decoder
|
33
|
+
from official.vision.evaluation import coco_evaluator
|
34
|
+
from official.vision.modeling import backbones
|
35
|
+
from official.vision.ops import box_ops
|
36
|
+
|
37
|
+
|
38
|
+
@task_factory.register_task_cls(detr_cfg.DetrTask)
|
39
|
+
class DetectionTask(base_task.Task):
|
40
|
+
"""A single-replica view of training procedure.
|
41
|
+
|
42
|
+
DETR task provides artifacts for training/evalution procedures, including
|
43
|
+
loading/iterating over Datasets, initializing the model, calculating the loss,
|
44
|
+
post-processing, and customized metrics with reduction.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def build_model(self):
|
48
|
+
"""Build DETR model."""
|
49
|
+
|
50
|
+
input_specs = tf_keras.layers.InputSpec(shape=[None] +
|
51
|
+
self._task_config.model.input_size)
|
52
|
+
|
53
|
+
backbone = backbones.factory.build_backbone(
|
54
|
+
input_specs=input_specs,
|
55
|
+
backbone_config=self._task_config.model.backbone,
|
56
|
+
norm_activation_config=self._task_config.model.norm_activation)
|
57
|
+
|
58
|
+
model = detr.DETR(backbone,
|
59
|
+
self._task_config.model.backbone_endpoint_name,
|
60
|
+
self._task_config.model.num_queries,
|
61
|
+
self._task_config.model.hidden_size,
|
62
|
+
self._task_config.model.num_classes,
|
63
|
+
self._task_config.model.num_encoder_layers,
|
64
|
+
self._task_config.model.num_decoder_layers)
|
65
|
+
return model
|
66
|
+
|
67
|
+
def initialize(self, model: tf_keras.Model):
|
68
|
+
"""Loading pretrained checkpoint."""
|
69
|
+
if not self._task_config.init_checkpoint:
|
70
|
+
return
|
71
|
+
|
72
|
+
ckpt_dir_or_file = self._task_config.init_checkpoint
|
73
|
+
|
74
|
+
# Restoring checkpoint.
|
75
|
+
if tf.io.gfile.isdir(ckpt_dir_or_file):
|
76
|
+
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
77
|
+
|
78
|
+
if self._task_config.init_checkpoint_modules == 'all':
|
79
|
+
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
|
80
|
+
status = ckpt.restore(ckpt_dir_or_file)
|
81
|
+
status.assert_consumed()
|
82
|
+
elif self._task_config.init_checkpoint_modules == 'backbone':
|
83
|
+
ckpt = tf.train.Checkpoint(backbone=model.backbone)
|
84
|
+
status = ckpt.restore(ckpt_dir_or_file)
|
85
|
+
status.expect_partial().assert_existing_objects_matched()
|
86
|
+
|
87
|
+
logging.info('Finished loading pretrained checkpoint from %s',
|
88
|
+
ckpt_dir_or_file)
|
89
|
+
|
90
|
+
def build_inputs(self,
|
91
|
+
params,
|
92
|
+
input_context: Optional[tf.distribute.InputContext] = None):
|
93
|
+
"""Build input dataset."""
|
94
|
+
if isinstance(params, coco.COCODataConfig):
|
95
|
+
dataset = coco.COCODataLoader(params).load(input_context)
|
96
|
+
else:
|
97
|
+
if params.tfds_name:
|
98
|
+
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
|
99
|
+
else:
|
100
|
+
decoder_cfg = params.decoder.get()
|
101
|
+
if params.decoder.type == 'simple_decoder':
|
102
|
+
decoder = tf_example_decoder.TfExampleDecoder(
|
103
|
+
regenerate_source_id=decoder_cfg.regenerate_source_id)
|
104
|
+
elif params.decoder.type == 'label_map_decoder':
|
105
|
+
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
|
106
|
+
label_map=decoder_cfg.label_map,
|
107
|
+
regenerate_source_id=decoder_cfg.regenerate_source_id)
|
108
|
+
else:
|
109
|
+
raise ValueError('Unknown decoder type: {}!'.format(
|
110
|
+
params.decoder.type))
|
111
|
+
|
112
|
+
parser = detr_input.Parser(
|
113
|
+
class_offset=self._task_config.losses.class_offset,
|
114
|
+
output_size=self._task_config.model.input_size[:2],
|
115
|
+
)
|
116
|
+
|
117
|
+
reader = input_reader_factory.input_reader_generator(
|
118
|
+
params,
|
119
|
+
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
|
120
|
+
decoder_fn=decoder.decode,
|
121
|
+
parser_fn=parser.parse_fn(params.is_training))
|
122
|
+
dataset = reader.read(input_context=input_context)
|
123
|
+
|
124
|
+
return dataset
|
125
|
+
|
126
|
+
def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
|
127
|
+
# Approximate classification cost with 1 - prob[target class].
|
128
|
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
129
|
+
# background: 0
|
130
|
+
cls_cost = self._task_config.losses.lambda_cls * tf.gather(
|
131
|
+
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
|
132
|
+
|
133
|
+
# Compute the L1 cost between boxes,
|
134
|
+
paired_differences = self._task_config.losses.lambda_box * tf.abs(
|
135
|
+
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
|
136
|
+
box_cost = tf.reduce_sum(paired_differences, axis=-1)
|
137
|
+
|
138
|
+
# Compute the giou cost betwen boxes
|
139
|
+
giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
|
140
|
+
box_ops.cycxhw_to_yxyx(box_outputs),
|
141
|
+
box_ops.cycxhw_to_yxyx(box_targets))
|
142
|
+
|
143
|
+
total_cost = cls_cost + box_cost + giou_cost
|
144
|
+
|
145
|
+
max_cost = (
|
146
|
+
self._task_config.losses.lambda_cls * 0.0 +
|
147
|
+
self._task_config.losses.lambda_box * 4. +
|
148
|
+
self._task_config.losses.lambda_giou * 0.0)
|
149
|
+
|
150
|
+
# Set pads to large constant
|
151
|
+
valid = tf.expand_dims(
|
152
|
+
tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1)
|
153
|
+
total_cost = (1 - valid) * max_cost + valid * total_cost
|
154
|
+
|
155
|
+
# Set inf of nan to large constant
|
156
|
+
total_cost = tf.where(
|
157
|
+
tf.logical_or(tf.math.is_nan(total_cost), tf.math.is_inf(total_cost)),
|
158
|
+
max_cost * tf.ones_like(total_cost, dtype=total_cost.dtype),
|
159
|
+
total_cost)
|
160
|
+
|
161
|
+
return total_cost
|
162
|
+
|
163
|
+
def build_losses(self, outputs, labels, aux_losses=None):
|
164
|
+
"""Builds DETR losses."""
|
165
|
+
cls_outputs = outputs['cls_outputs']
|
166
|
+
box_outputs = outputs['box_outputs']
|
167
|
+
cls_targets = labels['classes']
|
168
|
+
box_targets = labels['boxes']
|
169
|
+
|
170
|
+
cost = self._compute_cost(
|
171
|
+
cls_outputs, box_outputs, cls_targets, box_targets)
|
172
|
+
|
173
|
+
_, indices = matchers.hungarian_matching(cost)
|
174
|
+
indices = tf.stop_gradient(indices)
|
175
|
+
|
176
|
+
target_index = tf.math.argmax(indices, axis=1)
|
177
|
+
cls_assigned = tf.gather(cls_outputs, target_index, batch_dims=1, axis=1)
|
178
|
+
box_assigned = tf.gather(box_outputs, target_index, batch_dims=1, axis=1)
|
179
|
+
|
180
|
+
background = tf.equal(cls_targets, 0)
|
181
|
+
num_boxes = tf.reduce_sum(
|
182
|
+
tf.cast(tf.logical_not(background), tf.float32), axis=-1)
|
183
|
+
|
184
|
+
# Down-weight background to account for class imbalance.
|
185
|
+
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
186
|
+
labels=cls_targets, logits=cls_assigned)
|
187
|
+
cls_loss = self._task_config.losses.lambda_cls * tf.where(
|
188
|
+
background, self._task_config.losses.background_cls_weight * xentropy,
|
189
|
+
xentropy)
|
190
|
+
cls_weights = tf.where(
|
191
|
+
background,
|
192
|
+
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
|
193
|
+
tf.ones_like(cls_loss))
|
194
|
+
|
195
|
+
# Box loss is only calculated on non-background class.
|
196
|
+
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
|
197
|
+
box_loss = self._task_config.losses.lambda_box * tf.where(
|
198
|
+
background, tf.zeros_like(l_1), l_1)
|
199
|
+
|
200
|
+
# Giou loss is only calculated on non-background class.
|
201
|
+
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
|
202
|
+
box_ops.cycxhw_to_yxyx(box_assigned),
|
203
|
+
box_ops.cycxhw_to_yxyx(box_targets)
|
204
|
+
))
|
205
|
+
giou_loss = self._task_config.losses.lambda_giou * tf.where(
|
206
|
+
background, tf.zeros_like(giou), giou)
|
207
|
+
|
208
|
+
# Consider doing all reduce once in train_step to speed up.
|
209
|
+
num_boxes_per_replica = tf.reduce_sum(num_boxes)
|
210
|
+
cls_weights_per_replica = tf.reduce_sum(cls_weights)
|
211
|
+
replica_context = tf.distribute.get_replica_context()
|
212
|
+
num_boxes_sum, cls_weights_sum = replica_context.all_reduce(
|
213
|
+
tf.distribute.ReduceOp.SUM,
|
214
|
+
[num_boxes_per_replica, cls_weights_per_replica])
|
215
|
+
cls_loss = tf.math.divide_no_nan(
|
216
|
+
tf.reduce_sum(cls_loss), cls_weights_sum)
|
217
|
+
box_loss = tf.math.divide_no_nan(
|
218
|
+
tf.reduce_sum(box_loss), num_boxes_sum)
|
219
|
+
giou_loss = tf.math.divide_no_nan(
|
220
|
+
tf.reduce_sum(giou_loss), num_boxes_sum)
|
221
|
+
|
222
|
+
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
|
223
|
+
|
224
|
+
total_loss = cls_loss + box_loss + giou_loss + aux_losses
|
225
|
+
return total_loss, cls_loss, box_loss, giou_loss
|
226
|
+
|
227
|
+
def build_metrics(self, training=True):
|
228
|
+
"""Builds detection metrics."""
|
229
|
+
metrics = []
|
230
|
+
metric_names = ['cls_loss', 'box_loss', 'giou_loss']
|
231
|
+
for name in metric_names:
|
232
|
+
metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
|
233
|
+
|
234
|
+
if not training:
|
235
|
+
self.coco_metric = coco_evaluator.COCOEvaluator(
|
236
|
+
annotation_file=self._task_config.annotation_file,
|
237
|
+
include_mask=False,
|
238
|
+
need_rescale_bboxes=True,
|
239
|
+
per_category_metrics=self._task_config.per_category_metrics)
|
240
|
+
return metrics
|
241
|
+
|
242
|
+
def train_step(self, inputs, model, optimizer, metrics=None):
|
243
|
+
"""Does forward and backward.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
inputs: a dictionary of input tensors.
|
247
|
+
model: the model, forward pass definition.
|
248
|
+
optimizer: the optimizer for this training step.
|
249
|
+
metrics: a nested structure of metrics objects.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
A dictionary of logs.
|
253
|
+
"""
|
254
|
+
features, labels = inputs
|
255
|
+
with tf.GradientTape() as tape:
|
256
|
+
outputs = model(features, training=True)
|
257
|
+
|
258
|
+
loss = 0.0
|
259
|
+
cls_loss = 0.0
|
260
|
+
box_loss = 0.0
|
261
|
+
giou_loss = 0.0
|
262
|
+
|
263
|
+
for output in outputs:
|
264
|
+
# Computes per-replica loss.
|
265
|
+
layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = self.build_losses(
|
266
|
+
outputs=output, labels=labels, aux_losses=model.losses)
|
267
|
+
loss += layer_loss
|
268
|
+
cls_loss += layer_cls_loss
|
269
|
+
box_loss += layer_box_loss
|
270
|
+
giou_loss += layer_giou_loss
|
271
|
+
|
272
|
+
# Consider moving scaling logic from build_losses to here.
|
273
|
+
scaled_loss = loss
|
274
|
+
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
|
275
|
+
# scaled for numerical stability.
|
276
|
+
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
277
|
+
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
|
278
|
+
|
279
|
+
tvars = model.trainable_variables
|
280
|
+
grads = tape.gradient(scaled_loss, tvars)
|
281
|
+
# Scales back gradient when LossScaleOptimizer is used.
|
282
|
+
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
283
|
+
grads = optimizer.get_unscaled_gradients(grads)
|
284
|
+
optimizer.apply_gradients(list(zip(grads, tvars)))
|
285
|
+
|
286
|
+
# Multiply for logging.
|
287
|
+
# Since we expect the gradient replica sum to happen in the optimizer,
|
288
|
+
# the loss is scaled with global num_boxes and weights.
|
289
|
+
# To have it more interpretable/comparable we scale it back when logging.
|
290
|
+
num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync
|
291
|
+
loss *= num_replicas_in_sync
|
292
|
+
cls_loss *= num_replicas_in_sync
|
293
|
+
box_loss *= num_replicas_in_sync
|
294
|
+
giou_loss *= num_replicas_in_sync
|
295
|
+
|
296
|
+
# Trainer class handles loss metric for you.
|
297
|
+
logs = {self.loss: loss}
|
298
|
+
|
299
|
+
all_losses = {
|
300
|
+
'cls_loss': cls_loss,
|
301
|
+
'box_loss': box_loss,
|
302
|
+
'giou_loss': giou_loss,
|
303
|
+
}
|
304
|
+
|
305
|
+
# Metric results will be added to logs for you.
|
306
|
+
if metrics:
|
307
|
+
for m in metrics:
|
308
|
+
m.update_state(all_losses[m.name])
|
309
|
+
return logs
|
310
|
+
|
311
|
+
def validation_step(self, inputs, model, metrics=None):
|
312
|
+
"""Validatation step.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
inputs: a dictionary of input tensors.
|
316
|
+
model: the keras.Model.
|
317
|
+
metrics: a nested structure of metrics objects.
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
A dictionary of logs.
|
321
|
+
"""
|
322
|
+
features, labels = inputs
|
323
|
+
|
324
|
+
outputs = model(features, training=False)[-1]
|
325
|
+
loss, cls_loss, box_loss, giou_loss = self.build_losses(
|
326
|
+
outputs=outputs, labels=labels, aux_losses=model.losses)
|
327
|
+
|
328
|
+
# Multiply for logging.
|
329
|
+
# Since we expect the gradient replica sum to happen in the optimizer,
|
330
|
+
# the loss is scaled with global num_boxes and weights.
|
331
|
+
# To have it more interpretable/comparable we scale it back when logging.
|
332
|
+
num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync
|
333
|
+
loss *= num_replicas_in_sync
|
334
|
+
cls_loss *= num_replicas_in_sync
|
335
|
+
box_loss *= num_replicas_in_sync
|
336
|
+
giou_loss *= num_replicas_in_sync
|
337
|
+
|
338
|
+
# Evaluator class handles loss metric for you.
|
339
|
+
logs = {self.loss: loss}
|
340
|
+
|
341
|
+
# This is for backward compatibility.
|
342
|
+
if 'detection_boxes' not in outputs:
|
343
|
+
detection_boxes = box_ops.cycxhw_to_yxyx(
|
344
|
+
outputs['box_outputs']) * tf.expand_dims(
|
345
|
+
tf.concat([
|
346
|
+
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
|
347
|
+
1],
|
348
|
+
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
|
349
|
+
1]
|
350
|
+
],
|
351
|
+
axis=1),
|
352
|
+
axis=1)
|
353
|
+
else:
|
354
|
+
detection_boxes = outputs['detection_boxes']
|
355
|
+
|
356
|
+
detection_scores = tf.math.reduce_max(
|
357
|
+
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
|
358
|
+
) if 'detection_scores' not in outputs else outputs['detection_scores']
|
359
|
+
|
360
|
+
if 'detection_classes' not in outputs:
|
361
|
+
detection_classes = tf.math.argmax(
|
362
|
+
outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
|
363
|
+
else:
|
364
|
+
detection_classes = outputs['detection_classes']
|
365
|
+
|
366
|
+
if 'num_detections' not in outputs:
|
367
|
+
num_detections = tf.reduce_sum(
|
368
|
+
tf.cast(
|
369
|
+
tf.math.greater(
|
370
|
+
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0),
|
371
|
+
tf.int32),
|
372
|
+
axis=-1)
|
373
|
+
else:
|
374
|
+
num_detections = outputs['num_detections']
|
375
|
+
|
376
|
+
predictions = {
|
377
|
+
'detection_boxes': detection_boxes,
|
378
|
+
'detection_scores': detection_scores,
|
379
|
+
'detection_classes': detection_classes,
|
380
|
+
'num_detections': num_detections,
|
381
|
+
'source_id': labels['id'],
|
382
|
+
'image_info': labels['image_info']
|
383
|
+
}
|
384
|
+
|
385
|
+
ground_truths = {
|
386
|
+
'source_id': labels['id'],
|
387
|
+
'height': labels['image_info'][:, 0:1, 0],
|
388
|
+
'width': labels['image_info'][:, 0:1, 1],
|
389
|
+
'num_detections': tf.reduce_sum(
|
390
|
+
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1),
|
391
|
+
'boxes': labels['gt_boxes'],
|
392
|
+
'classes': labels['classes'],
|
393
|
+
'is_crowds': labels['is_crowd']
|
394
|
+
}
|
395
|
+
logs.update({'predictions': predictions,
|
396
|
+
'ground_truths': ground_truths})
|
397
|
+
|
398
|
+
all_losses = {
|
399
|
+
'cls_loss': cls_loss,
|
400
|
+
'box_loss': box_loss,
|
401
|
+
'giou_loss': giou_loss,
|
402
|
+
}
|
403
|
+
|
404
|
+
# Metric results will be added to logs for you.
|
405
|
+
if metrics:
|
406
|
+
for m in metrics:
|
407
|
+
m.update_state(all_losses[m.name])
|
408
|
+
return logs
|
409
|
+
|
410
|
+
def aggregate_logs(self, state=None, step_outputs=None):
|
411
|
+
if state is None:
|
412
|
+
self.coco_metric.reset_states()
|
413
|
+
state = self.coco_metric
|
414
|
+
|
415
|
+
state.update_state(
|
416
|
+
step_outputs['ground_truths'],
|
417
|
+
step_outputs['predictions'])
|
418
|
+
return state
|
419
|
+
|
420
|
+
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
|
421
|
+
return aggregated_logs.result()
|
@@ -0,0 +1,203 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Tests for detection."""
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import tensorflow as tf, tf_keras
|
19
|
+
import tensorflow_datasets as tfds
|
20
|
+
|
21
|
+
from official.projects.detr import optimization
|
22
|
+
from official.projects.detr.configs import detr as detr_cfg
|
23
|
+
from official.projects.detr.dataloaders import coco
|
24
|
+
from official.projects.detr.tasks import detection
|
25
|
+
from official.vision.configs import backbones
|
26
|
+
|
27
|
+
|
28
|
+
_NUM_EXAMPLES = 10
|
29
|
+
|
30
|
+
|
31
|
+
def _gen_fn():
|
32
|
+
h = np.random.randint(0, 300)
|
33
|
+
w = np.random.randint(0, 300)
|
34
|
+
num_boxes = np.random.randint(0, 50)
|
35
|
+
return {
|
36
|
+
'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
|
37
|
+
'image/id': np.random.randint(0, 100),
|
38
|
+
'image/filename': 'test',
|
39
|
+
'objects': {
|
40
|
+
'is_crowd': np.ones(shape=(num_boxes), dtype=bool),
|
41
|
+
'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
|
42
|
+
'label': np.ones(shape=(num_boxes), dtype=np.int64),
|
43
|
+
'id': np.ones(shape=(num_boxes), dtype=np.int64),
|
44
|
+
'area': np.ones(shape=(num_boxes), dtype=np.int64),
|
45
|
+
}
|
46
|
+
}
|
47
|
+
|
48
|
+
|
49
|
+
def _as_dataset(self, *args, **kwargs):
|
50
|
+
del args
|
51
|
+
del kwargs
|
52
|
+
return tf.data.Dataset.from_generator(
|
53
|
+
lambda: (_gen_fn() for i in range(_NUM_EXAMPLES)),
|
54
|
+
output_types=self.info.features.dtype,
|
55
|
+
output_shapes=self.info.features.shape,
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
class DetectionTest(tf.test.TestCase):
|
60
|
+
|
61
|
+
def test_train_step(self):
|
62
|
+
config = detr_cfg.DetrTask(
|
63
|
+
model=detr_cfg.Detr(
|
64
|
+
input_size=[1333, 1333, 3],
|
65
|
+
num_encoder_layers=1,
|
66
|
+
num_decoder_layers=1,
|
67
|
+
num_classes=81,
|
68
|
+
backbone=backbones.Backbone(
|
69
|
+
type='resnet',
|
70
|
+
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
|
71
|
+
),
|
72
|
+
train_data=coco.COCODataConfig(
|
73
|
+
tfds_name='coco/2017',
|
74
|
+
tfds_split='validation',
|
75
|
+
is_training=True,
|
76
|
+
global_batch_size=2,
|
77
|
+
))
|
78
|
+
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
|
79
|
+
task = detection.DetectionTask(config)
|
80
|
+
model = task.build_model()
|
81
|
+
dataset = task.build_inputs(config.train_data)
|
82
|
+
iterator = iter(dataset)
|
83
|
+
opt_cfg = optimization.OptimizationConfig({
|
84
|
+
'optimizer': {
|
85
|
+
'type': 'detr_adamw',
|
86
|
+
'detr_adamw': {
|
87
|
+
'weight_decay_rate': 1e-4,
|
88
|
+
'global_clipnorm': 0.1,
|
89
|
+
}
|
90
|
+
},
|
91
|
+
'learning_rate': {
|
92
|
+
'type': 'stepwise',
|
93
|
+
'stepwise': {
|
94
|
+
'boundaries': [120000],
|
95
|
+
'values': [0.0001, 1.0e-05]
|
96
|
+
}
|
97
|
+
},
|
98
|
+
})
|
99
|
+
optimizer = detection.DetectionTask.create_optimizer(opt_cfg)
|
100
|
+
task.train_step(next(iterator), model, optimizer)
|
101
|
+
|
102
|
+
def test_validation_step(self):
|
103
|
+
config = detr_cfg.DetrTask(
|
104
|
+
model=detr_cfg.Detr(
|
105
|
+
input_size=[1333, 1333, 3],
|
106
|
+
num_encoder_layers=1,
|
107
|
+
num_decoder_layers=1,
|
108
|
+
num_classes=81,
|
109
|
+
backbone=backbones.Backbone(
|
110
|
+
type='resnet',
|
111
|
+
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
|
112
|
+
),
|
113
|
+
validation_data=coco.COCODataConfig(
|
114
|
+
tfds_name='coco/2017',
|
115
|
+
tfds_split='validation',
|
116
|
+
is_training=False,
|
117
|
+
global_batch_size=2,
|
118
|
+
))
|
119
|
+
|
120
|
+
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
|
121
|
+
task = detection.DetectionTask(config)
|
122
|
+
model = task.build_model()
|
123
|
+
metrics = task.build_metrics(training=False)
|
124
|
+
dataset = task.build_inputs(config.validation_data)
|
125
|
+
iterator = iter(dataset)
|
126
|
+
logs = task.validation_step(next(iterator), model, metrics)
|
127
|
+
state = task.aggregate_logs(step_outputs=logs)
|
128
|
+
task.reduce_aggregated_logs(state)
|
129
|
+
|
130
|
+
|
131
|
+
class DetectionTFDSTest(tf.test.TestCase):
|
132
|
+
|
133
|
+
def test_train_step(self):
|
134
|
+
config = detr_cfg.DetrTask(
|
135
|
+
model=detr_cfg.Detr(
|
136
|
+
input_size=[1333, 1333, 3],
|
137
|
+
num_encoder_layers=1,
|
138
|
+
num_decoder_layers=1,
|
139
|
+
backbone=backbones.Backbone(
|
140
|
+
type='resnet',
|
141
|
+
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
|
142
|
+
),
|
143
|
+
losses=detr_cfg.Losses(class_offset=1),
|
144
|
+
train_data=detr_cfg.DataConfig(
|
145
|
+
tfds_name='coco/2017',
|
146
|
+
tfds_split='validation',
|
147
|
+
is_training=True,
|
148
|
+
global_batch_size=2,
|
149
|
+
))
|
150
|
+
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
|
151
|
+
task = detection.DetectionTask(config)
|
152
|
+
model = task.build_model()
|
153
|
+
dataset = task.build_inputs(config.train_data)
|
154
|
+
iterator = iter(dataset)
|
155
|
+
opt_cfg = optimization.OptimizationConfig({
|
156
|
+
'optimizer': {
|
157
|
+
'type': 'detr_adamw',
|
158
|
+
'detr_adamw': {
|
159
|
+
'weight_decay_rate': 1e-4,
|
160
|
+
'global_clipnorm': 0.1,
|
161
|
+
}
|
162
|
+
},
|
163
|
+
'learning_rate': {
|
164
|
+
'type': 'stepwise',
|
165
|
+
'stepwise': {
|
166
|
+
'boundaries': [120000],
|
167
|
+
'values': [0.0001, 1.0e-05]
|
168
|
+
}
|
169
|
+
},
|
170
|
+
})
|
171
|
+
optimizer = detection.DetectionTask.create_optimizer(opt_cfg)
|
172
|
+
task.train_step(next(iterator), model, optimizer)
|
173
|
+
|
174
|
+
def test_validation_step(self):
|
175
|
+
config = detr_cfg.DetrTask(
|
176
|
+
model=detr_cfg.Detr(
|
177
|
+
input_size=[1333, 1333, 3],
|
178
|
+
num_encoder_layers=1,
|
179
|
+
num_decoder_layers=1,
|
180
|
+
backbone=backbones.Backbone(
|
181
|
+
type='resnet',
|
182
|
+
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
|
183
|
+
),
|
184
|
+
losses=detr_cfg.Losses(class_offset=1),
|
185
|
+
validation_data=detr_cfg.DataConfig(
|
186
|
+
tfds_name='coco/2017',
|
187
|
+
tfds_split='validation',
|
188
|
+
is_training=False,
|
189
|
+
global_batch_size=2,
|
190
|
+
))
|
191
|
+
|
192
|
+
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
|
193
|
+
task = detection.DetectionTask(config)
|
194
|
+
model = task.build_model()
|
195
|
+
metrics = task.build_metrics(training=False)
|
196
|
+
dataset = task.build_inputs(config.validation_data)
|
197
|
+
iterator = iter(dataset)
|
198
|
+
logs = task.validation_step(next(iterator), model, metrics)
|
199
|
+
state = task.aggregate_logs(step_outputs=logs)
|
200
|
+
task.reduce_aggregated_logs(state)
|
201
|
+
|
202
|
+
if __name__ == '__main__':
|
203
|
+
tf.test.main()
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""TensorFlow Model Garden Vision training driver."""
|
16
|
+
|
17
|
+
from absl import app
|
18
|
+
from absl import flags
|
19
|
+
import gin
|
20
|
+
|
21
|
+
from official.common import distribute_utils
|
22
|
+
from official.common import flags as tfm_flags
|
23
|
+
from official.core import task_factory
|
24
|
+
from official.core import train_lib
|
25
|
+
from official.core import train_utils
|
26
|
+
from official.modeling import performance
|
27
|
+
# pylint: disable=unused-import
|
28
|
+
from official.projects.detr.configs import detr
|
29
|
+
from official.projects.detr.tasks import detection
|
30
|
+
# pylint: enable=unused-import
|
31
|
+
|
32
|
+
FLAGS = flags.FLAGS
|
33
|
+
|
34
|
+
|
35
|
+
def main(_):
|
36
|
+
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
37
|
+
params = train_utils.parse_configuration(FLAGS)
|
38
|
+
model_dir = FLAGS.model_dir
|
39
|
+
if 'train' in FLAGS.mode:
|
40
|
+
# Pure eval modes do not output yaml files. Otherwise continuous eval job
|
41
|
+
# may race against the train job for writing the same file.
|
42
|
+
train_utils.serialize_config(params, model_dir)
|
43
|
+
|
44
|
+
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
|
45
|
+
# can have significant impact on model speeds by utilizing float16 in case of
|
46
|
+
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
|
47
|
+
# dtype is float16
|
48
|
+
if params.runtime.mixed_precision_dtype:
|
49
|
+
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
|
50
|
+
distribution_strategy = distribute_utils.get_distribution_strategy(
|
51
|
+
distribution_strategy=params.runtime.distribution_strategy,
|
52
|
+
all_reduce_alg=params.runtime.all_reduce_alg,
|
53
|
+
num_gpus=params.runtime.num_gpus,
|
54
|
+
tpu_address=params.runtime.tpu)
|
55
|
+
with distribution_strategy.scope():
|
56
|
+
task = task_factory.get_task(params.task, logging_dir=model_dir)
|
57
|
+
|
58
|
+
train_lib.run_experiment(
|
59
|
+
distribution_strategy=distribution_strategy,
|
60
|
+
task=task,
|
61
|
+
mode=FLAGS.mode,
|
62
|
+
params=params,
|
63
|
+
model_dir=model_dir)
|
64
|
+
|
65
|
+
train_utils.save_gin_config(FLAGS.mode, model_dir)
|
66
|
+
|
67
|
+
if __name__ == '__main__':
|
68
|
+
tfm_flags.define_flags()
|
69
|
+
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
|
70
|
+
app.run(main)
|