ultralytics-thop 2.0.14__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
thop/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ __version__ = "2.0.14"
4
+
5
+
6
+ import torch
7
+
8
+ from .profile import profile, profile_origin
9
+ from .utils import clever_format
10
+
11
+ default_dtype = torch.float64
thop/fx_profile.py ADDED
@@ -0,0 +1,237 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import logging
4
+ from distutils.version import LooseVersion
5
+
6
+ import torch
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+ if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
11
+ logging.warning(
12
+ f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. "
13
+ )
14
+
15
+
16
+ def count_clamp(input_shapes, output_shapes):
17
+ """Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
18
+ return 0
19
+
20
+
21
+ def count_mul(input_shapes, output_shapes):
22
+ """Returns the number of elements in the first output shape."""
23
+ return output_shapes[0].numel()
24
+
25
+
26
+ def count_matmul(input_shapes, output_shapes):
27
+ """Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
28
+ in_shape = input_shapes[0]
29
+ out_shape = output_shapes[0]
30
+ in_features = in_shape[-1]
31
+ num_elements = out_shape.numel()
32
+ return in_features * num_elements
33
+
34
+
35
+ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
36
+ """Calculates the total FLOPs for a linear layer, including bias operations if specified."""
37
+ flops = count_matmul(input_shapes, output_shapes)
38
+ if "bias" in kwargs:
39
+ flops += output_shapes[0].numel()
40
+ return flops
41
+
42
+
43
+ from .vision.calc_func import calculate_conv
44
+
45
+
46
+ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
47
+ """Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
48
+ `calculate_conv`.
49
+ """
50
+ inputs, weight, bias, stride, padding, dilation, groups = args
51
+ if len(input_shapes) == 2:
52
+ x_shape, k_shape = input_shapes
53
+ elif len(input_shapes) == 3:
54
+ x_shape, k_shape, b_shape = input_shapes
55
+ out_shape = output_shapes[0]
56
+
57
+ kernel_parameters = k_shape[2:].numel()
58
+ bias_op = 0 # check it later
59
+ in_channel = x_shape[1]
60
+
61
+ total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
62
+ return int(total_ops)
63
+
64
+
65
+ def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
66
+ """Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
67
+ return count_matmul(input_shapes, output_shapes)
68
+
69
+
70
+ def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
71
+ """Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
72
+ return 0
73
+
74
+
75
+ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
76
+ """Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
77
+ bias_op = 1 if module.bias is not None else 0
78
+ out_shape = output_shapes[0]
79
+
80
+ in_channel = module.in_channels
81
+ groups = module.groups
82
+ kernel_ops = module.weight.shape[2:].numel()
83
+ total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
84
+ return int(total_ops)
85
+
86
+
87
+ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
88
+ """Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
89
+ assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
90
+ y = output_shapes[0]
91
+ return 2 * y.numel()
92
+
93
+
94
+ zero_ops = (
95
+ nn.ReLU,
96
+ nn.ReLU6,
97
+ nn.Dropout,
98
+ nn.MaxPool2d,
99
+ nn.AvgPool2d,
100
+ nn.AdaptiveAvgPool2d,
101
+ )
102
+
103
+ count_map = {
104
+ nn.Linear: count_nn_linear,
105
+ nn.Conv2d: count_nn_conv2d,
106
+ nn.BatchNorm2d: count_nn_bn2d,
107
+ "function linear": count_fn_linear,
108
+ "clamp": count_clamp,
109
+ "built-in function add": count_zero_ops,
110
+ "built-in method fl": count_zero_ops,
111
+ "built-in method conv2d of type object": count_fn_conv2d,
112
+ "built-in function mul": count_mul,
113
+ "built-in function truediv": count_mul,
114
+ }
115
+
116
+ for k in zero_ops:
117
+ count_map[k] = count_zero_ops
118
+
119
+ missing_maps = {}
120
+
121
+ from torch.fx import symbolic_trace
122
+ from torch.fx.passes.shape_prop import ShapeProp
123
+
124
+ from .utils import prRed, prYellow
125
+
126
+
127
+ def null_print(*args, **kwargs):
128
+ """A no-op print function that takes any arguments without performing any actions."""
129
+ return
130
+
131
+
132
+ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
133
+ """Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
134
+ gm: torch.fx.GraphModule = symbolic_trace(mod)
135
+ ShapeProp(gm).propagate(input)
136
+
137
+ fprint = null_print
138
+ if verbose:
139
+ fprint = print
140
+
141
+ v_maps = {}
142
+ total_flops = 0
143
+
144
+ for node in gm.graph.nodes:
145
+ # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
146
+ fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
147
+ # node_op_type = str(node.target).split(".")[-1]
148
+ node_flops = None
149
+
150
+ input_shapes = []
151
+ fprint("input_shape:", end="\t")
152
+ for arg in node.args:
153
+ if str(arg) not in v_maps:
154
+ continue
155
+ fprint(f"{v_maps[str(arg)]}", end="\t")
156
+ input_shapes.append(v_maps[str(arg)])
157
+ fprint()
158
+ fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
159
+ output_shapes = [node.meta["tensor_meta"].shape]
160
+ if node.op in ["output", "placeholder"]:
161
+ node_flops = 0
162
+ elif node.op == "call_function":
163
+ # torch internal functions
164
+ key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
165
+ if key in count_map:
166
+ node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
167
+ else:
168
+ missing_maps[key] = (node.op, key)
169
+ prRed(f"|{key}| is missing")
170
+ elif node.op == "call_method":
171
+ # torch internal functions
172
+ # fprint(str(node.target) in count_map, str(node.target), count_map.keys())
173
+ key = str(node.target)
174
+ if key in count_map:
175
+ node_flops = count_map[key](input_shapes, output_shapes)
176
+ else:
177
+ missing_maps[key] = (node.op, key)
178
+ prRed(f"{key} is missing")
179
+ elif node.op == "call_module":
180
+ # torch.nn modules
181
+ # m = getattr(mod, node.target, None)
182
+ m = mod.get_submodule(node.target)
183
+ key = type(m)
184
+ fprint(type(m), type(m) in count_map)
185
+ if type(m) in count_map:
186
+ node_flops = count_map[type(m)](m, input_shapes, output_shapes)
187
+ else:
188
+ missing_maps[key] = (node.op,)
189
+ prRed(f"{key} is missing")
190
+ print("module type:", type(m))
191
+ if isinstance(m, zero_ops):
192
+ print("weight_shape: None")
193
+ else:
194
+ print(type(m))
195
+ print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")
196
+
197
+ v_maps[str(node.name)] = node.meta["tensor_meta"].shape
198
+ if node_flops is not None:
199
+ total_flops += node_flops
200
+ prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
201
+ fprint("==" * 40)
202
+
203
+ if len(missing_maps.keys()) > 0:
204
+ from pprint import pprint
205
+
206
+ print("Missing operators: ")
207
+ pprint(missing_maps)
208
+ return total_flops
209
+
210
+
211
+ if __name__ == "__main__":
212
+
213
+ class MyOP(nn.Module):
214
+ def forward(self, input):
215
+ """Performs forward pass on given input data."""
216
+ return input / 1
217
+
218
+ class MyModule(torch.nn.Module):
219
+ def __init__(self):
220
+ """Initializes MyModule with two linear layers and a custom MyOP operator."""
221
+ super().__init__()
222
+ self.linear1 = torch.nn.Linear(5, 3)
223
+ self.linear2 = torch.nn.Linear(5, 3)
224
+ self.myop = MyOP()
225
+
226
+ def forward(self, x):
227
+ """Applies two linear transformations to the input tensor, clamps the second, then combines and processes
228
+ with MyOP operator.
229
+ """
230
+ out1 = self.linear1(x)
231
+ out2 = self.linear2(x).clamp(min=0.0, max=1.0)
232
+ return self.myop(out1 + out2)
233
+
234
+ net = MyModule()
235
+ data = th.randn(20, 5)
236
+ flops = fx_profile(net, data, verbose=False)
237
+ print(flops)
thop/profile.py ADDED
@@ -0,0 +1,228 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from thop.rnn_hooks import *
4
+ from thop.vision.basic_hooks import *
5
+
6
+ from .utils import prRed
7
+
8
+ default_dtype = torch.float64
9
+
10
+ register_hooks = {
11
+ nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
12
+ nn.Conv1d: count_convNd,
13
+ nn.Conv2d: count_convNd,
14
+ nn.Conv3d: count_convNd,
15
+ nn.ConvTranspose1d: count_convNd,
16
+ nn.ConvTranspose2d: count_convNd,
17
+ nn.ConvTranspose3d: count_convNd,
18
+ nn.BatchNorm1d: count_normalization,
19
+ nn.BatchNorm2d: count_normalization,
20
+ nn.BatchNorm3d: count_normalization,
21
+ nn.LayerNorm: count_normalization,
22
+ nn.InstanceNorm1d: count_normalization,
23
+ nn.InstanceNorm2d: count_normalization,
24
+ nn.InstanceNorm3d: count_normalization,
25
+ nn.PReLU: count_prelu,
26
+ nn.Softmax: count_softmax,
27
+ nn.ReLU: zero_ops,
28
+ nn.ReLU6: zero_ops,
29
+ nn.LeakyReLU: count_relu,
30
+ nn.MaxPool1d: zero_ops,
31
+ nn.MaxPool2d: zero_ops,
32
+ nn.MaxPool3d: zero_ops,
33
+ nn.AdaptiveMaxPool1d: zero_ops,
34
+ nn.AdaptiveMaxPool2d: zero_ops,
35
+ nn.AdaptiveMaxPool3d: zero_ops,
36
+ nn.AvgPool1d: count_avgpool,
37
+ nn.AvgPool2d: count_avgpool,
38
+ nn.AvgPool3d: count_avgpool,
39
+ nn.AdaptiveAvgPool1d: count_adap_avgpool,
40
+ nn.AdaptiveAvgPool2d: count_adap_avgpool,
41
+ nn.AdaptiveAvgPool3d: count_adap_avgpool,
42
+ nn.Linear: count_linear,
43
+ nn.Dropout: zero_ops,
44
+ nn.Upsample: count_upsample,
45
+ nn.UpsamplingBilinear2d: count_upsample,
46
+ nn.UpsamplingNearest2d: count_upsample,
47
+ nn.RNNCell: count_rnn_cell,
48
+ nn.GRUCell: count_gru_cell,
49
+ nn.LSTMCell: count_lstm_cell,
50
+ nn.RNN: count_rnn,
51
+ nn.GRU: count_gru,
52
+ nn.LSTM: count_lstm,
53
+ nn.Sequential: zero_ops,
54
+ nn.PixelShuffle: zero_ops,
55
+ nn.SyncBatchNorm: count_normalization,
56
+ }
57
+
58
+
59
+ def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
60
+ """Profiles a PyTorch model's operations and parameters, applying either custom or default hooks."""
61
+ handler_collection = []
62
+ types_collection = set()
63
+ if custom_ops is None:
64
+ custom_ops = {}
65
+ if report_missing:
66
+ verbose = True
67
+
68
+ def add_hooks(m):
69
+ if list(m.children()):
70
+ return
71
+
72
+ if hasattr(m, "total_ops") or hasattr(m, "total_params"):
73
+ logging.warning(
74
+ f"Either .total_ops or .total_params is already defined in {str(m)}. "
75
+ "Be careful, it might change your code's behavior."
76
+ )
77
+
78
+ m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
79
+ m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
80
+
81
+ for p in m.parameters():
82
+ m.total_params += torch.DoubleTensor([p.numel()])
83
+
84
+ m_type = type(m)
85
+
86
+ fn = None
87
+ if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
88
+ fn = custom_ops[m_type]
89
+ if m_type not in types_collection and verbose:
90
+ print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
91
+ elif m_type in register_hooks:
92
+ fn = register_hooks[m_type]
93
+ if m_type not in types_collection and verbose:
94
+ print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
95
+ else:
96
+ if m_type not in types_collection and report_missing:
97
+ prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
98
+
99
+ if fn is not None:
100
+ handler = m.register_forward_hook(fn)
101
+ handler_collection.append(handler)
102
+ types_collection.add(m_type)
103
+
104
+ training = model.training
105
+
106
+ model.eval()
107
+ model.apply(add_hooks)
108
+
109
+ with torch.no_grad():
110
+ model(*inputs)
111
+
112
+ total_ops = 0
113
+ total_params = 0
114
+ for m in model.modules():
115
+ if list(m.children()): # skip for non-leaf module
116
+ continue
117
+ total_ops += m.total_ops
118
+ total_params += m.total_params
119
+
120
+ total_ops = total_ops.item()
121
+ total_params = total_params.item()
122
+
123
+ # reset model to original status
124
+ model.train(training)
125
+ for handler in handler_collection:
126
+ handler.remove()
127
+
128
+ # remove temporal buffers
129
+ for n, m in model.named_modules():
130
+ if list(m.children()):
131
+ continue
132
+ if "total_ops" in m._buffers:
133
+ m._buffers.pop("total_ops")
134
+ if "total_params" in m._buffers:
135
+ m._buffers.pop("total_params")
136
+
137
+ return total_ops, total_params
138
+
139
+
140
+ def profile(
141
+ model: nn.Module,
142
+ inputs,
143
+ custom_ops=None,
144
+ verbose=True,
145
+ ret_layer_info=False,
146
+ report_missing=False,
147
+ ):
148
+ """Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details."""
149
+ handler_collection = {}
150
+ types_collection = set()
151
+ if custom_ops is None:
152
+ custom_ops = {}
153
+ if report_missing:
154
+ # overwrite `verbose` option when enable report_missing
155
+ verbose = True
156
+
157
+ def add_hooks(m: nn.Module):
158
+ """Registers hooks to a neural network module to track total operations and parameters."""
159
+ m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
160
+ m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
161
+
162
+ # for p in m.parameters():
163
+ # m.total_params += torch.DoubleTensor([p.numel()])
164
+
165
+ m_type = type(m)
166
+
167
+ fn = None
168
+ if m_type in custom_ops:
169
+ # if defined both op maps, use custom_ops to overwrite.
170
+ fn = custom_ops[m_type]
171
+ if m_type not in types_collection and verbose:
172
+ print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
173
+ elif m_type in register_hooks:
174
+ fn = register_hooks[m_type]
175
+ if m_type not in types_collection and verbose:
176
+ print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
177
+ else:
178
+ if m_type not in types_collection and report_missing:
179
+ prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
180
+
181
+ if fn is not None:
182
+ handler_collection[m] = (
183
+ m.register_forward_hook(fn),
184
+ m.register_forward_hook(count_parameters),
185
+ )
186
+ types_collection.add(m_type)
187
+
188
+ prev_training_status = model.training
189
+
190
+ model.eval()
191
+ model.apply(add_hooks)
192
+
193
+ with torch.no_grad():
194
+ model(*inputs)
195
+
196
+ def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
197
+ """Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
198
+ total_ops, total_params = module.total_ops.item(), 0
199
+ ret_dict = {}
200
+ for n, m in module.named_children():
201
+ # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
202
+ # m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
203
+ # else:
204
+ # m_ops, m_params = m.total_ops, m.total_params
205
+ next_dict = {}
206
+ if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
207
+ m_ops, m_params = m.total_ops.item(), m.total_params.item()
208
+ else:
209
+ m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
210
+ ret_dict[n] = (m_ops, m_params, next_dict)
211
+ total_ops += m_ops
212
+ total_params += m_params
213
+ # print(prefix, module._get_name(), (total_ops, total_params))
214
+ return total_ops, total_params, ret_dict
215
+
216
+ total_ops, total_params, ret_dict = dfs_count(model)
217
+
218
+ # reset model to original status
219
+ model.train(prev_training_status)
220
+ for m, (op_handler, params_handler) in handler_collection.items():
221
+ op_handler.remove()
222
+ params_handler.remove()
223
+ m._buffers.pop("total_ops")
224
+ m._buffers.pop("total_params")
225
+
226
+ if ret_layer_info:
227
+ return total_ops, total_params, ret_dict
228
+ return total_ops, total_params
thop/rnn_hooks.py ADDED
@@ -0,0 +1,202 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils.rnn import PackedSequence
6
+
7
+
8
+ def _count_rnn_cell(input_size, hidden_size, bias=True):
9
+ """Calculate the total operations for an RNN cell given input size, hidden size, and optional bias."""
10
+ total_ops = hidden_size * (input_size + hidden_size) + hidden_size
11
+ if bias:
12
+ total_ops += hidden_size * 2
13
+
14
+ return total_ops
15
+
16
+
17
+ def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
18
+ """Counts the total RNN cell operations based on input tensor, hidden size, bias, and batch size."""
19
+ total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)
20
+
21
+ batch_size = x[0].size(0)
22
+ total_ops *= batch_size
23
+
24
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
25
+
26
+
27
+ def _count_gru_cell(input_size, hidden_size, bias=True):
28
+ """Counts the total operations for a GRU cell based on input size, hidden size, and bias configuration."""
29
+ total_ops = 0
30
+ # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
31
+ # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
32
+ state_ops = (hidden_size + input_size) * hidden_size + hidden_size
33
+ if bias:
34
+ state_ops += hidden_size * 2
35
+ total_ops += state_ops * 2
36
+
37
+ # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
38
+ total_ops += (hidden_size + input_size) * hidden_size + hidden_size
39
+ if bias:
40
+ total_ops += hidden_size * 2
41
+ # r hadamard : r * (~)
42
+ total_ops += hidden_size
43
+
44
+ # h' = (1 - z) * n + z * h
45
+ # hadamard hadamard add
46
+ total_ops += hidden_size * 3
47
+
48
+ return total_ops
49
+
50
+
51
+ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
52
+ """Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
53
+ total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)
54
+
55
+ batch_size = x[0].size(0)
56
+ total_ops *= batch_size
57
+
58
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
59
+
60
+
61
+ def _count_lstm_cell(input_size, hidden_size, bias=True):
62
+ """Counts LSTM cell operations during inference based on input size, hidden size, and bias configuration."""
63
+ total_ops = 0
64
+
65
+ # i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
66
+ # f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
67
+ # o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
68
+ # g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
69
+ state_ops = (input_size + hidden_size) * hidden_size + hidden_size
70
+ if bias:
71
+ state_ops += hidden_size * 2
72
+ total_ops += state_ops * 4
73
+
74
+ # c' = f * c + i * g \\
75
+ # hadamard hadamard add
76
+ total_ops += hidden_size * 3
77
+
78
+ # h' = o * \tanh(c') \\
79
+ total_ops += hidden_size
80
+
81
+ return total_ops
82
+
83
+
84
+ def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
85
+ """Counts and updates the total operations for an LSTM cell in a mini-batch during inference."""
86
+ total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)
87
+
88
+ batch_size = x[0].size(0)
89
+ total_ops *= batch_size
90
+
91
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
92
+
93
+
94
+ def count_rnn(m: nn.RNN, x, y):
95
+ """Calculate and update the total number of operations for each RNN cell in a given batch."""
96
+ bias = m.bias
97
+ input_size = m.input_size
98
+ hidden_size = m.hidden_size
99
+ num_layers = m.num_layers
100
+
101
+ if isinstance(x[0], PackedSequence):
102
+ batch_size = torch.max(x[0].batch_sizes)
103
+ num_steps = x[0].batch_sizes.size(0)
104
+ elif m.batch_first:
105
+ batch_size = x[0].size(0)
106
+ num_steps = x[0].size(1)
107
+ else:
108
+ batch_size = x[0].size(1)
109
+ num_steps = x[0].size(0)
110
+
111
+ total_ops = 0
112
+ if m.bidirectional:
113
+ total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2
114
+ else:
115
+ total_ops += _count_rnn_cell(input_size, hidden_size, bias)
116
+
117
+ for _ in range(num_layers - 1):
118
+ total_ops += (
119
+ _count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
120
+ if m.bidirectional
121
+ else _count_rnn_cell(hidden_size, hidden_size, bias)
122
+ )
123
+ # time unroll
124
+ total_ops *= num_steps
125
+ # batch_size
126
+ total_ops *= batch_size
127
+
128
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
129
+
130
+
131
+ def count_gru(m: nn.GRU, x, y):
132
+ """Calculates total operations for a GRU layer, updating the model's operation count based on batch size."""
133
+ bias = m.bias
134
+ input_size = m.input_size
135
+ hidden_size = m.hidden_size
136
+ num_layers = m.num_layers
137
+
138
+ if isinstance(x[0], PackedSequence):
139
+ batch_size = torch.max(x[0].batch_sizes)
140
+ num_steps = x[0].batch_sizes.size(0)
141
+ elif m.batch_first:
142
+ batch_size = x[0].size(0)
143
+ num_steps = x[0].size(1)
144
+ else:
145
+ batch_size = x[0].size(1)
146
+ num_steps = x[0].size(0)
147
+
148
+ total_ops = 0
149
+ if m.bidirectional:
150
+ total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2
151
+ else:
152
+ total_ops += _count_gru_cell(input_size, hidden_size, bias)
153
+
154
+ for _ in range(num_layers - 1):
155
+ total_ops += (
156
+ _count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
157
+ if m.bidirectional
158
+ else _count_gru_cell(hidden_size, hidden_size, bias)
159
+ )
160
+ # time unroll
161
+ total_ops *= num_steps
162
+ # batch_size
163
+ total_ops *= batch_size
164
+
165
+ m.total_ops += torch.DoubleTensor([int(total_ops)])
166
+
167
+
168
+ def count_lstm(m: nn.LSTM, x, y):
169
+ """Calculate total operations for LSTM layers, including bidirectional, updating model's total operations."""
170
+ bias = m.bias
171
+ input_size = m.input_size
172
+ hidden_size = m.hidden_size
173
+ num_layers = m.num_layers
174
+
175
+ if isinstance(x[0], PackedSequence):
176
+ batch_size = torch.max(x[0].batch_sizes)
177
+ num_steps = x[0].batch_sizes.size(0)
178
+ elif m.batch_first:
179
+ batch_size = x[0].size(0)
180
+ num_steps = x[0].size(1)
181
+ else:
182
+ batch_size = x[0].size(1)
183
+ num_steps = x[0].size(0)
184
+
185
+ total_ops = 0
186
+ if m.bidirectional:
187
+ total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2
188
+ else:
189
+ total_ops += _count_lstm_cell(input_size, hidden_size, bias)
190
+
191
+ for _ in range(num_layers - 1):
192
+ total_ops += (
193
+ _count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
194
+ if m.bidirectional
195
+ else _count_lstm_cell(hidden_size, hidden_size, bias)
196
+ )
197
+ # time unroll
198
+ total_ops *= num_steps
199
+ # batch_size
200
+ total_ops *= batch_size
201
+
202
+ m.total_ops += torch.DoubleTensor([int(total_ops)])