tf-models-nightly 2.14.0.dev20231010__py2.py3-none-any.whl → 2.14.0.dev20231011__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -198,6 +198,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
198
198
  init_checkpoint_modules: Union[
199
199
  str, List[str]] = 'all' # all, backbone, and/or decoder
200
200
  export_config: ExportConfig = dataclasses.field(default_factory=ExportConfig)
201
+ allow_image_summary: bool = True
201
202
 
202
203
 
203
204
  @exp_factory.register_config_factory('semantic_segmentation')
@@ -29,6 +29,7 @@ from official.vision.dataloaders import tfds_factory
29
29
  from official.vision.evaluation import segmentation_metrics
30
30
  from official.vision.losses import segmentation_losses
31
31
  from official.vision.modeling import factory
32
+ from official.vision.utils.object_detection import visualization_utils
32
33
 
33
34
 
34
35
  @task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
@@ -321,6 +322,14 @@ class SemanticSegmentationTask(base_task.Task):
321
322
  if metrics:
322
323
  self.process_metrics(metrics, labels, outputs)
323
324
 
325
+ if (
326
+ hasattr(self.task_config, 'allow_image_summary')
327
+ and self.task_config.allow_image_summary
328
+ ):
329
+ logs.update(
330
+ {'visualization': (tf.cast(features, dtype=tf.float32), outputs)}
331
+ )
332
+
324
333
  return logs
325
334
 
326
335
  def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
@@ -330,17 +339,37 @@ class SemanticSegmentationTask(base_task.Task):
330
339
  def aggregate_logs(self, state=None, step_outputs=None):
331
340
  if state is None and self.iou_metric is not None:
332
341
  self.iou_metric.reset_states()
333
- state = self.iou_metric
342
+
343
+ if 'visualization' in step_outputs:
344
+ # Update segmentation state for writing summary if there are artifacts for
345
+ # visualization.
346
+ if state is None:
347
+ state = {}
348
+ state.update(visualization_utils.update_segmentation_state(step_outputs))
349
+
350
+ if state is None:
351
+ # Create an arbitrary state to indicate it's not the first step in the
352
+ # following calls to this function.
353
+ state = True
354
+
334
355
  return state
335
356
 
336
357
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
337
- result = {}
358
+ logs = {}
338
359
  if self.iou_metric is not None:
339
360
  ious = self.iou_metric.result()
340
361
  # TODO(arashwan): support loading class name from a label map file.
341
362
  if self.task_config.evaluation.report_per_class_iou:
342
363
  for i, value in enumerate(ious.numpy()):
343
- result.update({'iou/{}'.format(i): value})
364
+ logs.update({'iou/{}'.format(i): value})
344
365
  # Computes mean IoU
345
- result.update({'mean_iou': tf.reduce_mean(ious)})
346
- return result
366
+ logs.update({'mean_iou': tf.reduce_mean(ious)})
367
+
368
+ # Add visualization for summary.
369
+ if isinstance(aggregated_logs, dict) and 'image' in aggregated_logs:
370
+ validation_outputs = visualization_utils.visualize_segmentation_outputs(
371
+ logs=aggregated_logs, task_config=self.task_config
372
+ )
373
+ logs.update(validation_outputs)
374
+
375
+ return logs
@@ -894,3 +894,160 @@ def update_detection_state(step_outputs=None) -> Dict[str, Any]:
894
894
  state['detection_masks'] = tf.concat(detection_masks, axis=0)
895
895
 
896
896
  return state
897
+
898
+
899
+ def update_segmentation_state(step_outputs=None) -> Dict[str, Any]:
900
+ """Updates segmentation state to optionally add input image and predictions."""
901
+ state = {}
902
+ if step_outputs:
903
+ state['image'] = tf.concat(step_outputs['visualization'][0], axis=0)
904
+ state['logits'] = tf.concat(
905
+ step_outputs['visualization'][1]['logits'], axis=0
906
+ )
907
+ return state
908
+
909
+
910
+ def visualize_segmentation_outputs(
911
+ logs,
912
+ task_config,
913
+ original_image_spatial_shape=None,
914
+ true_image_shape=None,
915
+ image_mean: Optional[Union[float, List[float]]] = None,
916
+ image_std: Optional[Union[float, List[float]]] = None,
917
+ key: str = 'image/validation_outputs',
918
+ ) -> Dict[str, Any]:
919
+ """Visualizes the detection outputs.
920
+
921
+ It extracts images and predictions from logs and draws visualization on input
922
+ images. By default, it requires `detection_boxes`, `detection_classes` and
923
+ `detection_scores` in the prediction, and optionally accepts
924
+ `detection_keypoints` and `detection_masks`.
925
+
926
+ Args:
927
+ logs: A dictionaty of log that contains images and predictions.
928
+ task_config: A task config.
929
+ original_image_spatial_shape: A [N, 2] tensor containing the spatial size of
930
+ the original image.
931
+ true_image_shape: A [N, 3] tensor containing the spatial size of unpadded
932
+ original_image.
933
+ image_mean: An optional float or list of floats used as the mean pixel value
934
+ to normalize images.
935
+ image_std: An optional float or list of floats used as the std to normalize
936
+ images.
937
+ key: A string specifying the key of the returned dictionary.
938
+
939
+ Returns:
940
+ A dictionary of images with visualization drawn on it. Each key corresponds
941
+ to a 4D tensor with segments drawn on each image.
942
+ """
943
+ images = logs['image']
944
+ masks = np.argmax(logs['logits'], axis=-1)
945
+ num_classes = task_config.model.num_classes
946
+
947
+ def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
948
+ if image_mean is None and image_std is None:
949
+ images *= tf.constant(
950
+ preprocess_ops.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype
951
+ )
952
+ images += tf.constant(
953
+ preprocess_ops.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype
954
+ )
955
+ elif image_mean is not None and image_std is not None:
956
+ if isinstance(image_mean, float) and isinstance(image_std, float):
957
+ images = images * image_std + image_mean
958
+ elif isinstance(image_mean, list) and isinstance(image_std, list):
959
+ images *= tf.constant(image_std, shape=[1, 1, 3], dtype=images.dtype)
960
+ images += tf.constant(image_mean, shape=[1, 1, 3], dtype=images.dtype)
961
+ else:
962
+ raise ValueError(
963
+ '`image_mean` and `image_std` should be the same type.'
964
+ )
965
+ else:
966
+ raise ValueError(
967
+ 'Both `image_mean` and `image_std` should be set or None at the same '
968
+ 'time.'
969
+ )
970
+ return tf.cast(images, dtype=tf.uint8)
971
+
972
+ images = tf.nest.map_structure(
973
+ tf.identity,
974
+ tf.map_fn(
975
+ _denormalize_images,
976
+ elems=images,
977
+ fn_output_signature=tf.TensorSpec(
978
+ shape=images.shape.as_list()[1:], dtype=tf.uint8
979
+ ),
980
+ parallel_iterations=32,
981
+ ),
982
+ )
983
+
984
+ if images.shape[3] > 3:
985
+ images = images[:, :, :, 0:3]
986
+ elif images.shape[3] == 1:
987
+ images = tf.image.grayscale_to_rgb(images)
988
+ if true_image_shape is None:
989
+ true_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 3])
990
+ else:
991
+ true_shapes = true_image_shape
992
+ if original_image_spatial_shape is None:
993
+ original_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 2])
994
+ else:
995
+ original_shapes = original_image_spatial_shape
996
+
997
+ visualize_fn = functools.partial(_visualize_masks, num_classes=num_classes)
998
+ elems = [true_shapes, original_shapes, images, masks]
999
+
1000
+ def draw_segments(image_and_segments):
1001
+ """Draws boxes on image."""
1002
+ true_shape = image_and_segments[0]
1003
+ original_shape = image_and_segments[1]
1004
+ if true_image_shape is not None:
1005
+ image = shape_utils.pad_or_clip_nd(
1006
+ image_and_segments[2], [true_shape[0], true_shape[1], 3]
1007
+ )
1008
+ if original_image_spatial_shape is not None:
1009
+ image_and_segments[2] = _resize_original_image(image, original_shape)
1010
+
1011
+ image_with_boxes = tf.compat.v1.py_func(
1012
+ visualize_fn, image_and_segments[2:], tf.uint8
1013
+ )
1014
+ return image_with_boxes
1015
+
1016
+ images_with_segments = tf.map_fn(
1017
+ draw_segments, elems, dtype=tf.uint8, back_prop=False
1018
+ )
1019
+
1020
+ outputs = {}
1021
+ for i, image in enumerate(images_with_segments):
1022
+ outputs[key + f'/{i}'] = image[None, ...]
1023
+
1024
+ return outputs
1025
+
1026
+
1027
+ def _visualize_masks(image, mask, num_classes, alpha=0.4):
1028
+ """Visualizes semantic segmentation masks."""
1029
+ solid_color = np.repeat(
1030
+ np.expand_dims(np.zeros_like(mask), axis=2), 3, axis=2
1031
+ )
1032
+ for i in range(num_classes):
1033
+ color = STANDARD_COLORS[i % len(STANDARD_COLORS)]
1034
+ rgb = ImageColor.getrgb(color)
1035
+ one_class_mask = np.where(mask == i, 1, 0)
1036
+ solid_color = solid_color + np.expand_dims(
1037
+ one_class_mask, axis=2
1038
+ ) * np.reshape(list(rgb), [1, 1, 3])
1039
+
1040
+ pil_image = Image.fromarray(image)
1041
+ pil_solid_color = (
1042
+ Image.fromarray(np.uint8(solid_color))
1043
+ .convert('RGBA')
1044
+ .resize(pil_image.size)
1045
+ )
1046
+ pil_mask = (
1047
+ Image.fromarray(np.uint8(255.0 * alpha * np.ones_like(mask)))
1048
+ .convert('L')
1049
+ .resize(pil_image.size)
1050
+ )
1051
+ pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
1052
+ np.copyto(image, np.array(pil_image.convert('RGB')))
1053
+ return image
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.14.0.dev20231010
3
+ Version: 2.14.0.dev20231011
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -874,7 +874,7 @@ official/vision/configs/maskrcnn.py,sha256=k5TT6AheXFJ1y_o-7Qxn0lsVyfkRngg0pERWc
874
874
  official/vision/configs/maskrcnn_test.py,sha256=t3Yx1GV4lCrQANwacC99NJf3Cy_guDIOii5qUvidG9M,1723
875
875
  official/vision/configs/retinanet.py,sha256=jYXROb2d3c1796V7AaAeQJQyfDLM2MQdD3I6NAODN9A,17719
876
876
  official/vision/configs/retinanet_test.py,sha256=yeNLguvsPCpxv1BQAJJCQrrDkveNA1gmZlDya9CIwck,1689
877
- official/vision/configs/semantic_segmentation.py,sha256=JkdKl_FKhSZq1wUcQUc_ersiOOgnPTytavanBX2ppis,30581
877
+ official/vision/configs/semantic_segmentation.py,sha256=PJsmPt4bdORFFpVpswouOyK_yGiXbXLxRTxxklQBFzA,30616
878
878
  official/vision/configs/semantic_segmentation_test.py,sha256=EXXKMekGrik0uJ1O4zZ5AeN7BPDjm3a2uPDUTeDoL-A,1857
879
879
  official/vision/configs/video_classification.py,sha256=eLWT3ClzQAoJFxayR1AdGnFfiI0XYVOzrccYOGYx_FE,14513
880
880
  official/vision/configs/video_classification_test.py,sha256=tEprp-PLabvA0d7x52FtcCT6AmmRGPlP4dOEJ63fojU,1869
@@ -1059,7 +1059,7 @@ official/vision/tasks/__init__.py,sha256=pwY_FnD2dwCOuYUHfZoAvnONwDWalX5E1J27tSJ
1059
1059
  official/vision/tasks/image_classification.py,sha256=Q7TgcKEhHLP_64jlLbpu42Y8Hzdf3VhOm4Xb5dlEx1o,16689
1060
1060
  official/vision/tasks/maskrcnn.py,sha256=Gc5cu_W48bMz_5d8HL8UUR3uJTL0PMvKm3ZnyO2ALk4,25421
1061
1061
  official/vision/tasks/retinanet.py,sha256=UKiwdO7Id53nqpigW9Hmd-7xGtdEb0il_y5K3Zz3Fkw,18104
1062
- official/vision/tasks/semantic_segmentation.py,sha256=rbUWzGJoI0ho_b0WYEGwpAidQ7vi-IDN7a1cN7V8eQ8,13246
1062
+ official/vision/tasks/semantic_segmentation.py,sha256=QA6qyxjdxptBRneGhECOtubvGo4ZyG17fFBV6Yz0hlY,14233
1063
1063
  official/vision/tasks/video_classification.py,sha256=E_AQwPvSzMx2MhnQctHKGKKshot-R9lwYLE6HFEux5M,14308
1064
1064
  official/vision/utils/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
1065
1065
  official/vision/utils/ops_test.py,sha256=g2a5Hij3N-nnA5vSBhootzeapYAMrZRvadXNUXtwnpc,4384
@@ -1078,7 +1078,7 @@ official/vision/utils/object_detection/preprocessor.py,sha256=Fklxck111TQEKsrWSM
1078
1078
  official/vision/utils/object_detection/region_similarity_calculator.py,sha256=yqQLLRT80IdAu3K_fILli_u1aL37lv0FpDtFcyRrPzs,4544
1079
1079
  official/vision/utils/object_detection/shape_utils.py,sha256=p3Q7e9gTTQNv1gnMAkuTkfXc6DYVB8mjk_Vjlq7bRlg,3608
1080
1080
  official/vision/utils/object_detection/target_assigner.py,sha256=fTkjedzhp_-RTUGR27tWbmnnLpW7F3lVCkZlr7Nv_9o,24198
1081
- official/vision/utils/object_detection/visualization_utils.py,sha256=RzK96moeYTp90jn2uJ4l9NBI9-hSKuzM8vKhy4ZIl5Y,34757
1081
+ official/vision/utils/object_detection/visualization_utils.py,sha256=-g75oXcrcchAiJC4_63Z4CifJLN4Km6jUIyqIa4VY4Y,40256
1082
1082
  orbit/__init__.py,sha256=aQRo8zqIQ0Dw4JQReZeiB6MmuJLvvw4DbYHYti5AGys,1117
1083
1083
  orbit/controller.py,sha256=pQo060KNNWekzyyRuGBTNhVS0_9D4AJWylaW_Govhuc,25082
1084
1084
  orbit/controller_test.py,sha256=5b8JCx3s8mfzqYoOa1D3xbD-Qr5uIXcTXM3pLEGjRjo,31792
@@ -1112,9 +1112,9 @@ tensorflow_models/__init__.py,sha256=Ciz_YBke6teb6y42QyQTUBDdXJAiV7Qdu1zOoZvYiKw
1112
1112
  tensorflow_models/tensorflow_models_test.py,sha256=Kz2y4V-rtBhZFFfKD2soCq52hviSfJVV1L2ztqS-9oM,1385
1113
1113
  tensorflow_models/nlp/__init__.py,sha256=3dULDpUBpDi9vljpXadq6oJrWH4y6z42Bz2d3hopYZw,807
1114
1114
  tensorflow_models/vision/__init__.py,sha256=4y77XkHaH8qLls3-6ta4tMp3Xj8CLbB0ihH91HsQ9z4,833
1115
- tf_models_nightly-2.14.0.dev20231010.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1116
- tf_models_nightly-2.14.0.dev20231010.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1117
- tf_models_nightly-2.14.0.dev20231010.dist-info/METADATA,sha256=X4quZ2URqJhXw8000TrfZuteeGcvPfJa99--oMdC5yM,1390
1118
- tf_models_nightly-2.14.0.dev20231010.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1119
- tf_models_nightly-2.14.0.dev20231010.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1120
- tf_models_nightly-2.14.0.dev20231010.dist-info/RECORD,,
1115
+ tf_models_nightly-2.14.0.dev20231011.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1116
+ tf_models_nightly-2.14.0.dev20231011.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1117
+ tf_models_nightly-2.14.0.dev20231011.dist-info/METADATA,sha256=v1HZmRzGLdvqOpSYUi9Su-F1bbv4p1zqkuXEGmUXSCo,1390
1118
+ tf_models_nightly-2.14.0.dev20231011.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1119
+ tf_models_nightly-2.14.0.dev20231011.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1120
+ tf_models_nightly-2.14.0.dev20231011.dist-info/RECORD,,