ultralytics-thop 2.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
thop/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ __version__ = "2.0.14"
4
+
5
+
6
+ import torch
7
+
8
+ from .profile import profile, profile_origin
9
+ from .utils import clever_format
10
+
11
+ default_dtype = torch.float64
thop/fx_profile.py ADDED
@@ -0,0 +1,237 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import logging
4
+ from distutils.version import LooseVersion
5
+
6
+ import torch
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+ if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
11
+ logging.warning(
12
+ f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. "
13
+ )
14
+
15
+
16
+ def count_clamp(input_shapes, output_shapes):
17
+ """Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
18
+ return 0
19
+
20
+
21
+ def count_mul(input_shapes, output_shapes):
22
+ """Returns the number of elements in the first output shape."""
23
+ return output_shapes[0].numel()
24
+
25
+
26
+ def count_matmul(input_shapes, output_shapes):
27
+ """Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
28
+ in_shape = input_shapes[0]
29
+ out_shape = output_shapes[0]
30
+ in_features = in_shape[-1]
31
+ num_elements = out_shape.numel()
32
+ return in_features * num_elements
33
+
34
+
35
+ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
36
+ """Calculates the total FLOPs for a linear layer, including bias operations if specified."""
37
+ flops = count_matmul(input_shapes, output_shapes)
38
+ if "bias" in kwargs:
39
+ flops += output_shapes[0].numel()
40
+ return flops
41
+
42
+
43
+ from .vision.calc_func import calculate_conv
44
+
45
+
46
+ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
47
+ """Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
48
+ `calculate_conv`.
49
+ """
50
+ inputs, weight, bias, stride, padding, dilation, groups = args
51
+ if len(input_shapes) == 2:
52
+ x_shape, k_shape = input_shapes
53
+ elif len(input_shapes) == 3:
54
+ x_shape, k_shape, b_shape = input_shapes
55
+ out_shape = output_shapes[0]
56
+
57
+ kernel_parameters = k_shape[2:].numel()
58
+ bias_op = 0 # check it later
59
+ in_channel = x_shape[1]
60
+
61
+ total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
62
+ return int(total_ops)
63
+
64
+
65
+ def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
66
+ """Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
67
+ return count_matmul(input_shapes, output_shapes)
68
+
69
+
70
+ def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
71
+ """Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
72
+ return 0
73
+
74
+
75
+ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
76
+ """Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
77
+ bias_op = 1 if module.bias is not None else 0
78
+ out_shape = output_shapes[0]
79
+
80
+ in_channel = module.in_channels
81
+ groups = module.groups
82
+ kernel_ops = module.weight.shape[2:].numel()
83
+ total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
84
+ return int(total_ops)
85
+
86
+
87
+ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
88
+ """Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
89
+ assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
90
+ y = output_shapes[0]
91
+ return 2 * y.numel()
92
+
93
+
94
+ zero_ops = (
95
+ nn.ReLU,
96
+ nn.ReLU6,
97
+ nn.Dropout,
98
+ nn.MaxPool2d,
99
+ nn.AvgPool2d,
100
+ nn.AdaptiveAvgPool2d,
101
+ )
102
+
103
+ count_map = {
104
+ nn.Linear: count_nn_linear,
105
+ nn.Conv2d: count_nn_conv2d,
106
+ nn.BatchNorm2d: count_nn_bn2d,
107
+ "function linear": count_fn_linear,
108
+ "clamp": count_clamp,
109
+ "built-in function add": count_zero_ops,
110
+ "built-in method fl": count_zero_ops,
111
+ "built-in method conv2d of type object": count_fn_conv2d,
112
+ "built-in function mul": count_mul,
113
+ "built-in function truediv": count_mul,
114
+ }
115
+
116
+ for k in zero_ops:
117
+ count_map[k] = count_zero_ops
118
+
119
+ missing_maps = {}
120
+
121
+ from torch.fx import symbolic_trace
122
+ from torch.fx.passes.shape_prop import ShapeProp
123
+
124
+ from .utils import prRed, prYellow
125
+
126
+
127
+ def null_print(*args, **kwargs):
128
+ """A no-op print function that takes any arguments without performing any actions."""
129
+ return
130
+
131
+
132
+ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
133
+ """Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
134
+ gm: torch.fx.GraphModule = symbolic_trace(mod)
135
+ ShapeProp(gm).propagate(input)
136
+
137
+ fprint = null_print
138
+ if verbose:
139
+ fprint = print
140
+
141
+ v_maps = {}
142
+ total_flops = 0
143
+
144
+ for node in gm.graph.nodes:
145
+ # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
146
+ fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
147
+ # node_op_type = str(node.target).split(".")[-1]
148
+ node_flops = None
149
+
150
+ input_shapes = []
151
+ fprint("input_shape:", end="\t")
152
+ for arg in node.args:
153
+ if str(arg) not in v_maps:
154
+ continue
155
+ fprint(f"{v_maps[str(arg)]}", end="\t")
156
+ input_shapes.append(v_maps[str(arg)])
157
+ fprint()
158
+ fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
159
+ output_shapes = [node.meta["tensor_meta"].shape]
160
+ if node.op in ["output", "placeholder"]:
161
+ node_flops = 0
162
+ elif node.op == "call_function":
163
+ # torch internal functions
164
+ key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
165
+ if key in count_map:
166
+ node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
167
+ else:
168
+ missing_maps[key] = (node.op, key)
169
+ prRed(f"|{key}| is missing")
170
+ elif node.op == "call_method":
171
+ # torch internal functions
172
+ # fprint(str(node.target) in count_map, str(node.target), count_map.keys())
173
+ key = str(node.target)
174
+ if key in count_map:
175
+ node_flops = count_map[key](input_shapes, output_shapes)
176
+ else:
177
+ missing_maps[key] = (node.op, key)
178
+ prRed(f"{key} is missing")
179
+ elif node.op == "call_module":
180
+ # torch.nn modules
181
+ # m = getattr(mod, node.target, None)
182
+ m = mod.get_submodule(node.target)
183
+ key = type(m)
184
+ fprint(type(m), type(m) in count_map)
185
+ if type(m) in count_map:
186
+ node_flops = count_map[type(m)](m, input_shapes, output_shapes)
187
+ else:
188
+ missing_maps[key] = (node.op,)
189
+ prRed(f"{key} is missing")
190
+ print("module type:", type(m))
191
+ if isinstance(m, zero_ops):
192
+ print("weight_shape: None")
193
+ else:
194
+ print(type(m))
195
+ print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")
196
+
197
+ v_maps[str(node.name)] = node.meta["tensor_meta"].shape
198
+ if node_flops is not None:
199
+ total_flops += node_flops
200
+ prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
201
+ fprint("==" * 40)
202
+
203
+ if len(missing_maps.keys()) > 0:
204
+ from pprint import pprint
205
+
206
+ print("Missing operators: ")
207
+ pprint(missing_maps)
208
+ return total_flops
209
+
210
+
211
+ if __name__ == "__main__":
212
+
213
+ class MyOP(nn.Module):
214
+ def forward(self, input):
215
+ """Performs forward pass on given input data."""
216
+ return input / 1
217
+
218
+ class MyModule(torch.nn.Module):
219
+ def __init__(self):
220
+ """Initializes MyModule with two linear layers and a custom MyOP operator."""
221
+ super().__init__()
222
+ self.linear1 = torch.nn.Linear(5, 3)
223
+ self.linear2 = torch.nn.Linear(5, 3)
224
+ self.myop = MyOP()
225
+
226
+ def forward(self, x):
227
+ """Applies two linear transformations to the input tensor, clamps the second, then combines and processes
228
+ with MyOP operator.
229
+ """
230
+ out1 = self.linear1(x)
231
+ out2 = self.linear2(x).clamp(min=0.0, max=1.0)
232
+ return self.myop(out1 + out2)
233
+
234
+ net = MyModule()
235
+ data = th.randn(20, 5)
236
+ flops = fx_profile(net, data, verbose=False)
237
+ print(flops)
thop/profile.py ADDED
@@ -0,0 +1,228 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from thop.rnn_hooks import *
4
+ from thop.vision.basic_hooks import *
5
+
6
+ from .utils import prRed
7
+
8
+ default_dtype = torch.float64
9
+
10
+ register_hooks = {
11
+ nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
12
+ nn.Conv1d: count_convNd,
13
+ nn.Conv2d: count_convNd,
14
+ nn.Conv3d: count_convNd,
15
+ nn.ConvTranspose1d: count_convNd,
16
+ nn.ConvTranspose2d: count_convNd,
17
+ nn.ConvTranspose3d: count_convNd,
18
+ nn.BatchNorm1d: count_normalization,
19
+ nn.BatchNorm2d: count_normalization,
20
+ nn.BatchNorm3d: count_normalization,
21
+ nn.LayerNorm: count_normalization,
22
+ nn.InstanceNorm1d: count_normalization,
23
+ nn.InstanceNorm2d: count_normalization,
24
+ nn.InstanceNorm3d: count_normalization,
25
+ nn.PReLU: count_prelu,
26
+ nn.Softmax: count_softmax,
27
+ nn.ReLU: zero_ops,
28
+ nn.ReLU6: zero_ops,
29
+ nn.LeakyReLU: count_relu,
30
+ nn.MaxPool1d: zero_ops,
31
+ nn.MaxPool2d: zero_ops,
32
+ nn.MaxPool3d: zero_ops,
33
+ nn.AdaptiveMaxPool1d: zero_ops,
34
+ nn.AdaptiveMaxPool2d: zero_ops,
35
+ nn.AdaptiveMaxPool3d: zero_ops,
36
+ nn.AvgPool1d: count_avgpool,
37
+ nn.AvgPool2d: count_avgpool,
38
+ nn.AvgPool3d: count_avgpool,
39
+ nn.AdaptiveAvgPool1d: count_adap_avgpool,
40
+ nn.AdaptiveAvgPool2d: count_adap_avgpool,
41
+ nn.AdaptiveAvgPool3d: count_adap_avgpool,
42
+ nn.Linear: count_linear,
43
+ nn.Dropout: zero_ops,
44
+ nn.Upsample: count_upsample,
45
+ nn.UpsamplingBilinear2d: count_upsample,
46
+ nn.UpsamplingNearest2d: count_upsample,
47
+ nn.RNNCell: count_rnn_cell,
48
+ nn.GRUCell: count_gru_cell,
49
+ nn.LSTMCell: count_lstm_cell,
50
+ nn.RNN: count_rnn,
51
+ nn.GRU: count_gru,
52
+ nn.LSTM: count_lstm,
53
+ nn.Sequential: zero_ops,
54
+ nn.PixelShuffle: zero_ops,
55
+ nn.SyncBatchNorm: count_normalization,
56
+ }
57
+
58
+
59
+ def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
60
+ """Profiles a PyTorch model's operations and parameters, applying either custom or default hooks."""
61
+ handler_collection = []
62
+ types_collection = set()
63
+ if custom_ops is None:
64
+ custom_ops = {}
65
+ if report_missing:
66
+ verbose = True
67
+
68
+ def add_hooks(m):
69
+ if list(m.children()):
70
+ return
71
+
72
+ if hasattr(m, "total_ops") or hasattr(m, "total_params"):
73
+ logging.warning(
74
+ f"Either .total_ops or .total_params is already defined in {str(m)}. "
75
+ "Be careful, it might change your code's behavior."
76
+ )
77
+
78
+ m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
79
+ m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
80
+
81
+ for p in m.parameters():
82
+ m.total_params += torch.DoubleTensor([p.numel()])
83
+
84
+ m_type = type(m)
85
+
86
+ fn = None
87
+ if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
88
+ fn = custom_ops[m_type]
89
+ if m_type not in types_collection and verbose:
90
+ print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
91
+ elif m_type in register_hooks:
92
+ fn = register_hooks[m_type]
93
+ if m_type not in types_collection and verbose:
94
+ print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
95
+ else:
96
+ if m_type not in types_collection and report_missing:
97
+ prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
98
+
99
+ if fn is not None:
100
+ handler = m.register_forward_hook(fn)
101
+ handler_collection.append(handler)
102
+ types_collection.add(m_type)
103
+
104
+ training = model.training
105
+
106
+ model.eval()
107
+ model.apply(add_hooks)
108
+
109
+ with torch.no_grad():
110
+ model(*inputs)
111
+
112
+ total_ops = 0
113
+ total_params = 0
114
+ for m in model.modules():
115
+ if list(m.children()): # skip for non-leaf module
116
+ continue
117
+ total_ops += m.total_ops
118
+ total_params += m.total_params
119
+
120
+ total_ops = total_ops.item()
121
+ total_params = total_params.item()
122
+
123
+ # reset model to original status
124
+ model.train(training)
125
+ for handler in handler_collection:
126
+ handler.remove()
127
+
128
+ # remove temporal buffers
129
+ for n, m in model.named_modules():
130
+ if list(m.children()):
131
+ continue
132
+ if "total_ops" in m._buffers:
133
+ m._buffers.pop("total_ops")
134
+ if "total_params" in m._buffers:
135
+ m._buffers.pop("total_params")
136
+
137
+ return total_ops, total_params
138
+
139
+
140
+ def profile(
141
+ model: nn.Module,
142
+ inputs,
143
+ custom_ops=None,
144
+ verbose=True,
145
+ ret_layer_info=False,
146
+ report_missing=False,
147
+ ):
148
+ """Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details."""
149
+ handler_collection = {}
150
+ types_collection = set()
151
+ if custom_ops is None:
152
+ custom_ops = {}
153
+ if report_missing:
154
+ # overwrite `verbose` option when enable report_missing
155
+ verbose = True
156
+
157
+ def add_hooks(m: nn.Module):
158
+ """Registers hooks to a neural network module to track total operations and parameters."""
159
+ m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
160
+ m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
161
+
162
+ # for p in m.parameters():
163
+ # m.total_params += torch.DoubleTensor([p.numel()])
164
+
165
+ m_type = type(m)
166
+
167
+ fn = None
168
+ if m_type in custom_ops:
169
+ # if defined both op maps, use custom_ops to overwrite.
170
+ fn = custom_ops[m_type]
171
+ if m_type not in types_collection and verbose:
172
+ print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
173
+ elif m_type in register_hooks:
174
+ fn = register_hooks[m_type]
175
+ if m_type not in types_collection and verbose:
176
+ print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
177
+ else:
178
+ if m_type not in types_collection and report_missing:
179
+ prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
180
+
181
+ if fn is not None:
182
+ handler_collection[m] = (
183
+ m.register_forward_hook(fn),
184
+ m.register_forward_hook(count_parameters),
185
+ )
186
+ types_collection.add(m_type)
187
+
188
+ prev_training_status = model.training
189
+
190
+ model.eval()
191
+ model.apply(add_hooks)
192
+
193
+ with torch.no_grad():
194
+ model(*inputs)
195
+
196
+ def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
197
+ """Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
198
+ total_ops, total_params = module.total_ops.item(), 0
199
+ ret_dict = {}
200
+ for n, m in module.named_children():
201
+ # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
202
+ # m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
203
+ # else:
204
+ # m_ops, m_params = m.total_ops, m.total_params
205
+ next_dict = {}
206
+ if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
207
+ m_ops, m_params = m.total_ops.item(), m.total_params.item()
208
+ else:
209
+ m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
210
+ ret_dict[n] = (m_ops, m_params, next_dict)
211
+ total_ops += m_ops
212
+ total_params += m_params
213
+ # print(prefix, module._get_name(), (total_ops, total_params))
214
+ return total_ops, total_params, ret_dict
215
+
216
+ total_ops, total_params, ret_dict = dfs_count(model)
217
+
218
+ # reset model to original status
219
+ model.train(prev_training_status)
220
+ for m, (op_handler, params_handler) in handler_collection.items():
221
+ op_handler.remove()
222
+ params_handler.remove()
223
+ m._buffers.pop("total_ops")
224
+ m._buffers.pop("total_params")
225
+
226
+ if ret_layer_info:
227
+ return total_ops, total_params, ret_dict
228
+ return total_ops, total_params
thop/rnn_hooks.py ADDED
@@ -0,0 +1,202 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils.rnn import PackedSequence
6
+
7
+
8
+ def _count_rnn_cell(input_size, hidden_size, bias=True):
9
+ """Calculate the total operations for an RNN cell given input size, hidden size, and optional bias."""
10
+ total_ops = hidden_size * (input_size + hidden_size) + hidden_size
11
+ if bias:
12
+ total_ops += hidden_size * 2
13
+
14
+ return total_ops
15
+
16
+
17
+ def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
18
+ """Counts the total RNN cell operations based on input tensor, hidden size, bias, and batch size."""
19
+ total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)
20
+
21
+ batch_size = x[0].size(0)
22
+ total_ops *= batch_size
23
+
24
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
25
+
26
+
27
+ def _count_gru_cell(input_size, hidden_size, bias=True):
28
+ """Counts the total operations for a GRU cell based on input size, hidden size, and bias configuration."""
29
+ total_ops = 0
30
+ # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
31
+ # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
32
+ state_ops = (hidden_size + input_size) * hidden_size + hidden_size
33
+ if bias:
34
+ state_ops += hidden_size * 2
35
+ total_ops += state_ops * 2
36
+
37
+ # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
38
+ total_ops += (hidden_size + input_size) * hidden_size + hidden_size
39
+ if bias:
40
+ total_ops += hidden_size * 2
41
+ # r hadamard : r * (~)
42
+ total_ops += hidden_size
43
+
44
+ # h' = (1 - z) * n + z * h
45
+ # hadamard hadamard add
46
+ total_ops += hidden_size * 3
47
+
48
+ return total_ops
49
+
50
+
51
+ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
52
+ """Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
53
+ total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)
54
+
55
+ batch_size = x[0].size(0)
56
+ total_ops *= batch_size
57
+
58
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
59
+
60
+
61
+ def _count_lstm_cell(input_size, hidden_size, bias=True):
62
+ """Counts LSTM cell operations during inference based on input size, hidden size, and bias configuration."""
63
+ total_ops = 0
64
+
65
+ # i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
66
+ # f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
67
+ # o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
68
+ # g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
69
+ state_ops = (input_size + hidden_size) * hidden_size + hidden_size
70
+ if bias:
71
+ state_ops += hidden_size * 2
72
+ total_ops += state_ops * 4
73
+
74
+ # c' = f * c + i * g \\
75
+ # hadamard hadamard add
76
+ total_ops += hidden_size * 3
77
+
78
+ # h' = o * \tanh(c') \\
79
+ total_ops += hidden_size
80
+
81
+ return total_ops
82
+
83
+
84
+ def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
85
+ """Counts and updates the total operations for an LSTM cell in a mini-batch during inference."""
86
+ total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)
87
+
88
+ batch_size = x[0].size(0)
89
+ total_ops *= batch_size
90
+
91
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
92
+
93
+
94
+ def count_rnn(m: nn.RNN, x, y):
95
+ """Calculate and update the total number of operations for each RNN cell in a given batch."""
96
+ bias = m.bias
97
+ input_size = m.input_size
98
+ hidden_size = m.hidden_size
99
+ num_layers = m.num_layers
100
+
101
+ if isinstance(x[0], PackedSequence):
102
+ batch_size = torch.max(x[0].batch_sizes)
103
+ num_steps = x[0].batch_sizes.size(0)
104
+ elif m.batch_first:
105
+ batch_size = x[0].size(0)
106
+ num_steps = x[0].size(1)
107
+ else:
108
+ batch_size = x[0].size(1)
109
+ num_steps = x[0].size(0)
110
+
111
+ total_ops = 0
112
+ if m.bidirectional:
113
+ total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2
114
+ else:
115
+ total_ops += _count_rnn_cell(input_size, hidden_size, bias)
116
+
117
+ for _ in range(num_layers - 1):
118
+ total_ops += (
119
+ _count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
120
+ if m.bidirectional
121
+ else _count_rnn_cell(hidden_size, hidden_size, bias)
122
+ )
123
+ # time unroll
124
+ total_ops *= num_steps
125
+ # batch_size
126
+ total_ops *= batch_size
127
+
128
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
129
+
130
+
131
+ def count_gru(m: nn.GRU, x, y):
132
+ """Calculates total operations for a GRU layer, updating the model's operation count based on batch size."""
133
+ bias = m.bias
134
+ input_size = m.input_size
135
+ hidden_size = m.hidden_size
136
+ num_layers = m.num_layers
137
+
138
+ if isinstance(x[0], PackedSequence):
139
+ batch_size = torch.max(x[0].batch_sizes)
140
+ num_steps = x[0].batch_sizes.size(0)
141
+ elif m.batch_first:
142
+ batch_size = x[0].size(0)
143
+ num_steps = x[0].size(1)
144
+ else:
145
+ batch_size = x[0].size(1)
146
+ num_steps = x[0].size(0)
147
+
148
+ total_ops = 0
149
+ if m.bidirectional:
150
+ total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2
151
+ else:
152
+ total_ops += _count_gru_cell(input_size, hidden_size, bias)
153
+
154
+ for _ in range(num_layers - 1):
155
+ total_ops += (
156
+ _count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
157
+ if m.bidirectional
158
+ else _count_gru_cell(hidden_size, hidden_size, bias)
159
+ )
160
+ # time unroll
161
+ total_ops *= num_steps
162
+ # batch_size
163
+ total_ops *= batch_size
164
+
165
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
166
+
167
+
168
+ def count_lstm(m: nn.LSTM, x, y):
169
+ """Calculate total operations for LSTM layers, including bidirectional, updating model's total operations."""
170
+ bias = m.bias
171
+ input_size = m.input_size
172
+ hidden_size = m.hidden_size
173
+ num_layers = m.num_layers
174
+
175
+ if isinstance(x[0], PackedSequence):
176
+ batch_size = torch.max(x[0].batch_sizes)
177
+ num_steps = x[0].batch_sizes.size(0)
178
+ elif m.batch_first:
179
+ batch_size = x[0].size(0)
180
+ num_steps = x[0].size(1)
181
+ else:
182
+ batch_size = x[0].size(1)
183
+ num_steps = x[0].size(0)
184
+
185
+ total_ops = 0
186
+ if m.bidirectional:
187
+ total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2
188
+ else:
189
+ total_ops += _count_lstm_cell(input_size, hidden_size, bias)
190
+
191
+ for _ in range(num_layers - 1):
192
+ total_ops += (
193
+ _count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
194
+ if m.bidirectional
195
+ else _count_lstm_cell(hidden_size, hidden_size, bias)
196
+ )
197
+ # time unroll
198
+ total_ops *= num_steps
199
+ # batch_size
200
+ total_ops *= batch_size
201
+
202
+ m.total_ops += torch.DoubleTensor([int(total_ops)])