tf-models-nightly 2.11.0.dev20230321__py2.py3-none-any.whl → 2.11.0.dev20230323__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -95,6 +95,20 @@ class Config(params_dict.ParamsDict):
95
95
  def BUILDER(self):
96
96
  return self._BUILDER
97
97
 
98
+ @classmethod
99
+ def _get_annotations(cls):
100
+ """Returns valid annotations.
101
+
102
+ Note: this is similar to dataclasses.__annotations__ except it also includes
103
+ annotations from its parent classes.
104
+ """
105
+ all_annotations = typing.get_type_hints(cls)
106
+ # Removes Config class annotation from the value, e.g., default_params,
107
+ # restrictions, etc.
108
+ for k in Config.__annotations__:
109
+ del all_annotations[k]
110
+ return all_annotations
111
+
98
112
  @classmethod
99
113
  def _isvalidsequence(cls, v):
100
114
  """Check if the input values are valid sequences.
@@ -175,9 +189,10 @@ class Config(params_dict.ParamsDict):
175
189
  if not subconfig_type:
176
190
  subconfig_type = Config
177
191
 
178
- if k in cls.__annotations__:
192
+ annotations = cls._get_annotations()
193
+ if k in annotations:
179
194
  # Directly Config subtype.
180
- type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
195
+ type_annotation = annotations[k]
181
196
  i = 0
182
197
  # Loop for striping the Optional annotation.
183
198
  traverse_in = True
@@ -326,6 +341,9 @@ class Config(params_dict.ParamsDict):
326
341
  @classmethod
327
342
  def from_args(cls, *args, **kwargs):
328
343
  """Builds a config from the given list of arguments."""
344
+ # Note we intend to keep `__annotations__` instead of `_get_annotations`.
345
+ # Assuming a parent class of (a, b) with the sub-class of (c, d), the
346
+ # sub-class will take (c, d) for args, rather than starting from (a, b).
329
347
  attributes = list(cls.__annotations__.keys())
330
348
  default_params = {a: p for a, p in zip(attributes, args)}
331
349
  default_params.update(kwargs)
@@ -33,6 +33,7 @@ class DumpConfig2(base_config.Config):
33
33
  c: int = 2
34
34
  d: str = 'text'
35
35
  e: DumpConfig1 = DumpConfig1()
36
+ optional_e: Optional[DumpConfig1] = None
36
37
 
37
38
 
38
39
  @dataclasses.dataclass
@@ -348,6 +349,34 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
348
349
  ]),
349
350
  "['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
350
351
 
352
+ def test_with_superclass_override(self):
353
+ config = DumpConfig2()
354
+ config.override({'optional_e': {'a': 2}})
355
+ self.assertEqual(
356
+ config.optional_e.as_dict(),
357
+ {
358
+ 'a': 2,
359
+ 'b': 'text',
360
+ },
361
+ )
362
+
363
+ # Previously, the following will fail. See b/274696969 for context.
364
+ config = DumpConfig3()
365
+ config.override({'optional_e': {'a': 2}})
366
+ self.assertEqual(
367
+ config.optional_e.as_dict(),
368
+ {
369
+ 'a': 2,
370
+ 'b': 'text',
371
+ },
372
+ )
373
+
374
+ def test_get_annotations_without_base_config_leak(self):
375
+ with self.assertRaisesRegex(
376
+ KeyError, "The key 'restrictions' does not exist"
377
+ ):
378
+ DumpConfig3().override({'restrictions': None})
379
+
351
380
  def test_with_restrictions(self):
352
381
  restrictions = ['e.a<c']
353
382
  config = DumpConfig2(restrictions=restrictions)
@@ -15,7 +15,6 @@
15
15
  """Video classification configuration definition."""
16
16
  import dataclasses
17
17
  from typing import Optional, Tuple
18
- from absl import flags
19
18
 
20
19
  from official.core import config_definitions as cfg
21
20
  from official.core import exp_factory
@@ -23,7 +22,6 @@ from official.modeling import hyperparams
23
22
  from official.modeling import optimization
24
23
  from official.vision.configs import common
25
24
 
26
- FLAGS = flags.FLAGS
27
25
 
28
26
  YT8M_TRAIN_EXAMPLES = 3888919
29
27
  YT8M_VAL_EXAMPLES = 1112356
@@ -105,7 +103,6 @@ def yt8m(is_training):
105
103
  class MoeModel(hyperparams.Config):
106
104
  """The model config."""
107
105
  num_mixtures: int = 5
108
- l2_penalty: float = 1e-5
109
106
  use_input_context_gate: bool = False
110
107
  use_output_context_gate: bool = False
111
108
  vocab_as_last_dim: bool = False
@@ -121,7 +118,7 @@ class DbofModel(hyperparams.Config):
121
118
  use_context_gate_cluster_layer: bool = False
122
119
  context_gate_cluster_bottleneck_size: int = 0
123
120
  pooling_method: str = 'average'
124
- yt8m_agg_classifier_model: str = 'MoeModel'
121
+ agg_classifier_model: str = 'MoeModel'
125
122
  agg_model: hyperparams.Config = MoeModel()
126
123
  norm_activation: common.NormActivation = common.NormActivation(
127
124
  activation='relu', use_sync_bn=False)
@@ -13,22 +13,156 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Contains model definitions."""
16
+
17
+ import functools
16
18
  from typing import Any, Dict, Optional
17
19
 
18
20
  import tensorflow as tf
21
+
22
+ from official.modeling import tf_utils
23
+ from official.projects.yt8m.configs import yt8m as yt8m_cfg
19
24
  from official.projects.yt8m.modeling import yt8m_model_utils as utils
20
25
 
26
+
21
27
  layers = tf.keras.layers
22
28
 
23
29
 
24
- class LogisticModel():
25
- """Logistic model with L2 regularization."""
30
+ class Dbof(tf.keras.Model):
31
+ """A YT8M model class builder.
32
+
33
+ Creates a Deep Bag of Frames model.
34
+ The model projects the features for each frame into a higher dimensional
35
+ 'clustering' space, pools across frames in that space, and then
36
+ uses a configurable video-level model to classify the now aggregated features.
37
+ The model will randomly sample either frames or sequences of frames during
38
+ training to speed up convergence.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ params: yt8m_cfg.DbofModel,
44
+ num_classes: int = 3862,
45
+ input_specs: layers.InputSpec = layers.InputSpec(
46
+ shape=[None, None, 1152]),
47
+ l2_weight_decay: Optional[float] = None,
48
+ **kwargs):
49
+ """YT8M initialization function.
50
+
51
+ Args:
52
+ params: model configuration parameters
53
+ num_classes: `int` number of classes in dataset.
54
+ input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
55
+ [batch_size x num_frames x num_features]
56
+ l2_weight_decay: An optional `float` of kernel regularizer weight decay.
57
+ **kwargs: keyword arguments to be passed.
58
+ """
59
+ self._self_setattr_tracking = False
60
+ self._num_classes = num_classes
61
+ self._input_specs = input_specs
62
+ self._params = params
63
+ self._l2_weight_decay = l2_weight_decay
64
+ self._act_fn = tf_utils.get_activation(params.norm_activation.activation)
65
+ self._norm = functools.partial(
66
+ layers.BatchNormalization,
67
+ momentum=params.norm_activation.norm_momentum,
68
+ epsilon=params.norm_activation.norm_epsilon,
69
+ synchronized=params.norm_activation.use_sync_bn)
70
+
71
+ # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
72
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
73
+ # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
74
+ l2_regularizer = (
75
+ tf.keras.regularizers.l2(l2_weight_decay / 2.0)
76
+ if l2_weight_decay
77
+ else None
78
+ )
79
+
80
+ # [batch_size x num_frames x num_features]
81
+ feature_size = input_specs.shape[-1]
82
+ # shape 'excluding' batch_size
83
+ model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
84
+ # normalize input features
85
+ input_data = tf.nn.l2_normalize(model_input, -1)
86
+ tf.summary.histogram("input_hist", input_data)
87
+
88
+ # configure model
89
+ if params.add_batch_norm:
90
+ input_data = self._norm(name="input_bn")(input_data)
91
+
92
+ # activation = reshaped input * cluster weights
93
+ if params.cluster_size > 0:
94
+ activation = layers.Dense(
95
+ params.cluster_size,
96
+ kernel_regularizer=l2_regularizer,
97
+ kernel_initializer=tf.random_normal_initializer(
98
+ stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
99
+ input_data)
100
+
101
+ if params.add_batch_norm:
102
+ activation = self._norm(name="cluster_bn")(activation)
103
+ else:
104
+ cluster_biases = tf.Variable(
105
+ tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
106
+ shape=[params.cluster_size]),
107
+ name="cluster_biases")
108
+ tf.summary.histogram("cluster_biases", cluster_biases)
109
+ activation += cluster_biases
110
+
111
+ activation = self._act_fn(activation)
112
+ tf.summary.histogram("cluster_output", activation)
26
113
 
27
- def create_model(self, model_input, vocab_size, l2_penalty=1e-8, **kwargs):
114
+ if params.use_context_gate_cluster_layer:
115
+ pooling_method = None
116
+ norm_args = dict(name="context_gate_bn")
117
+ activation = utils.context_gate(
118
+ activation,
119
+ normalizer_fn=self._norm,
120
+ normalizer_params=norm_args,
121
+ pooling_method=pooling_method,
122
+ hidden_layer_size=params.context_gate_cluster_bottleneck_size,
123
+ kernel_regularizer=l2_regularizer)
124
+
125
+ activation = utils.frame_pooling(activation, params.pooling_method)
126
+
127
+ # activation = activation * hidden1_weights
128
+ activation = layers.Dense(
129
+ params.hidden_size,
130
+ kernel_regularizer=l2_regularizer,
131
+ kernel_initializer=tf.random_normal_initializer(
132
+ stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
133
+ activation)
134
+
135
+ if params.add_batch_norm:
136
+ activation = self._norm(name="hidden1_bn")(activation)
137
+
138
+ else:
139
+ hidden1_biases = tf.Variable(
140
+ tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
141
+ name="hidden1_biases")
142
+
143
+ tf.summary.histogram("hidden1_biases", hidden1_biases)
144
+ activation += hidden1_biases
145
+
146
+ activation = self._act_fn(activation)
147
+ tf.summary.histogram("hidden1_output", activation)
148
+
149
+ super().__init__(inputs=model_input, outputs=activation, **kwargs)
150
+
151
+
152
+ class LogisticModel(tf.keras.Model):
153
+ """Logistic prediction head model with L2 regularization."""
154
+
155
+ def __init__(
156
+ self,
157
+ input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
158
+ vocab_size: int = 3862,
159
+ l2_penalty: float = 1e-8,
160
+ **kwargs,
161
+ ):
28
162
  """Creates a logistic model.
29
163
 
30
164
  Args:
31
- model_input: 'batch' x 'num_features' matrix of input features.
165
+ input_specs: 'batch' x 'num_features' matrix of input features.
32
166
  vocab_size: The number of classes in the dataset.
33
167
  l2_penalty: L2 weight regularization ratio.
34
168
  **kwargs: extra key word args.
@@ -38,43 +172,44 @@ class LogisticModel():
38
172
  model in the 'predictions' key. The dimensions of the tensor are
39
173
  batch_size x num_classes.
40
174
  """
41
- del kwargs # Unused.
175
+ inputs = tf.keras.Input(shape=input_specs.shape[1:])
42
176
  output = layers.Dense(
43
177
  vocab_size,
44
178
  activation=tf.nn.sigmoid,
45
179
  kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
46
- model_input)
47
- return {"predictions": output}
180
+ inputs)
181
+
182
+ super().__init__(inputs=inputs, outputs={"predictions": output}, **kwargs)
48
183
 
49
184
 
50
- class MoeModel():
185
+ class MoeModel(tf.keras.Model):
51
186
  """A softmax over a mixture of logistic models (with L2 regularization)."""
52
187
 
53
- def create_model(self,
54
- model_input,
55
- vocab_size,
56
- num_mixtures: int = 2,
57
- use_input_context_gate: bool = False,
58
- use_output_context_gate: bool = False,
59
- normalizer_fn=None,
60
- normalizer_params: Optional[Dict[str, Any]] = None,
61
- vocab_as_last_dim: bool = False,
62
- l2_penalty: float = 1e-5,
63
- **kwargs):
188
+ def __init__(
189
+ self,
190
+ input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
191
+ vocab_size: int = 3862,
192
+ num_mixtures: int = 2,
193
+ use_input_context_gate: bool = False,
194
+ use_output_context_gate: bool = False,
195
+ normalizer_params: Optional[Dict[str, Any]] = None,
196
+ vocab_as_last_dim: bool = False,
197
+ l2_penalty: float = 1e-5,
198
+ **kwargs,
199
+ ):
64
200
  """Creates a Mixture of (Logistic) Experts model.
65
201
 
66
202
  The model consists of a per-class softmax distribution over a
67
203
  configurable number of logistic classifiers. One of the classifiers
68
204
  in the mixture is not trained, and always predicts 0.
69
205
  Args:
70
- model_input: 'batch_size' x 'num_features' matrix of input features.
206
+ input_specs: 'batch_size' x 'num_features' matrix of input features.
71
207
  vocab_size: The number of classes in the dataset.
72
208
  num_mixtures: The number of mixtures (excluding a dummy 'expert' that
73
209
  always predicts the non-existence of an entity).
74
210
  use_input_context_gate: if True apply context gate layer to the input.
75
211
  use_output_context_gate: if True apply context gate layer to the output.
76
- normalizer_fn: normalization op constructor (e.g. batch norm).
77
- normalizer_params: parameters to the `normalizer_fn`.
212
+ normalizer_params: parameters of the batch normalization.
78
213
  vocab_as_last_dim: if True reshape `activations` and make `vocab_size` as
79
214
  the last dimension to avoid small `num_mixtures` as the last dimension.
80
215
  XLA pads up the dimensions of tensors: typically the last dimension will
@@ -88,11 +223,13 @@ class MoeModel():
88
223
  of the model in the 'predictions' key. The dimensions of the tensor
89
224
  are batch_size x num_classes.
90
225
  """
91
- del kwargs # Unused.
226
+ inputs = tf.keras.Input(shape=input_specs.shape[1:])
227
+ model_input = inputs
228
+
92
229
  if use_input_context_gate:
93
230
  model_input = utils.context_gate(
94
231
  model_input,
95
- normalizer_fn=normalizer_fn,
232
+ normalizer_fn=layers.BatchNormalization,
96
233
  normalizer_params=normalizer_params,
97
234
  )
98
235
 
@@ -132,7 +269,11 @@ class MoeModel():
132
269
  if use_output_context_gate:
133
270
  final_probabilities = utils.context_gate(
134
271
  final_probabilities,
135
- normalizer_fn=normalizer_fn,
272
+ normalizer_fn=layers.BatchNormalization,
136
273
  normalizer_params=normalizer_params,
137
274
  )
138
- return {"predictions": final_probabilities}
275
+ super().__init__(
276
+ inputs=inputs,
277
+ outputs={"predictions": final_probabilities},
278
+ **kwargs,
279
+ )
@@ -12,17 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """YT8M model definition."""
15
+ """YT8M prediction model definition."""
16
16
 
17
17
  import functools
18
- from typing import Optional
18
+ from typing import Any, Optional
19
19
 
20
+ from absl import logging
20
21
  import tensorflow as tf
21
22
 
22
- from official.modeling import tf_utils
23
23
  from official.projects.yt8m.configs import yt8m as yt8m_cfg
24
24
  from official.projects.yt8m.modeling import nn_layers
25
- from official.projects.yt8m.modeling import yt8m_model_utils as utils
26
25
 
27
26
 
28
27
  layers = tf.keras.layers
@@ -45,202 +44,58 @@ class DbofModel(tf.keras.Model):
45
44
  num_classes: int = 3862,
46
45
  input_specs: layers.InputSpec = layers.InputSpec(
47
46
  shape=[None, None, 1152]),
48
- activation: str = "relu",
49
- use_sync_bn: bool = False,
50
- norm_momentum: float = 0.99,
51
- norm_epsilon: float = 0.001,
52
47
  l2_weight_decay: Optional[float] = None,
53
- **kwargs):
54
- """YT8M initialization function.
48
+ **kwargs,
49
+ ):
50
+ """YT8M Dbof model initialization function.
55
51
 
56
52
  Args:
57
53
  params: model configuration parameters
58
54
  num_classes: `int` number of classes in dataset.
59
55
  input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
60
56
  [batch_size x num_frames x num_features]
61
- activation: A `str` of name of the activation function.
62
- use_sync_bn: If True, use synchronized batch normalization.
63
- norm_momentum: A `float` of normalization momentum for the moving average.
64
- norm_epsilon: A `float` added to variance to avoid dividing by zero.
65
57
  l2_weight_decay: An optional `float` of kernel regularizer weight decay.
66
58
  **kwargs: keyword arguments to be passed.
67
59
  """
68
- model_input, activation = self.get_dbof(
69
- params=params,
70
- num_classes=num_classes,
71
- input_specs=input_specs,
72
- activation=activation,
73
- use_sync_bn=use_sync_bn,
74
- norm_momentum=norm_momentum,
75
- norm_epsilon=norm_epsilon,
76
- l2_weight_decay=l2_weight_decay,
77
- **kwargs,
78
- )
79
- output = self.get_aggregation(model_input=activation, **kwargs)
80
- super().__init__(
81
- inputs=model_input, outputs=output.get("predictions"), **kwargs)
82
-
83
- def get_dbof(
84
- self,
85
- params: yt8m_cfg.DbofModel,
86
- num_classes: int = 3862,
87
- input_specs: layers.InputSpec = layers.InputSpec(
88
- shape=[None, None, 1152]),
89
- activation: str = "relu",
90
- use_sync_bn: bool = False,
91
- norm_momentum: float = 0.99,
92
- norm_epsilon: float = 0.001,
93
- l2_weight_decay: Optional[float] = None,
94
- **kwargs):
95
-
96
- del kwargs # Unused and reserved for future extension.
97
- self._self_setattr_tracking = False
60
+ super().__init__()
61
+ self._params = params
62
+ self._num_classes = num_classes
63
+ self._input_specs = input_specs
64
+ self._l2_weight_decay = l2_weight_decay
98
65
  self._config_dict = {
66
+ "params": params,
99
67
  "input_specs": input_specs,
100
68
  "num_classes": num_classes,
101
- "params": params,
102
- "use_sync_bn": use_sync_bn,
103
- "activation": activation,
104
69
  "l2_weight_decay": l2_weight_decay,
105
- "norm_momentum": norm_momentum,
106
- "norm_epsilon": norm_epsilon,
107
70
  }
108
- self._num_classes = num_classes
109
- self._input_specs = input_specs
110
- self._params = params
111
- self._activation = activation
112
- self._l2_weight_decay = l2_weight_decay
113
- self._use_sync_bn = use_sync_bn
114
- self._norm_momentum = norm_momentum
115
- self._norm_epsilon = norm_epsilon
116
- self._act_fn = tf_utils.get_activation(activation)
117
- self._norm = functools.partial(
118
- layers.BatchNormalization, synchronized=use_sync_bn)
119
-
120
- # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
121
- # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
122
- # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
123
- l2_regularizer = (
124
- tf.keras.regularizers.l2(l2_weight_decay / 2.0)
125
- if l2_weight_decay
126
- else None
127
- )
128
71
 
129
- bn_axis = -1
130
- # [batch_size x num_frames x num_features]
131
- feature_size = input_specs.shape[-1]
132
- # shape 'excluding' batch_size
133
- model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
134
- # normalize input features
135
- input_data = tf.nn.l2_normalize(model_input, -1)
136
- tf.summary.histogram("input_hist", input_data)
137
-
138
- # configure model
139
- if params.add_batch_norm:
140
- input_data = self._norm(
141
- axis=bn_axis,
142
- momentum=norm_momentum,
143
- epsilon=norm_epsilon,
144
- name="input_bn")(
145
- input_data)
146
-
147
- # activation = reshaped input * cluster weights
148
- if params.cluster_size > 0:
149
- activation = layers.Dense(
150
- params.cluster_size,
151
- kernel_regularizer=l2_regularizer,
152
- kernel_initializer=tf.random_normal_initializer(
153
- stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
154
- input_data)
155
-
156
- if params.add_batch_norm:
157
- activation = self._norm(
158
- axis=bn_axis,
159
- momentum=norm_momentum,
160
- epsilon=norm_epsilon,
161
- name="cluster_bn")(
162
- activation)
163
- else:
164
- cluster_biases = tf.Variable(
165
- tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
166
- shape=[params.cluster_size]),
167
- name="cluster_biases")
168
- tf.summary.histogram("cluster_biases", cluster_biases)
169
- activation += cluster_biases
170
-
171
- activation = self._act_fn(activation)
172
- tf.summary.histogram("cluster_output", activation)
173
-
174
- if params.use_context_gate_cluster_layer:
175
- pooling_method = None
176
- norm_args = dict(
177
- axis=bn_axis,
178
- momentum=norm_momentum,
179
- epsilon=norm_epsilon,
180
- name="context_gate_bn")
181
- activation = utils.context_gate(
182
- activation,
183
- normalizer_fn=self._norm,
184
- normalizer_params=norm_args,
185
- pooling_method=pooling_method,
186
- hidden_layer_size=params.context_gate_cluster_bottleneck_size,
187
- kernel_regularizer=l2_regularizer)
188
-
189
- activation = utils.frame_pooling(activation, params.pooling_method)
190
-
191
- # activation = activation * hidden1_weights
192
- activation = layers.Dense(
193
- params.hidden_size,
194
- kernel_regularizer=l2_regularizer,
195
- kernel_initializer=tf.random_normal_initializer(
196
- stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
197
- activation)
198
-
199
- if params.add_batch_norm:
200
- activation = self._norm(
201
- axis=bn_axis,
202
- momentum=norm_momentum,
203
- epsilon=norm_epsilon,
204
- name="hidden1_bn")(
205
- activation)
206
-
207
- else:
208
- hidden1_biases = tf.Variable(
209
- tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
210
- name="hidden1_biases")
211
-
212
- tf.summary.histogram("hidden1_biases", hidden1_biases)
213
- activation += hidden1_biases
214
-
215
- activation = self._act_fn(activation)
216
- tf.summary.histogram("hidden1_output", activation)
217
-
218
- return model_input, activation
219
-
220
- def get_aggregation(self, model_input, **kwargs):
221
- del kwargs # Unused and reserved for future extension.
222
- normalizer_fn = functools.partial(
223
- layers.BatchNormalization, synchronized=self._use_sync_bn)
224
- normalizer_params = dict(
225
- axis=-1, momentum=self._norm_momentum, epsilon=self._norm_epsilon)
226
- aggregated_model = getattr(
227
- nn_layers, self._params.yt8m_agg_classifier_model)
228
-
229
- output = aggregated_model().create_model(
230
- model_input=model_input,
231
- vocab_size=self._num_classes,
232
- num_mixtures=self._params.agg_model.num_mixtures,
233
- normalizer_fn=normalizer_fn,
234
- normalizer_params=normalizer_params,
235
- vocab_as_last_dim=self._params.agg_model.vocab_as_last_dim,
236
- l2_penalty=self._params.agg_model.l2_penalty,
72
+ self.dbof_backbone = nn_layers.Dbof(
73
+ params,
74
+ num_classes,
75
+ input_specs,
76
+ l2_weight_decay,
77
+ **kwargs,
237
78
  )
238
- return output
239
79
 
240
- @property
241
- def checkpoint_items(self):
242
- """Returns a dictionary of items to be additionally checkpointed."""
243
- return dict()
80
+ logging.info("Build DbofModel with %s.", params.agg_classifier_model)
81
+ if hasattr(nn_layers, params.agg_classifier_model):
82
+ aggregation_head = getattr(nn_layers, params.agg_classifier_model)
83
+ if params.agg_classifier_model == "MoeModel":
84
+ normalizer_params = dict(
85
+ synchronized=params.norm_activation.use_sync_bn,
86
+ momentum=params.norm_activation.norm_momentum,
87
+ epsilon=params.norm_activation.norm_epsilon,
88
+ )
89
+ aggregation_head = functools.partial(
90
+ aggregation_head, normalizer_params=normalizer_params)
91
+
92
+ if params.agg_model is not None:
93
+ kwargs.update(params.agg_model.as_dict())
94
+ self.head = aggregation_head(
95
+ input_specs=layers.InputSpec(shape=[None, params.hidden_size]),
96
+ vocab_size=num_classes,
97
+ l2_penalty=l2_weight_decay,
98
+ **kwargs)
244
99
 
245
100
  def get_config(self):
246
101
  return self._config_dict
@@ -248,3 +103,10 @@ class DbofModel(tf.keras.Model):
248
103
  @classmethod
249
104
  def from_config(cls, config):
250
105
  return cls(**config)
106
+
107
+ def call(
108
+ self, inputs: tf.Tensor, training: Any = None, mask: Any = None
109
+ ) -> tf.Tensor:
110
+ features = self.dbof_backbone(inputs)
111
+ outputs = self.head(features)
112
+ return outputs["predictions"]
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Contains a collection of util functions for model construction."""
16
+
16
17
  from typing import Any, Dict, Optional, Union
17
18
 
18
19
  import tensorflow as tf
@@ -177,8 +178,7 @@ def context_gate(
177
178
  kernel_initializer=kernel_initializer,
178
179
  bias_initializer=bias_initializer,
179
180
  kernel_regularizer=kernel_regularizer,
180
- )(
181
- context_features)
181
+ )(context_features)
182
182
  if normalizer_fn:
183
183
  gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
184
184
  else:
@@ -191,14 +191,13 @@ def context_gate(
191
191
  kernel_initializer=kernel_initializer,
192
192
  bias_initializer=bias_initializer,
193
193
  kernel_regularizer=kernel_regularizer,
194
- )(
195
- gates_bottleneck)
194
+ )(gates_bottleneck)
196
195
  if normalizer_fn:
197
196
  gates = normalizer_fn(**normalizer_params)(gates)
198
197
 
199
198
  if additive_residual:
200
- input_features += gates
199
+ input_features += tf.cast(gates, input_features.dtype)
201
200
  else:
202
- input_features *= gates
201
+ input_features *= tf.cast(gates, input_features.dtype)
203
202
 
204
203
  return input_features