tf-models-nightly 2.14.0.dev20230929__py2.py3-none-any.whl → 2.14.0.dev20231001__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,7 +34,7 @@ class CenterNetDetectionGenerator(tf.keras.layers.Layer):
34
34
  """CenterNet Detection Generator."""
35
35
 
36
36
  def __init__(self,
37
- input_image_dims: int = 512,
37
+ input_image_dims: tuple[int, int] | int = 512,
38
38
  net_down_scale: int = 4,
39
39
  max_detections: int = 100,
40
40
  peak_error: float = 1e-6,
@@ -47,7 +47,10 @@ class CenterNetDetectionGenerator(tf.keras.layers.Layer):
47
47
  """Initialize CenterNet Detection Generator.
48
48
 
49
49
  Args:
50
- input_image_dims: An `int` that specifies the input image size.
50
+ input_image_dims: The input image size. If it is a tuple of two `int`s, it
51
+ is the size (height, width) of the input images. If it is an `int`, the
52
+ input images are supposed to be squared images whose height and width
53
+ are equal.
51
54
  net_down_scale: An `int` that specifies stride of the output.
52
55
  max_detections: An `int` specifying the maximum number of bounding
53
56
  boxes generated. This is an upper bound, so the number of generated
@@ -67,6 +70,9 @@ class CenterNetDetectionGenerator(tf.keras.layers.Layer):
67
70
  """
68
71
  super(CenterNetDetectionGenerator, self).__init__(**kwargs)
69
72
 
73
+ if isinstance(input_image_dims, int):
74
+ input_image_dims = (input_image_dims, input_image_dims)
75
+
70
76
  # Object center selection parameters
71
77
  self._max_detections = max_detections
72
78
  self._peak_error = peak_error
@@ -246,10 +252,28 @@ class CenterNetDetectionGenerator(tf.keras.layers.Layer):
246
252
  return boxes, detection_classes
247
253
 
248
254
  def convert_strided_predictions_to_normalized_boxes(self, boxes: tf.Tensor):
255
+ """Converts strided predictions to normalized boxes.
256
+
257
+ Args:
258
+ boxes: A tf.Tensor of shape [batch_size, num_predictions, 4], representing
259
+ the strided predictions of the detected objects.
260
+
261
+ Returns:
262
+ A tf.Tensor of shape [batch_size, num_predictions, 4], representing
263
+ the normalized boxes of the detected objects.
264
+ """
249
265
  boxes = boxes * tf.cast(self._net_down_scale, boxes.dtype)
250
- boxes = boxes / tf.cast(self._input_image_dims, boxes.dtype)
251
- boxes = tf.clip_by_value(boxes, 0.0, 1.0)
252
- return boxes
266
+
267
+ height = tf.cast(self._input_image_dims[0], boxes.dtype)
268
+ width = tf.cast(self._input_image_dims[1], boxes.dtype)
269
+ ymin = boxes[..., 0:1] / height
270
+ xmin = boxes[..., 1:2] / width
271
+ ymax = boxes[..., 2:3] / height
272
+ xmax = boxes[..., 3:4] / width
273
+
274
+ normalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
275
+ normalized_boxes = tf.clip_by_value(normalized_boxes, 0.0, 1.0)
276
+ return normalized_boxes
253
277
 
254
278
  def __call__(self, inputs):
255
279
  # Get heatmaps from decoded outputs via final hourglass stack output
@@ -308,8 +332,7 @@ class CenterNetDetectionGenerator(tf.keras.layers.Layer):
308
332
  nms_thresh=0.4)
309
333
 
310
334
  num_det = tf.reduce_sum(tf.cast(scores > 0, dtype=tf.int32), axis=1)
311
- boxes = box_ops.denormalize_boxes(
312
- boxes, [self._input_image_dims, self._input_image_dims])
335
+ boxes = box_ops.denormalize_boxes(boxes, self._input_image_dims)
313
336
 
314
337
  return {
315
338
  'boxes': boxes,
@@ -0,0 +1,152 @@
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for Centernet detection_generator."""
16
+
17
+ from collections.abc import Mapping, Sequence
18
+
19
+ from absl.testing import parameterized
20
+ import tensorflow as tf
21
+
22
+ from official.projects.centernet.modeling.layers import detection_generator
23
+
24
+
25
+ def _build_input_example(
26
+ batch_size: int, height: int, width: int, num_classes: int, num_outputs: int
27
+ ) -> Mapping[str, Sequence[tf.Tensor]]:
28
+ """Builds a random input example for CenterNetDetectionGenerator.
29
+
30
+ Args:
31
+ batch_size: The batch size.
32
+ height: The height of the feature_map.
33
+ width: The width of the feature_map.
34
+ num_classes: The number of classes to detect.
35
+ num_outputs: The number of output heatmaps, which corresponds to the length
36
+ of CenterNetHead's input_levels.
37
+
38
+ Returns:
39
+ A dictionary, mapping from feature names to sequences of tensors.
40
+ """
41
+ return {
42
+ 'ct_heatmaps': [
43
+ tf.random.normal([batch_size, height, width, num_classes])
44
+ for _ in range(num_outputs)
45
+ ],
46
+ 'ct_size': [
47
+ tf.random.normal([batch_size, height, width, 2])
48
+ for _ in range(num_outputs)
49
+ ],
50
+ 'ct_offset': [
51
+ tf.random.normal([batch_size, height, width, 2])
52
+ for _ in range(num_outputs)
53
+ ],
54
+ }
55
+
56
+
57
+ class CenterNetDetectionGeneratorTest(parameterized.TestCase, tf.test.TestCase):
58
+
59
+ @parameterized.parameters(
60
+ (1, 256),
61
+ (1, 512),
62
+ (2, 256),
63
+ (2, 512),
64
+ )
65
+ def test_squered_image_forward(self, batch_size, input_image_dims):
66
+ max_detections = 128
67
+ num_classes = 80
68
+ generator = detection_generator.CenterNetDetectionGenerator(
69
+ input_image_dims=input_image_dims, max_detections=max_detections
70
+ )
71
+ test_input = _build_input_example(
72
+ batch_size=batch_size,
73
+ height=input_image_dims,
74
+ width=input_image_dims,
75
+ num_classes=num_classes,
76
+ num_outputs=2,
77
+ )
78
+
79
+ output = generator(test_input)
80
+
81
+ self.assert_detection_generator_output_shapes(
82
+ output, batch_size, max_detections
83
+ )
84
+
85
+ @parameterized.parameters(
86
+ (1, (256, 512)),
87
+ (1, (512, 256)),
88
+ (2, (256, 512)),
89
+ (2, (512, 256)),
90
+ )
91
+ def test_rectangular_image_forward(self, batch_size, input_image_dims):
92
+ max_detections = 128
93
+ num_classes = 80
94
+ generator = detection_generator.CenterNetDetectionGenerator(
95
+ input_image_dims=input_image_dims, max_detections=max_detections
96
+ )
97
+ test_input = _build_input_example(
98
+ batch_size=batch_size,
99
+ height=input_image_dims[0],
100
+ width=input_image_dims[1],
101
+ num_classes=num_classes,
102
+ num_outputs=2,
103
+ )
104
+
105
+ output = generator(test_input)
106
+
107
+ self.assert_detection_generator_output_shapes(
108
+ output, batch_size, max_detections
109
+ )
110
+
111
+ def assert_detection_generator_output_shapes(
112
+ self,
113
+ output: Mapping[str, tf.Tensor],
114
+ batch_size: int,
115
+ max_detections: int,
116
+ ):
117
+ self.assertAllEqual(output['boxes'].shape, (batch_size, max_detections, 4))
118
+ self.assertAllEqual(output['classes'].shape, (batch_size, max_detections))
119
+ self.assertAllEqual(
120
+ output['confidence'].shape, (batch_size, max_detections)
121
+ )
122
+ self.assertAllEqual(output['num_detections'].shape, (batch_size,))
123
+
124
+ @parameterized.parameters(
125
+ (256,),
126
+ (512,),
127
+ ((256, 512),),
128
+ ((512, 256),),
129
+ )
130
+ def test_serialize_deserialize(self, input_image_dims):
131
+ kwargs = {
132
+ 'input_image_dims': input_image_dims,
133
+ 'net_down_scale': 4,
134
+ 'max_detections': 128,
135
+ 'peak_error': 1e-6,
136
+ 'peak_extract_kernel_size': 3,
137
+ 'class_offset': 1,
138
+ 'use_nms': False,
139
+ 'nms_pre_thresh': 0.1,
140
+ 'nms_thresh': 0.5,
141
+ }
142
+
143
+ generator = detection_generator.CenterNetDetectionGenerator(**kwargs)
144
+ new_generator = detection_generator.CenterNetDetectionGenerator.from_config(
145
+ generator.get_config()
146
+ )
147
+
148
+ self.assertAllEqual(generator.get_config(), new_generator.get_config())
149
+
150
+
151
+ if __name__ == '__main__':
152
+ tf.test.main()
@@ -130,7 +130,10 @@ class CenterNetTask(base_task.Task):
130
130
  peak_extract_kernel_size=dg_config.peak_extract_kernel_size,
131
131
  class_offset=dg_config.class_offset,
132
132
  net_down_scale=self._net_down_scale,
133
- input_image_dims=model_config.input_size[0],
133
+ input_image_dims=(
134
+ model_config.input_size[0],
135
+ model_config.input_size[1],
136
+ ),
134
137
  use_nms=dg_config.use_nms,
135
138
  nms_pre_thresh=dg_config.nms_pre_thresh,
136
139
  nms_thresh=dg_config.nms_thresh)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.14.0.dev20230929
3
+ Version: 2.14.0.dev20231001
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -488,7 +488,8 @@ official/projects/centernet/modeling/heads/centernet_head_test.py,sha256=01sGcBm
488
488
  official/projects/centernet/modeling/layers/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
489
489
  official/projects/centernet/modeling/layers/cn_nn_blocks.py,sha256=VJW6EZNk90wtdRshOJX13io5zvq02YI8yHw9gyHgGo0,12200
490
490
  official/projects/centernet/modeling/layers/cn_nn_blocks_test.py,sha256=UlKFV4npO_VLVyAhArb5Ece2Q7KMKU6ClnD4t4wH0K0,5148
491
- official/projects/centernet/modeling/layers/detection_generator.py,sha256=oOzg_K3RvHq4mWTQlNOkkdLiAE5PzTGCHV6spC4cNyk,13299
491
+ official/projects/centernet/modeling/layers/detection_generator.py,sha256=kgwYcDIchYPPMEF9y3eYEYOgJf_23KF1Tc7Yc5teM4w,14232
492
+ official/projects/centernet/modeling/layers/detection_generator_test.py,sha256=rvYmhe5d03fBeaaXqQ3K9-UV506oYe7VGMHY__P5Ajw,4618
492
493
  official/projects/centernet/ops/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
493
494
  official/projects/centernet/ops/box_list.py,sha256=7SbXrXlQ7eMMkL5Ig7TfBpMdGNHqhy2MyvBQOVUHNRg,6813
494
495
  official/projects/centernet/ops/box_list_ops.py,sha256=g0o8JML6Avd4tFnxQ6rd3ewx5di2OmtUnFpmYkHO6v0,12856
@@ -498,7 +499,7 @@ official/projects/centernet/ops/preprocess_ops.py,sha256=0gveVTznzP7hqFx2Dc8lifO
498
499
  official/projects/centernet/ops/target_assigner.py,sha256=MbK5lYE3x55zTfbQAam06PNopuh-TsDnH-rFjkXs0z4,16235
499
500
  official/projects/centernet/ops/target_assigner_test.py,sha256=ARCfMwJBgEcJmxGSU6KLvrb-YcGgjjKTon-vfHN54BY,6817
500
501
  official/projects/centernet/tasks/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
501
- official/projects/centernet/tasks/centernet.py,sha256=I56SLNBdQIP6SlQROM9NdfiStrD4BfRY3RsqK6Bt6eI,16719
502
+ official/projects/centernet/tasks/centernet.py,sha256=7ujw_qpykYOwxkaqYNKj_qr8qC8gMr0-d-c4hG2rlPE,16784
502
503
  official/projects/centernet/utils/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
503
504
  official/projects/centernet/utils/tf2_centernet_checkpoint_converter.py,sha256=zWwtApJWaXplRbrEUFZEOg0p_VKCwT9--bwb6ImtrJM,5327
504
505
  official/projects/centernet/utils/checkpoints/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
@@ -1111,9 +1112,9 @@ tensorflow_models/__init__.py,sha256=Ciz_YBke6teb6y42QyQTUBDdXJAiV7Qdu1zOoZvYiKw
1111
1112
  tensorflow_models/tensorflow_models_test.py,sha256=Kz2y4V-rtBhZFFfKD2soCq52hviSfJVV1L2ztqS-9oM,1385
1112
1113
  tensorflow_models/nlp/__init__.py,sha256=3dULDpUBpDi9vljpXadq6oJrWH4y6z42Bz2d3hopYZw,807
1113
1114
  tensorflow_models/vision/__init__.py,sha256=4y77XkHaH8qLls3-6ta4tMp3Xj8CLbB0ihH91HsQ9z4,833
1114
- tf_models_nightly-2.14.0.dev20230929.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1115
- tf_models_nightly-2.14.0.dev20230929.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1116
- tf_models_nightly-2.14.0.dev20230929.dist-info/METADATA,sha256=_xNPUlabJbjQNoDCdLKnS6iowEnnJF5PnntkemfxWVQ,1390
1117
- tf_models_nightly-2.14.0.dev20230929.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1118
- tf_models_nightly-2.14.0.dev20230929.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1119
- tf_models_nightly-2.14.0.dev20230929.dist-info/RECORD,,
1115
+ tf_models_nightly-2.14.0.dev20231001.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1116
+ tf_models_nightly-2.14.0.dev20231001.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1117
+ tf_models_nightly-2.14.0.dev20231001.dist-info/METADATA,sha256=zbnYnWYSbbEvohvshHzY_zsP-u6zrIZ039HHbd4lOz8,1390
1118
+ tf_models_nightly-2.14.0.dev20231001.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1119
+ tf_models_nightly-2.14.0.dev20231001.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1120
+ tf_models_nightly-2.14.0.dev20231001.dist-info/RECORD,,