ultralytics-thop 2.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- thop/__init__.py +11 -0
- thop/fx_profile.py +237 -0
- thop/profile.py +228 -0
- thop/rnn_hooks.py +202 -0
- thop/utils.py +50 -0
- thop/vision/__init__.py +1 -0
- thop/vision/basic_hooks.py +146 -0
- thop/vision/calc_func.py +127 -0
- ultralytics_thop-2.0.14.dist-info/LICENSE +661 -0
- ultralytics_thop-2.0.14.dist-info/METADATA +199 -0
- ultralytics_thop-2.0.14.dist-info/RECORD +13 -0
- ultralytics_thop-2.0.14.dist-info/WHEEL +5 -0
- ultralytics_thop-2.0.14.dist-info/top_level.txt +1 -0
thop/__init__.py
ADDED
thop/fx_profile.py
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from distutils.version import LooseVersion
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch as th
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
|
11
|
+
logging.warning(
|
12
|
+
f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. "
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def count_clamp(input_shapes, output_shapes):
|
17
|
+
"""Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
|
18
|
+
return 0
|
19
|
+
|
20
|
+
|
21
|
+
def count_mul(input_shapes, output_shapes):
|
22
|
+
"""Returns the number of elements in the first output shape."""
|
23
|
+
return output_shapes[0].numel()
|
24
|
+
|
25
|
+
|
26
|
+
def count_matmul(input_shapes, output_shapes):
|
27
|
+
"""Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
|
28
|
+
in_shape = input_shapes[0]
|
29
|
+
out_shape = output_shapes[0]
|
30
|
+
in_features = in_shape[-1]
|
31
|
+
num_elements = out_shape.numel()
|
32
|
+
return in_features * num_elements
|
33
|
+
|
34
|
+
|
35
|
+
def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
|
36
|
+
"""Calculates the total FLOPs for a linear layer, including bias operations if specified."""
|
37
|
+
flops = count_matmul(input_shapes, output_shapes)
|
38
|
+
if "bias" in kwargs:
|
39
|
+
flops += output_shapes[0].numel()
|
40
|
+
return flops
|
41
|
+
|
42
|
+
|
43
|
+
from .vision.calc_func import calculate_conv
|
44
|
+
|
45
|
+
|
46
|
+
def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
|
47
|
+
"""Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
|
48
|
+
`calculate_conv`.
|
49
|
+
"""
|
50
|
+
inputs, weight, bias, stride, padding, dilation, groups = args
|
51
|
+
if len(input_shapes) == 2:
|
52
|
+
x_shape, k_shape = input_shapes
|
53
|
+
elif len(input_shapes) == 3:
|
54
|
+
x_shape, k_shape, b_shape = input_shapes
|
55
|
+
out_shape = output_shapes[0]
|
56
|
+
|
57
|
+
kernel_parameters = k_shape[2:].numel()
|
58
|
+
bias_op = 0 # check it later
|
59
|
+
in_channel = x_shape[1]
|
60
|
+
|
61
|
+
total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
|
62
|
+
return int(total_ops)
|
63
|
+
|
64
|
+
|
65
|
+
def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
|
66
|
+
"""Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
|
67
|
+
return count_matmul(input_shapes, output_shapes)
|
68
|
+
|
69
|
+
|
70
|
+
def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
|
71
|
+
"""Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
|
72
|
+
return 0
|
73
|
+
|
74
|
+
|
75
|
+
def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
|
76
|
+
"""Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
|
77
|
+
bias_op = 1 if module.bias is not None else 0
|
78
|
+
out_shape = output_shapes[0]
|
79
|
+
|
80
|
+
in_channel = module.in_channels
|
81
|
+
groups = module.groups
|
82
|
+
kernel_ops = module.weight.shape[2:].numel()
|
83
|
+
total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
|
84
|
+
return int(total_ops)
|
85
|
+
|
86
|
+
|
87
|
+
def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
|
88
|
+
"""Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
|
89
|
+
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
|
90
|
+
y = output_shapes[0]
|
91
|
+
return 2 * y.numel()
|
92
|
+
|
93
|
+
|
94
|
+
zero_ops = (
|
95
|
+
nn.ReLU,
|
96
|
+
nn.ReLU6,
|
97
|
+
nn.Dropout,
|
98
|
+
nn.MaxPool2d,
|
99
|
+
nn.AvgPool2d,
|
100
|
+
nn.AdaptiveAvgPool2d,
|
101
|
+
)
|
102
|
+
|
103
|
+
count_map = {
|
104
|
+
nn.Linear: count_nn_linear,
|
105
|
+
nn.Conv2d: count_nn_conv2d,
|
106
|
+
nn.BatchNorm2d: count_nn_bn2d,
|
107
|
+
"function linear": count_fn_linear,
|
108
|
+
"clamp": count_clamp,
|
109
|
+
"built-in function add": count_zero_ops,
|
110
|
+
"built-in method fl": count_zero_ops,
|
111
|
+
"built-in method conv2d of type object": count_fn_conv2d,
|
112
|
+
"built-in function mul": count_mul,
|
113
|
+
"built-in function truediv": count_mul,
|
114
|
+
}
|
115
|
+
|
116
|
+
for k in zero_ops:
|
117
|
+
count_map[k] = count_zero_ops
|
118
|
+
|
119
|
+
missing_maps = {}
|
120
|
+
|
121
|
+
from torch.fx import symbolic_trace
|
122
|
+
from torch.fx.passes.shape_prop import ShapeProp
|
123
|
+
|
124
|
+
from .utils import prRed, prYellow
|
125
|
+
|
126
|
+
|
127
|
+
def null_print(*args, **kwargs):
|
128
|
+
"""A no-op print function that takes any arguments without performing any actions."""
|
129
|
+
return
|
130
|
+
|
131
|
+
|
132
|
+
def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
|
133
|
+
"""Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
|
134
|
+
gm: torch.fx.GraphModule = symbolic_trace(mod)
|
135
|
+
ShapeProp(gm).propagate(input)
|
136
|
+
|
137
|
+
fprint = null_print
|
138
|
+
if verbose:
|
139
|
+
fprint = print
|
140
|
+
|
141
|
+
v_maps = {}
|
142
|
+
total_flops = 0
|
143
|
+
|
144
|
+
for node in gm.graph.nodes:
|
145
|
+
# print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
|
146
|
+
fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
|
147
|
+
# node_op_type = str(node.target).split(".")[-1]
|
148
|
+
node_flops = None
|
149
|
+
|
150
|
+
input_shapes = []
|
151
|
+
fprint("input_shape:", end="\t")
|
152
|
+
for arg in node.args:
|
153
|
+
if str(arg) not in v_maps:
|
154
|
+
continue
|
155
|
+
fprint(f"{v_maps[str(arg)]}", end="\t")
|
156
|
+
input_shapes.append(v_maps[str(arg)])
|
157
|
+
fprint()
|
158
|
+
fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
|
159
|
+
output_shapes = [node.meta["tensor_meta"].shape]
|
160
|
+
if node.op in ["output", "placeholder"]:
|
161
|
+
node_flops = 0
|
162
|
+
elif node.op == "call_function":
|
163
|
+
# torch internal functions
|
164
|
+
key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
|
165
|
+
if key in count_map:
|
166
|
+
node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
|
167
|
+
else:
|
168
|
+
missing_maps[key] = (node.op, key)
|
169
|
+
prRed(f"|{key}| is missing")
|
170
|
+
elif node.op == "call_method":
|
171
|
+
# torch internal functions
|
172
|
+
# fprint(str(node.target) in count_map, str(node.target), count_map.keys())
|
173
|
+
key = str(node.target)
|
174
|
+
if key in count_map:
|
175
|
+
node_flops = count_map[key](input_shapes, output_shapes)
|
176
|
+
else:
|
177
|
+
missing_maps[key] = (node.op, key)
|
178
|
+
prRed(f"{key} is missing")
|
179
|
+
elif node.op == "call_module":
|
180
|
+
# torch.nn modules
|
181
|
+
# m = getattr(mod, node.target, None)
|
182
|
+
m = mod.get_submodule(node.target)
|
183
|
+
key = type(m)
|
184
|
+
fprint(type(m), type(m) in count_map)
|
185
|
+
if type(m) in count_map:
|
186
|
+
node_flops = count_map[type(m)](m, input_shapes, output_shapes)
|
187
|
+
else:
|
188
|
+
missing_maps[key] = (node.op,)
|
189
|
+
prRed(f"{key} is missing")
|
190
|
+
print("module type:", type(m))
|
191
|
+
if isinstance(m, zero_ops):
|
192
|
+
print("weight_shape: None")
|
193
|
+
else:
|
194
|
+
print(type(m))
|
195
|
+
print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")
|
196
|
+
|
197
|
+
v_maps[str(node.name)] = node.meta["tensor_meta"].shape
|
198
|
+
if node_flops is not None:
|
199
|
+
total_flops += node_flops
|
200
|
+
prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
|
201
|
+
fprint("==" * 40)
|
202
|
+
|
203
|
+
if len(missing_maps.keys()) > 0:
|
204
|
+
from pprint import pprint
|
205
|
+
|
206
|
+
print("Missing operators: ")
|
207
|
+
pprint(missing_maps)
|
208
|
+
return total_flops
|
209
|
+
|
210
|
+
|
211
|
+
if __name__ == "__main__":
|
212
|
+
|
213
|
+
class MyOP(nn.Module):
|
214
|
+
def forward(self, input):
|
215
|
+
"""Performs forward pass on given input data."""
|
216
|
+
return input / 1
|
217
|
+
|
218
|
+
class MyModule(torch.nn.Module):
|
219
|
+
def __init__(self):
|
220
|
+
"""Initializes MyModule with two linear layers and a custom MyOP operator."""
|
221
|
+
super().__init__()
|
222
|
+
self.linear1 = torch.nn.Linear(5, 3)
|
223
|
+
self.linear2 = torch.nn.Linear(5, 3)
|
224
|
+
self.myop = MyOP()
|
225
|
+
|
226
|
+
def forward(self, x):
|
227
|
+
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes
|
228
|
+
with MyOP operator.
|
229
|
+
"""
|
230
|
+
out1 = self.linear1(x)
|
231
|
+
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
|
232
|
+
return self.myop(out1 + out2)
|
233
|
+
|
234
|
+
net = MyModule()
|
235
|
+
data = th.randn(20, 5)
|
236
|
+
flops = fx_profile(net, data, verbose=False)
|
237
|
+
print(flops)
|
thop/profile.py
ADDED
@@ -0,0 +1,228 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from thop.rnn_hooks import *
|
4
|
+
from thop.vision.basic_hooks import *
|
5
|
+
|
6
|
+
from .utils import prRed
|
7
|
+
|
8
|
+
default_dtype = torch.float64
|
9
|
+
|
10
|
+
register_hooks = {
|
11
|
+
nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
|
12
|
+
nn.Conv1d: count_convNd,
|
13
|
+
nn.Conv2d: count_convNd,
|
14
|
+
nn.Conv3d: count_convNd,
|
15
|
+
nn.ConvTranspose1d: count_convNd,
|
16
|
+
nn.ConvTranspose2d: count_convNd,
|
17
|
+
nn.ConvTranspose3d: count_convNd,
|
18
|
+
nn.BatchNorm1d: count_normalization,
|
19
|
+
nn.BatchNorm2d: count_normalization,
|
20
|
+
nn.BatchNorm3d: count_normalization,
|
21
|
+
nn.LayerNorm: count_normalization,
|
22
|
+
nn.InstanceNorm1d: count_normalization,
|
23
|
+
nn.InstanceNorm2d: count_normalization,
|
24
|
+
nn.InstanceNorm3d: count_normalization,
|
25
|
+
nn.PReLU: count_prelu,
|
26
|
+
nn.Softmax: count_softmax,
|
27
|
+
nn.ReLU: zero_ops,
|
28
|
+
nn.ReLU6: zero_ops,
|
29
|
+
nn.LeakyReLU: count_relu,
|
30
|
+
nn.MaxPool1d: zero_ops,
|
31
|
+
nn.MaxPool2d: zero_ops,
|
32
|
+
nn.MaxPool3d: zero_ops,
|
33
|
+
nn.AdaptiveMaxPool1d: zero_ops,
|
34
|
+
nn.AdaptiveMaxPool2d: zero_ops,
|
35
|
+
nn.AdaptiveMaxPool3d: zero_ops,
|
36
|
+
nn.AvgPool1d: count_avgpool,
|
37
|
+
nn.AvgPool2d: count_avgpool,
|
38
|
+
nn.AvgPool3d: count_avgpool,
|
39
|
+
nn.AdaptiveAvgPool1d: count_adap_avgpool,
|
40
|
+
nn.AdaptiveAvgPool2d: count_adap_avgpool,
|
41
|
+
nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
42
|
+
nn.Linear: count_linear,
|
43
|
+
nn.Dropout: zero_ops,
|
44
|
+
nn.Upsample: count_upsample,
|
45
|
+
nn.UpsamplingBilinear2d: count_upsample,
|
46
|
+
nn.UpsamplingNearest2d: count_upsample,
|
47
|
+
nn.RNNCell: count_rnn_cell,
|
48
|
+
nn.GRUCell: count_gru_cell,
|
49
|
+
nn.LSTMCell: count_lstm_cell,
|
50
|
+
nn.RNN: count_rnn,
|
51
|
+
nn.GRU: count_gru,
|
52
|
+
nn.LSTM: count_lstm,
|
53
|
+
nn.Sequential: zero_ops,
|
54
|
+
nn.PixelShuffle: zero_ops,
|
55
|
+
nn.SyncBatchNorm: count_normalization,
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
|
60
|
+
"""Profiles a PyTorch model's operations and parameters, applying either custom or default hooks."""
|
61
|
+
handler_collection = []
|
62
|
+
types_collection = set()
|
63
|
+
if custom_ops is None:
|
64
|
+
custom_ops = {}
|
65
|
+
if report_missing:
|
66
|
+
verbose = True
|
67
|
+
|
68
|
+
def add_hooks(m):
|
69
|
+
if list(m.children()):
|
70
|
+
return
|
71
|
+
|
72
|
+
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
73
|
+
logging.warning(
|
74
|
+
f"Either .total_ops or .total_params is already defined in {str(m)}. "
|
75
|
+
"Be careful, it might change your code's behavior."
|
76
|
+
)
|
77
|
+
|
78
|
+
m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
|
79
|
+
m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
|
80
|
+
|
81
|
+
for p in m.parameters():
|
82
|
+
m.total_params += torch.DoubleTensor([p.numel()])
|
83
|
+
|
84
|
+
m_type = type(m)
|
85
|
+
|
86
|
+
fn = None
|
87
|
+
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
|
88
|
+
fn = custom_ops[m_type]
|
89
|
+
if m_type not in types_collection and verbose:
|
90
|
+
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
|
91
|
+
elif m_type in register_hooks:
|
92
|
+
fn = register_hooks[m_type]
|
93
|
+
if m_type not in types_collection and verbose:
|
94
|
+
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
|
95
|
+
else:
|
96
|
+
if m_type not in types_collection and report_missing:
|
97
|
+
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
|
98
|
+
|
99
|
+
if fn is not None:
|
100
|
+
handler = m.register_forward_hook(fn)
|
101
|
+
handler_collection.append(handler)
|
102
|
+
types_collection.add(m_type)
|
103
|
+
|
104
|
+
training = model.training
|
105
|
+
|
106
|
+
model.eval()
|
107
|
+
model.apply(add_hooks)
|
108
|
+
|
109
|
+
with torch.no_grad():
|
110
|
+
model(*inputs)
|
111
|
+
|
112
|
+
total_ops = 0
|
113
|
+
total_params = 0
|
114
|
+
for m in model.modules():
|
115
|
+
if list(m.children()): # skip for non-leaf module
|
116
|
+
continue
|
117
|
+
total_ops += m.total_ops
|
118
|
+
total_params += m.total_params
|
119
|
+
|
120
|
+
total_ops = total_ops.item()
|
121
|
+
total_params = total_params.item()
|
122
|
+
|
123
|
+
# reset model to original status
|
124
|
+
model.train(training)
|
125
|
+
for handler in handler_collection:
|
126
|
+
handler.remove()
|
127
|
+
|
128
|
+
# remove temporal buffers
|
129
|
+
for n, m in model.named_modules():
|
130
|
+
if list(m.children()):
|
131
|
+
continue
|
132
|
+
if "total_ops" in m._buffers:
|
133
|
+
m._buffers.pop("total_ops")
|
134
|
+
if "total_params" in m._buffers:
|
135
|
+
m._buffers.pop("total_params")
|
136
|
+
|
137
|
+
return total_ops, total_params
|
138
|
+
|
139
|
+
|
140
|
+
def profile(
|
141
|
+
model: nn.Module,
|
142
|
+
inputs,
|
143
|
+
custom_ops=None,
|
144
|
+
verbose=True,
|
145
|
+
ret_layer_info=False,
|
146
|
+
report_missing=False,
|
147
|
+
):
|
148
|
+
"""Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details."""
|
149
|
+
handler_collection = {}
|
150
|
+
types_collection = set()
|
151
|
+
if custom_ops is None:
|
152
|
+
custom_ops = {}
|
153
|
+
if report_missing:
|
154
|
+
# overwrite `verbose` option when enable report_missing
|
155
|
+
verbose = True
|
156
|
+
|
157
|
+
def add_hooks(m: nn.Module):
|
158
|
+
"""Registers hooks to a neural network module to track total operations and parameters."""
|
159
|
+
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
|
160
|
+
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
|
161
|
+
|
162
|
+
# for p in m.parameters():
|
163
|
+
# m.total_params += torch.DoubleTensor([p.numel()])
|
164
|
+
|
165
|
+
m_type = type(m)
|
166
|
+
|
167
|
+
fn = None
|
168
|
+
if m_type in custom_ops:
|
169
|
+
# if defined both op maps, use custom_ops to overwrite.
|
170
|
+
fn = custom_ops[m_type]
|
171
|
+
if m_type not in types_collection and verbose:
|
172
|
+
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
|
173
|
+
elif m_type in register_hooks:
|
174
|
+
fn = register_hooks[m_type]
|
175
|
+
if m_type not in types_collection and verbose:
|
176
|
+
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
|
177
|
+
else:
|
178
|
+
if m_type not in types_collection and report_missing:
|
179
|
+
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
|
180
|
+
|
181
|
+
if fn is not None:
|
182
|
+
handler_collection[m] = (
|
183
|
+
m.register_forward_hook(fn),
|
184
|
+
m.register_forward_hook(count_parameters),
|
185
|
+
)
|
186
|
+
types_collection.add(m_type)
|
187
|
+
|
188
|
+
prev_training_status = model.training
|
189
|
+
|
190
|
+
model.eval()
|
191
|
+
model.apply(add_hooks)
|
192
|
+
|
193
|
+
with torch.no_grad():
|
194
|
+
model(*inputs)
|
195
|
+
|
196
|
+
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
|
197
|
+
"""Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
|
198
|
+
total_ops, total_params = module.total_ops.item(), 0
|
199
|
+
ret_dict = {}
|
200
|
+
for n, m in module.named_children():
|
201
|
+
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
|
202
|
+
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
|
203
|
+
# else:
|
204
|
+
# m_ops, m_params = m.total_ops, m.total_params
|
205
|
+
next_dict = {}
|
206
|
+
if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
|
207
|
+
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
208
|
+
else:
|
209
|
+
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
|
210
|
+
ret_dict[n] = (m_ops, m_params, next_dict)
|
211
|
+
total_ops += m_ops
|
212
|
+
total_params += m_params
|
213
|
+
# print(prefix, module._get_name(), (total_ops, total_params))
|
214
|
+
return total_ops, total_params, ret_dict
|
215
|
+
|
216
|
+
total_ops, total_params, ret_dict = dfs_count(model)
|
217
|
+
|
218
|
+
# reset model to original status
|
219
|
+
model.train(prev_training_status)
|
220
|
+
for m, (op_handler, params_handler) in handler_collection.items():
|
221
|
+
op_handler.remove()
|
222
|
+
params_handler.remove()
|
223
|
+
m._buffers.pop("total_ops")
|
224
|
+
m._buffers.pop("total_params")
|
225
|
+
|
226
|
+
if ret_layer_info:
|
227
|
+
return total_ops, total_params, ret_dict
|
228
|
+
return total_ops, total_params
|
thop/rnn_hooks.py
ADDED
@@ -0,0 +1,202 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
from torch.nn.utils.rnn import PackedSequence
|
6
|
+
|
7
|
+
|
8
|
+
def _count_rnn_cell(input_size, hidden_size, bias=True):
|
9
|
+
"""Calculate the total operations for an RNN cell given input size, hidden size, and optional bias."""
|
10
|
+
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
|
11
|
+
if bias:
|
12
|
+
total_ops += hidden_size * 2
|
13
|
+
|
14
|
+
return total_ops
|
15
|
+
|
16
|
+
|
17
|
+
def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
|
18
|
+
"""Counts the total RNN cell operations based on input tensor, hidden size, bias, and batch size."""
|
19
|
+
total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)
|
20
|
+
|
21
|
+
batch_size = x[0].size(0)
|
22
|
+
total_ops *= batch_size
|
23
|
+
|
24
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
25
|
+
|
26
|
+
|
27
|
+
def _count_gru_cell(input_size, hidden_size, bias=True):
|
28
|
+
"""Counts the total operations for a GRU cell based on input size, hidden size, and bias configuration."""
|
29
|
+
total_ops = 0
|
30
|
+
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
|
31
|
+
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
|
32
|
+
state_ops = (hidden_size + input_size) * hidden_size + hidden_size
|
33
|
+
if bias:
|
34
|
+
state_ops += hidden_size * 2
|
35
|
+
total_ops += state_ops * 2
|
36
|
+
|
37
|
+
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
|
38
|
+
total_ops += (hidden_size + input_size) * hidden_size + hidden_size
|
39
|
+
if bias:
|
40
|
+
total_ops += hidden_size * 2
|
41
|
+
# r hadamard : r * (~)
|
42
|
+
total_ops += hidden_size
|
43
|
+
|
44
|
+
# h' = (1 - z) * n + z * h
|
45
|
+
# hadamard hadamard add
|
46
|
+
total_ops += hidden_size * 3
|
47
|
+
|
48
|
+
return total_ops
|
49
|
+
|
50
|
+
|
51
|
+
def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
|
52
|
+
"""Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
|
53
|
+
total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)
|
54
|
+
|
55
|
+
batch_size = x[0].size(0)
|
56
|
+
total_ops *= batch_size
|
57
|
+
|
58
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
59
|
+
|
60
|
+
|
61
|
+
def _count_lstm_cell(input_size, hidden_size, bias=True):
|
62
|
+
"""Counts LSTM cell operations during inference based on input size, hidden size, and bias configuration."""
|
63
|
+
total_ops = 0
|
64
|
+
|
65
|
+
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
|
66
|
+
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
|
67
|
+
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
|
68
|
+
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
|
69
|
+
state_ops = (input_size + hidden_size) * hidden_size + hidden_size
|
70
|
+
if bias:
|
71
|
+
state_ops += hidden_size * 2
|
72
|
+
total_ops += state_ops * 4
|
73
|
+
|
74
|
+
# c' = f * c + i * g \\
|
75
|
+
# hadamard hadamard add
|
76
|
+
total_ops += hidden_size * 3
|
77
|
+
|
78
|
+
# h' = o * \tanh(c') \\
|
79
|
+
total_ops += hidden_size
|
80
|
+
|
81
|
+
return total_ops
|
82
|
+
|
83
|
+
|
84
|
+
def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
|
85
|
+
"""Counts and updates the total operations for an LSTM cell in a mini-batch during inference."""
|
86
|
+
total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)
|
87
|
+
|
88
|
+
batch_size = x[0].size(0)
|
89
|
+
total_ops *= batch_size
|
90
|
+
|
91
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
92
|
+
|
93
|
+
|
94
|
+
def count_rnn(m: nn.RNN, x, y):
|
95
|
+
"""Calculate and update the total number of operations for each RNN cell in a given batch."""
|
96
|
+
bias = m.bias
|
97
|
+
input_size = m.input_size
|
98
|
+
hidden_size = m.hidden_size
|
99
|
+
num_layers = m.num_layers
|
100
|
+
|
101
|
+
if isinstance(x[0], PackedSequence):
|
102
|
+
batch_size = torch.max(x[0].batch_sizes)
|
103
|
+
num_steps = x[0].batch_sizes.size(0)
|
104
|
+
elif m.batch_first:
|
105
|
+
batch_size = x[0].size(0)
|
106
|
+
num_steps = x[0].size(1)
|
107
|
+
else:
|
108
|
+
batch_size = x[0].size(1)
|
109
|
+
num_steps = x[0].size(0)
|
110
|
+
|
111
|
+
total_ops = 0
|
112
|
+
if m.bidirectional:
|
113
|
+
total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2
|
114
|
+
else:
|
115
|
+
total_ops += _count_rnn_cell(input_size, hidden_size, bias)
|
116
|
+
|
117
|
+
for _ in range(num_layers - 1):
|
118
|
+
total_ops += (
|
119
|
+
_count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
|
120
|
+
if m.bidirectional
|
121
|
+
else _count_rnn_cell(hidden_size, hidden_size, bias)
|
122
|
+
)
|
123
|
+
# time unroll
|
124
|
+
total_ops *= num_steps
|
125
|
+
# batch_size
|
126
|
+
total_ops *= batch_size
|
127
|
+
|
128
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
129
|
+
|
130
|
+
|
131
|
+
def count_gru(m: nn.GRU, x, y):
|
132
|
+
"""Calculates total operations for a GRU layer, updating the model's operation count based on batch size."""
|
133
|
+
bias = m.bias
|
134
|
+
input_size = m.input_size
|
135
|
+
hidden_size = m.hidden_size
|
136
|
+
num_layers = m.num_layers
|
137
|
+
|
138
|
+
if isinstance(x[0], PackedSequence):
|
139
|
+
batch_size = torch.max(x[0].batch_sizes)
|
140
|
+
num_steps = x[0].batch_sizes.size(0)
|
141
|
+
elif m.batch_first:
|
142
|
+
batch_size = x[0].size(0)
|
143
|
+
num_steps = x[0].size(1)
|
144
|
+
else:
|
145
|
+
batch_size = x[0].size(1)
|
146
|
+
num_steps = x[0].size(0)
|
147
|
+
|
148
|
+
total_ops = 0
|
149
|
+
if m.bidirectional:
|
150
|
+
total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2
|
151
|
+
else:
|
152
|
+
total_ops += _count_gru_cell(input_size, hidden_size, bias)
|
153
|
+
|
154
|
+
for _ in range(num_layers - 1):
|
155
|
+
total_ops += (
|
156
|
+
_count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
|
157
|
+
if m.bidirectional
|
158
|
+
else _count_gru_cell(hidden_size, hidden_size, bias)
|
159
|
+
)
|
160
|
+
# time unroll
|
161
|
+
total_ops *= num_steps
|
162
|
+
# batch_size
|
163
|
+
total_ops *= batch_size
|
164
|
+
|
165
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|
166
|
+
|
167
|
+
|
168
|
+
def count_lstm(m: nn.LSTM, x, y):
|
169
|
+
"""Calculate total operations for LSTM layers, including bidirectional, updating model's total operations."""
|
170
|
+
bias = m.bias
|
171
|
+
input_size = m.input_size
|
172
|
+
hidden_size = m.hidden_size
|
173
|
+
num_layers = m.num_layers
|
174
|
+
|
175
|
+
if isinstance(x[0], PackedSequence):
|
176
|
+
batch_size = torch.max(x[0].batch_sizes)
|
177
|
+
num_steps = x[0].batch_sizes.size(0)
|
178
|
+
elif m.batch_first:
|
179
|
+
batch_size = x[0].size(0)
|
180
|
+
num_steps = x[0].size(1)
|
181
|
+
else:
|
182
|
+
batch_size = x[0].size(1)
|
183
|
+
num_steps = x[0].size(0)
|
184
|
+
|
185
|
+
total_ops = 0
|
186
|
+
if m.bidirectional:
|
187
|
+
total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2
|
188
|
+
else:
|
189
|
+
total_ops += _count_lstm_cell(input_size, hidden_size, bias)
|
190
|
+
|
191
|
+
for _ in range(num_layers - 1):
|
192
|
+
total_ops += (
|
193
|
+
_count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
|
194
|
+
if m.bidirectional
|
195
|
+
else _count_lstm_cell(hidden_size, hidden_size, bias)
|
196
|
+
)
|
197
|
+
# time unroll
|
198
|
+
total_ops *= num_steps
|
199
|
+
# batch_size
|
200
|
+
total_ops *= batch_size
|
201
|
+
|
202
|
+
m.total_ops += torch.DoubleTensor([int(total_ops)])
|