tf-models-nightly 2.11.0.dev20230321__py2.py3-none-any.whl → 2.11.0.dev20230323__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/hyperparams/base_config.py +20 -2
- official/modeling/hyperparams/base_config_test.py +29 -0
- official/projects/yt8m/configs/yt8m.py +1 -4
- official/projects/yt8m/modeling/nn_layers.py +167 -26
- official/projects/yt8m/modeling/yt8m_model.py +44 -182
- official/projects/yt8m/modeling/yt8m_model_utils.py +5 -6
- official/projects/yt8m/tasks/yt8m_task.py +42 -25
- official/vision/dataloaders/maskrcnn_input.py +12 -13
- official/vision/dataloaders/tfds_classification_decoders.py +1 -0
- official/vision/evaluation/instance_metrics.py +176 -225
- official/vision/ops/augment.py +45 -33
- official/vision/ops/augment_test.py +9 -0
- official/vision/ops/preprocess_ops.py +20 -6
- official/vision/serving/export_tflite_lib.py +20 -8
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/RECORD +20 -20
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/top_level.txt +0 -0
@@ -95,6 +95,20 @@ class Config(params_dict.ParamsDict):
|
|
95
95
|
def BUILDER(self):
|
96
96
|
return self._BUILDER
|
97
97
|
|
98
|
+
@classmethod
|
99
|
+
def _get_annotations(cls):
|
100
|
+
"""Returns valid annotations.
|
101
|
+
|
102
|
+
Note: this is similar to dataclasses.__annotations__ except it also includes
|
103
|
+
annotations from its parent classes.
|
104
|
+
"""
|
105
|
+
all_annotations = typing.get_type_hints(cls)
|
106
|
+
# Removes Config class annotation from the value, e.g., default_params,
|
107
|
+
# restrictions, etc.
|
108
|
+
for k in Config.__annotations__:
|
109
|
+
del all_annotations[k]
|
110
|
+
return all_annotations
|
111
|
+
|
98
112
|
@classmethod
|
99
113
|
def _isvalidsequence(cls, v):
|
100
114
|
"""Check if the input values are valid sequences.
|
@@ -175,9 +189,10 @@ class Config(params_dict.ParamsDict):
|
|
175
189
|
if not subconfig_type:
|
176
190
|
subconfig_type = Config
|
177
191
|
|
178
|
-
|
192
|
+
annotations = cls._get_annotations()
|
193
|
+
if k in annotations:
|
179
194
|
# Directly Config subtype.
|
180
|
-
type_annotation =
|
195
|
+
type_annotation = annotations[k]
|
181
196
|
i = 0
|
182
197
|
# Loop for striping the Optional annotation.
|
183
198
|
traverse_in = True
|
@@ -326,6 +341,9 @@ class Config(params_dict.ParamsDict):
|
|
326
341
|
@classmethod
|
327
342
|
def from_args(cls, *args, **kwargs):
|
328
343
|
"""Builds a config from the given list of arguments."""
|
344
|
+
# Note we intend to keep `__annotations__` instead of `_get_annotations`.
|
345
|
+
# Assuming a parent class of (a, b) with the sub-class of (c, d), the
|
346
|
+
# sub-class will take (c, d) for args, rather than starting from (a, b).
|
329
347
|
attributes = list(cls.__annotations__.keys())
|
330
348
|
default_params = {a: p for a, p in zip(attributes, args)}
|
331
349
|
default_params.update(kwargs)
|
@@ -33,6 +33,7 @@ class DumpConfig2(base_config.Config):
|
|
33
33
|
c: int = 2
|
34
34
|
d: str = 'text'
|
35
35
|
e: DumpConfig1 = DumpConfig1()
|
36
|
+
optional_e: Optional[DumpConfig1] = None
|
36
37
|
|
37
38
|
|
38
39
|
@dataclasses.dataclass
|
@@ -348,6 +349,34 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
|
|
348
349
|
]),
|
349
350
|
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
|
350
351
|
|
352
|
+
def test_with_superclass_override(self):
|
353
|
+
config = DumpConfig2()
|
354
|
+
config.override({'optional_e': {'a': 2}})
|
355
|
+
self.assertEqual(
|
356
|
+
config.optional_e.as_dict(),
|
357
|
+
{
|
358
|
+
'a': 2,
|
359
|
+
'b': 'text',
|
360
|
+
},
|
361
|
+
)
|
362
|
+
|
363
|
+
# Previously, the following will fail. See b/274696969 for context.
|
364
|
+
config = DumpConfig3()
|
365
|
+
config.override({'optional_e': {'a': 2}})
|
366
|
+
self.assertEqual(
|
367
|
+
config.optional_e.as_dict(),
|
368
|
+
{
|
369
|
+
'a': 2,
|
370
|
+
'b': 'text',
|
371
|
+
},
|
372
|
+
)
|
373
|
+
|
374
|
+
def test_get_annotations_without_base_config_leak(self):
|
375
|
+
with self.assertRaisesRegex(
|
376
|
+
KeyError, "The key 'restrictions' does not exist"
|
377
|
+
):
|
378
|
+
DumpConfig3().override({'restrictions': None})
|
379
|
+
|
351
380
|
def test_with_restrictions(self):
|
352
381
|
restrictions = ['e.a<c']
|
353
382
|
config = DumpConfig2(restrictions=restrictions)
|
@@ -15,7 +15,6 @@
|
|
15
15
|
"""Video classification configuration definition."""
|
16
16
|
import dataclasses
|
17
17
|
from typing import Optional, Tuple
|
18
|
-
from absl import flags
|
19
18
|
|
20
19
|
from official.core import config_definitions as cfg
|
21
20
|
from official.core import exp_factory
|
@@ -23,7 +22,6 @@ from official.modeling import hyperparams
|
|
23
22
|
from official.modeling import optimization
|
24
23
|
from official.vision.configs import common
|
25
24
|
|
26
|
-
FLAGS = flags.FLAGS
|
27
25
|
|
28
26
|
YT8M_TRAIN_EXAMPLES = 3888919
|
29
27
|
YT8M_VAL_EXAMPLES = 1112356
|
@@ -105,7 +103,6 @@ def yt8m(is_training):
|
|
105
103
|
class MoeModel(hyperparams.Config):
|
106
104
|
"""The model config."""
|
107
105
|
num_mixtures: int = 5
|
108
|
-
l2_penalty: float = 1e-5
|
109
106
|
use_input_context_gate: bool = False
|
110
107
|
use_output_context_gate: bool = False
|
111
108
|
vocab_as_last_dim: bool = False
|
@@ -121,7 +118,7 @@ class DbofModel(hyperparams.Config):
|
|
121
118
|
use_context_gate_cluster_layer: bool = False
|
122
119
|
context_gate_cluster_bottleneck_size: int = 0
|
123
120
|
pooling_method: str = 'average'
|
124
|
-
|
121
|
+
agg_classifier_model: str = 'MoeModel'
|
125
122
|
agg_model: hyperparams.Config = MoeModel()
|
126
123
|
norm_activation: common.NormActivation = common.NormActivation(
|
127
124
|
activation='relu', use_sync_bn=False)
|
@@ -13,22 +13,156 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
"""Contains model definitions."""
|
16
|
+
|
17
|
+
import functools
|
16
18
|
from typing import Any, Dict, Optional
|
17
19
|
|
18
20
|
import tensorflow as tf
|
21
|
+
|
22
|
+
from official.modeling import tf_utils
|
23
|
+
from official.projects.yt8m.configs import yt8m as yt8m_cfg
|
19
24
|
from official.projects.yt8m.modeling import yt8m_model_utils as utils
|
20
25
|
|
26
|
+
|
21
27
|
layers = tf.keras.layers
|
22
28
|
|
23
29
|
|
24
|
-
class
|
25
|
-
"""
|
30
|
+
class Dbof(tf.keras.Model):
|
31
|
+
"""A YT8M model class builder.
|
32
|
+
|
33
|
+
Creates a Deep Bag of Frames model.
|
34
|
+
The model projects the features for each frame into a higher dimensional
|
35
|
+
'clustering' space, pools across frames in that space, and then
|
36
|
+
uses a configurable video-level model to classify the now aggregated features.
|
37
|
+
The model will randomly sample either frames or sequences of frames during
|
38
|
+
training to speed up convergence.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
params: yt8m_cfg.DbofModel,
|
44
|
+
num_classes: int = 3862,
|
45
|
+
input_specs: layers.InputSpec = layers.InputSpec(
|
46
|
+
shape=[None, None, 1152]),
|
47
|
+
l2_weight_decay: Optional[float] = None,
|
48
|
+
**kwargs):
|
49
|
+
"""YT8M initialization function.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
params: model configuration parameters
|
53
|
+
num_classes: `int` number of classes in dataset.
|
54
|
+
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
|
55
|
+
[batch_size x num_frames x num_features]
|
56
|
+
l2_weight_decay: An optional `float` of kernel regularizer weight decay.
|
57
|
+
**kwargs: keyword arguments to be passed.
|
58
|
+
"""
|
59
|
+
self._self_setattr_tracking = False
|
60
|
+
self._num_classes = num_classes
|
61
|
+
self._input_specs = input_specs
|
62
|
+
self._params = params
|
63
|
+
self._l2_weight_decay = l2_weight_decay
|
64
|
+
self._act_fn = tf_utils.get_activation(params.norm_activation.activation)
|
65
|
+
self._norm = functools.partial(
|
66
|
+
layers.BatchNormalization,
|
67
|
+
momentum=params.norm_activation.norm_momentum,
|
68
|
+
epsilon=params.norm_activation.norm_epsilon,
|
69
|
+
synchronized=params.norm_activation.use_sync_bn)
|
70
|
+
|
71
|
+
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
|
72
|
+
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
|
73
|
+
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
|
74
|
+
l2_regularizer = (
|
75
|
+
tf.keras.regularizers.l2(l2_weight_decay / 2.0)
|
76
|
+
if l2_weight_decay
|
77
|
+
else None
|
78
|
+
)
|
79
|
+
|
80
|
+
# [batch_size x num_frames x num_features]
|
81
|
+
feature_size = input_specs.shape[-1]
|
82
|
+
# shape 'excluding' batch_size
|
83
|
+
model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
|
84
|
+
# normalize input features
|
85
|
+
input_data = tf.nn.l2_normalize(model_input, -1)
|
86
|
+
tf.summary.histogram("input_hist", input_data)
|
87
|
+
|
88
|
+
# configure model
|
89
|
+
if params.add_batch_norm:
|
90
|
+
input_data = self._norm(name="input_bn")(input_data)
|
91
|
+
|
92
|
+
# activation = reshaped input * cluster weights
|
93
|
+
if params.cluster_size > 0:
|
94
|
+
activation = layers.Dense(
|
95
|
+
params.cluster_size,
|
96
|
+
kernel_regularizer=l2_regularizer,
|
97
|
+
kernel_initializer=tf.random_normal_initializer(
|
98
|
+
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
|
99
|
+
input_data)
|
100
|
+
|
101
|
+
if params.add_batch_norm:
|
102
|
+
activation = self._norm(name="cluster_bn")(activation)
|
103
|
+
else:
|
104
|
+
cluster_biases = tf.Variable(
|
105
|
+
tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
|
106
|
+
shape=[params.cluster_size]),
|
107
|
+
name="cluster_biases")
|
108
|
+
tf.summary.histogram("cluster_biases", cluster_biases)
|
109
|
+
activation += cluster_biases
|
110
|
+
|
111
|
+
activation = self._act_fn(activation)
|
112
|
+
tf.summary.histogram("cluster_output", activation)
|
26
113
|
|
27
|
-
|
114
|
+
if params.use_context_gate_cluster_layer:
|
115
|
+
pooling_method = None
|
116
|
+
norm_args = dict(name="context_gate_bn")
|
117
|
+
activation = utils.context_gate(
|
118
|
+
activation,
|
119
|
+
normalizer_fn=self._norm,
|
120
|
+
normalizer_params=norm_args,
|
121
|
+
pooling_method=pooling_method,
|
122
|
+
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
|
123
|
+
kernel_regularizer=l2_regularizer)
|
124
|
+
|
125
|
+
activation = utils.frame_pooling(activation, params.pooling_method)
|
126
|
+
|
127
|
+
# activation = activation * hidden1_weights
|
128
|
+
activation = layers.Dense(
|
129
|
+
params.hidden_size,
|
130
|
+
kernel_regularizer=l2_regularizer,
|
131
|
+
kernel_initializer=tf.random_normal_initializer(
|
132
|
+
stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
|
133
|
+
activation)
|
134
|
+
|
135
|
+
if params.add_batch_norm:
|
136
|
+
activation = self._norm(name="hidden1_bn")(activation)
|
137
|
+
|
138
|
+
else:
|
139
|
+
hidden1_biases = tf.Variable(
|
140
|
+
tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
|
141
|
+
name="hidden1_biases")
|
142
|
+
|
143
|
+
tf.summary.histogram("hidden1_biases", hidden1_biases)
|
144
|
+
activation += hidden1_biases
|
145
|
+
|
146
|
+
activation = self._act_fn(activation)
|
147
|
+
tf.summary.histogram("hidden1_output", activation)
|
148
|
+
|
149
|
+
super().__init__(inputs=model_input, outputs=activation, **kwargs)
|
150
|
+
|
151
|
+
|
152
|
+
class LogisticModel(tf.keras.Model):
|
153
|
+
"""Logistic prediction head model with L2 regularization."""
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
|
158
|
+
vocab_size: int = 3862,
|
159
|
+
l2_penalty: float = 1e-8,
|
160
|
+
**kwargs,
|
161
|
+
):
|
28
162
|
"""Creates a logistic model.
|
29
163
|
|
30
164
|
Args:
|
31
|
-
|
165
|
+
input_specs: 'batch' x 'num_features' matrix of input features.
|
32
166
|
vocab_size: The number of classes in the dataset.
|
33
167
|
l2_penalty: L2 weight regularization ratio.
|
34
168
|
**kwargs: extra key word args.
|
@@ -38,43 +172,44 @@ class LogisticModel():
|
|
38
172
|
model in the 'predictions' key. The dimensions of the tensor are
|
39
173
|
batch_size x num_classes.
|
40
174
|
"""
|
41
|
-
|
175
|
+
inputs = tf.keras.Input(shape=input_specs.shape[1:])
|
42
176
|
output = layers.Dense(
|
43
177
|
vocab_size,
|
44
178
|
activation=tf.nn.sigmoid,
|
45
179
|
kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
|
46
|
-
|
47
|
-
|
180
|
+
inputs)
|
181
|
+
|
182
|
+
super().__init__(inputs=inputs, outputs={"predictions": output}, **kwargs)
|
48
183
|
|
49
184
|
|
50
|
-
class MoeModel():
|
185
|
+
class MoeModel(tf.keras.Model):
|
51
186
|
"""A softmax over a mixture of logistic models (with L2 regularization)."""
|
52
187
|
|
53
|
-
def
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
188
|
+
def __init__(
|
189
|
+
self,
|
190
|
+
input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
|
191
|
+
vocab_size: int = 3862,
|
192
|
+
num_mixtures: int = 2,
|
193
|
+
use_input_context_gate: bool = False,
|
194
|
+
use_output_context_gate: bool = False,
|
195
|
+
normalizer_params: Optional[Dict[str, Any]] = None,
|
196
|
+
vocab_as_last_dim: bool = False,
|
197
|
+
l2_penalty: float = 1e-5,
|
198
|
+
**kwargs,
|
199
|
+
):
|
64
200
|
"""Creates a Mixture of (Logistic) Experts model.
|
65
201
|
|
66
202
|
The model consists of a per-class softmax distribution over a
|
67
203
|
configurable number of logistic classifiers. One of the classifiers
|
68
204
|
in the mixture is not trained, and always predicts 0.
|
69
205
|
Args:
|
70
|
-
|
206
|
+
input_specs: 'batch_size' x 'num_features' matrix of input features.
|
71
207
|
vocab_size: The number of classes in the dataset.
|
72
208
|
num_mixtures: The number of mixtures (excluding a dummy 'expert' that
|
73
209
|
always predicts the non-existence of an entity).
|
74
210
|
use_input_context_gate: if True apply context gate layer to the input.
|
75
211
|
use_output_context_gate: if True apply context gate layer to the output.
|
76
|
-
|
77
|
-
normalizer_params: parameters to the `normalizer_fn`.
|
212
|
+
normalizer_params: parameters of the batch normalization.
|
78
213
|
vocab_as_last_dim: if True reshape `activations` and make `vocab_size` as
|
79
214
|
the last dimension to avoid small `num_mixtures` as the last dimension.
|
80
215
|
XLA pads up the dimensions of tensors: typically the last dimension will
|
@@ -88,11 +223,13 @@ class MoeModel():
|
|
88
223
|
of the model in the 'predictions' key. The dimensions of the tensor
|
89
224
|
are batch_size x num_classes.
|
90
225
|
"""
|
91
|
-
|
226
|
+
inputs = tf.keras.Input(shape=input_specs.shape[1:])
|
227
|
+
model_input = inputs
|
228
|
+
|
92
229
|
if use_input_context_gate:
|
93
230
|
model_input = utils.context_gate(
|
94
231
|
model_input,
|
95
|
-
normalizer_fn=
|
232
|
+
normalizer_fn=layers.BatchNormalization,
|
96
233
|
normalizer_params=normalizer_params,
|
97
234
|
)
|
98
235
|
|
@@ -132,7 +269,11 @@ class MoeModel():
|
|
132
269
|
if use_output_context_gate:
|
133
270
|
final_probabilities = utils.context_gate(
|
134
271
|
final_probabilities,
|
135
|
-
normalizer_fn=
|
272
|
+
normalizer_fn=layers.BatchNormalization,
|
136
273
|
normalizer_params=normalizer_params,
|
137
274
|
)
|
138
|
-
|
275
|
+
super().__init__(
|
276
|
+
inputs=inputs,
|
277
|
+
outputs={"predictions": final_probabilities},
|
278
|
+
**kwargs,
|
279
|
+
)
|
@@ -12,17 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
"""YT8M model definition."""
|
15
|
+
"""YT8M prediction model definition."""
|
16
16
|
|
17
17
|
import functools
|
18
|
-
from typing import Optional
|
18
|
+
from typing import Any, Optional
|
19
19
|
|
20
|
+
from absl import logging
|
20
21
|
import tensorflow as tf
|
21
22
|
|
22
|
-
from official.modeling import tf_utils
|
23
23
|
from official.projects.yt8m.configs import yt8m as yt8m_cfg
|
24
24
|
from official.projects.yt8m.modeling import nn_layers
|
25
|
-
from official.projects.yt8m.modeling import yt8m_model_utils as utils
|
26
25
|
|
27
26
|
|
28
27
|
layers = tf.keras.layers
|
@@ -45,202 +44,58 @@ class DbofModel(tf.keras.Model):
|
|
45
44
|
num_classes: int = 3862,
|
46
45
|
input_specs: layers.InputSpec = layers.InputSpec(
|
47
46
|
shape=[None, None, 1152]),
|
48
|
-
activation: str = "relu",
|
49
|
-
use_sync_bn: bool = False,
|
50
|
-
norm_momentum: float = 0.99,
|
51
|
-
norm_epsilon: float = 0.001,
|
52
47
|
l2_weight_decay: Optional[float] = None,
|
53
|
-
**kwargs
|
54
|
-
|
48
|
+
**kwargs,
|
49
|
+
):
|
50
|
+
"""YT8M Dbof model initialization function.
|
55
51
|
|
56
52
|
Args:
|
57
53
|
params: model configuration parameters
|
58
54
|
num_classes: `int` number of classes in dataset.
|
59
55
|
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
|
60
56
|
[batch_size x num_frames x num_features]
|
61
|
-
activation: A `str` of name of the activation function.
|
62
|
-
use_sync_bn: If True, use synchronized batch normalization.
|
63
|
-
norm_momentum: A `float` of normalization momentum for the moving average.
|
64
|
-
norm_epsilon: A `float` added to variance to avoid dividing by zero.
|
65
57
|
l2_weight_decay: An optional `float` of kernel regularizer weight decay.
|
66
58
|
**kwargs: keyword arguments to be passed.
|
67
59
|
"""
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
use_sync_bn=use_sync_bn,
|
74
|
-
norm_momentum=norm_momentum,
|
75
|
-
norm_epsilon=norm_epsilon,
|
76
|
-
l2_weight_decay=l2_weight_decay,
|
77
|
-
**kwargs,
|
78
|
-
)
|
79
|
-
output = self.get_aggregation(model_input=activation, **kwargs)
|
80
|
-
super().__init__(
|
81
|
-
inputs=model_input, outputs=output.get("predictions"), **kwargs)
|
82
|
-
|
83
|
-
def get_dbof(
|
84
|
-
self,
|
85
|
-
params: yt8m_cfg.DbofModel,
|
86
|
-
num_classes: int = 3862,
|
87
|
-
input_specs: layers.InputSpec = layers.InputSpec(
|
88
|
-
shape=[None, None, 1152]),
|
89
|
-
activation: str = "relu",
|
90
|
-
use_sync_bn: bool = False,
|
91
|
-
norm_momentum: float = 0.99,
|
92
|
-
norm_epsilon: float = 0.001,
|
93
|
-
l2_weight_decay: Optional[float] = None,
|
94
|
-
**kwargs):
|
95
|
-
|
96
|
-
del kwargs # Unused and reserved for future extension.
|
97
|
-
self._self_setattr_tracking = False
|
60
|
+
super().__init__()
|
61
|
+
self._params = params
|
62
|
+
self._num_classes = num_classes
|
63
|
+
self._input_specs = input_specs
|
64
|
+
self._l2_weight_decay = l2_weight_decay
|
98
65
|
self._config_dict = {
|
66
|
+
"params": params,
|
99
67
|
"input_specs": input_specs,
|
100
68
|
"num_classes": num_classes,
|
101
|
-
"params": params,
|
102
|
-
"use_sync_bn": use_sync_bn,
|
103
|
-
"activation": activation,
|
104
69
|
"l2_weight_decay": l2_weight_decay,
|
105
|
-
"norm_momentum": norm_momentum,
|
106
|
-
"norm_epsilon": norm_epsilon,
|
107
70
|
}
|
108
|
-
self._num_classes = num_classes
|
109
|
-
self._input_specs = input_specs
|
110
|
-
self._params = params
|
111
|
-
self._activation = activation
|
112
|
-
self._l2_weight_decay = l2_weight_decay
|
113
|
-
self._use_sync_bn = use_sync_bn
|
114
|
-
self._norm_momentum = norm_momentum
|
115
|
-
self._norm_epsilon = norm_epsilon
|
116
|
-
self._act_fn = tf_utils.get_activation(activation)
|
117
|
-
self._norm = functools.partial(
|
118
|
-
layers.BatchNormalization, synchronized=use_sync_bn)
|
119
|
-
|
120
|
-
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
|
121
|
-
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
|
122
|
-
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
|
123
|
-
l2_regularizer = (
|
124
|
-
tf.keras.regularizers.l2(l2_weight_decay / 2.0)
|
125
|
-
if l2_weight_decay
|
126
|
-
else None
|
127
|
-
)
|
128
71
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
input_data = tf.nn.l2_normalize(model_input, -1)
|
136
|
-
tf.summary.histogram("input_hist", input_data)
|
137
|
-
|
138
|
-
# configure model
|
139
|
-
if params.add_batch_norm:
|
140
|
-
input_data = self._norm(
|
141
|
-
axis=bn_axis,
|
142
|
-
momentum=norm_momentum,
|
143
|
-
epsilon=norm_epsilon,
|
144
|
-
name="input_bn")(
|
145
|
-
input_data)
|
146
|
-
|
147
|
-
# activation = reshaped input * cluster weights
|
148
|
-
if params.cluster_size > 0:
|
149
|
-
activation = layers.Dense(
|
150
|
-
params.cluster_size,
|
151
|
-
kernel_regularizer=l2_regularizer,
|
152
|
-
kernel_initializer=tf.random_normal_initializer(
|
153
|
-
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
|
154
|
-
input_data)
|
155
|
-
|
156
|
-
if params.add_batch_norm:
|
157
|
-
activation = self._norm(
|
158
|
-
axis=bn_axis,
|
159
|
-
momentum=norm_momentum,
|
160
|
-
epsilon=norm_epsilon,
|
161
|
-
name="cluster_bn")(
|
162
|
-
activation)
|
163
|
-
else:
|
164
|
-
cluster_biases = tf.Variable(
|
165
|
-
tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
|
166
|
-
shape=[params.cluster_size]),
|
167
|
-
name="cluster_biases")
|
168
|
-
tf.summary.histogram("cluster_biases", cluster_biases)
|
169
|
-
activation += cluster_biases
|
170
|
-
|
171
|
-
activation = self._act_fn(activation)
|
172
|
-
tf.summary.histogram("cluster_output", activation)
|
173
|
-
|
174
|
-
if params.use_context_gate_cluster_layer:
|
175
|
-
pooling_method = None
|
176
|
-
norm_args = dict(
|
177
|
-
axis=bn_axis,
|
178
|
-
momentum=norm_momentum,
|
179
|
-
epsilon=norm_epsilon,
|
180
|
-
name="context_gate_bn")
|
181
|
-
activation = utils.context_gate(
|
182
|
-
activation,
|
183
|
-
normalizer_fn=self._norm,
|
184
|
-
normalizer_params=norm_args,
|
185
|
-
pooling_method=pooling_method,
|
186
|
-
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
|
187
|
-
kernel_regularizer=l2_regularizer)
|
188
|
-
|
189
|
-
activation = utils.frame_pooling(activation, params.pooling_method)
|
190
|
-
|
191
|
-
# activation = activation * hidden1_weights
|
192
|
-
activation = layers.Dense(
|
193
|
-
params.hidden_size,
|
194
|
-
kernel_regularizer=l2_regularizer,
|
195
|
-
kernel_initializer=tf.random_normal_initializer(
|
196
|
-
stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
|
197
|
-
activation)
|
198
|
-
|
199
|
-
if params.add_batch_norm:
|
200
|
-
activation = self._norm(
|
201
|
-
axis=bn_axis,
|
202
|
-
momentum=norm_momentum,
|
203
|
-
epsilon=norm_epsilon,
|
204
|
-
name="hidden1_bn")(
|
205
|
-
activation)
|
206
|
-
|
207
|
-
else:
|
208
|
-
hidden1_biases = tf.Variable(
|
209
|
-
tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
|
210
|
-
name="hidden1_biases")
|
211
|
-
|
212
|
-
tf.summary.histogram("hidden1_biases", hidden1_biases)
|
213
|
-
activation += hidden1_biases
|
214
|
-
|
215
|
-
activation = self._act_fn(activation)
|
216
|
-
tf.summary.histogram("hidden1_output", activation)
|
217
|
-
|
218
|
-
return model_input, activation
|
219
|
-
|
220
|
-
def get_aggregation(self, model_input, **kwargs):
|
221
|
-
del kwargs # Unused and reserved for future extension.
|
222
|
-
normalizer_fn = functools.partial(
|
223
|
-
layers.BatchNormalization, synchronized=self._use_sync_bn)
|
224
|
-
normalizer_params = dict(
|
225
|
-
axis=-1, momentum=self._norm_momentum, epsilon=self._norm_epsilon)
|
226
|
-
aggregated_model = getattr(
|
227
|
-
nn_layers, self._params.yt8m_agg_classifier_model)
|
228
|
-
|
229
|
-
output = aggregated_model().create_model(
|
230
|
-
model_input=model_input,
|
231
|
-
vocab_size=self._num_classes,
|
232
|
-
num_mixtures=self._params.agg_model.num_mixtures,
|
233
|
-
normalizer_fn=normalizer_fn,
|
234
|
-
normalizer_params=normalizer_params,
|
235
|
-
vocab_as_last_dim=self._params.agg_model.vocab_as_last_dim,
|
236
|
-
l2_penalty=self._params.agg_model.l2_penalty,
|
72
|
+
self.dbof_backbone = nn_layers.Dbof(
|
73
|
+
params,
|
74
|
+
num_classes,
|
75
|
+
input_specs,
|
76
|
+
l2_weight_decay,
|
77
|
+
**kwargs,
|
237
78
|
)
|
238
|
-
return output
|
239
79
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
80
|
+
logging.info("Build DbofModel with %s.", params.agg_classifier_model)
|
81
|
+
if hasattr(nn_layers, params.agg_classifier_model):
|
82
|
+
aggregation_head = getattr(nn_layers, params.agg_classifier_model)
|
83
|
+
if params.agg_classifier_model == "MoeModel":
|
84
|
+
normalizer_params = dict(
|
85
|
+
synchronized=params.norm_activation.use_sync_bn,
|
86
|
+
momentum=params.norm_activation.norm_momentum,
|
87
|
+
epsilon=params.norm_activation.norm_epsilon,
|
88
|
+
)
|
89
|
+
aggregation_head = functools.partial(
|
90
|
+
aggregation_head, normalizer_params=normalizer_params)
|
91
|
+
|
92
|
+
if params.agg_model is not None:
|
93
|
+
kwargs.update(params.agg_model.as_dict())
|
94
|
+
self.head = aggregation_head(
|
95
|
+
input_specs=layers.InputSpec(shape=[None, params.hidden_size]),
|
96
|
+
vocab_size=num_classes,
|
97
|
+
l2_penalty=l2_weight_decay,
|
98
|
+
**kwargs)
|
244
99
|
|
245
100
|
def get_config(self):
|
246
101
|
return self._config_dict
|
@@ -248,3 +103,10 @@ class DbofModel(tf.keras.Model):
|
|
248
103
|
@classmethod
|
249
104
|
def from_config(cls, config):
|
250
105
|
return cls(**config)
|
106
|
+
|
107
|
+
def call(
|
108
|
+
self, inputs: tf.Tensor, training: Any = None, mask: Any = None
|
109
|
+
) -> tf.Tensor:
|
110
|
+
features = self.dbof_backbone(inputs)
|
111
|
+
outputs = self.head(features)
|
112
|
+
return outputs["predictions"]
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
"""Contains a collection of util functions for model construction."""
|
16
|
+
|
16
17
|
from typing import Any, Dict, Optional, Union
|
17
18
|
|
18
19
|
import tensorflow as tf
|
@@ -177,8 +178,7 @@ def context_gate(
|
|
177
178
|
kernel_initializer=kernel_initializer,
|
178
179
|
bias_initializer=bias_initializer,
|
179
180
|
kernel_regularizer=kernel_regularizer,
|
180
|
-
)(
|
181
|
-
context_features)
|
181
|
+
)(context_features)
|
182
182
|
if normalizer_fn:
|
183
183
|
gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
|
184
184
|
else:
|
@@ -191,14 +191,13 @@ def context_gate(
|
|
191
191
|
kernel_initializer=kernel_initializer,
|
192
192
|
bias_initializer=bias_initializer,
|
193
193
|
kernel_regularizer=kernel_regularizer,
|
194
|
-
)(
|
195
|
-
gates_bottleneck)
|
194
|
+
)(gates_bottleneck)
|
196
195
|
if normalizer_fn:
|
197
196
|
gates = normalizer_fn(**normalizer_params)(gates)
|
198
197
|
|
199
198
|
if additive_residual:
|
200
|
-
input_features += gates
|
199
|
+
input_features += tf.cast(gates, input_features.dtype)
|
201
200
|
else:
|
202
|
-
input_features *= gates
|
201
|
+
input_features *= tf.cast(gates, input_features.dtype)
|
203
202
|
|
204
203
|
return input_features
|