ultralytics-thop 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- thop/__init__.py +8 -0
- thop/__version__.py +1 -0
- thop/fx_profile.py +224 -0
- thop/onnx_profile.py +76 -0
- thop/profile.py +233 -0
- thop/rnn_hooks.py +195 -0
- thop/utils.py +56 -0
- thop/vision/__init__.py +0 -0
- thop/vision/basic_hooks.py +146 -0
- thop/vision/calc_func.py +118 -0
- thop/vision/efficientnet.py +9 -0
- thop/vision/onnx_counter.py +351 -0
- ultralytics_thop-0.0.1.dist-info/LICENSE +661 -0
- ultralytics_thop-0.0.1.dist-info/METADATA +846 -0
- ultralytics_thop-0.0.1.dist-info/RECORD +17 -0
- ultralytics_thop-0.0.1.dist-info/WHEEL +5 -0
- ultralytics_thop-0.0.1.dist-info/top_level.txt +1 -0
thop/__init__.py
ADDED
thop/__version__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.1"
|
thop/fx_profile.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from distutils.version import LooseVersion
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch as th
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
|
|
9
|
+
logging.warning(
|
|
10
|
+
f"torch.fx requires version higher than 1.8.0. "
|
|
11
|
+
f"But You are using an old version PyTorch {torch.__version__}. "
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def count_clamp(input_shapes, output_shapes):
|
|
16
|
+
return 0
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def count_mul(input_shapes, output_shapes):
|
|
20
|
+
# element-wise
|
|
21
|
+
return output_shapes[0].numel()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def count_matmul(input_shapes, output_shapes):
|
|
25
|
+
in_shape = input_shapes[0]
|
|
26
|
+
out_shape = output_shapes[0]
|
|
27
|
+
in_features = in_shape[-1]
|
|
28
|
+
num_elements = out_shape.numel()
|
|
29
|
+
return in_features * num_elements
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
|
|
33
|
+
mul_flops = count_matmul(input_shapes, output_shapes)
|
|
34
|
+
if "bias" in kwargs:
|
|
35
|
+
add_flops = output_shapes[0].numel()
|
|
36
|
+
return mul_flops
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
from .vision.calc_func import calculate_conv
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
|
|
43
|
+
inputs, weight, bias, stride, padding, dilation, groups = args
|
|
44
|
+
if len(input_shapes) == 2:
|
|
45
|
+
x_shape, k_shape = input_shapes
|
|
46
|
+
elif len(input_shapes) == 3:
|
|
47
|
+
x_shape, k_shape, b_shape = input_shapes
|
|
48
|
+
out_shape = output_shapes[0]
|
|
49
|
+
|
|
50
|
+
kernel_parameters = k_shape[2:].numel()
|
|
51
|
+
bias_op = 0 # check it later
|
|
52
|
+
in_channel = x_shape[1]
|
|
53
|
+
|
|
54
|
+
total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
|
|
55
|
+
return int(total_ops)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
|
|
59
|
+
return count_matmul(input_shapes, output_shapes)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
|
|
63
|
+
return 0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
|
|
67
|
+
bias_op = 1 if module.bias is not None else 0
|
|
68
|
+
out_shape = output_shapes[0]
|
|
69
|
+
|
|
70
|
+
in_channel = module.in_channels
|
|
71
|
+
groups = module.groups
|
|
72
|
+
kernel_ops = module.weight.shape[2:].numel()
|
|
73
|
+
total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
|
|
74
|
+
return int(total_ops)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
|
|
78
|
+
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
|
|
79
|
+
y = output_shapes[0]
|
|
80
|
+
# y = (x - mean) / \sqrt{var + e} * weight + bias
|
|
81
|
+
total_ops = 2 * y.numel()
|
|
82
|
+
return total_ops
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
zero_ops = (
|
|
86
|
+
nn.ReLU,
|
|
87
|
+
nn.ReLU6,
|
|
88
|
+
nn.Dropout,
|
|
89
|
+
nn.MaxPool2d,
|
|
90
|
+
nn.AvgPool2d,
|
|
91
|
+
nn.AdaptiveAvgPool2d,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
count_map = {
|
|
95
|
+
nn.Linear: count_nn_linear,
|
|
96
|
+
nn.Conv2d: count_nn_conv2d,
|
|
97
|
+
nn.BatchNorm2d: count_nn_bn2d,
|
|
98
|
+
"function linear": count_fn_linear,
|
|
99
|
+
"clamp": count_clamp,
|
|
100
|
+
"built-in function add": count_zero_ops,
|
|
101
|
+
"built-in method fl": count_zero_ops,
|
|
102
|
+
"built-in method conv2d of type object": count_fn_conv2d,
|
|
103
|
+
"built-in function mul": count_mul,
|
|
104
|
+
"built-in function truediv": count_mul,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
for k in zero_ops:
|
|
108
|
+
count_map[k] = count_zero_ops
|
|
109
|
+
|
|
110
|
+
missing_maps = {}
|
|
111
|
+
|
|
112
|
+
from torch.fx import symbolic_trace
|
|
113
|
+
from torch.fx.passes.shape_prop import ShapeProp
|
|
114
|
+
|
|
115
|
+
from .utils import prGreen, prRed, prYellow
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def null_print(*args, **kwargs):
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
|
|
123
|
+
gm: torch.fx.GraphModule = symbolic_trace(mod)
|
|
124
|
+
g = gm.graph
|
|
125
|
+
ShapeProp(gm).propagate(input)
|
|
126
|
+
|
|
127
|
+
fprint = null_print
|
|
128
|
+
if verbose:
|
|
129
|
+
fprint = print
|
|
130
|
+
|
|
131
|
+
v_maps = {}
|
|
132
|
+
total_flops = 0
|
|
133
|
+
|
|
134
|
+
for node in gm.graph.nodes:
|
|
135
|
+
# print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
|
|
136
|
+
fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
|
|
137
|
+
# node_op_type = str(node.target).split(".")[-1]
|
|
138
|
+
node_flops = None
|
|
139
|
+
|
|
140
|
+
input_shapes = []
|
|
141
|
+
output_shapes = []
|
|
142
|
+
fprint("input_shape:", end="\t")
|
|
143
|
+
for arg in node.args:
|
|
144
|
+
if str(arg) not in v_maps:
|
|
145
|
+
continue
|
|
146
|
+
fprint(f"{v_maps[str(arg)]}", end="\t")
|
|
147
|
+
input_shapes.append(v_maps[str(arg)])
|
|
148
|
+
fprint()
|
|
149
|
+
fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
|
|
150
|
+
output_shapes.append(node.meta["tensor_meta"].shape)
|
|
151
|
+
|
|
152
|
+
if node.op in ["output", "placeholder"]:
|
|
153
|
+
node_flops = 0
|
|
154
|
+
elif node.op == "call_function":
|
|
155
|
+
# torch internal functions
|
|
156
|
+
key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
|
|
157
|
+
if key in count_map:
|
|
158
|
+
node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
|
|
159
|
+
else:
|
|
160
|
+
missing_maps[key] = (node.op, key)
|
|
161
|
+
prRed(f"|{key}| is missing")
|
|
162
|
+
elif node.op == "call_method":
|
|
163
|
+
# torch internal functions
|
|
164
|
+
# fprint(str(node.target) in count_map, str(node.target), count_map.keys())
|
|
165
|
+
key = str(node.target)
|
|
166
|
+
if key in count_map:
|
|
167
|
+
node_flops = count_map[key](input_shapes, output_shapes)
|
|
168
|
+
else:
|
|
169
|
+
missing_maps[key] = (node.op, key)
|
|
170
|
+
prRed(f"{key} is missing")
|
|
171
|
+
elif node.op == "call_module":
|
|
172
|
+
# torch.nn modules
|
|
173
|
+
# m = getattr(mod, node.target, None)
|
|
174
|
+
m = mod.get_submodule(node.target)
|
|
175
|
+
key = type(m)
|
|
176
|
+
fprint(type(m), type(m) in count_map)
|
|
177
|
+
if type(m) in count_map:
|
|
178
|
+
node_flops = count_map[type(m)](m, input_shapes, output_shapes)
|
|
179
|
+
else:
|
|
180
|
+
missing_maps[key] = (node.op,)
|
|
181
|
+
prRed(f"{key} is missing")
|
|
182
|
+
print("module type:", type(m))
|
|
183
|
+
if isinstance(m, zero_ops):
|
|
184
|
+
print(f"weight_shape: None")
|
|
185
|
+
else:
|
|
186
|
+
print(type(m))
|
|
187
|
+
print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}")
|
|
188
|
+
|
|
189
|
+
v_maps[str(node.name)] = node.meta["tensor_meta"].shape
|
|
190
|
+
if node_flops is not None:
|
|
191
|
+
total_flops += node_flops
|
|
192
|
+
prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
|
|
193
|
+
fprint("==" * 40)
|
|
194
|
+
|
|
195
|
+
if len(missing_maps.keys()) > 0:
|
|
196
|
+
from pprint import pprint
|
|
197
|
+
|
|
198
|
+
print("Missing operators: ")
|
|
199
|
+
pprint(missing_maps)
|
|
200
|
+
return total_flops
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
if __name__ == "__main__":
|
|
204
|
+
|
|
205
|
+
class MyOP(nn.Module):
|
|
206
|
+
def forward(self, input):
|
|
207
|
+
return input / 1
|
|
208
|
+
|
|
209
|
+
class MyModule(torch.nn.Module):
|
|
210
|
+
def __init__(self):
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.linear1 = torch.nn.Linear(5, 3)
|
|
213
|
+
self.linear2 = torch.nn.Linear(5, 3)
|
|
214
|
+
self.myop = MyOP()
|
|
215
|
+
|
|
216
|
+
def forward(self, x):
|
|
217
|
+
out1 = self.linear1(x)
|
|
218
|
+
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
|
|
219
|
+
return self.myop(out1 + out2)
|
|
220
|
+
|
|
221
|
+
net = MyModule()
|
|
222
|
+
data = th.randn(20, 5)
|
|
223
|
+
flops = fx_profile(net, data, verbose=False)
|
|
224
|
+
print(flops)
|
thop/onnx_profile.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn
|
|
5
|
+
from onnx import numpy_helper
|
|
6
|
+
|
|
7
|
+
from thop.vision.onnx_counter import onnx_operators
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OnnxProfile:
|
|
11
|
+
def __init__(self):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
def calculate_params(self, model: onnx.ModelProto):
|
|
15
|
+
onnx_weights = model.graph.initializer
|
|
16
|
+
params = 0
|
|
17
|
+
|
|
18
|
+
for onnx_w in onnx_weights:
|
|
19
|
+
try:
|
|
20
|
+
weight = numpy_helper.to_array(onnx_w)
|
|
21
|
+
params += np.prod(weight.shape)
|
|
22
|
+
except Exception as _:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
return params
|
|
26
|
+
|
|
27
|
+
def create_dict(self, weight, input, output):
|
|
28
|
+
diction = {}
|
|
29
|
+
for w in weight:
|
|
30
|
+
dim = np.array(w.dims)
|
|
31
|
+
diction[str(w.name)] = dim
|
|
32
|
+
if dim.size == 1:
|
|
33
|
+
diction[str(w.name)] = np.append(1, dim)
|
|
34
|
+
for i in input:
|
|
35
|
+
# print(i.type.tensor_type.shape.dim[0].dim_value)
|
|
36
|
+
dim = np.array(i.type.tensor_type.shape.dim[0].dim_value)
|
|
37
|
+
# print(i.type.tensor_type.shape.dim.__sizeof__())
|
|
38
|
+
# name2dims[str(i.name)] = [dim]
|
|
39
|
+
dim = []
|
|
40
|
+
for key in i.type.tensor_type.shape.dim:
|
|
41
|
+
dim = np.append(dim, int(key.dim_value))
|
|
42
|
+
# print(key.dim_value)
|
|
43
|
+
# print(dim)
|
|
44
|
+
diction[str(i.name)] = dim
|
|
45
|
+
if dim.size == 1:
|
|
46
|
+
diction[str(i.name)] = np.append(1, dim)
|
|
47
|
+
for o in output:
|
|
48
|
+
dim = np.array(o.type.tensor_type.shape.dim[0].dim_value)
|
|
49
|
+
diction[str(o.name)] = [dim]
|
|
50
|
+
if dim.size == 1:
|
|
51
|
+
diction[str(o.name)] = np.append(1, dim)
|
|
52
|
+
return diction
|
|
53
|
+
|
|
54
|
+
def nodes_counter(self, diction, node):
|
|
55
|
+
if node.op_type not in onnx_operators:
|
|
56
|
+
print("Sorry, we haven't add ", node.op_type, "into dictionary.")
|
|
57
|
+
return 0, None, None
|
|
58
|
+
else:
|
|
59
|
+
fn = onnx_operators[node.op_type]
|
|
60
|
+
return fn(diction, node)
|
|
61
|
+
|
|
62
|
+
def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor:
|
|
63
|
+
macs = 0
|
|
64
|
+
name2dims = {}
|
|
65
|
+
weight = model.graph.initializer
|
|
66
|
+
nodes = model.graph.node
|
|
67
|
+
input = model.graph.input
|
|
68
|
+
output = model.graph.output
|
|
69
|
+
name2dims = self.create_dict(weight, input, output)
|
|
70
|
+
macs = 0
|
|
71
|
+
for n in nodes:
|
|
72
|
+
macs_adding, out_size, outname = self.nodes_counter(name2dims, n)
|
|
73
|
+
|
|
74
|
+
name2dims[outname] = out_size
|
|
75
|
+
macs += macs_adding
|
|
76
|
+
return np.array(macs[0])
|
thop/profile.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from distutils.version import LooseVersion
|
|
2
|
+
|
|
3
|
+
from thop.rnn_hooks import *
|
|
4
|
+
from thop.vision.basic_hooks import *
|
|
5
|
+
|
|
6
|
+
# logger = logging.getLogger(__name__)
|
|
7
|
+
# logger.setLevel(logging.INFO)
|
|
8
|
+
from .utils import prGreen, prRed, prYellow
|
|
9
|
+
|
|
10
|
+
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
|
|
11
|
+
logging.warning(
|
|
12
|
+
"You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__)
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
default_dtype = torch.float64
|
|
16
|
+
|
|
17
|
+
register_hooks = {
|
|
18
|
+
nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
|
|
19
|
+
nn.Conv1d: count_convNd,
|
|
20
|
+
nn.Conv2d: count_convNd,
|
|
21
|
+
nn.Conv3d: count_convNd,
|
|
22
|
+
nn.ConvTranspose1d: count_convNd,
|
|
23
|
+
nn.ConvTranspose2d: count_convNd,
|
|
24
|
+
nn.ConvTranspose3d: count_convNd,
|
|
25
|
+
nn.BatchNorm1d: count_normalization,
|
|
26
|
+
nn.BatchNorm2d: count_normalization,
|
|
27
|
+
nn.BatchNorm3d: count_normalization,
|
|
28
|
+
nn.LayerNorm: count_normalization,
|
|
29
|
+
nn.InstanceNorm1d: count_normalization,
|
|
30
|
+
nn.InstanceNorm2d: count_normalization,
|
|
31
|
+
nn.InstanceNorm3d: count_normalization,
|
|
32
|
+
nn.PReLU: count_prelu,
|
|
33
|
+
nn.Softmax: count_softmax,
|
|
34
|
+
nn.ReLU: zero_ops,
|
|
35
|
+
nn.ReLU6: zero_ops,
|
|
36
|
+
nn.LeakyReLU: count_relu,
|
|
37
|
+
nn.MaxPool1d: zero_ops,
|
|
38
|
+
nn.MaxPool2d: zero_ops,
|
|
39
|
+
nn.MaxPool3d: zero_ops,
|
|
40
|
+
nn.AdaptiveMaxPool1d: zero_ops,
|
|
41
|
+
nn.AdaptiveMaxPool2d: zero_ops,
|
|
42
|
+
nn.AdaptiveMaxPool3d: zero_ops,
|
|
43
|
+
nn.AvgPool1d: count_avgpool,
|
|
44
|
+
nn.AvgPool2d: count_avgpool,
|
|
45
|
+
nn.AvgPool3d: count_avgpool,
|
|
46
|
+
nn.AdaptiveAvgPool1d: count_adap_avgpool,
|
|
47
|
+
nn.AdaptiveAvgPool2d: count_adap_avgpool,
|
|
48
|
+
nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
|
49
|
+
nn.Linear: count_linear,
|
|
50
|
+
nn.Dropout: zero_ops,
|
|
51
|
+
nn.Upsample: count_upsample,
|
|
52
|
+
nn.UpsamplingBilinear2d: count_upsample,
|
|
53
|
+
nn.UpsamplingNearest2d: count_upsample,
|
|
54
|
+
nn.RNNCell: count_rnn_cell,
|
|
55
|
+
nn.GRUCell: count_gru_cell,
|
|
56
|
+
nn.LSTMCell: count_lstm_cell,
|
|
57
|
+
nn.RNN: count_rnn,
|
|
58
|
+
nn.GRU: count_gru,
|
|
59
|
+
nn.LSTM: count_lstm,
|
|
60
|
+
nn.Sequential: zero_ops,
|
|
61
|
+
nn.PixelShuffle: zero_ops,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
|
|
65
|
+
register_hooks.update({nn.SyncBatchNorm: count_normalization})
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
|
|
69
|
+
handler_collection = []
|
|
70
|
+
types_collection = set()
|
|
71
|
+
if custom_ops is None:
|
|
72
|
+
custom_ops = {}
|
|
73
|
+
if report_missing:
|
|
74
|
+
verbose = True
|
|
75
|
+
|
|
76
|
+
def add_hooks(m):
|
|
77
|
+
if len(list(m.children())) > 0:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
|
81
|
+
logging.warning(
|
|
82
|
+
"Either .total_ops or .total_params is already defined in %s. "
|
|
83
|
+
"Be careful, it might change your code's behavior." % str(m)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
|
|
87
|
+
m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
|
|
88
|
+
|
|
89
|
+
for p in m.parameters():
|
|
90
|
+
m.total_params += torch.DoubleTensor([p.numel()])
|
|
91
|
+
|
|
92
|
+
m_type = type(m)
|
|
93
|
+
|
|
94
|
+
fn = None
|
|
95
|
+
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
|
|
96
|
+
fn = custom_ops[m_type]
|
|
97
|
+
if m_type not in types_collection and verbose:
|
|
98
|
+
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
|
|
99
|
+
elif m_type in register_hooks:
|
|
100
|
+
fn = register_hooks[m_type]
|
|
101
|
+
if m_type not in types_collection and verbose:
|
|
102
|
+
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
|
|
103
|
+
else:
|
|
104
|
+
if m_type not in types_collection and report_missing:
|
|
105
|
+
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
|
|
106
|
+
|
|
107
|
+
if fn is not None:
|
|
108
|
+
handler = m.register_forward_hook(fn)
|
|
109
|
+
handler_collection.append(handler)
|
|
110
|
+
types_collection.add(m_type)
|
|
111
|
+
|
|
112
|
+
training = model.training
|
|
113
|
+
|
|
114
|
+
model.eval()
|
|
115
|
+
model.apply(add_hooks)
|
|
116
|
+
|
|
117
|
+
with torch.no_grad():
|
|
118
|
+
model(*inputs)
|
|
119
|
+
|
|
120
|
+
total_ops = 0
|
|
121
|
+
total_params = 0
|
|
122
|
+
for m in model.modules():
|
|
123
|
+
if len(list(m.children())) > 0: # skip for non-leaf module
|
|
124
|
+
continue
|
|
125
|
+
total_ops += m.total_ops
|
|
126
|
+
total_params += m.total_params
|
|
127
|
+
|
|
128
|
+
total_ops = total_ops.item()
|
|
129
|
+
total_params = total_params.item()
|
|
130
|
+
|
|
131
|
+
# reset model to original status
|
|
132
|
+
model.train(training)
|
|
133
|
+
for handler in handler_collection:
|
|
134
|
+
handler.remove()
|
|
135
|
+
|
|
136
|
+
# remove temporal buffers
|
|
137
|
+
for n, m in model.named_modules():
|
|
138
|
+
if len(list(m.children())) > 0:
|
|
139
|
+
continue
|
|
140
|
+
if "total_ops" in m._buffers:
|
|
141
|
+
m._buffers.pop("total_ops")
|
|
142
|
+
if "total_params" in m._buffers:
|
|
143
|
+
m._buffers.pop("total_params")
|
|
144
|
+
|
|
145
|
+
return total_ops, total_params
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def profile(
|
|
149
|
+
model: nn.Module,
|
|
150
|
+
inputs,
|
|
151
|
+
custom_ops=None,
|
|
152
|
+
verbose=True,
|
|
153
|
+
ret_layer_info=False,
|
|
154
|
+
report_missing=False,
|
|
155
|
+
):
|
|
156
|
+
handler_collection = {}
|
|
157
|
+
types_collection = set()
|
|
158
|
+
if custom_ops is None:
|
|
159
|
+
custom_ops = {}
|
|
160
|
+
if report_missing:
|
|
161
|
+
# overwrite `verbose` option when enable report_missing
|
|
162
|
+
verbose = True
|
|
163
|
+
|
|
164
|
+
def add_hooks(m: nn.Module):
|
|
165
|
+
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
|
|
166
|
+
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
|
|
167
|
+
|
|
168
|
+
# for p in m.parameters():
|
|
169
|
+
# m.total_params += torch.DoubleTensor([p.numel()])
|
|
170
|
+
|
|
171
|
+
m_type = type(m)
|
|
172
|
+
|
|
173
|
+
fn = None
|
|
174
|
+
if m_type in custom_ops:
|
|
175
|
+
# if defined both op maps, use custom_ops to overwrite.
|
|
176
|
+
fn = custom_ops[m_type]
|
|
177
|
+
if m_type not in types_collection and verbose:
|
|
178
|
+
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
|
|
179
|
+
elif m_type in register_hooks:
|
|
180
|
+
fn = register_hooks[m_type]
|
|
181
|
+
if m_type not in types_collection and verbose:
|
|
182
|
+
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
|
|
183
|
+
else:
|
|
184
|
+
if m_type not in types_collection and report_missing:
|
|
185
|
+
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
|
|
186
|
+
|
|
187
|
+
if fn is not None:
|
|
188
|
+
handler_collection[m] = (
|
|
189
|
+
m.register_forward_hook(fn),
|
|
190
|
+
m.register_forward_hook(count_parameters),
|
|
191
|
+
)
|
|
192
|
+
types_collection.add(m_type)
|
|
193
|
+
|
|
194
|
+
prev_training_status = model.training
|
|
195
|
+
|
|
196
|
+
model.eval()
|
|
197
|
+
model.apply(add_hooks)
|
|
198
|
+
|
|
199
|
+
with torch.no_grad():
|
|
200
|
+
model(*inputs)
|
|
201
|
+
|
|
202
|
+
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
|
|
203
|
+
total_ops, total_params = module.total_ops.item(), 0
|
|
204
|
+
ret_dict = {}
|
|
205
|
+
for n, m in module.named_children():
|
|
206
|
+
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
|
|
207
|
+
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
|
|
208
|
+
# else:
|
|
209
|
+
# m_ops, m_params = m.total_ops, m.total_params
|
|
210
|
+
next_dict = {}
|
|
211
|
+
if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
|
|
212
|
+
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
|
213
|
+
else:
|
|
214
|
+
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
|
|
215
|
+
ret_dict[n] = (m_ops, m_params, next_dict)
|
|
216
|
+
total_ops += m_ops
|
|
217
|
+
total_params += m_params
|
|
218
|
+
# print(prefix, module._get_name(), (total_ops, total_params))
|
|
219
|
+
return total_ops, total_params, ret_dict
|
|
220
|
+
|
|
221
|
+
total_ops, total_params, ret_dict = dfs_count(model)
|
|
222
|
+
|
|
223
|
+
# reset model to original status
|
|
224
|
+
model.train(prev_training_status)
|
|
225
|
+
for m, (op_handler, params_handler) in handler_collection.items():
|
|
226
|
+
op_handler.remove()
|
|
227
|
+
params_handler.remove()
|
|
228
|
+
m._buffers.pop("total_ops")
|
|
229
|
+
m._buffers.pop("total_params")
|
|
230
|
+
|
|
231
|
+
if ret_layer_info:
|
|
232
|
+
return total_ops, total_params, ret_dict
|
|
233
|
+
return total_ops, total_params
|