tf-models-nightly 2.11.0.dev20230320__py2.py3-none-any.whl → 2.11.0.dev20230322__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/projects/yt8m/configs/yt8m.py +3 -7
- official/projects/yt8m/modeling/nn_layers.py +169 -23
- official/projects/yt8m/modeling/yt8m_model.py +53 -138
- official/projects/yt8m/modeling/yt8m_model_test.py +2 -3
- official/projects/yt8m/modeling/yt8m_model_utils.py +5 -6
- official/projects/yt8m/tasks/yt8m_task.py +48 -37
- official/vision/dataloaders/maskrcnn_input.py +12 -13
- official/vision/dataloaders/tfds_classification_decoders.py +1 -0
- official/vision/evaluation/instance_metrics.py +176 -225
- official/vision/ops/preprocess_ops.py +20 -6
- {tf_models_nightly-2.11.0.dev20230320.dist-info → tf_models_nightly-2.11.0.dev20230322.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.11.0.dev20230320.dist-info → tf_models_nightly-2.11.0.dev20230322.dist-info}/RECORD +16 -16
- {tf_models_nightly-2.11.0.dev20230320.dist-info → tf_models_nightly-2.11.0.dev20230322.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.11.0.dev20230320.dist-info → tf_models_nightly-2.11.0.dev20230322.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.11.0.dev20230320.dist-info → tf_models_nightly-2.11.0.dev20230322.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.11.0.dev20230320.dist-info → tf_models_nightly-2.11.0.dev20230322.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@
|
|
15
15
|
"""Video classification configuration definition."""
|
16
16
|
import dataclasses
|
17
17
|
from typing import Optional, Tuple
|
18
|
-
from absl import flags
|
19
18
|
|
20
19
|
from official.core import config_definitions as cfg
|
21
20
|
from official.core import exp_factory
|
@@ -23,7 +22,6 @@ from official.modeling import hyperparams
|
|
23
22
|
from official.modeling import optimization
|
24
23
|
from official.vision.configs import common
|
25
24
|
|
26
|
-
FLAGS = flags.FLAGS
|
27
25
|
|
28
26
|
YT8M_TRAIN_EXAMPLES = 3888919
|
29
27
|
YT8M_VAL_EXAMPLES = 1112356
|
@@ -53,7 +51,7 @@ class DataConfig(cfg.DataConfig):
|
|
53
51
|
temporal_stride: Not used. Need to deprecated.
|
54
52
|
max_frames: Maxim Number of frames in a input example. It is used to crop
|
55
53
|
the input in the temporal dimension.
|
56
|
-
|
54
|
+
num_sample_frames: Number of frames to sample for each input example.
|
57
55
|
num_classes: Number of classes to classify. Assuming it is a classification
|
58
56
|
task.
|
59
57
|
num_devices: Not used. To be deprecated.
|
@@ -77,7 +75,7 @@ class DataConfig(cfg.DataConfig):
|
|
77
75
|
include_video_id: bool = False
|
78
76
|
temporal_stride: int = 1
|
79
77
|
max_frames: int = 300
|
80
|
-
|
78
|
+
num_sample_frames: int = 300 # set smaller to allow random sample (Parser)
|
81
79
|
num_classes: int = 3862
|
82
80
|
num_devices: int = 1
|
83
81
|
input_path: str = ''
|
@@ -90,7 +88,6 @@ def yt8m(is_training):
|
|
90
88
|
"""YT8M dataset configs."""
|
91
89
|
# pylint: disable=unexpected-keyword-arg
|
92
90
|
return DataConfig(
|
93
|
-
num_frames=30,
|
94
91
|
temporal_stride=1,
|
95
92
|
segment_labels=False,
|
96
93
|
segment_size=5,
|
@@ -106,7 +103,6 @@ def yt8m(is_training):
|
|
106
103
|
class MoeModel(hyperparams.Config):
|
107
104
|
"""The model config."""
|
108
105
|
num_mixtures: int = 5
|
109
|
-
l2_penalty: float = 1e-5
|
110
106
|
use_input_context_gate: bool = False
|
111
107
|
use_output_context_gate: bool = False
|
112
108
|
vocab_as_last_dim: bool = False
|
@@ -122,7 +118,7 @@ class DbofModel(hyperparams.Config):
|
|
122
118
|
use_context_gate_cluster_layer: bool = False
|
123
119
|
context_gate_cluster_bottleneck_size: int = 0
|
124
120
|
pooling_method: str = 'average'
|
125
|
-
|
121
|
+
agg_classifier_model: str = 'MoeModel'
|
126
122
|
agg_model: hyperparams.Config = MoeModel()
|
127
123
|
norm_activation: common.NormActivation = common.NormActivation(
|
128
124
|
activation='relu', use_sync_bn=False)
|
@@ -13,81 +13,223 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
"""Contains model definitions."""
|
16
|
+
|
17
|
+
import functools
|
16
18
|
from typing import Any, Dict, Optional
|
17
19
|
|
18
20
|
import tensorflow as tf
|
21
|
+
|
22
|
+
from official.modeling import tf_utils
|
23
|
+
from official.projects.yt8m.configs import yt8m as yt8m_cfg
|
19
24
|
from official.projects.yt8m.modeling import yt8m_model_utils as utils
|
20
25
|
|
26
|
+
|
21
27
|
layers = tf.keras.layers
|
22
28
|
|
23
29
|
|
24
|
-
class
|
25
|
-
"""
|
30
|
+
class Dbof(tf.keras.Model):
|
31
|
+
"""A YT8M model class builder.
|
32
|
+
|
33
|
+
Creates a Deep Bag of Frames model.
|
34
|
+
The model projects the features for each frame into a higher dimensional
|
35
|
+
'clustering' space, pools across frames in that space, and then
|
36
|
+
uses a configurable video-level model to classify the now aggregated features.
|
37
|
+
The model will randomly sample either frames or sequences of frames during
|
38
|
+
training to speed up convergence.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
params: yt8m_cfg.DbofModel,
|
44
|
+
num_classes: int = 3862,
|
45
|
+
input_specs: layers.InputSpec = layers.InputSpec(
|
46
|
+
shape=[None, None, 1152]),
|
47
|
+
l2_weight_decay: Optional[float] = None,
|
48
|
+
**kwargs):
|
49
|
+
"""YT8M initialization function.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
params: model configuration parameters
|
53
|
+
num_classes: `int` number of classes in dataset.
|
54
|
+
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
|
55
|
+
[batch_size x num_frames x num_features]
|
56
|
+
l2_weight_decay: An optional `float` of kernel regularizer weight decay.
|
57
|
+
**kwargs: keyword arguments to be passed.
|
58
|
+
"""
|
59
|
+
self._self_setattr_tracking = False
|
60
|
+
self._num_classes = num_classes
|
61
|
+
self._input_specs = input_specs
|
62
|
+
self._params = params
|
63
|
+
self._l2_weight_decay = l2_weight_decay
|
64
|
+
self._act_fn = tf_utils.get_activation(params.norm_activation.activation)
|
65
|
+
self._norm = functools.partial(
|
66
|
+
layers.BatchNormalization,
|
67
|
+
momentum=params.norm_activation.norm_momentum,
|
68
|
+
epsilon=params.norm_activation.norm_epsilon,
|
69
|
+
synchronized=params.norm_activation.use_sync_bn)
|
70
|
+
|
71
|
+
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
|
72
|
+
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
|
73
|
+
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
|
74
|
+
l2_regularizer = (
|
75
|
+
tf.keras.regularizers.l2(l2_weight_decay / 2.0)
|
76
|
+
if l2_weight_decay
|
77
|
+
else None
|
78
|
+
)
|
79
|
+
|
80
|
+
# [batch_size x num_frames x num_features]
|
81
|
+
feature_size = input_specs.shape[-1]
|
82
|
+
# shape 'excluding' batch_size
|
83
|
+
model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
|
84
|
+
# normalize input features
|
85
|
+
input_data = tf.nn.l2_normalize(model_input, -1)
|
86
|
+
tf.summary.histogram("input_hist", input_data)
|
87
|
+
|
88
|
+
# configure model
|
89
|
+
if params.add_batch_norm:
|
90
|
+
input_data = self._norm(name="input_bn")(input_data)
|
91
|
+
|
92
|
+
# activation = reshaped input * cluster weights
|
93
|
+
if params.cluster_size > 0:
|
94
|
+
activation = layers.Dense(
|
95
|
+
params.cluster_size,
|
96
|
+
kernel_regularizer=l2_regularizer,
|
97
|
+
kernel_initializer=tf.random_normal_initializer(
|
98
|
+
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
|
99
|
+
input_data)
|
100
|
+
|
101
|
+
if params.add_batch_norm:
|
102
|
+
activation = self._norm(name="cluster_bn")(activation)
|
103
|
+
else:
|
104
|
+
cluster_biases = tf.Variable(
|
105
|
+
tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
|
106
|
+
shape=[params.cluster_size]),
|
107
|
+
name="cluster_biases")
|
108
|
+
tf.summary.histogram("cluster_biases", cluster_biases)
|
109
|
+
activation += cluster_biases
|
110
|
+
|
111
|
+
activation = self._act_fn(activation)
|
112
|
+
tf.summary.histogram("cluster_output", activation)
|
26
113
|
|
27
|
-
|
114
|
+
if params.use_context_gate_cluster_layer:
|
115
|
+
pooling_method = None
|
116
|
+
norm_args = dict(name="context_gate_bn")
|
117
|
+
activation = utils.context_gate(
|
118
|
+
activation,
|
119
|
+
normalizer_fn=self._norm,
|
120
|
+
normalizer_params=norm_args,
|
121
|
+
pooling_method=pooling_method,
|
122
|
+
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
|
123
|
+
kernel_regularizer=l2_regularizer)
|
124
|
+
|
125
|
+
activation = utils.frame_pooling(activation, params.pooling_method)
|
126
|
+
|
127
|
+
# activation = activation * hidden1_weights
|
128
|
+
activation = layers.Dense(
|
129
|
+
params.hidden_size,
|
130
|
+
kernel_regularizer=l2_regularizer,
|
131
|
+
kernel_initializer=tf.random_normal_initializer(
|
132
|
+
stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
|
133
|
+
activation)
|
134
|
+
|
135
|
+
if params.add_batch_norm:
|
136
|
+
activation = self._norm(name="hidden1_bn")(activation)
|
137
|
+
|
138
|
+
else:
|
139
|
+
hidden1_biases = tf.Variable(
|
140
|
+
tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
|
141
|
+
name="hidden1_biases")
|
142
|
+
|
143
|
+
tf.summary.histogram("hidden1_biases", hidden1_biases)
|
144
|
+
activation += hidden1_biases
|
145
|
+
|
146
|
+
activation = self._act_fn(activation)
|
147
|
+
tf.summary.histogram("hidden1_output", activation)
|
148
|
+
|
149
|
+
super().__init__(inputs=model_input, outputs=activation, **kwargs)
|
150
|
+
|
151
|
+
|
152
|
+
class LogisticModel(tf.keras.Model):
|
153
|
+
"""Logistic prediction head model with L2 regularization."""
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
|
158
|
+
vocab_size: int = 3862,
|
159
|
+
l2_penalty: float = 1e-8,
|
160
|
+
**kwargs,
|
161
|
+
):
|
28
162
|
"""Creates a logistic model.
|
29
163
|
|
30
164
|
Args:
|
31
|
-
|
165
|
+
input_specs: 'batch' x 'num_features' matrix of input features.
|
32
166
|
vocab_size: The number of classes in the dataset.
|
33
167
|
l2_penalty: L2 weight regularization ratio.
|
168
|
+
**kwargs: extra key word args.
|
34
169
|
|
35
170
|
Returns:
|
36
171
|
A dictionary with a tensor containing the probability predictions of the
|
37
172
|
model in the 'predictions' key. The dimensions of the tensor are
|
38
173
|
batch_size x num_classes.
|
39
174
|
"""
|
175
|
+
inputs = tf.keras.Input(shape=input_specs.shape[1:])
|
40
176
|
output = layers.Dense(
|
41
177
|
vocab_size,
|
42
178
|
activation=tf.nn.sigmoid,
|
43
179
|
kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
|
44
|
-
|
45
|
-
|
180
|
+
inputs)
|
181
|
+
|
182
|
+
super().__init__(inputs=inputs, outputs={"predictions": output}, **kwargs)
|
46
183
|
|
47
184
|
|
48
|
-
class MoeModel():
|
185
|
+
class MoeModel(tf.keras.Model):
|
49
186
|
"""A softmax over a mixture of logistic models (with L2 regularization)."""
|
50
187
|
|
51
|
-
def
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
188
|
+
def __init__(
|
189
|
+
self,
|
190
|
+
input_specs: layers.InputSpec = layers.InputSpec(shape=[None, 128]),
|
191
|
+
vocab_size: int = 3862,
|
192
|
+
num_mixtures: int = 2,
|
193
|
+
use_input_context_gate: bool = False,
|
194
|
+
use_output_context_gate: bool = False,
|
195
|
+
normalizer_params: Optional[Dict[str, Any]] = None,
|
196
|
+
vocab_as_last_dim: bool = False,
|
197
|
+
l2_penalty: float = 1e-5,
|
198
|
+
**kwargs,
|
199
|
+
):
|
61
200
|
"""Creates a Mixture of (Logistic) Experts model.
|
62
201
|
|
63
202
|
The model consists of a per-class softmax distribution over a
|
64
203
|
configurable number of logistic classifiers. One of the classifiers
|
65
204
|
in the mixture is not trained, and always predicts 0.
|
66
205
|
Args:
|
67
|
-
|
206
|
+
input_specs: 'batch_size' x 'num_features' matrix of input features.
|
68
207
|
vocab_size: The number of classes in the dataset.
|
69
208
|
num_mixtures: The number of mixtures (excluding a dummy 'expert' that
|
70
209
|
always predicts the non-existence of an entity).
|
71
210
|
use_input_context_gate: if True apply context gate layer to the input.
|
72
211
|
use_output_context_gate: if True apply context gate layer to the output.
|
73
|
-
|
74
|
-
normalizer_params: parameters to the `normalizer_fn`.
|
212
|
+
normalizer_params: parameters of the batch normalization.
|
75
213
|
vocab_as_last_dim: if True reshape `activations` and make `vocab_size` as
|
76
214
|
the last dimension to avoid small `num_mixtures` as the last dimension.
|
77
215
|
XLA pads up the dimensions of tensors: typically the last dimension will
|
78
216
|
be padded to 128, and the second to last will be padded to 8.
|
79
217
|
l2_penalty: How much to penalize the squared magnitudes of parameter
|
80
218
|
values.
|
219
|
+
**kwargs: extra key word args.
|
81
220
|
|
82
221
|
Returns:
|
83
222
|
A dictionary with a tensor containing the probability predictions
|
84
223
|
of the model in the 'predictions' key. The dimensions of the tensor
|
85
224
|
are batch_size x num_classes.
|
86
225
|
"""
|
226
|
+
inputs = tf.keras.Input(shape=input_specs.shape[1:])
|
227
|
+
model_input = inputs
|
228
|
+
|
87
229
|
if use_input_context_gate:
|
88
230
|
model_input = utils.context_gate(
|
89
231
|
model_input,
|
90
|
-
normalizer_fn=
|
232
|
+
normalizer_fn=layers.BatchNormalization,
|
91
233
|
normalizer_params=normalizer_params,
|
92
234
|
)
|
93
235
|
|
@@ -127,7 +269,11 @@ class MoeModel():
|
|
127
269
|
if use_output_context_gate:
|
128
270
|
final_probabilities = utils.context_gate(
|
129
271
|
final_probabilities,
|
130
|
-
normalizer_fn=
|
272
|
+
normalizer_fn=layers.BatchNormalization,
|
131
273
|
normalizer_params=normalizer_params,
|
132
274
|
)
|
133
|
-
|
275
|
+
super().__init__(
|
276
|
+
inputs=inputs,
|
277
|
+
outputs={"predictions": final_probabilities},
|
278
|
+
**kwargs,
|
279
|
+
)
|
@@ -12,15 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
"""YT8M model definition."""
|
16
|
-
from typing import Optional
|
15
|
+
"""YT8M prediction model definition."""
|
17
16
|
|
17
|
+
import functools
|
18
|
+
from typing import Any, Optional
|
19
|
+
|
20
|
+
from absl import logging
|
18
21
|
import tensorflow as tf
|
19
22
|
|
20
|
-
from official.modeling import tf_utils
|
21
23
|
from official.projects.yt8m.configs import yt8m as yt8m_cfg
|
22
24
|
from official.projects.yt8m.modeling import nn_layers
|
23
|
-
|
25
|
+
|
24
26
|
|
25
27
|
layers = tf.keras.layers
|
26
28
|
|
@@ -39,155 +41,61 @@ class DbofModel(tf.keras.Model):
|
|
39
41
|
def __init__(
|
40
42
|
self,
|
41
43
|
params: yt8m_cfg.DbofModel,
|
42
|
-
num_frames: int = 30,
|
43
44
|
num_classes: int = 3862,
|
44
45
|
input_specs: layers.InputSpec = layers.InputSpec(
|
45
46
|
shape=[None, None, 1152]),
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
norm_epsilon: float = 0.001,
|
51
|
-
**kwargs):
|
52
|
-
"""YT8M initialization function.
|
47
|
+
l2_weight_decay: Optional[float] = None,
|
48
|
+
**kwargs,
|
49
|
+
):
|
50
|
+
"""YT8M Dbof model initialization function.
|
53
51
|
|
54
52
|
Args:
|
55
53
|
params: model configuration parameters
|
56
|
-
num_frames: `int` number of frames in a single input.
|
57
54
|
num_classes: `int` number of classes in dataset.
|
58
55
|
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
|
59
56
|
[batch_size x num_frames x num_features]
|
60
|
-
|
61
|
-
None.
|
62
|
-
activation: A `str` of name of the activation function.
|
63
|
-
use_sync_bn: If True, use synchronized batch normalization.
|
64
|
-
norm_momentum: A `float` of normalization momentum for the moving average.
|
65
|
-
norm_epsilon: A `float` added to variance to avoid dividing by zero.
|
57
|
+
l2_weight_decay: An optional `float` of kernel regularizer weight decay.
|
66
58
|
**kwargs: keyword arguments to be passed.
|
67
59
|
"""
|
68
|
-
|
69
|
-
self.
|
60
|
+
super().__init__()
|
61
|
+
self._params = params
|
62
|
+
self._num_classes = num_classes
|
63
|
+
self._input_specs = input_specs
|
64
|
+
self._l2_weight_decay = l2_weight_decay
|
70
65
|
self._config_dict = {
|
66
|
+
"params": params,
|
71
67
|
"input_specs": input_specs,
|
72
68
|
"num_classes": num_classes,
|
73
|
-
"
|
69
|
+
"l2_weight_decay": l2_weight_decay,
|
74
70
|
}
|
75
|
-
|
76
|
-
self.
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
activation = layers.Dense(
|
104
|
-
params.cluster_size,
|
105
|
-
kernel_regularizer=kernel_regularizer,
|
106
|
-
kernel_initializer=tf.random_normal_initializer(
|
107
|
-
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
|
108
|
-
input_data)
|
109
|
-
|
110
|
-
if params.add_batch_norm:
|
111
|
-
activation = self._norm(
|
112
|
-
axis=bn_axis,
|
113
|
-
momentum=norm_momentum,
|
114
|
-
epsilon=norm_epsilon,
|
115
|
-
name="cluster_bn")(
|
116
|
-
activation)
|
117
|
-
else:
|
118
|
-
cluster_biases = tf.Variable(
|
119
|
-
tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
|
120
|
-
shape=[params.cluster_size]),
|
121
|
-
name="cluster_biases")
|
122
|
-
tf.summary.histogram("cluster_biases", cluster_biases)
|
123
|
-
activation += cluster_biases
|
124
|
-
|
125
|
-
activation = self._act_fn(activation)
|
126
|
-
tf.summary.histogram("cluster_output", activation)
|
127
|
-
|
128
|
-
if params.use_context_gate_cluster_layer:
|
129
|
-
pooling_method = None
|
130
|
-
norm_args = dict(
|
131
|
-
axis=bn_axis,
|
132
|
-
momentum=norm_momentum,
|
133
|
-
epsilon=norm_epsilon,
|
134
|
-
name="context_gate_bn")
|
135
|
-
activation = utils.context_gate(
|
136
|
-
activation,
|
137
|
-
normalizer_fn=self._norm,
|
138
|
-
normalizer_params=norm_args,
|
139
|
-
pooling_method=pooling_method,
|
140
|
-
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
|
141
|
-
kernel_regularizer=kernel_regularizer)
|
142
|
-
|
143
|
-
activation = utils.frame_pooling(activation, params.pooling_method)
|
144
|
-
|
145
|
-
# activation = activation * hidden1_weights
|
146
|
-
activation = layers.Dense(
|
147
|
-
params.hidden_size,
|
148
|
-
kernel_regularizer=kernel_regularizer,
|
149
|
-
kernel_initializer=tf.random_normal_initializer(
|
150
|
-
stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
|
151
|
-
activation)
|
152
|
-
|
153
|
-
if params.add_batch_norm:
|
154
|
-
activation = self._norm(
|
155
|
-
axis=bn_axis,
|
156
|
-
momentum=norm_momentum,
|
157
|
-
epsilon=norm_epsilon,
|
158
|
-
name="hidden1_bn")(
|
159
|
-
activation)
|
160
|
-
|
161
|
-
else:
|
162
|
-
hidden1_biases = tf.Variable(
|
163
|
-
tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
|
164
|
-
name="hidden1_biases")
|
165
|
-
|
166
|
-
tf.summary.histogram("hidden1_biases", hidden1_biases)
|
167
|
-
activation += hidden1_biases
|
168
|
-
|
169
|
-
activation = self._act_fn(activation)
|
170
|
-
tf.summary.histogram("hidden1_output", activation)
|
171
|
-
|
172
|
-
aggregated_model = getattr(nn_layers,
|
173
|
-
params.yt8m_agg_classifier_model)
|
174
|
-
norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
|
175
|
-
output = aggregated_model().create_model(
|
176
|
-
model_input=activation,
|
177
|
-
vocab_size=self._num_classes,
|
178
|
-
num_mixtures=params.agg_model.num_mixtures,
|
179
|
-
normalizer_fn=self._norm,
|
180
|
-
normalizer_params=norm_args,
|
181
|
-
vocab_as_last_dim=params.agg_model.vocab_as_last_dim,
|
182
|
-
l2_penalty=params.agg_model.l2_penalty)
|
183
|
-
|
184
|
-
super().__init__(
|
185
|
-
inputs=model_input, outputs=output.get("predictions"), **kwargs)
|
186
|
-
|
187
|
-
@property
|
188
|
-
def checkpoint_items(self):
|
189
|
-
"""Returns a dictionary of items to be additionally checkpointed."""
|
190
|
-
return dict()
|
71
|
+
|
72
|
+
self.dbof_backbone = nn_layers.Dbof(
|
73
|
+
params,
|
74
|
+
num_classes,
|
75
|
+
input_specs,
|
76
|
+
l2_weight_decay,
|
77
|
+
**kwargs,
|
78
|
+
)
|
79
|
+
|
80
|
+
logging.info("Build DbofModel with %s.", params.agg_classifier_model)
|
81
|
+
if hasattr(nn_layers, params.agg_classifier_model):
|
82
|
+
aggregation_head = getattr(nn_layers, params.agg_classifier_model)
|
83
|
+
if params.agg_classifier_model == "MoeModel":
|
84
|
+
normalizer_params = dict(
|
85
|
+
synchronized=params.norm_activation.use_sync_bn,
|
86
|
+
momentum=params.norm_activation.norm_momentum,
|
87
|
+
epsilon=params.norm_activation.norm_epsilon,
|
88
|
+
)
|
89
|
+
aggregation_head = functools.partial(
|
90
|
+
aggregation_head, normalizer_params=normalizer_params)
|
91
|
+
|
92
|
+
if params.agg_model is not None:
|
93
|
+
kwargs.update(params.agg_model.as_dict())
|
94
|
+
self.head = aggregation_head(
|
95
|
+
input_specs=layers.InputSpec(shape=[None, params.hidden_size]),
|
96
|
+
vocab_size=num_classes,
|
97
|
+
l2_penalty=l2_weight_decay,
|
98
|
+
**kwargs)
|
191
99
|
|
192
100
|
def get_config(self):
|
193
101
|
return self._config_dict
|
@@ -195,3 +103,10 @@ class DbofModel(tf.keras.Model):
|
|
195
103
|
@classmethod
|
196
104
|
def from_config(cls, config):
|
197
105
|
return cls(**config)
|
106
|
+
|
107
|
+
def call(
|
108
|
+
self, inputs: tf.Tensor, training: Any = None, mask: Any = None
|
109
|
+
) -> tf.Tensor:
|
110
|
+
features = self.dbof_backbone(inputs)
|
111
|
+
outputs = self.head(features)
|
112
|
+
return outputs["predictions"]
|
@@ -26,7 +26,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
|
|
26
26
|
"""Class for testing yt8m network."""
|
27
27
|
|
28
28
|
# test_yt8m_network_creation arbitrary params
|
29
|
-
@parameterized.parameters((32, 1152)) # 1152 = 1024 + 128
|
29
|
+
@parameterized.parameters((32, 1152), (24, 1152)) # 1152 = 1024 + 128
|
30
30
|
def test_yt8m_network_creation(self, num_frames, feature_dims):
|
31
31
|
"""Test for creation of a YT8M Model.
|
32
32
|
|
@@ -39,11 +39,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
|
|
39
39
|
num_classes = 3862
|
40
40
|
model = yt8m_model.DbofModel(
|
41
41
|
params=yt8m_cfg.YT8MTask.model,
|
42
|
-
num_frames=num_frames,
|
43
42
|
num_classes=num_classes,
|
44
43
|
input_specs=input_specs)
|
45
44
|
|
46
|
-
# batch = 2 -> arbitrary value for test
|
45
|
+
# batch = 2 -> arbitrary value for test.
|
47
46
|
inputs = np.random.rand(2, num_frames, feature_dims)
|
48
47
|
logits = model(inputs)
|
49
48
|
self.assertAllEqual([2, num_classes], logits.numpy().shape)
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
"""Contains a collection of util functions for model construction."""
|
16
|
+
|
16
17
|
from typing import Any, Dict, Optional, Union
|
17
18
|
|
18
19
|
import tensorflow as tf
|
@@ -177,8 +178,7 @@ def context_gate(
|
|
177
178
|
kernel_initializer=kernel_initializer,
|
178
179
|
bias_initializer=bias_initializer,
|
179
180
|
kernel_regularizer=kernel_regularizer,
|
180
|
-
)(
|
181
|
-
context_features)
|
181
|
+
)(context_features)
|
182
182
|
if normalizer_fn:
|
183
183
|
gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
|
184
184
|
else:
|
@@ -191,14 +191,13 @@ def context_gate(
|
|
191
191
|
kernel_initializer=kernel_initializer,
|
192
192
|
bias_initializer=bias_initializer,
|
193
193
|
kernel_regularizer=kernel_regularizer,
|
194
|
-
)(
|
195
|
-
gates_bottleneck)
|
194
|
+
)(gates_bottleneck)
|
196
195
|
if normalizer_fn:
|
197
196
|
gates = normalizer_fn(**normalizer_params)(gates)
|
198
197
|
|
199
198
|
if additive_residual:
|
200
|
-
input_features += gates
|
199
|
+
input_features += tf.cast(gates, input_features.dtype)
|
201
200
|
else:
|
202
|
-
input_features *= gates
|
201
|
+
input_features *= tf.cast(gates, input_features.dtype)
|
203
202
|
|
204
203
|
return input_features
|