tf-models-nightly 2.11.0.dev20230320__py2.py3-none-any.whl → 2.11.0.dev20230322__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,6 @@
15
15
  """Video classification configuration definition."""
16
16
  import dataclasses
17
17
  from typing import Optional, Tuple
18
- from absl import flags
19
18
 
20
19
  from official.core import config_definitions as cfg
21
20
  from official.core import exp_factory
@@ -23,7 +22,6 @@ from official.modeling import hyperparams
23
22
  from official.modeling import optimization
24
23
  from official.vision.configs import common
25
24
 
26
- FLAGS = flags.FLAGS
27
25
 
28
26
  YT8M_TRAIN_EXAMPLES = 3888919
29
27
  YT8M_VAL_EXAMPLES = 1112356
@@ -53,7 +51,7 @@ class DataConfig(cfg.DataConfig):
53
51
  temporal_stride: Not used. Need to deprecated.
54
52
  max_frames: Maxim Number of frames in a input example. It is used to crop
55
53
  the input in the temporal dimension.
56
- num_frames: Number of frames in a single input example.
54
+ num_sample_frames: Number of frames to sample for each input example.
57
55
  num_classes: Number of classes to classify. Assuming it is a classification
58
56
  task.
59
57
  num_devices: Not used. To be deprecated.
@@ -77,7 +75,7 @@ class DataConfig(cfg.DataConfig):
77
75
  include_video_id: bool = False
78
76
  temporal_stride: int = 1
79
77
  max_frames: int = 300
80
- num_frames: int = 300 # set smaller to allow random sample (Parser)
78
+ num_sample_frames: int = 300 # set smaller to allow random sample (Parser)
81
79
  num_classes: int = 3862
82
80
  num_devices: int = 1
83
81
  input_path: str = ''
@@ -90,7 +88,6 @@ def yt8m(is_training):
90
88
  """YT8M dataset configs."""
91
89
  # pylint: disable=unexpected-keyword-arg
92
90
  return DataConfig(
93
- num_frames=30,
94
91
  temporal_stride=1,
95
92
  segment_labels=False,
96
93
  segment_size=5,
@@ -106,7 +103,6 @@ def yt8m(is_training):
106
103
  class MoeModel(hyperparams.Config):
107
104
  """The model config."""
108
105
  num_mixtures: int = 5
109
- l2_penalty: float = 1e-5
110
106
  use_input_context_gate: bool = False
111
107
  use_output_context_gate: bool = False
112
108
  vocab_as_last_dim: bool = False
@@ -122,7 +118,7 @@ class DbofModel(hyperparams.Config):
122
118
  use_context_gate_cluster_layer: bool = False
123
119
  context_gate_cluster_bottleneck_size: int = 0
124
120
  pooling_method: str = 'average'
125
- yt8m_agg_classifier_model: str = 'MoeModel'
121
+ agg_classifier_model: str = 'MoeModel'
126
122
  agg_model: hyperparams.Config = MoeModel()
127
123
  norm_activation: common.NormActivation = common.NormActivation(
128
124
  activation='relu', use_sync_bn=False)
@@ -13,81 +13,223 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Contains model definitions."""
16
+
17
+ import functools
16
18
  from typing import Any, Dict, Optional
17
19
 
18
20
  import tensorflow as tf
21
+
22
+ from official.modeling import tf_utils
23
+ from official.projects.yt8m.configs import yt8m as yt8m_cfg
19
24
  from official.projects.yt8m.modeling import yt8m_model_utils as utils
20
25
 
26
+
21
27
  layers = tf.keras.layers
22
28
 
23
29
 
24
- class LogisticModel():
25
- """Logistic model with L2 regularization."""
30
+ class Dbof(tf.keras.Model):
31
+ """A YT8M model class builder.
32
+
33
+ Creates a Deep Bag of Frames model.
34
+ The model projects the features for each frame into a higher dimensional
35
+ 'clustering' space, pools across frames in that space, and then
36
+ uses a configurable video-level model to classify the now aggregated features.
37
+ The model will randomly sample either frames or sequences of frames during
38
+ training to speed up convergence.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ params: yt8m_cfg.DbofModel,
44
+ num_classes: int = 3862,
45
+ input_specs: layers.InputSpec = layers.InputSpec(
46
+ shape=[None, None, 1152]),
47
+ l2_weight_decay: Optional[float] = None,
48
+ **kwargs):
49
+ """YT8M initialization function.
50
+
51
+ Args:
52
+ params: model configuration parameters
53
+ num_classes: `int` number of classes in dataset.
54
+ input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
55
+ [batch_size x num_frames x num_features]
56
+ l2_weight_decay: An optional `float` of kernel regularizer weight decay.
57
+ **kwargs: keyword arguments to be passed.
58
+ """
59
+ self._self_setattr_tracking = False
60
+ self._num_classes = num_classes
61
+ self._input_specs = input_specs
62
+ self._params = params
63
+ self._l2_weight_decay = l2_weight_decay
64
+ self._act_fn = tf_utils.get_activation(params.norm_activation.activation)
65
+ self._norm = functools.partial(
66
+ layers.BatchNormalization,
67
+ momentum=params.norm_activation.norm_momentum,
68
+ epsilon=params.norm_activation.norm_epsilon,
69
+ synchronized=params.norm_activation.use_sync_bn)
70
+
71
+ # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
72
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
73
+ # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
74
+ l2_regularizer = (
75
+ tf.keras.regularizers.l2(l2_weight_decay / 2.0)
76
+ if l2_weight_decay
77
+ else None
78
+ )
79
+
80
+ # [batch_size x num_frames x num_features]
81
+ feature_size = input_specs.shape[-1]
82
+ # shape 'excluding' batch_size
83
+ model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
84
+ # normalize input features
85
+ input_data = tf.nn.l2_normalize(model_input, -1)
86
+ tf.summary.histogram("input_hist", input_data)
87
+
88
+ # configure model
89
+ if params.add_batch_norm:
90
+ input_data = self._norm(name="input_bn")(input_data)
91
+
92
+ # activation = reshaped input * cluster weights
93
+ if params.cluster_size > 0:
94
+ activation = layers.Dense(
95
+ params.cluster_size,
96
+ kernel_regularizer=l2_regularizer,
97
+ kernel_initializer=tf.random_normal_initializer(
98
+ stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
99
+ input_data)
100
+
101
+ if params.add_batch_norm:
102
+ activation = self._norm(name="cluster_bn")(activation)
103
+ else:
104
+ cluster_biases = tf.Variable(
105
+ tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
106
+ shape=[params.cluster_size]),
107
+ name="cluster_biases")
108
+ tf.summary.histogram("cluster_biases", cluster_biases)
109
+ activation += cluster_biases
110
+
111
+ activation = self._act_fn(activation)
112
+ tf.summary.histogram("cluster_output", activation)
26
113
 
27
- def create_model(self, model_input, vocab_size, l2_penalty=1e-8):
114
+ if params.use_context_gate_cluster_layer:
115
+ pooling_method = None
116
+ norm_args = dict(name="context_gate_bn")
117
+ activation = utils.context_gate(
118
+ activation,
119
+ normalizer_fn=self._norm,
120
+ normalizer_params=norm_args,
121
+ pooling_method=pooling_method,
122
+ hidden_layer_size=params.context_gate_cluster_bottleneck_size,
123
+ kernel_regularizer=l2_regularizer)
124
+
125
+ activation = utils.frame_pooling(activation, params.pooling_method)
126
+
127
+ # activation = activation * hidden1_weights
128
+ activation = layers.Dense(
129
+ params.hidden_size,
130
+ kernel_regularizer=l2_regularizer,
131
+ kernel_initializer=tf.random_normal_initializer(
132
+ stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
133
+ activation)
134
+
135
+ if params.add_batch_norm:
136
+ activation = self._norm(name="hidden1_bn")(activation)
137
+
138
+ else:
139
+ hidden1_biases = tf.Variable(
140
+ tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
141
+ name="hidden1_biases")
142
+
143
+ tf.summary.histogram("hidden1_biases", hidden1_biases)
144
+ activation += hidden1_biases
145
+
146
+ activation = self._act_fn(activation)
147
+ tf.summary.histogram("hidden1_output", activation)
148
+
149
+ super().__init__(inputs=model_input, outputs=activation, **kwargs)
150
+
151
+
152
+ class LogisticModel(tf.keras.Model):
153
+ """Logistic prediction head model with L2 regularization."""
154
+
155
+ def __init__(
156
+ self,
157
+ input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
158
+ vocab_size: int = 3862,
159
+ l2_penalty: float = 1e-8,
160
+ **kwargs,
161
+ ):
28
162
  """Creates a logistic model.
29
163
 
30
164
  Args:
31
- model_input: 'batch' x 'num_features' matrix of input features.
165
+ input_specs: 'batch' x 'num_features' matrix of input features.
32
166
  vocab_size: The number of classes in the dataset.
33
167
  l2_penalty: L2 weight regularization ratio.
168
+ **kwargs: extra key word args.
34
169
 
35
170
  Returns:
36
171
  A dictionary with a tensor containing the probability predictions of the
37
172
  model in the 'predictions' key. The dimensions of the tensor are
38
173
  batch_size x num_classes.
39
174
  """
175
+ inputs = tf.keras.Input(shape=input_specs.shape[1:])
40
176
  output = layers.Dense(
41
177
  vocab_size,
42
178
  activation=tf.nn.sigmoid,
43
179
  kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
44
- model_input)
45
- return {"predictions": output}
180
+ inputs)
181
+
182
+ super().__init__(inputs=inputs, outputs={"predictions": output}, **kwargs)
46
183
 
47
184
 
48
- class MoeModel():
185
+ class MoeModel(tf.keras.Model):
49
186
  """A softmax over a mixture of logistic models (with L2 regularization)."""
50
187
 
51
- def create_model(self,
52
- model_input,
53
- vocab_size,
54
- num_mixtures: int = 2,
55
- use_input_context_gate: bool = False,
56
- use_output_context_gate: bool = False,
57
- normalizer_fn=None,
58
- normalizer_params: Optional[Dict[str, Any]] = None,
59
- vocab_as_last_dim: bool = False,
60
- l2_penalty: float = 1e-5):
188
+ def __init__(
189
+ self,
190
+ input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
191
+ vocab_size: int = 3862,
192
+ num_mixtures: int = 2,
193
+ use_input_context_gate: bool = False,
194
+ use_output_context_gate: bool = False,
195
+ normalizer_params: Optional[Dict[str, Any]] = None,
196
+ vocab_as_last_dim: bool = False,
197
+ l2_penalty: float = 1e-5,
198
+ **kwargs,
199
+ ):
61
200
  """Creates a Mixture of (Logistic) Experts model.
62
201
 
63
202
  The model consists of a per-class softmax distribution over a
64
203
  configurable number of logistic classifiers. One of the classifiers
65
204
  in the mixture is not trained, and always predicts 0.
66
205
  Args:
67
- model_input: 'batch_size' x 'num_features' matrix of input features.
206
+ input_specs: 'batch_size' x 'num_features' matrix of input features.
68
207
  vocab_size: The number of classes in the dataset.
69
208
  num_mixtures: The number of mixtures (excluding a dummy 'expert' that
70
209
  always predicts the non-existence of an entity).
71
210
  use_input_context_gate: if True apply context gate layer to the input.
72
211
  use_output_context_gate: if True apply context gate layer to the output.
73
- normalizer_fn: normalization op constructor (e.g. batch norm).
74
- normalizer_params: parameters to the `normalizer_fn`.
212
+ normalizer_params: parameters of the batch normalization.
75
213
  vocab_as_last_dim: if True reshape `activations` and make `vocab_size` as
76
214
  the last dimension to avoid small `num_mixtures` as the last dimension.
77
215
  XLA pads up the dimensions of tensors: typically the last dimension will
78
216
  be padded to 128, and the second to last will be padded to 8.
79
217
  l2_penalty: How much to penalize the squared magnitudes of parameter
80
218
  values.
219
+ **kwargs: extra key word args.
81
220
 
82
221
  Returns:
83
222
  A dictionary with a tensor containing the probability predictions
84
223
  of the model in the 'predictions' key. The dimensions of the tensor
85
224
  are batch_size x num_classes.
86
225
  """
226
+ inputs = tf.keras.Input(shape=input_specs.shape[1:])
227
+ model_input = inputs
228
+
87
229
  if use_input_context_gate:
88
230
  model_input = utils.context_gate(
89
231
  model_input,
90
- normalizer_fn=normalizer_fn,
232
+ normalizer_fn=layers.BatchNormalization,
91
233
  normalizer_params=normalizer_params,
92
234
  )
93
235
 
@@ -127,7 +269,11 @@ class MoeModel():
127
269
  if use_output_context_gate:
128
270
  final_probabilities = utils.context_gate(
129
271
  final_probabilities,
130
- normalizer_fn=normalizer_fn,
272
+ normalizer_fn=layers.BatchNormalization,
131
273
  normalizer_params=normalizer_params,
132
274
  )
133
- return {"predictions": final_probabilities}
275
+ super().__init__(
276
+ inputs=inputs,
277
+ outputs={"predictions": final_probabilities},
278
+ **kwargs,
279
+ )
@@ -12,15 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """YT8M model definition."""
16
- from typing import Optional
15
+ """YT8M prediction model definition."""
17
16
 
17
+ import functools
18
+ from typing import Any, Optional
19
+
20
+ from absl import logging
18
21
  import tensorflow as tf
19
22
 
20
- from official.modeling import tf_utils
21
23
  from official.projects.yt8m.configs import yt8m as yt8m_cfg
22
24
  from official.projects.yt8m.modeling import nn_layers
23
- from official.projects.yt8m.modeling import yt8m_model_utils as utils
25
+
24
26
 
25
27
  layers = tf.keras.layers
26
28
 
@@ -39,155 +41,61 @@ class DbofModel(tf.keras.Model):
39
41
  def __init__(
40
42
  self,
41
43
  params: yt8m_cfg.DbofModel,
42
- num_frames: int = 30,
43
44
  num_classes: int = 3862,
44
45
  input_specs: layers.InputSpec = layers.InputSpec(
45
46
  shape=[None, None, 1152]),
46
- kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
47
- activation: str = "relu",
48
- use_sync_bn: bool = False,
49
- norm_momentum: float = 0.99,
50
- norm_epsilon: float = 0.001,
51
- **kwargs):
52
- """YT8M initialization function.
47
+ l2_weight_decay: Optional[float] = None,
48
+ **kwargs,
49
+ ):
50
+ """YT8M Dbof model initialization function.
53
51
 
54
52
  Args:
55
53
  params: model configuration parameters
56
- num_frames: `int` number of frames in a single input.
57
54
  num_classes: `int` number of classes in dataset.
58
55
  input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
59
56
  [batch_size x num_frames x num_features]
60
- kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
61
- None.
62
- activation: A `str` of name of the activation function.
63
- use_sync_bn: If True, use synchronized batch normalization.
64
- norm_momentum: A `float` of normalization momentum for the moving average.
65
- norm_epsilon: A `float` added to variance to avoid dividing by zero.
57
+ l2_weight_decay: An optional `float` of kernel regularizer weight decay.
66
58
  **kwargs: keyword arguments to be passed.
67
59
  """
68
- del num_frames
69
- self._self_setattr_tracking = False
60
+ super().__init__()
61
+ self._params = params
62
+ self._num_classes = num_classes
63
+ self._input_specs = input_specs
64
+ self._l2_weight_decay = l2_weight_decay
70
65
  self._config_dict = {
66
+ "params": params,
71
67
  "input_specs": input_specs,
72
68
  "num_classes": num_classes,
73
- "params": params
69
+ "l2_weight_decay": l2_weight_decay,
74
70
  }
75
- self._num_classes = num_classes
76
- self._input_specs = input_specs
77
- self._act_fn = tf_utils.get_activation(activation)
78
- if use_sync_bn:
79
- self._norm = layers.experimental.SyncBatchNormalization
80
- else:
81
- self._norm = layers.BatchNormalization
82
-
83
- bn_axis = -1
84
- # [batch_size x num_frames x num_features]
85
- feature_size = input_specs.shape[-1]
86
- # shape 'excluding' batch_size
87
- model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
88
- # normalize input features
89
- input_data = tf.nn.l2_normalize(model_input, -1)
90
- tf.summary.histogram("input_hist", input_data)
91
-
92
- # configure model
93
- if params.add_batch_norm:
94
- input_data = self._norm(
95
- axis=bn_axis,
96
- momentum=norm_momentum,
97
- epsilon=norm_epsilon,
98
- name="input_bn")(
99
- input_data)
100
-
101
- # activation = reshaped input * cluster weights
102
- if params.cluster_size > 0:
103
- activation = layers.Dense(
104
- params.cluster_size,
105
- kernel_regularizer=kernel_regularizer,
106
- kernel_initializer=tf.random_normal_initializer(
107
- stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
108
- input_data)
109
-
110
- if params.add_batch_norm:
111
- activation = self._norm(
112
- axis=bn_axis,
113
- momentum=norm_momentum,
114
- epsilon=norm_epsilon,
115
- name="cluster_bn")(
116
- activation)
117
- else:
118
- cluster_biases = tf.Variable(
119
- tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
120
- shape=[params.cluster_size]),
121
- name="cluster_biases")
122
- tf.summary.histogram("cluster_biases", cluster_biases)
123
- activation += cluster_biases
124
-
125
- activation = self._act_fn(activation)
126
- tf.summary.histogram("cluster_output", activation)
127
-
128
- if params.use_context_gate_cluster_layer:
129
- pooling_method = None
130
- norm_args = dict(
131
- axis=bn_axis,
132
- momentum=norm_momentum,
133
- epsilon=norm_epsilon,
134
- name="context_gate_bn")
135
- activation = utils.context_gate(
136
- activation,
137
- normalizer_fn=self._norm,
138
- normalizer_params=norm_args,
139
- pooling_method=pooling_method,
140
- hidden_layer_size=params.context_gate_cluster_bottleneck_size,
141
- kernel_regularizer=kernel_regularizer)
142
-
143
- activation = utils.frame_pooling(activation, params.pooling_method)
144
-
145
- # activation = activation * hidden1_weights
146
- activation = layers.Dense(
147
- params.hidden_size,
148
- kernel_regularizer=kernel_regularizer,
149
- kernel_initializer=tf.random_normal_initializer(
150
- stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
151
- activation)
152
-
153
- if params.add_batch_norm:
154
- activation = self._norm(
155
- axis=bn_axis,
156
- momentum=norm_momentum,
157
- epsilon=norm_epsilon,
158
- name="hidden1_bn")(
159
- activation)
160
-
161
- else:
162
- hidden1_biases = tf.Variable(
163
- tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
164
- name="hidden1_biases")
165
-
166
- tf.summary.histogram("hidden1_biases", hidden1_biases)
167
- activation += hidden1_biases
168
-
169
- activation = self._act_fn(activation)
170
- tf.summary.histogram("hidden1_output", activation)
171
-
172
- aggregated_model = getattr(nn_layers,
173
- params.yt8m_agg_classifier_model)
174
- norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
175
- output = aggregated_model().create_model(
176
- model_input=activation,
177
- vocab_size=self._num_classes,
178
- num_mixtures=params.agg_model.num_mixtures,
179
- normalizer_fn=self._norm,
180
- normalizer_params=norm_args,
181
- vocab_as_last_dim=params.agg_model.vocab_as_last_dim,
182
- l2_penalty=params.agg_model.l2_penalty)
183
-
184
- super().__init__(
185
- inputs=model_input, outputs=output.get("predictions"), **kwargs)
186
-
187
- @property
188
- def checkpoint_items(self):
189
- """Returns a dictionary of items to be additionally checkpointed."""
190
- return dict()
71
+
72
+ self.dbof_backbone = nn_layers.Dbof(
73
+ params,
74
+ num_classes,
75
+ input_specs,
76
+ l2_weight_decay,
77
+ **kwargs,
78
+ )
79
+
80
+ logging.info("Build DbofModel with %s.", params.agg_classifier_model)
81
+ if hasattr(nn_layers, params.agg_classifier_model):
82
+ aggregation_head = getattr(nn_layers, params.agg_classifier_model)
83
+ if params.agg_classifier_model == "MoeModel":
84
+ normalizer_params = dict(
85
+ synchronized=params.norm_activation.use_sync_bn,
86
+ momentum=params.norm_activation.norm_momentum,
87
+ epsilon=params.norm_activation.norm_epsilon,
88
+ )
89
+ aggregation_head = functools.partial(
90
+ aggregation_head, normalizer_params=normalizer_params)
91
+
92
+ if params.agg_model is not None:
93
+ kwargs.update(params.agg_model.as_dict())
94
+ self.head = aggregation_head(
95
+ input_specs=layers.InputSpec(shape=[None, params.hidden_size]),
96
+ vocab_size=num_classes,
97
+ l2_penalty=l2_weight_decay,
98
+ **kwargs)
191
99
 
192
100
  def get_config(self):
193
101
  return self._config_dict
@@ -195,3 +103,10 @@ class DbofModel(tf.keras.Model):
195
103
  @classmethod
196
104
  def from_config(cls, config):
197
105
  return cls(**config)
106
+
107
+ def call(
108
+ self, inputs: tf.Tensor, training: Any = None, mask: Any = None
109
+ ) -> tf.Tensor:
110
+ features = self.dbof_backbone(inputs)
111
+ outputs = self.head(features)
112
+ return outputs["predictions"]
@@ -26,7 +26,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
26
26
  """Class for testing yt8m network."""
27
27
 
28
28
  # test_yt8m_network_creation arbitrary params
29
- @parameterized.parameters((32, 1152)) # 1152 = 1024 + 128
29
+ @parameterized.parameters((32, 1152), (24, 1152)) # 1152 = 1024 + 128
30
30
  def test_yt8m_network_creation(self, num_frames, feature_dims):
31
31
  """Test for creation of a YT8M Model.
32
32
 
@@ -39,11 +39,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
39
39
  num_classes = 3862
40
40
  model = yt8m_model.DbofModel(
41
41
  params=yt8m_cfg.YT8MTask.model,
42
- num_frames=num_frames,
43
42
  num_classes=num_classes,
44
43
  input_specs=input_specs)
45
44
 
46
- # batch = 2 -> arbitrary value for test
45
+ # batch = 2 -> arbitrary value for test.
47
46
  inputs = np.random.rand(2, num_frames, feature_dims)
48
47
  logits = model(inputs)
49
48
  self.assertAllEqual([2, num_classes], logits.numpy().shape)
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Contains a collection of util functions for model construction."""
16
+
16
17
  from typing import Any, Dict, Optional, Union
17
18
 
18
19
  import tensorflow as tf
@@ -177,8 +178,7 @@ def context_gate(
177
178
  kernel_initializer=kernel_initializer,
178
179
  bias_initializer=bias_initializer,
179
180
  kernel_regularizer=kernel_regularizer,
180
- )(
181
- context_features)
181
+ )(context_features)
182
182
  if normalizer_fn:
183
183
  gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
184
184
  else:
@@ -191,14 +191,13 @@ def context_gate(
191
191
  kernel_initializer=kernel_initializer,
192
192
  bias_initializer=bias_initializer,
193
193
  kernel_regularizer=kernel_regularizer,
194
- )(
195
- gates_bottleneck)
194
+ )(gates_bottleneck)
196
195
  if normalizer_fn:
197
196
  gates = normalizer_fn(**normalizer_params)(gates)
198
197
 
199
198
  if additive_residual:
200
- input_features += gates
199
+ input_features += tf.cast(gates, input_features.dtype)
201
200
  else:
202
- input_features *= gates
201
+ input_features *= tf.cast(gates, input_features.dtype)
203
202
 
204
203
  return input_features