ultralytics-thop 2.0.14__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- thop/__init__.py +11 -0
- thop/fx_profile.py +237 -0
- thop/profile.py +228 -0
- thop/rnn_hooks.py +202 -0
- thop/utils.py +50 -0
- thop/vision/__init__.py +1 -0
- thop/vision/basic_hooks.py +146 -0
- thop/vision/calc_func.py +127 -0
- ultralytics_thop-2.0.14.dist-info/LICENSE +661 -0
- ultralytics_thop-2.0.14.dist-info/METADATA +199 -0
- ultralytics_thop-2.0.14.dist-info/RECORD +13 -0
- ultralytics_thop-2.0.14.dist-info/WHEEL +5 -0
- ultralytics_thop-2.0.14.dist-info/top_level.txt +1 -0
thop/__init__.py
ADDED
thop/fx_profile.py
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from distutils.version import LooseVersion
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch as th
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
|
11
|
+
logging.warning(
|
12
|
+
f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. "
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def count_clamp(input_shapes, output_shapes):
|
17
|
+
"""Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
|
18
|
+
return 0
|
19
|
+
|
20
|
+
|
21
|
+
def count_mul(input_shapes, output_shapes):
|
22
|
+
"""Returns the number of elements in the first output shape."""
|
23
|
+
return output_shapes[0].numel()
|
24
|
+
|
25
|
+
|
26
|
+
def count_matmul(input_shapes, output_shapes):
|
27
|
+
"""Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
|
28
|
+
in_shape = input_shapes[0]
|
29
|
+
out_shape = output_shapes[0]
|
30
|
+
in_features = in_shape[-1]
|
31
|
+
num_elements = out_shape.numel()
|
32
|
+
return in_features * num_elements
|
33
|
+
|
34
|
+
|
35
|
+
def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
|
36
|
+
"""Calculates the total FLOPs for a linear layer, including bias operations if specified."""
|
37
|
+
flops = count_matmul(input_shapes, output_shapes)
|
38
|
+
if "bias" in kwargs:
|
39
|
+
flops += output_shapes[0].numel()
|
40
|
+
return flops
|
41
|
+
|
42
|
+
|
43
|
+
from .vision.calc_func import calculate_conv
|
44
|
+
|
45
|
+
|
46
|
+
def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
|
47
|
+
"""Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
|
48
|
+
`calculate_conv`.
|
49
|
+
"""
|
50
|
+
inputs, weight, bias, stride, padding, dilation, groups = args
|
51
|
+
if len(input_shapes) == 2:
|
52
|
+
x_shape, k_shape = input_shapes
|
53
|
+
elif len(input_shapes) == 3:
|
54
|
+
x_shape, k_shape, b_shape = input_shapes
|
55
|
+
out_shape = output_shapes[0]
|
56
|
+
|
57
|
+
kernel_parameters = k_shape[2:].numel()
|
58
|
+
bias_op = 0 # check it later
|
59
|
+
in_channel = x_shape[1]
|
60
|
+
|
61
|
+
total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
|
62
|
+
return int(total_ops)
|
63
|
+
|
64
|
+
|
65
|
+
def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
|
66
|
+
"""Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
|
67
|
+
return count_matmul(input_shapes, output_shapes)
|
68
|
+
|
69
|
+
|
70
|
+
def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
|
71
|
+
"""Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
|
72
|
+
return 0
|
73
|
+
|
74
|
+
|
75
|
+
def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
|
76
|
+
"""Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
|
77
|
+
bias_op = 1 if module.bias is not None else 0
|
78
|
+
out_shape = output_shapes[0]
|
79
|
+
|
80
|
+
in_channel = module.in_channels
|
81
|
+
groups = module.groups
|
82
|
+
kernel_ops = module.weight.shape[2:].numel()
|
83
|
+
total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
|
84
|
+
return int(total_ops)
|
85
|
+
|
86
|
+
|
87
|
+
def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
|
88
|
+
"""Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
|
89
|
+
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
|
90
|
+
y = output_shapes[0]
|
91
|
+
return 2 * y.numel()
|
92
|
+
|
93
|
+
|
94
|
+
zero_ops = (
|
95
|
+
nn.ReLU,
|
96
|
+
nn.ReLU6,
|
97
|
+
nn.Dropout,
|
98
|
+
nn.MaxPool2d,
|
99
|
+
nn.AvgPool2d,
|
100
|
+
nn.AdaptiveAvgPool2d,
|
101
|
+
)
|
102
|
+
|
103
|
+
count_map = {
|
104
|
+
nn.Linear: count_nn_linear,
|
105
|
+
nn.Conv2d: count_nn_conv2d,
|
106
|
+
nn.BatchNorm2d: count_nn_bn2d,
|
107
|
+
"function linear": count_fn_linear,
|
108
|
+
"clamp": count_clamp,
|
109
|
+
"built-in function add": count_zero_ops,
|
110
|
+
"built-in method fl": count_zero_ops,
|
111
|
+
"built-in method conv2d of type object": count_fn_conv2d,
|
112
|
+
"built-in function mul": count_mul,
|
113
|
+
"built-in function truediv": count_mul,
|
114
|
+
}
|
115
|
+
|
116
|
+
for k in zero_ops:
|
117
|
+
count_map[k] = count_zero_ops
|
118
|
+
|
119
|
+
missing_maps = {}
|
120
|
+
|
121
|
+
from torch.fx import symbolic_trace
|
122
|
+
from torch.fx.passes.shape_prop import ShapeProp
|
123
|
+
|
124
|
+
from .utils import prRed, prYellow
|
125
|
+
|
126
|
+
|
127
|
+
def null_print(*args, **kwargs):
|
128
|
+
"""A no-op print function that takes any arguments without performing any actions."""
|
129
|
+
return
|
130
|
+
|
131
|
+
|
132
|
+
def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
|
133
|
+
"""Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
|
134
|
+
gm: torch.fx.GraphModule = symbolic_trace(mod)
|
135
|
+
ShapeProp(gm).propagate(input)
|
136
|
+
|
137
|
+
fprint = null_print
|
138
|
+
if verbose:
|
139
|
+
fprint = print
|
140
|
+
|
141
|
+
v_maps = {}
|
142
|
+
total_flops = 0
|
143
|
+
|
144
|
+
for node in gm.graph.nodes:
|
145
|
+
# print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
|
146
|
+
fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
|
147
|
+
# node_op_type = str(node.target).split(".")[-1]
|
148
|
+
node_flops = None
|
149
|
+
|
150
|
+
input_shapes = []
|
151
|
+
fprint("input_shape:", end="\t")
|
152
|
+
for arg in node.args:
|
153
|
+
if str(arg) not in v_maps:
|
154
|
+
continue
|
155
|
+
fprint(f"{v_maps[str(arg)]}", end="\t")
|
156
|
+
input_shapes.append(v_maps[str(arg)])
|
157
|
+
fprint()
|
158
|
+
fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
|
159
|
+
output_shapes = [node.meta["tensor_meta"].shape]
|
160
|
+
if node.op in ["output", "placeholder"]:
|
161
|
+
node_flops = 0
|
162
|
+
elif node.op == "call_function":
|
163
|
+
# torch internal functions
|
164
|
+
key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
|
165
|
+
if key in count_map:
|
166
|
+
node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
|
167
|
+
else:
|
168
|
+
missing_maps[key] = (node.op, key)
|
169
|
+
prRed(f"|{key}| is missing")
|
170
|
+
elif node.op == "call_method":
|
171
|
+
# torch internal functions
|
172
|
+
# fprint(str(node.target) in count_map, str(node.target), count_map.keys())
|
173
|
+
key = str(node.target)
|
174
|
+
if key in count_map:
|
175
|
+
node_flops = count_map[key](input_shapes, output_shapes)
|
176
|
+
else:
|
177
|
+
missing_maps[key] = (node.op, key)
|
178
|
+
prRed(f"{key} is missing")
|
179
|
+
elif node.op == "call_module":
|
180
|
+
# torch.nn modules
|
181
|
+
# m = getattr(mod, node.target, None)
|
182
|
+
m = mod.get_submodule(node.target)
|
183
|
+
key = type(m)
|
184
|
+
fprint(type(m), type(m) in count_map)
|
185
|
+
if type(m) in count_map:
|
186
|
+
node_flops = count_map[type(m)](m, input_shapes, output_shapes)
|
187
|
+
else:
|
188
|
+
missing_maps[key] = (node.op,)
|
189
|
+
prRed(f"{key} is missing")
|
190
|
+
print("module type:", type(m))
|
191
|
+
if isinstance(m, zero_ops):
|
192
|
+
print("weight_shape: None")
|
193
|
+
else:
|
194
|
+
print(type(m))
|
195
|
+
print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")
|
196
|
+
|
197
|
+
v_maps[str(node.name)] = node.meta["tensor_meta"].shape
|
198
|
+
if node_flops is not None:
|
199
|
+
total_flops += node_flops
|
200
|
+
prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
|
201
|
+
fprint("==" * 40)
|
202
|
+
|
203
|
+
if len(missing_maps.keys()) > 0:
|
204
|
+
from pprint import pprint
|
205
|
+
|
206
|
+
print("Missing operators: ")
|
207
|
+
pprint(missing_maps)
|
208
|
+
return total_flops
|
209
|
+
|
210
|
+
|
211
|
+
if __name__ == "__main__":
|
212
|
+
|
213
|
+
class MyOP(nn.Module):
|
214
|
+
def forward(self, input):
|
215
|
+
"""Performs forward pass on given input data."""
|
216
|
+
return input / 1
|
217
|
+
|
218
|
+
class MyModule(torch.nn.Module):
|
219
|
+
def __init__(self):
|
220
|
+
"""Initializes MyModule with two linear layers and a custom MyOP operator."""
|
221
|
+
super().__init__()
|
222
|
+
self.linear1 = torch.nn.Linear(5, 3)
|
223
|
+
self.linear2 = torch.nn.Linear(5, 3)
|
224
|
+
self.myop = MyOP()
|
225
|
+
|
226
|
+
def forward(self, x):
|
227
|
+
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes
|
228
|
+
with MyOP operator.
|
229
|
+
"""
|
230
|
+
out1 = self.linear1(x)
|
231
|
+
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
|
232
|
+
return self.myop(out1 + out2)
|
233
|
+
|
234
|
+
net = MyModule()
|
235
|
+
data = th.randn(20, 5)
|
236
|
+
flops = fx_profile(net, data, verbose=False)
|
237
|
+
print(flops)
|
thop/profile.py
ADDED
@@ -0,0 +1,228 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from thop.rnn_hooks import *
|
4
|
+
from thop.vision.basic_hooks import *
|
5
|
+
|
6
|
+
from .utils import prRed
|
7
|
+
|
8
|
+
default_dtype = torch.float64
|
9
|
+
|
10
|
+
register_hooks = {
|
11
|
+
nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
|
12
|
+
nn.Conv1d: count_convNd,
|
13
|
+
nn.Conv2d: count_convNd,
|
14
|
+
nn.Conv3d: count_convNd,
|
15
|
+
nn.ConvTranspose1d: count_convNd,
|
16
|
+
nn.ConvTranspose2d: count_convNd,
|
17
|
+
nn.ConvTranspose3d: count_convNd,
|
18
|
+
nn.BatchNorm1d: count_normalization,
|
19
|
+
nn.BatchNorm2d: count_normalization,
|
20
|
+
nn.BatchNorm3d: count_normalization,
|
21
|
+
nn.LayerNorm: count_normalization,
|
22
|
+
nn.InstanceNorm1d: count_normalization,
|
23
|
+
nn.InstanceNorm2d: count_normalization,
|
24
|
+
nn.InstanceNorm3d: count_normalization,
|
25
|
+
nn.PReLU: count_prelu,
|
26
|
+
nn.Softmax: count_softmax,
|
27
|
+
nn.ReLU: zero_ops,
|
28
|
+
nn.ReLU6: zero_ops,
|
29
|
+
nn.LeakyReLU: count_relu,
|
30
|
+
nn.MaxPool1d: zero_ops,
|
31
|
+
nn.MaxPool2d: zero_ops,
|
32
|
+
nn.MaxPool3d: zero_ops,
|
33
|
+
nn.AdaptiveMaxPool1d: zero_ops,
|
34
|
+
nn.AdaptiveMaxPool2d: zero_ops,
|
35
|
+
nn.AdaptiveMaxPool3d: zero_ops,
|
36
|
+
nn.AvgPool1d: count_avgpool,
|
37
|
+
nn.AvgPool2d: count_avgpool,
|
38
|
+
nn.AvgPool3d: count_avgpool,
|
39
|
+
nn.AdaptiveAvgPool1d: count_adap_avgpool,
|
40
|
+
nn.AdaptiveAvgPool2d: count_adap_avgpool,
|
41
|
+
nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
42
|
+
nn.Linear: count_linear,
|
43
|
+
nn.Dropout: zero_ops,
|
44
|
+
nn.Upsample: count_upsample,
|
45
|
+
nn.UpsamplingBilinear2d: count_upsample,
|
46
|
+
nn.UpsamplingNearest2d: count_upsample,
|
47
|
+
nn.RNNCell: count_rnn_cell,
|
48
|
+
nn.GRUCell: count_gru_cell,
|
49
|
+
nn.LSTMCell: count_lstm_cell,
|
50
|
+
nn.RNN: count_rnn,
|
51
|
+
nn.GRU: count_gru,
|
52
|
+
nn.LSTM: count_lstm,
|
53
|
+
nn.Sequential: zero_ops,
|
54
|
+
nn.PixelShuffle: zero_ops,
|
55
|
+
nn.SyncBatchNorm: count_normalization,
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
|
60
|
+
"""Profiles a PyTorch model's operations and parameters, applying either custom or default hooks."""
|
61
|
+
handler_collection = []
|
62
|
+
types_collection = set()
|
63
|
+
if custom_ops is None:
|
64
|
+
custom_ops = {}
|
65
|
+
if report_missing:
|
66
|
+
verbose = True
|
67
|
+
|
68
|
+
def add_hooks(m):
|
69
|
+
if list(m.children()):
|
70
|
+
return
|
71
|
+
|
72
|
+
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
73
|
+
logging.warning(
|
74
|
+
f"Either .total_ops or .total_params is already defined in {str(m)}. "
|
75
|
+
"Be careful, it might change your code's behavior."
|
76
|
+
)
|
77
|
+
|
78
|
+
m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
|
79
|
+
m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
|
80
|
+
|
81
|
+
for p in m.parameters():
|
82
|
+
m.total_params += torch.DoubleTensor([p.numel()])
|
83
|
+
|
84
|
+
m_type = type(m)
|
85
|
+
|
86
|
+
fn = None
|
87
|
+
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
|
88
|
+
fn = custom_ops[m_type]
|
89
|
+
if m_type not in types_collection and verbose:
|
90
|
+
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
|
91
|
+
elif m_type in register_hooks:
|
92
|
+
fn = register_hooks[m_type]
|
93
|
+
if m_type not in types_collection and verbose:
|
94
|
+
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
|
95
|
+
else:
|
96
|
+
if m_type not in types_collection and report_missing:
|
97
|
+
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
|
98
|
+
|
99
|
+
if fn is not None:
|
100
|
+
handler = m.register_forward_hook(fn)
|
101
|
+
handler_collection.append(handler)
|
102
|
+
types_collection.add(m_type)
|
103
|
+
|
104
|
+
training = model.training
|
105
|
+
|
106
|
+
model.eval()
|
107
|
+
model.apply(add_hooks)
|
108
|
+
|
109
|
+
with torch.no_grad():
|
110
|
+
model(*inputs)
|
111
|
+
|
112
|
+
total_ops = 0
|
113
|
+
total_params = 0
|
114
|
+
for m in model.modules():
|
115
|
+
if list(m.children()): # skip for non-leaf module
|
116
|
+
continue
|
117
|
+
total_ops += m.total_ops
|
118
|
+
total_params += m.total_params
|
119
|
+
|
120
|
+
total_ops = total_ops.item()
|
121
|
+
total_params = total_params.item()
|
122
|
+
|
123
|
+
# reset model to original status
|
124
|
+
model.train(training)
|
125
|
+
for handler in handler_collection:
|
126
|
+
handler.remove()
|
127
|
+
|
128
|
+
# remove temporal buffers
|
129
|
+
for n, m in model.named_modules():
|
130
|
+
if list(m.children()):
|
131
|
+
continue
|
132
|
+
if "total_ops" in m._buffers:
|
133
|
+
m._buffers.pop("total_ops")
|
134
|
+
if "total_params" in m._buffers:
|
135
|
+
m._buffers.pop("total_params")
|
136
|
+
|
137
|
+
return total_ops, total_params
|
138
|
+
|
139
|
+
|
140
|
+
def profile(
|
141
|
+
model: nn.Module,
|
142
|
+
inputs,
|
143
|
+
custom_ops=None,
|
144
|
+
verbose=True,
|
145
|
+
ret_layer_info=False,
|
146
|
+
report_missing=False,
|
147
|
+
):
|
148
|
+
"""Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details."""
|
149
|
+
handler_collection = {}
|
150
|
+
types_collection = set()
|
151
|
+
if custom_ops is None:
|
152
|
+
custom_ops = {}
|
153
|
+
if report_missing:
|
154
|
+
# overwrite `verbose` option when enable report_missing
|
155
|
+
verbose = True
|
156
|
+
|
157
|
+
def add_hooks(m: nn.Module):
|
158
|
+
"""Registers hooks to a neural network module to track total operations and parameters."""
|
159
|
+
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
|
160
|
+
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
|
161
|
+
|
162
|
+
# for p in m.parameters():
|
163
|
+
# m.total_params += torch.DoubleTensor([p.numel()])
|
164
|
+
|
165
|
+
m_type = type(m)
|
166
|
+
|
167
|
+
fn = None
|
168
|
+
if m_type in custom_ops:
|
169
|
+
# if defined both op maps, use custom_ops to overwrite.
|
170
|
+
fn = custom_ops[m_type]
|
171
|
+
if m_type not in types_collection and verbose:
|
172
|
+
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
|
173
|
+
elif m_type in register_hooks:
|
174
|
+
fn = register_hooks[m_type]
|
175
|
+
if m_type not in types_collection and verbose:
|
176
|
+
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
|
177
|
+
else:
|
178
|
+
if m_type not in types_collection and report_missing:
|
179
|
+
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
|
180
|
+
|
181
|
+
if fn is not None:
|
182
|
+
handler_collection[m] = (
|
183
|
+
m.register_forward_hook(fn),
|
184
|
+
m.register_forward_hook(count_parameters),
|
185
|
+
)
|
186
|
+
types_collection.add(m_type)
|
187
|
+
|
188
|
+
prev_training_status = model.training
|
189
|
+
|
190
|
+
model.eval()
|
191
|
+
model.apply(add_hooks)
|
192
|
+
|
193
|
+
with torch.no_grad():
|
194
|
+
model(*inputs)
|
195
|
+
|
196
|
+
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
|
197
|
+
"""Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
|
198
|
+
total_ops, total_params = module.total_ops.item(), 0
|
199
|
+
ret_dict = {}
|
200
|
+
for n, m in module.named_children():
|
201
|
+
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
|
202
|
+
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
|
203
|
+
# else:
|
204
|
+
# m_ops, m_params = m.total_ops, m.total_params
|
205
|
+
next_dict = {}
|
206
|
+
if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
|
207
|
+
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
208
|
+
else:
|
209
|
+
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
|
210
|
+
ret_dict[n] = (m_ops, m_params, next_dict)
|
211
|
+
total_ops += m_ops
|
212
|
+
total_params += m_params
|
213
|
+
# print(prefix, module._get_name(), (total_ops, total_params))
|
214
|
+
return total_ops, total_params, ret_dict
|
215
|
+
|
216
|
+
total_ops, total_params, ret_dict = dfs_count(model)
|
217
|
+
|
218
|
+
# reset model to original status
|
219
|
+
model.train(prev_training_status)
|
220
|
+
for m, (op_handler, params_handler) in handler_collection.items():
|
221
|
+
op_handler.remove()
|
222
|
+
params_handler.remove()
|
223
|
+
m._buffers.pop("total_ops")
|
224
|
+
m._buffers.pop("total_params")
|
225
|
+
|
226
|
+
if ret_layer_info:
|
227
|
+
return total_ops, total_params, ret_dict
|
228
|
+
return total_ops, total_params
|
thop/rnn_hooks.py
ADDED
@@ -0,0 +1,202 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
from torch.nn.utils.rnn import PackedSequence
|
6
|
+
|
7
|
+
|
8
|
+
def _count_rnn_cell(input_size, hidden_size, bias=True):
|
9
|
+
"""Calculate the total operations for an RNN cell given input size, hidden size, and optional bias."""
|
10
|
+
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
|
11
|
+
if bias:
|
12
|
+
total_ops += hidden_size * 2
|
13
|
+
|
14
|
+
return total_ops
|
15
|
+
|
16
|
+
|
17
|
+
def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
|
18
|
+
"""Counts the total RNN cell operations based on input tensor, hidden size, bias, and batch size."""
|
19
|
+
total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)
|
20
|
+
|
21
|
+
batch_size = x[0].size(0)
|
22
|
+
total_ops *= batch_size
|
23
|
+
|
24
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
25
|
+
|
26
|
+
|
27
|
+
def _count_gru_cell(input_size, hidden_size, bias=True):
|
28
|
+
"""Counts the total operations for a GRU cell based on input size, hidden size, and bias configuration."""
|
29
|
+
total_ops = 0
|
30
|
+
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
|
31
|
+
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
|
32
|
+
state_ops = (hidden_size + input_size) * hidden_size + hidden_size
|
33
|
+
if bias:
|
34
|
+
state_ops += hidden_size * 2
|
35
|
+
total_ops += state_ops * 2
|
36
|
+
|
37
|
+
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
|
38
|
+
total_ops += (hidden_size + input_size) * hidden_size + hidden_size
|
39
|
+
if bias:
|
40
|
+
total_ops += hidden_size * 2
|
41
|
+
# r hadamard : r * (~)
|
42
|
+
total_ops += hidden_size
|
43
|
+
|
44
|
+
# h' = (1 - z) * n + z * h
|
45
|
+
# hadamard hadamard add
|
46
|
+
total_ops += hidden_size * 3
|
47
|
+
|
48
|
+
return total_ops
|
49
|
+
|
50
|
+
|
51
|
+
def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
|
52
|
+
"""Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
|
53
|
+
total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)
|
54
|
+
|
55
|
+
batch_size = x[0].size(0)
|
56
|
+
total_ops *= batch_size
|
57
|
+
|
58
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
59
|
+
|
60
|
+
|
61
|
+
def _count_lstm_cell(input_size, hidden_size, bias=True):
|
62
|
+
"""Counts LSTM cell operations during inference based on input size, hidden size, and bias configuration."""
|
63
|
+
total_ops = 0
|
64
|
+
|
65
|
+
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
|
66
|
+
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
|
67
|
+
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
|
68
|
+
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
|
69
|
+
state_ops = (input_size + hidden_size) * hidden_size + hidden_size
|
70
|
+
if bias:
|
71
|
+
state_ops += hidden_size * 2
|
72
|
+
total_ops += state_ops * 4
|
73
|
+
|
74
|
+
# c' = f * c + i * g \\
|
75
|
+
# hadamard hadamard add
|
76
|
+
total_ops += hidden_size * 3
|
77
|
+
|
78
|
+
# h' = o * \tanh(c') \\
|
79
|
+
total_ops += hidden_size
|
80
|
+
|
81
|
+
return total_ops
|
82
|
+
|
83
|
+
|
84
|
+
def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
|
85
|
+
"""Counts and updates the total operations for an LSTM cell in a mini-batch during inference."""
|
86
|
+
total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)
|
87
|
+
|
88
|
+
batch_size = x[0].size(0)
|
89
|
+
total_ops *= batch_size
|
90
|
+
|
91
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
92
|
+
|
93
|
+
|
94
|
+
def count_rnn(m: nn.RNN, x, y):
|
95
|
+
"""Calculate and update the total number of operations for each RNN cell in a given batch."""
|
96
|
+
bias = m.bias
|
97
|
+
input_size = m.input_size
|
98
|
+
hidden_size = m.hidden_size
|
99
|
+
num_layers = m.num_layers
|
100
|
+
|
101
|
+
if isinstance(x[0], PackedSequence):
|
102
|
+
batch_size = torch.max(x[0].batch_sizes)
|
103
|
+
num_steps = x[0].batch_sizes.size(0)
|
104
|
+
elif m.batch_first:
|
105
|
+
batch_size = x[0].size(0)
|
106
|
+
num_steps = x[0].size(1)
|
107
|
+
else:
|
108
|
+
batch_size = x[0].size(1)
|
109
|
+
num_steps = x[0].size(0)
|
110
|
+
|
111
|
+
total_ops = 0
|
112
|
+
if m.bidirectional:
|
113
|
+
total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2
|
114
|
+
else:
|
115
|
+
total_ops += _count_rnn_cell(input_size, hidden_size, bias)
|
116
|
+
|
117
|
+
for _ in range(num_layers - 1):
|
118
|
+
total_ops += (
|
119
|
+
_count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
|
120
|
+
if m.bidirectional
|
121
|
+
else _count_rnn_cell(hidden_size, hidden_size, bias)
|
122
|
+
)
|
123
|
+
# time unroll
|
124
|
+
total_ops *= num_steps
|
125
|
+
# batch_size
|
126
|
+
total_ops *= batch_size
|
127
|
+
|
128
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
129
|
+
|
130
|
+
|
131
|
+
def count_gru(m: nn.GRU, x, y):
|
132
|
+
"""Calculates total operations for a GRU layer, updating the model's operation count based on batch size."""
|
133
|
+
bias = m.bias
|
134
|
+
input_size = m.input_size
|
135
|
+
hidden_size = m.hidden_size
|
136
|
+
num_layers = m.num_layers
|
137
|
+
|
138
|
+
if isinstance(x[0], PackedSequence):
|
139
|
+
batch_size = torch.max(x[0].batch_sizes)
|
140
|
+
num_steps = x[0].batch_sizes.size(0)
|
141
|
+
elif m.batch_first:
|
142
|
+
batch_size = x[0].size(0)
|
143
|
+
num_steps = x[0].size(1)
|
144
|
+
else:
|
145
|
+
batch_size = x[0].size(1)
|
146
|
+
num_steps = x[0].size(0)
|
147
|
+
|
148
|
+
total_ops = 0
|
149
|
+
if m.bidirectional:
|
150
|
+
total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2
|
151
|
+
else:
|
152
|
+
total_ops += _count_gru_cell(input_size, hidden_size, bias)
|
153
|
+
|
154
|
+
for _ in range(num_layers - 1):
|
155
|
+
total_ops += (
|
156
|
+
_count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
|
157
|
+
if m.bidirectional
|
158
|
+
else _count_gru_cell(hidden_size, hidden_size, bias)
|
159
|
+
)
|
160
|
+
# time unroll
|
161
|
+
total_ops *= num_steps
|
162
|
+
# batch_size
|
163
|
+
total_ops *= batch_size
|
164
|
+
|
165
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
166
|
+
|
167
|
+
|
168
|
+
def count_lstm(m: nn.LSTM, x, y):
|
169
|
+
"""Calculate total operations for LSTM layers, including bidirectional, updating model's total operations."""
|
170
|
+
bias = m.bias
|
171
|
+
input_size = m.input_size
|
172
|
+
hidden_size = m.hidden_size
|
173
|
+
num_layers = m.num_layers
|
174
|
+
|
175
|
+
if isinstance(x[0], PackedSequence):
|
176
|
+
batch_size = torch.max(x[0].batch_sizes)
|
177
|
+
num_steps = x[0].batch_sizes.size(0)
|
178
|
+
elif m.batch_first:
|
179
|
+
batch_size = x[0].size(0)
|
180
|
+
num_steps = x[0].size(1)
|
181
|
+
else:
|
182
|
+
batch_size = x[0].size(1)
|
183
|
+
num_steps = x[0].size(0)
|
184
|
+
|
185
|
+
total_ops = 0
|
186
|
+
if m.bidirectional:
|
187
|
+
total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2
|
188
|
+
else:
|
189
|
+
total_ops += _count_lstm_cell(input_size, hidden_size, bias)
|
190
|
+
|
191
|
+
for _ in range(num_layers - 1):
|
192
|
+
total_ops += (
|
193
|
+
_count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
|
194
|
+
if m.bidirectional
|
195
|
+
else _count_lstm_cell(hidden_size, hidden_size, bias)
|
196
|
+
)
|
197
|
+
# time unroll
|
198
|
+
total_ops *= num_steps
|
199
|
+
# batch_size
|
200
|
+
total_ops *= batch_size
|
201
|
+
|
202
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|