torch-weighttracker 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. torch_weighttracker-0.1.0/LICENSE +22 -0
  2. torch_weighttracker-0.1.0/MANIFEST.in +6 -0
  3. torch_weighttracker-0.1.0/PKG-INFO +93 -0
  4. torch_weighttracker-0.1.0/README.md +58 -0
  5. torch_weighttracker-0.1.0/pyproject.toml +77 -0
  6. torch_weighttracker-0.1.0/setup.cfg +4 -0
  7. torch_weighttracker-0.1.0/tests/__init__.py +1 -0
  8. torch_weighttracker-0.1.0/tests/dependency_graph.py +46 -0
  9. torch_weighttracker-0.1.0/tests/fixtures_models.py +35 -0
  10. torch_weighttracker-0.1.0/tests/test_attention_axes_e2e.py +148 -0
  11. torch_weighttracker-0.1.0/tests/test_attention_axes_unit.py +104 -0
  12. torch_weighttracker-0.1.0/tests/test_calculation_runtime.py +267 -0
  13. torch_weighttracker-0.1.0/tests/test_calculation_specs.py +263 -0
  14. torch_weighttracker-0.1.0/tests/test_group_lasso_regularizer.py +480 -0
  15. torch_weighttracker-0.1.0/tests/test_linear_weight_sum_specs.py +44 -0
  16. torch_weighttracker-0.1.0/tests/test_mha_operations.py +86 -0
  17. torch_weighttracker-0.1.0/tests/test_model_structured_sparsity.py +379 -0
  18. torch_weighttracker-0.1.0/tests/test_module_bitrate_extractor.py +181 -0
  19. torch_weighttracker-0.1.0/tests/test_param_pr_unit.py +392 -0
  20. torch_weighttracker-0.1.0/tests/test_param_unit_calculator.py +65 -0
  21. torch_weighttracker-0.1.0/tests/test_reduction_plan_builder.py +631 -0
  22. torch_weighttracker-0.1.0/tests/test_sparsity_controller.py +247 -0
  23. torch_weighttracker-0.1.0/tests/test_weight_tracker_unit_extensive.py +345 -0
  24. torch_weighttracker-0.1.0/tests/test_weighttracker_fvcore_e2e.py +385 -0
  25. torch_weighttracker-0.1.0/torch_weighttracker/__init__.py +5 -0
  26. torch_weighttracker-0.1.0/torch_weighttracker/calculations/__init__.py +77 -0
  27. torch_weighttracker-0.1.0/torch_weighttracker/calculations/base.py +54 -0
  28. torch_weighttracker-0.1.0/torch_weighttracker/calculations/cached_calc.py +23 -0
  29. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/__init__.py +68 -0
  30. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/active_macs_pr_module.py +58 -0
  31. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/active_units.py +31 -0
  32. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_group_sizes.py +51 -0
  33. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_macs_pr_module.py +100 -0
  34. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_module_axes.py +102 -0
  35. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_param_pr_unit_pr_group.py +144 -0
  36. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/bitrate_pr_module.py +91 -0
  37. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/group_change_effect.py +136 -0
  38. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/group_sizes.py +49 -0
  39. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/group_unit_param_change.py +179 -0
  40. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/groups_to_units.py +79 -0
  41. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/l2_norm_pr_unit.py +47 -0
  42. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/param_pr_unit.py +64 -0
  43. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/structured_unit_sum.py +31 -0
  44. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/unit_active_mask.py +38 -0
  45. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/unit_delta_to_module_axis.py +185 -0
  46. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/units_to_group.py +64 -0
  47. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/units_to_module_axis.py +118 -0
  48. torch_weighttracker-0.1.0/torch_weighttracker/calculations/calculations.py +75 -0
  49. torch_weighttracker-0.1.0/torch_weighttracker/calculations/context.py +56 -0
  50. torch_weighttracker-0.1.0/torch_weighttracker/calculations/pipeline_calc.py +136 -0
  51. torch_weighttracker-0.1.0/torch_weighttracker/calculations/reduction_calc.py +149 -0
  52. torch_weighttracker-0.1.0/torch_weighttracker/calculations/registry.py +84 -0
  53. torch_weighttracker-0.1.0/torch_weighttracker/calculations/spec.py +18 -0
  54. torch_weighttracker-0.1.0/torch_weighttracker/calculations/static_calc.py +21 -0
  55. torch_weighttracker-0.1.0/torch_weighttracker/canonical_units.py +448 -0
  56. torch_weighttracker-0.1.0/torch_weighttracker/consumer_ignore.py +53 -0
  57. torch_weighttracker-0.1.0/torch_weighttracker/extractors/codeq_bitrate_extractor.py +222 -0
  58. torch_weighttracker-0.1.0/torch_weighttracker/extractors/extractor.py +106 -0
  59. torch_weighttracker-0.1.0/torch_weighttracker/extractors/parameter_extractor.py +103 -0
  60. torch_weighttracker-0.1.0/torch_weighttracker/operations/__init__.py +53 -0
  61. torch_weighttracker-0.1.0/torch_weighttracker/operations/base.py +41 -0
  62. torch_weighttracker-0.1.0/torch_weighttracker/operations/conv.py +29 -0
  63. torch_weighttracker-0.1.0/torch_weighttracker/operations/generic.py +176 -0
  64. torch_weighttracker-0.1.0/torch_weighttracker/operations/linear.py +24 -0
  65. torch_weighttracker-0.1.0/torch_weighttracker/operations/mha.py +285 -0
  66. torch_weighttracker-0.1.0/torch_weighttracker/operations/norm.py +39 -0
  67. torch_weighttracker-0.1.0/torch_weighttracker/operations/resolver.py +44 -0
  68. torch_weighttracker-0.1.0/torch_weighttracker/plans/mapping_plan.py +90 -0
  69. torch_weighttracker-0.1.0/torch_weighttracker/plans/unit_weight_operation_plan.py +226 -0
  70. torch_weighttracker-0.1.0/torch_weighttracker/py.typed +1 -0
  71. torch_weighttracker-0.1.0/torch_weighttracker/reductions/builder.py +498 -0
  72. torch_weighttracker-0.1.0/torch_weighttracker/reductions/compiler.py +92 -0
  73. torch_weighttracker-0.1.0/torch_weighttracker/reductions/helpers.py +75 -0
  74. torch_weighttracker-0.1.0/torch_weighttracker/reductions/ops.py +96 -0
  75. torch_weighttracker-0.1.0/torch_weighttracker/regularizers/__init__.py +17 -0
  76. torch_weighttracker-0.1.0/torch_weighttracker/regularizers/base.py +79 -0
  77. torch_weighttracker-0.1.0/torch_weighttracker/regularizers/group_lasso.py +43 -0
  78. torch_weighttracker-0.1.0/torch_weighttracker/regularizers/group_lasso_with_bitrate.py +9 -0
  79. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/__init__.py +75 -0
  80. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/_helpers.py +79 -0
  81. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/__init__.py +9 -0
  82. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/constants.py +10 -0
  83. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/dependency.py +109 -0
  84. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/graph.py +613 -0
  85. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/group.py +143 -0
  86. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/index_mapping.py +403 -0
  87. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/node.py +60 -0
  88. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/shape_infer.py +124 -0
  89. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/ops.py +346 -0
  90. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/pruner/__init__.py +53 -0
  91. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/pruner/function.py +574 -0
  92. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/utils/__init__.py +3 -0
  93. torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/utils/utils.py +19 -0
  94. torch_weighttracker-0.1.0/torch_weighttracker/trackers/__init__.py +11 -0
  95. torch_weighttracker-0.1.0/torch_weighttracker/trackers/base.py +96 -0
  96. torch_weighttracker-0.1.0/torch_weighttracker/trackers/structured_bops.py +122 -0
  97. torch_weighttracker-0.1.0/torch_weighttracker/weight_tracker.py +593 -0
  98. torch_weighttracker-0.1.0/torch_weighttracker.egg-info/PKG-INFO +93 -0
  99. torch_weighttracker-0.1.0/torch_weighttracker.egg-info/SOURCES.txt +100 -0
  100. torch_weighttracker-0.1.0/torch_weighttracker.egg-info/dependency_links.txt +1 -0
  101. torch_weighttracker-0.1.0/torch_weighttracker.egg-info/requires.txt +14 -0
  102. torch_weighttracker-0.1.0/torch_weighttracker.egg-info/top_level.txt +1 -0
@@ -0,0 +1,22 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 torch-weighttracker contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
@@ -0,0 +1,6 @@
1
+ include LICENSE
2
+ include README.md
3
+ include pyproject.toml
4
+ include torch_weighttracker/py.typed
5
+ recursive-include tests *.py
6
+ prune tests/ignored
@@ -0,0 +1,93 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-weighttracker
3
+ Version: 0.1.0
4
+ Summary: Tools for tracking structured weight sparsity in PyTorch models.
5
+ License-Expression: MIT
6
+ Project-URL: Homepage, https://github.com/dadyownes15/torch-weighttracker
7
+ Project-URL: Issues, https://github.com/dadyownes15/torch-weighttracker/issues
8
+ Project-URL: Repository, https://github.com/dadyownes15/torch-weighttracker
9
+ Keywords: pytorch,torch,weights,tracking,structured-sparsity,pruning
10
+ Classifier: Development Status :: 2 - Pre-Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.10
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: numpy>=1.24
23
+ Requires-Dist: torch>=2.0
24
+ Provides-Extra: structured-bops
25
+ Requires-Dist: fvcore; extra == "structured-bops"
26
+ Provides-Extra: dev
27
+ Requires-Dist: build>=1.2; extra == "dev"
28
+ Requires-Dist: fvcore; extra == "dev"
29
+ Requires-Dist: pytest>=8.0; extra == "dev"
30
+ Requires-Dist: ruff>=0.6; extra == "dev"
31
+ Requires-Dist: timm>=1.0; extra == "dev"
32
+ Requires-Dist: torchvision>=0.15; extra == "dev"
33
+ Requires-Dist: twine>=5.0; extra == "dev"
34
+ Dynamic: license-file
35
+
36
+ # torch-weighttracker
37
+
38
+ PyTorch tools for tracking structured weight sparsity, regularization signals,
39
+ and bit-operation estimates in neural network modules.
40
+
41
+ The public API is centered on `WeightTracker`:
42
+
43
+ ```python
44
+ import torch
45
+ from torch import nn
46
+
47
+ from torch_weighttracker import WeightTracker
48
+
49
+ model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
50
+ tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
51
+ print(tracker.view_structures())
52
+ ```
53
+
54
+ ## Installation
55
+
56
+ ```bash
57
+ python -m pip install torch-weighttracker
58
+ ```
59
+
60
+ Structured BOPs MAC accounting uses `fvcore` for baseline per-module MACs:
61
+
62
+ ```bash
63
+ python -m pip install "torch-weighttracker[structured-bops]"
64
+ ```
65
+
66
+ ## Development
67
+
68
+ ```bash
69
+ python -m pip install -e ".[dev]"
70
+ ```
71
+
72
+ Run tests and lint checks:
73
+
74
+ ```bash
75
+ pytest
76
+ ruff check .
77
+ ruff format --check .
78
+ ```
79
+
80
+ ## Smoke Test
81
+
82
+ ```bash
83
+ python -c "from torch_weighttracker import WeightTracker; print(WeightTracker)"
84
+ ```
85
+
86
+ ## Status
87
+
88
+ This package is pre-1.0. Public APIs may still change while the tracker,
89
+ calculation, and regularizer surfaces settle.
90
+
91
+ ## License
92
+
93
+ MIT
@@ -0,0 +1,58 @@
1
+ # torch-weighttracker
2
+
3
+ PyTorch tools for tracking structured weight sparsity, regularization signals,
4
+ and bit-operation estimates in neural network modules.
5
+
6
+ The public API is centered on `WeightTracker`:
7
+
8
+ ```python
9
+ import torch
10
+ from torch import nn
11
+
12
+ from torch_weighttracker import WeightTracker
13
+
14
+ model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
15
+ tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
16
+ print(tracker.view_structures())
17
+ ```
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ python -m pip install torch-weighttracker
23
+ ```
24
+
25
+ Structured BOPs MAC accounting uses `fvcore` for baseline per-module MACs:
26
+
27
+ ```bash
28
+ python -m pip install "torch-weighttracker[structured-bops]"
29
+ ```
30
+
31
+ ## Development
32
+
33
+ ```bash
34
+ python -m pip install -e ".[dev]"
35
+ ```
36
+
37
+ Run tests and lint checks:
38
+
39
+ ```bash
40
+ pytest
41
+ ruff check .
42
+ ruff format --check .
43
+ ```
44
+
45
+ ## Smoke Test
46
+
47
+ ```bash
48
+ python -c "from torch_weighttracker import WeightTracker; print(WeightTracker)"
49
+ ```
50
+
51
+ ## Status
52
+
53
+ This package is pre-1.0. Public APIs may still change while the tracker,
54
+ calculation, and regularizer surfaces settle.
55
+
56
+ ## License
57
+
58
+ MIT
@@ -0,0 +1,77 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torch-weighttracker"
7
+ version = "0.1.0"
8
+ description = "Tools for tracking structured weight sparsity in PyTorch models."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = "MIT"
12
+ license-files = ["LICENSE"]
13
+ keywords = ["pytorch", "torch", "weights", "tracking", "structured-sparsity", "pruning"]
14
+ classifiers = [
15
+ "Development Status :: 2 - Pre-Alpha",
16
+ "Intended Audience :: Developers",
17
+ "Intended Audience :: Science/Research",
18
+ "Programming Language :: Python :: 3",
19
+ "Programming Language :: Python :: 3.10",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Programming Language :: Python :: 3.13",
23
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
24
+ ]
25
+ dependencies = [
26
+ "numpy>=1.24",
27
+ "torch>=2.0",
28
+ ]
29
+
30
+ [project.optional-dependencies]
31
+ structured-bops = [
32
+ "fvcore",
33
+ ]
34
+ dev = [
35
+ "build>=1.2",
36
+ "fvcore",
37
+ "pytest>=8.0",
38
+ "ruff>=0.6",
39
+ "timm>=1.0",
40
+ "torchvision>=0.15",
41
+ "twine>=5.0",
42
+ ]
43
+
44
+ [project.urls]
45
+ Homepage = "https://github.com/dadyownes15/torch-weighttracker"
46
+ Issues = "https://github.com/dadyownes15/torch-weighttracker/issues"
47
+ Repository = "https://github.com/dadyownes15/torch-weighttracker"
48
+
49
+ [tool.setuptools.packages.find]
50
+ where = ["."]
51
+ include = ["torch_weighttracker*"]
52
+
53
+ [tool.setuptools.package-data]
54
+ torch_weighttracker = ["py.typed"]
55
+
56
+ [tool.pytest.ini_options]
57
+ addopts = [
58
+ "-ra",
59
+ "--ignore=tests/test_attention_axes_e2e.py",
60
+ "--ignore=tests/test_attention_axes_unit.py",
61
+ "--ignore=tests/test_model_structured_sparsity.py",
62
+ "--ignore=tests/test_sparsity_controller.py",
63
+ "--ignore=tests/ignored",
64
+ ]
65
+ pythonpath = ["."]
66
+ testpaths = ["tests"]
67
+
68
+ [tool.ruff]
69
+ line-length = 88
70
+ target-version = "py310"
71
+
72
+ [tool.ruff.lint]
73
+ select = ["E", "F", "I", "UP", "B"]
74
+
75
+ [tool.ruff.format]
76
+ quote-style = "double"
77
+ indent-style = "space"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ import pytest
5
+ import torch.nn as nn
6
+ import torch
7
+ from torch_weighttracker.torch_pruning.dependency import DependencyGraph
8
+ from torch_weighttracker.torch_pruning.pruner.function import LinearPruner, prune_linear_out_channels
9
+
10
+ class simpleMLP(nn.Module):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+ self.first_layer =nn.Linear(1,3)
14
+ self.secound_layer =nn.Linear(3,3)
15
+ self.third_layer =nn.Linear(3,1)
16
+ self.net = nn.Sequential(
17
+ self.first_layer,
18
+ self.secound_layer,
19
+ self.third_layer
20
+ )
21
+
22
+ def forward(self,x):
23
+ return self.net(x)
24
+
25
+
26
+ def test_specs_creation():
27
+ model = simpleMLP()
28
+ input_ex = torch.tensor([[1.0]])
29
+ print(model.forward(input_ex))
30
+ graph = DependencyGraph().build_dependency(model = model, example_inputs=input_ex)
31
+ all_groups = []
32
+
33
+ for group in graph.get_all_groups():
34
+ all_groups.append(group)
35
+ # iteraer
36
+ first_group = all_groups[2]
37
+ first_item = first_group[0]
38
+ print("first item: ",first_item)
39
+ dep = first_item[0]
40
+ dep2 = first_group[1][0]
41
+ print(dep.handler)
42
+ print(vars(dep))
43
+ print("dep 2")
44
+ print(vars(dep2))
45
+
46
+ assert dep.handler == prune_linear_out_channels
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+
8
+ class TinyTransformerClassifier(nn.Module):
9
+ def __init__(self) -> None:
10
+ super().__init__()
11
+ self.token_embed = nn.Embedding(32, 16)
12
+ self.position_embed = nn.Embedding(8, 16)
13
+ self.attn = nn.MultiheadAttention(
14
+ embed_dim=16,
15
+ num_heads=4,
16
+ batch_first=True,
17
+ )
18
+ self.norm1 = nn.LayerNorm(16)
19
+ self.mlp_in = nn.Linear(16, 32)
20
+ self.activation = nn.GELU()
21
+ self.mlp_out = nn.Linear(32, 16)
22
+ self.norm2 = nn.LayerNorm(16)
23
+ self.head = nn.Linear(16, 5)
24
+
25
+ def forward(self, token_ids: Tensor) -> Tensor:
26
+ positions = torch.arange(token_ids.size(1), device=token_ids.device).unsqueeze(0)
27
+ x = self.token_embed(token_ids) + self.position_embed(positions)
28
+
29
+ attn_out, _ = self.attn(x, x, x, need_weights=False)
30
+ x = self.norm1(x + attn_out)
31
+
32
+ ff = self.mlp_out(self.activation(self.mlp_in(x)))
33
+ x = self.norm2(x + ff)
34
+ return self.head(x[:, 0])
35
+
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ import torch_structure_analyser as tsa
7
+ from torch_structure_analyser.analysis import StructureAxis
8
+ from tests.fixtures_models import TinyTransformerClassifier
9
+
10
+
11
+ class AttentionProbeNet(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
15
+ self.lin = nn.Linear(8, 8, bias=False)
16
+
17
+ def forward(self, x):
18
+ y, _ = self.mha(x, x, x)
19
+ return self.lin(y)
20
+
21
+
22
+ def _make_probe_controller(model: AttentionProbeNet | None = None) -> tsa.SparsityTracker:
23
+ if model is None:
24
+ model = AttentionProbeNet()
25
+ return tsa.SparsityTracker(
26
+ model,
27
+ example_inputs=torch.randn(1, 4, 8),
28
+ root_module_types=[nn.MultiheadAttention, nn.Linear],
29
+ num_heads={model.mha: 2},
30
+ prune_num_heads=True,
31
+ prune_head_dims=True,
32
+ )
33
+
34
+
35
+ def _make_transformer_controller(model: TinyTransformerClassifier | None = None) -> tsa.SparsityTracker:
36
+ if model is None:
37
+ model = TinyTransformerClassifier()
38
+ token_ids = torch.randint(0, 32, (1, 8))
39
+ return tsa.SparsityTracker(
40
+ model,
41
+ example_inputs=token_ids,
42
+ root_module_types=[nn.MultiheadAttention, nn.Linear],
43
+ num_heads={model.attn: model.attn.num_heads},
44
+ prune_num_heads=True,
45
+ prune_head_dims=True,
46
+ )
47
+
48
+
49
+ def _zero_mha_out_slices(attention: nn.MultiheadAttention, idxs: list[int]) -> None:
50
+ embed_dim = attention.embed_dim
51
+ repeated_idxs = idxs + [idx + embed_dim for idx in idxs] + [idx + 2 * embed_dim for idx in idxs]
52
+ with torch.no_grad():
53
+ if attention.q_proj_weight is not None:
54
+ attention.q_proj_weight[idxs, :] = 0
55
+ if attention.k_proj_weight is not None:
56
+ attention.k_proj_weight[idxs, :] = 0
57
+ if attention.v_proj_weight is not None:
58
+ attention.v_proj_weight[idxs, :] = 0
59
+ if attention.in_proj_weight is not None:
60
+ attention.in_proj_weight[repeated_idxs, :] = 0
61
+ attention.in_proj_weight[:, idxs] = 0
62
+ if attention.out_proj is not None:
63
+ attention.out_proj.weight[idxs, :] = 0
64
+ attention.out_proj.weight[:, idxs] = 0
65
+
66
+
67
+ def _zero_member_slices(module, handler, idxs: list[int]) -> None:
68
+ with torch.no_grad():
69
+ if handler in [tsa.prune_linear_out_channels]:
70
+ module.weight[idxs, :] = 0
71
+ if module.bias is not None:
72
+ module.bias[idxs] = 0
73
+ elif handler in [tsa.prune_linear_in_channels]:
74
+ module.weight[:, idxs] = 0
75
+ elif handler in [tsa.prune_layernorm_out_channels]:
76
+ if module.elementwise_affine:
77
+ module.weight[idxs] = 0
78
+ if module.bias is not None:
79
+ module.bias[idxs] = 0
80
+ elif handler in [tsa.prune_batchnorm_out_channels, tsa.prune_groupnorm_out_channels, tsa.prune_instancenorm_out_channels]:
81
+ if getattr(module, "affine", False):
82
+ module.weight[idxs] = 0
83
+ if module.bias is not None:
84
+ module.bias[idxs] = 0
85
+ elif handler in [tsa.prune_embedding_out_channels]:
86
+ module.weight[:, idxs] = 0
87
+ elif handler in [tsa.prune_multihead_attention_out_channels]:
88
+ _zero_mha_out_slices(module, idxs)
89
+ else:
90
+ raise AssertionError(f"Unhandled test zeroing rule for {handler.__name__}")
91
+
92
+
93
+ def _zero_group_unit(controller: tsa.SparsityTracker, group_id: str, root_indices: list[int]) -> None:
94
+ group_view = next(view for view in controller.iter_groups() if view.group_id == group_id)
95
+ root_set = set(root_indices)
96
+ for member in group_view.members:
97
+ if not member.measurable:
98
+ continue
99
+ local_idxs = [
100
+ local_idx
101
+ for local_idx, root_idx in zip(member.local_idxs, member.root_idxs)
102
+ if root_idx in root_set
103
+ ]
104
+ if len(local_idxs) == 0:
105
+ continue
106
+ _zero_member_slices(member.module, member.handler, local_idxs)
107
+
108
+
109
+ def test_attention_head_dim_report_detects_zeroed_shared_dimension():
110
+ controller = _make_probe_controller()
111
+ _zero_group_unit(controller, "lin:prune_in_channels:head_dim", [0, 4])
112
+
113
+ report = controller.structured_sparsity()
114
+
115
+ assert report.by_group["lin:prune_in_channels:head_dim"].zero_prune_units == ((0, 4),)
116
+ assert report.by_group["lin:prune_in_channels:head"].stats.removed == 0
117
+
118
+
119
+ def test_attention_head_report_detects_zeroed_whole_head():
120
+ controller = _make_probe_controller()
121
+ _zero_group_unit(controller, "lin:prune_in_channels:head", [0, 1, 2, 3])
122
+
123
+ report = controller.structured_sparsity()
124
+
125
+ assert report.by_group["lin:prune_in_channels:head"].zero_prune_units == ((0, 1, 2, 3),)
126
+ assert report.by_group["lin:prune_in_channels:head_dim"].stats.removed == 0
127
+
128
+
129
+ def test_prune_zero_structures_updates_logical_num_heads_after_head_prune():
130
+ model = TinyTransformerClassifier()
131
+ controller = _make_transformer_controller(model)
132
+ _zero_group_unit(controller, "attn:prune_out_channels:head", [0, 1, 2, 3])
133
+
134
+ result = controller.prune_zero_structures()
135
+ attention_views = [view for view in controller.iter_groups() if view.axis == StructureAxis.HEAD]
136
+
137
+ assert result.pruned_group_ids == ("attn:prune_out_channels:head",)
138
+ assert controller.config.num_heads[model.attn] == 3
139
+ assert attention_views[0].size == 3
140
+
141
+
142
+ def test_tiny_transformer_exposes_attention_head_and_head_dim_views_end_to_end():
143
+ controller = _make_transformer_controller()
144
+
145
+ attention_views = [view for view in controller.iter_groups() if view.group_id.startswith("attn:prune_out_channels")]
146
+
147
+ assert len(attention_views) == 2
148
+ assert {view.axis for view in attention_views} == {StructureAxis.HEAD, StructureAxis.HEAD_DIM}
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ import torch_structure_analyser as tsa
7
+ from torch_structure_analyser.analysis import (
8
+ StructureAxis,
9
+ build_atomic_prune_units,
10
+ build_grouped_prune_units,
11
+ build_head_prune_units,
12
+ )
13
+
14
+
15
+ class AttentionProbeNet(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
19
+ self.lin = nn.Linear(8, 8, bias=False)
20
+
21
+ def forward(self, x):
22
+ y, _ = self.mha(x, x, x)
23
+ return self.lin(y)
24
+
25
+
26
+ def _make_controller() -> tsa.SparsityTracker:
27
+ model = AttentionProbeNet()
28
+ return tsa.SparsityTracker(
29
+ model,
30
+ example_inputs=torch.randn(1, 4, 8),
31
+ root_module_types=[nn.MultiheadAttention, nn.Linear],
32
+ num_heads={model.mha: 2},
33
+ prune_num_heads=True,
34
+ prune_head_dims=True,
35
+ )
36
+
37
+
38
+ def _zero_attention_indices(model: AttentionProbeNet, idxs: list[int]) -> None:
39
+ embed_dim = model.mha.embed_dim
40
+ repeated_idxs = idxs + [idx + embed_dim for idx in idxs] + [idx + 2 * embed_dim for idx in idxs]
41
+ with torch.no_grad():
42
+ if model.mha.q_proj_weight is not None:
43
+ model.mha.q_proj_weight[idxs, :] = 0
44
+ if model.mha.k_proj_weight is not None:
45
+ model.mha.k_proj_weight[idxs, :] = 0
46
+ if model.mha.v_proj_weight is not None:
47
+ model.mha.v_proj_weight[idxs, :] = 0
48
+ if model.mha.in_proj_weight is not None:
49
+ model.mha.in_proj_weight[repeated_idxs, :] = 0
50
+ model.mha.in_proj_weight[:, idxs] = 0
51
+ if model.mha.out_proj is not None:
52
+ model.mha.out_proj.weight[idxs, :] = 0
53
+ model.mha.out_proj.weight[:, idxs] = 0
54
+ model.lin.weight[:, idxs] = 0
55
+
56
+
57
+ def test_build_atomic_prune_units_returns_singletons():
58
+ units = build_atomic_prune_units(4)
59
+ assert [unit.root_indices for unit in units] == [(0,), (1,), (2,), (3,)]
60
+
61
+
62
+ def test_build_grouped_prune_units_matches_head_dim_pattern():
63
+ units = build_grouped_prune_units(8, 2)
64
+ assert [unit.root_indices for unit in units] == [(0, 4), (1, 5), (2, 6), (3, 7)]
65
+
66
+
67
+ def test_build_head_prune_units_returns_contiguous_blocks():
68
+ units = build_head_prune_units(8, 2)
69
+ assert [unit.root_indices for unit in units] == [(0, 1, 2, 3), (4, 5, 6, 7)]
70
+
71
+
72
+ def test_attention_group_views_include_head_and_head_dim_axes():
73
+ controller = _make_controller()
74
+ attention_views = [view for view in controller.iter_groups() if view.group_id.startswith("lin:prune_in_channels")]
75
+
76
+ assert [view.axis for view in attention_views] == [StructureAxis.HEAD_DIM, StructureAxis.HEAD]
77
+ assert [unit.root_indices for unit in attention_views[0].prune_units] == [(0, 4), (1, 5), (2, 6), (3, 7)]
78
+ assert [unit.root_indices for unit in attention_views[1].prune_units] == [(0, 1, 2, 3), (4, 5, 6, 7)]
79
+
80
+
81
+ def test_zero_structure_candidates_prioritize_heads_before_head_dims():
82
+ controller = _make_controller()
83
+ _zero_attention_indices(controller.model, list(range(8)))
84
+
85
+ candidates = controller.zero_structure_candidates()
86
+ attention_candidates = [candidate for candidate in candidates if candidate.group_id.startswith("lin:prune_in_channels")]
87
+
88
+ assert [candidate.axis for candidate in attention_candidates[:2]] == [StructureAxis.HEAD, StructureAxis.HEAD_DIM]
89
+
90
+
91
+ def test_group_lasso_axes_filter_keeps_attention_views_separate():
92
+ controller = _make_controller()
93
+
94
+ head_loss, head_named_terms = controller.group_lasso(axes=(StructureAxis.HEAD,))
95
+ head_dim_loss, head_dim_named_terms = controller.group_lasso(axes=(StructureAxis.HEAD_DIM,))
96
+
97
+ assert head_loss.ndim == 0
98
+ assert head_dim_loss.ndim == 0
99
+ assert head_loss.requires_grad
100
+ assert head_dim_loss.requires_grad
101
+ assert head_named_terms
102
+ assert head_dim_named_terms
103
+ assert all("[0]" in key or "[1]" in key for key in head_named_terms)
104
+ assert len(head_named_terms) < len(head_dim_named_terms)