ultralytics-thop 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
thop/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ # from .onnx_profile import OnnxProfile
2
+ import torch
3
+
4
+ from .profile import profile, profile_origin
5
+ from .utils import clever_format
6
+
7
+ default_dtype = torch.float64
8
+ from .__version__ import __version__
thop/__version__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.1"
thop/fx_profile.py ADDED
@@ -0,0 +1,224 @@
1
+ import logging
2
+ from distutils.version import LooseVersion
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
9
+ logging.warning(
10
+ f"torch.fx requires version higher than 1.8.0. "
11
+ f"But You are using an old version PyTorch {torch.__version__}. "
12
+ )
13
+
14
+
15
+ def count_clamp(input_shapes, output_shapes):
16
+ return 0
17
+
18
+
19
+ def count_mul(input_shapes, output_shapes):
20
+ # element-wise
21
+ return output_shapes[0].numel()
22
+
23
+
24
+ def count_matmul(input_shapes, output_shapes):
25
+ in_shape = input_shapes[0]
26
+ out_shape = output_shapes[0]
27
+ in_features = in_shape[-1]
28
+ num_elements = out_shape.numel()
29
+ return in_features * num_elements
30
+
31
+
32
+ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
33
+ mul_flops = count_matmul(input_shapes, output_shapes)
34
+ if "bias" in kwargs:
35
+ add_flops = output_shapes[0].numel()
36
+ return mul_flops
37
+
38
+
39
+ from .vision.calc_func import calculate_conv
40
+
41
+
42
+ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
43
+ inputs, weight, bias, stride, padding, dilation, groups = args
44
+ if len(input_shapes) == 2:
45
+ x_shape, k_shape = input_shapes
46
+ elif len(input_shapes) == 3:
47
+ x_shape, k_shape, b_shape = input_shapes
48
+ out_shape = output_shapes[0]
49
+
50
+ kernel_parameters = k_shape[2:].numel()
51
+ bias_op = 0 # check it later
52
+ in_channel = x_shape[1]
53
+
54
+ total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
55
+ return int(total_ops)
56
+
57
+
58
+ def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
59
+ return count_matmul(input_shapes, output_shapes)
60
+
61
+
62
+ def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
63
+ return 0
64
+
65
+
66
+ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
67
+ bias_op = 1 if module.bias is not None else 0
68
+ out_shape = output_shapes[0]
69
+
70
+ in_channel = module.in_channels
71
+ groups = module.groups
72
+ kernel_ops = module.weight.shape[2:].numel()
73
+ total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
74
+ return int(total_ops)
75
+
76
+
77
+ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
78
+ assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
79
+ y = output_shapes[0]
80
+ # y = (x - mean) / \sqrt{var + e} * weight + bias
81
+ total_ops = 2 * y.numel()
82
+ return total_ops
83
+
84
+
85
+ zero_ops = (
86
+ nn.ReLU,
87
+ nn.ReLU6,
88
+ nn.Dropout,
89
+ nn.MaxPool2d,
90
+ nn.AvgPool2d,
91
+ nn.AdaptiveAvgPool2d,
92
+ )
93
+
94
+ count_map = {
95
+ nn.Linear: count_nn_linear,
96
+ nn.Conv2d: count_nn_conv2d,
97
+ nn.BatchNorm2d: count_nn_bn2d,
98
+ "function linear": count_fn_linear,
99
+ "clamp": count_clamp,
100
+ "built-in function add": count_zero_ops,
101
+ "built-in method fl": count_zero_ops,
102
+ "built-in method conv2d of type object": count_fn_conv2d,
103
+ "built-in function mul": count_mul,
104
+ "built-in function truediv": count_mul,
105
+ }
106
+
107
+ for k in zero_ops:
108
+ count_map[k] = count_zero_ops
109
+
110
+ missing_maps = {}
111
+
112
+ from torch.fx import symbolic_trace
113
+ from torch.fx.passes.shape_prop import ShapeProp
114
+
115
+ from .utils import prGreen, prRed, prYellow
116
+
117
+
118
+ def null_print(*args, **kwargs):
119
+ return
120
+
121
+
122
+ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
123
+ gm: torch.fx.GraphModule = symbolic_trace(mod)
124
+ g = gm.graph
125
+ ShapeProp(gm).propagate(input)
126
+
127
+ fprint = null_print
128
+ if verbose:
129
+ fprint = print
130
+
131
+ v_maps = {}
132
+ total_flops = 0
133
+
134
+ for node in gm.graph.nodes:
135
+ # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
136
+ fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
137
+ # node_op_type = str(node.target).split(".")[-1]
138
+ node_flops = None
139
+
140
+ input_shapes = []
141
+ output_shapes = []
142
+ fprint("input_shape:", end="\t")
143
+ for arg in node.args:
144
+ if str(arg) not in v_maps:
145
+ continue
146
+ fprint(f"{v_maps[str(arg)]}", end="\t")
147
+ input_shapes.append(v_maps[str(arg)])
148
+ fprint()
149
+ fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
150
+ output_shapes.append(node.meta["tensor_meta"].shape)
151
+
152
+ if node.op in ["output", "placeholder"]:
153
+ node_flops = 0
154
+ elif node.op == "call_function":
155
+ # torch internal functions
156
+ key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
157
+ if key in count_map:
158
+ node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
159
+ else:
160
+ missing_maps[key] = (node.op, key)
161
+ prRed(f"|{key}| is missing")
162
+ elif node.op == "call_method":
163
+ # torch internal functions
164
+ # fprint(str(node.target) in count_map, str(node.target), count_map.keys())
165
+ key = str(node.target)
166
+ if key in count_map:
167
+ node_flops = count_map[key](input_shapes, output_shapes)
168
+ else:
169
+ missing_maps[key] = (node.op, key)
170
+ prRed(f"{key} is missing")
171
+ elif node.op == "call_module":
172
+ # torch.nn modules
173
+ # m = getattr(mod, node.target, None)
174
+ m = mod.get_submodule(node.target)
175
+ key = type(m)
176
+ fprint(type(m), type(m) in count_map)
177
+ if type(m) in count_map:
178
+ node_flops = count_map[type(m)](m, input_shapes, output_shapes)
179
+ else:
180
+ missing_maps[key] = (node.op,)
181
+ prRed(f"{key} is missing")
182
+ print("module type:", type(m))
183
+ if isinstance(m, zero_ops):
184
+ print(f"weight_shape: None")
185
+ else:
186
+ print(type(m))
187
+ print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}")
188
+
189
+ v_maps[str(node.name)] = node.meta["tensor_meta"].shape
190
+ if node_flops is not None:
191
+ total_flops += node_flops
192
+ prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
193
+ fprint("==" * 40)
194
+
195
+ if len(missing_maps.keys()) > 0:
196
+ from pprint import pprint
197
+
198
+ print("Missing operators: ")
199
+ pprint(missing_maps)
200
+ return total_flops
201
+
202
+
203
+ if __name__ == "__main__":
204
+
205
+ class MyOP(nn.Module):
206
+ def forward(self, input):
207
+ return input / 1
208
+
209
+ class MyModule(torch.nn.Module):
210
+ def __init__(self):
211
+ super().__init__()
212
+ self.linear1 = torch.nn.Linear(5, 3)
213
+ self.linear2 = torch.nn.Linear(5, 3)
214
+ self.myop = MyOP()
215
+
216
+ def forward(self, x):
217
+ out1 = self.linear1(x)
218
+ out2 = self.linear2(x).clamp(min=0.0, max=1.0)
219
+ return self.myop(out1 + out2)
220
+
221
+ net = MyModule()
222
+ data = th.randn(20, 5)
223
+ flops = fx_profile(net, data, verbose=False)
224
+ print(flops)
thop/onnx_profile.py ADDED
@@ -0,0 +1,76 @@
1
+ import numpy as np
2
+ import onnx
3
+ import torch
4
+ import torch.nn
5
+ from onnx import numpy_helper
6
+
7
+ from thop.vision.onnx_counter import onnx_operators
8
+
9
+
10
+ class OnnxProfile:
11
+ def __init__(self):
12
+ pass
13
+
14
+ def calculate_params(self, model: onnx.ModelProto):
15
+ onnx_weights = model.graph.initializer
16
+ params = 0
17
+
18
+ for onnx_w in onnx_weights:
19
+ try:
20
+ weight = numpy_helper.to_array(onnx_w)
21
+ params += np.prod(weight.shape)
22
+ except Exception as _:
23
+ pass
24
+
25
+ return params
26
+
27
+ def create_dict(self, weight, input, output):
28
+ diction = {}
29
+ for w in weight:
30
+ dim = np.array(w.dims)
31
+ diction[str(w.name)] = dim
32
+ if dim.size == 1:
33
+ diction[str(w.name)] = np.append(1, dim)
34
+ for i in input:
35
+ # print(i.type.tensor_type.shape.dim[0].dim_value)
36
+ dim = np.array(i.type.tensor_type.shape.dim[0].dim_value)
37
+ # print(i.type.tensor_type.shape.dim.__sizeof__())
38
+ # name2dims[str(i.name)] = [dim]
39
+ dim = []
40
+ for key in i.type.tensor_type.shape.dim:
41
+ dim = np.append(dim, int(key.dim_value))
42
+ # print(key.dim_value)
43
+ # print(dim)
44
+ diction[str(i.name)] = dim
45
+ if dim.size == 1:
46
+ diction[str(i.name)] = np.append(1, dim)
47
+ for o in output:
48
+ dim = np.array(o.type.tensor_type.shape.dim[0].dim_value)
49
+ diction[str(o.name)] = [dim]
50
+ if dim.size == 1:
51
+ diction[str(o.name)] = np.append(1, dim)
52
+ return diction
53
+
54
+ def nodes_counter(self, diction, node):
55
+ if node.op_type not in onnx_operators:
56
+ print("Sorry, we haven't add ", node.op_type, "into dictionary.")
57
+ return 0, None, None
58
+ else:
59
+ fn = onnx_operators[node.op_type]
60
+ return fn(diction, node)
61
+
62
+ def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor:
63
+ macs = 0
64
+ name2dims = {}
65
+ weight = model.graph.initializer
66
+ nodes = model.graph.node
67
+ input = model.graph.input
68
+ output = model.graph.output
69
+ name2dims = self.create_dict(weight, input, output)
70
+ macs = 0
71
+ for n in nodes:
72
+ macs_adding, out_size, outname = self.nodes_counter(name2dims, n)
73
+
74
+ name2dims[outname] = out_size
75
+ macs += macs_adding
76
+ return np.array(macs[0])
thop/profile.py ADDED
@@ -0,0 +1,233 @@
1
+ from distutils.version import LooseVersion
2
+
3
+ from thop.rnn_hooks import *
4
+ from thop.vision.basic_hooks import *
5
+
6
+ # logger = logging.getLogger(__name__)
7
+ # logger.setLevel(logging.INFO)
8
+ from .utils import prGreen, prRed, prYellow
9
+
10
+ if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
11
+ logging.warning(
12
+ "You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__)
13
+ )
14
+
15
+ default_dtype = torch.float64
16
+
17
+ register_hooks = {
18
+ nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
19
+ nn.Conv1d: count_convNd,
20
+ nn.Conv2d: count_convNd,
21
+ nn.Conv3d: count_convNd,
22
+ nn.ConvTranspose1d: count_convNd,
23
+ nn.ConvTranspose2d: count_convNd,
24
+ nn.ConvTranspose3d: count_convNd,
25
+ nn.BatchNorm1d: count_normalization,
26
+ nn.BatchNorm2d: count_normalization,
27
+ nn.BatchNorm3d: count_normalization,
28
+ nn.LayerNorm: count_normalization,
29
+ nn.InstanceNorm1d: count_normalization,
30
+ nn.InstanceNorm2d: count_normalization,
31
+ nn.InstanceNorm3d: count_normalization,
32
+ nn.PReLU: count_prelu,
33
+ nn.Softmax: count_softmax,
34
+ nn.ReLU: zero_ops,
35
+ nn.ReLU6: zero_ops,
36
+ nn.LeakyReLU: count_relu,
37
+ nn.MaxPool1d: zero_ops,
38
+ nn.MaxPool2d: zero_ops,
39
+ nn.MaxPool3d: zero_ops,
40
+ nn.AdaptiveMaxPool1d: zero_ops,
41
+ nn.AdaptiveMaxPool2d: zero_ops,
42
+ nn.AdaptiveMaxPool3d: zero_ops,
43
+ nn.AvgPool1d: count_avgpool,
44
+ nn.AvgPool2d: count_avgpool,
45
+ nn.AvgPool3d: count_avgpool,
46
+ nn.AdaptiveAvgPool1d: count_adap_avgpool,
47
+ nn.AdaptiveAvgPool2d: count_adap_avgpool,
48
+ nn.AdaptiveAvgPool3d: count_adap_avgpool,
49
+ nn.Linear: count_linear,
50
+ nn.Dropout: zero_ops,
51
+ nn.Upsample: count_upsample,
52
+ nn.UpsamplingBilinear2d: count_upsample,
53
+ nn.UpsamplingNearest2d: count_upsample,
54
+ nn.RNNCell: count_rnn_cell,
55
+ nn.GRUCell: count_gru_cell,
56
+ nn.LSTMCell: count_lstm_cell,
57
+ nn.RNN: count_rnn,
58
+ nn.GRU: count_gru,
59
+ nn.LSTM: count_lstm,
60
+ nn.Sequential: zero_ops,
61
+ nn.PixelShuffle: zero_ops,
62
+ }
63
+
64
+ if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
65
+ register_hooks.update({nn.SyncBatchNorm: count_normalization})
66
+
67
+
68
+ def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
69
+ handler_collection = []
70
+ types_collection = set()
71
+ if custom_ops is None:
72
+ custom_ops = {}
73
+ if report_missing:
74
+ verbose = True
75
+
76
+ def add_hooks(m):
77
+ if len(list(m.children())) > 0:
78
+ return
79
+
80
+ if hasattr(m, "total_ops") or hasattr(m, "total_params"):
81
+ logging.warning(
82
+ "Either .total_ops or .total_params is already defined in %s. "
83
+ "Be careful, it might change your code's behavior." % str(m)
84
+ )
85
+
86
+ m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
87
+ m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
88
+
89
+ for p in m.parameters():
90
+ m.total_params += torch.DoubleTensor([p.numel()])
91
+
92
+ m_type = type(m)
93
+
94
+ fn = None
95
+ if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
96
+ fn = custom_ops[m_type]
97
+ if m_type not in types_collection and verbose:
98
+ print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
99
+ elif m_type in register_hooks:
100
+ fn = register_hooks[m_type]
101
+ if m_type not in types_collection and verbose:
102
+ print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
103
+ else:
104
+ if m_type not in types_collection and report_missing:
105
+ prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
106
+
107
+ if fn is not None:
108
+ handler = m.register_forward_hook(fn)
109
+ handler_collection.append(handler)
110
+ types_collection.add(m_type)
111
+
112
+ training = model.training
113
+
114
+ model.eval()
115
+ model.apply(add_hooks)
116
+
117
+ with torch.no_grad():
118
+ model(*inputs)
119
+
120
+ total_ops = 0
121
+ total_params = 0
122
+ for m in model.modules():
123
+ if len(list(m.children())) > 0: # skip for non-leaf module
124
+ continue
125
+ total_ops += m.total_ops
126
+ total_params += m.total_params
127
+
128
+ total_ops = total_ops.item()
129
+ total_params = total_params.item()
130
+
131
+ # reset model to original status
132
+ model.train(training)
133
+ for handler in handler_collection:
134
+ handler.remove()
135
+
136
+ # remove temporal buffers
137
+ for n, m in model.named_modules():
138
+ if len(list(m.children())) > 0:
139
+ continue
140
+ if "total_ops" in m._buffers:
141
+ m._buffers.pop("total_ops")
142
+ if "total_params" in m._buffers:
143
+ m._buffers.pop("total_params")
144
+
145
+ return total_ops, total_params
146
+
147
+
148
+ def profile(
149
+ model: nn.Module,
150
+ inputs,
151
+ custom_ops=None,
152
+ verbose=True,
153
+ ret_layer_info=False,
154
+ report_missing=False,
155
+ ):
156
+ handler_collection = {}
157
+ types_collection = set()
158
+ if custom_ops is None:
159
+ custom_ops = {}
160
+ if report_missing:
161
+ # overwrite `verbose` option when enable report_missing
162
+ verbose = True
163
+
164
+ def add_hooks(m: nn.Module):
165
+ m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
166
+ m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
167
+
168
+ # for p in m.parameters():
169
+ # m.total_params += torch.DoubleTensor([p.numel()])
170
+
171
+ m_type = type(m)
172
+
173
+ fn = None
174
+ if m_type in custom_ops:
175
+ # if defined both op maps, use custom_ops to overwrite.
176
+ fn = custom_ops[m_type]
177
+ if m_type not in types_collection and verbose:
178
+ print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
179
+ elif m_type in register_hooks:
180
+ fn = register_hooks[m_type]
181
+ if m_type not in types_collection and verbose:
182
+ print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
183
+ else:
184
+ if m_type not in types_collection and report_missing:
185
+ prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
186
+
187
+ if fn is not None:
188
+ handler_collection[m] = (
189
+ m.register_forward_hook(fn),
190
+ m.register_forward_hook(count_parameters),
191
+ )
192
+ types_collection.add(m_type)
193
+
194
+ prev_training_status = model.training
195
+
196
+ model.eval()
197
+ model.apply(add_hooks)
198
+
199
+ with torch.no_grad():
200
+ model(*inputs)
201
+
202
+ def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
203
+ total_ops, total_params = module.total_ops.item(), 0
204
+ ret_dict = {}
205
+ for n, m in module.named_children():
206
+ # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
207
+ # m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
208
+ # else:
209
+ # m_ops, m_params = m.total_ops, m.total_params
210
+ next_dict = {}
211
+ if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
212
+ m_ops, m_params = m.total_ops.item(), m.total_params.item()
213
+ else:
214
+ m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
215
+ ret_dict[n] = (m_ops, m_params, next_dict)
216
+ total_ops += m_ops
217
+ total_params += m_params
218
+ # print(prefix, module._get_name(), (total_ops, total_params))
219
+ return total_ops, total_params, ret_dict
220
+
221
+ total_ops, total_params, ret_dict = dfs_count(model)
222
+
223
+ # reset model to original status
224
+ model.train(prev_training_status)
225
+ for m, (op_handler, params_handler) in handler_collection.items():
226
+ op_handler.remove()
227
+ params_handler.remove()
228
+ m._buffers.pop("total_ops")
229
+ m._buffers.pop("total_params")
230
+
231
+ if ret_layer_info:
232
+ return total_ops, total_params, ret_dict
233
+ return total_ops, total_params