tf-models-nightly 2.11.0.dev20230321__py2.py3-none-any.whl → 2.11.0.dev20230322__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,6 @@
15
15
  """Video classification configuration definition."""
16
16
  import dataclasses
17
17
  from typing import Optional, Tuple
18
- from absl import flags
19
18
 
20
19
  from official.core import config_definitions as cfg
21
20
  from official.core import exp_factory
@@ -23,7 +22,6 @@ from official.modeling import hyperparams
23
22
  from official.modeling import optimization
24
23
  from official.vision.configs import common
25
24
 
26
- FLAGS = flags.FLAGS
27
25
 
28
26
  YT8M_TRAIN_EXAMPLES = 3888919
29
27
  YT8M_VAL_EXAMPLES = 1112356
@@ -105,7 +103,6 @@ def yt8m(is_training):
105
103
  class MoeModel(hyperparams.Config):
106
104
  """The model config."""
107
105
  num_mixtures: int = 5
108
- l2_penalty: float = 1e-5
109
106
  use_input_context_gate: bool = False
110
107
  use_output_context_gate: bool = False
111
108
  vocab_as_last_dim: bool = False
@@ -121,7 +118,7 @@ class DbofModel(hyperparams.Config):
121
118
  use_context_gate_cluster_layer: bool = False
122
119
  context_gate_cluster_bottleneck_size: int = 0
123
120
  pooling_method: str = 'average'
124
- yt8m_agg_classifier_model: str = 'MoeModel'
121
+ agg_classifier_model: str = 'MoeModel'
125
122
  agg_model: hyperparams.Config = MoeModel()
126
123
  norm_activation: common.NormActivation = common.NormActivation(
127
124
  activation='relu', use_sync_bn=False)
@@ -13,22 +13,156 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Contains model definitions."""
16
+
17
+ import functools
16
18
  from typing import Any, Dict, Optional
17
19
 
18
20
  import tensorflow as tf
21
+
22
+ from official.modeling import tf_utils
23
+ from official.projects.yt8m.configs import yt8m as yt8m_cfg
19
24
  from official.projects.yt8m.modeling import yt8m_model_utils as utils
20
25
 
26
+
21
27
  layers = tf.keras.layers
22
28
 
23
29
 
24
- class LogisticModel():
25
- """Logistic model with L2 regularization."""
30
+ class Dbof(tf.keras.Model):
31
+ """A YT8M model class builder.
32
+
33
+ Creates a Deep Bag of Frames model.
34
+ The model projects the features for each frame into a higher dimensional
35
+ 'clustering' space, pools across frames in that space, and then
36
+ uses a configurable video-level model to classify the now aggregated features.
37
+ The model will randomly sample either frames or sequences of frames during
38
+ training to speed up convergence.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ params: yt8m_cfg.DbofModel,
44
+ num_classes: int = 3862,
45
+ input_specs: layers.InputSpec = layers.InputSpec(
46
+ shape=[None, None, 1152]),
47
+ l2_weight_decay: Optional[float] = None,
48
+ **kwargs):
49
+ """YT8M initialization function.
50
+
51
+ Args:
52
+ params: model configuration parameters
53
+ num_classes: `int` number of classes in dataset.
54
+ input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
55
+ [batch_size x num_frames x num_features]
56
+ l2_weight_decay: An optional `float` of kernel regularizer weight decay.
57
+ **kwargs: keyword arguments to be passed.
58
+ """
59
+ self._self_setattr_tracking = False
60
+ self._num_classes = num_classes
61
+ self._input_specs = input_specs
62
+ self._params = params
63
+ self._l2_weight_decay = l2_weight_decay
64
+ self._act_fn = tf_utils.get_activation(params.norm_activation.activation)
65
+ self._norm = functools.partial(
66
+ layers.BatchNormalization,
67
+ momentum=params.norm_activation.norm_momentum,
68
+ epsilon=params.norm_activation.norm_epsilon,
69
+ synchronized=params.norm_activation.use_sync_bn)
70
+
71
+ # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
72
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
73
+ # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
74
+ l2_regularizer = (
75
+ tf.keras.regularizers.l2(l2_weight_decay / 2.0)
76
+ if l2_weight_decay
77
+ else None
78
+ )
79
+
80
+ # [batch_size x num_frames x num_features]
81
+ feature_size = input_specs.shape[-1]
82
+ # shape 'excluding' batch_size
83
+ model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
84
+ # normalize input features
85
+ input_data = tf.nn.l2_normalize(model_input, -1)
86
+ tf.summary.histogram("input_hist", input_data)
87
+
88
+ # configure model
89
+ if params.add_batch_norm:
90
+ input_data = self._norm(name="input_bn")(input_data)
91
+
92
+ # activation = reshaped input * cluster weights
93
+ if params.cluster_size > 0:
94
+ activation = layers.Dense(
95
+ params.cluster_size,
96
+ kernel_regularizer=l2_regularizer,
97
+ kernel_initializer=tf.random_normal_initializer(
98
+ stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
99
+ input_data)
100
+
101
+ if params.add_batch_norm:
102
+ activation = self._norm(name="cluster_bn")(activation)
103
+ else:
104
+ cluster_biases = tf.Variable(
105
+ tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
106
+ shape=[params.cluster_size]),
107
+ name="cluster_biases")
108
+ tf.summary.histogram("cluster_biases", cluster_biases)
109
+ activation += cluster_biases
110
+
111
+ activation = self._act_fn(activation)
112
+ tf.summary.histogram("cluster_output", activation)
26
113
 
27
- def create_model(self, model_input, vocab_size, l2_penalty=1e-8, **kwargs):
114
+ if params.use_context_gate_cluster_layer:
115
+ pooling_method = None
116
+ norm_args = dict(name="context_gate_bn")
117
+ activation = utils.context_gate(
118
+ activation,
119
+ normalizer_fn=self._norm,
120
+ normalizer_params=norm_args,
121
+ pooling_method=pooling_method,
122
+ hidden_layer_size=params.context_gate_cluster_bottleneck_size,
123
+ kernel_regularizer=l2_regularizer)
124
+
125
+ activation = utils.frame_pooling(activation, params.pooling_method)
126
+
127
+ # activation = activation * hidden1_weights
128
+ activation = layers.Dense(
129
+ params.hidden_size,
130
+ kernel_regularizer=l2_regularizer,
131
+ kernel_initializer=tf.random_normal_initializer(
132
+ stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
133
+ activation)
134
+
135
+ if params.add_batch_norm:
136
+ activation = self._norm(name="hidden1_bn")(activation)
137
+
138
+ else:
139
+ hidden1_biases = tf.Variable(
140
+ tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
141
+ name="hidden1_biases")
142
+
143
+ tf.summary.histogram("hidden1_biases", hidden1_biases)
144
+ activation += hidden1_biases
145
+
146
+ activation = self._act_fn(activation)
147
+ tf.summary.histogram("hidden1_output", activation)
148
+
149
+ super().__init__(inputs=model_input, outputs=activation, **kwargs)
150
+
151
+
152
+ class LogisticModel(tf.keras.Model):
153
+ """Logistic prediction head model with L2 regularization."""
154
+
155
+ def __init__(
156
+ self,
157
+ input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
158
+ vocab_size: int = 3862,
159
+ l2_penalty: float = 1e-8,
160
+ **kwargs,
161
+ ):
28
162
  """Creates a logistic model.
29
163
 
30
164
  Args:
31
- model_input: 'batch' x 'num_features' matrix of input features.
165
+ input_specs: 'batch' x 'num_features' matrix of input features.
32
166
  vocab_size: The number of classes in the dataset.
33
167
  l2_penalty: L2 weight regularization ratio.
34
168
  **kwargs: extra key word args.
@@ -38,43 +172,44 @@ class LogisticModel():
38
172
  model in the 'predictions' key. The dimensions of the tensor are
39
173
  batch_size x num_classes.
40
174
  """
41
- del kwargs # Unused.
175
+ inputs = tf.keras.Input(shape=input_specs.shape[1:])
42
176
  output = layers.Dense(
43
177
  vocab_size,
44
178
  activation=tf.nn.sigmoid,
45
179
  kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
46
- model_input)
47
- return {"predictions": output}
180
+ inputs)
181
+
182
+ super().__init__(inputs=inputs, outputs={"predictions": output}, **kwargs)
48
183
 
49
184
 
50
- class MoeModel():
185
+ class MoeModel(tf.keras.Model):
51
186
  """A softmax over a mixture of logistic models (with L2 regularization)."""
52
187
 
53
- def create_model(self,
54
- model_input,
55
- vocab_size,
56
- num_mixtures: int = 2,
57
- use_input_context_gate: bool = False,
58
- use_output_context_gate: bool = False,
59
- normalizer_fn=None,
60
- normalizer_params: Optional[Dict[str, Any]] = None,
61
- vocab_as_last_dim: bool = False,
62
- l2_penalty: float = 1e-5,
63
- **kwargs):
188
+ def __init__(
189
+ self,
190
+ input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
191
+ vocab_size: int = 3862,
192
+ num_mixtures: int = 2,
193
+ use_input_context_gate: bool = False,
194
+ use_output_context_gate: bool = False,
195
+ normalizer_params: Optional[Dict[str, Any]] = None,
196
+ vocab_as_last_dim: bool = False,
197
+ l2_penalty: float = 1e-5,
198
+ **kwargs,
199
+ ):
64
200
  """Creates a Mixture of (Logistic) Experts model.
65
201
 
66
202
  The model consists of a per-class softmax distribution over a
67
203
  configurable number of logistic classifiers. One of the classifiers
68
204
  in the mixture is not trained, and always predicts 0.
69
205
  Args:
70
- model_input: 'batch_size' x 'num_features' matrix of input features.
206
+ input_specs: 'batch_size' x 'num_features' matrix of input features.
71
207
  vocab_size: The number of classes in the dataset.
72
208
  num_mixtures: The number of mixtures (excluding a dummy 'expert' that
73
209
  always predicts the non-existence of an entity).
74
210
  use_input_context_gate: if True apply context gate layer to the input.
75
211
  use_output_context_gate: if True apply context gate layer to the output.
76
- normalizer_fn: normalization op constructor (e.g. batch norm).
77
- normalizer_params: parameters to the `normalizer_fn`.
212
+ normalizer_params: parameters of the batch normalization.
78
213
  vocab_as_last_dim: if True reshape `activations` and make `vocab_size` as
79
214
  the last dimension to avoid small `num_mixtures` as the last dimension.
80
215
  XLA pads up the dimensions of tensors: typically the last dimension will
@@ -88,11 +223,13 @@ class MoeModel():
88
223
  of the model in the 'predictions' key. The dimensions of the tensor
89
224
  are batch_size x num_classes.
90
225
  """
91
- del kwargs # Unused.
226
+ inputs = tf.keras.Input(shape=input_specs.shape[1:])
227
+ model_input = inputs
228
+
92
229
  if use_input_context_gate:
93
230
  model_input = utils.context_gate(
94
231
  model_input,
95
- normalizer_fn=normalizer_fn,
232
+ normalizer_fn=layers.BatchNormalization,
96
233
  normalizer_params=normalizer_params,
97
234
  )
98
235
 
@@ -132,7 +269,11 @@ class MoeModel():
132
269
  if use_output_context_gate:
133
270
  final_probabilities = utils.context_gate(
134
271
  final_probabilities,
135
- normalizer_fn=normalizer_fn,
272
+ normalizer_fn=layers.BatchNormalization,
136
273
  normalizer_params=normalizer_params,
137
274
  )
138
- return {"predictions": final_probabilities}
275
+ super().__init__(
276
+ inputs=inputs,
277
+ outputs={"predictions": final_probabilities},
278
+ **kwargs,
279
+ )
@@ -12,17 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """YT8M model definition."""
15
+ """YT8M prediction model definition."""
16
16
 
17
17
  import functools
18
- from typing import Optional
18
+ from typing import Any, Optional
19
19
 
20
+ from absl import logging
20
21
  import tensorflow as tf
21
22
 
22
- from official.modeling import tf_utils
23
23
  from official.projects.yt8m.configs import yt8m as yt8m_cfg
24
24
  from official.projects.yt8m.modeling import nn_layers
25
- from official.projects.yt8m.modeling import yt8m_model_utils as utils
26
25
 
27
26
 
28
27
  layers = tf.keras.layers
@@ -45,202 +44,58 @@ class DbofModel(tf.keras.Model):
45
44
  num_classes: int = 3862,
46
45
  input_specs: layers.InputSpec = layers.InputSpec(
47
46
  shape=[None, None, 1152]),
48
- activation: str = "relu",
49
- use_sync_bn: bool = False,
50
- norm_momentum: float = 0.99,
51
- norm_epsilon: float = 0.001,
52
47
  l2_weight_decay: Optional[float] = None,
53
- **kwargs):
54
- """YT8M initialization function.
48
+ **kwargs,
49
+ ):
50
+ """YT8M Dbof model initialization function.
55
51
 
56
52
  Args:
57
53
  params: model configuration parameters
58
54
  num_classes: `int` number of classes in dataset.
59
55
  input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
60
56
  [batch_size x num_frames x num_features]
61
- activation: A `str` of name of the activation function.
62
- use_sync_bn: If True, use synchronized batch normalization.
63
- norm_momentum: A `float` of normalization momentum for the moving average.
64
- norm_epsilon: A `float` added to variance to avoid dividing by zero.
65
57
  l2_weight_decay: An optional `float` of kernel regularizer weight decay.
66
58
  **kwargs: keyword arguments to be passed.
67
59
  """
68
- model_input, activation = self.get_dbof(
69
- params=params,
70
- num_classes=num_classes,
71
- input_specs=input_specs,
72
- activation=activation,
73
- use_sync_bn=use_sync_bn,
74
- norm_momentum=norm_momentum,
75
- norm_epsilon=norm_epsilon,
76
- l2_weight_decay=l2_weight_decay,
77
- **kwargs,
78
- )
79
- output = self.get_aggregation(model_input=activation, **kwargs)
80
- super().__init__(
81
- inputs=model_input, outputs=output.get("predictions"), **kwargs)
82
-
83
- def get_dbof(
84
- self,
85
- params: yt8m_cfg.DbofModel,
86
- num_classes: int = 3862,
87
- input_specs: layers.InputSpec = layers.InputSpec(
88
- shape=[None, None, 1152]),
89
- activation: str = "relu",
90
- use_sync_bn: bool = False,
91
- norm_momentum: float = 0.99,
92
- norm_epsilon: float = 0.001,
93
- l2_weight_decay: Optional[float] = None,
94
- **kwargs):
95
-
96
- del kwargs # Unused and reserved for future extension.
97
- self._self_setattr_tracking = False
60
+ super().__init__()
61
+ self._params = params
62
+ self._num_classes = num_classes
63
+ self._input_specs = input_specs
64
+ self._l2_weight_decay = l2_weight_decay
98
65
  self._config_dict = {
66
+ "params": params,
99
67
  "input_specs": input_specs,
100
68
  "num_classes": num_classes,
101
- "params": params,
102
- "use_sync_bn": use_sync_bn,
103
- "activation": activation,
104
69
  "l2_weight_decay": l2_weight_decay,
105
- "norm_momentum": norm_momentum,
106
- "norm_epsilon": norm_epsilon,
107
70
  }
108
- self._num_classes = num_classes
109
- self._input_specs = input_specs
110
- self._params = params
111
- self._activation = activation
112
- self._l2_weight_decay = l2_weight_decay
113
- self._use_sync_bn = use_sync_bn
114
- self._norm_momentum = norm_momentum
115
- self._norm_epsilon = norm_epsilon
116
- self._act_fn = tf_utils.get_activation(activation)
117
- self._norm = functools.partial(
118
- layers.BatchNormalization, synchronized=use_sync_bn)
119
-
120
- # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
121
- # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
122
- # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
123
- l2_regularizer = (
124
- tf.keras.regularizers.l2(l2_weight_decay / 2.0)
125
- if l2_weight_decay
126
- else None
127
- )
128
71
 
129
- bn_axis = -1
130
- # [batch_size x num_frames x num_features]
131
- feature_size = input_specs.shape[-1]
132
- # shape 'excluding' batch_size
133
- model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
134
- # normalize input features
135
- input_data = tf.nn.l2_normalize(model_input, -1)
136
- tf.summary.histogram("input_hist", input_data)
137
-
138
- # configure model
139
- if params.add_batch_norm:
140
- input_data = self._norm(
141
- axis=bn_axis,
142
- momentum=norm_momentum,
143
- epsilon=norm_epsilon,
144
- name="input_bn")(
145
- input_data)
146
-
147
- # activation = reshaped input * cluster weights
148
- if params.cluster_size > 0:
149
- activation = layers.Dense(
150
- params.cluster_size,
151
- kernel_regularizer=l2_regularizer,
152
- kernel_initializer=tf.random_normal_initializer(
153
- stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
154
- input_data)
155
-
156
- if params.add_batch_norm:
157
- activation = self._norm(
158
- axis=bn_axis,
159
- momentum=norm_momentum,
160
- epsilon=norm_epsilon,
161
- name="cluster_bn")(
162
- activation)
163
- else:
164
- cluster_biases = tf.Variable(
165
- tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
166
- shape=[params.cluster_size]),
167
- name="cluster_biases")
168
- tf.summary.histogram("cluster_biases", cluster_biases)
169
- activation += cluster_biases
170
-
171
- activation = self._act_fn(activation)
172
- tf.summary.histogram("cluster_output", activation)
173
-
174
- if params.use_context_gate_cluster_layer:
175
- pooling_method = None
176
- norm_args = dict(
177
- axis=bn_axis,
178
- momentum=norm_momentum,
179
- epsilon=norm_epsilon,
180
- name="context_gate_bn")
181
- activation = utils.context_gate(
182
- activation,
183
- normalizer_fn=self._norm,
184
- normalizer_params=norm_args,
185
- pooling_method=pooling_method,
186
- hidden_layer_size=params.context_gate_cluster_bottleneck_size,
187
- kernel_regularizer=l2_regularizer)
188
-
189
- activation = utils.frame_pooling(activation, params.pooling_method)
190
-
191
- # activation = activation * hidden1_weights
192
- activation = layers.Dense(
193
- params.hidden_size,
194
- kernel_regularizer=l2_regularizer,
195
- kernel_initializer=tf.random_normal_initializer(
196
- stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
197
- activation)
198
-
199
- if params.add_batch_norm:
200
- activation = self._norm(
201
- axis=bn_axis,
202
- momentum=norm_momentum,
203
- epsilon=norm_epsilon,
204
- name="hidden1_bn")(
205
- activation)
206
-
207
- else:
208
- hidden1_biases = tf.Variable(
209
- tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
210
- name="hidden1_biases")
211
-
212
- tf.summary.histogram("hidden1_biases", hidden1_biases)
213
- activation += hidden1_biases
214
-
215
- activation = self._act_fn(activation)
216
- tf.summary.histogram("hidden1_output", activation)
217
-
218
- return model_input, activation
219
-
220
- def get_aggregation(self, model_input, **kwargs):
221
- del kwargs # Unused and reserved for future extension.
222
- normalizer_fn = functools.partial(
223
- layers.BatchNormalization, synchronized=self._use_sync_bn)
224
- normalizer_params = dict(
225
- axis=-1, momentum=self._norm_momentum, epsilon=self._norm_epsilon)
226
- aggregated_model = getattr(
227
- nn_layers, self._params.yt8m_agg_classifier_model)
228
-
229
- output = aggregated_model().create_model(
230
- model_input=model_input,
231
- vocab_size=self._num_classes,
232
- num_mixtures=self._params.agg_model.num_mixtures,
233
- normalizer_fn=normalizer_fn,
234
- normalizer_params=normalizer_params,
235
- vocab_as_last_dim=self._params.agg_model.vocab_as_last_dim,
236
- l2_penalty=self._params.agg_model.l2_penalty,
72
+ self.dbof_backbone = nn_layers.Dbof(
73
+ params,
74
+ num_classes,
75
+ input_specs,
76
+ l2_weight_decay,
77
+ **kwargs,
237
78
  )
238
- return output
239
79
 
240
- @property
241
- def checkpoint_items(self):
242
- """Returns a dictionary of items to be additionally checkpointed."""
243
- return dict()
80
+ logging.info("Build DbofModel with %s.", params.agg_classifier_model)
81
+ if hasattr(nn_layers, params.agg_classifier_model):
82
+ aggregation_head = getattr(nn_layers, params.agg_classifier_model)
83
+ if params.agg_classifier_model == "MoeModel":
84
+ normalizer_params = dict(
85
+ synchronized=params.norm_activation.use_sync_bn,
86
+ momentum=params.norm_activation.norm_momentum,
87
+ epsilon=params.norm_activation.norm_epsilon,
88
+ )
89
+ aggregation_head = functools.partial(
90
+ aggregation_head, normalizer_params=normalizer_params)
91
+
92
+ if params.agg_model is not None:
93
+ kwargs.update(params.agg_model.as_dict())
94
+ self.head = aggregation_head(
95
+ input_specs=layers.InputSpec(shape=[None, params.hidden_size]),
96
+ vocab_size=num_classes,
97
+ l2_penalty=l2_weight_decay,
98
+ **kwargs)
244
99
 
245
100
  def get_config(self):
246
101
  return self._config_dict
@@ -248,3 +103,10 @@ class DbofModel(tf.keras.Model):
248
103
  @classmethod
249
104
  def from_config(cls, config):
250
105
  return cls(**config)
106
+
107
+ def call(
108
+ self, inputs: tf.Tensor, training: Any = None, mask: Any = None
109
+ ) -> tf.Tensor:
110
+ features = self.dbof_backbone(inputs)
111
+ outputs = self.head(features)
112
+ return outputs["predictions"]
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Contains a collection of util functions for model construction."""
16
+
16
17
  from typing import Any, Dict, Optional, Union
17
18
 
18
19
  import tensorflow as tf
@@ -177,8 +178,7 @@ def context_gate(
177
178
  kernel_initializer=kernel_initializer,
178
179
  bias_initializer=bias_initializer,
179
180
  kernel_regularizer=kernel_regularizer,
180
- )(
181
- context_features)
181
+ )(context_features)
182
182
  if normalizer_fn:
183
183
  gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
184
184
  else:
@@ -191,14 +191,13 @@ def context_gate(
191
191
  kernel_initializer=kernel_initializer,
192
192
  bias_initializer=bias_initializer,
193
193
  kernel_regularizer=kernel_regularizer,
194
- )(
195
- gates_bottleneck)
194
+ )(gates_bottleneck)
196
195
  if normalizer_fn:
197
196
  gates = normalizer_fn(**normalizer_params)(gates)
198
197
 
199
198
  if additive_residual:
200
- input_features += gates
199
+ input_features += tf.cast(gates, input_features.dtype)
201
200
  else:
202
- input_features *= gates
201
+ input_features *= tf.cast(gates, input_features.dtype)
203
202
 
204
203
  return input_features
@@ -45,15 +45,10 @@ class YT8MTask(base_task.Task):
45
45
  l2_weight_decay = self.task_config.losses.l2_weight_decay
46
46
  # Model configuration.
47
47
  model_config = self.task_config.model
48
- norm_activation_config = model_config.norm_activation
49
48
  model = DbofModel(
50
49
  params=model_config,
51
50
  input_specs=input_specs,
52
51
  num_classes=train_cfg.num_classes,
53
- activation=norm_activation_config.activation,
54
- use_sync_bn=norm_activation_config.use_sync_bn,
55
- norm_momentum=norm_activation_config.norm_momentum,
56
- norm_epsilon=norm_activation_config.norm_epsilon,
57
52
  l2_weight_decay=l2_weight_decay)
58
53
 
59
54
  non_trainable_batch_norm_variables = []
@@ -66,18 +61,32 @@ class YT8MTask(base_task.Task):
66
61
  non_trainable_extra_variables.append(var)
67
62
 
68
63
  logging.info(
69
- 'Trainable model variables:\n%s', '\n'.join(
70
- [f'{var.name}\t{var.shape}' for var in model.trainable_variables]))
64
+ 'Trainable model variables:\n%s',
65
+ '\n'.join(
66
+ [f'{var.name}\t{var.shape}' for var in model.trainable_variables]
67
+ ),
68
+ )
71
69
  logging.info(
72
- 'Non-trainable batch norm variables (get updated in training mode):\n%s',
73
- '\n'.join([
74
- f'{var.name}\t{var.shape}'
75
- for var in non_trainable_batch_norm_variables
76
- ]))
70
+ (
71
+ 'Non-trainable batch norm variables (get updated in training'
72
+ ' mode):\n%s'
73
+ ),
74
+ '\n'.join(
75
+ [
76
+ f'{var.name}\t{var.shape}'
77
+ for var in non_trainable_batch_norm_variables
78
+ ]
79
+ ),
80
+ )
77
81
  logging.info(
78
- 'Non-trainable frozen model variables:\n%s', '\n'.join([
79
- f'{var.name}\t{var.shape}' for var in non_trainable_extra_variables
80
- ]))
82
+ 'Non-trainable frozen model variables:\n%s',
83
+ '\n'.join(
84
+ [
85
+ f'{var.name}\t{var.shape}'
86
+ for var in non_trainable_extra_variables
87
+ ]
88
+ ),
89
+ )
81
90
  return model
82
91
 
83
92
  def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
@@ -173,7 +182,10 @@ class YT8MTask(base_task.Task):
173
182
  for name in metric_names:
174
183
  metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
175
184
 
176
- if self.task_config.evaluation.average_precision is not None and not training:
185
+ if (
186
+ self.task_config.evaluation.average_precision is not None
187
+ and not training
188
+ ):
177
189
  # Cannot run in train step.
178
190
  num_classes = self.task_config.validation_data.num_classes
179
191
  top_k = self.task_config.evaluation.average_precision.top_k
@@ -183,14 +195,16 @@ class YT8MTask(base_task.Task):
183
195
 
184
196
  return metrics
185
197
 
186
- def process_metrics(self,
187
- metrics: List[tf.keras.metrics.Metric],
188
- labels: tf.Tensor,
189
- outputs: tf.Tensor,
190
- model_losses: Optional[Dict[str, tf.Tensor]] = None,
191
- label_weights: Optional[tf.Tensor] = None,
192
- training: bool = True,
193
- **kwargs) -> Dict[str, Tuple[tf.Tensor, ...]]:
198
+ def process_metrics(
199
+ self,
200
+ metrics: List[tf.keras.metrics.Metric],
201
+ labels: tf.Tensor,
202
+ outputs: tf.Tensor,
203
+ model_losses: Optional[Dict[str, tf.Tensor]] = None,
204
+ label_weights: Optional[tf.Tensor] = None,
205
+ training: bool = True,
206
+ **kwargs,
207
+ ) -> Dict[str, Tuple[tf.Tensor, ...]]:
194
208
  """Updates metrics.
195
209
 
196
210
  Args:
@@ -210,7 +224,10 @@ class YT8MTask(base_task.Task):
210
224
  model_losses = {}
211
225
 
212
226
  logs = {}
213
- if self.task_config.evaluation.average_precision is not None and not training:
227
+ if (
228
+ self.task_config.evaluation.average_precision is not None
229
+ and not training
230
+ ):
214
231
  logs.update({self.avg_prec_metric.name: (labels, outputs)})
215
232
 
216
233
  for m in metrics:
@@ -211,19 +211,18 @@ class Parser(parser.Parser):
211
211
  image = preprocess_ops.normalize_image(image)
212
212
 
213
213
  # Flips image randomly during training.
214
- if self._aug_rand_hflip:
215
- if self._include_mask:
216
- image, boxes, masks = preprocess_ops.random_horizontal_flip(
217
- image, boxes, masks)
218
- else:
219
- image, boxes, _ = preprocess_ops.random_horizontal_flip(
220
- image, boxes)
221
- if self._aug_rand_vflip:
222
- if self._include_mask:
223
- image, boxes, masks = preprocess_ops.random_vertical_flip(
224
- image, boxes, masks)
225
- else:
226
- image, boxes, _ = preprocess_ops.random_vertical_flip(image, boxes)
214
+ image, boxes, masks = preprocess_ops.random_horizontal_flip(
215
+ image,
216
+ boxes,
217
+ masks=None if not self._include_mask else masks,
218
+ prob=tf.where(self._aug_rand_hflip, 0.5, 0.0),
219
+ )
220
+ image, boxes, masks = preprocess_ops.random_vertical_flip(
221
+ image,
222
+ boxes,
223
+ masks=None if not self._include_mask else masks,
224
+ prob=tf.where(self._aug_rand_vflip, 0.5, 0.0),
225
+ )
227
226
 
228
227
  # Converts boxes from normalized coordinates to pixel coordinates.
229
228
  # Now the coordinates of boxes are w.r.t. the original image.
@@ -35,4 +35,5 @@ TFDS_ID_TO_DECODER_MAP = {
35
35
  'cifar10': ClassificationDecorder,
36
36
  'cifar100': ClassificationDecorder,
37
37
  'imagenet2012': ClassificationDecorder,
38
+ 'imagenet2012_fewshot/10shot': ClassificationDecorder,
38
39
  }
@@ -465,213 +465,6 @@ def _count_detection_type(
465
465
  return count
466
466
 
467
467
 
468
- def _compute_fp_tp_gt_count(
469
- y_true: Dict[str, tf.Tensor],
470
- y_pred: Dict[str, tf.Tensor],
471
- num_classes: int,
472
- mask_output_boundary: Tuple[int, int] = (640, 640),
473
- iou_thresholds: Tuple[float, ...] = (0.5,),
474
- matching_algorithm: Optional[MatchingAlgorithm] = None,
475
- num_confidence_bins: int = 1000,
476
- use_masks: bool = False,
477
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
478
- """Computes the true and false positives."""
479
-
480
- if matching_algorithm is None:
481
- matching_algorithm = COCOMatchingAlgorithm(iou_thresholds)
482
-
483
- # (batch_size, num_detections, 4) in absolute coordinates.
484
- detection_boxes = tf.cast(y_pred['detection_boxes'], tf.float32)
485
- # (batch_size, num_detections)
486
- detection_classes = tf.cast(y_pred['detection_classes'], tf.int32)
487
- # (batch_size, num_detections)
488
- detection_scores = tf.cast(y_pred['detection_scores'], tf.float32)
489
- # (batch_size, num_gts, 4) in absolute coordinates.
490
- gt_boxes = tf.cast(y_true['boxes'], tf.float32)
491
- # (batch_size, num_gts)
492
- gt_classes = tf.cast(y_true['classes'], tf.int32)
493
- # (batch_size, num_gts)
494
- if 'is_crowds' in y_true:
495
- gt_is_crowd = tf.cast(y_true['is_crowds'], tf.bool)
496
- else:
497
- gt_is_crowd = tf.zeros_like(gt_classes, dtype=tf.bool)
498
-
499
- image_scale = tf.tile(y_true['image_info'][:, 2:3, :], multiples=[1, 1, 2])
500
- detection_boxes = detection_boxes / tf.cast(
501
- image_scale, dtype=detection_boxes.dtype
502
- )
503
-
504
- # Step 1: Computes IoUs between the detections and the non-crowd ground
505
- # truths and IoAs between the detections and the crowd ground truths.
506
- if not use_masks:
507
- # (batch_size, num_detections, num_gts)
508
- detection_to_gt_ious = box_ops.bbox_overlap(detection_boxes, gt_boxes)
509
- detection_to_gt_ioas = box_ops.bbox_intersection_over_area(
510
- detection_boxes, gt_boxes
511
- )
512
- else:
513
- # (batch_size, num_detections, mask_height, mask_width)
514
- detection_masks = tf.cast(y_pred['detection_masks'], tf.float32)
515
- # (batch_size, num_gts, gt_mask_height, gt_mask_width)
516
- gt_masks = tf.cast(y_true['masks'], tf.float32)
517
-
518
- num_detections = detection_boxes.get_shape()[1]
519
- # (batch_size, num_detections + num_gts, 4)
520
- all_boxes = _shift_and_rescale_boxes(
521
- tf.concat([detection_boxes, gt_boxes], axis=1),
522
- mask_output_boundary,
523
- )
524
- detection_boxes = all_boxes[:, :num_detections, :]
525
- gt_boxes = all_boxes[:, num_detections:, :]
526
- # (batch_size, num_detections, num_gts)
527
- detection_to_gt_ious, detection_to_gt_ioas = (
528
- mask_ops.instance_masks_overlap(
529
- detection_boxes,
530
- detection_masks,
531
- gt_boxes,
532
- gt_masks,
533
- output_size=mask_output_boundary,
534
- )
535
- )
536
-
537
- # (batch_size, num_detections, num_gts)
538
- detection_to_gt_ious = tf.where(
539
- gt_is_crowd[:, tf.newaxis, :], 0.0, detection_to_gt_ious
540
- )
541
- detection_to_crowd_ioas = tf.where(
542
- gt_is_crowd[:, tf.newaxis, :], detection_to_gt_ioas, 0.0
543
- )
544
-
545
- # Step 2: counts true positives grouped by IoU thresholds, classes and
546
- # confidence bins.
547
-
548
- # (batch_size, num_detections, num_iou_thresholds)
549
- detection_is_tp, _ = matching_algorithm(
550
- detection_to_gt_ious, detection_classes, detection_scores, gt_classes
551
- )
552
- # (batch_size * num_detections,)
553
- flattened_binned_confidence = tf.reshape(
554
- tf.cast(detection_scores * num_confidence_bins, tf.int32), [-1]
555
- )
556
- # (batch_size * num_detections, num_confidence_bins + 1)
557
- flattened_binned_confidence_one_hot = tf.one_hot(
558
- flattened_binned_confidence, num_confidence_bins + 1, axis=1
559
- )
560
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
561
- tp_count = _count_detection_type(
562
- detection_is_tp,
563
- detection_classes,
564
- flattened_binned_confidence_one_hot,
565
- num_classes,
566
- )
567
-
568
- # Step 3: Counts false positives grouped by IoU thresholds, classes and
569
- # confidence bins.
570
- # False positive: detection is not true positive (see above) and not part of
571
- # the crowd ground truth with the same class.
572
-
573
- # (batch_size, num_detections, num_gts, num_iou_thresholds)
574
- detection_matches_crowd = (
575
- (detection_to_crowd_ioas[..., tf.newaxis] > iou_thresholds)
576
- & (
577
- detection_classes[:, :, tf.newaxis, tf.newaxis]
578
- == gt_classes[:, tf.newaxis, :, tf.newaxis]
579
- )
580
- & (detection_classes[:, :, tf.newaxis, tf.newaxis] > 0)
581
- )
582
- # (batch_size, num_detections, num_iou_thresholds)
583
- detection_matches_any_crowd = tf.reduce_any(
584
- detection_matches_crowd & ~detection_is_tp[:, :, tf.newaxis, :], axis=2
585
- )
586
- detection_is_fp = ~detection_is_tp & ~detection_matches_any_crowd
587
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
588
- fp_count = _count_detection_type(
589
- detection_is_fp,
590
- detection_classes,
591
- flattened_binned_confidence_one_hot,
592
- num_classes,
593
- )
594
-
595
- # Step 4: Counts non-crowd groundtruths grouped by classes.
596
- # (num_classes, )
597
- gt_count = tf.reduce_sum(
598
- tf.one_hot(
599
- tf.where(gt_is_crowd, -1, gt_classes), num_classes, axis=-1
600
- ),
601
- axis=[0, 1],
602
- )
603
- # Clears the count of class 0 (background).
604
- gt_count *= 1.0 - tf.eye(1, num_classes, dtype=gt_count.dtype)[0]
605
-
606
- return tp_count, fp_count, gt_count
607
-
608
-
609
- def _compute_metrics(
610
- tp_count: tf.Tensor,
611
- fp_count: tf.Tensor,
612
- gt_count: tf.Tensor,
613
- confidence_thresholds: Tuple[float, ...] = (),
614
- num_confidence_bins: int = 1000,
615
- average_precision_algorithms: Optional[
616
- Dict[str, AveragePrecision]] = None,
617
- ) -> Dict[str, tf.Tensor]:
618
- """Returns the metrics values as a dict."""
619
-
620
- if average_precision_algorithms is None:
621
- average_precision_algorithms = {'ap': COCOAveragePrecision()}
622
-
623
- result = {
624
- # (num_classes,)
625
- 'valid_classes': gt_count != 0,
626
- }
627
-
628
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
629
- tp_count_cum_by_confidence = tf.math.cumsum(
630
- tp_count, axis=-1, reverse=True
631
- )
632
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
633
- fp_count_cum_by_confidence = tf.math.cumsum(
634
- fp_count, axis=-1, reverse=True
635
- )
636
-
637
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
638
- precisions = tf.math.divide_no_nan(
639
- tp_count_cum_by_confidence,
640
- tp_count_cum_by_confidence + fp_count_cum_by_confidence,
641
- )
642
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
643
- recalls = tf.math.divide_no_nan(
644
- tp_count_cum_by_confidence, gt_count[..., tf.newaxis]
645
- )
646
-
647
- if confidence_thresholds:
648
- # If confidence_thresholds is set, reports precision and recall at each
649
- # confidence threshold.
650
- confidence_thresholds = tf.cast(
651
- tf.constant(confidence_thresholds, dtype=tf.float32)
652
- * num_confidence_bins,
653
- dtype=tf.int32,
654
- )
655
- # (num_confidence_thresholds, num_iou_thresholds, num_classes)
656
- result['precisions'] = tf.gather(
657
- tf.transpose(precisions, [2, 0, 1]), confidence_thresholds
658
- )
659
- result['recalls'] = tf.gather(
660
- tf.transpose(recalls, [2, 0, 1]), confidence_thresholds
661
- )
662
-
663
- precisions = tf.reverse(precisions, axis=[-1])
664
- recalls = tf.reverse(recalls, axis=[-1])
665
- result.update(
666
- {
667
- # (num_iou_thresholds, num_classes)
668
- key: ap_algorithm(precisions, recalls)
669
- for key, ap_algorithm in average_precision_algorithms.items()
670
- }
671
- )
672
- return result
673
-
674
-
675
468
  class InstanceMetrics(tf.keras.metrics.Metric):
676
469
  """Reports the metrics of instance detection & segmentation."""
677
470
 
@@ -780,22 +573,138 @@ class InstanceMetrics(tf.keras.metrics.Metric):
780
573
 
781
574
  def reset_state(self):
782
575
  """Resets all of the metric state variables."""
783
- for v in self.variables:
784
- tf.keras.backend.set_value(v, np.zeros(v.shape))
576
+ self.tp_count.assign(tf.zeros_like(self.tp_count))
577
+ self.fp_count.assign(tf.zeros_like(self.fp_count))
578
+ self.gt_count.assign(tf.zeros_like(self.gt_count))
785
579
 
786
580
  def update_state(
787
581
  self, y_true: Dict[str, tf.Tensor], y_pred: Dict[str, tf.Tensor]
788
582
  ):
583
+ # (batch_size, num_detections, 4) in absolute coordinates.
584
+ detection_boxes = tf.cast(y_pred['detection_boxes'], tf.float32)
585
+ # (batch_size, num_detections)
586
+ detection_classes = tf.cast(y_pred['detection_classes'], tf.int32)
587
+ # (batch_size, num_detections)
588
+ detection_scores = tf.cast(y_pred['detection_scores'], tf.float32)
589
+ # (batch_size, num_gts, 4) in absolute coordinates.
590
+ gt_boxes = tf.cast(y_true['boxes'], tf.float32)
591
+ # (batch_size, num_gts)
592
+ gt_classes = tf.cast(y_true['classes'], tf.int32)
593
+ # (batch_size, num_gts)
594
+ if 'is_crowds' in y_true:
595
+ gt_is_crowd = tf.cast(y_true['is_crowds'], tf.bool)
596
+ else:
597
+ gt_is_crowd = tf.zeros_like(gt_classes, dtype=tf.bool)
789
598
 
790
- tp_count, fp_count, gt_count = _compute_fp_tp_gt_count(
791
- y_true=y_true,
792
- y_pred=y_pred,
793
- num_classes=self._num_classes,
794
- mask_output_boundary=self._mask_output_boundary,
795
- iou_thresholds=self._iou_thresholds,
796
- matching_algorithm=self._matching_algorithm,
797
- num_confidence_bins=self._num_confidence_bins,
798
- use_masks=self._use_masks)
599
+ image_scale = tf.tile(y_true['image_info'][:, 2:3, :], multiples=[1, 1, 2])
600
+ detection_boxes = detection_boxes / tf.cast(
601
+ image_scale, dtype=detection_boxes.dtype
602
+ )
603
+
604
+ # Step 1: Computes IoUs between the detections and the non-crowd ground
605
+ # truths and IoAs between the detections and the crowd ground truths.
606
+ if not self._use_masks:
607
+ # (batch_size, num_detections, num_gts)
608
+ detection_to_gt_ious = box_ops.bbox_overlap(detection_boxes, gt_boxes)
609
+ detection_to_gt_ioas = box_ops.bbox_intersection_over_area(
610
+ detection_boxes, gt_boxes
611
+ )
612
+ else:
613
+ # Use outer boxes to generate the masks if available.
614
+ if 'detection_outer_boxes' in y_pred:
615
+ detection_boxes = tf.cast(y_pred['detection_outer_boxes'], tf.float32)
616
+
617
+ # (batch_size, num_detections, mask_height, mask_width)
618
+ detection_masks = tf.cast(y_pred['detection_masks'], tf.float32)
619
+ # (batch_size, num_gts, gt_mask_height, gt_mask_width)
620
+ gt_masks = tf.cast(y_true['masks'], tf.float32)
621
+
622
+ num_detections = detection_boxes.get_shape()[1]
623
+ # (batch_size, num_detections + num_gts, 4)
624
+ all_boxes = _shift_and_rescale_boxes(
625
+ tf.concat([detection_boxes, gt_boxes], axis=1),
626
+ self._mask_output_boundary,
627
+ )
628
+ detection_boxes = all_boxes[:, :num_detections, :]
629
+ gt_boxes = all_boxes[:, num_detections:, :]
630
+ # (batch_size, num_detections, num_gts)
631
+ detection_to_gt_ious, detection_to_gt_ioas = (
632
+ mask_ops.instance_masks_overlap(
633
+ detection_boxes,
634
+ detection_masks,
635
+ gt_boxes,
636
+ gt_masks,
637
+ output_size=self._mask_output_boundary,
638
+ )
639
+ )
640
+ # (batch_size, num_detections, num_gts)
641
+ detection_to_gt_ious = tf.where(
642
+ gt_is_crowd[:, tf.newaxis, :], 0.0, detection_to_gt_ious
643
+ )
644
+ detection_to_crowd_ioas = tf.where(
645
+ gt_is_crowd[:, tf.newaxis, :], detection_to_gt_ioas, 0.0
646
+ )
647
+
648
+ # Step 2: counts true positives grouped by IoU thresholds, classes and
649
+ # confidence bins.
650
+
651
+ # (batch_size, num_detections, num_iou_thresholds)
652
+ detection_is_tp, _ = self._matching_algorithm(
653
+ detection_to_gt_ious, detection_classes, detection_scores, gt_classes
654
+ )
655
+ # (batch_size * num_detections,)
656
+ flattened_binned_confidence = tf.reshape(
657
+ tf.cast(detection_scores * self._num_confidence_bins, tf.int32), [-1]
658
+ )
659
+ # (batch_size * num_detections, num_confidence_bins + 1)
660
+ flattened_binned_confidence_one_hot = tf.one_hot(
661
+ flattened_binned_confidence, self._num_confidence_bins + 1, axis=1
662
+ )
663
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
664
+ tp_count = _count_detection_type(
665
+ detection_is_tp,
666
+ detection_classes,
667
+ flattened_binned_confidence_one_hot,
668
+ self._num_classes,
669
+ )
670
+
671
+ # Step 3: Counts false positives grouped by IoU thresholds, classes and
672
+ # confidence bins.
673
+ # False positive: detection is not true positive (see above) and not part of
674
+ # the crowd ground truth with the same class.
675
+
676
+ # (batch_size, num_detections, num_gts, num_iou_thresholds)
677
+ detection_matches_crowd = (
678
+ (detection_to_crowd_ioas[..., tf.newaxis] > self._iou_thresholds)
679
+ & (
680
+ detection_classes[:, :, tf.newaxis, tf.newaxis]
681
+ == gt_classes[:, tf.newaxis, :, tf.newaxis]
682
+ )
683
+ & (detection_classes[:, :, tf.newaxis, tf.newaxis] > 0)
684
+ )
685
+ # (batch_size, num_detections, num_iou_thresholds)
686
+ detection_matches_any_crowd = tf.reduce_any(
687
+ detection_matches_crowd & ~detection_is_tp[:, :, tf.newaxis, :], axis=2
688
+ )
689
+ detection_is_fp = ~detection_is_tp & ~detection_matches_any_crowd
690
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
691
+ fp_count = _count_detection_type(
692
+ detection_is_fp,
693
+ detection_classes,
694
+ flattened_binned_confidence_one_hot,
695
+ self._num_classes,
696
+ )
697
+
698
+ # Step 4: Counts non-crowd groundtruths grouped by classes.
699
+ # (num_classes, )
700
+ gt_count = tf.reduce_sum(
701
+ tf.one_hot(
702
+ tf.where(gt_is_crowd, -1, gt_classes), self._num_classes, axis=-1
703
+ ),
704
+ axis=[0, 1],
705
+ )
706
+ # Clears the count of class 0 (background).
707
+ gt_count *= 1.0 - tf.eye(1, self._num_classes, dtype=gt_count.dtype)[0]
799
708
 
800
709
  # Accumulates the variables.
801
710
  self.fp_count.assign_add(tf.cast(fp_count, self.fp_count.dtype))
@@ -818,13 +727,55 @@ class InstanceMetrics(tf.keras.metrics.Metric):
818
727
  'valid_classes': a bool tensor in shape (num_classes,). If False, there
819
728
  is no instance of the class in the ground truth.
820
729
  """
821
- result = _compute_metrics(
822
- fp_count=self.fp_count,
823
- tp_count=self.tp_count,
824
- gt_count=self.gt_count,
825
- confidence_thresholds=self._confidence_thresholds,
826
- num_confidence_bins=self._num_confidence_bins,
827
- average_precision_algorithms=self._average_precision_algorithms)
730
+ result = {
731
+ # (num_classes,)
732
+ 'valid_classes': self.gt_count != 0,
733
+ }
734
+
735
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
736
+ tp_count_cum_by_confidence = tf.math.cumsum(
737
+ self.tp_count, axis=-1, reverse=True
738
+ )
739
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
740
+ fp_count_cum_by_confidence = tf.math.cumsum(
741
+ self.fp_count, axis=-1, reverse=True
742
+ )
743
+
744
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
745
+ precisions = tf.math.divide_no_nan(
746
+ tp_count_cum_by_confidence,
747
+ tp_count_cum_by_confidence + fp_count_cum_by_confidence,
748
+ )
749
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
750
+ recalls = tf.math.divide_no_nan(
751
+ tp_count_cum_by_confidence, self.gt_count[..., tf.newaxis]
752
+ )
753
+
754
+ if self._confidence_thresholds:
755
+ # If confidence_thresholds is set, reports precision and recall at each
756
+ # confidence threshold.
757
+ confidence_thresholds = tf.cast(
758
+ tf.constant(self._confidence_thresholds, dtype=tf.float32)
759
+ * self._num_confidence_bins,
760
+ dtype=tf.int32,
761
+ )
762
+ # (num_confidence_thresholds, num_iou_thresholds, num_classes)
763
+ result['precisions'] = tf.gather(
764
+ tf.transpose(precisions, [2, 0, 1]), confidence_thresholds
765
+ )
766
+ result['recalls'] = tf.gather(
767
+ tf.transpose(recalls, [2, 0, 1]), confidence_thresholds
768
+ )
769
+
770
+ precisions = tf.reverse(precisions, axis=[-1])
771
+ recalls = tf.reverse(recalls, axis=[-1])
772
+ result.update(
773
+ {
774
+ # (num_iou_thresholds, num_classes)
775
+ key: ap_algorithm(precisions, recalls)
776
+ for key, ap_algorithm in self._average_precision_algorithms.items()
777
+ }
778
+ )
828
779
  return result
829
780
 
830
781
  def get_average_precision_metrics_keys(self):
@@ -182,7 +182,12 @@ def resize_and_crop_image(image,
182
182
  with tf.name_scope('resize_and_crop_image'):
183
183
  image_size = tf.cast(tf.shape(image)[0:2], tf.float32)
184
184
 
185
- random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)
185
+ random_jittering = (
186
+ isinstance(aug_scale_min, tf.Tensor)
187
+ or isinstance(aug_scale_max, tf.Tensor)
188
+ or not math.isclose(aug_scale_min, 1.0)
189
+ or not math.isclose(aug_scale_max, 1.0)
190
+ )
186
191
 
187
192
  if random_jittering:
188
193
  random_scale = tf.random.uniform(
@@ -292,7 +297,12 @@ def resize_and_crop_image_v2(image,
292
297
  scaled_size)
293
298
  desired_size = scaled_size
294
299
 
295
- random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)
300
+ random_jittering = (
301
+ isinstance(aug_scale_min, tf.Tensor)
302
+ or isinstance(aug_scale_max, tf.Tensor)
303
+ or not math.isclose(aug_scale_min, 1.0)
304
+ or not math.isclose(aug_scale_max, 1.0)
305
+ )
296
306
 
297
307
  if random_jittering:
298
308
  random_scale = tf.random.uniform(
@@ -641,10 +651,12 @@ def horizontal_flip_masks(masks):
641
651
  return masks[:, :, ::-1]
642
652
 
643
653
 
644
- def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
654
+ def random_horizontal_flip(
655
+ image, normalized_boxes=None, masks=None, seed=1, prob=0.5
656
+ ):
645
657
  """Randomly flips input image and bounding boxes horizontally."""
646
658
  with tf.name_scope('random_horizontal_flip'):
647
- do_flip = tf.greater(tf.random.uniform([], seed=seed), 0.5)
659
+ do_flip = tf.less(tf.random.uniform([], seed=seed), prob)
648
660
 
649
661
  image = tf.cond(
650
662
  do_flip,
@@ -713,10 +725,12 @@ def random_horizontal_flip_with_roi(
713
725
  return image, boxes, masks, roi_boxes
714
726
 
715
727
 
716
- def random_vertical_flip(image, normalized_boxes=None, masks=None, seed=1):
728
+ def random_vertical_flip(
729
+ image, normalized_boxes=None, masks=None, seed=1, prob=0.5
730
+ ):
717
731
  """Randomly flips input image and bounding boxes vertically."""
718
732
  with tf.name_scope('random_vertical_flip'):
719
- do_flip = tf.greater(tf.random.uniform([], seed=seed), 0.5)
733
+ do_flip = tf.less(tf.random.uniform([], seed=seed), prob)
720
734
 
721
735
  image = tf.cond(
722
736
  do_flip,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.11.0.dev20230321
3
+ Version: 2.11.0.dev20230322
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -759,15 +759,15 @@ official/projects/yt8m/__init__.py,sha256=XKqEvUISyqNK_cFqr7umxt6r-vnABJ2OqGEKE3
759
759
  official/projects/yt8m/train.py,sha256=NNzjalMAmrRSQlYGZlXNAbnwdXK2OMJXfS6qwv6_cTk,987
760
760
  official/projects/yt8m/train_test.py,sha256=TUqDVIJeDrkbmVUlJ1rw5RqOtRBzyYwguO-09OdLmkk,3892
761
761
  official/projects/yt8m/configs/__init__.py,sha256=sItZXhE5JuW4RRDBhX9EW4eSrp1HdDxhmEv9420aaD0,692
762
- official/projects/yt8m/configs/yt8m.py,sha256=PCez9EyP1ne4_wn6S7uO29Vwxzd2wWdb3HdipAfRsZA,8704
762
+ official/projects/yt8m/configs/yt8m.py,sha256=R-wUAo-udQYYEfS7SJ4JzyCcifC8NJqnpUiiSltUVFM,8629
763
763
  official/projects/yt8m/configs/yt8m_test.py,sha256=5zyL8EVpP127q6yZkBlAve3Fko2ujDCThsWZ9pUDGBM,1561
764
764
  official/projects/yt8m/modeling/__init__.py,sha256=XKqEvUISyqNK_cFqr7umxt6r-vnABJ2OqGEKE395w20,609
765
- official/projects/yt8m/modeling/nn_layers.py,sha256=qMJyW4ZR7FI1yYOJdbltooBGp0q8eF6k7Sf_yZM8fbc,5490
766
- official/projects/yt8m/modeling/yt8m_model.py,sha256=dfF2qhS5-rHdQMuq5ZjcSKBdE7stOAjFIY4o3npiDFA,8697
765
+ official/projects/yt8m/modeling/nn_layers.py,sha256=DK3yKXBC3yPMfpGqYW1Se5AhQcNBul3376uOFpFc0QQ,10392
766
+ official/projects/yt8m/modeling/yt8m_model.py,sha256=QkRW1yktCJPqE7x2cGOkcBzF605D0D1Gb5ndEIENF9E,3720
767
767
  official/projects/yt8m/modeling/yt8m_model_test.py,sha256=k5zPv4lFCJ9UjD2BSJpo4ggVHX5ZAbZ1m7kFWLe-SIM,2131
768
- official/projects/yt8m/modeling/yt8m_model_utils.py,sha256=bXoALSMfyMixux4BDUFAcTVadSz6uav9sIA-dtK4x8E,8408
768
+ official/projects/yt8m/modeling/yt8m_model_utils.py,sha256=lrSjCKjfBTTlhr3ledY5mit9Qy5P5dAnO_IgBX9gUSs,8451
769
769
  official/projects/yt8m/tasks/__init__.py,sha256=tS1Bb__H9G0UtOQMi7sdvaJv4c4ROqNVwaSSQuwCC0Y,692
770
- official/projects/yt8m/tasks/yt8m_task.py,sha256=G9PI6NkMFEc_0tViQez7Jiaks1KKu2wkmaii5PUSZqk,14615
770
+ official/projects/yt8m/tasks/yt8m_task.py,sha256=BJhvjVnaRPYNrScjDlXUxtKL09bZOijP_EfwuO3Zdio,14438
771
771
  official/recommendation/__init__.py,sha256=XKqEvUISyqNK_cFqr7umxt6r-vnABJ2OqGEKE395w20,609
772
772
  official/recommendation/constants.py,sha256=aTDcZc7_1Ir3Wt3NAzM96exWcV41iXDLUAz33kf5z4g,2877
773
773
  official/recommendation/create_ncf_data.py,sha256=oP6ueDrWfS7bDKW099TFdy9LiYfdd40jFp4Zv6WlENE,4008
@@ -851,7 +851,7 @@ official/vision/dataloaders/classification_input.py,sha256=nyAzbzuorMIusOQd2h8bm
851
851
  official/vision/dataloaders/decoder.py,sha256=XGvZHeqJzGr1cgXY4VpEQzGvapkd80u_FDfKqD4LfRs,1016
852
852
  official/vision/dataloaders/input_reader.py,sha256=S7NlBfBcGfdqWrIvQNpL5nM7YDFBQWceZv5PQZRCWME,10410
853
853
  official/vision/dataloaders/input_reader_factory.py,sha256=Gc5eZ4kEUfwp7E5U_Bl9JaEN6WRZCgtku69gP7a6zjw,1623
854
- official/vision/dataloaders/maskrcnn_input.py,sha256=8siryuNYgSNbe_5U5__K2p1LHrD3FjpVHQq6I6ZILWQ,16922
854
+ official/vision/dataloaders/maskrcnn_input.py,sha256=F1BfbUjZpwHzAG7NrKQCPbKYyOvjR1qWQ3DMjTPPqC8,16837
855
855
  official/vision/dataloaders/parser.py,sha256=7TJNrl28Ddf7w1Nx7FMMbt_aX0SRDX8h0UZa-upemo4,2315
856
856
  official/vision/dataloaders/retinanet_input.py,sha256=osZU-_eh70aK5CUD9dwzfoOv8xDfSMQPYxAn3hlQR78,13594
857
857
  official/vision/dataloaders/segmentation_input.py,sha256=xNHHSnX_VTaHE6A1WHyvabGmbPLwpJ3JNOmK8B2-RM0,11828
@@ -859,7 +859,7 @@ official/vision/dataloaders/tf_example_decoder.py,sha256=0J-rQzSkUwgR_xEZ-UR5ToS
859
859
  official/vision/dataloaders/tf_example_decoder_test.py,sha256=sPvAtExClFz7C6MYqg0pEgZzZ5-PcZYipHU0bKk2Fn8,12619
860
860
  official/vision/dataloaders/tf_example_label_map_decoder.py,sha256=SkNyEohAFvUJCVP1XRqy4lTT7bn7agULtgrExfFeLzE,2588
861
861
  official/vision/dataloaders/tf_example_label_map_decoder_test.py,sha256=FAXpKc8BZq2y28XjsbTaxlMRwe9vS53mPLW3r61ShHI,7746
862
- official/vision/dataloaders/tfds_classification_decoders.py,sha256=sQHsswMcbeCixDX0kvii9m_KbE3xxMv5ovFcYFJh1d4,1242
862
+ official/vision/dataloaders/tfds_classification_decoders.py,sha256=X_WUL3QwgGBtE0K0oguONRD38NJs7DNP7yW_gukqimY,1301
863
863
  official/vision/dataloaders/tfds_detection_decoders.py,sha256=-2l3aSCOEoJZO3h0LOrdyUn7TBgFy_iWMzQKE_DuN-Q,2273
864
864
  official/vision/dataloaders/tfds_factory.py,sha256=wEZCTW8h6CRHEO8cniSEgxXsrtrFZS_WSjaDiEFvqHc,2568
865
865
  official/vision/dataloaders/tfds_factory_test.py,sha256=PmgbXmfkmh8erSSii8GYLXu_coShbIdCEjzTeOepd3w,4010
@@ -873,7 +873,7 @@ official/vision/evaluation/__init__.py,sha256=XKqEvUISyqNK_cFqr7umxt6r-vnABJ2OqG
873
873
  official/vision/evaluation/coco_evaluator.py,sha256=LMl4rACazArEGsI5RlgRiQW3q5DeQimhh72NjfXhDLs,15535
874
874
  official/vision/evaluation/coco_utils.py,sha256=9-5uF3ofhhNdlGbdPBYM8ySbtxxXc_NorGFnU9kCxjg,17911
875
875
  official/vision/evaluation/coco_utils_test.py,sha256=PjqRYETQsOAEtcA5Pod8NPmA1xc-7wm_poYU7W_mOXw,1710
876
- official/vision/evaluation/instance_metrics.py,sha256=HJzrhyl2ETX86zXnyCNBaWG4Lz_z27oqxvVDUt2pcAI,30175
876
+ official/vision/evaluation/instance_metrics.py,sha256=X2McJng7V57AZdoo6-k3C117SUhpKGnsuLbtEOxIgKE,29104
877
877
  official/vision/evaluation/instance_metrics_test.py,sha256=GqtAgSpl4hR-wrnk55Efm8dX2b3hHNbHR9hvvR4sJ6s,10813
878
878
  official/vision/evaluation/iou.py,sha256=8VH0AHMquWvmdLtr9Yl0ZWv7efJGhhi1sKtAYUyVhCY,6272
879
879
  official/vision/evaluation/iou_test.py,sha256=1aPmTZmvp6_bL6yr2c9Wa2StUZU2xZ3mICN11pJoojg,5380
@@ -981,7 +981,7 @@ official/vision/ops/iou_similarity_test.py,sha256=SMq2wPl98_HenxAKk4p4By_ZoRy6LY
981
981
  official/vision/ops/mask_ops.py,sha256=9yR6KWC9croKvzMtSFUU3cK87-ileUD8MzHfJ-IvBJo,10260
982
982
  official/vision/ops/mask_ops_test.py,sha256=34mwpK2-GRS0BWLcOz2Dk2h9kaNBu5zBV70pElpAGsI,2825
983
983
  official/vision/ops/nms.py,sha256=Q5iGXJ-f_hcLNzyCu7CXKigCK4yHSq40ed7ornqZfVQ,8115
984
- official/vision/ops/preprocess_ops.py,sha256=xiQgsetqA380Po751xaKPMZYUG_3YAmyWIBA0lyfwPw,39149
984
+ official/vision/ops/preprocess_ops.py,sha256=NEcNSEZP1mvWD-XmS8nBvsGxE7EpY_i5MyIsmjM1yp4,39477
985
985
  official/vision/ops/preprocess_ops_3d.py,sha256=aUn1OTLkr1046sAfVefybRxHzhaVAtywlmP0R1UGbIg,15378
986
986
  official/vision/ops/preprocess_ops_3d_test.py,sha256=mrs0IVby6WFc2gcD8pIYU__KQFfhmP4m9U000hrpIJQ,7239
987
987
  official/vision/ops/preprocess_ops_test.py,sha256=_6ozeiWsDMBxSQslNJ2y5SCqqTLwbGzgLBjU7dohqmE,11710
@@ -1069,9 +1069,9 @@ tensorflow_models/__init__.py,sha256=021FKgqdPz3ds1xxfV67FWL7e5ECQ7WHbo67D37vAQI
1069
1069
  tensorflow_models/tensorflow_models_test.py,sha256=3oRV5seq-V1La0eY0IFpGLD7AKkiemylW8GyvZIRtmo,1385
1070
1070
  tensorflow_models/nlp/__init__.py,sha256=ro-1L0G8Z1wby8D1Jbaa3No-n73tiNEx7C4f8pAUNlk,807
1071
1071
  tensorflow_models/vision/__init__.py,sha256=3qeLW_6HkgH5hEclFog2DIRu1FSzOr3JynyM23zGhu8,833
1072
- tf_models_nightly-2.11.0.dev20230321.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1073
- tf_models_nightly-2.11.0.dev20230321.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1074
- tf_models_nightly-2.11.0.dev20230321.dist-info/METADATA,sha256=ONfrJ2f_hyv4tO-2Dj8AVq-IBYx3v_a9Q8ViV5oJfAQ,1426
1075
- tf_models_nightly-2.11.0.dev20230321.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1076
- tf_models_nightly-2.11.0.dev20230321.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1077
- tf_models_nightly-2.11.0.dev20230321.dist-info/RECORD,,
1072
+ tf_models_nightly-2.11.0.dev20230322.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1073
+ tf_models_nightly-2.11.0.dev20230322.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1074
+ tf_models_nightly-2.11.0.dev20230322.dist-info/METADATA,sha256=nDQJeoSeIrhNpJ_zLaHs0TTJh6nTufa15zPzqMjIZUk,1426
1075
+ tf_models_nightly-2.11.0.dev20230322.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1076
+ tf_models_nightly-2.11.0.dev20230322.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1077
+ tf_models_nightly-2.11.0.dev20230322.dist-info/RECORD,,