saga-activation 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,238 @@
1
+ Metadata-Version: 2.4
2
+ Name: saga-activation
3
+ Version: 0.1.0
4
+ Summary: Spatially-Adaptive Gated Activation (SAGA) for medical image restoration
5
+ Author: Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
6
+ Author-email: "Siju K.S." <sijuks@example.com>
7
+ License: MIT
8
+ Project-URL: Homepage, https://github.com/sijuswamyresearch/saga-activation
9
+ Project-URL: Documentation, https://sijuswamyresearch.github.io/saga-activation
10
+ Project-URL: Repository, https://github.com/sijuswamyresearch/saga-activation
11
+ Project-URL: Bug Tracker, https://github.com/sijuswamyresearch/saga-activation/issues
12
+ Project-URL: Paper, https://doi.org/10.1016/j.health.2026.100468
13
+ Keywords: deep learning,activation function,medical imaging,image restoration,deblurring,PyTorch
14
+ Classifier: Development Status :: 4 - Beta
15
+ Classifier: Intended Audience :: Science/Research
16
+ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Classifier: License :: OSI Approved :: MIT License
19
+ Classifier: Programming Language :: Python :: 3
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
23
+ Requires-Python: >=3.10
24
+ Description-Content-Type: text/markdown
25
+ License-File: LICENSE
26
+ Requires-Dist: torch>=2.0
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest>=7.0; extra == "dev"
29
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
30
+ Provides-Extra: docs
31
+ Requires-Dist: sphinx>=7.0; extra == "docs"
32
+ Requires-Dist: sphinx-rtd-theme>=2.0; extra == "docs"
33
+ Requires-Dist: myst-parser>=2.0; extra == "docs"
34
+ Dynamic: license-file
35
+
36
+ # SAGA — Spatially-Adaptive Gated Activation
37
+
38
+ [![CI](https://github.com/sijuswamyresearch/SAGA/actions/workflows/ci.yml/badge.svg)](https://github.com/sijuswamyresearch/SAGA/actions)
39
+ [![PyPI version](https://badge.fury.io/py/saga-activation.svg)](https://pypi.org/project/saga-activation/)
40
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
41
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.XXXXXXX.svg)](https://doi.org/10.5281/zenodo.XXXXXXX)
42
+ [![Paper](https://img.shields.io/badge/Paper-Healthcare%20Analytics-blue)](https://doi.org/10.1016/j.health.2026.100468)
43
+
44
+ > **An Interpretable Deep Learning Method for Medical Image Deblurring and Restoration**
45
+ > Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
46
+ > *Healthcare Analytics* 9 (2026) 100468 · [doi:10.1016/j.health.2026.100468](https://doi.org/10.1016/j.health.2026.100468)
47
+
48
+ ---
49
+
50
+ ## Overview
51
+
52
+ Standard activation functions (ReLU, SiLU, GELU) treat every spatial location in a feature map identically. In medical images — CT slices, DXA scans — the information content is *not* spatially uniform: anatomical boundaries carry high-frequency diagnostically-critical detail while homogeneous regions (background, soft tissue) require smooth suppression.
53
+
54
+ **SAGA** introduces a *learned spatial gating map* that modulates the activation response position-by-position:
55
+
56
+ ```
57
+ G(X) = σ(W_g * X) # spatial gate (depthwise-separable conv)
58
+ SAGA(X) = G(X) ⊙ φ(X) # φ = SiLU (default)
59
+ ```
60
+
61
+ This two-path design lets the network selectively amplify high-frequency boundary signals while smoothly gating uniform background areas — without increasing the depth of the network.
62
+
63
+ ---
64
+
65
+ ## Installation
66
+
67
+ ```bash
68
+ pip install saga-activation
69
+ ```
70
+
71
+ Or install from source:
72
+
73
+ ```bash
74
+ git clone https://github.com/sijuswamyresearch/SAGA.git
75
+ cd SAGA
76
+ pip install -e ".[dev]"
77
+ ```
78
+
79
+ **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.0
80
+
81
+ ---
82
+
83
+ ## Quick Start
84
+
85
+ ### Drop-in activation replacement
86
+
87
+ ```python
88
+ import torch
89
+ from saga import SAGA
90
+
91
+ # Replace any activation layer with SAGA
92
+ act = SAGA(in_channels=64) # matches the channel dim of your feature map
93
+ x = torch.randn(2, 64, 256, 256) # (B, C, H, W)
94
+ y = act(x) # same shape: (2, 64, 256, 256)
95
+ ```
96
+
97
+ ### Inside a U-Net encoder block
98
+
99
+ ```python
100
+ import torch.nn as nn
101
+ from saga import SAGA
102
+
103
+ class EncoderBlock(nn.Module):
104
+ def __init__(self, in_ch, out_ch):
105
+ super().__init__()
106
+ self.block = nn.Sequential(
107
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
108
+ nn.BatchNorm2d(out_ch),
109
+ SAGA(out_ch), # ← swap in SAGA here
110
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
111
+ nn.BatchNorm2d(out_ch),
112
+ SAGA(out_ch),
113
+ )
114
+ self.pool = nn.MaxPool2d(2)
115
+
116
+ def forward(self, x):
117
+ return self.pool(self.block(x))
118
+ ```
119
+
120
+ ### Pre-built residual blocks
121
+
122
+ ```python
123
+ from saga import SAGAResBlock, SAGABottleneck
124
+
125
+ res = SAGAResBlock(64) # standard residual block
126
+ bottle = SAGABottleneck(64, out_channels=128) # bottleneck variant
127
+ ```
128
+
129
+ ### Base-activation variants
130
+
131
+ ```python
132
+ from saga import SAGA
133
+
134
+ act_relu = SAGA(64, base_activation="relu")
135
+ act_gelu = SAGA(64, base_activation="gelu")
136
+ act_tanh = SAGA(64, base_activation="tanh")
137
+ ```
138
+
139
+ ### Gate curriculum training
140
+
141
+ ```python
142
+ from saga.utils import freeze_gate, unfreeze_gate
143
+
144
+ # Phase 1 – train backbone only
145
+ freeze_gate(model)
146
+ train(model, epochs=10, lr=1e-3)
147
+
148
+ # Phase 2 – fine-tune gates
149
+ unfreeze_gate(model)
150
+ train(model, epochs=5, lr=1e-4)
151
+ ```
152
+
153
+ ---
154
+
155
+ ## Repository Structure
156
+
157
+ ```
158
+ SAGA/
159
+ ├── saga/ # installable Python package
160
+ │ ├── __init__.py
161
+ │ ├── activation.py # SAGA operator (core)
162
+ │ ├── blocks.py # SAGAResBlock, SAGABottleneck
163
+ │ └── utils.py # parameter counting, gate freeze helpers
164
+
165
+ ├── tests/
166
+ │ ├── conftest.py
167
+ │ └── test_saga.py # pytest suite (shapes, edge cases, GPU, gradients)
168
+
169
+ ├── SAGA_Supplementary_Code/ # original experimental pipeline
170
+ │ ├── models/
171
+ │ │ ├── saga_layer.py # raw research implementation
172
+ │ │ ├── unet.py
173
+ │ │ ├── resnet.py
174
+ │ │ ├── edsr.py
175
+ │ │ └── vggnet.py
176
+ │ ├── generate_dataset.py
177
+ │ ├── train.py
178
+ │ ├── evaluate.py
179
+ │ ├── xai_analysis.py
180
+ │ └── clinical_validation.py
181
+
182
+ ├── docs/ # Sphinx documentation source
183
+ ├── .github/workflows/ci.yml # GitHub Actions CI
184
+ ├── pyproject.toml
185
+ └── README.md
186
+ ```
187
+
188
+ ---
189
+
190
+ ## Experimental Results (summary)
191
+
192
+ | Model | Activation | CT PSNR (dB) | CT SSIM | DXA PSNR (dB) | DXA SSIM |
193
+ |---------------|-----------|:------------:|:-------:|:-------------:|:--------:|
194
+ | U-Net | ReLU | 32.14 | 0.891 | 30.87 | 0.873 |
195
+ | U-Net | SiLU | 33.01 | 0.902 | 31.54 | 0.881 |
196
+ | **U-Net** | **SAGA** | **34.67** | **0.921** | **33.12** | **0.903** |
197
+ | DeblurResNet | ReLU | 31.89 | 0.883 | 30.21 | 0.864 |
198
+ | **DeblurResNet** | **SAGA** | **34.11** | **0.916** | **32.78** | **0.897** |
199
+
200
+ Full results and ablation studies are reported in the paper.
201
+
202
+ ---
203
+
204
+ ## Running the Tests
205
+
206
+ ```bash
207
+ pytest tests/ -v
208
+ ```
209
+
210
+ To run with coverage:
211
+
212
+ ```bash
213
+ pytest tests/ --cov=saga --cov-report=term-missing
214
+ ```
215
+
216
+ ---
217
+
218
+ ## Citing
219
+
220
+ If SAGA is useful in your research, please cite:
221
+
222
+ ```bibtex
223
+ @article{siju2026saga,
224
+ title = {An interpretable deep learning method for medical image deblurring and restoration},
225
+ author = {Siju K.S. and Vipin Venugopal and Mithun Kumar Kar and Jayakrishnan Anandakrishnan},
226
+ journal = {Healthcare Analytics},
227
+ volume = {9},
228
+ pages = {100468},
229
+ year = {2026},
230
+ doi = {10.1016/j.health.2026.100468}
231
+ }
232
+ ```
233
+
234
+ ---
235
+
236
+ ## License
237
+
238
+ [MIT](LICENSE)
@@ -0,0 +1,203 @@
1
+ # SAGA — Spatially-Adaptive Gated Activation
2
+
3
+ [![CI](https://github.com/sijuswamyresearch/SAGA/actions/workflows/ci.yml/badge.svg)](https://github.com/sijuswamyresearch/SAGA/actions)
4
+ [![PyPI version](https://badge.fury.io/py/saga-activation.svg)](https://pypi.org/project/saga-activation/)
5
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
6
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.XXXXXXX.svg)](https://doi.org/10.5281/zenodo.XXXXXXX)
7
+ [![Paper](https://img.shields.io/badge/Paper-Healthcare%20Analytics-blue)](https://doi.org/10.1016/j.health.2026.100468)
8
+
9
+ > **An Interpretable Deep Learning Method for Medical Image Deblurring and Restoration**
10
+ > Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
11
+ > *Healthcare Analytics* 9 (2026) 100468 · [doi:10.1016/j.health.2026.100468](https://doi.org/10.1016/j.health.2026.100468)
12
+
13
+ ---
14
+
15
+ ## Overview
16
+
17
+ Standard activation functions (ReLU, SiLU, GELU) treat every spatial location in a feature map identically. In medical images — CT slices, DXA scans — the information content is *not* spatially uniform: anatomical boundaries carry high-frequency diagnostically-critical detail while homogeneous regions (background, soft tissue) require smooth suppression.
18
+
19
+ **SAGA** introduces a *learned spatial gating map* that modulates the activation response position-by-position:
20
+
21
+ ```
22
+ G(X) = σ(W_g * X) # spatial gate (depthwise-separable conv)
23
+ SAGA(X) = G(X) ⊙ φ(X) # φ = SiLU (default)
24
+ ```
25
+
26
+ This two-path design lets the network selectively amplify high-frequency boundary signals while smoothly gating uniform background areas — without increasing the depth of the network.
27
+
28
+ ---
29
+
30
+ ## Installation
31
+
32
+ ```bash
33
+ pip install saga-activation
34
+ ```
35
+
36
+ Or install from source:
37
+
38
+ ```bash
39
+ git clone https://github.com/sijuswamyresearch/SAGA.git
40
+ cd SAGA
41
+ pip install -e ".[dev]"
42
+ ```
43
+
44
+ **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.0
45
+
46
+ ---
47
+
48
+ ## Quick Start
49
+
50
+ ### Drop-in activation replacement
51
+
52
+ ```python
53
+ import torch
54
+ from saga import SAGA
55
+
56
+ # Replace any activation layer with SAGA
57
+ act = SAGA(in_channels=64) # matches the channel dim of your feature map
58
+ x = torch.randn(2, 64, 256, 256) # (B, C, H, W)
59
+ y = act(x) # same shape: (2, 64, 256, 256)
60
+ ```
61
+
62
+ ### Inside a U-Net encoder block
63
+
64
+ ```python
65
+ import torch.nn as nn
66
+ from saga import SAGA
67
+
68
+ class EncoderBlock(nn.Module):
69
+ def __init__(self, in_ch, out_ch):
70
+ super().__init__()
71
+ self.block = nn.Sequential(
72
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
73
+ nn.BatchNorm2d(out_ch),
74
+ SAGA(out_ch), # ← swap in SAGA here
75
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
76
+ nn.BatchNorm2d(out_ch),
77
+ SAGA(out_ch),
78
+ )
79
+ self.pool = nn.MaxPool2d(2)
80
+
81
+ def forward(self, x):
82
+ return self.pool(self.block(x))
83
+ ```
84
+
85
+ ### Pre-built residual blocks
86
+
87
+ ```python
88
+ from saga import SAGAResBlock, SAGABottleneck
89
+
90
+ res = SAGAResBlock(64) # standard residual block
91
+ bottle = SAGABottleneck(64, out_channels=128) # bottleneck variant
92
+ ```
93
+
94
+ ### Base-activation variants
95
+
96
+ ```python
97
+ from saga import SAGA
98
+
99
+ act_relu = SAGA(64, base_activation="relu")
100
+ act_gelu = SAGA(64, base_activation="gelu")
101
+ act_tanh = SAGA(64, base_activation="tanh")
102
+ ```
103
+
104
+ ### Gate curriculum training
105
+
106
+ ```python
107
+ from saga.utils import freeze_gate, unfreeze_gate
108
+
109
+ # Phase 1 – train backbone only
110
+ freeze_gate(model)
111
+ train(model, epochs=10, lr=1e-3)
112
+
113
+ # Phase 2 – fine-tune gates
114
+ unfreeze_gate(model)
115
+ train(model, epochs=5, lr=1e-4)
116
+ ```
117
+
118
+ ---
119
+
120
+ ## Repository Structure
121
+
122
+ ```
123
+ SAGA/
124
+ ├── saga/ # installable Python package
125
+ │ ├── __init__.py
126
+ │ ├── activation.py # SAGA operator (core)
127
+ │ ├── blocks.py # SAGAResBlock, SAGABottleneck
128
+ │ └── utils.py # parameter counting, gate freeze helpers
129
+
130
+ ├── tests/
131
+ │ ├── conftest.py
132
+ │ └── test_saga.py # pytest suite (shapes, edge cases, GPU, gradients)
133
+
134
+ ├── SAGA_Supplementary_Code/ # original experimental pipeline
135
+ │ ├── models/
136
+ │ │ ├── saga_layer.py # raw research implementation
137
+ │ │ ├── unet.py
138
+ │ │ ├── resnet.py
139
+ │ │ ├── edsr.py
140
+ │ │ └── vggnet.py
141
+ │ ├── generate_dataset.py
142
+ │ ├── train.py
143
+ │ ├── evaluate.py
144
+ │ ├── xai_analysis.py
145
+ │ └── clinical_validation.py
146
+
147
+ ├── docs/ # Sphinx documentation source
148
+ ├── .github/workflows/ci.yml # GitHub Actions CI
149
+ ├── pyproject.toml
150
+ └── README.md
151
+ ```
152
+
153
+ ---
154
+
155
+ ## Experimental Results (summary)
156
+
157
+ | Model | Activation | CT PSNR (dB) | CT SSIM | DXA PSNR (dB) | DXA SSIM |
158
+ |---------------|-----------|:------------:|:-------:|:-------------:|:--------:|
159
+ | U-Net | ReLU | 32.14 | 0.891 | 30.87 | 0.873 |
160
+ | U-Net | SiLU | 33.01 | 0.902 | 31.54 | 0.881 |
161
+ | **U-Net** | **SAGA** | **34.67** | **0.921** | **33.12** | **0.903** |
162
+ | DeblurResNet | ReLU | 31.89 | 0.883 | 30.21 | 0.864 |
163
+ | **DeblurResNet** | **SAGA** | **34.11** | **0.916** | **32.78** | **0.897** |
164
+
165
+ Full results and ablation studies are reported in the paper.
166
+
167
+ ---
168
+
169
+ ## Running the Tests
170
+
171
+ ```bash
172
+ pytest tests/ -v
173
+ ```
174
+
175
+ To run with coverage:
176
+
177
+ ```bash
178
+ pytest tests/ --cov=saga --cov-report=term-missing
179
+ ```
180
+
181
+ ---
182
+
183
+ ## Citing
184
+
185
+ If SAGA is useful in your research, please cite:
186
+
187
+ ```bibtex
188
+ @article{siju2026saga,
189
+ title = {An interpretable deep learning method for medical image deblurring and restoration},
190
+ author = {Siju K.S. and Vipin Venugopal and Mithun Kumar Kar and Jayakrishnan Anandakrishnan},
191
+ journal = {Healthcare Analytics},
192
+ volume = {9},
193
+ pages = {100468},
194
+ year = {2026},
195
+ doi = {10.1016/j.health.2026.100468}
196
+ }
197
+ ```
198
+
199
+ ---
200
+
201
+ ## License
202
+
203
+ [MIT](LICENSE)
@@ -0,0 +1,63 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "saga-activation"
7
+ version = "0.1.0"
8
+ description = "Spatially-Adaptive Gated Activation (SAGA) for medical image restoration"
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ authors = [
12
+ { name = "Siju K.S.", email = "sijuks@example.com" },
13
+ { name = "Vipin Venugopal" },
14
+ { name = "Mithun Kumar Kar" },
15
+ { name = "Jayakrishnan Anandakrishnan" },
16
+ ]
17
+ keywords = [
18
+ "deep learning", "activation function", "medical imaging",
19
+ "image restoration", "deblurring", "PyTorch",
20
+ ]
21
+ classifiers = [
22
+ "Development Status :: 4 - Beta",
23
+ "Intended Audience :: Science/Research",
24
+ "Topic :: Scientific/Engineering :: Medical Science Apps.",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ "License :: OSI Approved :: MIT License",
27
+ "Programming Language :: Python :: 3",
28
+ "Programming Language :: Python :: 3.10",
29
+ "Programming Language :: Python :: 3.11",
30
+ "Programming Language :: Python :: 3.12",
31
+ ]
32
+ requires-python = ">=3.10"
33
+ dependencies = [
34
+ "torch>=2.0",
35
+ ]
36
+
37
+ [project.optional-dependencies]
38
+ dev = [
39
+ "pytest>=7.0",
40
+ "pytest-cov>=4.0",
41
+ ]
42
+ docs = [
43
+ "sphinx>=7.0",
44
+ "sphinx-rtd-theme>=2.0",
45
+ "myst-parser>=2.0",
46
+ ]
47
+
48
+ [project.urls]
49
+ Homepage = "https://github.com/sijuswamyresearch/saga-activation"
50
+ Documentation = "https://sijuswamyresearch.github.io/saga-activation"
51
+ Repository = "https://github.com/sijuswamyresearch/saga-activation"
52
+ "Bug Tracker" = "https://github.com/sijuswamyresearch/saga-activation/issues"
53
+ Paper = "https://doi.org/10.1016/j.health.2026.100468"
54
+ [tool.setuptools.packages.find]
55
+ where = ["src"]
56
+ include = ["saga*"]
57
+
58
+ [tool.pytest.ini_options]
59
+ testpaths = ["tests"]
60
+ addopts = "-v --tb=short"
61
+
62
+ [tool.coverage.run]
63
+ source = ["saga"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,40 @@
1
+ """
2
+ SAGA — Spatially-Adaptive Gated Activation for Medical Image Restoration
3
+ =========================================================================
4
+
5
+ Quick start
6
+ -----------
7
+ >>> import torch
8
+ >>> from saga import SAGA
9
+ >>> act = SAGA(in_channels=64)
10
+ >>> x = torch.randn(2, 64, 256, 256)
11
+ >>> y = act(x) # same shape, spatially gated
12
+
13
+ Reference
14
+ ---------
15
+ Siju K.S. et al. "An interpretable deep learning method for medical image
16
+ deblurring and restoration." Healthcare Analytics 9 (2026) 100468.
17
+ https://doi.org/10.1016/j.health.2026.100468
18
+ """
19
+
20
+ from importlib.metadata import version, PackageNotFoundError
21
+
22
+ try:
23
+ __version__: str = version("saga-activation")
24
+ except PackageNotFoundError:
25
+ __version__ = "0.1.0"
26
+
27
+ from .activation import SAGA, SAGALayer
28
+ from .blocks import SAGAResBlock, SAGABottleneck
29
+ from .utils import count_parameters, freeze_gate, unfreeze_gate
30
+
31
+ __all__ = [
32
+ "SAGA",
33
+ "SAGALayer",
34
+ "SAGAResBlock",
35
+ "SAGABottleneck",
36
+ "count_parameters",
37
+ "freeze_gate",
38
+ "unfreeze_gate",
39
+ "__version__",
40
+ ]
@@ -0,0 +1,114 @@
1
+ """
2
+ saga.activation
3
+ ===============
4
+ Spatially-Adaptive Gated Activation (SAGA) operator.
5
+
6
+ SAGA extracts spatial context via a depthwise convolution, calculates a
7
+ residual boost, and dynamically gates this boost before adding it back to
8
+ the original input. This enables the network to selectively route gradient
9
+ flow through high-frequency anatomical boundary regions.
10
+
11
+ Reference
12
+ ---------
13
+ Siju K.S., Venugopal V., Kar M.K., Anandakrishnan J.
14
+ "An interpretable deep learning method for medical image deblurring and
15
+ restoration." Healthcare Analytics 9 (2026) 100468.
16
+ https://doi.org/10.1016/j.health.2026.100468
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ __all__ = ["SAGA", "SAGALayer"]
26
+
27
+
28
+ class SAGA(nn.Module):
29
+ """Spatially-Adaptive Gated Activation (SAGA).
30
+
31
+ Parameters
32
+ ----------
33
+ in_channels : int
34
+ Number of input (and output) channels.
35
+
36
+ Examples
37
+ --------
38
+ >>> import torch
39
+ >>> from saga import SAGA
40
+ >>> act = SAGA(in_channels=64)
41
+ >>> x = torch.randn(2, 64, 32, 32)
42
+ >>> y = act(x)
43
+ >>> y.shape
44
+ torch.Size([2, 64, 32, 32])
45
+ """
46
+
47
+ def __init__(self, in_channels: int) -> None:
48
+ super().__init__()
49
+ self.in_channels = in_channels
50
+
51
+ # Spatial context extractor
52
+ self.spatial_conv = nn.Conv2d(
53
+ in_channels,
54
+ in_channels,
55
+ kernel_size=3,
56
+ padding=1,
57
+ groups=in_channels,
58
+ bias=False
59
+ )
60
+ self.spatial_bn = nn.BatchNorm2d(in_channels)
61
+
62
+ # Dynamic gate generator
63
+ self.gate_generator = nn.Conv2d(
64
+ in_channels,
65
+ in_channels,
66
+ kernel_size=1,
67
+ padding=0,
68
+ bias=True
69
+ )
70
+
71
+ self._init_weights()
72
+
73
+ def _init_weights(self) -> None:
74
+ """Initialize weights to ensure stable early-stage training."""
75
+ nn.init.kaiming_normal_(self.spatial_conv.weight, mode='fan_in', nonlinearity='relu')
76
+ nn.init.constant_(self.spatial_bn.weight, 1)
77
+ nn.init.constant_(self.spatial_bn.bias, 0)
78
+
79
+ # Gate generator initialization
80
+ nn.init.constant_(self.gate_generator.weight, 0)
81
+ # Starts gate near ~0.88 for stability during initial epochs
82
+ nn.init.constant_(self.gate_generator.bias, 2.0)
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ """Forward pass.
86
+
87
+ Parameters
88
+ ----------
89
+ x : torch.Tensor
90
+ Shape ``(B, C, H, W)`` with ``C == in_channels``.
91
+
92
+ Returns
93
+ -------
94
+ torch.Tensor
95
+ Same shape as *x*.
96
+ """
97
+ # Extract spatial context
98
+ T_x = self.spatial_bn(self.spatial_conv(x))
99
+
100
+ # Calculate positive boost
101
+ boost = F.relu(T_x - x)
102
+
103
+ # Generate spatial gate
104
+ gate = torch.sigmoid(self.gate_generator(T_x))
105
+
106
+ # Add gated boost to identity
107
+ return x + (gate * boost)
108
+
109
+ def extra_repr(self) -> str:
110
+ return f"in_channels={self.in_channels}"
111
+
112
+
113
+ # Alias for drop-in use inside sequential blocks
114
+ SAGALayer = SAGA
@@ -0,0 +1,88 @@
1
+ """
2
+ saga.blocks
3
+ ===========
4
+ Ready-made convolutional building blocks that use SAGA as their internal
5
+ activation function. These blocks can be used as drop-in replacements for
6
+ standard residual blocks in U-Net, ResNet, or EDSR style architectures.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .activation import SAGA
15
+
16
+ __all__ = ["SAGAResBlock", "SAGABottleneck"]
17
+
18
+
19
+ class SAGAResBlock(nn.Module):
20
+ """Residual block with SAGA activations."""
21
+
22
+ def __init__(
23
+ self,
24
+ in_channels: int,
25
+ out_channels: int | None = None,
26
+ stride: int = 1,
27
+ ) -> None:
28
+ super().__init__()
29
+ out_channels = out_channels or in_channels
30
+
31
+ self.conv1 = nn.Conv2d(
32
+ in_channels, out_channels, 3, stride=stride, padding=1, bias=False
33
+ )
34
+ self.bn1 = nn.BatchNorm2d(out_channels)
35
+ self.act1 = SAGA(out_channels)
36
+
37
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
38
+ self.bn2 = nn.BatchNorm2d(out_channels)
39
+ self.act2 = SAGA(out_channels)
40
+
41
+ self.shortcut: nn.Module
42
+ if stride != 1 or in_channels != out_channels:
43
+ self.shortcut = nn.Sequential(
44
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
45
+ nn.BatchNorm2d(out_channels),
46
+ )
47
+ else:
48
+ self.shortcut = nn.Identity()
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ residual = self.shortcut(x)
52
+ out = self.act1(self.bn1(self.conv1(x)))
53
+ out = self.bn2(self.conv2(out))
54
+ return self.act2(out + residual)
55
+
56
+
57
+ class SAGABottleneck(nn.Module):
58
+ """Bottleneck block (1x1 -> 3x3 -> 1x1) with SAGA activations."""
59
+
60
+ def __init__(
61
+ self,
62
+ in_channels: int,
63
+ bottleneck_channels: int | None = None,
64
+ out_channels: int | None = None,
65
+ ) -> None:
66
+ super().__init__()
67
+ bottleneck_channels = bottleneck_channels or max(in_channels // 4, 1)
68
+ out_channels = out_channels or in_channels
69
+
70
+ self.net = nn.Sequential(
71
+ nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False),
72
+ nn.BatchNorm2d(bottleneck_channels),
73
+ SAGA(bottleneck_channels),
74
+ nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False),
75
+ nn.BatchNorm2d(bottleneck_channels),
76
+ SAGA(bottleneck_channels),
77
+ nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False),
78
+ nn.BatchNorm2d(out_channels),
79
+ )
80
+ self.skip = (
81
+ nn.Conv2d(in_channels, out_channels, 1, bias=False)
82
+ if in_channels != out_channels
83
+ else nn.Identity()
84
+ )
85
+ self.out_act = SAGA(out_channels)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return self.out_act(self.net(x) + self.skip(x))
@@ -0,0 +1,59 @@
1
+ """
2
+ saga.utils
3
+ ==========
4
+ Lightweight helpers for parameter accounting and gate control.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch.nn as nn
10
+
11
+ from .activation import SAGA
12
+
13
+ __all__ = ["count_parameters", "freeze_gate", "unfreeze_gate"]
14
+
15
+
16
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
17
+ """Return the total number of (trainable) parameters in *model*.
18
+
19
+ Parameters
20
+ ----------
21
+ model : nn.Module
22
+ trainable_only : bool, optional
23
+ When *True* (default) only count parameters with
24
+ ``requires_grad == True``.
25
+
26
+ Returns
27
+ -------
28
+ int
29
+ """
30
+ return sum(
31
+ p.numel()
32
+ for p in model.parameters()
33
+ if (not trainable_only) or p.requires_grad
34
+ )
35
+
36
+
37
+ def _set_gate_grad(model: nn.Module, requires_grad: bool) -> None:
38
+ """Recursively set requires_grad for all SAGA components."""
39
+ for module in model.modules():
40
+ if isinstance(module, SAGA):
41
+ for name in ("spatial_conv", "spatial_bn", "gate_generator"):
42
+ sub = getattr(module, name, None)
43
+ if sub is not None:
44
+ for p in sub.parameters():
45
+ p.requires_grad_(requires_grad)
46
+
47
+
48
+ def freeze_gate(model: nn.Module) -> None:
49
+ """Freeze all SAGA gating parameters in *model*.
50
+
51
+ Useful for curriculum training: first train the backbone, then unfreeze
52
+ the gates for fine-tuning.
53
+ """
54
+ _set_gate_grad(model, False)
55
+
56
+
57
+ def unfreeze_gate(model: nn.Module) -> None:
58
+ """Unfreeze all SAGA gating parameters in *model*."""
59
+ _set_gate_grad(model, True)
@@ -0,0 +1,238 @@
1
+ Metadata-Version: 2.4
2
+ Name: saga-activation
3
+ Version: 0.1.0
4
+ Summary: Spatially-Adaptive Gated Activation (SAGA) for medical image restoration
5
+ Author: Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
6
+ Author-email: "Siju K.S." <sijuks@example.com>
7
+ License: MIT
8
+ Project-URL: Homepage, https://github.com/sijuswamyresearch/saga-activation
9
+ Project-URL: Documentation, https://sijuswamyresearch.github.io/saga-activation
10
+ Project-URL: Repository, https://github.com/sijuswamyresearch/saga-activation
11
+ Project-URL: Bug Tracker, https://github.com/sijuswamyresearch/saga-activation/issues
12
+ Project-URL: Paper, https://doi.org/10.1016/j.health.2026.100468
13
+ Keywords: deep learning,activation function,medical imaging,image restoration,deblurring,PyTorch
14
+ Classifier: Development Status :: 4 - Beta
15
+ Classifier: Intended Audience :: Science/Research
16
+ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Classifier: License :: OSI Approved :: MIT License
19
+ Classifier: Programming Language :: Python :: 3
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
23
+ Requires-Python: >=3.10
24
+ Description-Content-Type: text/markdown
25
+ License-File: LICENSE
26
+ Requires-Dist: torch>=2.0
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest>=7.0; extra == "dev"
29
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
30
+ Provides-Extra: docs
31
+ Requires-Dist: sphinx>=7.0; extra == "docs"
32
+ Requires-Dist: sphinx-rtd-theme>=2.0; extra == "docs"
33
+ Requires-Dist: myst-parser>=2.0; extra == "docs"
34
+ Dynamic: license-file
35
+
36
+ # SAGA — Spatially-Adaptive Gated Activation
37
+
38
+ [![CI](https://github.com/sijuswamyresearch/SAGA/actions/workflows/ci.yml/badge.svg)](https://github.com/sijuswamyresearch/SAGA/actions)
39
+ [![PyPI version](https://badge.fury.io/py/saga-activation.svg)](https://pypi.org/project/saga-activation/)
40
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
41
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.XXXXXXX.svg)](https://doi.org/10.5281/zenodo.XXXXXXX)
42
+ [![Paper](https://img.shields.io/badge/Paper-Healthcare%20Analytics-blue)](https://doi.org/10.1016/j.health.2026.100468)
43
+
44
+ > **An Interpretable Deep Learning Method for Medical Image Deblurring and Restoration**
45
+ > Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
46
+ > *Healthcare Analytics* 9 (2026) 100468 · [doi:10.1016/j.health.2026.100468](https://doi.org/10.1016/j.health.2026.100468)
47
+
48
+ ---
49
+
50
+ ## Overview
51
+
52
+ Standard activation functions (ReLU, SiLU, GELU) treat every spatial location in a feature map identically. In medical images — CT slices, DXA scans — the information content is *not* spatially uniform: anatomical boundaries carry high-frequency diagnostically-critical detail while homogeneous regions (background, soft tissue) require smooth suppression.
53
+
54
+ **SAGA** introduces a *learned spatial gating map* that modulates the activation response position-by-position:
55
+
56
+ ```
57
+ G(X) = σ(W_g * X) # spatial gate (depthwise-separable conv)
58
+ SAGA(X) = G(X) ⊙ φ(X) # φ = SiLU (default)
59
+ ```
60
+
61
+ This two-path design lets the network selectively amplify high-frequency boundary signals while smoothly gating uniform background areas — without increasing the depth of the network.
62
+
63
+ ---
64
+
65
+ ## Installation
66
+
67
+ ```bash
68
+ pip install saga-activation
69
+ ```
70
+
71
+ Or install from source:
72
+
73
+ ```bash
74
+ git clone https://github.com/sijuswamyresearch/SAGA.git
75
+ cd SAGA
76
+ pip install -e ".[dev]"
77
+ ```
78
+
79
+ **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.0
80
+
81
+ ---
82
+
83
+ ## Quick Start
84
+
85
+ ### Drop-in activation replacement
86
+
87
+ ```python
88
+ import torch
89
+ from saga import SAGA
90
+
91
+ # Replace any activation layer with SAGA
92
+ act = SAGA(in_channels=64) # matches the channel dim of your feature map
93
+ x = torch.randn(2, 64, 256, 256) # (B, C, H, W)
94
+ y = act(x) # same shape: (2, 64, 256, 256)
95
+ ```
96
+
97
+ ### Inside a U-Net encoder block
98
+
99
+ ```python
100
+ import torch.nn as nn
101
+ from saga import SAGA
102
+
103
+ class EncoderBlock(nn.Module):
104
+ def __init__(self, in_ch, out_ch):
105
+ super().__init__()
106
+ self.block = nn.Sequential(
107
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
108
+ nn.BatchNorm2d(out_ch),
109
+ SAGA(out_ch), # ← swap in SAGA here
110
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
111
+ nn.BatchNorm2d(out_ch),
112
+ SAGA(out_ch),
113
+ )
114
+ self.pool = nn.MaxPool2d(2)
115
+
116
+ def forward(self, x):
117
+ return self.pool(self.block(x))
118
+ ```
119
+
120
+ ### Pre-built residual blocks
121
+
122
+ ```python
123
+ from saga import SAGAResBlock, SAGABottleneck
124
+
125
+ res = SAGAResBlock(64) # standard residual block
126
+ bottle = SAGABottleneck(64, out_channels=128) # bottleneck variant
127
+ ```
128
+
129
+ ### Base-activation variants
130
+
131
+ ```python
132
+ from saga import SAGA
133
+
134
+ act_relu = SAGA(64, base_activation="relu")
135
+ act_gelu = SAGA(64, base_activation="gelu")
136
+ act_tanh = SAGA(64, base_activation="tanh")
137
+ ```
138
+
139
+ ### Gate curriculum training
140
+
141
+ ```python
142
+ from saga.utils import freeze_gate, unfreeze_gate
143
+
144
+ # Phase 1 – train backbone only
145
+ freeze_gate(model)
146
+ train(model, epochs=10, lr=1e-3)
147
+
148
+ # Phase 2 – fine-tune gates
149
+ unfreeze_gate(model)
150
+ train(model, epochs=5, lr=1e-4)
151
+ ```
152
+
153
+ ---
154
+
155
+ ## Repository Structure
156
+
157
+ ```
158
+ SAGA/
159
+ ├── saga/ # installable Python package
160
+ │ ├── __init__.py
161
+ │ ├── activation.py # SAGA operator (core)
162
+ │ ├── blocks.py # SAGAResBlock, SAGABottleneck
163
+ │ └── utils.py # parameter counting, gate freeze helpers
164
+
165
+ ├── tests/
166
+ │ ├── conftest.py
167
+ │ └── test_saga.py # pytest suite (shapes, edge cases, GPU, gradients)
168
+
169
+ ├── SAGA_Supplementary_Code/ # original experimental pipeline
170
+ │ ├── models/
171
+ │ │ ├── saga_layer.py # raw research implementation
172
+ │ │ ├── unet.py
173
+ │ │ ├── resnet.py
174
+ │ │ ├── edsr.py
175
+ │ │ └── vggnet.py
176
+ │ ├── generate_dataset.py
177
+ │ ├── train.py
178
+ │ ├── evaluate.py
179
+ │ ├── xai_analysis.py
180
+ │ └── clinical_validation.py
181
+
182
+ ├── docs/ # Sphinx documentation source
183
+ ├── .github/workflows/ci.yml # GitHub Actions CI
184
+ ├── pyproject.toml
185
+ └── README.md
186
+ ```
187
+
188
+ ---
189
+
190
+ ## Experimental Results (summary)
191
+
192
+ | Model | Activation | CT PSNR (dB) | CT SSIM | DXA PSNR (dB) | DXA SSIM |
193
+ |---------------|-----------|:------------:|:-------:|:-------------:|:--------:|
194
+ | U-Net | ReLU | 32.14 | 0.891 | 30.87 | 0.873 |
195
+ | U-Net | SiLU | 33.01 | 0.902 | 31.54 | 0.881 |
196
+ | **U-Net** | **SAGA** | **34.67** | **0.921** | **33.12** | **0.903** |
197
+ | DeblurResNet | ReLU | 31.89 | 0.883 | 30.21 | 0.864 |
198
+ | **DeblurResNet** | **SAGA** | **34.11** | **0.916** | **32.78** | **0.897** |
199
+
200
+ Full results and ablation studies are reported in the paper.
201
+
202
+ ---
203
+
204
+ ## Running the Tests
205
+
206
+ ```bash
207
+ pytest tests/ -v
208
+ ```
209
+
210
+ To run with coverage:
211
+
212
+ ```bash
213
+ pytest tests/ --cov=saga --cov-report=term-missing
214
+ ```
215
+
216
+ ---
217
+
218
+ ## Citing
219
+
220
+ If SAGA is useful in your research, please cite:
221
+
222
+ ```bibtex
223
+ @article{siju2026saga,
224
+ title = {An interpretable deep learning method for medical image deblurring and restoration},
225
+ author = {Siju K.S. and Vipin Venugopal and Mithun Kumar Kar and Jayakrishnan Anandakrishnan},
226
+ journal = {Healthcare Analytics},
227
+ volume = {9},
228
+ pages = {100468},
229
+ year = {2026},
230
+ doi = {10.1016/j.health.2026.100468}
231
+ }
232
+ ```
233
+
234
+ ---
235
+
236
+ ## License
237
+
238
+ [MIT](LICENSE)
@@ -0,0 +1,13 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/saga/__init__.py
5
+ src/saga/activation.py
6
+ src/saga/blocks.py
7
+ src/saga/utils.py
8
+ src/saga_activation.egg-info/PKG-INFO
9
+ src/saga_activation.egg-info/SOURCES.txt
10
+ src/saga_activation.egg-info/dependency_links.txt
11
+ src/saga_activation.egg-info/requires.txt
12
+ src/saga_activation.egg-info/top_level.txt
13
+ tests/test_saga.py
@@ -0,0 +1,10 @@
1
+ torch>=2.0
2
+
3
+ [dev]
4
+ pytest>=7.0
5
+ pytest-cov>=4.0
6
+
7
+ [docs]
8
+ sphinx>=7.0
9
+ sphinx-rtd-theme>=2.0
10
+ myst-parser>=2.0
@@ -0,0 +1,244 @@
1
+ """
2
+ tests/test_saga.py
3
+ ==================
4
+ pytest test suite for the SAGA activation package.
5
+
6
+ Run with:
7
+ pytest tests/ -v
8
+ """
9
+
10
+ import pytest
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from src.saga import SAGA, SAGALayer, SAGAResBlock, SAGABottleneck
15
+ from src.saga.utils import count_parameters, freeze_gate, unfreeze_gate
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Fixtures
20
+ # ---------------------------------------------------------------------------
21
+
22
+ DEVICES = ["cpu"]
23
+ if torch.cuda.is_available():
24
+ DEVICES.append("cuda")
25
+
26
+
27
+ @pytest.fixture(params=DEVICES)
28
+ def device(request):
29
+ return torch.device(request.param)
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # 1. Output shape tests
34
+ # ---------------------------------------------------------------------------
35
+
36
+ class TestOutputShape:
37
+ @pytest.mark.parametrize("B,C,H,W", [
38
+ (2, 1, 8, 8), # B=2 to avoid BatchNorm crash during train mode
39
+ (2, 16, 32, 32),
40
+ (4, 64, 128, 128),
41
+ (2, 3, 256, 256),
42
+ ])
43
+ def test_shape_preserved(self, B, C, H, W):
44
+ """SAGA must return a tensor with the same shape as the input."""
45
+ act = SAGA(in_channels=C)
46
+ x = torch.randn(B, C, H, W)
47
+ y = act(x)
48
+ assert y.shape == x.shape, f"Expected {x.shape}, got {y.shape}"
49
+
50
+ def test_single_channel(self):
51
+ act = SAGA(in_channels=1)
52
+ x = torch.randn(2, 1, 4, 4)
53
+ assert act(x).shape == x.shape
54
+
55
+ def test_large_channel(self):
56
+ act = SAGA(in_channels=512)
57
+ x = torch.randn(2, 512, 8, 8)
58
+ assert act(x).shape == x.shape
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # 2. Edge-case tensor value tests
63
+ # ---------------------------------------------------------------------------
64
+
65
+ class TestEdgeCases:
66
+ def test_zero_input(self):
67
+ act = SAGA(in_channels=8)
68
+ x = torch.zeros(2, 8, 16, 16)
69
+ y = act(x)
70
+ assert y.shape == x.shape
71
+ assert torch.isfinite(y).all()
72
+
73
+ def test_negative_input(self):
74
+ act = SAGA(in_channels=8)
75
+ x = -torch.abs(torch.randn(2, 8, 16, 16))
76
+ y = act(x)
77
+ assert torch.isfinite(y).all()
78
+
79
+ def test_large_positive_input(self):
80
+ act = SAGA(in_channels=8)
81
+ x = torch.full((2, 8, 16, 16), 1e3)
82
+ y = act(x)
83
+ assert torch.isfinite(y).all()
84
+
85
+ def test_large_negative_input(self):
86
+ act = SAGA(in_channels=8)
87
+ x = torch.full((2, 8, 16, 16), -1e3)
88
+ y = act(x)
89
+ assert torch.isfinite(y).all()
90
+
91
+ def test_nan_free(self):
92
+ """Confirm no NaN leaks on random inputs."""
93
+ torch.manual_seed(42)
94
+ act = SAGA(in_channels=32)
95
+ for _ in range(10):
96
+ x = torch.randn(2, 32, 64, 64) * 10
97
+ y = act(x)
98
+ assert not torch.isnan(y).any(), "NaN detected in SAGA output"
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # 3. Gradient / backprop tests
103
+ # ---------------------------------------------------------------------------
104
+
105
+ class TestGradients:
106
+ def test_backward_cpu(self):
107
+ act = SAGA(in_channels=16)
108
+ x = torch.randn(2, 16, 32, 32, requires_grad=True)
109
+ y = act(x).sum()
110
+ y.backward()
111
+ assert x.grad is not None
112
+ assert torch.isfinite(x.grad).all()
113
+
114
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
115
+ def test_backward_cuda(self):
116
+ act = SAGA(in_channels=16).cuda()
117
+ x = torch.randn(2, 16, 32, 32, requires_grad=True, device="cuda")
118
+ y = act(x).sum()
119
+ y.backward()
120
+ assert x.grad is not None
121
+ assert torch.isfinite(x.grad).all()
122
+
123
+ def test_gate_parameters_get_gradients(self):
124
+ act = SAGA(in_channels=8)
125
+ x = torch.randn(2, 8, 16, 16)
126
+ act(x).sum().backward()
127
+ for name, p in act.named_parameters():
128
+ assert p.grad is not None, f"No gradient for parameter '{name}'"
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # 4. Device tests
133
+ # ---------------------------------------------------------------------------
134
+
135
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
136
+ class TestCUDA:
137
+ def test_forward_cuda(self):
138
+ act = SAGA(in_channels=32).cuda()
139
+ x = torch.randn(2, 32, 64, 64, device="cuda")
140
+ y = act(x)
141
+ assert y.device.type == "cuda"
142
+ assert y.shape == x.shape
143
+
144
+ def test_no_memory_leak(self):
145
+ """Check GPU memory does not grow unboundedly over repeated calls."""
146
+ act = SAGA(in_channels=64).cuda()
147
+ torch.cuda.reset_peak_memory_stats()
148
+ baseline = torch.cuda.memory_allocated()
149
+
150
+ for _ in range(50):
151
+ x = torch.randn(4, 64, 128, 128, device="cuda")
152
+ _ = act(x)
153
+ del x
154
+
155
+ torch.cuda.synchronize()
156
+ peak = torch.cuda.max_memory_allocated()
157
+ # Allow at most 200 MB overhead
158
+ assert (peak - baseline) < 200 * 1024 ** 2, "Potential memory leak detected"
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # 5. Building-block tests
163
+ # ---------------------------------------------------------------------------
164
+
165
+ class TestBlocks:
166
+ def test_resblock_shape(self):
167
+ block = SAGAResBlock(64)
168
+ x = torch.randn(2, 64, 32, 32)
169
+ assert block(x).shape == x.shape
170
+
171
+ def test_resblock_projection(self):
172
+ block = SAGAResBlock(32, out_channels=64, stride=2)
173
+ x = torch.randn(2, 32, 32, 32)
174
+ y = block(x)
175
+ assert y.shape == (2, 64, 16, 16)
176
+
177
+ def test_bottleneck_shape(self):
178
+ block = SAGABottleneck(64)
179
+ x = torch.randn(2, 64, 32, 32)
180
+ assert block(x).shape == x.shape
181
+
182
+ def test_bottleneck_channel_change(self):
183
+ block = SAGABottleneck(32, out_channels=128)
184
+ x = torch.randn(2, 32, 16, 16)
185
+ assert block(x).shape == (2, 128, 16, 16)
186
+
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # 6. Utility tests
190
+ # ---------------------------------------------------------------------------
191
+
192
+ class TestUtils:
193
+ def test_count_parameters(self):
194
+ act = SAGA(in_channels=64)
195
+ n = count_parameters(act)
196
+ assert n > 0
197
+
198
+ def test_freeze_unfreeze_gate(self):
199
+ model = nn.Sequential(SAGA(32), SAGA(32))
200
+
201
+ # Test Freeze
202
+ freeze_gate(model)
203
+ for module in model.modules():
204
+ if isinstance(module, SAGA):
205
+ for p in module.spatial_conv.parameters():
206
+ assert not p.requires_grad
207
+ for p in module.gate_generator.parameters():
208
+ assert not p.requires_grad
209
+
210
+ # Test Unfreeze
211
+ unfreeze_gate(model)
212
+ for module in model.modules():
213
+ if isinstance(module, SAGA):
214
+ for p in module.spatial_conv.parameters():
215
+ assert p.requires_grad
216
+ for p in module.gate_generator.parameters():
217
+ assert p.requires_grad
218
+
219
+
220
+ # ---------------------------------------------------------------------------
221
+ # 7. SAGALayer alias
222
+ # ---------------------------------------------------------------------------
223
+
224
+ def test_saga_layer_alias():
225
+ assert SAGALayer is SAGA
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # 8. Serialisation / state-dict round-trip
230
+ # ---------------------------------------------------------------------------
231
+
232
+ def test_state_dict_round_trip(tmp_path):
233
+ act = SAGA(in_channels=16)
234
+ x = torch.randn(2, 16, 8, 8)
235
+ y_before = act(x)
236
+
237
+ path = tmp_path / "saga.pt"
238
+ torch.save(act.state_dict(), path)
239
+
240
+ act2 = SAGA(in_channels=16)
241
+ act2.load_state_dict(torch.load(path, map_location="cpu"))
242
+ y_after = act2(x)
243
+
244
+ assert torch.allclose(y_before, y_after, atol=1e-6)