rasnatune 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 rasnatune contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,79 @@
1
+ Metadata-Version: 2.4
2
+ Name: rasnatune
3
+ Version: 0.0.1
4
+ Summary: Runtime compression tuning utilities for PyTorch
5
+ Author: rasnatune contributors
6
+ License-Expression: MIT
7
+ Keywords: pytorch,compression,sparsification,quantization
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch>=2.0
12
+ Dynamic: license-file
13
+
14
+ # rasnatune
15
+
16
+ `rasnatune` is a lightweight PyTorch compression helper based on forward hooks.
17
+
18
+ ## Installation
19
+
20
+ ```bash
21
+ pip install rasnatune
22
+ ```
23
+
24
+ ## What It Does
25
+
26
+ - Applies compression at runtime using hooks (no permanent weight rewrite).
27
+ - Supports quantization and unstructured sparsification.
28
+ - Targets `torch.nn.Conv2d` and `torch.nn.Linear`.
29
+
30
+ ## Quick Start
31
+
32
+ ```python
33
+ import torch
34
+ from rasnatune import Compression, SparseWeightUnstructured
35
+
36
+ model = torch.nn.Sequential(
37
+ torch.nn.Linear(128, 64),
38
+ torch.nn.ReLU(),
39
+ torch.nn.Linear(64, 10),
40
+ )
41
+
42
+ wrapped = Compression(model)
43
+ wrapped.attach(
44
+ SparseWeightUnstructured,
45
+ filter=lambda m: isinstance(m, torch.nn.Linear),
46
+ sparsity=0.5,
47
+ )
48
+
49
+ x = torch.randn(128)
50
+ y = wrapped(x)
51
+ ```
52
+
53
+ ## Public API
54
+
55
+ Top-level exports:
56
+
57
+ - `rasnatune.Compression`
58
+ - `rasnatune.Compressor`
59
+ - `rasnatune.QuantizeWeight`
60
+ - `rasnatune.QuantizeActivation`
61
+ - `rasnatune.SparseWeightUnstructured`
62
+ - `rasnatune.SparseActivationUnstructured`
63
+
64
+ ## Compression Classes
65
+
66
+ - `QuantizeWeight(min=-128, max=127)`:
67
+ Quantizes layer weights only during forward, then restores original weights.
68
+ - `QuantizeActivation(min=-128, max=127)`:
69
+ Quantizes the first input activation of the layer before forward.
70
+ - `SparseWeightUnstructured(sparsity=0.5)`:
71
+ Applies unstructured sparsity to layer weights only during forward, then restores.
72
+ - `SparseActivationUnstructured(sparsity=0.5)`:
73
+ Applies unstructured sparsity to the first input activation before forward.
74
+
75
+ ## Notes
76
+
77
+ - `Compression.attach` uses `filter=` to choose target modules.
78
+ - If `filter` is omitted, it will try all submodules, and unsupported modules will raise an assertion.
79
+ - Supported module types for built-in compressors are `torch.nn.Conv2d` and `torch.nn.Linear`.
@@ -0,0 +1,66 @@
1
+ # rasnatune
2
+
3
+ `rasnatune` is a lightweight PyTorch compression helper based on forward hooks.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install rasnatune
9
+ ```
10
+
11
+ ## What It Does
12
+
13
+ - Applies compression at runtime using hooks (no permanent weight rewrite).
14
+ - Supports quantization and unstructured sparsification.
15
+ - Targets `torch.nn.Conv2d` and `torch.nn.Linear`.
16
+
17
+ ## Quick Start
18
+
19
+ ```python
20
+ import torch
21
+ from rasnatune import Compression, SparseWeightUnstructured
22
+
23
+ model = torch.nn.Sequential(
24
+ torch.nn.Linear(128, 64),
25
+ torch.nn.ReLU(),
26
+ torch.nn.Linear(64, 10),
27
+ )
28
+
29
+ wrapped = Compression(model)
30
+ wrapped.attach(
31
+ SparseWeightUnstructured,
32
+ filter=lambda m: isinstance(m, torch.nn.Linear),
33
+ sparsity=0.5,
34
+ )
35
+
36
+ x = torch.randn(128)
37
+ y = wrapped(x)
38
+ ```
39
+
40
+ ## Public API
41
+
42
+ Top-level exports:
43
+
44
+ - `rasnatune.Compression`
45
+ - `rasnatune.Compressor`
46
+ - `rasnatune.QuantizeWeight`
47
+ - `rasnatune.QuantizeActivation`
48
+ - `rasnatune.SparseWeightUnstructured`
49
+ - `rasnatune.SparseActivationUnstructured`
50
+
51
+ ## Compression Classes
52
+
53
+ - `QuantizeWeight(min=-128, max=127)`:
54
+ Quantizes layer weights only during forward, then restores original weights.
55
+ - `QuantizeActivation(min=-128, max=127)`:
56
+ Quantizes the first input activation of the layer before forward.
57
+ - `SparseWeightUnstructured(sparsity=0.5)`:
58
+ Applies unstructured sparsity to layer weights only during forward, then restores.
59
+ - `SparseActivationUnstructured(sparsity=0.5)`:
60
+ Applies unstructured sparsity to the first input activation before forward.
61
+
62
+ ## Notes
63
+
64
+ - `Compression.attach` uses `filter=` to choose target modules.
65
+ - If `filter` is omitted, it will try all submodules, and unsupported modules will raise an assertion.
66
+ - Supported module types for built-in compressors are `torch.nn.Conv2d` and `torch.nn.Linear`.
@@ -0,0 +1,18 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "rasnatune"
7
+ version = "0.0.1"
8
+ description = "Runtime compression tuning utilities for PyTorch"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [{ name = "rasnatune contributors" }]
12
+ license = "MIT"
13
+ keywords = ["pytorch", "compression", "sparsification", "quantization"]
14
+ dependencies = ["torch>=2.0"]
15
+
16
+ [tool.setuptools.packages.find]
17
+ include = ["rasnatune*"]
18
+ exclude = ["__pycache__*"]
@@ -0,0 +1,6 @@
1
+ from .base import Compressor
2
+ from .compression import Compression
3
+ from .quantization import QuantizeActivation, QuantizeWeight
4
+ from .sparsification import SparseActivationUnstructured, SparseWeightUnstructured
5
+
6
+ __version__ = "0.0.1"
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+
6
+ class Compressor:
7
+ _backup: torch.Tensor
8
+ _hooks: list[torch.utils.hooks.RemovableHandle]
9
+
10
+ def attach(self, module: torch.nn.Module) -> None:
11
+ raise NotImplementedError
12
+
13
+ def detach(self, module: torch.nn.Module) -> None:
14
+ for hook in self._hooks:
15
+ hook.remove()
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ import torch
7
+
8
+ from .base import Compressor
9
+
10
+
11
+ class Compression(torch.nn.Module):
12
+ def __init__(self, model: torch.nn.Module):
13
+ super().__init__()
14
+ self.model = model
15
+ self.registrations: list[Compressor] = []
16
+
17
+ def forward(self, *args, **kwargs):
18
+ return self.model(*args, **kwargs)
19
+
20
+ def attach(self, compressor: type[Compressor], filter: Callable, *args, **kwargs,):
21
+ for name, module in self.model.named_modules():
22
+ if not filter(module):
23
+ continue
24
+ instance = compressor(*args, **kwargs)
25
+ instance.attach(module)
26
+ self.registrations.append(instance)
27
+
28
+ def detach(self, compressor: Compressor):
29
+ compressor.detach(self.model)
30
+ if compressor in self.registrations:
31
+ self.registrations.remove(compressor)
32
+
33
+ def clear(self):
34
+ for compressor in list(self.registrations):
35
+ self.detach(compressor)
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ import torch
6
+
7
+ from .base import Compressor
8
+
9
+
10
+ def quantize(x: torch.Tensor, qmin: int, qmax: int) -> torch.Tensor:
11
+ max_abs = x.detach().abs().max()
12
+ if max_abs.item() == 0:
13
+ return x
14
+ scale = max_abs / max(qmax, -qmin)
15
+ q = torch.clamp(torch.round(x / scale), qmin, qmax)
16
+ quantized = q * scale
17
+ error = x.detach() - quantized
18
+ return x - error
19
+
20
+
21
+ @dataclass
22
+ class QuantizeWeight(Compressor):
23
+ min: int = -128
24
+ max: int = 127
25
+
26
+ def attach(self, module: torch.nn.Module) -> None:
27
+ def pre_hook(mod: torch.nn.Module, inputs) -> None:
28
+ self._backup = mod.weight.data.detach().clone()
29
+ mod.weight.data.copy_(quantize(mod.weight.data, qmin=self.min, qmax=self.max))
30
+
31
+ def post_hook(mod: torch.nn.Module, inputs, output):
32
+ mod.weight.data.copy_(self._backup)
33
+ return output
34
+
35
+ self._hooks = [
36
+ module.register_forward_pre_hook(pre_hook),
37
+ module.register_forward_hook(post_hook)
38
+ ]
39
+
40
+
41
+ @dataclass
42
+ class QuantizeActivation(Compressor):
43
+ min: int = -128
44
+ max: int = 127
45
+
46
+ def attach(self, module: torch.nn.Module) -> None:
47
+ def pre_hook(_module: torch.nn.Module, inputs):
48
+ return (quantize(inputs[0], qmin=self.min, qmax=self.max), *inputs[1:])
49
+
50
+ self._hooks = [
51
+ module.register_forward_pre_hook(pre_hook)
52
+ ]
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+
7
+ from .base import Compressor
8
+
9
+
10
+ def sparsify(x: torch.Tensor, sparsity: float) -> torch.Tensor:
11
+ if not 0.0 <= sparsity <= 1.0:
12
+ raise ValueError(f"sparsity must be between 0 and 1, got {sparsity}")
13
+
14
+ flat_abs = x.detach().abs().flatten()
15
+ if flat_abs.numel() == 0 or sparsity == 0.0:
16
+ return x
17
+ if sparsity == 1.0:
18
+ return torch.zeros_like(x)
19
+
20
+ num_zero = int(flat_abs.numel() * sparsity)
21
+ if num_zero == 0:
22
+ return x
23
+
24
+ zero_indices = flat_abs.topk(num_zero, largest=False).indices
25
+ sparsity_mask = torch.ones_like(flat_abs, dtype=torch.bool)
26
+ sparsity_mask[zero_indices] = False
27
+ return x * sparsity_mask.view_as(x).to(x.dtype)
28
+
29
+
30
+ @dataclass
31
+ class SparseWeightUnstructured(Compressor):
32
+ sparsity: float = 0.5
33
+
34
+ def attach(self, module: torch.nn.Module) -> None:
35
+ def pre_hook(mod: torch.nn.Module, _inputs) -> None:
36
+ self._backup = mod.weight.data.detach().clone()
37
+ mod.weight.data.copy_(
38
+ sparsify(mod.weight.data, sparsity=self.sparsity)
39
+ )
40
+
41
+ def post_hook(mod: torch.nn.Module, inputs, output):
42
+ mod.weight.data.copy_(self._backup)
43
+ return output
44
+
45
+ self._hooks = [
46
+ module.register_forward_pre_hook(pre_hook),
47
+ module.register_forward_hook(post_hook),
48
+ ]
49
+
50
+
51
+ @dataclass
52
+ class SparseActivationUnstructured(Compressor):
53
+ sparsity: float = 0.5
54
+
55
+ def attach(self, module: torch.nn.Module) -> None:
56
+ def pre_hook(_module: torch.nn.Module, inputs):
57
+ return (sparsify(inputs[0], sparsity=self.sparsity), *inputs[1:])
58
+
59
+ self._hooks = [
60
+ module.register_forward_pre_hook(pre_hook),
61
+ ]
@@ -0,0 +1,79 @@
1
+ Metadata-Version: 2.4
2
+ Name: rasnatune
3
+ Version: 0.0.1
4
+ Summary: Runtime compression tuning utilities for PyTorch
5
+ Author: rasnatune contributors
6
+ License-Expression: MIT
7
+ Keywords: pytorch,compression,sparsification,quantization
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch>=2.0
12
+ Dynamic: license-file
13
+
14
+ # rasnatune
15
+
16
+ `rasnatune` is a lightweight PyTorch compression helper based on forward hooks.
17
+
18
+ ## Installation
19
+
20
+ ```bash
21
+ pip install rasnatune
22
+ ```
23
+
24
+ ## What It Does
25
+
26
+ - Applies compression at runtime using hooks (no permanent weight rewrite).
27
+ - Supports quantization and unstructured sparsification.
28
+ - Targets `torch.nn.Conv2d` and `torch.nn.Linear`.
29
+
30
+ ## Quick Start
31
+
32
+ ```python
33
+ import torch
34
+ from rasnatune import Compression, SparseWeightUnstructured
35
+
36
+ model = torch.nn.Sequential(
37
+ torch.nn.Linear(128, 64),
38
+ torch.nn.ReLU(),
39
+ torch.nn.Linear(64, 10),
40
+ )
41
+
42
+ wrapped = Compression(model)
43
+ wrapped.attach(
44
+ SparseWeightUnstructured,
45
+ filter=lambda m: isinstance(m, torch.nn.Linear),
46
+ sparsity=0.5,
47
+ )
48
+
49
+ x = torch.randn(128)
50
+ y = wrapped(x)
51
+ ```
52
+
53
+ ## Public API
54
+
55
+ Top-level exports:
56
+
57
+ - `rasnatune.Compression`
58
+ - `rasnatune.Compressor`
59
+ - `rasnatune.QuantizeWeight`
60
+ - `rasnatune.QuantizeActivation`
61
+ - `rasnatune.SparseWeightUnstructured`
62
+ - `rasnatune.SparseActivationUnstructured`
63
+
64
+ ## Compression Classes
65
+
66
+ - `QuantizeWeight(min=-128, max=127)`:
67
+ Quantizes layer weights only during forward, then restores original weights.
68
+ - `QuantizeActivation(min=-128, max=127)`:
69
+ Quantizes the first input activation of the layer before forward.
70
+ - `SparseWeightUnstructured(sparsity=0.5)`:
71
+ Applies unstructured sparsity to layer weights only during forward, then restores.
72
+ - `SparseActivationUnstructured(sparsity=0.5)`:
73
+ Applies unstructured sparsity to the first input activation before forward.
74
+
75
+ ## Notes
76
+
77
+ - `Compression.attach` uses `filter=` to choose target modules.
78
+ - If `filter` is omitted, it will try all submodules, and unsupported modules will raise an assertion.
79
+ - Supported module types for built-in compressors are `torch.nn.Conv2d` and `torch.nn.Linear`.
@@ -0,0 +1,16 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ rasnatune/__init__.py
5
+ rasnatune/base.py
6
+ rasnatune/compression.py
7
+ rasnatune/quantization.py
8
+ rasnatune/sparsification.py
9
+ rasnatune.egg-info/PKG-INFO
10
+ rasnatune.egg-info/SOURCES.txt
11
+ rasnatune.egg-info/dependency_links.txt
12
+ rasnatune.egg-info/requires.txt
13
+ rasnatune.egg-info/top_level.txt
14
+ tests/test_compression.py
15
+ tests/test_quantization.py
16
+ tests/test_sparsification.py
@@ -0,0 +1 @@
1
+ torch>=2.0
@@ -0,0 +1 @@
1
+ rasnatune
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,172 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ try:
5
+ _TORCHVISION_NMS_LIB = torch.library.Library("torchvision", "DEF")
6
+ _TORCHVISION_NMS_LIB.define("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")
7
+ except RuntimeError:
8
+ _TORCHVISION_NMS_LIB = None
9
+
10
+ from torchvision.models import efficientnet_b0, resnet50
11
+
12
+ from rasnatune.compression import Compression
13
+ from rasnatune.quantization import QuantizeActivation, QuantizeWeight
14
+ from rasnatune.sparsification import SparseActivationUnstructured, SparseWeightUnstructured
15
+
16
+
17
+ def _supported_module_filter(module: torch.nn.Module) -> bool:
18
+ return isinstance(module, (torch.nn.Conv2d, torch.nn.Linear))
19
+
20
+
21
+ def _supported_modules(model: torch.nn.Module) -> list[torch.nn.Module]:
22
+ return [module for module in model.modules() if _supported_module_filter(module)]
23
+
24
+
25
+ def _make_toy_model() -> torch.nn.Sequential:
26
+ model = torch.nn.Sequential(
27
+ torch.nn.Linear(4, 4, bias=False),
28
+ torch.nn.ReLU(),
29
+ torch.nn.Linear(4, 2, bias=False),
30
+ )
31
+ with torch.no_grad():
32
+ model[0].weight.copy_(
33
+ torch.tensor(
34
+ [
35
+ [1.75, -0.25, 0.5, -1.1],
36
+ [0.2, -0.6, 1.4, -0.8],
37
+ [0.9, 0.3, -0.4, 1.1],
38
+ [-1.2, 0.7, 0.6, -0.5],
39
+ ],
40
+ dtype=torch.float32,
41
+ )
42
+ )
43
+ model[2].weight.copy_(
44
+ torch.tensor(
45
+ [
46
+ [0.8, -1.4, 0.6, -0.2],
47
+ [-0.5, 0.9, -1.1, 1.3],
48
+ ],
49
+ dtype=torch.float32,
50
+ )
51
+ )
52
+ return model
53
+
54
+
55
+ def _attach_quantization(wrapped: Compression) -> None:
56
+ wrapped.attach(QuantizeWeight, filter=_supported_module_filter, min=-8, max=7)
57
+ wrapped.attach(QuantizeActivation, filter=_supported_module_filter, min=-8, max=7)
58
+
59
+
60
+ def _attach_sparsification(wrapped: Compression) -> None:
61
+ wrapped.attach(SparseWeightUnstructured, filter=_supported_module_filter, sparsity=0.25)
62
+ wrapped.attach(SparseActivationUnstructured, filter=_supported_module_filter, sparsity=0.25)
63
+
64
+
65
+ def _run_compressed_training_step(model: torch.nn.Module, attach_fn) -> None:
66
+ wrapped = Compression(model)
67
+ modules = _supported_modules(wrapped.model)
68
+ first_module = modules[0]
69
+ original_weight = first_module.weight.detach().clone()
70
+
71
+ attach_fn(wrapped)
72
+
73
+ assert len(wrapped.registrations) == 2 * len(modules)
74
+
75
+ optimizer = torch.optim.SGD(wrapped.parameters(), lr=1e-2)
76
+ inputs = torch.randn(2, 3, 64, 64, dtype=torch.float32)
77
+ targets = torch.tensor([1, 3], dtype=torch.long)
78
+
79
+ optimizer.zero_grad(set_to_none=True)
80
+ logits = wrapped(inputs)
81
+ loss = F.cross_entropy(logits, targets)
82
+ loss.backward()
83
+
84
+ gradients = [parameter.grad for parameter in wrapped.parameters() if parameter.grad is not None]
85
+ assert gradients
86
+ assert all(torch.isfinite(gradient).all().item() for gradient in gradients)
87
+ torch.testing.assert_close(first_module.weight.detach(), original_weight)
88
+
89
+ parameters_before_step = [parameter.detach().clone() for parameter in wrapped.parameters()]
90
+ optimizer.step()
91
+
92
+ assert any(
93
+ not torch.equal(parameter.detach(), before_step)
94
+ for parameter, before_step in zip(wrapped.parameters(), parameters_before_step)
95
+ )
96
+
97
+ wrapped.clear()
98
+ assert wrapped.registrations == []
99
+
100
+ cleared_logits = wrapped(inputs)
101
+ assert cleared_logits.shape == logits.shape
102
+
103
+
104
+ def test_compression_attach_and_detach_manage_registrations() -> None:
105
+ model = _make_toy_model()
106
+ wrapped = Compression(model)
107
+ x = torch.tensor([[1.0, -2.0, 3.0, -4.0]], dtype=torch.float32)
108
+
109
+ wrapped.attach(QuantizeWeight, filter=lambda module: isinstance(module, torch.nn.Linear), min=-1, max=1)
110
+
111
+ assert len(wrapped.registrations) == 2
112
+
113
+ compressed_output = wrapped(x)
114
+ wrapped.detach(wrapped.registrations[0])
115
+
116
+ assert len(wrapped.registrations) == 1
117
+
118
+ wrapped.clear()
119
+ assert wrapped.registrations == []
120
+
121
+ restored_output = wrapped(x)
122
+ reference_output = model(x)
123
+
124
+ torch.testing.assert_close(restored_output, reference_output)
125
+ assert not torch.equal(compressed_output, reference_output)
126
+
127
+
128
+ def test_compression_clear_removes_sparse_registrations() -> None:
129
+ model = _make_toy_model()
130
+ wrapped = Compression(model)
131
+
132
+ _attach_sparsification(wrapped)
133
+
134
+ assert len(wrapped.registrations) == 4
135
+
136
+ wrapped.clear()
137
+
138
+ assert wrapped.registrations == []
139
+ output = wrapped(torch.randn(1, 4, dtype=torch.float32))
140
+ assert output.shape == (1, 2)
141
+
142
+
143
+ def test_resnet50_quantized_training_step_runs() -> None:
144
+ torch.manual_seed(0)
145
+ model = resnet50(weights=None, num_classes=5)
146
+ model.train()
147
+
148
+ _run_compressed_training_step(model, _attach_quantization)
149
+
150
+
151
+ def test_efficientnet_b0_quantized_training_step_runs() -> None:
152
+ torch.manual_seed(0)
153
+ model = efficientnet_b0(weights=None, num_classes=5)
154
+ model.train()
155
+
156
+ _run_compressed_training_step(model, _attach_quantization)
157
+
158
+
159
+ def test_resnet50_sparse_training_step_runs() -> None:
160
+ torch.manual_seed(0)
161
+ model = resnet50(weights=None, num_classes=5)
162
+ model.train()
163
+
164
+ _run_compressed_training_step(model, _attach_sparsification)
165
+
166
+
167
+ def test_efficientnet_b0_sparse_training_step_runs() -> None:
168
+ torch.manual_seed(0)
169
+ model = efficientnet_b0(weights=None, num_classes=5)
170
+ model.train()
171
+
172
+ _run_compressed_training_step(model, _attach_sparsification)
@@ -0,0 +1,127 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from rasnatune.quantization import QuantizeActivation, QuantizeWeight, quantize
5
+
6
+
7
+ def _make_linear() -> torch.nn.Linear:
8
+ module = torch.nn.Linear(2, 1, bias=False)
9
+ with torch.no_grad():
10
+ module.weight.copy_(torch.tensor([[1.75, -0.25]], dtype=torch.float32))
11
+ return module
12
+
13
+
14
+ def _make_conv2d() -> torch.nn.Conv2d:
15
+ module = torch.nn.Conv2d(1, 1, kernel_size=2, bias=False)
16
+ with torch.no_grad():
17
+ module.weight.copy_(
18
+ torch.tensor([[[[1.75, -0.25], [0.5, -1.1]]]], dtype=torch.float32)
19
+ )
20
+ return module
21
+
22
+
23
+ def test_quantize_maps_values_to_expected_levels() -> None:
24
+ x = torch.tensor([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=torch.float32)
25
+
26
+ actual = quantize(x, qmin=-2, qmax=2)
27
+ expected = torch.tensor([-3.0, -1.5, 0.0, 1.5, 3.0], dtype=torch.float32)
28
+
29
+ torch.testing.assert_close(actual, expected)
30
+
31
+
32
+ def test_quantize_clamps_to_asymmetric_range() -> None:
33
+ x = torch.tensor([-2.0, -0.6, 0.2, 1.7], dtype=torch.float32)
34
+
35
+ actual = quantize(x, qmin=-1, qmax=2)
36
+ expected = torch.tensor([-1.0, -1.0, 0.0, 2.0], dtype=torch.float32)
37
+
38
+ torch.testing.assert_close(actual, expected)
39
+
40
+
41
+ def test_quantize_preserves_straight_through_gradients() -> None:
42
+ x = torch.tensor([0.3, -0.6], dtype=torch.float32, requires_grad=True)
43
+ upstream = torch.tensor([2.0, -3.0], dtype=torch.float32)
44
+
45
+ loss = (quantize(x, qmin=-2, qmax=1) * upstream).sum()
46
+ loss.backward()
47
+
48
+ torch.testing.assert_close(x.grad, upstream)
49
+
50
+
51
+ def test_quantize_zero_tensor_returns_zero_without_nan() -> None:
52
+ x = torch.zeros(4, dtype=torch.float32, requires_grad=True)
53
+
54
+ actual = quantize(x, qmin=-128, qmax=127)
55
+ actual.sum().backward()
56
+
57
+ torch.testing.assert_close(actual, torch.zeros_like(x))
58
+ torch.testing.assert_close(x.grad, torch.ones_like(x))
59
+ assert torch.isfinite(actual).all().item()
60
+
61
+
62
+ def test_quantize_weight_applies_and_restores_on_linear() -> None:
63
+ module = _make_linear()
64
+ compressor = QuantizeWeight(min=-1, max=1)
65
+ x = torch.tensor([[2.0, 3.0]], dtype=torch.float32)
66
+ original_weight = module.weight.detach().clone()
67
+
68
+ compressor.attach(module)
69
+ actual = module(x)
70
+
71
+ expected_weight = quantize(original_weight, qmin=-1, qmax=1)
72
+ expected = F.linear(x, expected_weight)
73
+
74
+ torch.testing.assert_close(actual, expected)
75
+ torch.testing.assert_close(module.weight.detach(), original_weight)
76
+
77
+
78
+ def test_quantize_weight_applies_and_restores_on_conv2d() -> None:
79
+ module = _make_conv2d()
80
+ compressor = QuantizeWeight(min=-1, max=1)
81
+ x = torch.tensor([[[[1.0, 0.5], [-0.25, 0.75]]]], dtype=torch.float32)
82
+ original_weight = module.weight.detach().clone()
83
+
84
+ compressor.attach(module)
85
+ actual = module(x)
86
+
87
+ expected_weight = quantize(original_weight, qmin=-1, qmax=1)
88
+ expected = F.conv2d(x, expected_weight)
89
+
90
+ torch.testing.assert_close(actual, expected)
91
+ torch.testing.assert_close(module.weight.detach(), original_weight)
92
+
93
+
94
+ def test_quantize_activation_applies_to_linear_inputs() -> None:
95
+ module = _make_linear()
96
+ compressor = QuantizeActivation(min=-1, max=1)
97
+ x = torch.tensor([[1.75, -0.25]], dtype=torch.float32)
98
+
99
+ compressor.attach(module)
100
+ actual = module(x)
101
+
102
+ expected = F.linear(quantize(x, qmin=-1, qmax=1), module.weight.detach())
103
+
104
+ torch.testing.assert_close(actual, expected)
105
+
106
+ compressor.detach(module)
107
+ restored = module(x)
108
+ torch.testing.assert_close(restored, F.linear(x, module.weight.detach()))
109
+ assert not torch.equal(actual, restored)
110
+
111
+
112
+ def test_quantize_activation_applies_to_conv2d_inputs() -> None:
113
+ module = _make_conv2d()
114
+ compressor = QuantizeActivation(min=-1, max=1)
115
+ x = torch.tensor([[[[1.75, -0.25], [0.5, -1.1]]]], dtype=torch.float32)
116
+
117
+ compressor.attach(module)
118
+ actual = module(x)
119
+
120
+ expected = F.conv2d(quantize(x, qmin=-1, qmax=1), module.weight.detach())
121
+
122
+ torch.testing.assert_close(actual, expected)
123
+
124
+ compressor.detach(module)
125
+ restored = module(x)
126
+ torch.testing.assert_close(restored, F.conv2d(x, module.weight.detach()))
127
+ assert not torch.equal(actual, restored)
@@ -0,0 +1,123 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from rasnatune.sparsification import (
6
+ SparseActivationUnstructured,
7
+ SparseWeightUnstructured,
8
+ sparsify,
9
+ )
10
+
11
+
12
+ def _make_linear() -> torch.nn.Linear:
13
+ module = torch.nn.Linear(4, 1, bias=False)
14
+ with torch.no_grad():
15
+ module.weight.copy_(torch.tensor([[1.0, -2.0, 3.0, -4.0]], dtype=torch.float32))
16
+ return module
17
+
18
+
19
+ def _make_conv2d() -> torch.nn.Conv2d:
20
+ module = torch.nn.Conv2d(1, 1, kernel_size=2, bias=False)
21
+ with torch.no_grad():
22
+ module.weight.copy_(
23
+ torch.tensor([[[[1.0, -2.0], [3.0, -4.0]]]], dtype=torch.float32)
24
+ )
25
+ return module
26
+
27
+
28
+ def test_sparsify_returns_input_for_zero_sparsity() -> None:
29
+ x = torch.tensor([1.0, -2.0, 3.0, -4.0], dtype=torch.float32)
30
+
31
+ actual = sparsify(x, sparsity=0.0)
32
+
33
+ torch.testing.assert_close(actual, x)
34
+
35
+
36
+ def test_sparsify_zeroes_smallest_magnitudes() -> None:
37
+ x = torch.tensor([1.0, -2.0, 3.0, -4.0], dtype=torch.float32)
38
+
39
+ actual = sparsify(x, sparsity=0.5)
40
+ expected = torch.tensor([0.0, 0.0, 3.0, -4.0], dtype=torch.float32)
41
+
42
+ torch.testing.assert_close(actual, expected)
43
+
44
+
45
+ def test_sparsify_zeros_all_values_for_full_sparsity() -> None:
46
+ x = torch.tensor([1.0, -2.0, 3.0, -4.0], dtype=torch.float32)
47
+
48
+ actual = sparsify(x, sparsity=1.0)
49
+
50
+ torch.testing.assert_close(actual, torch.zeros_like(x))
51
+
52
+
53
+ def test_sparsify_rejects_invalid_sparsity() -> None:
54
+ with pytest.raises(ValueError):
55
+ sparsify(torch.ones(2, dtype=torch.float32), sparsity=1.5)
56
+
57
+
58
+ def test_sparse_weight_applies_and_restores_on_linear() -> None:
59
+ module = _make_linear()
60
+ compressor = SparseWeightUnstructured(sparsity=0.5)
61
+ x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], dtype=torch.float32)
62
+ original_weight = module.weight.detach().clone()
63
+
64
+ compressor.attach(module)
65
+ actual = module(x)
66
+
67
+ expected_weight = sparsify(original_weight, sparsity=0.5)
68
+ expected = F.linear(x, expected_weight)
69
+
70
+ torch.testing.assert_close(actual, expected)
71
+ torch.testing.assert_close(module.weight.detach(), original_weight)
72
+
73
+
74
+ def test_sparse_weight_applies_and_restores_on_conv2d() -> None:
75
+ module = _make_conv2d()
76
+ compressor = SparseWeightUnstructured(sparsity=0.5)
77
+ x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32)
78
+ original_weight = module.weight.detach().clone()
79
+
80
+ compressor.attach(module)
81
+ actual = module(x)
82
+
83
+ expected_weight = sparsify(original_weight, sparsity=0.5)
84
+ expected = F.conv2d(x, expected_weight)
85
+
86
+ torch.testing.assert_close(actual, expected)
87
+ torch.testing.assert_close(module.weight.detach(), original_weight)
88
+
89
+
90
+ def test_sparse_activation_applies_to_linear_inputs() -> None:
91
+ module = _make_linear()
92
+ compressor = SparseActivationUnstructured(sparsity=0.5)
93
+ x = torch.tensor([[1.0, -2.0, 3.0, -4.0]], dtype=torch.float32)
94
+
95
+ compressor.attach(module)
96
+ actual = module(x)
97
+
98
+ expected = F.linear(sparsify(x, sparsity=0.5), module.weight.detach())
99
+
100
+ torch.testing.assert_close(actual, expected)
101
+
102
+ compressor.detach(module)
103
+ restored = module(x)
104
+ torch.testing.assert_close(restored, F.linear(x, module.weight.detach()))
105
+ assert not torch.equal(actual, restored)
106
+
107
+
108
+ def test_sparse_activation_applies_to_conv2d_inputs() -> None:
109
+ module = _make_conv2d()
110
+ compressor = SparseActivationUnstructured(sparsity=0.5)
111
+ x = torch.tensor([[[[1.0, -2.0], [3.0, -4.0]]]], dtype=torch.float32)
112
+
113
+ compressor.attach(module)
114
+ actual = module(x)
115
+
116
+ expected = F.conv2d(sparsify(x, sparsity=0.5), module.weight.detach())
117
+
118
+ torch.testing.assert_close(actual, expected)
119
+
120
+ compressor.detach(module)
121
+ restored = module(x)
122
+ torch.testing.assert_close(restored, F.conv2d(x, module.weight.detach()))
123
+ assert not torch.equal(actual, restored)