torch-weighttracker 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_weighttracker-0.1.0/LICENSE +22 -0
- torch_weighttracker-0.1.0/MANIFEST.in +6 -0
- torch_weighttracker-0.1.0/PKG-INFO +93 -0
- torch_weighttracker-0.1.0/README.md +58 -0
- torch_weighttracker-0.1.0/pyproject.toml +77 -0
- torch_weighttracker-0.1.0/setup.cfg +4 -0
- torch_weighttracker-0.1.0/tests/__init__.py +1 -0
- torch_weighttracker-0.1.0/tests/dependency_graph.py +46 -0
- torch_weighttracker-0.1.0/tests/fixtures_models.py +35 -0
- torch_weighttracker-0.1.0/tests/test_attention_axes_e2e.py +148 -0
- torch_weighttracker-0.1.0/tests/test_attention_axes_unit.py +104 -0
- torch_weighttracker-0.1.0/tests/test_calculation_runtime.py +267 -0
- torch_weighttracker-0.1.0/tests/test_calculation_specs.py +263 -0
- torch_weighttracker-0.1.0/tests/test_group_lasso_regularizer.py +480 -0
- torch_weighttracker-0.1.0/tests/test_linear_weight_sum_specs.py +44 -0
- torch_weighttracker-0.1.0/tests/test_mha_operations.py +86 -0
- torch_weighttracker-0.1.0/tests/test_model_structured_sparsity.py +379 -0
- torch_weighttracker-0.1.0/tests/test_module_bitrate_extractor.py +181 -0
- torch_weighttracker-0.1.0/tests/test_param_pr_unit.py +392 -0
- torch_weighttracker-0.1.0/tests/test_param_unit_calculator.py +65 -0
- torch_weighttracker-0.1.0/tests/test_reduction_plan_builder.py +631 -0
- torch_weighttracker-0.1.0/tests/test_sparsity_controller.py +247 -0
- torch_weighttracker-0.1.0/tests/test_weight_tracker_unit_extensive.py +345 -0
- torch_weighttracker-0.1.0/tests/test_weighttracker_fvcore_e2e.py +385 -0
- torch_weighttracker-0.1.0/torch_weighttracker/__init__.py +5 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/__init__.py +77 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/base.py +54 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/cached_calc.py +23 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/__init__.py +68 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/active_macs_pr_module.py +58 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/active_units.py +31 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_group_sizes.py +51 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_macs_pr_module.py +100 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_module_axes.py +102 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/baseline_param_pr_unit_pr_group.py +144 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/bitrate_pr_module.py +91 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/group_change_effect.py +136 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/group_sizes.py +49 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/group_unit_param_change.py +179 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/groups_to_units.py +79 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/l2_norm_pr_unit.py +47 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/param_pr_unit.py +64 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/structured_unit_sum.py +31 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/unit_active_mask.py +38 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/unit_delta_to_module_axis.py +185 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/units_to_group.py +64 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calcs/units_to_module_axis.py +118 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/calculations.py +75 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/context.py +56 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/pipeline_calc.py +136 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/reduction_calc.py +149 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/registry.py +84 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/spec.py +18 -0
- torch_weighttracker-0.1.0/torch_weighttracker/calculations/static_calc.py +21 -0
- torch_weighttracker-0.1.0/torch_weighttracker/canonical_units.py +448 -0
- torch_weighttracker-0.1.0/torch_weighttracker/consumer_ignore.py +53 -0
- torch_weighttracker-0.1.0/torch_weighttracker/extractors/codeq_bitrate_extractor.py +222 -0
- torch_weighttracker-0.1.0/torch_weighttracker/extractors/extractor.py +106 -0
- torch_weighttracker-0.1.0/torch_weighttracker/extractors/parameter_extractor.py +103 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/__init__.py +53 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/base.py +41 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/conv.py +29 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/generic.py +176 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/linear.py +24 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/mha.py +285 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/norm.py +39 -0
- torch_weighttracker-0.1.0/torch_weighttracker/operations/resolver.py +44 -0
- torch_weighttracker-0.1.0/torch_weighttracker/plans/mapping_plan.py +90 -0
- torch_weighttracker-0.1.0/torch_weighttracker/plans/unit_weight_operation_plan.py +226 -0
- torch_weighttracker-0.1.0/torch_weighttracker/py.typed +1 -0
- torch_weighttracker-0.1.0/torch_weighttracker/reductions/builder.py +498 -0
- torch_weighttracker-0.1.0/torch_weighttracker/reductions/compiler.py +92 -0
- torch_weighttracker-0.1.0/torch_weighttracker/reductions/helpers.py +75 -0
- torch_weighttracker-0.1.0/torch_weighttracker/reductions/ops.py +96 -0
- torch_weighttracker-0.1.0/torch_weighttracker/regularizers/__init__.py +17 -0
- torch_weighttracker-0.1.0/torch_weighttracker/regularizers/base.py +79 -0
- torch_weighttracker-0.1.0/torch_weighttracker/regularizers/group_lasso.py +43 -0
- torch_weighttracker-0.1.0/torch_weighttracker/regularizers/group_lasso_with_bitrate.py +9 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/__init__.py +75 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/_helpers.py +79 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/__init__.py +9 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/constants.py +10 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/dependency.py +109 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/graph.py +613 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/group.py +143 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/index_mapping.py +403 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/node.py +60 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/dependency/shape_infer.py +124 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/ops.py +346 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/pruner/__init__.py +53 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/pruner/function.py +574 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/utils/__init__.py +3 -0
- torch_weighttracker-0.1.0/torch_weighttracker/torch_pruning/utils/utils.py +19 -0
- torch_weighttracker-0.1.0/torch_weighttracker/trackers/__init__.py +11 -0
- torch_weighttracker-0.1.0/torch_weighttracker/trackers/base.py +96 -0
- torch_weighttracker-0.1.0/torch_weighttracker/trackers/structured_bops.py +122 -0
- torch_weighttracker-0.1.0/torch_weighttracker/weight_tracker.py +593 -0
- torch_weighttracker-0.1.0/torch_weighttracker.egg-info/PKG-INFO +93 -0
- torch_weighttracker-0.1.0/torch_weighttracker.egg-info/SOURCES.txt +100 -0
- torch_weighttracker-0.1.0/torch_weighttracker.egg-info/dependency_links.txt +1 -0
- torch_weighttracker-0.1.0/torch_weighttracker.egg-info/requires.txt +14 -0
- torch_weighttracker-0.1.0/torch_weighttracker.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 torch-weighttracker contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
22
|
+
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-weighttracker
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Tools for tracking structured weight sparsity in PyTorch models.
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Project-URL: Homepage, https://github.com/dadyownes15/torch-weighttracker
|
|
7
|
+
Project-URL: Issues, https://github.com/dadyownes15/torch-weighttracker/issues
|
|
8
|
+
Project-URL: Repository, https://github.com/dadyownes15/torch-weighttracker
|
|
9
|
+
Keywords: pytorch,torch,weights,tracking,structured-sparsity,pruning
|
|
10
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: numpy>=1.24
|
|
23
|
+
Requires-Dist: torch>=2.0
|
|
24
|
+
Provides-Extra: structured-bops
|
|
25
|
+
Requires-Dist: fvcore; extra == "structured-bops"
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: build>=1.2; extra == "dev"
|
|
28
|
+
Requires-Dist: fvcore; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff>=0.6; extra == "dev"
|
|
31
|
+
Requires-Dist: timm>=1.0; extra == "dev"
|
|
32
|
+
Requires-Dist: torchvision>=0.15; extra == "dev"
|
|
33
|
+
Requires-Dist: twine>=5.0; extra == "dev"
|
|
34
|
+
Dynamic: license-file
|
|
35
|
+
|
|
36
|
+
# torch-weighttracker
|
|
37
|
+
|
|
38
|
+
PyTorch tools for tracking structured weight sparsity, regularization signals,
|
|
39
|
+
and bit-operation estimates in neural network modules.
|
|
40
|
+
|
|
41
|
+
The public API is centered on `WeightTracker`:
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
import torch
|
|
45
|
+
from torch import nn
|
|
46
|
+
|
|
47
|
+
from torch_weighttracker import WeightTracker
|
|
48
|
+
|
|
49
|
+
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
|
|
50
|
+
tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
|
|
51
|
+
print(tracker.view_structures())
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Installation
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
python -m pip install torch-weighttracker
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Structured BOPs MAC accounting uses `fvcore` for baseline per-module MACs:
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
python -m pip install "torch-weighttracker[structured-bops]"
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Development
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
python -m pip install -e ".[dev]"
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
Run tests and lint checks:
|
|
73
|
+
|
|
74
|
+
```bash
|
|
75
|
+
pytest
|
|
76
|
+
ruff check .
|
|
77
|
+
ruff format --check .
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Smoke Test
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
python -c "from torch_weighttracker import WeightTracker; print(WeightTracker)"
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Status
|
|
87
|
+
|
|
88
|
+
This package is pre-1.0. Public APIs may still change while the tracker,
|
|
89
|
+
calculation, and regularizer surfaces settle.
|
|
90
|
+
|
|
91
|
+
## License
|
|
92
|
+
|
|
93
|
+
MIT
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# torch-weighttracker
|
|
2
|
+
|
|
3
|
+
PyTorch tools for tracking structured weight sparsity, regularization signals,
|
|
4
|
+
and bit-operation estimates in neural network modules.
|
|
5
|
+
|
|
6
|
+
The public API is centered on `WeightTracker`:
|
|
7
|
+
|
|
8
|
+
```python
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from torch_weighttracker import WeightTracker
|
|
13
|
+
|
|
14
|
+
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
|
|
15
|
+
tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
|
|
16
|
+
print(tracker.view_structures())
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
python -m pip install torch-weighttracker
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
Structured BOPs MAC accounting uses `fvcore` for baseline per-module MACs:
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
python -m pip install "torch-weighttracker[structured-bops]"
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Development
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
python -m pip install -e ".[dev]"
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Run tests and lint checks:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
pytest
|
|
41
|
+
ruff check .
|
|
42
|
+
ruff format --check .
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Smoke Test
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
python -c "from torch_weighttracker import WeightTracker; print(WeightTracker)"
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Status
|
|
52
|
+
|
|
53
|
+
This package is pre-1.0. Public APIs may still change while the tracker,
|
|
54
|
+
calculation, and regularizer surfaces settle.
|
|
55
|
+
|
|
56
|
+
## License
|
|
57
|
+
|
|
58
|
+
MIT
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "torch-weighttracker"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Tools for tracking structured weight sparsity in PyTorch models."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
keywords = ["pytorch", "torch", "weights", "tracking", "structured-sparsity", "pruning"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 2 - Pre-Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"Programming Language :: Python :: 3",
|
|
19
|
+
"Programming Language :: Python :: 3.10",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Programming Language :: Python :: 3.13",
|
|
23
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
24
|
+
]
|
|
25
|
+
dependencies = [
|
|
26
|
+
"numpy>=1.24",
|
|
27
|
+
"torch>=2.0",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.optional-dependencies]
|
|
31
|
+
structured-bops = [
|
|
32
|
+
"fvcore",
|
|
33
|
+
]
|
|
34
|
+
dev = [
|
|
35
|
+
"build>=1.2",
|
|
36
|
+
"fvcore",
|
|
37
|
+
"pytest>=8.0",
|
|
38
|
+
"ruff>=0.6",
|
|
39
|
+
"timm>=1.0",
|
|
40
|
+
"torchvision>=0.15",
|
|
41
|
+
"twine>=5.0",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
[project.urls]
|
|
45
|
+
Homepage = "https://github.com/dadyownes15/torch-weighttracker"
|
|
46
|
+
Issues = "https://github.com/dadyownes15/torch-weighttracker/issues"
|
|
47
|
+
Repository = "https://github.com/dadyownes15/torch-weighttracker"
|
|
48
|
+
|
|
49
|
+
[tool.setuptools.packages.find]
|
|
50
|
+
where = ["."]
|
|
51
|
+
include = ["torch_weighttracker*"]
|
|
52
|
+
|
|
53
|
+
[tool.setuptools.package-data]
|
|
54
|
+
torch_weighttracker = ["py.typed"]
|
|
55
|
+
|
|
56
|
+
[tool.pytest.ini_options]
|
|
57
|
+
addopts = [
|
|
58
|
+
"-ra",
|
|
59
|
+
"--ignore=tests/test_attention_axes_e2e.py",
|
|
60
|
+
"--ignore=tests/test_attention_axes_unit.py",
|
|
61
|
+
"--ignore=tests/test_model_structured_sparsity.py",
|
|
62
|
+
"--ignore=tests/test_sparsity_controller.py",
|
|
63
|
+
"--ignore=tests/ignored",
|
|
64
|
+
]
|
|
65
|
+
pythonpath = ["."]
|
|
66
|
+
testpaths = ["tests"]
|
|
67
|
+
|
|
68
|
+
[tool.ruff]
|
|
69
|
+
line-length = 88
|
|
70
|
+
target-version = "py310"
|
|
71
|
+
|
|
72
|
+
[tool.ruff.lint]
|
|
73
|
+
select = ["E", "F", "I", "UP", "B"]
|
|
74
|
+
|
|
75
|
+
[tool.ruff.format]
|
|
76
|
+
quote-style = "double"
|
|
77
|
+
indent-style = "space"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch
|
|
7
|
+
from torch_weighttracker.torch_pruning.dependency import DependencyGraph
|
|
8
|
+
from torch_weighttracker.torch_pruning.pruner.function import LinearPruner, prune_linear_out_channels
|
|
9
|
+
|
|
10
|
+
class simpleMLP(nn.Module):
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.first_layer =nn.Linear(1,3)
|
|
14
|
+
self.secound_layer =nn.Linear(3,3)
|
|
15
|
+
self.third_layer =nn.Linear(3,1)
|
|
16
|
+
self.net = nn.Sequential(
|
|
17
|
+
self.first_layer,
|
|
18
|
+
self.secound_layer,
|
|
19
|
+
self.third_layer
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self,x):
|
|
23
|
+
return self.net(x)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_specs_creation():
|
|
27
|
+
model = simpleMLP()
|
|
28
|
+
input_ex = torch.tensor([[1.0]])
|
|
29
|
+
print(model.forward(input_ex))
|
|
30
|
+
graph = DependencyGraph().build_dependency(model = model, example_inputs=input_ex)
|
|
31
|
+
all_groups = []
|
|
32
|
+
|
|
33
|
+
for group in graph.get_all_groups():
|
|
34
|
+
all_groups.append(group)
|
|
35
|
+
# iteraer
|
|
36
|
+
first_group = all_groups[2]
|
|
37
|
+
first_item = first_group[0]
|
|
38
|
+
print("first item: ",first_item)
|
|
39
|
+
dep = first_item[0]
|
|
40
|
+
dep2 = first_group[1][0]
|
|
41
|
+
print(dep.handler)
|
|
42
|
+
print(vars(dep))
|
|
43
|
+
print("dep 2")
|
|
44
|
+
print(vars(dep2))
|
|
45
|
+
|
|
46
|
+
assert dep.handler == prune_linear_out_channels
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TinyTransformerClassifier(nn.Module):
|
|
9
|
+
def __init__(self) -> None:
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.token_embed = nn.Embedding(32, 16)
|
|
12
|
+
self.position_embed = nn.Embedding(8, 16)
|
|
13
|
+
self.attn = nn.MultiheadAttention(
|
|
14
|
+
embed_dim=16,
|
|
15
|
+
num_heads=4,
|
|
16
|
+
batch_first=True,
|
|
17
|
+
)
|
|
18
|
+
self.norm1 = nn.LayerNorm(16)
|
|
19
|
+
self.mlp_in = nn.Linear(16, 32)
|
|
20
|
+
self.activation = nn.GELU()
|
|
21
|
+
self.mlp_out = nn.Linear(32, 16)
|
|
22
|
+
self.norm2 = nn.LayerNorm(16)
|
|
23
|
+
self.head = nn.Linear(16, 5)
|
|
24
|
+
|
|
25
|
+
def forward(self, token_ids: Tensor) -> Tensor:
|
|
26
|
+
positions = torch.arange(token_ids.size(1), device=token_ids.device).unsqueeze(0)
|
|
27
|
+
x = self.token_embed(token_ids) + self.position_embed(positions)
|
|
28
|
+
|
|
29
|
+
attn_out, _ = self.attn(x, x, x, need_weights=False)
|
|
30
|
+
x = self.norm1(x + attn_out)
|
|
31
|
+
|
|
32
|
+
ff = self.mlp_out(self.activation(self.mlp_in(x)))
|
|
33
|
+
x = self.norm2(x + ff)
|
|
34
|
+
return self.head(x[:, 0])
|
|
35
|
+
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
import torch_structure_analyser as tsa
|
|
7
|
+
from torch_structure_analyser.analysis import StructureAxis
|
|
8
|
+
from tests.fixtures_models import TinyTransformerClassifier
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AttentionProbeNet(nn.Module):
|
|
12
|
+
def __init__(self):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
|
|
15
|
+
self.lin = nn.Linear(8, 8, bias=False)
|
|
16
|
+
|
|
17
|
+
def forward(self, x):
|
|
18
|
+
y, _ = self.mha(x, x, x)
|
|
19
|
+
return self.lin(y)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _make_probe_controller(model: AttentionProbeNet | None = None) -> tsa.SparsityTracker:
|
|
23
|
+
if model is None:
|
|
24
|
+
model = AttentionProbeNet()
|
|
25
|
+
return tsa.SparsityTracker(
|
|
26
|
+
model,
|
|
27
|
+
example_inputs=torch.randn(1, 4, 8),
|
|
28
|
+
root_module_types=[nn.MultiheadAttention, nn.Linear],
|
|
29
|
+
num_heads={model.mha: 2},
|
|
30
|
+
prune_num_heads=True,
|
|
31
|
+
prune_head_dims=True,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _make_transformer_controller(model: TinyTransformerClassifier | None = None) -> tsa.SparsityTracker:
|
|
36
|
+
if model is None:
|
|
37
|
+
model = TinyTransformerClassifier()
|
|
38
|
+
token_ids = torch.randint(0, 32, (1, 8))
|
|
39
|
+
return tsa.SparsityTracker(
|
|
40
|
+
model,
|
|
41
|
+
example_inputs=token_ids,
|
|
42
|
+
root_module_types=[nn.MultiheadAttention, nn.Linear],
|
|
43
|
+
num_heads={model.attn: model.attn.num_heads},
|
|
44
|
+
prune_num_heads=True,
|
|
45
|
+
prune_head_dims=True,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _zero_mha_out_slices(attention: nn.MultiheadAttention, idxs: list[int]) -> None:
|
|
50
|
+
embed_dim = attention.embed_dim
|
|
51
|
+
repeated_idxs = idxs + [idx + embed_dim for idx in idxs] + [idx + 2 * embed_dim for idx in idxs]
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
if attention.q_proj_weight is not None:
|
|
54
|
+
attention.q_proj_weight[idxs, :] = 0
|
|
55
|
+
if attention.k_proj_weight is not None:
|
|
56
|
+
attention.k_proj_weight[idxs, :] = 0
|
|
57
|
+
if attention.v_proj_weight is not None:
|
|
58
|
+
attention.v_proj_weight[idxs, :] = 0
|
|
59
|
+
if attention.in_proj_weight is not None:
|
|
60
|
+
attention.in_proj_weight[repeated_idxs, :] = 0
|
|
61
|
+
attention.in_proj_weight[:, idxs] = 0
|
|
62
|
+
if attention.out_proj is not None:
|
|
63
|
+
attention.out_proj.weight[idxs, :] = 0
|
|
64
|
+
attention.out_proj.weight[:, idxs] = 0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _zero_member_slices(module, handler, idxs: list[int]) -> None:
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
if handler in [tsa.prune_linear_out_channels]:
|
|
70
|
+
module.weight[idxs, :] = 0
|
|
71
|
+
if module.bias is not None:
|
|
72
|
+
module.bias[idxs] = 0
|
|
73
|
+
elif handler in [tsa.prune_linear_in_channels]:
|
|
74
|
+
module.weight[:, idxs] = 0
|
|
75
|
+
elif handler in [tsa.prune_layernorm_out_channels]:
|
|
76
|
+
if module.elementwise_affine:
|
|
77
|
+
module.weight[idxs] = 0
|
|
78
|
+
if module.bias is not None:
|
|
79
|
+
module.bias[idxs] = 0
|
|
80
|
+
elif handler in [tsa.prune_batchnorm_out_channels, tsa.prune_groupnorm_out_channels, tsa.prune_instancenorm_out_channels]:
|
|
81
|
+
if getattr(module, "affine", False):
|
|
82
|
+
module.weight[idxs] = 0
|
|
83
|
+
if module.bias is not None:
|
|
84
|
+
module.bias[idxs] = 0
|
|
85
|
+
elif handler in [tsa.prune_embedding_out_channels]:
|
|
86
|
+
module.weight[:, idxs] = 0
|
|
87
|
+
elif handler in [tsa.prune_multihead_attention_out_channels]:
|
|
88
|
+
_zero_mha_out_slices(module, idxs)
|
|
89
|
+
else:
|
|
90
|
+
raise AssertionError(f"Unhandled test zeroing rule for {handler.__name__}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _zero_group_unit(controller: tsa.SparsityTracker, group_id: str, root_indices: list[int]) -> None:
|
|
94
|
+
group_view = next(view for view in controller.iter_groups() if view.group_id == group_id)
|
|
95
|
+
root_set = set(root_indices)
|
|
96
|
+
for member in group_view.members:
|
|
97
|
+
if not member.measurable:
|
|
98
|
+
continue
|
|
99
|
+
local_idxs = [
|
|
100
|
+
local_idx
|
|
101
|
+
for local_idx, root_idx in zip(member.local_idxs, member.root_idxs)
|
|
102
|
+
if root_idx in root_set
|
|
103
|
+
]
|
|
104
|
+
if len(local_idxs) == 0:
|
|
105
|
+
continue
|
|
106
|
+
_zero_member_slices(member.module, member.handler, local_idxs)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_attention_head_dim_report_detects_zeroed_shared_dimension():
|
|
110
|
+
controller = _make_probe_controller()
|
|
111
|
+
_zero_group_unit(controller, "lin:prune_in_channels:head_dim", [0, 4])
|
|
112
|
+
|
|
113
|
+
report = controller.structured_sparsity()
|
|
114
|
+
|
|
115
|
+
assert report.by_group["lin:prune_in_channels:head_dim"].zero_prune_units == ((0, 4),)
|
|
116
|
+
assert report.by_group["lin:prune_in_channels:head"].stats.removed == 0
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_attention_head_report_detects_zeroed_whole_head():
|
|
120
|
+
controller = _make_probe_controller()
|
|
121
|
+
_zero_group_unit(controller, "lin:prune_in_channels:head", [0, 1, 2, 3])
|
|
122
|
+
|
|
123
|
+
report = controller.structured_sparsity()
|
|
124
|
+
|
|
125
|
+
assert report.by_group["lin:prune_in_channels:head"].zero_prune_units == ((0, 1, 2, 3),)
|
|
126
|
+
assert report.by_group["lin:prune_in_channels:head_dim"].stats.removed == 0
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_prune_zero_structures_updates_logical_num_heads_after_head_prune():
|
|
130
|
+
model = TinyTransformerClassifier()
|
|
131
|
+
controller = _make_transformer_controller(model)
|
|
132
|
+
_zero_group_unit(controller, "attn:prune_out_channels:head", [0, 1, 2, 3])
|
|
133
|
+
|
|
134
|
+
result = controller.prune_zero_structures()
|
|
135
|
+
attention_views = [view for view in controller.iter_groups() if view.axis == StructureAxis.HEAD]
|
|
136
|
+
|
|
137
|
+
assert result.pruned_group_ids == ("attn:prune_out_channels:head",)
|
|
138
|
+
assert controller.config.num_heads[model.attn] == 3
|
|
139
|
+
assert attention_views[0].size == 3
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_tiny_transformer_exposes_attention_head_and_head_dim_views_end_to_end():
|
|
143
|
+
controller = _make_transformer_controller()
|
|
144
|
+
|
|
145
|
+
attention_views = [view for view in controller.iter_groups() if view.group_id.startswith("attn:prune_out_channels")]
|
|
146
|
+
|
|
147
|
+
assert len(attention_views) == 2
|
|
148
|
+
assert {view.axis for view in attention_views} == {StructureAxis.HEAD, StructureAxis.HEAD_DIM}
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
import torch_structure_analyser as tsa
|
|
7
|
+
from torch_structure_analyser.analysis import (
|
|
8
|
+
StructureAxis,
|
|
9
|
+
build_atomic_prune_units,
|
|
10
|
+
build_grouped_prune_units,
|
|
11
|
+
build_head_prune_units,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AttentionProbeNet(nn.Module):
|
|
16
|
+
def __init__(self):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
|
|
19
|
+
self.lin = nn.Linear(8, 8, bias=False)
|
|
20
|
+
|
|
21
|
+
def forward(self, x):
|
|
22
|
+
y, _ = self.mha(x, x, x)
|
|
23
|
+
return self.lin(y)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _make_controller() -> tsa.SparsityTracker:
|
|
27
|
+
model = AttentionProbeNet()
|
|
28
|
+
return tsa.SparsityTracker(
|
|
29
|
+
model,
|
|
30
|
+
example_inputs=torch.randn(1, 4, 8),
|
|
31
|
+
root_module_types=[nn.MultiheadAttention, nn.Linear],
|
|
32
|
+
num_heads={model.mha: 2},
|
|
33
|
+
prune_num_heads=True,
|
|
34
|
+
prune_head_dims=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _zero_attention_indices(model: AttentionProbeNet, idxs: list[int]) -> None:
|
|
39
|
+
embed_dim = model.mha.embed_dim
|
|
40
|
+
repeated_idxs = idxs + [idx + embed_dim for idx in idxs] + [idx + 2 * embed_dim for idx in idxs]
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
if model.mha.q_proj_weight is not None:
|
|
43
|
+
model.mha.q_proj_weight[idxs, :] = 0
|
|
44
|
+
if model.mha.k_proj_weight is not None:
|
|
45
|
+
model.mha.k_proj_weight[idxs, :] = 0
|
|
46
|
+
if model.mha.v_proj_weight is not None:
|
|
47
|
+
model.mha.v_proj_weight[idxs, :] = 0
|
|
48
|
+
if model.mha.in_proj_weight is not None:
|
|
49
|
+
model.mha.in_proj_weight[repeated_idxs, :] = 0
|
|
50
|
+
model.mha.in_proj_weight[:, idxs] = 0
|
|
51
|
+
if model.mha.out_proj is not None:
|
|
52
|
+
model.mha.out_proj.weight[idxs, :] = 0
|
|
53
|
+
model.mha.out_proj.weight[:, idxs] = 0
|
|
54
|
+
model.lin.weight[:, idxs] = 0
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def test_build_atomic_prune_units_returns_singletons():
|
|
58
|
+
units = build_atomic_prune_units(4)
|
|
59
|
+
assert [unit.root_indices for unit in units] == [(0,), (1,), (2,), (3,)]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_build_grouped_prune_units_matches_head_dim_pattern():
|
|
63
|
+
units = build_grouped_prune_units(8, 2)
|
|
64
|
+
assert [unit.root_indices for unit in units] == [(0, 4), (1, 5), (2, 6), (3, 7)]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_build_head_prune_units_returns_contiguous_blocks():
|
|
68
|
+
units = build_head_prune_units(8, 2)
|
|
69
|
+
assert [unit.root_indices for unit in units] == [(0, 1, 2, 3), (4, 5, 6, 7)]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_attention_group_views_include_head_and_head_dim_axes():
|
|
73
|
+
controller = _make_controller()
|
|
74
|
+
attention_views = [view for view in controller.iter_groups() if view.group_id.startswith("lin:prune_in_channels")]
|
|
75
|
+
|
|
76
|
+
assert [view.axis for view in attention_views] == [StructureAxis.HEAD_DIM, StructureAxis.HEAD]
|
|
77
|
+
assert [unit.root_indices for unit in attention_views[0].prune_units] == [(0, 4), (1, 5), (2, 6), (3, 7)]
|
|
78
|
+
assert [unit.root_indices for unit in attention_views[1].prune_units] == [(0, 1, 2, 3), (4, 5, 6, 7)]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_zero_structure_candidates_prioritize_heads_before_head_dims():
|
|
82
|
+
controller = _make_controller()
|
|
83
|
+
_zero_attention_indices(controller.model, list(range(8)))
|
|
84
|
+
|
|
85
|
+
candidates = controller.zero_structure_candidates()
|
|
86
|
+
attention_candidates = [candidate for candidate in candidates if candidate.group_id.startswith("lin:prune_in_channels")]
|
|
87
|
+
|
|
88
|
+
assert [candidate.axis for candidate in attention_candidates[:2]] == [StructureAxis.HEAD, StructureAxis.HEAD_DIM]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def test_group_lasso_axes_filter_keeps_attention_views_separate():
|
|
92
|
+
controller = _make_controller()
|
|
93
|
+
|
|
94
|
+
head_loss, head_named_terms = controller.group_lasso(axes=(StructureAxis.HEAD,))
|
|
95
|
+
head_dim_loss, head_dim_named_terms = controller.group_lasso(axes=(StructureAxis.HEAD_DIM,))
|
|
96
|
+
|
|
97
|
+
assert head_loss.ndim == 0
|
|
98
|
+
assert head_dim_loss.ndim == 0
|
|
99
|
+
assert head_loss.requires_grad
|
|
100
|
+
assert head_dim_loss.requires_grad
|
|
101
|
+
assert head_named_terms
|
|
102
|
+
assert head_dim_named_terms
|
|
103
|
+
assert all("[0]" in key or "[1]" in key for key in head_named_terms)
|
|
104
|
+
assert len(head_named_terms) < len(head_dim_named_terms)
|