rasnatune 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rasnatune-0.0.1/LICENSE +21 -0
- rasnatune-0.0.1/PKG-INFO +79 -0
- rasnatune-0.0.1/README.md +66 -0
- rasnatune-0.0.1/pyproject.toml +18 -0
- rasnatune-0.0.1/rasnatune/__init__.py +6 -0
- rasnatune-0.0.1/rasnatune/base.py +15 -0
- rasnatune-0.0.1/rasnatune/compression.py +35 -0
- rasnatune-0.0.1/rasnatune/quantization.py +52 -0
- rasnatune-0.0.1/rasnatune/sparsification.py +61 -0
- rasnatune-0.0.1/rasnatune.egg-info/PKG-INFO +79 -0
- rasnatune-0.0.1/rasnatune.egg-info/SOURCES.txt +16 -0
- rasnatune-0.0.1/rasnatune.egg-info/dependency_links.txt +1 -0
- rasnatune-0.0.1/rasnatune.egg-info/requires.txt +1 -0
- rasnatune-0.0.1/rasnatune.egg-info/top_level.txt +1 -0
- rasnatune-0.0.1/setup.cfg +4 -0
- rasnatune-0.0.1/tests/test_compression.py +172 -0
- rasnatune-0.0.1/tests/test_quantization.py +127 -0
- rasnatune-0.0.1/tests/test_sparsification.py +123 -0
rasnatune-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 rasnatune contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
rasnatune-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rasnatune
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Runtime compression tuning utilities for PyTorch
|
|
5
|
+
Author: rasnatune contributors
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: pytorch,compression,sparsification,quantization
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: torch>=2.0
|
|
12
|
+
Dynamic: license-file
|
|
13
|
+
|
|
14
|
+
# rasnatune
|
|
15
|
+
|
|
16
|
+
`rasnatune` is a lightweight PyTorch compression helper based on forward hooks.
|
|
17
|
+
|
|
18
|
+
## Installation
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
pip install rasnatune
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## What It Does
|
|
25
|
+
|
|
26
|
+
- Applies compression at runtime using hooks (no permanent weight rewrite).
|
|
27
|
+
- Supports quantization and unstructured sparsification.
|
|
28
|
+
- Targets `torch.nn.Conv2d` and `torch.nn.Linear`.
|
|
29
|
+
|
|
30
|
+
## Quick Start
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import torch
|
|
34
|
+
from rasnatune import Compression, SparseWeightUnstructured
|
|
35
|
+
|
|
36
|
+
model = torch.nn.Sequential(
|
|
37
|
+
torch.nn.Linear(128, 64),
|
|
38
|
+
torch.nn.ReLU(),
|
|
39
|
+
torch.nn.Linear(64, 10),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
wrapped = Compression(model)
|
|
43
|
+
wrapped.attach(
|
|
44
|
+
SparseWeightUnstructured,
|
|
45
|
+
filter=lambda m: isinstance(m, torch.nn.Linear),
|
|
46
|
+
sparsity=0.5,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
x = torch.randn(128)
|
|
50
|
+
y = wrapped(x)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Public API
|
|
54
|
+
|
|
55
|
+
Top-level exports:
|
|
56
|
+
|
|
57
|
+
- `rasnatune.Compression`
|
|
58
|
+
- `rasnatune.Compressor`
|
|
59
|
+
- `rasnatune.QuantizeWeight`
|
|
60
|
+
- `rasnatune.QuantizeActivation`
|
|
61
|
+
- `rasnatune.SparseWeightUnstructured`
|
|
62
|
+
- `rasnatune.SparseActivationUnstructured`
|
|
63
|
+
|
|
64
|
+
## Compression Classes
|
|
65
|
+
|
|
66
|
+
- `QuantizeWeight(min=-128, max=127)`:
|
|
67
|
+
Quantizes layer weights only during forward, then restores original weights.
|
|
68
|
+
- `QuantizeActivation(min=-128, max=127)`:
|
|
69
|
+
Quantizes the first input activation of the layer before forward.
|
|
70
|
+
- `SparseWeightUnstructured(sparsity=0.5)`:
|
|
71
|
+
Applies unstructured sparsity to layer weights only during forward, then restores.
|
|
72
|
+
- `SparseActivationUnstructured(sparsity=0.5)`:
|
|
73
|
+
Applies unstructured sparsity to the first input activation before forward.
|
|
74
|
+
|
|
75
|
+
## Notes
|
|
76
|
+
|
|
77
|
+
- `Compression.attach` uses `filter=` to choose target modules.
|
|
78
|
+
- If `filter` is omitted, it will try all submodules, and unsupported modules will raise an assertion.
|
|
79
|
+
- Supported module types for built-in compressors are `torch.nn.Conv2d` and `torch.nn.Linear`.
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# rasnatune
|
|
2
|
+
|
|
3
|
+
`rasnatune` is a lightweight PyTorch compression helper based on forward hooks.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install rasnatune
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## What It Does
|
|
12
|
+
|
|
13
|
+
- Applies compression at runtime using hooks (no permanent weight rewrite).
|
|
14
|
+
- Supports quantization and unstructured sparsification.
|
|
15
|
+
- Targets `torch.nn.Conv2d` and `torch.nn.Linear`.
|
|
16
|
+
|
|
17
|
+
## Quick Start
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
import torch
|
|
21
|
+
from rasnatune import Compression, SparseWeightUnstructured
|
|
22
|
+
|
|
23
|
+
model = torch.nn.Sequential(
|
|
24
|
+
torch.nn.Linear(128, 64),
|
|
25
|
+
torch.nn.ReLU(),
|
|
26
|
+
torch.nn.Linear(64, 10),
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
wrapped = Compression(model)
|
|
30
|
+
wrapped.attach(
|
|
31
|
+
SparseWeightUnstructured,
|
|
32
|
+
filter=lambda m: isinstance(m, torch.nn.Linear),
|
|
33
|
+
sparsity=0.5,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
x = torch.randn(128)
|
|
37
|
+
y = wrapped(x)
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Public API
|
|
41
|
+
|
|
42
|
+
Top-level exports:
|
|
43
|
+
|
|
44
|
+
- `rasnatune.Compression`
|
|
45
|
+
- `rasnatune.Compressor`
|
|
46
|
+
- `rasnatune.QuantizeWeight`
|
|
47
|
+
- `rasnatune.QuantizeActivation`
|
|
48
|
+
- `rasnatune.SparseWeightUnstructured`
|
|
49
|
+
- `rasnatune.SparseActivationUnstructured`
|
|
50
|
+
|
|
51
|
+
## Compression Classes
|
|
52
|
+
|
|
53
|
+
- `QuantizeWeight(min=-128, max=127)`:
|
|
54
|
+
Quantizes layer weights only during forward, then restores original weights.
|
|
55
|
+
- `QuantizeActivation(min=-128, max=127)`:
|
|
56
|
+
Quantizes the first input activation of the layer before forward.
|
|
57
|
+
- `SparseWeightUnstructured(sparsity=0.5)`:
|
|
58
|
+
Applies unstructured sparsity to layer weights only during forward, then restores.
|
|
59
|
+
- `SparseActivationUnstructured(sparsity=0.5)`:
|
|
60
|
+
Applies unstructured sparsity to the first input activation before forward.
|
|
61
|
+
|
|
62
|
+
## Notes
|
|
63
|
+
|
|
64
|
+
- `Compression.attach` uses `filter=` to choose target modules.
|
|
65
|
+
- If `filter` is omitted, it will try all submodules, and unsupported modules will raise an assertion.
|
|
66
|
+
- Supported module types for built-in compressors are `torch.nn.Conv2d` and `torch.nn.Linear`.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "rasnatune"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Runtime compression tuning utilities for PyTorch"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
authors = [{ name = "rasnatune contributors" }]
|
|
12
|
+
license = "MIT"
|
|
13
|
+
keywords = ["pytorch", "compression", "sparsification", "quantization"]
|
|
14
|
+
dependencies = ["torch>=2.0"]
|
|
15
|
+
|
|
16
|
+
[tool.setuptools.packages.find]
|
|
17
|
+
include = ["rasnatune*"]
|
|
18
|
+
exclude = ["__pycache__*"]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Compressor:
|
|
7
|
+
_backup: torch.Tensor
|
|
8
|
+
_hooks: list[torch.utils.hooks.RemovableHandle]
|
|
9
|
+
|
|
10
|
+
def attach(self, module: torch.nn.Module) -> None:
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
def detach(self, module: torch.nn.Module) -> None:
|
|
14
|
+
for hook in self._hooks:
|
|
15
|
+
hook.remove()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from .base import Compressor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Compression(torch.nn.Module):
|
|
12
|
+
def __init__(self, model: torch.nn.Module):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.model = model
|
|
15
|
+
self.registrations: list[Compressor] = []
|
|
16
|
+
|
|
17
|
+
def forward(self, *args, **kwargs):
|
|
18
|
+
return self.model(*args, **kwargs)
|
|
19
|
+
|
|
20
|
+
def attach(self, compressor: type[Compressor], filter: Callable, *args, **kwargs,):
|
|
21
|
+
for name, module in self.model.named_modules():
|
|
22
|
+
if not filter(module):
|
|
23
|
+
continue
|
|
24
|
+
instance = compressor(*args, **kwargs)
|
|
25
|
+
instance.attach(module)
|
|
26
|
+
self.registrations.append(instance)
|
|
27
|
+
|
|
28
|
+
def detach(self, compressor: Compressor):
|
|
29
|
+
compressor.detach(self.model)
|
|
30
|
+
if compressor in self.registrations:
|
|
31
|
+
self.registrations.remove(compressor)
|
|
32
|
+
|
|
33
|
+
def clear(self):
|
|
34
|
+
for compressor in list(self.registrations):
|
|
35
|
+
self.detach(compressor)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .base import Compressor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def quantize(x: torch.Tensor, qmin: int, qmax: int) -> torch.Tensor:
|
|
11
|
+
max_abs = x.detach().abs().max()
|
|
12
|
+
if max_abs.item() == 0:
|
|
13
|
+
return x
|
|
14
|
+
scale = max_abs / max(qmax, -qmin)
|
|
15
|
+
q = torch.clamp(torch.round(x / scale), qmin, qmax)
|
|
16
|
+
quantized = q * scale
|
|
17
|
+
error = x.detach() - quantized
|
|
18
|
+
return x - error
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class QuantizeWeight(Compressor):
|
|
23
|
+
min: int = -128
|
|
24
|
+
max: int = 127
|
|
25
|
+
|
|
26
|
+
def attach(self, module: torch.nn.Module) -> None:
|
|
27
|
+
def pre_hook(mod: torch.nn.Module, inputs) -> None:
|
|
28
|
+
self._backup = mod.weight.data.detach().clone()
|
|
29
|
+
mod.weight.data.copy_(quantize(mod.weight.data, qmin=self.min, qmax=self.max))
|
|
30
|
+
|
|
31
|
+
def post_hook(mod: torch.nn.Module, inputs, output):
|
|
32
|
+
mod.weight.data.copy_(self._backup)
|
|
33
|
+
return output
|
|
34
|
+
|
|
35
|
+
self._hooks = [
|
|
36
|
+
module.register_forward_pre_hook(pre_hook),
|
|
37
|
+
module.register_forward_hook(post_hook)
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class QuantizeActivation(Compressor):
|
|
43
|
+
min: int = -128
|
|
44
|
+
max: int = 127
|
|
45
|
+
|
|
46
|
+
def attach(self, module: torch.nn.Module) -> None:
|
|
47
|
+
def pre_hook(_module: torch.nn.Module, inputs):
|
|
48
|
+
return (quantize(inputs[0], qmin=self.min, qmax=self.max), *inputs[1:])
|
|
49
|
+
|
|
50
|
+
self._hooks = [
|
|
51
|
+
module.register_forward_pre_hook(pre_hook)
|
|
52
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .base import Compressor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def sparsify(x: torch.Tensor, sparsity: float) -> torch.Tensor:
|
|
11
|
+
if not 0.0 <= sparsity <= 1.0:
|
|
12
|
+
raise ValueError(f"sparsity must be between 0 and 1, got {sparsity}")
|
|
13
|
+
|
|
14
|
+
flat_abs = x.detach().abs().flatten()
|
|
15
|
+
if flat_abs.numel() == 0 or sparsity == 0.0:
|
|
16
|
+
return x
|
|
17
|
+
if sparsity == 1.0:
|
|
18
|
+
return torch.zeros_like(x)
|
|
19
|
+
|
|
20
|
+
num_zero = int(flat_abs.numel() * sparsity)
|
|
21
|
+
if num_zero == 0:
|
|
22
|
+
return x
|
|
23
|
+
|
|
24
|
+
zero_indices = flat_abs.topk(num_zero, largest=False).indices
|
|
25
|
+
sparsity_mask = torch.ones_like(flat_abs, dtype=torch.bool)
|
|
26
|
+
sparsity_mask[zero_indices] = False
|
|
27
|
+
return x * sparsity_mask.view_as(x).to(x.dtype)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class SparseWeightUnstructured(Compressor):
|
|
32
|
+
sparsity: float = 0.5
|
|
33
|
+
|
|
34
|
+
def attach(self, module: torch.nn.Module) -> None:
|
|
35
|
+
def pre_hook(mod: torch.nn.Module, _inputs) -> None:
|
|
36
|
+
self._backup = mod.weight.data.detach().clone()
|
|
37
|
+
mod.weight.data.copy_(
|
|
38
|
+
sparsify(mod.weight.data, sparsity=self.sparsity)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def post_hook(mod: torch.nn.Module, inputs, output):
|
|
42
|
+
mod.weight.data.copy_(self._backup)
|
|
43
|
+
return output
|
|
44
|
+
|
|
45
|
+
self._hooks = [
|
|
46
|
+
module.register_forward_pre_hook(pre_hook),
|
|
47
|
+
module.register_forward_hook(post_hook),
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class SparseActivationUnstructured(Compressor):
|
|
53
|
+
sparsity: float = 0.5
|
|
54
|
+
|
|
55
|
+
def attach(self, module: torch.nn.Module) -> None:
|
|
56
|
+
def pre_hook(_module: torch.nn.Module, inputs):
|
|
57
|
+
return (sparsify(inputs[0], sparsity=self.sparsity), *inputs[1:])
|
|
58
|
+
|
|
59
|
+
self._hooks = [
|
|
60
|
+
module.register_forward_pre_hook(pre_hook),
|
|
61
|
+
]
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rasnatune
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Runtime compression tuning utilities for PyTorch
|
|
5
|
+
Author: rasnatune contributors
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: pytorch,compression,sparsification,quantization
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: torch>=2.0
|
|
12
|
+
Dynamic: license-file
|
|
13
|
+
|
|
14
|
+
# rasnatune
|
|
15
|
+
|
|
16
|
+
`rasnatune` is a lightweight PyTorch compression helper based on forward hooks.
|
|
17
|
+
|
|
18
|
+
## Installation
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
pip install rasnatune
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## What It Does
|
|
25
|
+
|
|
26
|
+
- Applies compression at runtime using hooks (no permanent weight rewrite).
|
|
27
|
+
- Supports quantization and unstructured sparsification.
|
|
28
|
+
- Targets `torch.nn.Conv2d` and `torch.nn.Linear`.
|
|
29
|
+
|
|
30
|
+
## Quick Start
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import torch
|
|
34
|
+
from rasnatune import Compression, SparseWeightUnstructured
|
|
35
|
+
|
|
36
|
+
model = torch.nn.Sequential(
|
|
37
|
+
torch.nn.Linear(128, 64),
|
|
38
|
+
torch.nn.ReLU(),
|
|
39
|
+
torch.nn.Linear(64, 10),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
wrapped = Compression(model)
|
|
43
|
+
wrapped.attach(
|
|
44
|
+
SparseWeightUnstructured,
|
|
45
|
+
filter=lambda m: isinstance(m, torch.nn.Linear),
|
|
46
|
+
sparsity=0.5,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
x = torch.randn(128)
|
|
50
|
+
y = wrapped(x)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Public API
|
|
54
|
+
|
|
55
|
+
Top-level exports:
|
|
56
|
+
|
|
57
|
+
- `rasnatune.Compression`
|
|
58
|
+
- `rasnatune.Compressor`
|
|
59
|
+
- `rasnatune.QuantizeWeight`
|
|
60
|
+
- `rasnatune.QuantizeActivation`
|
|
61
|
+
- `rasnatune.SparseWeightUnstructured`
|
|
62
|
+
- `rasnatune.SparseActivationUnstructured`
|
|
63
|
+
|
|
64
|
+
## Compression Classes
|
|
65
|
+
|
|
66
|
+
- `QuantizeWeight(min=-128, max=127)`:
|
|
67
|
+
Quantizes layer weights only during forward, then restores original weights.
|
|
68
|
+
- `QuantizeActivation(min=-128, max=127)`:
|
|
69
|
+
Quantizes the first input activation of the layer before forward.
|
|
70
|
+
- `SparseWeightUnstructured(sparsity=0.5)`:
|
|
71
|
+
Applies unstructured sparsity to layer weights only during forward, then restores.
|
|
72
|
+
- `SparseActivationUnstructured(sparsity=0.5)`:
|
|
73
|
+
Applies unstructured sparsity to the first input activation before forward.
|
|
74
|
+
|
|
75
|
+
## Notes
|
|
76
|
+
|
|
77
|
+
- `Compression.attach` uses `filter=` to choose target modules.
|
|
78
|
+
- If `filter` is omitted, it will try all submodules, and unsupported modules will raise an assertion.
|
|
79
|
+
- Supported module types for built-in compressors are `torch.nn.Conv2d` and `torch.nn.Linear`.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
rasnatune/__init__.py
|
|
5
|
+
rasnatune/base.py
|
|
6
|
+
rasnatune/compression.py
|
|
7
|
+
rasnatune/quantization.py
|
|
8
|
+
rasnatune/sparsification.py
|
|
9
|
+
rasnatune.egg-info/PKG-INFO
|
|
10
|
+
rasnatune.egg-info/SOURCES.txt
|
|
11
|
+
rasnatune.egg-info/dependency_links.txt
|
|
12
|
+
rasnatune.egg-info/requires.txt
|
|
13
|
+
rasnatune.egg-info/top_level.txt
|
|
14
|
+
tests/test_compression.py
|
|
15
|
+
tests/test_quantization.py
|
|
16
|
+
tests/test_sparsification.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch>=2.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
rasnatune
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
_TORCHVISION_NMS_LIB = torch.library.Library("torchvision", "DEF")
|
|
6
|
+
_TORCHVISION_NMS_LIB.define("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")
|
|
7
|
+
except RuntimeError:
|
|
8
|
+
_TORCHVISION_NMS_LIB = None
|
|
9
|
+
|
|
10
|
+
from torchvision.models import efficientnet_b0, resnet50
|
|
11
|
+
|
|
12
|
+
from rasnatune.compression import Compression
|
|
13
|
+
from rasnatune.quantization import QuantizeActivation, QuantizeWeight
|
|
14
|
+
from rasnatune.sparsification import SparseActivationUnstructured, SparseWeightUnstructured
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _supported_module_filter(module: torch.nn.Module) -> bool:
|
|
18
|
+
return isinstance(module, (torch.nn.Conv2d, torch.nn.Linear))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _supported_modules(model: torch.nn.Module) -> list[torch.nn.Module]:
|
|
22
|
+
return [module for module in model.modules() if _supported_module_filter(module)]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _make_toy_model() -> torch.nn.Sequential:
|
|
26
|
+
model = torch.nn.Sequential(
|
|
27
|
+
torch.nn.Linear(4, 4, bias=False),
|
|
28
|
+
torch.nn.ReLU(),
|
|
29
|
+
torch.nn.Linear(4, 2, bias=False),
|
|
30
|
+
)
|
|
31
|
+
with torch.no_grad():
|
|
32
|
+
model[0].weight.copy_(
|
|
33
|
+
torch.tensor(
|
|
34
|
+
[
|
|
35
|
+
[1.75, -0.25, 0.5, -1.1],
|
|
36
|
+
[0.2, -0.6, 1.4, -0.8],
|
|
37
|
+
[0.9, 0.3, -0.4, 1.1],
|
|
38
|
+
[-1.2, 0.7, 0.6, -0.5],
|
|
39
|
+
],
|
|
40
|
+
dtype=torch.float32,
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
model[2].weight.copy_(
|
|
44
|
+
torch.tensor(
|
|
45
|
+
[
|
|
46
|
+
[0.8, -1.4, 0.6, -0.2],
|
|
47
|
+
[-0.5, 0.9, -1.1, 1.3],
|
|
48
|
+
],
|
|
49
|
+
dtype=torch.float32,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
return model
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _attach_quantization(wrapped: Compression) -> None:
|
|
56
|
+
wrapped.attach(QuantizeWeight, filter=_supported_module_filter, min=-8, max=7)
|
|
57
|
+
wrapped.attach(QuantizeActivation, filter=_supported_module_filter, min=-8, max=7)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _attach_sparsification(wrapped: Compression) -> None:
|
|
61
|
+
wrapped.attach(SparseWeightUnstructured, filter=_supported_module_filter, sparsity=0.25)
|
|
62
|
+
wrapped.attach(SparseActivationUnstructured, filter=_supported_module_filter, sparsity=0.25)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _run_compressed_training_step(model: torch.nn.Module, attach_fn) -> None:
|
|
66
|
+
wrapped = Compression(model)
|
|
67
|
+
modules = _supported_modules(wrapped.model)
|
|
68
|
+
first_module = modules[0]
|
|
69
|
+
original_weight = first_module.weight.detach().clone()
|
|
70
|
+
|
|
71
|
+
attach_fn(wrapped)
|
|
72
|
+
|
|
73
|
+
assert len(wrapped.registrations) == 2 * len(modules)
|
|
74
|
+
|
|
75
|
+
optimizer = torch.optim.SGD(wrapped.parameters(), lr=1e-2)
|
|
76
|
+
inputs = torch.randn(2, 3, 64, 64, dtype=torch.float32)
|
|
77
|
+
targets = torch.tensor([1, 3], dtype=torch.long)
|
|
78
|
+
|
|
79
|
+
optimizer.zero_grad(set_to_none=True)
|
|
80
|
+
logits = wrapped(inputs)
|
|
81
|
+
loss = F.cross_entropy(logits, targets)
|
|
82
|
+
loss.backward()
|
|
83
|
+
|
|
84
|
+
gradients = [parameter.grad for parameter in wrapped.parameters() if parameter.grad is not None]
|
|
85
|
+
assert gradients
|
|
86
|
+
assert all(torch.isfinite(gradient).all().item() for gradient in gradients)
|
|
87
|
+
torch.testing.assert_close(first_module.weight.detach(), original_weight)
|
|
88
|
+
|
|
89
|
+
parameters_before_step = [parameter.detach().clone() for parameter in wrapped.parameters()]
|
|
90
|
+
optimizer.step()
|
|
91
|
+
|
|
92
|
+
assert any(
|
|
93
|
+
not torch.equal(parameter.detach(), before_step)
|
|
94
|
+
for parameter, before_step in zip(wrapped.parameters(), parameters_before_step)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
wrapped.clear()
|
|
98
|
+
assert wrapped.registrations == []
|
|
99
|
+
|
|
100
|
+
cleared_logits = wrapped(inputs)
|
|
101
|
+
assert cleared_logits.shape == logits.shape
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_compression_attach_and_detach_manage_registrations() -> None:
|
|
105
|
+
model = _make_toy_model()
|
|
106
|
+
wrapped = Compression(model)
|
|
107
|
+
x = torch.tensor([[1.0, -2.0, 3.0, -4.0]], dtype=torch.float32)
|
|
108
|
+
|
|
109
|
+
wrapped.attach(QuantizeWeight, filter=lambda module: isinstance(module, torch.nn.Linear), min=-1, max=1)
|
|
110
|
+
|
|
111
|
+
assert len(wrapped.registrations) == 2
|
|
112
|
+
|
|
113
|
+
compressed_output = wrapped(x)
|
|
114
|
+
wrapped.detach(wrapped.registrations[0])
|
|
115
|
+
|
|
116
|
+
assert len(wrapped.registrations) == 1
|
|
117
|
+
|
|
118
|
+
wrapped.clear()
|
|
119
|
+
assert wrapped.registrations == []
|
|
120
|
+
|
|
121
|
+
restored_output = wrapped(x)
|
|
122
|
+
reference_output = model(x)
|
|
123
|
+
|
|
124
|
+
torch.testing.assert_close(restored_output, reference_output)
|
|
125
|
+
assert not torch.equal(compressed_output, reference_output)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_compression_clear_removes_sparse_registrations() -> None:
|
|
129
|
+
model = _make_toy_model()
|
|
130
|
+
wrapped = Compression(model)
|
|
131
|
+
|
|
132
|
+
_attach_sparsification(wrapped)
|
|
133
|
+
|
|
134
|
+
assert len(wrapped.registrations) == 4
|
|
135
|
+
|
|
136
|
+
wrapped.clear()
|
|
137
|
+
|
|
138
|
+
assert wrapped.registrations == []
|
|
139
|
+
output = wrapped(torch.randn(1, 4, dtype=torch.float32))
|
|
140
|
+
assert output.shape == (1, 2)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_resnet50_quantized_training_step_runs() -> None:
|
|
144
|
+
torch.manual_seed(0)
|
|
145
|
+
model = resnet50(weights=None, num_classes=5)
|
|
146
|
+
model.train()
|
|
147
|
+
|
|
148
|
+
_run_compressed_training_step(model, _attach_quantization)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_efficientnet_b0_quantized_training_step_runs() -> None:
|
|
152
|
+
torch.manual_seed(0)
|
|
153
|
+
model = efficientnet_b0(weights=None, num_classes=5)
|
|
154
|
+
model.train()
|
|
155
|
+
|
|
156
|
+
_run_compressed_training_step(model, _attach_quantization)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_resnet50_sparse_training_step_runs() -> None:
|
|
160
|
+
torch.manual_seed(0)
|
|
161
|
+
model = resnet50(weights=None, num_classes=5)
|
|
162
|
+
model.train()
|
|
163
|
+
|
|
164
|
+
_run_compressed_training_step(model, _attach_sparsification)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def test_efficientnet_b0_sparse_training_step_runs() -> None:
|
|
168
|
+
torch.manual_seed(0)
|
|
169
|
+
model = efficientnet_b0(weights=None, num_classes=5)
|
|
170
|
+
model.train()
|
|
171
|
+
|
|
172
|
+
_run_compressed_training_step(model, _attach_sparsification)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from rasnatune.quantization import QuantizeActivation, QuantizeWeight, quantize
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _make_linear() -> torch.nn.Linear:
|
|
8
|
+
module = torch.nn.Linear(2, 1, bias=False)
|
|
9
|
+
with torch.no_grad():
|
|
10
|
+
module.weight.copy_(torch.tensor([[1.75, -0.25]], dtype=torch.float32))
|
|
11
|
+
return module
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _make_conv2d() -> torch.nn.Conv2d:
|
|
15
|
+
module = torch.nn.Conv2d(1, 1, kernel_size=2, bias=False)
|
|
16
|
+
with torch.no_grad():
|
|
17
|
+
module.weight.copy_(
|
|
18
|
+
torch.tensor([[[[1.75, -0.25], [0.5, -1.1]]]], dtype=torch.float32)
|
|
19
|
+
)
|
|
20
|
+
return module
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_quantize_maps_values_to_expected_levels() -> None:
|
|
24
|
+
x = torch.tensor([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=torch.float32)
|
|
25
|
+
|
|
26
|
+
actual = quantize(x, qmin=-2, qmax=2)
|
|
27
|
+
expected = torch.tensor([-3.0, -1.5, 0.0, 1.5, 3.0], dtype=torch.float32)
|
|
28
|
+
|
|
29
|
+
torch.testing.assert_close(actual, expected)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_quantize_clamps_to_asymmetric_range() -> None:
|
|
33
|
+
x = torch.tensor([-2.0, -0.6, 0.2, 1.7], dtype=torch.float32)
|
|
34
|
+
|
|
35
|
+
actual = quantize(x, qmin=-1, qmax=2)
|
|
36
|
+
expected = torch.tensor([-1.0, -1.0, 0.0, 2.0], dtype=torch.float32)
|
|
37
|
+
|
|
38
|
+
torch.testing.assert_close(actual, expected)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_quantize_preserves_straight_through_gradients() -> None:
|
|
42
|
+
x = torch.tensor([0.3, -0.6], dtype=torch.float32, requires_grad=True)
|
|
43
|
+
upstream = torch.tensor([2.0, -3.0], dtype=torch.float32)
|
|
44
|
+
|
|
45
|
+
loss = (quantize(x, qmin=-2, qmax=1) * upstream).sum()
|
|
46
|
+
loss.backward()
|
|
47
|
+
|
|
48
|
+
torch.testing.assert_close(x.grad, upstream)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_quantize_zero_tensor_returns_zero_without_nan() -> None:
|
|
52
|
+
x = torch.zeros(4, dtype=torch.float32, requires_grad=True)
|
|
53
|
+
|
|
54
|
+
actual = quantize(x, qmin=-128, qmax=127)
|
|
55
|
+
actual.sum().backward()
|
|
56
|
+
|
|
57
|
+
torch.testing.assert_close(actual, torch.zeros_like(x))
|
|
58
|
+
torch.testing.assert_close(x.grad, torch.ones_like(x))
|
|
59
|
+
assert torch.isfinite(actual).all().item()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_quantize_weight_applies_and_restores_on_linear() -> None:
|
|
63
|
+
module = _make_linear()
|
|
64
|
+
compressor = QuantizeWeight(min=-1, max=1)
|
|
65
|
+
x = torch.tensor([[2.0, 3.0]], dtype=torch.float32)
|
|
66
|
+
original_weight = module.weight.detach().clone()
|
|
67
|
+
|
|
68
|
+
compressor.attach(module)
|
|
69
|
+
actual = module(x)
|
|
70
|
+
|
|
71
|
+
expected_weight = quantize(original_weight, qmin=-1, qmax=1)
|
|
72
|
+
expected = F.linear(x, expected_weight)
|
|
73
|
+
|
|
74
|
+
torch.testing.assert_close(actual, expected)
|
|
75
|
+
torch.testing.assert_close(module.weight.detach(), original_weight)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_quantize_weight_applies_and_restores_on_conv2d() -> None:
|
|
79
|
+
module = _make_conv2d()
|
|
80
|
+
compressor = QuantizeWeight(min=-1, max=1)
|
|
81
|
+
x = torch.tensor([[[[1.0, 0.5], [-0.25, 0.75]]]], dtype=torch.float32)
|
|
82
|
+
original_weight = module.weight.detach().clone()
|
|
83
|
+
|
|
84
|
+
compressor.attach(module)
|
|
85
|
+
actual = module(x)
|
|
86
|
+
|
|
87
|
+
expected_weight = quantize(original_weight, qmin=-1, qmax=1)
|
|
88
|
+
expected = F.conv2d(x, expected_weight)
|
|
89
|
+
|
|
90
|
+
torch.testing.assert_close(actual, expected)
|
|
91
|
+
torch.testing.assert_close(module.weight.detach(), original_weight)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_quantize_activation_applies_to_linear_inputs() -> None:
|
|
95
|
+
module = _make_linear()
|
|
96
|
+
compressor = QuantizeActivation(min=-1, max=1)
|
|
97
|
+
x = torch.tensor([[1.75, -0.25]], dtype=torch.float32)
|
|
98
|
+
|
|
99
|
+
compressor.attach(module)
|
|
100
|
+
actual = module(x)
|
|
101
|
+
|
|
102
|
+
expected = F.linear(quantize(x, qmin=-1, qmax=1), module.weight.detach())
|
|
103
|
+
|
|
104
|
+
torch.testing.assert_close(actual, expected)
|
|
105
|
+
|
|
106
|
+
compressor.detach(module)
|
|
107
|
+
restored = module(x)
|
|
108
|
+
torch.testing.assert_close(restored, F.linear(x, module.weight.detach()))
|
|
109
|
+
assert not torch.equal(actual, restored)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_quantize_activation_applies_to_conv2d_inputs() -> None:
|
|
113
|
+
module = _make_conv2d()
|
|
114
|
+
compressor = QuantizeActivation(min=-1, max=1)
|
|
115
|
+
x = torch.tensor([[[[1.75, -0.25], [0.5, -1.1]]]], dtype=torch.float32)
|
|
116
|
+
|
|
117
|
+
compressor.attach(module)
|
|
118
|
+
actual = module(x)
|
|
119
|
+
|
|
120
|
+
expected = F.conv2d(quantize(x, qmin=-1, qmax=1), module.weight.detach())
|
|
121
|
+
|
|
122
|
+
torch.testing.assert_close(actual, expected)
|
|
123
|
+
|
|
124
|
+
compressor.detach(module)
|
|
125
|
+
restored = module(x)
|
|
126
|
+
torch.testing.assert_close(restored, F.conv2d(x, module.weight.detach()))
|
|
127
|
+
assert not torch.equal(actual, restored)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from rasnatune.sparsification import (
|
|
6
|
+
SparseActivationUnstructured,
|
|
7
|
+
SparseWeightUnstructured,
|
|
8
|
+
sparsify,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _make_linear() -> torch.nn.Linear:
|
|
13
|
+
module = torch.nn.Linear(4, 1, bias=False)
|
|
14
|
+
with torch.no_grad():
|
|
15
|
+
module.weight.copy_(torch.tensor([[1.0, -2.0, 3.0, -4.0]], dtype=torch.float32))
|
|
16
|
+
return module
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _make_conv2d() -> torch.nn.Conv2d:
|
|
20
|
+
module = torch.nn.Conv2d(1, 1, kernel_size=2, bias=False)
|
|
21
|
+
with torch.no_grad():
|
|
22
|
+
module.weight.copy_(
|
|
23
|
+
torch.tensor([[[[1.0, -2.0], [3.0, -4.0]]]], dtype=torch.float32)
|
|
24
|
+
)
|
|
25
|
+
return module
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_sparsify_returns_input_for_zero_sparsity() -> None:
|
|
29
|
+
x = torch.tensor([1.0, -2.0, 3.0, -4.0], dtype=torch.float32)
|
|
30
|
+
|
|
31
|
+
actual = sparsify(x, sparsity=0.0)
|
|
32
|
+
|
|
33
|
+
torch.testing.assert_close(actual, x)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_sparsify_zeroes_smallest_magnitudes() -> None:
|
|
37
|
+
x = torch.tensor([1.0, -2.0, 3.0, -4.0], dtype=torch.float32)
|
|
38
|
+
|
|
39
|
+
actual = sparsify(x, sparsity=0.5)
|
|
40
|
+
expected = torch.tensor([0.0, 0.0, 3.0, -4.0], dtype=torch.float32)
|
|
41
|
+
|
|
42
|
+
torch.testing.assert_close(actual, expected)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_sparsify_zeros_all_values_for_full_sparsity() -> None:
|
|
46
|
+
x = torch.tensor([1.0, -2.0, 3.0, -4.0], dtype=torch.float32)
|
|
47
|
+
|
|
48
|
+
actual = sparsify(x, sparsity=1.0)
|
|
49
|
+
|
|
50
|
+
torch.testing.assert_close(actual, torch.zeros_like(x))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_sparsify_rejects_invalid_sparsity() -> None:
|
|
54
|
+
with pytest.raises(ValueError):
|
|
55
|
+
sparsify(torch.ones(2, dtype=torch.float32), sparsity=1.5)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_sparse_weight_applies_and_restores_on_linear() -> None:
|
|
59
|
+
module = _make_linear()
|
|
60
|
+
compressor = SparseWeightUnstructured(sparsity=0.5)
|
|
61
|
+
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], dtype=torch.float32)
|
|
62
|
+
original_weight = module.weight.detach().clone()
|
|
63
|
+
|
|
64
|
+
compressor.attach(module)
|
|
65
|
+
actual = module(x)
|
|
66
|
+
|
|
67
|
+
expected_weight = sparsify(original_weight, sparsity=0.5)
|
|
68
|
+
expected = F.linear(x, expected_weight)
|
|
69
|
+
|
|
70
|
+
torch.testing.assert_close(actual, expected)
|
|
71
|
+
torch.testing.assert_close(module.weight.detach(), original_weight)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_sparse_weight_applies_and_restores_on_conv2d() -> None:
|
|
75
|
+
module = _make_conv2d()
|
|
76
|
+
compressor = SparseWeightUnstructured(sparsity=0.5)
|
|
77
|
+
x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32)
|
|
78
|
+
original_weight = module.weight.detach().clone()
|
|
79
|
+
|
|
80
|
+
compressor.attach(module)
|
|
81
|
+
actual = module(x)
|
|
82
|
+
|
|
83
|
+
expected_weight = sparsify(original_weight, sparsity=0.5)
|
|
84
|
+
expected = F.conv2d(x, expected_weight)
|
|
85
|
+
|
|
86
|
+
torch.testing.assert_close(actual, expected)
|
|
87
|
+
torch.testing.assert_close(module.weight.detach(), original_weight)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_sparse_activation_applies_to_linear_inputs() -> None:
|
|
91
|
+
module = _make_linear()
|
|
92
|
+
compressor = SparseActivationUnstructured(sparsity=0.5)
|
|
93
|
+
x = torch.tensor([[1.0, -2.0, 3.0, -4.0]], dtype=torch.float32)
|
|
94
|
+
|
|
95
|
+
compressor.attach(module)
|
|
96
|
+
actual = module(x)
|
|
97
|
+
|
|
98
|
+
expected = F.linear(sparsify(x, sparsity=0.5), module.weight.detach())
|
|
99
|
+
|
|
100
|
+
torch.testing.assert_close(actual, expected)
|
|
101
|
+
|
|
102
|
+
compressor.detach(module)
|
|
103
|
+
restored = module(x)
|
|
104
|
+
torch.testing.assert_close(restored, F.linear(x, module.weight.detach()))
|
|
105
|
+
assert not torch.equal(actual, restored)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_sparse_activation_applies_to_conv2d_inputs() -> None:
|
|
109
|
+
module = _make_conv2d()
|
|
110
|
+
compressor = SparseActivationUnstructured(sparsity=0.5)
|
|
111
|
+
x = torch.tensor([[[[1.0, -2.0], [3.0, -4.0]]]], dtype=torch.float32)
|
|
112
|
+
|
|
113
|
+
compressor.attach(module)
|
|
114
|
+
actual = module(x)
|
|
115
|
+
|
|
116
|
+
expected = F.conv2d(sparsify(x, sparsity=0.5), module.weight.detach())
|
|
117
|
+
|
|
118
|
+
torch.testing.assert_close(actual, expected)
|
|
119
|
+
|
|
120
|
+
compressor.detach(module)
|
|
121
|
+
restored = module(x)
|
|
122
|
+
torch.testing.assert_close(restored, F.conv2d(x, module.weight.detach()))
|
|
123
|
+
assert not torch.equal(actual, restored)
|