ultralytics-thop 0.0.1__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/PKG-INFO +10 -9
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/README.md +8 -8
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/pyproject.toml +2 -1
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/tests/test_conv2d.py +5 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/tests/test_matmul.py +3 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/tests/test_relu.py +3 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/tests/test_utils.py +4 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/fx_profile.py +18 -1
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/onnx_profile.py +6 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/profile.py +8 -3
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/rnn_hooks.py +15 -1
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/utils.py +3 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/vision/basic_hooks.py +16 -3
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/vision/calc_func.py +19 -2
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/vision/onnx_counter.py +28 -8
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/ultralytics_thop.egg-info/PKG-INFO +10 -9
 - ultralytics_thop-0.0.3/ultralytics_thop.egg-info/requires.txt +2 -0
 - ultralytics_thop-0.0.1/ultralytics_thop.egg-info/requires.txt +0 -1
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/LICENSE +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/setup.cfg +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/__init__.py +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/__version__.py +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/vision/__init__.py +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/thop/vision/efficientnet.py +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/ultralytics_thop.egg-info/SOURCES.txt +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/ultralytics_thop.egg-info/dependency_links.txt +0 -0
 - {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/ultralytics_thop.egg-info/top_level.txt +0 -0
 
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.1
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: ultralytics-thop
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.0.3
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: A tool to count the FLOPs of PyTorch model.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Author-email: Ligeng Zhu <ligeng.zhu+github@gmail.com>
         
     | 
| 
       6 
6 
     | 
    
         
             
            Maintainer-email: Ligeng Zhu <ligeng.zhu+github@gmail.com>
         
     | 
| 
         @@ -688,6 +688,7 @@ Classifier: Operating System :: Microsoft :: Windows 
     | 
|
| 
       688 
688 
     | 
    
         
             
            Requires-Python: >=3.8
         
     | 
| 
       689 
689 
     | 
    
         
             
            Description-Content-Type: text/markdown
         
     | 
| 
       690 
690 
     | 
    
         
             
            License-File: LICENSE
         
     | 
| 
      
 691 
     | 
    
         
            +
            Requires-Dist: packaging
         
     | 
| 
       691 
692 
     | 
    
         
             
            Requires-Dist: torch
         
     | 
| 
       692 
693 
     | 
    
         | 
| 
       693 
694 
     | 
    
         
             
            <br>
         
     | 
| 
         @@ -697,7 +698,7 @@ Requires-Dist: torch 
     | 
|
| 
       697 
698 
     | 
    
         | 
| 
       698 
699 
     | 
    
         
             
            Welcome to the [THOP](https://github.com/ultralytics/thop) repository, your comprehensive solution for profiling PyTorch models by computing the number of Multiply-Accumulate Operations (MACs) and parameters. This tool is essential for deep learning practitioners to evaluate model efficiency and performance.
         
     | 
| 
       699 
700 
     | 
    
         | 
| 
       700 
     | 
    
         
            -
            [](https://github.com/ultralytics/thop/actions/workflows/main.yml) [](https://github.com/ultralytics/thop/actions/workflows/main.yml) [](https://badge.fury.io/py/ultralytics-thop) <a href="https://ultralytics.com/discord"><img alt="Discord" src="https://img.shields.io/discord/1089800235347353640?logo=discord&logoColor=white&label=Discord&color=blue"></a>
         
     | 
| 
       701 
702 
     | 
    
         | 
| 
       702 
703 
     | 
    
         
             
            ## 📄 Description
         
     | 
| 
       703 
704 
     | 
    
         | 
| 
         @@ -830,17 +831,17 @@ For bugs or feature requests, please open an issue on [GitHub Issues](https://gi 
     | 
|
| 
       830 
831 
     | 
    
         | 
| 
       831 
832 
     | 
    
         
             
            <br>
         
     | 
| 
       832 
833 
     | 
    
         
             
            <div align="center">
         
     | 
| 
       833 
     | 
    
         
            -
              <a href="https://github.com/ultralytics 
     | 
| 
      
 834 
     | 
    
         
            +
              <a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
         
     | 
| 
       834 
835 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       835 
     | 
    
         
            -
              <a href="https://www.linkedin.com/company/ 
     | 
| 
      
 836 
     | 
    
         
            +
              <a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
         
     | 
| 
       836 
837 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       837 
     | 
    
         
            -
              <a href="https://twitter.com/ 
     | 
| 
      
 838 
     | 
    
         
            +
              <a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
         
     | 
| 
       838 
839 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       839 
     | 
    
         
            -
              <a href="https://youtube.com/ 
     | 
| 
      
 840 
     | 
    
         
            +
              <a href="https://youtube.com/ultralytics?sub_confirmation=1"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
         
     | 
| 
       840 
841 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       841 
     | 
    
         
            -
              <a href="https://www.tiktok.com/@ 
     | 
| 
      
 842 
     | 
    
         
            +
              <a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
         
     | 
| 
       842 
843 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       843 
     | 
    
         
            -
              <a href="https://www.instagram.com/ 
     | 
| 
      
 844 
     | 
    
         
            +
              <a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
         
     | 
| 
       844 
845 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       845 
     | 
    
         
            -
              <a href="https:// 
     | 
| 
      
 846 
     | 
    
         
            +
              <a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
         
     | 
| 
       846 
847 
     | 
    
         
             
            </div>
         
     | 
| 
         @@ -5,7 +5,7 @@ 
     | 
|
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            Welcome to the [THOP](https://github.com/ultralytics/thop) repository, your comprehensive solution for profiling PyTorch models by computing the number of Multiply-Accumulate Operations (MACs) and parameters. This tool is essential for deep learning practitioners to evaluate model efficiency and performance.
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
     | 
    
         
            -
            [](https://github.com/ultralytics/thop/actions/workflows/main.yml) [](https://github.com/ultralytics/thop/actions/workflows/main.yml) [](https://badge.fury.io/py/ultralytics-thop) <a href="https://ultralytics.com/discord"><img alt="Discord" src="https://img.shields.io/discord/1089800235347353640?logo=discord&logoColor=white&label=Discord&color=blue"></a>
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
            ## 📄 Description
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
         @@ -138,17 +138,17 @@ For bugs or feature requests, please open an issue on [GitHub Issues](https://gi 
     | 
|
| 
       138 
138 
     | 
    
         | 
| 
       139 
139 
     | 
    
         
             
            <br>
         
     | 
| 
       140 
140 
     | 
    
         
             
            <div align="center">
         
     | 
| 
       141 
     | 
    
         
            -
              <a href="https://github.com/ultralytics 
     | 
| 
      
 141 
     | 
    
         
            +
              <a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
         
     | 
| 
       142 
142 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       143 
     | 
    
         
            -
              <a href="https://www.linkedin.com/company/ 
     | 
| 
      
 143 
     | 
    
         
            +
              <a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
         
     | 
| 
       144 
144 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       145 
     | 
    
         
            -
              <a href="https://twitter.com/ 
     | 
| 
      
 145 
     | 
    
         
            +
              <a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
         
     | 
| 
       146 
146 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       147 
     | 
    
         
            -
              <a href="https://youtube.com/ 
     | 
| 
      
 147 
     | 
    
         
            +
              <a href="https://youtube.com/ultralytics?sub_confirmation=1"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
         
     | 
| 
       148 
148 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       149 
     | 
    
         
            -
              <a href="https://www.tiktok.com/@ 
     | 
| 
      
 149 
     | 
    
         
            +
              <a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
         
     | 
| 
       150 
150 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       151 
     | 
    
         
            -
              <a href="https://www.instagram.com/ 
     | 
| 
      
 151 
     | 
    
         
            +
              <a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
         
     | 
| 
       152 
152 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       153 
     | 
    
         
            -
              <a href="https:// 
     | 
| 
      
 153 
     | 
    
         
            +
              <a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
         
     | 
| 
       154 
154 
     | 
    
         
             
            </div>
         
     | 
| 
         @@ -25,7 +25,7 @@ build-backend = "setuptools.build_meta" 
     | 
|
| 
       25 
25 
     | 
    
         | 
| 
       26 
26 
     | 
    
         
             
            [project]
         
     | 
| 
       27 
27 
     | 
    
         
             
            name = "ultralytics-thop"
         
     | 
| 
       28 
     | 
    
         
            -
            version = "0.0. 
     | 
| 
      
 28 
     | 
    
         
            +
            version = "0.0.3"  # Placeholder version, needs to be dynamically set
         
     | 
| 
       29 
29 
     | 
    
         
             
            description = "A tool to count the FLOPs of PyTorch model."
         
     | 
| 
       30 
30 
     | 
    
         
             
            readme = "README.md"
         
     | 
| 
       31 
31 
     | 
    
         
             
            requires-python = ">=3.8"
         
     | 
| 
         @@ -57,6 +57,7 @@ classifiers = [ 
     | 
|
| 
       57 
57 
     | 
    
         
             
                "Operating System :: Microsoft :: Windows",
         
     | 
| 
       58 
58 
     | 
    
         
             
            ]
         
     | 
| 
       59 
59 
     | 
    
         
             
            dependencies = [
         
     | 
| 
      
 60 
     | 
    
         
            +
                "packaging",
         
     | 
| 
       60 
61 
     | 
    
         
             
                "torch",
         
     | 
| 
       61 
62 
     | 
    
         
             
            ]
         
     | 
| 
       62 
63 
     | 
    
         | 
| 
         @@ -8,6 +8,9 @@ from thop import profile 
     | 
|
| 
       8 
8 
     | 
    
         | 
| 
       9 
9 
     | 
    
         
             
            class TestUtils:
         
     | 
| 
       10 
10 
     | 
    
         
             
                def test_conv2d_no_bias(self):
         
     | 
| 
      
 11 
     | 
    
         
            +
                    """Tests a 2D convolutional layer without bias using THOP profiling with predefined input dimensions and
         
     | 
| 
      
 12 
     | 
    
         
            +
                    convolution parameters.
         
     | 
| 
      
 13 
     | 
    
         
            +
                    """
         
     | 
| 
       11 
14 
     | 
    
         
             
                    n, in_c, ih, iw = 1, 3, 32, 32  # torch.randint(1, 10, (4,)).tolist()
         
     | 
| 
       12 
15 
     | 
    
         
             
                    out_c, kh, kw = 12, 5, 5
         
     | 
| 
       13 
16 
     | 
    
         
             
                    s, p, d, g = 1, 1, 1, 1
         
     | 
| 
         @@ -22,6 +25,7 @@ class TestUtils: 
     | 
|
| 
       22 
25 
     | 
    
         
             
                    assert flops == 810000, f"{flops} v.s. {810000}"
         
     | 
| 
       23 
26 
     | 
    
         | 
| 
       24 
27 
     | 
    
         
             
                def test_conv2d(self):
         
     | 
| 
      
 28 
     | 
    
         
            +
                    """Tests a Conv2D layer with specific input dimensions, kernel size, stride, padding, dilation, and groups."""
         
     | 
| 
       25 
29 
     | 
    
         
             
                    n, in_c, ih, iw = 1, 3, 32, 32  # torch.randint(1, 10, (4,)).tolist()
         
     | 
| 
       26 
30 
     | 
    
         
             
                    out_c, kh, kw = 12, 5, 5
         
     | 
| 
       27 
31 
     | 
    
         
             
                    s, p, d, g = 1, 1, 1, 1
         
     | 
| 
         @@ -36,6 +40,7 @@ class TestUtils: 
     | 
|
| 
       36 
40 
     | 
    
         
             
                    assert flops == 810000, f"{flops} v.s. {810000}"
         
     | 
| 
       37 
41 
     | 
    
         | 
| 
       38 
42 
     | 
    
         
             
                def test_conv2d_random(self):
         
     | 
| 
      
 43 
     | 
    
         
            +
                    """Test Conv2D layer with random parameters and validate the computed FLOPs and parameters using 'profile'."""
         
     | 
| 
       39 
44 
     | 
    
         
             
                    for i in range(10):
         
     | 
| 
       40 
45 
     | 
    
         
             
                        out_c, kh, kw = torch.randint(1, 20, (3,)).tolist()
         
     | 
| 
       41 
46 
     | 
    
         
             
                        n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist()  # torch.randint(1, 10, (4,)).tolist()
         
     | 
| 
         @@ -7,6 +7,7 @@ from thop import profile 
     | 
|
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            class TestUtils:
         
     | 
| 
       9 
9 
     | 
    
         
             
                def test_matmul_case2(self):
         
     | 
| 
      
 10 
     | 
    
         
            +
                    """Test matrix multiplication case asserting the FLOPs and parameters of a nn.Linear layer."""
         
     | 
| 
       10 
11 
     | 
    
         
             
                    n, in_c, out_c = 1, 100, 200
         
     | 
| 
       11 
12 
     | 
    
         
             
                    net = nn.Linear(in_c, out_c)
         
     | 
| 
       12 
13 
     | 
    
         
             
                    flops, params = profile(net, inputs=(torch.randn(n, in_c),))
         
     | 
| 
         @@ -14,6 +15,7 @@ class TestUtils: 
     | 
|
| 
       14 
15 
     | 
    
         
             
                    assert flops == n * in_c * out_c
         
     | 
| 
       15 
16 
     | 
    
         | 
| 
       16 
17 
     | 
    
         
             
                def test_matmul_case2(self):
         
     | 
| 
      
 18 
     | 
    
         
            +
                    """Tests matrix multiplication to assert FLOPs and parameters of nn.Linear layer using random dimensions."""
         
     | 
| 
       17 
19 
     | 
    
         
             
                    for i in range(10):
         
     | 
| 
       18 
20 
     | 
    
         
             
                        n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
         
     | 
| 
       19 
21 
     | 
    
         
             
                        net = nn.Linear(in_c, out_c)
         
     | 
| 
         @@ -22,6 +24,7 @@ class TestUtils: 
     | 
|
| 
       22 
24 
     | 
    
         
             
                        assert flops == n * in_c * out_c
         
     | 
| 
       23 
25 
     | 
    
         | 
| 
       24 
26 
     | 
    
         
             
                def test_conv2d(self):
         
     | 
| 
      
 27 
     | 
    
         
            +
                    """Tests the number of FLOPs and parameters for a randomly initialized nn.Linear layer using torch.profiler."""
         
     | 
| 
       25 
28 
     | 
    
         
             
                    n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
         
     | 
| 
       26 
29 
     | 
    
         
             
                    net = nn.Linear(in_c, out_c)
         
     | 
| 
       27 
30 
     | 
    
         
             
                    flops, params = profile(net, inputs=(torch.randn(n, in_c),))
         
     | 
| 
         @@ -7,6 +7,9 @@ from thop import profile 
     | 
|
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            class TestUtils:
         
     | 
| 
       9 
9 
     | 
    
         
             
                def test_relu(self):
         
     | 
| 
      
 10 
     | 
    
         
            +
                    """Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP
         
     | 
| 
      
 11 
     | 
    
         
            +
                    profiling.
         
     | 
| 
      
 12 
     | 
    
         
            +
                    """
         
     | 
| 
       10 
13 
     | 
    
         
             
                    n, in_c, out_c = 1, 100, 200
         
     | 
| 
       11 
14 
     | 
    
         
             
                    data = torch.randn(n, in_c)
         
     | 
| 
       12 
15 
     | 
    
         
             
                    net = nn.ReLU()
         
     | 
| 
         @@ -5,12 +5,16 @@ from thop import utils 
     | 
|
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            class TestUtils:
         
     | 
| 
       7 
7 
     | 
    
         
             
                def test_clever_format_returns_formatted_number(self):
         
     | 
| 
      
 8 
     | 
    
         
            +
                    """Tests that the clever_format function returns a formatted number string with a '1.00B' pattern."""
         
     | 
| 
       8 
9 
     | 
    
         
             
                    nums = 1
         
     | 
| 
       9 
10 
     | 
    
         
             
                    format = "%.2f"
         
     | 
| 
       10 
11 
     | 
    
         
             
                    clever_nums = utils.clever_format(nums, format)
         
     | 
| 
       11 
12 
     | 
    
         
             
                    assert clever_nums == "1.00B"
         
     | 
| 
       12 
13 
     | 
    
         | 
| 
       13 
14 
     | 
    
         
             
                def test_clever_format_returns_formatted_numbers(self):
         
     | 
| 
      
 15 
     | 
    
         
            +
                    """Tests that the clever_format function correctly formats a list of numbers as strings with a '1.00B'
         
     | 
| 
      
 16 
     | 
    
         
            +
                    pattern.
         
     | 
| 
      
 17 
     | 
    
         
            +
                    """
         
     | 
| 
       14 
18 
     | 
    
         
             
                    nums = [1, 2]
         
     | 
| 
       15 
19 
     | 
    
         
             
                    format = "%.2f"
         
     | 
| 
       16 
20 
     | 
    
         
             
                    clever_nums = utils.clever_format(nums, format)
         
     | 
| 
         @@ -13,15 +13,17 @@ if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): 
     | 
|
| 
       13 
13 
     | 
    
         | 
| 
       14 
14 
     | 
    
         | 
| 
       15 
15 
     | 
    
         
             
            def count_clamp(input_shapes, output_shapes):
         
     | 
| 
      
 16 
     | 
    
         
            +
                """Ensures proper array sizes for tensors by clamping input and output shapes."""
         
     | 
| 
       16 
17 
     | 
    
         
             
                return 0
         
     | 
| 
       17 
18 
     | 
    
         | 
| 
       18 
19 
     | 
    
         | 
| 
       19 
20 
     | 
    
         
             
            def count_mul(input_shapes, output_shapes):
         
     | 
| 
       20 
     | 
    
         
            -
                 
     | 
| 
      
 21 
     | 
    
         
            +
                """Returns the number of elements in the first output shape."""
         
     | 
| 
       21 
22 
     | 
    
         
             
                return output_shapes[0].numel()
         
     | 
| 
       22 
23 
     | 
    
         | 
| 
       23 
24 
     | 
    
         | 
| 
       24 
25 
     | 
    
         
             
            def count_matmul(input_shapes, output_shapes):
         
     | 
| 
      
 26 
     | 
    
         
            +
                """Calculates the total number of operations for a matrix multiplication given input and output shapes."""
         
     | 
| 
       25 
27 
     | 
    
         
             
                in_shape = input_shapes[0]
         
     | 
| 
       26 
28 
     | 
    
         
             
                out_shape = output_shapes[0]
         
     | 
| 
       27 
29 
     | 
    
         
             
                in_features = in_shape[-1]
         
     | 
| 
         @@ -30,6 +32,7 @@ def count_matmul(input_shapes, output_shapes): 
     | 
|
| 
       30 
32 
     | 
    
         | 
| 
       31 
33 
     | 
    
         | 
| 
       32 
34 
     | 
    
         
             
            def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
         
     | 
| 
      
 35 
     | 
    
         
            +
                """Calculates total operations (FLOPs) for a linear layer given input and output shapes."""
         
     | 
| 
       33 
36 
     | 
    
         
             
                mul_flops = count_matmul(input_shapes, output_shapes)
         
     | 
| 
       34 
37 
     | 
    
         
             
                if "bias" in kwargs:
         
     | 
| 
       35 
38 
     | 
    
         
             
                    add_flops = output_shapes[0].numel()
         
     | 
| 
         @@ -40,6 +43,7 @@ from .vision.calc_func import calculate_conv 
     | 
|
| 
       40 
43 
     | 
    
         | 
| 
       41 
44 
     | 
    
         | 
| 
       42 
45 
     | 
    
         
             
            def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
         
     | 
| 
      
 46 
     | 
    
         
            +
                """Calculates total operations (FLOPs) for a 2D convolutional layer given input and output shapes."""
         
     | 
| 
       43 
47 
     | 
    
         
             
                inputs, weight, bias, stride, padding, dilation, groups = args
         
     | 
| 
       44 
48 
     | 
    
         
             
                if len(input_shapes) == 2:
         
     | 
| 
       45 
49 
     | 
    
         
             
                    x_shape, k_shape = input_shapes
         
     | 
| 
         @@ -56,14 +60,17 @@ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs): 
     | 
|
| 
       56 
60 
     | 
    
         | 
| 
       57 
61 
     | 
    
         | 
| 
       58 
62 
     | 
    
         
             
            def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
         
     | 
| 
      
 63 
     | 
    
         
            +
                """Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
         
     | 
| 
       59 
64 
     | 
    
         
             
                return count_matmul(input_shapes, output_shapes)
         
     | 
| 
       60 
65 
     | 
    
         | 
| 
       61 
66 
     | 
    
         | 
| 
       62 
67 
     | 
    
         
             
            def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
         
     | 
| 
      
 68 
     | 
    
         
            +
                """Returns 0 for the given neural network module, input shapes, and output shapes."""
         
     | 
| 
       63 
69 
     | 
    
         
             
                return 0
         
     | 
| 
       64 
70 
     | 
    
         | 
| 
       65 
71 
     | 
    
         | 
| 
       66 
72 
     | 
    
         
             
            def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
         
     | 
| 
      
 73 
     | 
    
         
            +
                """Calculates total operations for a 2D convolutional neural network layer in a given neural network module."""
         
     | 
| 
       67 
74 
     | 
    
         
             
                bias_op = 1 if module.bias is not None else 0
         
     | 
| 
       68 
75 
     | 
    
         
             
                out_shape = output_shapes[0]
         
     | 
| 
       69 
76 
     | 
    
         | 
| 
         @@ -75,6 +82,7 @@ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes): 
     | 
|
| 
       75 
82 
     | 
    
         | 
| 
       76 
83 
     | 
    
         | 
| 
       77 
84 
     | 
    
         
             
            def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
         
     | 
| 
      
 85 
     | 
    
         
            +
                """Calculate the total operations for a given nn.BatchNorm2d module based on its output shape."""
         
     | 
| 
       78 
86 
     | 
    
         
             
                assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
         
     | 
| 
       79 
87 
     | 
    
         
             
                y = output_shapes[0]
         
     | 
| 
       80 
88 
     | 
    
         
             
                # y = (x - mean) / \sqrt{var + e} * weight + bias
         
     | 
| 
         @@ -116,10 +124,14 @@ from .utils import prGreen, prRed, prYellow 
     | 
|
| 
       116 
124 
     | 
    
         | 
| 
       117 
125 
     | 
    
         | 
| 
       118 
126 
     | 
    
         
             
            def null_print(*args, **kwargs):
         
     | 
| 
      
 127 
     | 
    
         
            +
                """A no-op print function that takes any arguments without performing any actions."""
         
     | 
| 
       119 
128 
     | 
    
         
             
                return
         
     | 
| 
       120 
129 
     | 
    
         | 
| 
       121 
130 
     | 
    
         | 
| 
       122 
131 
     | 
    
         
             
            def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
         
     | 
| 
      
 132 
     | 
    
         
            +
                """Profiles the given torch.nn Module to calculate total FLOPs for each operation and prints detailed node
         
     | 
| 
      
 133 
     | 
    
         
            +
                information if verbose.
         
     | 
| 
      
 134 
     | 
    
         
            +
                """
         
     | 
| 
       123 
135 
     | 
    
         
             
                gm: torch.fx.GraphModule = symbolic_trace(mod)
         
     | 
| 
       124 
136 
     | 
    
         
             
                g = gm.graph
         
     | 
| 
       125 
137 
     | 
    
         
             
                ShapeProp(gm).propagate(input)
         
     | 
| 
         @@ -204,16 +216,21 @@ if __name__ == "__main__": 
     | 
|
| 
       204 
216 
     | 
    
         | 
| 
       205 
217 
     | 
    
         
             
                class MyOP(nn.Module):
         
     | 
| 
       206 
218 
     | 
    
         
             
                    def forward(self, input):
         
     | 
| 
      
 219 
     | 
    
         
            +
                        """Performs forward pass on given input data."""
         
     | 
| 
       207 
220 
     | 
    
         
             
                        return input / 1
         
     | 
| 
       208 
221 
     | 
    
         | 
| 
       209 
222 
     | 
    
         
             
                class MyModule(torch.nn.Module):
         
     | 
| 
       210 
223 
     | 
    
         
             
                    def __init__(self):
         
     | 
| 
      
 224 
     | 
    
         
            +
                        """Initializes MyModule with two linear layers and a custom MyOP operator."""
         
     | 
| 
       211 
225 
     | 
    
         
             
                        super().__init__()
         
     | 
| 
       212 
226 
     | 
    
         
             
                        self.linear1 = torch.nn.Linear(5, 3)
         
     | 
| 
       213 
227 
     | 
    
         
             
                        self.linear2 = torch.nn.Linear(5, 3)
         
     | 
| 
       214 
228 
     | 
    
         
             
                        self.myop = MyOP()
         
     | 
| 
       215 
229 
     | 
    
         | 
| 
       216 
230 
     | 
    
         
             
                    def forward(self, x):
         
     | 
| 
      
 231 
     | 
    
         
            +
                        """Applies two linear transformations to the input tensor, clamps the second, then combines and processes
         
     | 
| 
      
 232 
     | 
    
         
            +
                        with MyOP operator.
         
     | 
| 
      
 233 
     | 
    
         
            +
                        """
         
     | 
| 
       217 
234 
     | 
    
         
             
                        out1 = self.linear1(x)
         
     | 
| 
       218 
235 
     | 
    
         
             
                        out2 = self.linear2(x).clamp(min=0.0, max=1.0)
         
     | 
| 
       219 
236 
     | 
    
         
             
                        return self.myop(out1 + out2)
         
     | 
| 
         @@ -9,9 +9,11 @@ from thop.vision.onnx_counter import onnx_operators 
     | 
|
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
            class OnnxProfile:
         
     | 
| 
       11 
11 
     | 
    
         
             
                def __init__(self):
         
     | 
| 
      
 12 
     | 
    
         
            +
                    """Initialize the OnnxProfile class with necessary imports for ONNX profiling."""
         
     | 
| 
       12 
13 
     | 
    
         
             
                    pass
         
     | 
| 
       13 
14 
     | 
    
         | 
| 
       14 
15 
     | 
    
         
             
                def calculate_params(self, model: onnx.ModelProto):
         
     | 
| 
      
 16 
     | 
    
         
            +
                    """Calculate the total number of parameters in an ONNX model."""
         
     | 
| 
       15 
17 
     | 
    
         
             
                    onnx_weights = model.graph.initializer
         
     | 
| 
       16 
18 
     | 
    
         
             
                    params = 0
         
     | 
| 
       17 
19 
     | 
    
         | 
| 
         @@ -25,6 +27,7 @@ class OnnxProfile: 
     | 
|
| 
       25 
27 
     | 
    
         
             
                    return params
         
     | 
| 
       26 
28 
     | 
    
         | 
| 
       27 
29 
     | 
    
         
             
                def create_dict(self, weight, input, output):
         
     | 
| 
      
 30 
     | 
    
         
            +
                    """Create and return a dictionary mapping weight, input, and output names to their respective dimensions."""
         
     | 
| 
       28 
31 
     | 
    
         
             
                    diction = {}
         
     | 
| 
       29 
32 
     | 
    
         
             
                    for w in weight:
         
     | 
| 
       30 
33 
     | 
    
         
             
                        dim = np.array(w.dims)
         
     | 
| 
         @@ -52,6 +55,9 @@ class OnnxProfile: 
     | 
|
| 
       52 
55 
     | 
    
         
             
                    return diction
         
     | 
| 
       53 
56 
     | 
    
         | 
| 
       54 
57 
     | 
    
         
             
                def nodes_counter(self, diction, node):
         
     | 
| 
      
 58 
     | 
    
         
            +
                    """Count nodes of a specific type in an ONNX graph, returning the count and associated node operation
         
     | 
| 
      
 59 
     | 
    
         
            +
                    details.
         
     | 
| 
      
 60 
     | 
    
         
            +
                    """
         
     | 
| 
       55 
61 
     | 
    
         
             
                    if node.op_type not in onnx_operators:
         
     | 
| 
       56 
62 
     | 
    
         
             
                        print("Sorry, we haven't add ", node.op_type, "into dictionary.")
         
     | 
| 
       57 
63 
     | 
    
         
             
                        return 0, None, None
         
     | 
| 
         @@ -1,4 +1,4 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from  
     | 
| 
      
 1 
     | 
    
         
            +
            from packaging.version import Version
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            from thop.rnn_hooks import *
         
     | 
| 
       4 
4 
     | 
    
         
             
            from thop.vision.basic_hooks import *
         
     | 
| 
         @@ -7,7 +7,7 @@ from thop.vision.basic_hooks import * 
     | 
|
| 
       7 
7 
     | 
    
         
             
            # logger.setLevel(logging.INFO)
         
     | 
| 
       8 
8 
     | 
    
         
             
            from .utils import prGreen, prRed, prYellow
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
     | 
    
         
            -
            if  
     | 
| 
      
 10 
     | 
    
         
            +
            if Version(torch.__version__) < Version("1.0.0"):
         
     | 
| 
       11 
11 
     | 
    
         
             
                logging.warning(
         
     | 
| 
       12 
12 
     | 
    
         
             
                    "You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__)
         
     | 
| 
       13 
13 
     | 
    
         
             
                )
         
     | 
| 
         @@ -61,11 +61,14 @@ register_hooks = { 
     | 
|
| 
       61 
61 
     | 
    
         
             
                nn.PixelShuffle: zero_ops,
         
     | 
| 
       62 
62 
     | 
    
         
             
            }
         
     | 
| 
       63 
63 
     | 
    
         | 
| 
       64 
     | 
    
         
            -
            if  
     | 
| 
      
 64 
     | 
    
         
            +
            if Version(torch.__version__) >= Version("1.1.0"):
         
     | 
| 
       65 
65 
     | 
    
         
             
                register_hooks.update({nn.SyncBatchNorm: count_normalization})
         
     | 
| 
       66 
66 
     | 
    
         | 
| 
       67 
67 
     | 
    
         | 
| 
       68 
68 
     | 
    
         
             
            def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
         
     | 
| 
      
 69 
     | 
    
         
            +
                """Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total
         
     | 
| 
      
 70 
     | 
    
         
            +
                operations and parameters.
         
     | 
| 
      
 71 
     | 
    
         
            +
                """
         
     | 
| 
       69 
72 
     | 
    
         
             
                handler_collection = []
         
     | 
| 
       70 
73 
     | 
    
         
             
                types_collection = set()
         
     | 
| 
       71 
74 
     | 
    
         
             
                if custom_ops is None:
         
     | 
| 
         @@ -162,6 +165,7 @@ def profile( 
     | 
|
| 
       162 
165 
     | 
    
         
             
                    verbose = True
         
     | 
| 
       163 
166 
     | 
    
         | 
| 
       164 
167 
     | 
    
         
             
                def add_hooks(m: nn.Module):
         
     | 
| 
      
 168 
     | 
    
         
            +
                    """Registers hooks to a neural network module to track total operations and parameters."""
         
     | 
| 
       165 
169 
     | 
    
         
             
                    m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
         
     | 
| 
       166 
170 
     | 
    
         
             
                    m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
         
     | 
| 
       167 
171 
     | 
    
         | 
| 
         @@ -200,6 +204,7 @@ def profile( 
     | 
|
| 
       200 
204 
     | 
    
         
             
                    model(*inputs)
         
     | 
| 
       201 
205 
     | 
    
         | 
| 
       202 
206 
     | 
    
         
             
                def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
         
     | 
| 
      
 207 
     | 
    
         
            +
                    """Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
         
     | 
| 
       203 
208 
     | 
    
         
             
                    total_ops, total_params = module.total_ops.item(), 0
         
     | 
| 
       204 
209 
     | 
    
         
             
                    ret_dict = {}
         
     | 
| 
       205 
210 
     | 
    
         
             
                    for n, m in module.named_children():
         
     | 
| 
         @@ -4,7 +4,7 @@ from torch.nn.utils.rnn import PackedSequence 
     | 
|
| 
       4 
4 
     | 
    
         | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            def _count_rnn_cell(input_size, hidden_size, bias=True):
         
     | 
| 
       7 
     | 
    
         
            -
                 
     | 
| 
      
 7 
     | 
    
         
            +
                """Calculate the total operations for an RNN cell based on input size, hidden size, and bias configuration."""
         
     | 
| 
       8 
8 
     | 
    
         
             
                total_ops = hidden_size * (input_size + hidden_size) + hidden_size
         
     | 
| 
       9 
9 
     | 
    
         
             
                if bias:
         
     | 
| 
       10 
10 
     | 
    
         
             
                    total_ops += hidden_size * 2
         
     | 
| 
         @@ -13,6 +13,7 @@ def _count_rnn_cell(input_size, hidden_size, bias=True): 
     | 
|
| 
       13 
13 
     | 
    
         | 
| 
       14 
14 
     | 
    
         | 
| 
       15 
15 
     | 
    
         
             
            def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
         
     | 
| 
      
 16 
     | 
    
         
            +
                """Counts RNN cell operations based on input, hidden size, bias, and batch size."""
         
     | 
| 
       16 
17 
     | 
    
         
             
                total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)
         
     | 
| 
       17 
18 
     | 
    
         | 
| 
       18 
19 
     | 
    
         
             
                batch_size = x[0].size(0)
         
     | 
| 
         @@ -22,6 +23,7 @@ def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor): 
     | 
|
| 
       22 
23 
     | 
    
         | 
| 
       23 
24 
     | 
    
         | 
| 
       24 
25 
     | 
    
         
             
            def _count_gru_cell(input_size, hidden_size, bias=True):
         
     | 
| 
      
 26 
     | 
    
         
            +
                """Counts the total operations for a GRU cell based on input size, hidden size, and bias."""
         
     | 
| 
       25 
27 
     | 
    
         
             
                total_ops = 0
         
     | 
| 
       26 
28 
     | 
    
         
             
                # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
         
     | 
| 
       27 
29 
     | 
    
         
             
                # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
         
     | 
| 
         @@ -45,6 +47,7 @@ def _count_gru_cell(input_size, hidden_size, bias=True): 
     | 
|
| 
       45 
47 
     | 
    
         | 
| 
       46 
48 
     | 
    
         | 
| 
       47 
49 
     | 
    
         
             
            def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
         
     | 
| 
      
 50 
     | 
    
         
            +
                """Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
         
     | 
| 
       48 
51 
     | 
    
         
             
                total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)
         
     | 
| 
       49 
52 
     | 
    
         | 
| 
       50 
53 
     | 
    
         
             
                batch_size = x[0].size(0)
         
     | 
| 
         @@ -54,6 +57,9 @@ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor): 
     | 
|
| 
       54 
57 
     | 
    
         | 
| 
       55 
58 
     | 
    
         | 
| 
       56 
59 
     | 
    
         
             
            def _count_lstm_cell(input_size, hidden_size, bias=True):
         
     | 
| 
      
 60 
     | 
    
         
            +
                """Calculates the total operations for an LSTM cell during inference given input size, hidden size, and optional
         
     | 
| 
      
 61 
     | 
    
         
            +
                bias.
         
     | 
| 
      
 62 
     | 
    
         
            +
                """
         
     | 
| 
       57 
63 
     | 
    
         
             
                total_ops = 0
         
     | 
| 
       58 
64 
     | 
    
         | 
| 
       59 
65 
     | 
    
         
             
                # i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
         
     | 
| 
         @@ -76,6 +82,9 @@ def _count_lstm_cell(input_size, hidden_size, bias=True): 
     | 
|
| 
       76 
82 
     | 
    
         | 
| 
       77 
83 
     | 
    
         | 
| 
       78 
84 
     | 
    
         
             
            def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
         
     | 
| 
      
 85 
     | 
    
         
            +
                """Count the number of operations for a single LSTM cell in a given batch, updating the model's total operations
         
     | 
| 
      
 86 
     | 
    
         
            +
                count.
         
     | 
| 
      
 87 
     | 
    
         
            +
                """
         
     | 
| 
       79 
88 
     | 
    
         
             
                total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)
         
     | 
| 
       80 
89 
     | 
    
         | 
| 
       81 
90 
     | 
    
         
             
                batch_size = x[0].size(0)
         
     | 
| 
         @@ -85,6 +94,7 @@ def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor): 
     | 
|
| 
       85 
94 
     | 
    
         | 
| 
       86 
95 
     | 
    
         | 
| 
       87 
96 
     | 
    
         
             
            def count_rnn(m: nn.RNN, x, y):
         
     | 
| 
      
 97 
     | 
    
         
            +
                """Calculate and update the total number of operations for a single RNN cell in a given batch."""
         
     | 
| 
       88 
98 
     | 
    
         
             
                bias = m.bias
         
     | 
| 
       89 
99 
     | 
    
         
             
                input_size = m.input_size
         
     | 
| 
       90 
100 
     | 
    
         
             
                hidden_size = m.hidden_size
         
     | 
| 
         @@ -122,6 +132,7 @@ def count_rnn(m: nn.RNN, x, y): 
     | 
|
| 
       122 
132 
     | 
    
         | 
| 
       123 
133 
     | 
    
         | 
| 
       124 
134 
     | 
    
         
             
            def count_gru(m: nn.GRU, x, y):
         
     | 
| 
      
 135 
     | 
    
         
            +
                """Calculate the total number of operations for a GRU layer in a neural network model."""
         
     | 
| 
       125 
136 
     | 
    
         
             
                bias = m.bias
         
     | 
| 
       126 
137 
     | 
    
         
             
                input_size = m.input_size
         
     | 
| 
       127 
138 
     | 
    
         
             
                hidden_size = m.hidden_size
         
     | 
| 
         @@ -159,6 +170,9 @@ def count_gru(m: nn.GRU, x, y): 
     | 
|
| 
       159 
170 
     | 
    
         | 
| 
       160 
171 
     | 
    
         | 
| 
       161 
172 
     | 
    
         
             
            def count_lstm(m: nn.LSTM, x, y):
         
     | 
| 
      
 173 
     | 
    
         
            +
                """Calculate the total operations for LSTM layers in a network, accounting for input size, hidden size, bias, and
         
     | 
| 
      
 174 
     | 
    
         
            +
                bidirectionality.
         
     | 
| 
      
 175 
     | 
    
         
            +
                """
         
     | 
| 
       162 
176 
     | 
    
         
             
                bias = m.bias
         
     | 
| 
       163 
177 
     | 
    
         
             
                input_size = m.input_size
         
     | 
| 
       164 
178 
     | 
    
         
             
                hidden_size = m.hidden_size
         
     | 
| 
         @@ -6,6 +6,8 @@ COLOR_YELLOW = "93m" 
     | 
|
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            def colorful_print(fn_print, color=COLOR_RED):
         
     | 
| 
      
 9 
     | 
    
         
            +
                """A decorator to print text in the specified terminal color by wrapping the given print function."""
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
       9 
11 
     | 
    
         
             
                def actual_call(*args, **kwargs):
         
     | 
| 
       10 
12 
     | 
    
         
             
                    print(f"\033[{color}", end="")
         
     | 
| 
       11 
13 
     | 
    
         
             
                    fn_print(*args, **kwargs)
         
     | 
| 
         @@ -29,6 +31,7 @@ prYellow = colorful_print(print, color=COLOR_YELLOW) 
     | 
|
| 
       29 
31 
     | 
    
         | 
| 
       30 
32 
     | 
    
         | 
| 
       31 
33 
     | 
    
         
             
            def clever_format(nums, format="%.2f"):
         
     | 
| 
      
 34 
     | 
    
         
            +
                """Formats numerical values into a more readable string with units (K, M, G, T) based on their magnitude."""
         
     | 
| 
       32 
35 
     | 
    
         
             
                if not isinstance(nums, Iterable):
         
     | 
| 
       33 
36 
     | 
    
         
             
                    nums = [nums]
         
     | 
| 
       34 
37 
     | 
    
         
             
                clever_nums = []
         
     | 
| 
         @@ -11,6 +11,7 @@ multiply_adds = 1 
     | 
|
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
            def count_parameters(m, x, y):
         
     | 
| 
      
 14 
     | 
    
         
            +
                """Calculate and update the total number of parameters in a given PyTorch model."""
         
     | 
| 
       14 
15 
     | 
    
         
             
                total_params = 0
         
     | 
| 
       15 
16 
     | 
    
         
             
                for p in m.parameters():
         
     | 
| 
       16 
17 
     | 
    
         
             
                    total_params += torch.DoubleTensor([p.numel()])
         
     | 
| 
         @@ -18,10 +19,12 @@ def count_parameters(m, x, y): 
     | 
|
| 
       18 
19 
     | 
    
         | 
| 
       19 
20 
     | 
    
         | 
| 
       20 
21 
     | 
    
         
             
            def zero_ops(m, x, y):
         
     | 
| 
      
 22 
     | 
    
         
            +
                """Incrementally add the number of zero operations to the model's total operations count."""
         
     | 
| 
       21 
23 
     | 
    
         
             
                m.total_ops += calculate_zero_ops()
         
     | 
| 
       22 
24 
     | 
    
         | 
| 
       23 
25 
     | 
    
         | 
| 
       24 
26 
     | 
    
         
             
            def count_convNd(m: _ConvNd, x, y: torch.Tensor):
         
     | 
| 
      
 27 
     | 
    
         
            +
                """Calculate and add the number of convolutional operations (FLOPs) to the model's total operations count."""
         
     | 
| 
       25 
28 
     | 
    
         
             
                x = x[0]
         
     | 
| 
       26 
29 
     | 
    
         | 
| 
       27 
30 
     | 
    
         
             
                kernel_ops = torch.zeros(m.weight.size()[2:]).numel()  # Kw x Kh
         
     | 
| 
         @@ -45,6 +48,7 @@ def count_convNd(m: _ConvNd, x, y: torch.Tensor): 
     | 
|
| 
       45 
48 
     | 
    
         | 
| 
       46 
49 
     | 
    
         | 
| 
       47 
50 
     | 
    
         
             
            def count_convNd_ver2(m: _ConvNd, x, y: torch.Tensor):
         
     | 
| 
      
 51 
     | 
    
         
            +
                """Calculates the total operations for a convolutional layer and updates the layer's total_ops attribute."""
         
     | 
| 
       48 
52 
     | 
    
         
             
                x = x[0]
         
     | 
| 
       49 
53 
     | 
    
         | 
| 
       50 
54 
     | 
    
         
             
                # N x H x W (exclude Cout)
         
     | 
| 
         @@ -60,7 +64,9 @@ def count_convNd_ver2(m: _ConvNd, x, y: torch.Tensor): 
     | 
|
| 
       60 
64 
     | 
    
         | 
| 
       61 
65 
     | 
    
         | 
| 
       62 
66 
     | 
    
         
             
            def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
         
     | 
| 
       63 
     | 
    
         
            -
                 
     | 
| 
      
 67 
     | 
    
         
            +
                """Calculate and add the FLOPs for a batch normalization layer, considering elementwise operations and possible
         
     | 
| 
      
 68 
     | 
    
         
            +
                affine parameters.
         
     | 
| 
      
 69 
     | 
    
         
            +
                """
         
     | 
| 
       64 
70 
     | 
    
         
             
                # https://github.com/Lyken17/pytorch-OpCounter/issues/124
         
     | 
| 
       65 
71 
     | 
    
         
             
                # y = (x - mean) / sqrt(eps + var) * weight + bias
         
     | 
| 
       66 
72 
     | 
    
         
             
                x = x[0]
         
     | 
| 
         @@ -82,6 +88,7 @@ def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y): 
     | 
|
| 
       82 
88 
     | 
    
         | 
| 
       83 
89 
     | 
    
         | 
| 
       84 
90 
     | 
    
         
             
            def count_prelu(m, x, y):
         
     | 
| 
      
 91 
     | 
    
         
            +
                """Calculate and update the total operation counts for a PReLU layer."""
         
     | 
| 
       85 
92 
     | 
    
         
             
                x = x[0]
         
     | 
| 
       86 
93 
     | 
    
         | 
| 
       87 
94 
     | 
    
         
             
                nelements = x.numel()
         
     | 
| 
         @@ -90,6 +97,7 @@ def count_prelu(m, x, y): 
     | 
|
| 
       90 
97 
     | 
    
         | 
| 
       91 
98 
     | 
    
         | 
| 
       92 
99 
     | 
    
         
             
            def count_relu(m, x, y):
         
     | 
| 
      
 100 
     | 
    
         
            +
                """Calculate and update the total operation counts for a ReLU layer."""
         
     | 
| 
       93 
101 
     | 
    
         
             
                x = x[0]
         
     | 
| 
       94 
102 
     | 
    
         | 
| 
       95 
103 
     | 
    
         
             
                nelements = x.numel()
         
     | 
| 
         @@ -98,6 +106,7 @@ def count_relu(m, x, y): 
     | 
|
| 
       98 
106 
     | 
    
         | 
| 
       99 
107 
     | 
    
         | 
| 
       100 
108 
     | 
    
         
             
            def count_softmax(m, x, y):
         
     | 
| 
      
 109 
     | 
    
         
            +
                """Calculate and update the total operation counts for a Softmax layer."""
         
     | 
| 
       101 
110 
     | 
    
         
             
                x = x[0]
         
     | 
| 
       102 
111 
     | 
    
         
             
                nfeatures = x.size()[m.dim]
         
     | 
| 
       103 
112 
     | 
    
         
             
                batch_size = x.numel() // nfeatures
         
     | 
| 
         @@ -106,7 +115,7 @@ def count_softmax(m, x, y): 
     | 
|
| 
       106 
115 
     | 
    
         | 
| 
       107 
116 
     | 
    
         | 
| 
       108 
117 
     | 
    
         
             
            def count_avgpool(m, x, y):
         
     | 
| 
       109 
     | 
    
         
            -
                 
     | 
| 
      
 118 
     | 
    
         
            +
                """Calculate and update the total operation counts for an AvgPool layer."""
         
     | 
| 
       110 
119 
     | 
    
         
             
                # total_div = 1
         
     | 
| 
       111 
120 
     | 
    
         
             
                # kernel_ops = total_add + total_div
         
     | 
| 
       112 
121 
     | 
    
         
             
                num_elements = y.numel()
         
     | 
| 
         @@ -114,6 +123,7 @@ def count_avgpool(m, x, y): 
     | 
|
| 
       114 
123 
     | 
    
         | 
| 
       115 
124 
     | 
    
         | 
| 
       116 
125 
     | 
    
         
             
            def count_adap_avgpool(m, x, y):
         
     | 
| 
      
 126 
     | 
    
         
            +
                """Calculate and update the total operation counts for an AdaptiveAvgPool layer."""
         
     | 
| 
       117 
127 
     | 
    
         
             
                kernel = torch.div(torch.DoubleTensor([*(x[0].shape[2:])]), torch.DoubleTensor([*(y.shape[2:])]))
         
     | 
| 
       118 
128 
     | 
    
         
             
                total_add = torch.prod(kernel)
         
     | 
| 
       119 
129 
     | 
    
         
             
                num_elements = y.numel()
         
     | 
| 
         @@ -122,6 +132,7 @@ def count_adap_avgpool(m, x, y): 
     | 
|
| 
       122 
132 
     | 
    
         | 
| 
       123 
133 
     | 
    
         
             
            # TODO: verify the accuracy
         
     | 
| 
       124 
134 
     | 
    
         
             
            def count_upsample(m, x, y):
         
     | 
| 
      
 135 
     | 
    
         
            +
                """Update the total operations counter in the given module for supported upsampling modes."""
         
     | 
| 
       125 
136 
     | 
    
         
             
                if m.mode not in (
         
     | 
| 
       126 
137 
     | 
    
         
             
                    "nearest",
         
     | 
| 
       127 
138 
     | 
    
         
             
                    "linear",
         
     | 
| 
         @@ -137,7 +148,9 @@ def count_upsample(m, x, y): 
     | 
|
| 
       137 
148 
     | 
    
         | 
| 
       138 
149 
     | 
    
         
             
            # nn.Linear
         
     | 
| 
       139 
150 
     | 
    
         
             
            def count_linear(m, x, y):
         
     | 
| 
       140 
     | 
    
         
            -
                 
     | 
| 
      
 151 
     | 
    
         
            +
                """Counts total operations for nn.Linear layers by calculating multiplications and additions based on input and
         
     | 
| 
      
 152 
     | 
    
         
            +
                output elements.
         
     | 
| 
      
 153 
     | 
    
         
            +
                """
         
     | 
| 
       141 
154 
     | 
    
         
             
                total_mul = m.in_features
         
     | 
| 
       142 
155 
     | 
    
         
             
                # total_add = m.in_features - 1
         
     | 
| 
       143 
156 
     | 
    
         
             
                # total_add += 1 if m.bias is not None else 0
         
     | 
| 
         @@ -5,6 +5,7 @@ import torch 
     | 
|
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            def l_prod(in_list):
         
     | 
| 
      
 8 
     | 
    
         
            +
                """Calculate the product of all elements in a list."""
         
     | 
| 
       8 
9 
     | 
    
         
             
                res = 1
         
     | 
| 
       9 
10 
     | 
    
         
             
                for _ in in_list:
         
     | 
| 
       10 
11 
     | 
    
         
             
                    res *= _
         
     | 
| 
         @@ -12,6 +13,7 @@ def l_prod(in_list): 
     | 
|
| 
       12 
13 
     | 
    
         | 
| 
       13 
14 
     | 
    
         | 
| 
       14 
15 
     | 
    
         
             
            def l_sum(in_list):
         
     | 
| 
      
 16 
     | 
    
         
            +
                """Calculate the sum of all elements in a list."""
         
     | 
| 
       15 
17 
     | 
    
         
             
                res = 0
         
     | 
| 
       16 
18 
     | 
    
         
             
                for _ in in_list:
         
     | 
| 
       17 
19 
     | 
    
         
             
                    res += _
         
     | 
| 
         @@ -19,6 +21,7 @@ def l_sum(in_list): 
     | 
|
| 
       19 
21 
     | 
    
         | 
| 
       20 
22 
     | 
    
         | 
| 
       21 
23 
     | 
    
         
             
            def calculate_parameters(param_list):
         
     | 
| 
      
 24 
     | 
    
         
            +
                """Calculate the total number of parameters in a list of tensors."""
         
     | 
| 
       22 
25 
     | 
    
         
             
                total_params = 0
         
     | 
| 
       23 
26 
     | 
    
         
             
                for p in param_list:
         
     | 
| 
       24 
27 
     | 
    
         
             
                    total_params += torch.DoubleTensor([p.nelement()])
         
     | 
| 
         @@ -26,11 +29,12 @@ def calculate_parameters(param_list): 
     | 
|
| 
       26 
29 
     | 
    
         | 
| 
       27 
30 
     | 
    
         | 
| 
       28 
31 
     | 
    
         
             
            def calculate_zero_ops():
         
     | 
| 
      
 32 
     | 
    
         
            +
                """Return a tensor initialized to zero."""
         
     | 
| 
       29 
33 
     | 
    
         
             
                return torch.DoubleTensor([int(0)])
         
     | 
| 
       30 
34 
     | 
    
         | 
| 
       31 
35 
     | 
    
         | 
| 
       32 
36 
     | 
    
         
             
            def calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: list, groups: int, bias: bool = False):
         
     | 
| 
       33 
     | 
    
         
            -
                 
     | 
| 
      
 37 
     | 
    
         
            +
                """Calculate FLOPs for a Conv2D layer given input/output sizes, kernel size, groups, and bias flag."""
         
     | 
| 
       34 
38 
     | 
    
         
             
                # n, in_c, ih, iw = input_size
         
     | 
| 
       35 
39 
     | 
    
         
             
                # out_c, in_c, kh, kw = kernel_size
         
     | 
| 
       36 
40 
     | 
    
         
             
                in_c = input_size[1]
         
     | 
| 
         @@ -50,16 +54,20 @@ def calculate_norm(input_size): 
     | 
|
| 
       50 
54 
     | 
    
         | 
| 
       51 
55 
     | 
    
         | 
| 
       52 
56 
     | 
    
         
             
            def calculate_relu_flops(input_size):
         
     | 
| 
       53 
     | 
    
         
            -
                 
     | 
| 
      
 57 
     | 
    
         
            +
                """Calculates the FLOPs for a ReLU activation function based on the input size."""
         
     | 
| 
       54 
58 
     | 
    
         
             
                return 0
         
     | 
| 
       55 
59 
     | 
    
         | 
| 
       56 
60 
     | 
    
         | 
| 
       57 
61 
     | 
    
         
             
            def calculate_relu(input_size: torch.Tensor):
         
     | 
| 
      
 62 
     | 
    
         
            +
                """Convert an input tensor to a DoubleTensor with the same value."""
         
     | 
| 
       58 
63 
     | 
    
         
             
                warnings.warn("This API is being deprecated")
         
     | 
| 
       59 
64 
     | 
    
         
             
                return torch.DoubleTensor([int(input_size)])
         
     | 
| 
       60 
65 
     | 
    
         | 
| 
       61 
66 
     | 
    
         | 
| 
       62 
67 
     | 
    
         
             
            def calculate_softmax(batch_size, nfeatures):
         
     | 
| 
      
 68 
     | 
    
         
            +
                """Calculate the number of FLOPs required for a softmax activation function based on batch size and number of
         
     | 
| 
      
 69 
     | 
    
         
            +
                features.
         
     | 
| 
      
 70 
     | 
    
         
            +
                """
         
     | 
| 
       63 
71 
     | 
    
         
             
                total_exp = nfeatures
         
     | 
| 
       64 
72 
     | 
    
         
             
                total_add = nfeatures - 1
         
     | 
| 
       65 
73 
     | 
    
         
             
                total_div = nfeatures
         
     | 
| 
         @@ -68,16 +76,19 @@ def calculate_softmax(batch_size, nfeatures): 
     | 
|
| 
       68 
76 
     | 
    
         | 
| 
       69 
77 
     | 
    
         | 
| 
       70 
78 
     | 
    
         
             
            def calculate_avgpool(input_size):
         
     | 
| 
      
 79 
     | 
    
         
            +
                """Calculate the average pooling size given the input size."""
         
     | 
| 
       71 
80 
     | 
    
         
             
                return torch.DoubleTensor([int(input_size)])
         
     | 
| 
       72 
81 
     | 
    
         | 
| 
       73 
82 
     | 
    
         | 
| 
       74 
83 
     | 
    
         
             
            def calculate_adaptive_avg(kernel_size, output_size):
         
     | 
| 
      
 84 
     | 
    
         
            +
                """Calculate the number of operations for adaptive average pooling given kernel and output sizes."""
         
     | 
| 
       75 
85 
     | 
    
         
             
                total_div = 1
         
     | 
| 
       76 
86 
     | 
    
         
             
                kernel_op = kernel_size + total_div
         
     | 
| 
       77 
87 
     | 
    
         
             
                return torch.DoubleTensor([int(kernel_op * output_size)])
         
     | 
| 
       78 
88 
     | 
    
         | 
| 
       79 
89 
     | 
    
         | 
| 
       80 
90 
     | 
    
         
             
            def calculate_upsample(mode: str, output_size):
         
     | 
| 
      
 91 
     | 
    
         
            +
                """Calculate the number of operations for upsample methods given the mode and output size."""
         
     | 
| 
       81 
92 
     | 
    
         
             
                total_ops = output_size
         
     | 
| 
       82 
93 
     | 
    
         
             
                if mode == "linear":
         
     | 
| 
       83 
94 
     | 
    
         
             
                    total_ops *= 5
         
     | 
| 
         @@ -93,26 +104,32 @@ def calculate_upsample(mode: str, output_size): 
     | 
|
| 
       93 
104 
     | 
    
         | 
| 
       94 
105 
     | 
    
         | 
| 
       95 
106 
     | 
    
         
             
            def calculate_linear(in_feature, num_elements):
         
     | 
| 
      
 107 
     | 
    
         
            +
                """Calculate the linear operation count for an input feature and number of elements."""
         
     | 
| 
       96 
108 
     | 
    
         
             
                return torch.DoubleTensor([int(in_feature * num_elements)])
         
     | 
| 
       97 
109 
     | 
    
         | 
| 
       98 
110 
     | 
    
         | 
| 
       99 
111 
     | 
    
         
             
            def counter_matmul(input_size, output_size):
         
     | 
| 
      
 112 
     | 
    
         
            +
                """Calculate the total number of operations for a matrix multiplication given input and output sizes."""
         
     | 
| 
       100 
113 
     | 
    
         
             
                input_size = np.array(input_size)
         
     | 
| 
       101 
114 
     | 
    
         
             
                output_size = np.array(output_size)
         
     | 
| 
       102 
115 
     | 
    
         
             
                return np.prod(input_size) * output_size[-1]
         
     | 
| 
       103 
116 
     | 
    
         | 
| 
       104 
117 
     | 
    
         | 
| 
       105 
118 
     | 
    
         
             
            def counter_mul(input_size):
         
     | 
| 
      
 119 
     | 
    
         
            +
                """Calculate the total number of operations for a matrix multiplication given input and output sizes."""
         
     | 
| 
       106 
120 
     | 
    
         
             
                return input_size
         
     | 
| 
       107 
121 
     | 
    
         | 
| 
       108 
122 
     | 
    
         | 
| 
       109 
123 
     | 
    
         
             
            def counter_pow(input_size):
         
     | 
| 
      
 124 
     | 
    
         
            +
                """Calculate the total number of scalar multiplications for a power operation given an input size."""
         
     | 
| 
       110 
125 
     | 
    
         
             
                return input_size
         
     | 
| 
       111 
126 
     | 
    
         | 
| 
       112 
127 
     | 
    
         | 
| 
       113 
128 
     | 
    
         
             
            def counter_sqrt(input_size):
         
     | 
| 
      
 129 
     | 
    
         
            +
                """Calculate the total number of scalar operations for a square root operation given an input size."""
         
     | 
| 
       114 
130 
     | 
    
         
             
                return input_size
         
     | 
| 
       115 
131 
     | 
    
         | 
| 
       116 
132 
     | 
    
         | 
| 
       117 
133 
     | 
    
         
             
            def counter_div(input_size):
         
     | 
| 
      
 134 
     | 
    
         
            +
                """Calculate the total number of scalar operations for a division operation given an input size."""
         
     | 
| 
       118 
135 
     | 
    
         
             
                return input_size
         
     | 
| 
         @@ -19,6 +19,7 @@ from .calc_func import ( 
     | 
|
| 
       19 
19 
     | 
    
         | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
21 
     | 
    
         
             
            def onnx_counter_matmul(diction, node):
         
     | 
| 
      
 22 
     | 
    
         
            +
                """Calculates multiply-accumulate operations and output size for matrix multiplication in an ONNX model node."""
         
     | 
| 
       22 
23 
     | 
    
         
             
                input1 = node.input[0]
         
     | 
| 
       23 
24 
     | 
    
         
             
                input2 = node.input[1]
         
     | 
| 
       24 
25 
     | 
    
         
             
                input1_dim = diction[input1]
         
     | 
| 
         @@ -30,6 +31,7 @@ def onnx_counter_matmul(diction, node): 
     | 
|
| 
       30 
31 
     | 
    
         | 
| 
       31 
32 
     | 
    
         | 
| 
       32 
33 
     | 
    
         
             
            def onnx_counter_add(diction, node):
         
     | 
| 
      
 34 
     | 
    
         
            +
                """Calculate multiply-accumulate operations (MACs), output size, and output name for ONNX addition nodes."""
         
     | 
| 
       33 
35 
     | 
    
         
             
                if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
         
     | 
| 
       34 
36 
     | 
    
         
             
                    out_size = diction[node.input[1]]
         
     | 
| 
       35 
37 
     | 
    
         
             
                else:
         
     | 
| 
         @@ -42,7 +44,9 @@ def onnx_counter_add(diction, node): 
     | 
|
| 
       42 
44 
     | 
    
         | 
| 
       43 
45 
     | 
    
         | 
| 
       44 
46 
     | 
    
         
             
            def onnx_counter_conv(diction, node):
         
     | 
| 
       45 
     | 
    
         
            -
                 
     | 
| 
      
 47 
     | 
    
         
            +
                """Calculates MACs, output size, and name for an ONNX convolution node based on input tensor dimensions and node
         
     | 
| 
      
 48 
     | 
    
         
            +
                attributes.
         
     | 
| 
      
 49 
     | 
    
         
            +
                """
         
     | 
| 
       46 
50 
     | 
    
         
             
                # bias,kernelsize,outputsize
         
     | 
| 
       47 
51 
     | 
    
         
             
                dim_bias = 0
         
     | 
| 
       48 
52 
     | 
    
         
             
                input_count = 0
         
     | 
| 
         @@ -81,7 +85,7 @@ def onnx_counter_conv(diction, node): 
     | 
|
| 
       81 
85 
     | 
    
         | 
| 
       82 
86 
     | 
    
         | 
| 
       83 
87 
     | 
    
         
             
            def onnx_counter_constant(diction, node):
         
     | 
| 
       84 
     | 
    
         
            -
                 
     | 
| 
      
 88 
     | 
    
         
            +
                """Calculate MACs, output size, and output name for a constant operation in an ONNX model."""
         
     | 
| 
       85 
89 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       86 
90 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
       87 
91 
     | 
    
         
             
                output_size = [1]
         
     | 
| 
         @@ -90,6 +94,7 @@ def onnx_counter_constant(diction, node): 
     | 
|
| 
       90 
94 
     | 
    
         | 
| 
       91 
95 
     | 
    
         | 
| 
       92 
96 
     | 
    
         
             
            def onnx_counter_mul(diction, node):
         
     | 
| 
      
 97 
     | 
    
         
            +
                """Calculate MACs, output size, and output name for a multiplication operation in an ONNX model."""
         
     | 
| 
       93 
98 
     | 
    
         
             
                if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
         
     | 
| 
       94 
99 
     | 
    
         
             
                    input_size = diction[node.input[1]]
         
     | 
| 
       95 
100 
     | 
    
         
             
                else:
         
     | 
| 
         @@ -101,6 +106,7 @@ def onnx_counter_mul(diction, node): 
     | 
|
| 
       101 
106 
     | 
    
         | 
| 
       102 
107 
     | 
    
         | 
| 
       103 
108 
     | 
    
         
             
            def onnx_counter_bn(diction, node):
         
     | 
| 
      
 109 
     | 
    
         
            +
                """Calculates MACs, output size, and output name for batch normalization layers in an ONNX model."""
         
     | 
| 
       104 
110 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       105 
111 
     | 
    
         
             
                macs = calculate_norm(np.prod(input_size))
         
     | 
| 
       106 
112 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
         @@ -109,6 +115,7 @@ def onnx_counter_bn(diction, node): 
     | 
|
| 
       109 
115 
     | 
    
         | 
| 
       110 
116 
     | 
    
         | 
| 
       111 
117 
     | 
    
         
             
            def onnx_counter_relu(diction, node):
         
     | 
| 
      
 118 
     | 
    
         
            +
                """Calculates MACs, output size, and output name for ReLU layers in an ONNX model."""
         
     | 
| 
       112 
119 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       113 
120 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       114 
121 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
         @@ -120,6 +127,9 @@ def onnx_counter_relu(diction, node): 
     | 
|
| 
       120 
127 
     | 
    
         | 
| 
       121 
128 
     | 
    
         | 
| 
       122 
129 
     | 
    
         
             
            def onnx_counter_reducemean(diction, node):
         
     | 
| 
      
 130 
     | 
    
         
            +
                """Compute MACs, output size, and name for the ReduceMean ONNX node, adjusting dimensions based on the 'axes' and
         
     | 
| 
      
 131 
     | 
    
         
            +
                'keepdims' attributes.
         
     | 
| 
      
 132 
     | 
    
         
            +
                """
         
     | 
| 
       123 
133 
     | 
    
         
             
                keep_dim = 0
         
     | 
| 
       124 
134 
     | 
    
         
             
                for attr in node.attribute:
         
     | 
| 
       125 
135 
     | 
    
         
             
                    if "axes" in attr.name:
         
     | 
| 
         @@ -139,6 +149,7 @@ def onnx_counter_reducemean(diction, node): 
     | 
|
| 
       139 
149 
     | 
    
         | 
| 
       140 
150 
     | 
    
         | 
| 
       141 
151 
     | 
    
         
             
            def onnx_counter_sub(diction, node):
         
     | 
| 
      
 152 
     | 
    
         
            +
                """Computes MACs, output size, and output name for a given ONNX node with specified input size."""
         
     | 
| 
       142 
153 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       143 
154 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       144 
155 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
         @@ -147,6 +158,7 @@ def onnx_counter_sub(diction, node): 
     | 
|
| 
       147 
158 
     | 
    
         | 
| 
       148 
159 
     | 
    
         | 
| 
       149 
160 
     | 
    
         
             
            def onnx_counter_pow(diction, node):
         
     | 
| 
      
 161 
     | 
    
         
            +
                """Calculates MACs, output size, and output name for a given ONNX 'Pow' node with specified input size."""
         
     | 
| 
       150 
162 
     | 
    
         
             
                if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
         
     | 
| 
       151 
163 
     | 
    
         
             
                    input_size = diction[node.input[1]]
         
     | 
| 
       152 
164 
     | 
    
         
             
                else:
         
     | 
| 
         @@ -158,6 +170,7 @@ def onnx_counter_pow(diction, node): 
     | 
|
| 
       158 
170 
     | 
    
         | 
| 
       159 
171 
     | 
    
         | 
| 
       160 
172 
     | 
    
         
             
            def onnx_counter_sqrt(diction, node):
         
     | 
| 
      
 173 
     | 
    
         
            +
                """Calculate MACs and output information for the SQRT operation in an ONNX node."""
         
     | 
| 
       161 
174 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       162 
175 
     | 
    
         
             
                macs = counter_sqrt(np.prod(input_size))
         
     | 
| 
       163 
176 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
         @@ -166,6 +179,7 @@ def onnx_counter_sqrt(diction, node): 
     | 
|
| 
       166 
179 
     | 
    
         | 
| 
       167 
180 
     | 
    
         | 
| 
       168 
181 
     | 
    
         
             
            def onnx_counter_div(diction, node):
         
     | 
| 
      
 182 
     | 
    
         
            +
                """Calculate MACs and output information for the DIV operation in an ONNX node."""
         
     | 
| 
       169 
183 
     | 
    
         
             
                if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
         
     | 
| 
       170 
184 
     | 
    
         
             
                    input_size = diction[node.input[1]]
         
     | 
| 
       171 
185 
     | 
    
         
             
                else:
         
     | 
| 
         @@ -177,6 +191,7 @@ def onnx_counter_div(diction, node): 
     | 
|
| 
       177 
191 
     | 
    
         | 
| 
       178 
192 
     | 
    
         | 
| 
       179 
193 
     | 
    
         
             
            def onnx_counter_instance(diction, node):
         
     | 
| 
      
 194 
     | 
    
         
            +
                """Calculate MACs, output size, and name for an ONNX node instance."""
         
     | 
| 
       180 
195 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       181 
196 
     | 
    
         
             
                macs = calculate_norm(np.prod(input_size))
         
     | 
| 
       182 
197 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
         @@ -185,6 +200,7 @@ def onnx_counter_instance(diction, node): 
     | 
|
| 
       185 
200 
     | 
    
         | 
| 
       186 
201 
     | 
    
         | 
| 
       187 
202 
     | 
    
         
             
            def onnx_counter_softmax(diction, node):
         
     | 
| 
      
 203 
     | 
    
         
            +
                """Calculate MACs, output size, and name for an ONNX softmax node instance."""
         
     | 
| 
       188 
204 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       189 
205 
     | 
    
         
             
                dim = node.attribute[0].i
         
     | 
| 
       190 
206 
     | 
    
         
             
                nfeatures = input_size[dim]
         
     | 
| 
         @@ -196,7 +212,7 @@ def onnx_counter_softmax(diction, node): 
     | 
|
| 
       196 
212 
     | 
    
         | 
| 
       197 
213 
     | 
    
         | 
| 
       198 
214 
     | 
    
         
             
            def onnx_counter_pad(diction, node):
         
     | 
| 
       199 
     | 
    
         
            -
                 
     | 
| 
      
 215 
     | 
    
         
            +
                """Compute memory access cost (MACs), output size, and output name for ONNX pad operation."""
         
     | 
| 
       200 
216 
     | 
    
         
             
                # if
         
     | 
| 
       201 
217 
     | 
    
         
             
                # if (np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size):
         
     | 
| 
       202 
218 
     | 
    
         
             
                #     input_size = diction[node.input[1]]
         
     | 
| 
         @@ -210,7 +226,7 @@ def onnx_counter_pad(diction, node): 
     | 
|
| 
       210 
226 
     | 
    
         | 
| 
       211 
227 
     | 
    
         | 
| 
       212 
228 
     | 
    
         
             
            def onnx_counter_averagepool(diction, node):
         
     | 
| 
       213 
     | 
    
         
            -
                 
     | 
| 
      
 229 
     | 
    
         
            +
                """Calculate MACs and output size for an AveragePool ONNX operation based on input dimensions and attributes."""
         
     | 
| 
       214 
230 
     | 
    
         
             
                macs = calculate_avgpool(np.prod(diction[node.input[0]]))
         
     | 
| 
       215 
231 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
       216 
232 
     | 
    
         
             
                dim_pad = None
         
     | 
| 
         @@ -240,7 +256,7 @@ def onnx_counter_averagepool(diction, node): 
     | 
|
| 
       240 
256 
     | 
    
         | 
| 
       241 
257 
     | 
    
         | 
| 
       242 
258 
     | 
    
         
             
            def onnx_counter_flatten(diction, node):
         
     | 
| 
       243 
     | 
    
         
            -
                 
     | 
| 
      
 259 
     | 
    
         
            +
                """Returns MACs, output size, and output name for an ONNX Flatten node."""
         
     | 
| 
       244 
260 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       245 
261 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
       246 
262 
     | 
    
         
             
                axis = node.attribute[0].i
         
     | 
| 
         @@ -251,7 +267,7 @@ def onnx_counter_flatten(diction, node): 
     | 
|
| 
       251 
267 
     | 
    
         | 
| 
       252 
268 
     | 
    
         | 
| 
       253 
269 
     | 
    
         
             
            def onnx_counter_gemm(diction, node):
         
     | 
| 
       254 
     | 
    
         
            -
                 
     | 
| 
      
 270 
     | 
    
         
            +
                """Calculate multiply–accumulate operations (MACs), output size, and name for ONNX Gemm node."""
         
     | 
| 
       255 
271 
     | 
    
         
             
                # Compute Y = alpha * A' * B' + beta * C
         
     | 
| 
       256 
272 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
       257 
273 
     | 
    
         
             
                dim_weight = diction[node.input[1]]
         
     | 
| 
         @@ -264,7 +280,7 @@ def onnx_counter_gemm(diction, node): 
     | 
|
| 
       264 
280 
     | 
    
         | 
| 
       265 
281 
     | 
    
         | 
| 
       266 
282 
     | 
    
         
             
            def onnx_counter_maxpool(diction, node):
         
     | 
| 
       267 
     | 
    
         
            -
                 
     | 
| 
      
 283 
     | 
    
         
            +
                """Calculate MACs and output size for ONNX MaxPool operation based on input node attributes and dimensions."""
         
     | 
| 
       268 
284 
     | 
    
         
             
                # print(node)
         
     | 
| 
       269 
285 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       270 
286 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
         @@ -295,6 +311,7 @@ def onnx_counter_maxpool(diction, node): 
     | 
|
| 
       295 
311 
     | 
    
         | 
| 
       296 
312 
     | 
    
         | 
| 
       297 
313 
     | 
    
         
             
            def onnx_counter_globalaveragepool(diction, node):
         
     | 
| 
      
 314 
     | 
    
         
            +
                """Counts MACs and computes output size for a global average pooling layer in an ONNX model."""
         
     | 
| 
       298 
315 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       299 
316 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
       300 
317 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
         @@ -303,7 +320,7 @@ def onnx_counter_globalaveragepool(diction, node): 
     | 
|
| 
       303 
320 
     | 
    
         | 
| 
       304 
321 
     | 
    
         | 
| 
       305 
322 
     | 
    
         
             
            def onnx_counter_concat(diction, node):
         
     | 
| 
       306 
     | 
    
         
            -
                 
     | 
| 
      
 323 
     | 
    
         
            +
                """Counts MACs and computes output size for a concatenation layer along a specified axis in an ONNX model."""
         
     | 
| 
       307 
324 
     | 
    
         
             
                # print(diction[node.input[0]])
         
     | 
| 
       308 
325 
     | 
    
         
             
                axis = node.attribute[0].i
         
     | 
| 
       309 
326 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
         @@ -317,6 +334,9 @@ def onnx_counter_concat(diction, node): 
     | 
|
| 
       317 
334 
     | 
    
         | 
| 
       318 
335 
     | 
    
         | 
| 
       319 
336 
     | 
    
         
             
            def onnx_counter_clip(diction, node):
         
     | 
| 
      
 337 
     | 
    
         
            +
                """Calculate MACs, output size, and output name for an ONNX node clip operation using provided dimensions and input
         
     | 
| 
      
 338 
     | 
    
         
            +
                size.
         
     | 
| 
      
 339 
     | 
    
         
            +
                """
         
     | 
| 
       320 
340 
     | 
    
         
             
                macs = calculate_zero_ops()
         
     | 
| 
       321 
341 
     | 
    
         
             
                output_name = node.output[0]
         
     | 
| 
       322 
342 
     | 
    
         
             
                input_size = diction[node.input[0]]
         
     | 
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.1
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: ultralytics-thop
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.0.3
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: A tool to count the FLOPs of PyTorch model.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Author-email: Ligeng Zhu <ligeng.zhu+github@gmail.com>
         
     | 
| 
       6 
6 
     | 
    
         
             
            Maintainer-email: Ligeng Zhu <ligeng.zhu+github@gmail.com>
         
     | 
| 
         @@ -688,6 +688,7 @@ Classifier: Operating System :: Microsoft :: Windows 
     | 
|
| 
       688 
688 
     | 
    
         
             
            Requires-Python: >=3.8
         
     | 
| 
       689 
689 
     | 
    
         
             
            Description-Content-Type: text/markdown
         
     | 
| 
       690 
690 
     | 
    
         
             
            License-File: LICENSE
         
     | 
| 
      
 691 
     | 
    
         
            +
            Requires-Dist: packaging
         
     | 
| 
       691 
692 
     | 
    
         
             
            Requires-Dist: torch
         
     | 
| 
       692 
693 
     | 
    
         | 
| 
       693 
694 
     | 
    
         
             
            <br>
         
     | 
| 
         @@ -697,7 +698,7 @@ Requires-Dist: torch 
     | 
|
| 
       697 
698 
     | 
    
         | 
| 
       698 
699 
     | 
    
         
             
            Welcome to the [THOP](https://github.com/ultralytics/thop) repository, your comprehensive solution for profiling PyTorch models by computing the number of Multiply-Accumulate Operations (MACs) and parameters. This tool is essential for deep learning practitioners to evaluate model efficiency and performance.
         
     | 
| 
       699 
700 
     | 
    
         | 
| 
       700 
     | 
    
         
            -
            [](https://github.com/ultralytics/thop/actions/workflows/main.yml) [](https://github.com/ultralytics/thop/actions/workflows/main.yml) [](https://badge.fury.io/py/ultralytics-thop) <a href="https://ultralytics.com/discord"><img alt="Discord" src="https://img.shields.io/discord/1089800235347353640?logo=discord&logoColor=white&label=Discord&color=blue"></a>
         
     | 
| 
       701 
702 
     | 
    
         | 
| 
       702 
703 
     | 
    
         
             
            ## 📄 Description
         
     | 
| 
       703 
704 
     | 
    
         | 
| 
         @@ -830,17 +831,17 @@ For bugs or feature requests, please open an issue on [GitHub Issues](https://gi 
     | 
|
| 
       830 
831 
     | 
    
         | 
| 
       831 
832 
     | 
    
         
             
            <br>
         
     | 
| 
       832 
833 
     | 
    
         
             
            <div align="center">
         
     | 
| 
       833 
     | 
    
         
            -
              <a href="https://github.com/ultralytics 
     | 
| 
      
 834 
     | 
    
         
            +
              <a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
         
     | 
| 
       834 
835 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       835 
     | 
    
         
            -
              <a href="https://www.linkedin.com/company/ 
     | 
| 
      
 836 
     | 
    
         
            +
              <a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
         
     | 
| 
       836 
837 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       837 
     | 
    
         
            -
              <a href="https://twitter.com/ 
     | 
| 
      
 838 
     | 
    
         
            +
              <a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
         
     | 
| 
       838 
839 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       839 
     | 
    
         
            -
              <a href="https://youtube.com/ 
     | 
| 
      
 840 
     | 
    
         
            +
              <a href="https://youtube.com/ultralytics?sub_confirmation=1"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
         
     | 
| 
       840 
841 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       841 
     | 
    
         
            -
              <a href="https://www.tiktok.com/@ 
     | 
| 
      
 842 
     | 
    
         
            +
              <a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
         
     | 
| 
       842 
843 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       843 
     | 
    
         
            -
              <a href="https://www.instagram.com/ 
     | 
| 
      
 844 
     | 
    
         
            +
              <a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
         
     | 
| 
       844 
845 
     | 
    
         
             
              <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%" alt="space">
         
     | 
| 
       845 
     | 
    
         
            -
              <a href="https:// 
     | 
| 
      
 846 
     | 
    
         
            +
              <a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
         
     | 
| 
       846 
847 
     | 
    
         
             
            </div>
         
     | 
| 
         @@ -1 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            torch
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
    
        {ultralytics_thop-0.0.1 → ultralytics_thop-0.0.3}/ultralytics_thop.egg-info/dependency_links.txt
    RENAMED
    
    | 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |