tf-models-nightly 2.12.0.dev20230425__py2.py3-none-any.whl → 2.12.0.dev20230501__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/core/test_utils.py +1 -1
- official/modeling/multitask/test_utils.py +2 -2
- official/modeling/optimization/ema_optimizer.py +3 -3
- official/modeling/optimization/lr_schedule.py +0 -9
- official/nlp/modeling/layers/bigbird_attention.py +1 -1
- official/nlp/modeling/layers/per_dim_scale_attention.py +1 -1
- official/nlp/modeling/layers/relative_attention.py +1 -1
- official/nlp/modeling/layers/spectral_normalization.py +2 -2
- official/nlp/modeling/models/bert_pretrainer.py +1 -1
- official/nlp/modeling/models/electra_pretrainer.py +1 -1
- official/nlp/modeling/models/seq2seq_transformer.py +1 -1
- official/nlp/modeling/models/xlnet.py +3 -3
- official/projects/centernet/modeling/centernet_model.py +1 -1
- official/projects/deepmac_maskrcnn/modeling/heads/hourglass_network.py +1 -1
- official/projects/roformer/roformer_attention.py +1 -1
- official/projects/teams/teams_pretrainer.py +1 -1
- official/utils/testing/mock_task.py +1 -1
- official/vision/configs/backbones.py +2 -0
- official/vision/modeling/backbones/vit.py +28 -20
- official/vision/modeling/heads/segmentation_heads.py +1 -1
- official/vision/modeling/layers/nn_blocks.py +24 -9
- official/vision/modeling/layers/nn_layers.py +18 -1
- official/vision/modeling/maskrcnn_model.py +1 -1
- official/vision/modeling/segmentation_model.py +1 -1
- official/vision/ops/anchor.py +1 -1
- {tf_models_nightly-2.12.0.dev20230425.dist-info → tf_models_nightly-2.12.0.dev20230501.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.12.0.dev20230425.dist-info → tf_models_nightly-2.12.0.dev20230501.dist-info}/RECORD +31 -32
- official/vision/serving/export_tflite_lib_test.py +0 -192
- {tf_models_nightly-2.12.0.dev20230425.dist-info → tf_models_nightly-2.12.0.dev20230501.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.12.0.dev20230425.dist-info → tf_models_nightly-2.12.0.dev20230501.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.12.0.dev20230425.dist-info → tf_models_nightly-2.12.0.dev20230501.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.12.0.dev20230425.dist-info → tf_models_nightly-2.12.0.dev20230501.dist-info}/top_level.txt +0 -0
official/core/test_utils.py
CHANGED
@@ -25,7 +25,7 @@ class FakeKerasModel(tf.keras.Model):
|
|
25
25
|
self.dense = tf.keras.layers.Dense(4, activation=tf.nn.relu)
|
26
26
|
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
|
27
27
|
|
28
|
-
def call(self, inputs):
|
28
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
29
29
|
return self.dense2(self.dense(inputs))
|
30
30
|
|
31
31
|
|
@@ -31,7 +31,7 @@ class MockFooModel(tf.keras.Model):
|
|
31
31
|
self.inputs = {"foo": tf.keras.Input(shape=(2,), dtype=tf.float32),
|
32
32
|
"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
|
33
33
|
|
34
|
-
def call(self, inputs):
|
34
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
35
35
|
self.add_loss(tf.zeros((1,), dtype=tf.float32))
|
36
36
|
if "foo" in inputs:
|
37
37
|
input_tensor = inputs["foo"]
|
@@ -49,7 +49,7 @@ class MockBarModel(tf.keras.Model):
|
|
49
49
|
self._bar_specific_layer = tf.keras.layers.Dense(1)
|
50
50
|
self.inputs = {"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
|
51
51
|
|
52
|
-
def call(self, inputs):
|
52
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
53
53
|
self.add_loss(tf.zeros((2,), dtype=tf.float32))
|
54
54
|
return self._bar_specific_layer(self._share_layer(inputs["bar"]))
|
55
55
|
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
"""Exponential moving average optimizer."""
|
16
16
|
|
17
|
-
from typing import List, Optional
|
17
|
+
from typing import List, Optional
|
18
18
|
|
19
19
|
import tensorflow as tf
|
20
20
|
|
@@ -79,7 +79,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.legacy.Optimizer):
|
|
79
79
|
average_decay: float = 0.99,
|
80
80
|
start_step: int = 0,
|
81
81
|
dynamic_decay: bool = True,
|
82
|
-
name:
|
82
|
+
name: str = 'ExponentialMovingAverage',
|
83
83
|
**kwargs):
|
84
84
|
"""Construct a new ExponentialMovingAverage optimizer.
|
85
85
|
|
@@ -107,7 +107,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.legacy.Optimizer):
|
|
107
107
|
self._start_step = tf.constant(start_step, tf.float32)
|
108
108
|
self._dynamic_decay = dynamic_decay
|
109
109
|
self._optimizer = optimizer
|
110
|
-
self._track_trackable(self._optimizer, '
|
110
|
+
self._track_trackable(self._optimizer, 'ema_base_optimizer')
|
111
111
|
self._average_weights = None
|
112
112
|
self._model_weights = None
|
113
113
|
|
@@ -460,10 +460,6 @@ class StepCosineDecayWithOffset(
|
|
460
460
|
tf.constant(math.pi) * (global_step) /
|
461
461
|
(init_total_steps)) + 1.0) / 2.0 + next_init_lr)
|
462
462
|
learning_rate = cosine_learning_rate
|
463
|
-
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
|
464
|
-
cosine_learning_rate)
|
465
|
-
tf.compat.v1.logging.info("DEBUG lr %r next lr %r inittotalstep %r",
|
466
|
-
init_lr, next_init_lr, init_total_steps)
|
467
463
|
|
468
464
|
for i in range(1, num_levels):
|
469
465
|
next_init_lr = lr_levels[i]
|
@@ -471,9 +467,6 @@ class StepCosineDecayWithOffset(
|
|
471
467
|
next_total_steps = level_total_steps[i]
|
472
468
|
next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.
|
473
469
|
|
474
|
-
tf.compat.v1.logging.info(
|
475
|
-
"DEBUG step %r nilr %r nss %r nts %r nnilr %r", global_step,
|
476
|
-
next_init_lr, next_start_step, next_total_steps, next_next_init_lr)
|
477
470
|
next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
|
478
471
|
(tf.cos(
|
479
472
|
tf.constant(math.pi) *
|
@@ -482,8 +475,6 @@ class StepCosineDecayWithOffset(
|
|
482
475
|
next_next_init_lr)
|
483
476
|
learning_rate = tf.where(global_step >= next_start_step,
|
484
477
|
next_cosine_learning_rate, learning_rate)
|
485
|
-
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
|
486
|
-
next_cosine_learning_rate)
|
487
478
|
|
488
479
|
return learning_rate
|
489
480
|
|
@@ -458,7 +458,7 @@ class BigBirdAttention(tf.keras.layers.MultiHeadAttention):
|
|
458
458
|
to_block_size=self._to_block_size,
|
459
459
|
rand_attn=rand_attn)
|
460
460
|
|
461
|
-
def call(self, query, value, key=None, attention_mask=None, **kwargs):
|
461
|
+
def call(self, query, value, key=None, attention_mask=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
462
462
|
if not self._built_from_signature:
|
463
463
|
self._build_from_signature(query=query, value=value, key=key)
|
464
464
|
if key is None:
|
@@ -67,7 +67,7 @@ class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention):
|
|
67
67
|
attention_scores_dropout, value)
|
68
68
|
return attention_output, attention_scores
|
69
69
|
|
70
|
-
def call(
|
70
|
+
def call( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
71
71
|
self,
|
72
72
|
query,
|
73
73
|
value,
|
@@ -228,7 +228,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
|
|
228
228
|
value)
|
229
229
|
return attention_output
|
230
230
|
|
231
|
-
def call(self,
|
231
|
+
def call(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
232
232
|
query,
|
233
233
|
value,
|
234
234
|
content_attention_bias,
|
@@ -77,7 +77,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
|
|
77
77
|
super().__init__(
|
78
78
|
layer, name=wrapper_name, **kwargs)
|
79
79
|
|
80
|
-
def build(self, input_shape):
|
80
|
+
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
81
81
|
super().build(input_shape)
|
82
82
|
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
|
83
83
|
self._dtype = self.layer.kernel.dtype
|
@@ -195,7 +195,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
|
|
195
195
|
.format(input=layer))
|
196
196
|
super().__init__(layer, **kwargs)
|
197
197
|
|
198
|
-
def build(self, input_shape):
|
198
|
+
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
199
199
|
if not self.layer.built:
|
200
200
|
self.layer.build(input_shape)
|
201
201
|
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
|
@@ -226,7 +226,7 @@ class BertPretrainerV2(tf.keras.Model):
|
|
226
226
|
inputs.append(masked_lm_positions)
|
227
227
|
self.inputs = inputs
|
228
228
|
|
229
|
-
def call(self, inputs):
|
229
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
230
230
|
if isinstance(inputs, list):
|
231
231
|
logging.warning('List inputs to BertPretrainer are discouraged.')
|
232
232
|
inputs = dict([
|
@@ -113,7 +113,7 @@ class ElectraPretrainer(tf.keras.Model):
|
|
113
113
|
units=1,
|
114
114
|
kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
|
115
115
|
|
116
|
-
def call(self, inputs):
|
116
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
117
117
|
"""ELECTRA forward pass.
|
118
118
|
|
119
119
|
Args:
|
@@ -144,7 +144,7 @@ class Seq2SeqTransformer(tf.keras.Model):
|
|
144
144
|
|
145
145
|
return embedded_inputs, boolean_mask, input_shape, source_dtype
|
146
146
|
|
147
|
-
def call(self, inputs):
|
147
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
148
148
|
"""Calculate target logits or inferred target sequences.
|
149
149
|
|
150
150
|
Args:
|
@@ -117,7 +117,7 @@ class XLNetPretrainer(tf.keras.Model):
|
|
117
117
|
hidden_size=self._hidden_size,
|
118
118
|
initializer=self._initializer)
|
119
119
|
|
120
|
-
def call(self, inputs: Mapping[str, Any]):
|
120
|
+
def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
121
121
|
input_word_ids = inputs['input_word_ids']
|
122
122
|
input_type_ids = inputs['input_type_ids']
|
123
123
|
masked_tokens = inputs['masked_tokens']
|
@@ -212,7 +212,7 @@ class XLNetClassifier(tf.keras.Model):
|
|
212
212
|
cls_token_idx=cls_token_idx,
|
213
213
|
name=head_name)
|
214
214
|
|
215
|
-
def call(self, inputs: Mapping[str, Any]):
|
215
|
+
def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
216
216
|
input_ids = inputs['input_word_ids']
|
217
217
|
segment_ids = inputs['input_type_ids']
|
218
218
|
input_mask = tf.cast(inputs['input_mask'], tf.float32)
|
@@ -305,7 +305,7 @@ class XLNetSpanLabeler(tf.keras.Model):
|
|
305
305
|
dropout_rate=self._dropout_rate,
|
306
306
|
initializer=self._initializer)
|
307
307
|
|
308
|
-
def call(self, inputs: Mapping[str, Any]):
|
308
|
+
def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
309
309
|
input_word_ids = inputs['input_word_ids']
|
310
310
|
input_type_ids = inputs['input_type_ids']
|
311
311
|
input_mask = inputs['input_mask']
|
@@ -41,7 +41,7 @@ class CenterNetModel(tf.keras.Model):
|
|
41
41
|
self._detection_generator = detection_generator
|
42
42
|
self._head = head
|
43
43
|
|
44
|
-
def call(self,
|
44
|
+
def call(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
45
45
|
inputs: tf.Tensor,
|
46
46
|
training: bool = None,
|
47
47
|
**kwargs) -> Mapping[str, tf.Tensor]:
|
@@ -439,7 +439,7 @@ class HourglassNetwork(tf.keras.Model):
|
|
439
439
|
|
440
440
|
self.intermediate_relu = tf.keras.layers.ReLU()
|
441
441
|
|
442
|
-
def call(self, inputs):
|
442
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
443
443
|
|
444
444
|
if self.initial_downsample:
|
445
445
|
inputs = self.downsample_input(inputs)
|
@@ -90,7 +90,7 @@ class RoformerAttention(tf.keras.layers.MultiHeadAttention):
|
|
90
90
|
...] + k2 * self.k_sin_vec[:, 0:k_len, ...]
|
91
91
|
return ret_q, ret_w, v
|
92
92
|
|
93
|
-
def call(self,
|
93
|
+
def call(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
94
94
|
query,
|
95
95
|
value,
|
96
96
|
key=None,
|
@@ -299,7 +299,7 @@ class TeamsPretrainer(tf.keras.Model):
|
|
299
299
|
output=output_type,
|
300
300
|
name='discriminator_mws')
|
301
301
|
|
302
|
-
def call(self, inputs):
|
302
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
303
303
|
"""TEAMS forward pass.
|
304
304
|
|
305
305
|
Args:
|
@@ -31,7 +31,7 @@ class MockModel(tf.keras.Model):
|
|
31
31
|
super().__init__()
|
32
32
|
self.network = network
|
33
33
|
|
34
|
-
def call(self, inputs):
|
34
|
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
35
35
|
outputs = self.network(inputs)
|
36
36
|
self.add_loss(tf.reduce_mean(outputs))
|
37
37
|
return outputs
|
@@ -50,6 +50,8 @@ class VisionTransformer(hyperparams.Config):
|
|
50
50
|
|
51
51
|
# Adding Layerscale to each Encoder block https://arxiv.org/abs/2204.07118
|
52
52
|
layer_scale_init_value: float = 0.0
|
53
|
+
# Transformer encoder spatial partition dimensions.
|
54
|
+
transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None
|
53
55
|
|
54
56
|
|
55
57
|
@dataclasses.dataclass
|
@@ -122,6 +122,7 @@ class Encoder(layers.Layer):
|
|
122
122
|
pos_embed_origin_shape=None,
|
123
123
|
pos_embed_target_shape=None,
|
124
124
|
layer_scale_init_value=0.0,
|
125
|
+
transformer_partition_dims=None,
|
125
126
|
**kwargs):
|
126
127
|
super().__init__(**kwargs)
|
127
128
|
self._num_layers = num_layers
|
@@ -137,6 +138,7 @@ class Encoder(layers.Layer):
|
|
137
138
|
self._pos_embed_origin_shape = pos_embed_origin_shape
|
138
139
|
self._pos_embed_target_shape = pos_embed_target_shape
|
139
140
|
self._layer_scale_init_value = layer_scale_init_value
|
141
|
+
self._transformer_partition_dims = transformer_partition_dims
|
140
142
|
|
141
143
|
def build(self, input_shape):
|
142
144
|
if self._add_pos_embed:
|
@@ -163,7 +165,8 @@ class Encoder(layers.Layer):
|
|
163
165
|
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
|
164
166
|
self._init_stochastic_depth_rate, i + 1, self._num_layers),
|
165
167
|
norm_epsilon=1e-6,
|
166
|
-
layer_scale_init_value=self._layer_scale_init_value,
|
168
|
+
layer_scale_init_value=self._layer_scale_init_value,
|
169
|
+
transformer_partition_dims=self._transformer_partition_dims)
|
167
170
|
self._encoder_layers.append(encoder_layer)
|
168
171
|
self._norm = layers.LayerNormalization(epsilon=1e-6)
|
169
172
|
super().build(input_shape)
|
@@ -195,6 +198,7 @@ class Encoder(layers.Layer):
|
|
195
198
|
'pos_embed_origin_shape': self._pos_embed_origin_shape,
|
196
199
|
'pos_embed_target_shape': self._pos_embed_target_shape,
|
197
200
|
'layer_scale_init_value': self._layer_scale_init_value,
|
201
|
+
'transformer_partition_dims': self._transformer_partition_dims,
|
198
202
|
}
|
199
203
|
config.update(updates)
|
200
204
|
return config
|
@@ -203,24 +207,27 @@ class Encoder(layers.Layer):
|
|
203
207
|
class VisionTransformer(tf.keras.Model):
|
204
208
|
"""Class to build VisionTransformer family model."""
|
205
209
|
|
206
|
-
def __init__(
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
210
|
+
def __init__(
|
211
|
+
self,
|
212
|
+
mlp_dim=3072,
|
213
|
+
num_heads=12,
|
214
|
+
num_layers=12,
|
215
|
+
attention_dropout_rate=0.0,
|
216
|
+
dropout_rate=0.1,
|
217
|
+
init_stochastic_depth_rate=0.0,
|
218
|
+
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
|
219
|
+
patch_size=16,
|
220
|
+
hidden_size=768,
|
221
|
+
representation_size=0,
|
222
|
+
pooler='token',
|
223
|
+
kernel_regularizer=None,
|
224
|
+
original_init: bool = True,
|
225
|
+
output_encoded_tokens: bool = True,
|
226
|
+
output_2d_feature_maps: bool = False,
|
227
|
+
pos_embed_shape: Optional[Tuple[int, int]] = None,
|
228
|
+
layer_scale_init_value: float = 0.0,
|
229
|
+
transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None,
|
230
|
+
):
|
224
231
|
"""VisionTransformer initialization function."""
|
225
232
|
self._mlp_dim = mlp_dim
|
226
233
|
self._num_heads = num_heads
|
@@ -368,4 +375,5 @@ def build_vit(input_specs,
|
|
368
375
|
output_encoded_tokens=backbone_cfg.output_encoded_tokens,
|
369
376
|
output_2d_feature_maps=backbone_cfg.output_2d_feature_maps,
|
370
377
|
layer_scale_init_value=backbone_cfg.layer_scale_init_value,
|
371
|
-
pos_embed_shape=backbone_cfg.pos_embed_shape
|
378
|
+
pos_embed_shape=backbone_cfg.pos_embed_shape,
|
379
|
+
transformer_partition_dims=backbone_cfg.transformer_partition_dims)
|
@@ -167,7 +167,7 @@ class MaskScoring(tf.keras.Model):
|
|
167
167
|
|
168
168
|
super(MaskScoring, self).build(input_shape)
|
169
169
|
|
170
|
-
def call(self, inputs: tf.Tensor, training: bool = None):
|
170
|
+
def call(self, inputs: tf.Tensor, training: bool = None): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
171
171
|
"""Forward pass mask scoring head.
|
172
172
|
|
173
173
|
Args:
|
@@ -1578,6 +1578,7 @@ class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
|
|
1578
1578
|
*args,
|
1579
1579
|
stochastic_depth_drop_rate=0.0,
|
1580
1580
|
layer_scale_init_value=0.0,
|
1581
|
+
transformer_partition_dims=None,
|
1581
1582
|
max_attention_inference_parallelism=None,
|
1582
1583
|
**kwargs
|
1583
1584
|
):
|
@@ -1587,6 +1588,7 @@ class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
|
|
1587
1588
|
*args: positional arguments passed to super().__init__.
|
1588
1589
|
stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
|
1589
1590
|
layer_scale_init_value:
|
1591
|
+
transformer_partition_dims: transformer spatial partition dimenstions.
|
1590
1592
|
max_attention_inference_parallelism: the number of examples to run in
|
1591
1593
|
parallel in the attention blocks during inference. Set this limit to
|
1592
1594
|
reduce the peak memory usage. If None, use vectorized operations to run
|
@@ -1596,11 +1598,14 @@ class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
|
|
1596
1598
|
super().__init__(*args, **kwargs)
|
1597
1599
|
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
|
1598
1600
|
self._layer_scale_init_value = layer_scale_init_value
|
1601
|
+
self._transformer_partition_dims = transformer_partition_dims
|
1599
1602
|
self._max_attention_inference_parallelism = (
|
1600
1603
|
max_attention_inference_parallelism
|
1601
1604
|
)
|
1602
1605
|
|
1603
1606
|
def build(self, input_shape):
|
1607
|
+
super().build(input_shape)
|
1608
|
+
|
1604
1609
|
if self._stochastic_depth_drop_rate:
|
1605
1610
|
self._stochastic_depth = nn_layers.StochasticDepth(
|
1606
1611
|
self._stochastic_depth_drop_rate)
|
@@ -1615,22 +1620,32 @@ class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
|
|
1615
1620
|
else:
|
1616
1621
|
self._layer_scale_attn = lambda x, *args, **kwargs: tf.identity(x)
|
1617
1622
|
self._layer_scale_mlp = lambda x, *args, **kwargs: tf.identity(x)
|
1618
|
-
super().build(input_shape)
|
1619
1623
|
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1624
|
+
self._attention_layer = nn_layers.MultiHeadAttention(
|
1625
|
+
num_heads=self._num_heads,
|
1626
|
+
key_dim=self._key_dim,
|
1627
|
+
value_dim=self._value_dim,
|
1628
|
+
dropout=self._attention_dropout_rate,
|
1629
|
+
use_bias=self._use_bias,
|
1630
|
+
kernel_initializer=self._attention_initializer,
|
1631
|
+
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
|
1632
|
+
attention_axes=self._attention_axes,
|
1633
|
+
output_shape=self._output_last_dim,
|
1634
|
+
bias_regularizer=self._bias_regularizer,
|
1635
|
+
activity_regularizer=self._activity_regularizer,
|
1636
|
+
kernel_constraint=self._kernel_constraint,
|
1637
|
+
bias_constraint=self._bias_constraint,
|
1638
|
+
max_inference_parallelism=self._max_attention_inference_parallelism,
|
1639
|
+
partition_dims=self._transformer_partition_dims,
|
1640
|
+
name='self_attention',
|
1641
|
+
)
|
1628
1642
|
|
1629
1643
|
def get_config(self):
|
1630
1644
|
config = super().get_config()
|
1631
1645
|
config.update({
|
1632
1646
|
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
|
1633
1647
|
'layer_scale_init_value': self._layer_scale_init_value,
|
1648
|
+
'transformer_partition_dims': self._transformer_partition_dims,
|
1634
1649
|
'max_attention_inference_parallelism': (
|
1635
1650
|
self._max_attention_inference_parallelism
|
1636
1651
|
),
|
@@ -1291,23 +1291,30 @@ class MultiHeadAttention(tf.keras.layers.MultiHeadAttention):
|
|
1291
1291
|
"""
|
1292
1292
|
|
1293
1293
|
def __init__(
|
1294
|
-
self,
|
1294
|
+
self,
|
1295
|
+
*args,
|
1296
|
+
partition_dims: Optional[Tuple[int, int, int, int]] = None,
|
1297
|
+
max_inference_parallelism: Optional[int] = None,
|
1298
|
+
**kwargs,
|
1295
1299
|
):
|
1296
1300
|
"""Initializes MultiHeadAttention.
|
1297
1301
|
|
1298
1302
|
Args:
|
1299
1303
|
*args: Positional arguments passed to super().__init__.
|
1304
|
+
partition_dims: Spatial partition dimensions.
|
1300
1305
|
max_inference_parallelism: The number of examples to run in parallel
|
1301
1306
|
during inference. Set this limit to reduce the peak memory usage. If
|
1302
1307
|
None, use vectorized operations to run the whole batch in parallel.
|
1303
1308
|
**kwargs: Keyword arguments passed to super().__init__.
|
1304
1309
|
"""
|
1305
1310
|
super().__init__(*args, **kwargs)
|
1311
|
+
self._partition_dims = partition_dims
|
1306
1312
|
self._max_inference_parallelism = max_inference_parallelism
|
1307
1313
|
|
1308
1314
|
def get_config(self):
|
1309
1315
|
config = super().get_config()
|
1310
1316
|
config.update({
|
1317
|
+
'partition_dims': self._partition_dims,
|
1311
1318
|
'max_inference_parallelism': self._max_inference_parallelism,
|
1312
1319
|
})
|
1313
1320
|
return config
|
@@ -1336,6 +1343,16 @@ class MultiHeadAttention(tf.keras.layers.MultiHeadAttention):
|
|
1336
1343
|
attention_output: Multi-headed outputs of attention computation.
|
1337
1344
|
attention_scores: Multi-headed attention weights.
|
1338
1345
|
"""
|
1346
|
+
if self._partition_dims is not None:
|
1347
|
+
strategy = tf.distribute.get_strategy()
|
1348
|
+
# `query` = [B, T, N ,H]
|
1349
|
+
query = strategy.experimental_split_to_logical_devices(
|
1350
|
+
query, self._partition_dims)
|
1351
|
+
key = strategy.experimental_split_to_logical_devices(
|
1352
|
+
key, self._partition_dims)
|
1353
|
+
value = strategy.experimental_split_to_logical_devices(
|
1354
|
+
value, self._partition_dims)
|
1355
|
+
|
1339
1356
|
batch_size = query.get_shape().as_list()[0] # None if dynamic.
|
1340
1357
|
|
1341
1358
|
if (
|
@@ -59,7 +59,7 @@ class SegmentationModel(tf.keras.Model):
|
|
59
59
|
self.head = head
|
60
60
|
self.mask_scoring_head = mask_scoring_head
|
61
61
|
|
62
|
-
def call(self, inputs: tf.Tensor, training: bool = None
|
62
|
+
def call(self, inputs: tf.Tensor, training: bool = None # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
63
63
|
) -> Dict[str, tf.Tensor]:
|
64
64
|
backbone_features = self.backbone(inputs)
|
65
65
|
|
official/vision/ops/anchor.py
CHANGED
@@ -304,7 +304,7 @@ class RpnAnchorLabeler(AnchorLabeler):
|
|
304
304
|
return (ignore_labels + positive_labels + negative_labels,
|
305
305
|
positive_labels, negative_labels)
|
306
306
|
|
307
|
-
def label_anchors(
|
307
|
+
def label_anchors( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
308
308
|
self, anchor_boxes: Dict[str, tf.Tensor], gt_boxes: tf.Tensor,
|
309
309
|
gt_labels: tf.Tensor
|
310
310
|
) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]:
|
@@ -24,7 +24,7 @@ official/core/registry_test.py,sha256=8iQvUwOajmJ_ajHMjV7GkhhCX6ymNCTRcq1i7rQnmu
|
|
24
24
|
official/core/savedmodel_checkpoint_manager.py,sha256=6q0IBaNsQ3IsxF08dAsUGcQoQ7Vo74ELJHH_znT6Dj0,9225
|
25
25
|
official/core/savedmodel_checkpoint_manager_test.py,sha256=XGLlc_CimGMJLHL1-vtznBqEJ4j7S9IE7iS5Obo1lJg,4045
|
26
26
|
official/core/task_factory.py,sha256=mFwJ3jHEj2FjWiEf3w22V5n2qwPywmXFeR9FZvYZLMM,2513
|
27
|
-
official/core/test_utils.py,sha256=
|
27
|
+
official/core/test_utils.py,sha256=bsjidTNPtPOTGyDQRmi2MCgi1OjDw7PW9BmAYnCiQyM,1877
|
28
28
|
official/core/tf_example_builder.py,sha256=eI34jkqX378mvtIVsSldtIlsBSfonqNwleh5pQrQmpc,4623
|
29
29
|
official/core/tf_example_builder_test.py,sha256=1g9iB17VqsO36hyXbfkprPHUzggU6dF5mX1XsoLlXFQ,5793
|
30
30
|
official/core/tf_example_feature_key.py,sha256=WJC1HAQ2y2h7goJ2Pbf2ejUYvi2CQ5jUGbRv7vJcQQ0,2095
|
@@ -223,17 +223,17 @@ official/modeling/multitask/interleaving_trainer_test.py,sha256=fw7wJRtghPMIuAOf
|
|
223
223
|
official/modeling/multitask/multitask.py,sha256=pciaApxYL0X1Y6hj3wrSWHzT5Q-T7T__47vBNnYzhlM,5938
|
224
224
|
official/modeling/multitask/task_sampler.py,sha256=unXwSKOqg1XoJYdF7c5Nxd0QFUtxjNSzOqhPJxBUp6Y,4887
|
225
225
|
official/modeling/multitask/task_sampler_test.py,sha256=yUqlPXYAyH5JmyeVCPVJ2efQZfVKMO7y5m4KEeGp090,3027
|
226
|
-
official/modeling/multitask/test_utils.py,sha256=
|
226
|
+
official/modeling/multitask/test_utils.py,sha256=CxvKg4YTP1B6WynXudcZrlrD-rXFj3zRGUz0bj2RGXE,4305
|
227
227
|
official/modeling/multitask/train_lib.py,sha256=_5KBdUlUhotyiwWsdJnTm2e9MnR2UI0mCTAy1VRkgpI,10832
|
228
228
|
official/modeling/multitask/train_lib_test.py,sha256=YGS0Aa1PLfH2FQtwm8v6PYoqeSSdusvllrnunUfREpk,4756
|
229
229
|
official/modeling/optimization/__init__.py,sha256=mKS7Ujww2sec6_n_KXFktt5iJdAVIGHPTAaMYWdQa3c,1201
|
230
230
|
official/modeling/optimization/adafactor_optimizer.py,sha256=Qnw1VSjRI5HJbsKBv3Q83tRhFno35Vfu0G1jNxA8Ofc,792
|
231
|
-
official/modeling/optimization/ema_optimizer.py,sha256=
|
231
|
+
official/modeling/optimization/ema_optimizer.py,sha256=ecf1C_y_GcyVKr_MrEMlbjVUzkt9b3src8o8J-1qUwA,10508
|
232
232
|
official/modeling/optimization/lamb.py,sha256=bANup6_NWFaITBMComKKCXpzU8_hBDcr3AARZiNP0m4,10137
|
233
233
|
official/modeling/optimization/lamb_test.py,sha256=hTLunlnOr-VWzRUTRdkeDWKdSBMCmQJcrEH-gP75Fuo,5928
|
234
234
|
official/modeling/optimization/lars.py,sha256=UO5YPPlaGR5lyrBH1FzUvVGzMefklioJXxGxWhHGWgk,7338
|
235
235
|
official/modeling/optimization/legacy_adamw.py,sha256=8d0Zw4gDsDOub2vCU4t9rgFEZ7qVTUOdT82BGnq0SR8,5953
|
236
|
-
official/modeling/optimization/lr_schedule.py,sha256=
|
236
|
+
official/modeling/optimization/lr_schedule.py,sha256=skKKqQ51aEO5xRZWLrElqLUJEB2-Wd_6WLhA0wurPk8,18498
|
237
237
|
official/modeling/optimization/lr_schedule_test.py,sha256=G_IV6daj1S5cKsypVaR8NfLbrB7DvtFZ4HQCtP0I-oY,3951
|
238
238
|
official/modeling/optimization/optimizer_factory.py,sha256=8s1hMSgyQZKk-QoNI1WunOSUyTu5Sri2guLvsgl9Mqg,10475
|
239
239
|
official/modeling/optimization/optimizer_factory_test.py,sha256=MiHFcKvsSPsgHnBxk8T_CwzakcGnRLvw3bJ7PO1FlLI,17015
|
@@ -301,7 +301,7 @@ official/nlp/modeling/__init__.py,sha256=bL8szAeBlvUmdY-yNRk-Nsurs7USPP3K4_pnyjj
|
|
301
301
|
official/nlp/modeling/layers/__init__.py,sha256=0YkQXKDdGZl1H8K0pKb93ex3L6nY_vP4XFNXPKH8OJE,4796
|
302
302
|
official/nlp/modeling/layers/attention.py,sha256=fJj_bsWgosN8p4Yy1eFToKaAuQICef_L_iwkHeg5gfw,3896
|
303
303
|
official/nlp/modeling/layers/attention_test.py,sha256=vMBf1RRgx3tg3CXNyi6XW9nwG6KKnea8IzJ58udwN1o,3526
|
304
|
-
official/nlp/modeling/layers/bigbird_attention.py,sha256=
|
304
|
+
official/nlp/modeling/layers/bigbird_attention.py,sha256=XOcuap9IKvpXSkWtiAztqsutIhdJ66bylPKhDvzNA5A,21101
|
305
305
|
official/nlp/modeling/layers/bigbird_attention_test.py,sha256=RVo4pNazC8ktRE7YI4HdyrM0X69pFmz9b8B-9lp7ix0,2196
|
306
306
|
official/nlp/modeling/layers/block_diag_feedforward.py,sha256=KYzPAyZDCyA1jTL7P9gTNAJVRypngbkZ5Sv5QAbfTRA,7233
|
307
307
|
official/nlp/modeling/layers/block_diag_feedforward_test.py,sha256=JrSs8t7Ry4ZZHYGrk_W4Tp0EBGqhboHRtKbj4jT68Lo,4171
|
@@ -333,11 +333,11 @@ official/nlp/modeling/layers/on_device_embedding.py,sha256=Z7sDSzdVe7CaMLfPe2DEP
|
|
333
333
|
official/nlp/modeling/layers/on_device_embedding_test.py,sha256=t5s90xeiHB7WVJKb-YKw63hx2Pgjf9UcvPdZRpy_CwQ,8579
|
334
334
|
official/nlp/modeling/layers/pack_optimization.py,sha256=b1XKyUtxeygyVzsDJveZIm-vgsdy-xznSRoXjyZ_Mzk,10279
|
335
335
|
official/nlp/modeling/layers/pack_optimization_test.py,sha256=cuNoFFxwLB9cT5hw0lfy_aA7MTHOO3Mh231UJO65_Ak,2785
|
336
|
-
official/nlp/modeling/layers/per_dim_scale_attention.py,sha256=
|
336
|
+
official/nlp/modeling/layers/per_dim_scale_attention.py,sha256=qQIuonWRU92kO0-RHqoI4BHEWIm_v-XnozXYWxUjmbM,3406
|
337
337
|
official/nlp/modeling/layers/per_dim_scale_attention_test.py,sha256=7GZE3CyIHZx9x_s2Dwn5YNODLbyjhPH-SpAJojDET_0,1751
|
338
338
|
official/nlp/modeling/layers/position_embedding.py,sha256=X7yaGV7KoHN-t51wk3jWRVg6w_gpgoYnhLTxUP-qEMs,11337
|
339
339
|
official/nlp/modeling/layers/position_embedding_test.py,sha256=CueBbGQYpirkZpjlID2xAd4pmh49GWssx2QN7jKU6H8,8009
|
340
|
-
official/nlp/modeling/layers/relative_attention.py,sha256=
|
340
|
+
official/nlp/modeling/layers/relative_attention.py,sha256=flArMqRqCtHIQd6yWisscgclaEq6S_HxtpGV81tHrgw,20547
|
341
341
|
official/nlp/modeling/layers/relative_attention_test.py,sha256=rTN71dKVytcIceiHRUPLPKdDogPnHJ_GzVW9bU_OXHQ,6685
|
342
342
|
official/nlp/modeling/layers/reuse_attention.py,sha256=QjNS2hipbSKMSjqsldmNMjQ_3plYhGX_xj_X87_iQ08,25658
|
343
343
|
official/nlp/modeling/layers/reuse_attention_test.py,sha256=340CNOwxcNxf2RKZnBdS4rS-rv0VkrbiCww5TSdjLWs,14319
|
@@ -348,7 +348,7 @@ official/nlp/modeling/layers/rezero_transformer_test.py,sha256=NmD3v5xcnpRxTe9dk
|
|
348
348
|
official/nlp/modeling/layers/routing.py,sha256=BdNat8l8PWXE1eIdYXBh41Uy89BJa1b4OR7x9w69W34,4459
|
349
349
|
official/nlp/modeling/layers/routing_test.py,sha256=iEx_88DCG7G62ML_TCGR1rS5CXlNSZEGZ8E3Z8UkkEI,2216
|
350
350
|
official/nlp/modeling/layers/self_attention_mask.py,sha256=vFcKClhCma8R3DZ3V_f-5oIXr36eOFp4s7IPpXGyxWA,2163
|
351
|
-
official/nlp/modeling/layers/spectral_normalization.py,sha256=
|
351
|
+
official/nlp/modeling/layers/spectral_normalization.py,sha256=2y-Ay8MkDzl3DytPInMqUjx_ww7UfmYgOCvRZjjAYxA,10728
|
352
352
|
official/nlp/modeling/layers/spectral_normalization_test.py,sha256=acyfgLtuzKEpws14u-aqWqni6D41Z1gdq4hRfmvBQwM,3111
|
353
353
|
official/nlp/modeling/layers/talking_heads_attention.py,sha256=shkxv2FzH76jSzTdxtZyPgu8kBnuNzNZlaWGti_8xwY,6903
|
354
354
|
official/nlp/modeling/layers/talking_heads_attention_test.py,sha256=oqgiQPsNLpckHJy1BYN56O5P1jXzHyYj9Y2JaejVqzU,7012
|
@@ -373,7 +373,7 @@ official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py,sh
|
|
373
373
|
official/nlp/modeling/models/__init__.py,sha256=bOt0Hr5hjT9K7NQK-hzG9fsye3flx3U4X8RhkF8neN0,1654
|
374
374
|
official/nlp/modeling/models/bert_classifier.py,sha256=qR8trNtfxhO_FWKUyTR43q3_Xu3LNWMxRUlvN3uQxhk,5896
|
375
375
|
official/nlp/modeling/models/bert_classifier_test.py,sha256=uAXgZgK2Zq2dtIsbDJWqUO1Aro0Beyjc4rFMBLE37OA,4815
|
376
|
-
official/nlp/modeling/models/bert_pretrainer.py,sha256=
|
376
|
+
official/nlp/modeling/models/bert_pretrainer.py,sha256=upQUW4JBvFPqibsSiMxyXL30Qwn8GGaZsR4ekHgi7ZU,11480
|
377
377
|
official/nlp/modeling/models/bert_pretrainer_test.py,sha256=g4yaAdaoNGcfzt2vDMhwJleTRH9pLNzCYOGol5o7D84,9900
|
378
378
|
official/nlp/modeling/models/bert_span_labeler.py,sha256=J7p_RZc_B89UUZbJ8Z3vrbujIwJWuMZ8_20Ti-4bsHI,4997
|
379
379
|
official/nlp/modeling/models/bert_span_labeler_test.py,sha256=s75mgXjAcZrOTnsOUVkcxuAS0v00nomOlnvf3EW80UQ,4769
|
@@ -381,13 +381,13 @@ official/nlp/modeling/models/bert_token_classifier.py,sha256=IoaQsv_uFCnmlfxTRFU
|
|
381
381
|
official/nlp/modeling/models/bert_token_classifier_test.py,sha256=bkL644Hp57GlW_NcQlrPobF2YO_ueJLLLauuxOsnLmU,4835
|
382
382
|
official/nlp/modeling/models/dual_encoder.py,sha256=_yUt8Ox4iWhJCF3zqIbBDba0WxcAvd1PrVjQACJZRjs,6583
|
383
383
|
official/nlp/modeling/models/dual_encoder_test.py,sha256=XZ1AGzR55WFWmo0VstV4eWcvkBzJlHyl_6H9nQIM7Xk,5120
|
384
|
-
official/nlp/modeling/models/electra_pretrainer.py,sha256=
|
384
|
+
official/nlp/modeling/models/electra_pretrainer.py,sha256=Eg9aOziKn761EDEUs3gqxUOgdq1nHArjUX_l2Vcc2VE,13036
|
385
385
|
official/nlp/modeling/models/electra_pretrainer_test.py,sha256=ArnVb5gURE5xj88DIq6HsZib0bcFk46ePpsd9Nqa6Fo,6503
|
386
|
-
official/nlp/modeling/models/seq2seq_transformer.py,sha256=
|
386
|
+
official/nlp/modeling/models/seq2seq_transformer.py,sha256=4oEIFMEWavyDRRvtaz2g-bAKiD3ehdrLM6MSzNfOw90,25791
|
387
387
|
official/nlp/modeling/models/seq2seq_transformer_test.py,sha256=3fGfPITjC_xdNL8B8U5_qVPoYXS6_7Ic3jlBKaeebrI,5399
|
388
388
|
official/nlp/modeling/models/t5.py,sha256=pQk5Gleeq3Sr90tVQQ5uZlftyViVNhqY2l6d3f0DUYc,57312
|
389
389
|
official/nlp/modeling/models/t5_test.py,sha256=UsF6AMuTRRf_OX9xAcbwIRsuYukp8KS-qtVHAy4Y2jM,27024
|
390
|
-
official/nlp/modeling/models/xlnet.py,sha256=
|
390
|
+
official/nlp/modeling/models/xlnet.py,sha256=kJkmNsKavBKlPPwiklbGA0okakCJYky3M87WUjtVQrc,12004
|
391
391
|
official/nlp/modeling/models/xlnet_test.py,sha256=VSG37nAOuZkgurNlNDCAXXpOZ-KeEXm8YRQvBWS1UzI,12608
|
392
392
|
official/nlp/modeling/networks/__init__.py,sha256=TQT6o5ZZUtc_ofg8gSUq30bA4FtId6Mz-z39dEF73Rg,1839
|
393
393
|
official/nlp/modeling/networks/albert_encoder.py,sha256=LgrfJB5jYZsNQ2LPpZhoeko3TekCwQuT_taPVxjhLa4,8877
|
@@ -477,7 +477,7 @@ official/projects/centernet/losses/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXN
|
|
477
477
|
official/projects/centernet/losses/centernet_losses.py,sha256=rN7tyZ0VwNMgCFUKn7ce9HCAPBcFhxQIVug5jPgPRXM,4548
|
478
478
|
official/projects/centernet/losses/centernet_losses_test.py,sha256=VlUVFqFxHhyMIjECHepmzsvePQvRpMG9f6Qa-loyFmM,3777
|
479
479
|
official/projects/centernet/modeling/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
480
|
-
official/projects/centernet/modeling/centernet_model.py,sha256=
|
480
|
+
official/projects/centernet/modeling/centernet_model.py,sha256=uAXOCcotwzXKQvCbXHCcGPPQA2OlyMq56fdJ7VqWEHE,2616
|
481
481
|
official/projects/centernet/modeling/centernet_model_test.py,sha256=WWUNxWXcu-zpI2ZcdeXrfOqcxvEzeODvRn5lnPCt6mc,2575
|
482
482
|
official/projects/centernet/modeling/backbones/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
483
483
|
official/projects/centernet/modeling/backbones/hourglass.py,sha256=M5aOsOHFGcHaTTAdEyRv5s0ThTqDh-hYAOs91KT6490,10285
|
@@ -517,7 +517,7 @@ official/projects/deepmac_maskrcnn/modeling/__init__.py,sha256=1ToRMjre4mErL4Ek4
|
|
517
517
|
official/projects/deepmac_maskrcnn/modeling/maskrcnn_model.py,sha256=_hRwbPrjII6WhhrK6Ehu1Zpo8XBmQQf_Evzmpy0VflY,9656
|
518
518
|
official/projects/deepmac_maskrcnn/modeling/maskrcnn_model_test.py,sha256=0_zFQo30yEDME7o_lo_KJqfegmIJjlxUl0Y-saHdqIs,5731
|
519
519
|
official/projects/deepmac_maskrcnn/modeling/heads/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
520
|
-
official/projects/deepmac_maskrcnn/modeling/heads/hourglass_network.py,sha256=
|
520
|
+
official/projects/deepmac_maskrcnn/modeling/heads/hourglass_network.py,sha256=dochwOGv2S_O_Fd8ew0_nTn58wNh_Q6vLWKb9cychyg,21705
|
521
521
|
official/projects/deepmac_maskrcnn/modeling/heads/instance_heads.py,sha256=fsCbLANUxJ5QRVUVhThZgymJIWLQC39Zr0uooP7GO2g,11466
|
522
522
|
official/projects/deepmac_maskrcnn/modeling/heads/instance_heads_test.py,sha256=M8iuImwpTHnnVXDLNCHVVk7NunIzKJrMrRyQc9_RTtU,3172
|
523
523
|
official/projects/deepmac_maskrcnn/serving/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
@@ -651,7 +651,7 @@ official/projects/qat/vision/tasks/retinanet_test.py,sha256=Kl8s3-ag5ME58QthtwiY
|
|
651
651
|
official/projects/qat/vision/tasks/semantic_segmentation.py,sha256=mhQ7hb3M4e5f71kLDa82aWyqQA-gUIfOw9GFUSoD05o,1552
|
652
652
|
official/projects/roformer/__init__.py,sha256=wrNoNRP6wGQBFDLVCYB01k1HgwOGlQ8iZnfqimxlDyk,610
|
653
653
|
official/projects/roformer/roformer.py,sha256=6FV81XLin-7TqZQLwmS2gdOpfb4S8jSXvoo1Og3vZ18,2002
|
654
|
-
official/projects/roformer/roformer_attention.py,sha256=
|
654
|
+
official/projects/roformer/roformer_attention.py,sha256=cR1YY-dWLGnFN4yyof9kYH9v_DM09a8SCA5THdjK52U,4638
|
655
655
|
official/projects/roformer/roformer_attention_test.py,sha256=IQpM6PWCRCilVBxpAk80tVnqhwcUVM_t1lio871yBkQ,4037
|
656
656
|
official/projects/roformer/roformer_encoder.py,sha256=InBAjuLgz_bAH5Sgqe32cgdhmh6bSsLMD-NaqFYqvwE,11374
|
657
657
|
official/projects/roformer/roformer_encoder_block.py,sha256=Cfe3NStzPrHK8JsL3DjoM7jemGCY2aWGu6a6d1Z1r9M,13507
|
@@ -662,7 +662,7 @@ official/projects/roformer/train.py,sha256=zDMF_2jbLB8XrOG1ea_sGIYSS0q-q8k3QVkON
|
|
662
662
|
official/projects/teams/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
663
663
|
official/projects/teams/teams.py,sha256=JwhdzJnbsZkMHefM0vBP4eAPadh11LbMkEpH-UFC-po,4010
|
664
664
|
official/projects/teams/teams_experiments.py,sha256=gLFMx-Osj5LqgcuBAwOJkWOM4dG3Z0UgOHARMB42y5g,4424
|
665
|
-
official/projects/teams/teams_pretrainer.py,sha256=
|
665
|
+
official/projects/teams/teams_pretrainer.py,sha256=MZv7Tc2WHPPoWcffI3Kk9l_UbdnkBJkzrdbFPKxeyag,18683
|
666
666
|
official/projects/teams/teams_pretrainer_test.py,sha256=IwABozbqvlDfYujZiGIIIaoX4Fi4E4GN43PXFxTpRPU,7436
|
667
667
|
official/projects/teams/teams_task.py,sha256=VmNuUhOkyttFImSh85CuqdnGg7og6wrx12M-vJQcEfA,10170
|
668
668
|
official/projects/teams/teams_task_test.py,sha256=_3A-z-zLhH8Y8LdPbBaQJRbzAJiune-8JpZA94Fvqow,2184
|
@@ -830,13 +830,13 @@ official/utils/misc/model_helpers.py,sha256=zqgwrD_oi23qBaY7PnMIBKWHqUsNGJRwgral
|
|
830
830
|
official/utils/misc/model_helpers_test.py,sha256=WS0Eb0pPGk746JSDXRJyZ0Yu95_RvwwmN0uJakRf-ls,4549
|
831
831
|
official/utils/testing/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
832
832
|
official/utils/testing/integration.py,sha256=Wac7nJTxnaK-XHt5pNbFD9eaP4CwZ9mnr7ZOMkJosEw,2220
|
833
|
-
official/utils/testing/mock_task.py,sha256=
|
833
|
+
official/utils/testing/mock_task.py,sha256=HPDArX_WR_ZZqom2iCSwn_zy8tbSjOGIN38KLL36bhw,3302
|
834
834
|
official/vision/__init__.py,sha256=g2WviTDdhFjzzfQROU7MV-xdCchqoyXvs_RlIE8fiTc,744
|
835
835
|
official/vision/registry_imports.py,sha256=1T4yHR55Tct02_E5P71jR0LtAkft84IZC2Zt2hwm7dI,760
|
836
836
|
official/vision/train.py,sha256=QQlYD8NtC3_JYLvPhV3i0GoqTUOaYoEHRHi2UoNracQ,3724
|
837
837
|
official/vision/train_spatial_partitioning.py,sha256=sXVjCU8r29aXP6vc8kp8DxciEhfdn3APoyQOr-iCQZ8,5725
|
838
838
|
official/vision/configs/__init__.py,sha256=5Uk9jeIYHTEHXs8kBrq-vMa7cgg_JpLgSYuVaLPyUt8,1045
|
839
|
-
official/vision/configs/backbones.py,sha256=
|
839
|
+
official/vision/configs/backbones.py,sha256=E87hCWSPU07fqS-SPkUYsFeapkSIjGQBIAU_dq-Knw4,5197
|
840
840
|
official/vision/configs/backbones_3d.py,sha256=YGmeo0y1sedOQpfXdcXUubmgPpDS8IlyjM1yBSBTxAU,3650
|
841
841
|
official/vision/configs/common.py,sha256=S_SYgZs9mSU5udq9xUuGlt52dzhdemobVzBFrGfhNPc,4621
|
842
842
|
official/vision/configs/decoders.py,sha256=eSUlND3gFJK_z9yrHdoRzJN-PY_8v2_TcgwY36G9OqU,1948
|
@@ -913,11 +913,11 @@ official/vision/modeling/classification_model_test.py,sha256=M8t4MA5cI42rmfTqGsX
|
|
913
913
|
official/vision/modeling/factory.py,sha256=EfuDNq9629SIRYn01Dy1h4gGyjPa2wfPROhQTL-9dQs,17437
|
914
914
|
official/vision/modeling/factory_3d.py,sha256=Gtz4xWxPQ4rmMn360lTMgw_Xjha5q6UYawotcerXpYM,3530
|
915
915
|
official/vision/modeling/factory_test.py,sha256=I1aevMcMbD6dihHfU5ohFV4R2lHEdVLvdVra7QE-kU4,5286
|
916
|
-
official/vision/modeling/maskrcnn_model.py,sha256=
|
916
|
+
official/vision/modeling/maskrcnn_model.py,sha256=1eb4mIbZeYA1ofBk2MosF6VsNRDJph9gBPCxU0Zb95w,20733
|
917
917
|
official/vision/modeling/maskrcnn_model_test.py,sha256=eJ0WtCGFztZj7zbgAd0VDF4_PzsnLgPK0k4r1n85X5k,14834
|
918
918
|
official/vision/modeling/retinanet_model.py,sha256=kT9j8khuYxZeRbJ070Ys84y1hU412VzZcHpJVpBEe4Q,9695
|
919
919
|
official/vision/modeling/retinanet_model_test.py,sha256=_PTpYedMG0L8rq6Xydh1qR0vKpvE4uDu1M6HZF4M3Ig,11148
|
920
|
-
official/vision/modeling/segmentation_model.py,sha256=
|
920
|
+
official/vision/modeling/segmentation_model.py,sha256=Vyhq0G2_Smf0RCu1gRhVhu3C145r1lRRWAie6keU60M,3428
|
921
921
|
official/vision/modeling/segmentation_model_test.py,sha256=Z5toOhb3dIyU7DE38pLOalwBYKx2n1prIxawkOu4GHI,2807
|
922
922
|
official/vision/modeling/video_classification_model.py,sha256=em4Yt3grA7Y7GA9bSzPSCCrc1_l_RNUu6qEWmCZUrZ4,4703
|
923
923
|
official/vision/modeling/video_classification_model_test.py,sha256=eX0sBnEdLwPV7qPfR9nSApRw1qd0iUPew1yCLMn4j4E,3261
|
@@ -942,7 +942,7 @@ official/vision/modeling/backbones/spinenet.py,sha256=uIaJdINdVnvyh5xIs_1oRMajQA
|
|
942
942
|
official/vision/modeling/backbones/spinenet_mobile.py,sha256=vVpqM7oKSQMsCTJtZZBPWnbKei5oQHq33woOwIHiF8w,20774
|
943
943
|
official/vision/modeling/backbones/spinenet_mobile_test.py,sha256=UvLU7Jutk7f5q8Vz3jflWHlC-itRMvNJtB5-cXA0JxA,3948
|
944
944
|
official/vision/modeling/backbones/spinenet_test.py,sha256=YhQRhc3bDhpNd-EUQYxDkLSfrOhKjcpeUeh-nQ__pZg,4724
|
945
|
-
official/vision/modeling/backbones/vit.py,sha256=
|
945
|
+
official/vision/modeling/backbones/vit.py,sha256=eO18uyTYZ7Uy9FKp-m1zVpNqy4AHwIMBvgJqpwJ6IiY,14428
|
946
946
|
official/vision/modeling/backbones/vit_specs.py,sha256=Qfa5Oecd4rCWA73m8X_g8J_w1fpCZsNRSUw0qLIhO9Q,2412
|
947
947
|
official/vision/modeling/backbones/vit_test.py,sha256=aP11sEVNfOvb8i3VfRd3k_3qHWqFvunEYO9DRtCTZrA,3453
|
948
948
|
official/vision/modeling/decoders/__init__.py,sha256=F7uTs9fJjPKacKgZnuRRm21t1N4cV_c4o76dDQS9VkM,815
|
@@ -959,7 +959,7 @@ official/vision/modeling/heads/dense_prediction_heads.py,sha256=Xmtb4moAiinD3-wC
|
|
959
959
|
official/vision/modeling/heads/dense_prediction_heads_test.py,sha256=8O1K4iksJ7IBRVfiHPjaolcAYd3sl0DxOoTljfJf_c0,7227
|
960
960
|
official/vision/modeling/heads/instance_heads.py,sha256=P580eVgYex6ddcua1zTkXEmgOqPd1xnmqJbrM6Dz6Hk,17819
|
961
961
|
official/vision/modeling/heads/instance_heads_test.py,sha256=Hjsb196FnfJtG4NMXmZuibo_sgBINPy-qzPyXZ_Ru9A,4197
|
962
|
-
official/vision/modeling/heads/segmentation_heads.py,sha256=
|
962
|
+
official/vision/modeling/heads/segmentation_heads.py,sha256=ABDjiLFo7pmyV4uodPRXQrNxWHhzhcQ7xHgrjEsmPJk,20174
|
963
963
|
official/vision/modeling/heads/segmentation_heads_test.py,sha256=VJOg0_VcpdyMj3HYfxG1BcpHm0D9B5yEw6kZISiZCY8,3802
|
964
964
|
official/vision/modeling/layers/__init__.py,sha256=Yk0zkJVVgwi--tv5-NKIAKHCzpFw9tKrg5wYm5CZwh4,2597
|
965
965
|
official/vision/modeling/layers/box_sampler.py,sha256=hsFoo9lsD57LDUnj0SlIoM9QNg26yqQxvADPtQRVhV8,3401
|
@@ -970,11 +970,11 @@ official/vision/modeling/layers/detection_generator_test.py,sha256=-XXoY6dyjE1R3
|
|
970
970
|
official/vision/modeling/layers/edgetpu.py,sha256=AmfMIxT1XCSZo0zXoZO0h-WOdxNYYYZo4lGf_d7DL6g,16727
|
971
971
|
official/vision/modeling/layers/edgetpu_test.py,sha256=aIOtLXD1jCExn3tVFS-q6yiCZV7yIxCvG06p7fQdCgs,10165
|
972
972
|
official/vision/modeling/layers/mask_sampler.py,sha256=1INadZeRmqZoiFeEHiE_IMq39k8f2S66VvoyMxcM1mk,7914
|
973
|
-
official/vision/modeling/layers/nn_blocks.py,sha256=
|
973
|
+
official/vision/modeling/layers/nn_blocks.py,sha256=wU0OK-l1TPeMpnNmJwmUxypNpag2wXsaIhYp2NunmLo,72079
|
974
974
|
official/vision/modeling/layers/nn_blocks_3d.py,sha256=b0h_wWmzgrRUaG03eziHoAqBpqZJyfSn8riHfFbEKzo,10565
|
975
975
|
official/vision/modeling/layers/nn_blocks_3d_test.py,sha256=pSmDiikoPq83gZIP9hyZ0UbmJxMxVxippuLd38mHR74,2028
|
976
976
|
official/vision/modeling/layers/nn_blocks_test.py,sha256=-NNR8zVnkgkyLTyNzOL2BnkvH5IWz1ggO3sXCeYe-E4,33825
|
977
|
-
official/vision/modeling/layers/nn_layers.py,sha256=
|
977
|
+
official/vision/modeling/layers/nn_layers.py,sha256=Bv18J7Sa-FHLrlIru9vSXiJDOlCw2M9xtoOQH0_uuPM,52972
|
978
978
|
official/vision/modeling/layers/nn_layers_test.py,sha256=NMthj6XogpbTcbhjxsMiO8WEgq6ha4uLJwwqMDOEY5I,13309
|
979
979
|
official/vision/modeling/layers/roi_aligner.py,sha256=v09HensTcogJDNzw2dIUtvbQhoPzaxvSVLVsbISwYrw,2533
|
980
980
|
official/vision/modeling/layers/roi_aligner_test.py,sha256=p5_zu8nXS2L1mV5zX9iGIgC7OQc1gd0Fw7o1AffyzhA,1275
|
@@ -982,7 +982,7 @@ official/vision/modeling/layers/roi_generator.py,sha256=VGNMrm0_td1gaE7ENK-XzGEy
|
|
982
982
|
official/vision/modeling/layers/roi_sampler.py,sha256=3RZBvQ3esodqT_BzDvAG0r4x7mxOHKoxLwSk0V_JNEs,9996
|
983
983
|
official/vision/modeling/models/__init__.py,sha256=dKnmOzbcEAJcah2R85pfGE37VZYY090oMy9ol4eHk8k,1020
|
984
984
|
official/vision/ops/__init__.py,sha256=1ToRMjre4mErL4Ek4_dMVxMjXNPossNXggV8fqbISao,609
|
985
|
-
official/vision/ops/anchor.py,sha256=
|
985
|
+
official/vision/ops/anchor.py,sha256=iS97q20CKozhg9tpqFkkDGZ66MqrIfCYjSk3ZjndrRM,19102
|
986
986
|
official/vision/ops/anchor_generator.py,sha256=WwASWaE5uR9oIB69lnK8mW2xUfRnLY0LsgyVX1BwsQQ,7234
|
987
987
|
official/vision/ops/anchor_generator_test.py,sha256=aihKzRV8SQFhNSzS9_SpJNSadL4KcUcj0vT5dgmnkIQ,5286
|
988
988
|
official/vision/ops/anchor_test.py,sha256=UHue1kHD184Lh5eJzWQ2w0UAZk9FjolYJ9Eg6b1hIbo,7623
|
@@ -1020,7 +1020,6 @@ official/vision/serving/export_tfhub.py,sha256=ES-BDGUjhfv8u93S1RKnY-ZJw-g25NAxK
|
|
1020
1020
|
official/vision/serving/export_tfhub_lib.py,sha256=x_wtofpZH-eb8y92IBdPexjta0ecG6l-9HoHOePzt0g,2870
|
1021
1021
|
official/vision/serving/export_tflite.py,sha256=a-j5An1DigxPxpIg8IcqH2uQzaSXIQRL9irD2HmkAVg,5105
|
1022
1022
|
official/vision/serving/export_tflite_lib.py,sha256=3zOA1W1-yLcHCzngtgauyyiyidOPZXuJR35HSDP3Rwo,6907
|
1023
|
-
official/vision/serving/export_tflite_lib_test.py,sha256=2TW-215vjJdLT6lPYUtS_tSxYYNhd0-JbfvPK46oUTc,7185
|
1024
1023
|
official/vision/serving/export_utils.py,sha256=tPzbbJmbu7rLjiu5GRdCcaU6ADybNbAPH4DalJNURbo,4871
|
1025
1024
|
official/vision/serving/image_classification.py,sha256=YWHm-VPnIpv4Hn973jM39saR2NZGB_4n8CSmJA9tIvw,2858
|
1026
1025
|
official/vision/serving/image_classification_test.py,sha256=VwQc5WBRAMaB7Tm0TT2SJ2svoT2V9-OAWbZgT4QdsYA,4612
|
@@ -1084,9 +1083,9 @@ tensorflow_models/__init__.py,sha256=Ciz_YBke6teb6y42QyQTUBDdXJAiV7Qdu1zOoZvYiKw
|
|
1084
1083
|
tensorflow_models/tensorflow_models_test.py,sha256=Kz2y4V-rtBhZFFfKD2soCq52hviSfJVV1L2ztqS-9oM,1385
|
1085
1084
|
tensorflow_models/nlp/__init__.py,sha256=3dULDpUBpDi9vljpXadq6oJrWH4y6z42Bz2d3hopYZw,807
|
1086
1085
|
tensorflow_models/vision/__init__.py,sha256=4y77XkHaH8qLls3-6ta4tMp3Xj8CLbB0ihH91HsQ9z4,833
|
1087
|
-
tf_models_nightly-2.12.0.
|
1088
|
-
tf_models_nightly-2.12.0.
|
1089
|
-
tf_models_nightly-2.12.0.
|
1090
|
-
tf_models_nightly-2.12.0.
|
1091
|
-
tf_models_nightly-2.12.0.
|
1092
|
-
tf_models_nightly-2.12.0.
|
1086
|
+
tf_models_nightly-2.12.0.dev20230501.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1087
|
+
tf_models_nightly-2.12.0.dev20230501.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1088
|
+
tf_models_nightly-2.12.0.dev20230501.dist-info/METADATA,sha256=RngQsLLXQRNghqFmRtBPUdSsXzWsnGQdOWQPdT6q2vg,1393
|
1089
|
+
tf_models_nightly-2.12.0.dev20230501.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1090
|
+
tf_models_nightly-2.12.0.dev20230501.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1091
|
+
tf_models_nightly-2.12.0.dev20230501.dist-info/RECORD,,
|
@@ -1,192 +0,0 @@
|
|
1
|
-
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
"""Tests for export_tflite_lib."""
|
16
|
-
import os
|
17
|
-
|
18
|
-
from absl.testing import parameterized
|
19
|
-
import tensorflow as tf
|
20
|
-
|
21
|
-
from tensorflow.python.distribute import combinations
|
22
|
-
from official.core import exp_factory
|
23
|
-
from official.vision import registry_imports # pylint: disable=unused-import
|
24
|
-
from official.vision.dataloaders import tfexample_utils
|
25
|
-
from official.vision.serving import detection as detection_serving
|
26
|
-
from official.vision.serving import export_tflite_lib
|
27
|
-
from official.vision.serving import image_classification as image_classification_serving
|
28
|
-
from official.vision.serving import semantic_segmentation as semantic_segmentation_serving
|
29
|
-
|
30
|
-
|
31
|
-
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
|
32
|
-
|
33
|
-
def setUp(self):
|
34
|
-
super().setUp()
|
35
|
-
# Create test data for image classification.
|
36
|
-
self.test_tfrecord_file_cls = os.path.join(self.get_temp_dir(),
|
37
|
-
'cls_test.tfrecord')
|
38
|
-
example = tf.train.Example.FromString(
|
39
|
-
tfexample_utils.create_classification_example(
|
40
|
-
image_height=224, image_width=224))
|
41
|
-
self._create_test_tfrecord(
|
42
|
-
tfrecord_file=self.test_tfrecord_file_cls,
|
43
|
-
example=example,
|
44
|
-
num_samples=10)
|
45
|
-
|
46
|
-
# Create test data for object detection.
|
47
|
-
self.test_tfrecord_file_det = os.path.join(self.get_temp_dir(),
|
48
|
-
'det_test.tfrecord')
|
49
|
-
example = tfexample_utils.create_detection_test_example(
|
50
|
-
image_height=128, image_width=128, image_channel=3, num_instances=10)
|
51
|
-
self._create_test_tfrecord(
|
52
|
-
tfrecord_file=self.test_tfrecord_file_det,
|
53
|
-
example=example,
|
54
|
-
num_samples=10)
|
55
|
-
|
56
|
-
# Create test data for semantic segmentation.
|
57
|
-
self.test_tfrecord_file_seg = os.path.join(self.get_temp_dir(),
|
58
|
-
'seg_test.tfrecord')
|
59
|
-
example = tfexample_utils.create_segmentation_test_example(
|
60
|
-
image_height=512, image_width=512, image_channel=3)
|
61
|
-
self._create_test_tfrecord(
|
62
|
-
tfrecord_file=self.test_tfrecord_file_seg,
|
63
|
-
example=example,
|
64
|
-
num_samples=10)
|
65
|
-
|
66
|
-
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
|
67
|
-
examples = [example] * num_samples
|
68
|
-
tfexample_utils.dump_to_tfrecord(
|
69
|
-
record_file=tfrecord_file, tf_examples=examples)
|
70
|
-
|
71
|
-
def _export_from_module(self, module, input_type, saved_model_dir):
|
72
|
-
signatures = module.get_inference_signatures(
|
73
|
-
{input_type: 'serving_default'})
|
74
|
-
tf.saved_model.save(module, saved_model_dir, signatures=signatures)
|
75
|
-
|
76
|
-
@combinations.generate(
|
77
|
-
combinations.combine(
|
78
|
-
experiment=['mobilenet_imagenet'],
|
79
|
-
quant_type=[
|
80
|
-
None,
|
81
|
-
'default',
|
82
|
-
'fp16',
|
83
|
-
'int8_fallback',
|
84
|
-
'int8_full',
|
85
|
-
'int8_full_fp32_io',
|
86
|
-
'int8_full_int8_io',
|
87
|
-
]))
|
88
|
-
def test_export_tflite_image_classification(self, experiment, quant_type):
|
89
|
-
|
90
|
-
params = exp_factory.get_exp_config(experiment)
|
91
|
-
params.task.validation_data.input_path = self.test_tfrecord_file_cls
|
92
|
-
params.task.train_data.input_path = self.test_tfrecord_file_cls
|
93
|
-
params.task.train_data.shuffle_buffer_size = 10
|
94
|
-
temp_dir = self.get_temp_dir()
|
95
|
-
module = image_classification_serving.ClassificationModule(
|
96
|
-
params=params,
|
97
|
-
batch_size=1,
|
98
|
-
input_image_size=[224, 224],
|
99
|
-
input_type='tflite')
|
100
|
-
self._export_from_module(
|
101
|
-
module=module,
|
102
|
-
input_type='tflite',
|
103
|
-
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
|
104
|
-
|
105
|
-
tflite_model = export_tflite_lib.convert_tflite_model(
|
106
|
-
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
|
107
|
-
quant_type=quant_type,
|
108
|
-
params=params,
|
109
|
-
calibration_steps=5)
|
110
|
-
|
111
|
-
self.assertIsInstance(tflite_model, bytes)
|
112
|
-
|
113
|
-
@combinations.generate(
|
114
|
-
combinations.combine(
|
115
|
-
experiment=['retinanet_mobile_coco'],
|
116
|
-
quant_type=[
|
117
|
-
None,
|
118
|
-
'default',
|
119
|
-
'fp16',
|
120
|
-
'int8_fallback',
|
121
|
-
'int8_full',
|
122
|
-
'int8_full_fp32_io',
|
123
|
-
'int8_full_int8_io',
|
124
|
-
]))
|
125
|
-
def test_export_tflite_detection(self, experiment, quant_type):
|
126
|
-
|
127
|
-
params = exp_factory.get_exp_config(experiment)
|
128
|
-
params.task.validation_data.input_path = self.test_tfrecord_file_det
|
129
|
-
params.task.train_data.input_path = self.test_tfrecord_file_det
|
130
|
-
params.task.model.num_classes = 2
|
131
|
-
params.task.model.backbone.spinenet_mobile.model_id = '49XS'
|
132
|
-
params.task.model.input_size = [128, 128, 3]
|
133
|
-
params.task.model.detection_generator.nms_version = 'v1'
|
134
|
-
params.task.train_data.shuffle_buffer_size = 5
|
135
|
-
temp_dir = self.get_temp_dir()
|
136
|
-
module = detection_serving.DetectionModule(
|
137
|
-
params=params,
|
138
|
-
batch_size=1,
|
139
|
-
input_image_size=[128, 128],
|
140
|
-
input_type='tflite')
|
141
|
-
self._export_from_module(
|
142
|
-
module=module,
|
143
|
-
input_type='tflite',
|
144
|
-
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
|
145
|
-
|
146
|
-
tflite_model = export_tflite_lib.convert_tflite_model(
|
147
|
-
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
|
148
|
-
quant_type=quant_type,
|
149
|
-
params=params,
|
150
|
-
calibration_steps=1)
|
151
|
-
|
152
|
-
self.assertIsInstance(tflite_model, bytes)
|
153
|
-
|
154
|
-
@combinations.generate(
|
155
|
-
combinations.combine(
|
156
|
-
experiment=['mnv2_deeplabv3_pascal'],
|
157
|
-
quant_type=[
|
158
|
-
None,
|
159
|
-
'default',
|
160
|
-
'fp16',
|
161
|
-
'int8_fallback',
|
162
|
-
'int8_full',
|
163
|
-
'int8_full_fp32_io',
|
164
|
-
'int8_full_int8_io',
|
165
|
-
]))
|
166
|
-
def test_export_tflite_semantic_segmentation(self, experiment, quant_type):
|
167
|
-
|
168
|
-
params = exp_factory.get_exp_config(experiment)
|
169
|
-
params.task.validation_data.input_path = self.test_tfrecord_file_seg
|
170
|
-
params.task.train_data.input_path = self.test_tfrecord_file_seg
|
171
|
-
params.task.train_data.shuffle_buffer_size = 10
|
172
|
-
temp_dir = self.get_temp_dir()
|
173
|
-
module = semantic_segmentation_serving.SegmentationModule(
|
174
|
-
params=params,
|
175
|
-
batch_size=1,
|
176
|
-
input_image_size=[512, 512],
|
177
|
-
input_type='tflite')
|
178
|
-
self._export_from_module(
|
179
|
-
module=module,
|
180
|
-
input_type='tflite',
|
181
|
-
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
|
182
|
-
|
183
|
-
tflite_model = export_tflite_lib.convert_tflite_model(
|
184
|
-
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
|
185
|
-
quant_type=quant_type,
|
186
|
-
params=params,
|
187
|
-
calibration_steps=5)
|
188
|
-
|
189
|
-
self.assertIsInstance(tflite_model, bytes)
|
190
|
-
|
191
|
-
if __name__ == '__main__':
|
192
|
-
tf.test.main()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|