tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250109__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-models-nightly might be problematic. Click here for more details.
- official/projects/detr/__init__.py +14 -0
- official/projects/detr/configs/__init__.py +14 -0
- official/projects/detr/configs/detr.py +277 -0
- official/projects/detr/configs/detr_test.py +51 -0
- official/projects/detr/dataloaders/__init__.py +14 -0
- official/projects/detr/dataloaders/coco.py +157 -0
- official/projects/detr/dataloaders/coco_test.py +111 -0
- official/projects/detr/dataloaders/detr_input.py +175 -0
- official/projects/detr/experiments/__init__.py +14 -0
- official/projects/detr/modeling/__init__.py +14 -0
- official/projects/detr/modeling/detr.py +345 -0
- official/projects/detr/modeling/detr_test.py +70 -0
- official/projects/detr/modeling/transformer.py +849 -0
- official/projects/detr/modeling/transformer_test.py +263 -0
- official/projects/detr/ops/__init__.py +14 -0
- official/projects/detr/ops/matchers.py +489 -0
- official/projects/detr/ops/matchers_test.py +95 -0
- official/projects/detr/optimization.py +151 -0
- official/projects/detr/serving/__init__.py +14 -0
- official/projects/detr/serving/export_module.py +103 -0
- official/projects/detr/serving/export_module_test.py +98 -0
- official/projects/detr/serving/export_saved_model.py +109 -0
- official/projects/detr/tasks/__init__.py +14 -0
- official/projects/detr/tasks/detection.py +421 -0
- official/projects/detr/tasks/detection_test.py +203 -0
- official/projects/detr/train.py +70 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/RECORD +32 -6
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Customized optimizer to match paper results."""
|
16
|
+
|
17
|
+
import dataclasses
|
18
|
+
import tensorflow as tf, tf_keras
|
19
|
+
from official.modeling import optimization
|
20
|
+
from official.nlp import optimization as nlp_optimization
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class DETRAdamWConfig(optimization.AdamWeightDecayConfig):
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
@dataclasses.dataclass
|
29
|
+
class OptimizerConfig(optimization.OptimizerConfig):
|
30
|
+
detr_adamw: DETRAdamWConfig = dataclasses.field(
|
31
|
+
default_factory=DETRAdamWConfig
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
@dataclasses.dataclass
|
36
|
+
class OptimizationConfig(optimization.OptimizationConfig):
|
37
|
+
"""Configuration for optimizer and learning rate schedule.
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
optimizer: optimizer oneof config.
|
41
|
+
ema: optional exponential moving average optimizer config, if specified, ema
|
42
|
+
optimizer will be used.
|
43
|
+
learning_rate: learning rate oneof config.
|
44
|
+
warmup: warmup oneof config.
|
45
|
+
"""
|
46
|
+
optimizer: OptimizerConfig = dataclasses.field(
|
47
|
+
default_factory=OptimizerConfig
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
# TODO(frederickliu): figure out how to make this configuable.
|
52
|
+
# TODO(frederickliu): Study if this is needed.
|
53
|
+
class _DETRAdamW(nlp_optimization.AdamWeightDecay):
|
54
|
+
"""Custom AdamW to support different lr scaling for backbone.
|
55
|
+
|
56
|
+
The code is copied from AdamWeightDecay and Adam with learning scaling.
|
57
|
+
"""
|
58
|
+
|
59
|
+
def _resource_apply_dense(self, grad, var, apply_state=None):
|
60
|
+
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
61
|
+
apply_state = kwargs['apply_state']
|
62
|
+
if 'detr' not in var.name:
|
63
|
+
lr_t *= 0.1
|
64
|
+
decay = self._decay_weights_op(var, lr_t, apply_state)
|
65
|
+
with tf.control_dependencies([decay]):
|
66
|
+
var_device, var_dtype = var.device, var.dtype.base_dtype
|
67
|
+
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
68
|
+
or self._fallback_apply_state(var_device, var_dtype))
|
69
|
+
|
70
|
+
m = self.get_slot(var, 'm')
|
71
|
+
v = self.get_slot(var, 'v')
|
72
|
+
lr = coefficients[
|
73
|
+
'lr_t'] * 0.1 if 'detr' not in var.name else coefficients['lr_t']
|
74
|
+
|
75
|
+
if not self.amsgrad:
|
76
|
+
return tf.raw_ops.ResourceApplyAdam(
|
77
|
+
var=var.handle,
|
78
|
+
m=m.handle,
|
79
|
+
v=v.handle,
|
80
|
+
beta1_power=coefficients['beta_1_power'],
|
81
|
+
beta2_power=coefficients['beta_2_power'],
|
82
|
+
lr=lr,
|
83
|
+
beta1=coefficients['beta_1_t'],
|
84
|
+
beta2=coefficients['beta_2_t'],
|
85
|
+
epsilon=coefficients['epsilon'],
|
86
|
+
grad=grad,
|
87
|
+
use_locking=self._use_locking)
|
88
|
+
else:
|
89
|
+
vhat = self.get_slot(var, 'vhat')
|
90
|
+
return tf.raw_ops.ResourceApplyAdamWithAmsgrad(
|
91
|
+
var=var.handle,
|
92
|
+
m=m.handle,
|
93
|
+
v=v.handle,
|
94
|
+
vhat=vhat.handle,
|
95
|
+
beta1_power=coefficients['beta_1_power'],
|
96
|
+
beta2_power=coefficients['beta_2_power'],
|
97
|
+
lr=lr,
|
98
|
+
beta1=coefficients['beta_1_t'],
|
99
|
+
beta2=coefficients['beta_2_t'],
|
100
|
+
epsilon=coefficients['epsilon'],
|
101
|
+
grad=grad,
|
102
|
+
use_locking=self._use_locking)
|
103
|
+
|
104
|
+
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
105
|
+
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
106
|
+
apply_state = kwargs['apply_state']
|
107
|
+
if 'detr' not in var.name:
|
108
|
+
lr_t *= 0.1
|
109
|
+
decay = self._decay_weights_op(var, lr_t, apply_state)
|
110
|
+
with tf.control_dependencies([decay]):
|
111
|
+
var_device, var_dtype = var.device, var.dtype.base_dtype
|
112
|
+
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
113
|
+
or self._fallback_apply_state(var_device, var_dtype))
|
114
|
+
|
115
|
+
# m_t = beta1 * m + (1 - beta1) * g_t
|
116
|
+
m = self.get_slot(var, 'm')
|
117
|
+
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
|
118
|
+
m_t = tf.compat.v1.assign(m, m * coefficients['beta_1_t'],
|
119
|
+
use_locking=self._use_locking)
|
120
|
+
with tf.control_dependencies([m_t]):
|
121
|
+
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
|
122
|
+
|
123
|
+
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
124
|
+
v = self.get_slot(var, 'v')
|
125
|
+
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
|
126
|
+
v_t = tf.compat.v1.assign(v, v * coefficients['beta_2_t'],
|
127
|
+
use_locking=self._use_locking)
|
128
|
+
with tf.control_dependencies([v_t]):
|
129
|
+
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
|
130
|
+
lr = coefficients[
|
131
|
+
'lr_t'] * 0.1 if 'detr' not in var.name else coefficients['lr_t']
|
132
|
+
if not self.amsgrad:
|
133
|
+
v_sqrt = tf.sqrt(v_t)
|
134
|
+
var_update = tf.compat.v1.assign_sub(
|
135
|
+
var, lr * m_t / (v_sqrt + coefficients['epsilon']),
|
136
|
+
use_locking=self._use_locking)
|
137
|
+
return tf.group(*[var_update, m_t, v_t])
|
138
|
+
else:
|
139
|
+
v_hat = self.get_slot(var, 'vhat')
|
140
|
+
v_hat_t = tf.maximum(v_hat, v_t)
|
141
|
+
with tf.control_dependencies([v_hat_t]):
|
142
|
+
v_hat_t = tf.compat.v1.assign(
|
143
|
+
v_hat, v_hat_t, use_locking=self._use_locking)
|
144
|
+
v_hat_sqrt = tf.sqrt(v_hat_t)
|
145
|
+
var_update = tf.compat.v1.assign_sub(
|
146
|
+
var,
|
147
|
+
lr* m_t / (v_hat_sqrt + coefficients['epsilon']),
|
148
|
+
use_locking=self._use_locking)
|
149
|
+
return tf.group(*[var_update, m_t, v_t, v_hat_t])
|
150
|
+
|
151
|
+
optimization.register_optimizer_cls('detr_adamw', _DETRAdamW)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Export module for DETR model."""
|
16
|
+
import tensorflow as tf, tf_keras
|
17
|
+
|
18
|
+
from official.projects.detr.modeling import detr
|
19
|
+
from official.vision.modeling import backbones
|
20
|
+
from official.vision.ops import preprocess_ops
|
21
|
+
from official.vision.serving import detection
|
22
|
+
|
23
|
+
|
24
|
+
class DETRModule(detection.DetectionModule):
|
25
|
+
"""DETR detection module."""
|
26
|
+
|
27
|
+
def _build_model(self) -> tf_keras.Model:
|
28
|
+
input_specs = tf_keras.layers.InputSpec(shape=[self._batch_size] +
|
29
|
+
self._input_image_size +
|
30
|
+
[self._num_channels])
|
31
|
+
|
32
|
+
backbone = backbones.factory.build_backbone(
|
33
|
+
input_specs=input_specs,
|
34
|
+
backbone_config=self.params.task.model.backbone,
|
35
|
+
norm_activation_config=self.params.task.model.norm_activation)
|
36
|
+
|
37
|
+
model = detr.DETR(backbone, self.params.task.model.backbone_endpoint_name,
|
38
|
+
self.params.task.model.num_queries,
|
39
|
+
self.params.task.model.hidden_size,
|
40
|
+
self.params.task.model.num_classes,
|
41
|
+
self.params.task.model.num_encoder_layers,
|
42
|
+
self.params.task.model.num_decoder_layers)
|
43
|
+
model(tf_keras.Input(input_specs.shape[1:]))
|
44
|
+
return model
|
45
|
+
|
46
|
+
def _build_inputs(self, image: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
47
|
+
"""Builds detection model inputs for serving."""
|
48
|
+
# Normalizes image with mean and std pixel values.
|
49
|
+
image = preprocess_ops.normalize_image(
|
50
|
+
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
|
51
|
+
|
52
|
+
image, image_info = preprocess_ops.resize_image(
|
53
|
+
image, size=self._input_image_size)
|
54
|
+
|
55
|
+
return image, image_info
|
56
|
+
|
57
|
+
def serve(self, images: tf.Tensor) -> dict[str, tf.Tensor]:
|
58
|
+
"""Cast image to float and run inference.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
images: uint8 Tensor of shape [batch_size, None, None, 3]
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Tensor holding classification output logits.
|
65
|
+
"""
|
66
|
+
# Skip image preprocessing when input_type is tflite so it is compatible
|
67
|
+
# with TFLite quantization.
|
68
|
+
image_info = None
|
69
|
+
if self._input_type != 'tflite':
|
70
|
+
with tf.device('cpu:0'):
|
71
|
+
images = tf.cast(images, dtype=tf.float32)
|
72
|
+
|
73
|
+
images_spec = tf.TensorSpec(
|
74
|
+
shape=self._input_image_size + [3], dtype=tf.float32)
|
75
|
+
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
|
76
|
+
|
77
|
+
images, image_info = tf.nest.map_structure(
|
78
|
+
tf.identity,
|
79
|
+
tf.map_fn(
|
80
|
+
self._build_inputs,
|
81
|
+
elems=images,
|
82
|
+
fn_output_signature=(images_spec, image_info_spec),
|
83
|
+
parallel_iterations=32))
|
84
|
+
|
85
|
+
outputs = self.inference_step(images)[-1]
|
86
|
+
outputs = {
|
87
|
+
'detection_boxes': outputs['detection_boxes'],
|
88
|
+
'detection_scores': outputs['detection_scores'],
|
89
|
+
'detection_classes': outputs['detection_classes'],
|
90
|
+
'num_detections': outputs['num_detections']
|
91
|
+
}
|
92
|
+
if image_info is not None:
|
93
|
+
outputs['detection_boxes'] = outputs['detection_boxes'] * tf.expand_dims(
|
94
|
+
tf.concat([
|
95
|
+
image_info[:, 1:2, 0], image_info[:, 1:2, 1],
|
96
|
+
image_info[:, 1:2, 0], image_info[:, 1:2, 1]
|
97
|
+
],
|
98
|
+
axis=1),
|
99
|
+
axis=1)
|
100
|
+
|
101
|
+
outputs.update({'image_info': image_info})
|
102
|
+
|
103
|
+
return outputs
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Test for DETR export module."""
|
16
|
+
|
17
|
+
import io
|
18
|
+
import os
|
19
|
+
|
20
|
+
from absl.testing import parameterized
|
21
|
+
import numpy as np
|
22
|
+
from PIL import Image
|
23
|
+
import tensorflow as tf, tf_keras
|
24
|
+
|
25
|
+
from official.core import exp_factory
|
26
|
+
from official.projects.detr.configs import detr as exp_cfg # pylint: disable=unused-import
|
27
|
+
from official.projects.detr.serving import export_module
|
28
|
+
|
29
|
+
|
30
|
+
class ExportModuleTest(tf.test.TestCase, parameterized.TestCase):
|
31
|
+
|
32
|
+
def _get_module(self, input_type):
|
33
|
+
params = exp_factory.get_exp_config('detr_coco')
|
34
|
+
return export_module.DETRModule(
|
35
|
+
params,
|
36
|
+
batch_size=1,
|
37
|
+
input_image_size=[384, 384],
|
38
|
+
input_type=input_type)
|
39
|
+
|
40
|
+
def _export_from_module(self, module, input_type, save_directory):
|
41
|
+
signatures = module.get_inference_signatures(
|
42
|
+
{input_type: 'serving_default'})
|
43
|
+
tf.saved_model.save(module, save_directory, signatures=signatures)
|
44
|
+
|
45
|
+
def _get_dummy_input(self, input_type):
|
46
|
+
"""Gets dummy input for the given input type."""
|
47
|
+
|
48
|
+
if input_type == 'image_tensor':
|
49
|
+
return tf.zeros((1, 384, 384, 3), dtype=np.uint8)
|
50
|
+
elif input_type == 'image_bytes':
|
51
|
+
image = Image.fromarray(np.zeros((384, 384, 3), dtype=np.uint8))
|
52
|
+
byte_io = io.BytesIO()
|
53
|
+
image.save(byte_io, 'PNG')
|
54
|
+
return [byte_io.getvalue()]
|
55
|
+
elif input_type == 'tf_example':
|
56
|
+
image_tensor = tf.zeros((384, 384, 3), dtype=tf.uint8)
|
57
|
+
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
|
58
|
+
example = tf.train.Example(
|
59
|
+
features=tf.train.Features(
|
60
|
+
feature={
|
61
|
+
'image/encoded':
|
62
|
+
tf.train.Feature(
|
63
|
+
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
|
64
|
+
})).SerializeToString()
|
65
|
+
return [example]
|
66
|
+
|
67
|
+
@parameterized.parameters(
|
68
|
+
{'input_type': 'image_tensor'},
|
69
|
+
{'input_type': 'image_bytes'},
|
70
|
+
{'input_type': 'tf_example'},
|
71
|
+
)
|
72
|
+
def test_export(self, input_type='image_tensor'):
|
73
|
+
tmp_dir = self.get_temp_dir()
|
74
|
+
module = self._get_module(input_type)
|
75
|
+
self._export_from_module(module, input_type, tmp_dir)
|
76
|
+
|
77
|
+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
|
78
|
+
self.assertTrue(
|
79
|
+
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
|
80
|
+
self.assertTrue(
|
81
|
+
os.path.exists(
|
82
|
+
os.path.join(tmp_dir, 'variables',
|
83
|
+
'variables.data-00000-of-00001')))
|
84
|
+
|
85
|
+
imported = tf.saved_model.load(tmp_dir)
|
86
|
+
predict_fn = imported.signatures['serving_default']
|
87
|
+
|
88
|
+
images = self._get_dummy_input(input_type)
|
89
|
+
outputs = predict_fn(tf.constant(images))
|
90
|
+
|
91
|
+
self.assertNotEmpty(outputs['detection_boxes'])
|
92
|
+
self.assertNotEmpty(outputs['detection_classes'])
|
93
|
+
self.assertNotEmpty(outputs['detection_scores'])
|
94
|
+
self.assertNotEmpty(outputs['num_detections'])
|
95
|
+
|
96
|
+
|
97
|
+
if __name__ == '__main__':
|
98
|
+
tf.test.main()
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
r"""Vision models export binary for serving/inference.
|
16
|
+
|
17
|
+
To export a trained checkpoint in saved_model format (shell script):
|
18
|
+
|
19
|
+
EXPERIMENT_TYPE = XX
|
20
|
+
CHECKPOINT_PATH = XX
|
21
|
+
EXPORT_DIR_PATH = XX
|
22
|
+
export_saved_model --experiment=${EXPERIMENT_TYPE} \
|
23
|
+
--export_dir=${EXPORT_DIR_PATH}/ \
|
24
|
+
--checkpoint_path=${CHECKPOINT_PATH} \
|
25
|
+
--batch_size=2 \
|
26
|
+
--input_image_size=224,224
|
27
|
+
|
28
|
+
To serve (python):
|
29
|
+
|
30
|
+
export_dir_path = XX
|
31
|
+
input_type = XX
|
32
|
+
input_images = XX
|
33
|
+
imported = tf.saved_model.load(export_dir_path)
|
34
|
+
model_fn = imported.signatures['serving_default']
|
35
|
+
output = model_fn(input_images)
|
36
|
+
"""
|
37
|
+
|
38
|
+
from absl import app
|
39
|
+
from absl import flags
|
40
|
+
|
41
|
+
from official.core import exp_factory
|
42
|
+
from official.modeling import hyperparams
|
43
|
+
from official.projects.detr.configs import detr as exp_cfg # pylint: disable=unused-import
|
44
|
+
from official.projects.detr.serving import export_module
|
45
|
+
from official.vision.serving import export_saved_model_lib
|
46
|
+
|
47
|
+
FLAGS = flags.FLAGS
|
48
|
+
|
49
|
+
_EXPERIMENT = flags.DEFINE_string('experiment', None,
|
50
|
+
'experiment type, e.g. detr_coco')
|
51
|
+
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.')
|
52
|
+
_CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None,
|
53
|
+
'Checkpoint path.')
|
54
|
+
_CONFIG_FILE = flags.DEFINE_multi_string(
|
55
|
+
'config_file',
|
56
|
+
default=None,
|
57
|
+
help='YAML/JSON files which specifies overrides. The override order '
|
58
|
+
'follows the order of args. Note that each file '
|
59
|
+
'can be used as an override template to override the default parameters '
|
60
|
+
'specified in Python. If the same parameter is specified in both '
|
61
|
+
'`--config_file` and `--params_override`, `config_file` will be used '
|
62
|
+
'first, followed by params_override.')
|
63
|
+
_PARAMS_OVERRIDE = flags.DEFINE_string(
|
64
|
+
'params_override', '',
|
65
|
+
'The JSON/YAML file or string which specifies the parameter to be overriden'
|
66
|
+
' on top of `config_file` template.')
|
67
|
+
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'The batch size.')
|
68
|
+
_IMAGE_TYPE = flags.DEFINE_string(
|
69
|
+
'input_type', 'image_tensor',
|
70
|
+
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
|
71
|
+
_INPUT_IMAGE_SIZE = flags.DEFINE_string(
|
72
|
+
'input_image_size', '224,224',
|
73
|
+
'The comma-separated string of two integers representing the height,width '
|
74
|
+
'of the input to the model.')
|
75
|
+
|
76
|
+
|
77
|
+
def main(_):
|
78
|
+
|
79
|
+
params = exp_factory.get_exp_config(_EXPERIMENT.value)
|
80
|
+
for config_file in _CONFIG_FILE.value or []:
|
81
|
+
params = hyperparams.override_params_dict(
|
82
|
+
params, config_file, is_strict=False)
|
83
|
+
if _PARAMS_OVERRIDE.value:
|
84
|
+
params = hyperparams.override_params_dict(
|
85
|
+
params, _PARAMS_OVERRIDE.value, is_strict=False)
|
86
|
+
|
87
|
+
params.validate()
|
88
|
+
params.lock()
|
89
|
+
|
90
|
+
input_image_size = [int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')]
|
91
|
+
module = export_module.DETRModule(
|
92
|
+
params=params,
|
93
|
+
batch_size=_BATCH_SIZE.value,
|
94
|
+
input_image_size=input_image_size,
|
95
|
+
input_type=_IMAGE_TYPE.value,
|
96
|
+
num_channels=3)
|
97
|
+
|
98
|
+
export_saved_model_lib.export_inference_graph(
|
99
|
+
input_type=_IMAGE_TYPE.value,
|
100
|
+
batch_size=_BATCH_SIZE.value,
|
101
|
+
input_image_size=input_image_size,
|
102
|
+
params=params,
|
103
|
+
checkpoint_path=_CHECKPOINT_PATH.value,
|
104
|
+
export_dir=_EXPORT_DIR.value,
|
105
|
+
export_module=module)
|
106
|
+
|
107
|
+
|
108
|
+
if __name__ == '__main__':
|
109
|
+
app.run(main)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|