saga-activation 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- saga_activation-0.1.0/LICENSE +21 -0
- saga_activation-0.1.0/PKG-INFO +238 -0
- saga_activation-0.1.0/README.md +203 -0
- saga_activation-0.1.0/pyproject.toml +63 -0
- saga_activation-0.1.0/setup.cfg +4 -0
- saga_activation-0.1.0/src/saga/__init__.py +40 -0
- saga_activation-0.1.0/src/saga/activation.py +114 -0
- saga_activation-0.1.0/src/saga/blocks.py +88 -0
- saga_activation-0.1.0/src/saga/utils.py +59 -0
- saga_activation-0.1.0/src/saga_activation.egg-info/PKG-INFO +238 -0
- saga_activation-0.1.0/src/saga_activation.egg-info/SOURCES.txt +13 -0
- saga_activation-0.1.0/src/saga_activation.egg-info/dependency_links.txt +1 -0
- saga_activation-0.1.0/src/saga_activation.egg-info/requires.txt +10 -0
- saga_activation-0.1.0/src/saga_activation.egg-info/top_level.txt +1 -0
- saga_activation-0.1.0/tests/test_saga.py +244 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: saga-activation
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Spatially-Adaptive Gated Activation (SAGA) for medical image restoration
|
|
5
|
+
Author: Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
|
|
6
|
+
Author-email: "Siju K.S." <sijuks@example.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
Project-URL: Homepage, https://github.com/sijuswamyresearch/saga-activation
|
|
9
|
+
Project-URL: Documentation, https://sijuswamyresearch.github.io/saga-activation
|
|
10
|
+
Project-URL: Repository, https://github.com/sijuswamyresearch/saga-activation
|
|
11
|
+
Project-URL: Bug Tracker, https://github.com/sijuswamyresearch/saga-activation/issues
|
|
12
|
+
Project-URL: Paper, https://doi.org/10.1016/j.health.2026.100468
|
|
13
|
+
Keywords: deep learning,activation function,medical imaging,image restoration,deblurring,PyTorch
|
|
14
|
+
Classifier: Development Status :: 4 - Beta
|
|
15
|
+
Classifier: Intended Audience :: Science/Research
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
19
|
+
Classifier: Programming Language :: Python :: 3
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Requires-Python: >=3.10
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
License-File: LICENSE
|
|
26
|
+
Requires-Dist: torch>=2.0
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
30
|
+
Provides-Extra: docs
|
|
31
|
+
Requires-Dist: sphinx>=7.0; extra == "docs"
|
|
32
|
+
Requires-Dist: sphinx-rtd-theme>=2.0; extra == "docs"
|
|
33
|
+
Requires-Dist: myst-parser>=2.0; extra == "docs"
|
|
34
|
+
Dynamic: license-file
|
|
35
|
+
|
|
36
|
+
# SAGA — Spatially-Adaptive Gated Activation
|
|
37
|
+
|
|
38
|
+
[](https://github.com/sijuswamyresearch/SAGA/actions)
|
|
39
|
+
[](https://pypi.org/project/saga-activation/)
|
|
40
|
+
[](https://opensource.org/licenses/MIT)
|
|
41
|
+
[](https://doi.org/10.5281/zenodo.XXXXXXX)
|
|
42
|
+
[](https://doi.org/10.1016/j.health.2026.100468)
|
|
43
|
+
|
|
44
|
+
> **An Interpretable Deep Learning Method for Medical Image Deblurring and Restoration**
|
|
45
|
+
> Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
|
|
46
|
+
> *Healthcare Analytics* 9 (2026) 100468 · [doi:10.1016/j.health.2026.100468](https://doi.org/10.1016/j.health.2026.100468)
|
|
47
|
+
|
|
48
|
+
---
|
|
49
|
+
|
|
50
|
+
## Overview
|
|
51
|
+
|
|
52
|
+
Standard activation functions (ReLU, SiLU, GELU) treat every spatial location in a feature map identically. In medical images — CT slices, DXA scans — the information content is *not* spatially uniform: anatomical boundaries carry high-frequency diagnostically-critical detail while homogeneous regions (background, soft tissue) require smooth suppression.
|
|
53
|
+
|
|
54
|
+
**SAGA** introduces a *learned spatial gating map* that modulates the activation response position-by-position:
|
|
55
|
+
|
|
56
|
+
```
|
|
57
|
+
G(X) = σ(W_g * X) # spatial gate (depthwise-separable conv)
|
|
58
|
+
SAGA(X) = G(X) ⊙ φ(X) # φ = SiLU (default)
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
This two-path design lets the network selectively amplify high-frequency boundary signals while smoothly gating uniform background areas — without increasing the depth of the network.
|
|
62
|
+
|
|
63
|
+
---
|
|
64
|
+
|
|
65
|
+
## Installation
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
pip install saga-activation
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
Or install from source:
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
git clone https://github.com/sijuswamyresearch/SAGA.git
|
|
75
|
+
cd SAGA
|
|
76
|
+
pip install -e ".[dev]"
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
**Requirements:** Python ≥ 3.10, PyTorch ≥ 2.0
|
|
80
|
+
|
|
81
|
+
---
|
|
82
|
+
|
|
83
|
+
## Quick Start
|
|
84
|
+
|
|
85
|
+
### Drop-in activation replacement
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
import torch
|
|
89
|
+
from saga import SAGA
|
|
90
|
+
|
|
91
|
+
# Replace any activation layer with SAGA
|
|
92
|
+
act = SAGA(in_channels=64) # matches the channel dim of your feature map
|
|
93
|
+
x = torch.randn(2, 64, 256, 256) # (B, C, H, W)
|
|
94
|
+
y = act(x) # same shape: (2, 64, 256, 256)
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
### Inside a U-Net encoder block
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
import torch.nn as nn
|
|
101
|
+
from saga import SAGA
|
|
102
|
+
|
|
103
|
+
class EncoderBlock(nn.Module):
|
|
104
|
+
def __init__(self, in_ch, out_ch):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.block = nn.Sequential(
|
|
107
|
+
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
|
108
|
+
nn.BatchNorm2d(out_ch),
|
|
109
|
+
SAGA(out_ch), # ← swap in SAGA here
|
|
110
|
+
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
|
|
111
|
+
nn.BatchNorm2d(out_ch),
|
|
112
|
+
SAGA(out_ch),
|
|
113
|
+
)
|
|
114
|
+
self.pool = nn.MaxPool2d(2)
|
|
115
|
+
|
|
116
|
+
def forward(self, x):
|
|
117
|
+
return self.pool(self.block(x))
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
### Pre-built residual blocks
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
from saga import SAGAResBlock, SAGABottleneck
|
|
124
|
+
|
|
125
|
+
res = SAGAResBlock(64) # standard residual block
|
|
126
|
+
bottle = SAGABottleneck(64, out_channels=128) # bottleneck variant
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
### Base-activation variants
|
|
130
|
+
|
|
131
|
+
```python
|
|
132
|
+
from saga import SAGA
|
|
133
|
+
|
|
134
|
+
act_relu = SAGA(64, base_activation="relu")
|
|
135
|
+
act_gelu = SAGA(64, base_activation="gelu")
|
|
136
|
+
act_tanh = SAGA(64, base_activation="tanh")
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
### Gate curriculum training
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
from saga.utils import freeze_gate, unfreeze_gate
|
|
143
|
+
|
|
144
|
+
# Phase 1 – train backbone only
|
|
145
|
+
freeze_gate(model)
|
|
146
|
+
train(model, epochs=10, lr=1e-3)
|
|
147
|
+
|
|
148
|
+
# Phase 2 – fine-tune gates
|
|
149
|
+
unfreeze_gate(model)
|
|
150
|
+
train(model, epochs=5, lr=1e-4)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
---
|
|
154
|
+
|
|
155
|
+
## Repository Structure
|
|
156
|
+
|
|
157
|
+
```
|
|
158
|
+
SAGA/
|
|
159
|
+
├── saga/ # installable Python package
|
|
160
|
+
│ ├── __init__.py
|
|
161
|
+
│ ├── activation.py # SAGA operator (core)
|
|
162
|
+
│ ├── blocks.py # SAGAResBlock, SAGABottleneck
|
|
163
|
+
│ └── utils.py # parameter counting, gate freeze helpers
|
|
164
|
+
│
|
|
165
|
+
├── tests/
|
|
166
|
+
│ ├── conftest.py
|
|
167
|
+
│ └── test_saga.py # pytest suite (shapes, edge cases, GPU, gradients)
|
|
168
|
+
│
|
|
169
|
+
├── SAGA_Supplementary_Code/ # original experimental pipeline
|
|
170
|
+
│ ├── models/
|
|
171
|
+
│ │ ├── saga_layer.py # raw research implementation
|
|
172
|
+
│ │ ├── unet.py
|
|
173
|
+
│ │ ├── resnet.py
|
|
174
|
+
│ │ ├── edsr.py
|
|
175
|
+
│ │ └── vggnet.py
|
|
176
|
+
│ ├── generate_dataset.py
|
|
177
|
+
│ ├── train.py
|
|
178
|
+
│ ├── evaluate.py
|
|
179
|
+
│ ├── xai_analysis.py
|
|
180
|
+
│ └── clinical_validation.py
|
|
181
|
+
│
|
|
182
|
+
├── docs/ # Sphinx documentation source
|
|
183
|
+
├── .github/workflows/ci.yml # GitHub Actions CI
|
|
184
|
+
├── pyproject.toml
|
|
185
|
+
└── README.md
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
---
|
|
189
|
+
|
|
190
|
+
## Experimental Results (summary)
|
|
191
|
+
|
|
192
|
+
| Model | Activation | CT PSNR (dB) | CT SSIM | DXA PSNR (dB) | DXA SSIM |
|
|
193
|
+
|---------------|-----------|:------------:|:-------:|:-------------:|:--------:|
|
|
194
|
+
| U-Net | ReLU | 32.14 | 0.891 | 30.87 | 0.873 |
|
|
195
|
+
| U-Net | SiLU | 33.01 | 0.902 | 31.54 | 0.881 |
|
|
196
|
+
| **U-Net** | **SAGA** | **34.67** | **0.921** | **33.12** | **0.903** |
|
|
197
|
+
| DeblurResNet | ReLU | 31.89 | 0.883 | 30.21 | 0.864 |
|
|
198
|
+
| **DeblurResNet** | **SAGA** | **34.11** | **0.916** | **32.78** | **0.897** |
|
|
199
|
+
|
|
200
|
+
Full results and ablation studies are reported in the paper.
|
|
201
|
+
|
|
202
|
+
---
|
|
203
|
+
|
|
204
|
+
## Running the Tests
|
|
205
|
+
|
|
206
|
+
```bash
|
|
207
|
+
pytest tests/ -v
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
To run with coverage:
|
|
211
|
+
|
|
212
|
+
```bash
|
|
213
|
+
pytest tests/ --cov=saga --cov-report=term-missing
|
|
214
|
+
```
|
|
215
|
+
|
|
216
|
+
---
|
|
217
|
+
|
|
218
|
+
## Citing
|
|
219
|
+
|
|
220
|
+
If SAGA is useful in your research, please cite:
|
|
221
|
+
|
|
222
|
+
```bibtex
|
|
223
|
+
@article{siju2026saga,
|
|
224
|
+
title = {An interpretable deep learning method for medical image deblurring and restoration},
|
|
225
|
+
author = {Siju K.S. and Vipin Venugopal and Mithun Kumar Kar and Jayakrishnan Anandakrishnan},
|
|
226
|
+
journal = {Healthcare Analytics},
|
|
227
|
+
volume = {9},
|
|
228
|
+
pages = {100468},
|
|
229
|
+
year = {2026},
|
|
230
|
+
doi = {10.1016/j.health.2026.100468}
|
|
231
|
+
}
|
|
232
|
+
```
|
|
233
|
+
|
|
234
|
+
---
|
|
235
|
+
|
|
236
|
+
## License
|
|
237
|
+
|
|
238
|
+
[MIT](LICENSE)
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# SAGA — Spatially-Adaptive Gated Activation
|
|
2
|
+
|
|
3
|
+
[](https://github.com/sijuswamyresearch/SAGA/actions)
|
|
4
|
+
[](https://pypi.org/project/saga-activation/)
|
|
5
|
+
[](https://opensource.org/licenses/MIT)
|
|
6
|
+
[](https://doi.org/10.5281/zenodo.XXXXXXX)
|
|
7
|
+
[](https://doi.org/10.1016/j.health.2026.100468)
|
|
8
|
+
|
|
9
|
+
> **An Interpretable Deep Learning Method for Medical Image Deblurring and Restoration**
|
|
10
|
+
> Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
|
|
11
|
+
> *Healthcare Analytics* 9 (2026) 100468 · [doi:10.1016/j.health.2026.100468](https://doi.org/10.1016/j.health.2026.100468)
|
|
12
|
+
|
|
13
|
+
---
|
|
14
|
+
|
|
15
|
+
## Overview
|
|
16
|
+
|
|
17
|
+
Standard activation functions (ReLU, SiLU, GELU) treat every spatial location in a feature map identically. In medical images — CT slices, DXA scans — the information content is *not* spatially uniform: anatomical boundaries carry high-frequency diagnostically-critical detail while homogeneous regions (background, soft tissue) require smooth suppression.
|
|
18
|
+
|
|
19
|
+
**SAGA** introduces a *learned spatial gating map* that modulates the activation response position-by-position:
|
|
20
|
+
|
|
21
|
+
```
|
|
22
|
+
G(X) = σ(W_g * X) # spatial gate (depthwise-separable conv)
|
|
23
|
+
SAGA(X) = G(X) ⊙ φ(X) # φ = SiLU (default)
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
This two-path design lets the network selectively amplify high-frequency boundary signals while smoothly gating uniform background areas — without increasing the depth of the network.
|
|
27
|
+
|
|
28
|
+
---
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install saga-activation
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Or install from source:
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
git clone https://github.com/sijuswamyresearch/SAGA.git
|
|
40
|
+
cd SAGA
|
|
41
|
+
pip install -e ".[dev]"
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
**Requirements:** Python ≥ 3.10, PyTorch ≥ 2.0
|
|
45
|
+
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
## Quick Start
|
|
49
|
+
|
|
50
|
+
### Drop-in activation replacement
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
import torch
|
|
54
|
+
from saga import SAGA
|
|
55
|
+
|
|
56
|
+
# Replace any activation layer with SAGA
|
|
57
|
+
act = SAGA(in_channels=64) # matches the channel dim of your feature map
|
|
58
|
+
x = torch.randn(2, 64, 256, 256) # (B, C, H, W)
|
|
59
|
+
y = act(x) # same shape: (2, 64, 256, 256)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### Inside a U-Net encoder block
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
import torch.nn as nn
|
|
66
|
+
from saga import SAGA
|
|
67
|
+
|
|
68
|
+
class EncoderBlock(nn.Module):
|
|
69
|
+
def __init__(self, in_ch, out_ch):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.block = nn.Sequential(
|
|
72
|
+
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
|
73
|
+
nn.BatchNorm2d(out_ch),
|
|
74
|
+
SAGA(out_ch), # ← swap in SAGA here
|
|
75
|
+
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
|
|
76
|
+
nn.BatchNorm2d(out_ch),
|
|
77
|
+
SAGA(out_ch),
|
|
78
|
+
)
|
|
79
|
+
self.pool = nn.MaxPool2d(2)
|
|
80
|
+
|
|
81
|
+
def forward(self, x):
|
|
82
|
+
return self.pool(self.block(x))
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### Pre-built residual blocks
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
from saga import SAGAResBlock, SAGABottleneck
|
|
89
|
+
|
|
90
|
+
res = SAGAResBlock(64) # standard residual block
|
|
91
|
+
bottle = SAGABottleneck(64, out_channels=128) # bottleneck variant
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
### Base-activation variants
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
from saga import SAGA
|
|
98
|
+
|
|
99
|
+
act_relu = SAGA(64, base_activation="relu")
|
|
100
|
+
act_gelu = SAGA(64, base_activation="gelu")
|
|
101
|
+
act_tanh = SAGA(64, base_activation="tanh")
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Gate curriculum training
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
from saga.utils import freeze_gate, unfreeze_gate
|
|
108
|
+
|
|
109
|
+
# Phase 1 – train backbone only
|
|
110
|
+
freeze_gate(model)
|
|
111
|
+
train(model, epochs=10, lr=1e-3)
|
|
112
|
+
|
|
113
|
+
# Phase 2 – fine-tune gates
|
|
114
|
+
unfreeze_gate(model)
|
|
115
|
+
train(model, epochs=5, lr=1e-4)
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
---
|
|
119
|
+
|
|
120
|
+
## Repository Structure
|
|
121
|
+
|
|
122
|
+
```
|
|
123
|
+
SAGA/
|
|
124
|
+
├── saga/ # installable Python package
|
|
125
|
+
│ ├── __init__.py
|
|
126
|
+
│ ├── activation.py # SAGA operator (core)
|
|
127
|
+
│ ├── blocks.py # SAGAResBlock, SAGABottleneck
|
|
128
|
+
│ └── utils.py # parameter counting, gate freeze helpers
|
|
129
|
+
│
|
|
130
|
+
├── tests/
|
|
131
|
+
│ ├── conftest.py
|
|
132
|
+
│ └── test_saga.py # pytest suite (shapes, edge cases, GPU, gradients)
|
|
133
|
+
│
|
|
134
|
+
├── SAGA_Supplementary_Code/ # original experimental pipeline
|
|
135
|
+
│ ├── models/
|
|
136
|
+
│ │ ├── saga_layer.py # raw research implementation
|
|
137
|
+
│ │ ├── unet.py
|
|
138
|
+
│ │ ├── resnet.py
|
|
139
|
+
│ │ ├── edsr.py
|
|
140
|
+
│ │ └── vggnet.py
|
|
141
|
+
│ ├── generate_dataset.py
|
|
142
|
+
│ ├── train.py
|
|
143
|
+
│ ├── evaluate.py
|
|
144
|
+
│ ├── xai_analysis.py
|
|
145
|
+
│ └── clinical_validation.py
|
|
146
|
+
│
|
|
147
|
+
├── docs/ # Sphinx documentation source
|
|
148
|
+
├── .github/workflows/ci.yml # GitHub Actions CI
|
|
149
|
+
├── pyproject.toml
|
|
150
|
+
└── README.md
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
---
|
|
154
|
+
|
|
155
|
+
## Experimental Results (summary)
|
|
156
|
+
|
|
157
|
+
| Model | Activation | CT PSNR (dB) | CT SSIM | DXA PSNR (dB) | DXA SSIM |
|
|
158
|
+
|---------------|-----------|:------------:|:-------:|:-------------:|:--------:|
|
|
159
|
+
| U-Net | ReLU | 32.14 | 0.891 | 30.87 | 0.873 |
|
|
160
|
+
| U-Net | SiLU | 33.01 | 0.902 | 31.54 | 0.881 |
|
|
161
|
+
| **U-Net** | **SAGA** | **34.67** | **0.921** | **33.12** | **0.903** |
|
|
162
|
+
| DeblurResNet | ReLU | 31.89 | 0.883 | 30.21 | 0.864 |
|
|
163
|
+
| **DeblurResNet** | **SAGA** | **34.11** | **0.916** | **32.78** | **0.897** |
|
|
164
|
+
|
|
165
|
+
Full results and ablation studies are reported in the paper.
|
|
166
|
+
|
|
167
|
+
---
|
|
168
|
+
|
|
169
|
+
## Running the Tests
|
|
170
|
+
|
|
171
|
+
```bash
|
|
172
|
+
pytest tests/ -v
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
To run with coverage:
|
|
176
|
+
|
|
177
|
+
```bash
|
|
178
|
+
pytest tests/ --cov=saga --cov-report=term-missing
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
---
|
|
182
|
+
|
|
183
|
+
## Citing
|
|
184
|
+
|
|
185
|
+
If SAGA is useful in your research, please cite:
|
|
186
|
+
|
|
187
|
+
```bibtex
|
|
188
|
+
@article{siju2026saga,
|
|
189
|
+
title = {An interpretable deep learning method for medical image deblurring and restoration},
|
|
190
|
+
author = {Siju K.S. and Vipin Venugopal and Mithun Kumar Kar and Jayakrishnan Anandakrishnan},
|
|
191
|
+
journal = {Healthcare Analytics},
|
|
192
|
+
volume = {9},
|
|
193
|
+
pages = {100468},
|
|
194
|
+
year = {2026},
|
|
195
|
+
doi = {10.1016/j.health.2026.100468}
|
|
196
|
+
}
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
---
|
|
200
|
+
|
|
201
|
+
## License
|
|
202
|
+
|
|
203
|
+
[MIT](LICENSE)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "saga-activation"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Spatially-Adaptive Gated Activation (SAGA) for medical image restoration"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { text = "MIT" }
|
|
11
|
+
authors = [
|
|
12
|
+
{ name = "Siju K.S.", email = "sijuks@example.com" },
|
|
13
|
+
{ name = "Vipin Venugopal" },
|
|
14
|
+
{ name = "Mithun Kumar Kar" },
|
|
15
|
+
{ name = "Jayakrishnan Anandakrishnan" },
|
|
16
|
+
]
|
|
17
|
+
keywords = [
|
|
18
|
+
"deep learning", "activation function", "medical imaging",
|
|
19
|
+
"image restoration", "deblurring", "PyTorch",
|
|
20
|
+
]
|
|
21
|
+
classifiers = [
|
|
22
|
+
"Development Status :: 4 - Beta",
|
|
23
|
+
"Intended Audience :: Science/Research",
|
|
24
|
+
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
|
25
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
26
|
+
"License :: OSI Approved :: MIT License",
|
|
27
|
+
"Programming Language :: Python :: 3",
|
|
28
|
+
"Programming Language :: Python :: 3.10",
|
|
29
|
+
"Programming Language :: Python :: 3.11",
|
|
30
|
+
"Programming Language :: Python :: 3.12",
|
|
31
|
+
]
|
|
32
|
+
requires-python = ">=3.10"
|
|
33
|
+
dependencies = [
|
|
34
|
+
"torch>=2.0",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.optional-dependencies]
|
|
38
|
+
dev = [
|
|
39
|
+
"pytest>=7.0",
|
|
40
|
+
"pytest-cov>=4.0",
|
|
41
|
+
]
|
|
42
|
+
docs = [
|
|
43
|
+
"sphinx>=7.0",
|
|
44
|
+
"sphinx-rtd-theme>=2.0",
|
|
45
|
+
"myst-parser>=2.0",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[project.urls]
|
|
49
|
+
Homepage = "https://github.com/sijuswamyresearch/saga-activation"
|
|
50
|
+
Documentation = "https://sijuswamyresearch.github.io/saga-activation"
|
|
51
|
+
Repository = "https://github.com/sijuswamyresearch/saga-activation"
|
|
52
|
+
"Bug Tracker" = "https://github.com/sijuswamyresearch/saga-activation/issues"
|
|
53
|
+
Paper = "https://doi.org/10.1016/j.health.2026.100468"
|
|
54
|
+
[tool.setuptools.packages.find]
|
|
55
|
+
where = ["src"]
|
|
56
|
+
include = ["saga*"]
|
|
57
|
+
|
|
58
|
+
[tool.pytest.ini_options]
|
|
59
|
+
testpaths = ["tests"]
|
|
60
|
+
addopts = "-v --tb=short"
|
|
61
|
+
|
|
62
|
+
[tool.coverage.run]
|
|
63
|
+
source = ["saga"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAGA — Spatially-Adaptive Gated Activation for Medical Image Restoration
|
|
3
|
+
=========================================================================
|
|
4
|
+
|
|
5
|
+
Quick start
|
|
6
|
+
-----------
|
|
7
|
+
>>> import torch
|
|
8
|
+
>>> from saga import SAGA
|
|
9
|
+
>>> act = SAGA(in_channels=64)
|
|
10
|
+
>>> x = torch.randn(2, 64, 256, 256)
|
|
11
|
+
>>> y = act(x) # same shape, spatially gated
|
|
12
|
+
|
|
13
|
+
Reference
|
|
14
|
+
---------
|
|
15
|
+
Siju K.S. et al. "An interpretable deep learning method for medical image
|
|
16
|
+
deblurring and restoration." Healthcare Analytics 9 (2026) 100468.
|
|
17
|
+
https://doi.org/10.1016/j.health.2026.100468
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
__version__: str = version("saga-activation")
|
|
24
|
+
except PackageNotFoundError:
|
|
25
|
+
__version__ = "0.1.0"
|
|
26
|
+
|
|
27
|
+
from .activation import SAGA, SAGALayer
|
|
28
|
+
from .blocks import SAGAResBlock, SAGABottleneck
|
|
29
|
+
from .utils import count_parameters, freeze_gate, unfreeze_gate
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"SAGA",
|
|
33
|
+
"SAGALayer",
|
|
34
|
+
"SAGAResBlock",
|
|
35
|
+
"SAGABottleneck",
|
|
36
|
+
"count_parameters",
|
|
37
|
+
"freeze_gate",
|
|
38
|
+
"unfreeze_gate",
|
|
39
|
+
"__version__",
|
|
40
|
+
]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
saga.activation
|
|
3
|
+
===============
|
|
4
|
+
Spatially-Adaptive Gated Activation (SAGA) operator.
|
|
5
|
+
|
|
6
|
+
SAGA extracts spatial context via a depthwise convolution, calculates a
|
|
7
|
+
residual boost, and dynamically gates this boost before adding it back to
|
|
8
|
+
the original input. This enables the network to selectively route gradient
|
|
9
|
+
flow through high-frequency anatomical boundary regions.
|
|
10
|
+
|
|
11
|
+
Reference
|
|
12
|
+
---------
|
|
13
|
+
Siju K.S., Venugopal V., Kar M.K., Anandakrishnan J.
|
|
14
|
+
"An interpretable deep learning method for medical image deblurring and
|
|
15
|
+
restoration." Healthcare Analytics 9 (2026) 100468.
|
|
16
|
+
https://doi.org/10.1016/j.health.2026.100468
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
|
|
25
|
+
__all__ = ["SAGA", "SAGALayer"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SAGA(nn.Module):
|
|
29
|
+
"""Spatially-Adaptive Gated Activation (SAGA).
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
in_channels : int
|
|
34
|
+
Number of input (and output) channels.
|
|
35
|
+
|
|
36
|
+
Examples
|
|
37
|
+
--------
|
|
38
|
+
>>> import torch
|
|
39
|
+
>>> from saga import SAGA
|
|
40
|
+
>>> act = SAGA(in_channels=64)
|
|
41
|
+
>>> x = torch.randn(2, 64, 32, 32)
|
|
42
|
+
>>> y = act(x)
|
|
43
|
+
>>> y.shape
|
|
44
|
+
torch.Size([2, 64, 32, 32])
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, in_channels: int) -> None:
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.in_channels = in_channels
|
|
50
|
+
|
|
51
|
+
# Spatial context extractor
|
|
52
|
+
self.spatial_conv = nn.Conv2d(
|
|
53
|
+
in_channels,
|
|
54
|
+
in_channels,
|
|
55
|
+
kernel_size=3,
|
|
56
|
+
padding=1,
|
|
57
|
+
groups=in_channels,
|
|
58
|
+
bias=False
|
|
59
|
+
)
|
|
60
|
+
self.spatial_bn = nn.BatchNorm2d(in_channels)
|
|
61
|
+
|
|
62
|
+
# Dynamic gate generator
|
|
63
|
+
self.gate_generator = nn.Conv2d(
|
|
64
|
+
in_channels,
|
|
65
|
+
in_channels,
|
|
66
|
+
kernel_size=1,
|
|
67
|
+
padding=0,
|
|
68
|
+
bias=True
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self._init_weights()
|
|
72
|
+
|
|
73
|
+
def _init_weights(self) -> None:
|
|
74
|
+
"""Initialize weights to ensure stable early-stage training."""
|
|
75
|
+
nn.init.kaiming_normal_(self.spatial_conv.weight, mode='fan_in', nonlinearity='relu')
|
|
76
|
+
nn.init.constant_(self.spatial_bn.weight, 1)
|
|
77
|
+
nn.init.constant_(self.spatial_bn.bias, 0)
|
|
78
|
+
|
|
79
|
+
# Gate generator initialization
|
|
80
|
+
nn.init.constant_(self.gate_generator.weight, 0)
|
|
81
|
+
# Starts gate near ~0.88 for stability during initial epochs
|
|
82
|
+
nn.init.constant_(self.gate_generator.bias, 2.0)
|
|
83
|
+
|
|
84
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
85
|
+
"""Forward pass.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
x : torch.Tensor
|
|
90
|
+
Shape ``(B, C, H, W)`` with ``C == in_channels``.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
torch.Tensor
|
|
95
|
+
Same shape as *x*.
|
|
96
|
+
"""
|
|
97
|
+
# Extract spatial context
|
|
98
|
+
T_x = self.spatial_bn(self.spatial_conv(x))
|
|
99
|
+
|
|
100
|
+
# Calculate positive boost
|
|
101
|
+
boost = F.relu(T_x - x)
|
|
102
|
+
|
|
103
|
+
# Generate spatial gate
|
|
104
|
+
gate = torch.sigmoid(self.gate_generator(T_x))
|
|
105
|
+
|
|
106
|
+
# Add gated boost to identity
|
|
107
|
+
return x + (gate * boost)
|
|
108
|
+
|
|
109
|
+
def extra_repr(self) -> str:
|
|
110
|
+
return f"in_channels={self.in_channels}"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# Alias for drop-in use inside sequential blocks
|
|
114
|
+
SAGALayer = SAGA
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""
|
|
2
|
+
saga.blocks
|
|
3
|
+
===========
|
|
4
|
+
Ready-made convolutional building blocks that use SAGA as their internal
|
|
5
|
+
activation function. These blocks can be used as drop-in replacements for
|
|
6
|
+
standard residual blocks in U-Net, ResNet, or EDSR style architectures.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
|
|
14
|
+
from .activation import SAGA
|
|
15
|
+
|
|
16
|
+
__all__ = ["SAGAResBlock", "SAGABottleneck"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SAGAResBlock(nn.Module):
|
|
20
|
+
"""Residual block with SAGA activations."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
in_channels: int,
|
|
25
|
+
out_channels: int | None = None,
|
|
26
|
+
stride: int = 1,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
out_channels = out_channels or in_channels
|
|
30
|
+
|
|
31
|
+
self.conv1 = nn.Conv2d(
|
|
32
|
+
in_channels, out_channels, 3, stride=stride, padding=1, bias=False
|
|
33
|
+
)
|
|
34
|
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
35
|
+
self.act1 = SAGA(out_channels)
|
|
36
|
+
|
|
37
|
+
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
|
|
38
|
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
39
|
+
self.act2 = SAGA(out_channels)
|
|
40
|
+
|
|
41
|
+
self.shortcut: nn.Module
|
|
42
|
+
if stride != 1 or in_channels != out_channels:
|
|
43
|
+
self.shortcut = nn.Sequential(
|
|
44
|
+
nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
|
|
45
|
+
nn.BatchNorm2d(out_channels),
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
self.shortcut = nn.Identity()
|
|
49
|
+
|
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
residual = self.shortcut(x)
|
|
52
|
+
out = self.act1(self.bn1(self.conv1(x)))
|
|
53
|
+
out = self.bn2(self.conv2(out))
|
|
54
|
+
return self.act2(out + residual)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SAGABottleneck(nn.Module):
|
|
58
|
+
"""Bottleneck block (1x1 -> 3x3 -> 1x1) with SAGA activations."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
in_channels: int,
|
|
63
|
+
bottleneck_channels: int | None = None,
|
|
64
|
+
out_channels: int | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
bottleneck_channels = bottleneck_channels or max(in_channels // 4, 1)
|
|
68
|
+
out_channels = out_channels or in_channels
|
|
69
|
+
|
|
70
|
+
self.net = nn.Sequential(
|
|
71
|
+
nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False),
|
|
72
|
+
nn.BatchNorm2d(bottleneck_channels),
|
|
73
|
+
SAGA(bottleneck_channels),
|
|
74
|
+
nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False),
|
|
75
|
+
nn.BatchNorm2d(bottleneck_channels),
|
|
76
|
+
SAGA(bottleneck_channels),
|
|
77
|
+
nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False),
|
|
78
|
+
nn.BatchNorm2d(out_channels),
|
|
79
|
+
)
|
|
80
|
+
self.skip = (
|
|
81
|
+
nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
|
82
|
+
if in_channels != out_channels
|
|
83
|
+
else nn.Identity()
|
|
84
|
+
)
|
|
85
|
+
self.out_act = SAGA(out_channels)
|
|
86
|
+
|
|
87
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
88
|
+
return self.out_act(self.net(x) + self.skip(x))
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
saga.utils
|
|
3
|
+
==========
|
|
4
|
+
Lightweight helpers for parameter accounting and gate control.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from .activation import SAGA
|
|
12
|
+
|
|
13
|
+
__all__ = ["count_parameters", "freeze_gate", "unfreeze_gate"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
|
|
17
|
+
"""Return the total number of (trainable) parameters in *model*.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
model : nn.Module
|
|
22
|
+
trainable_only : bool, optional
|
|
23
|
+
When *True* (default) only count parameters with
|
|
24
|
+
``requires_grad == True``.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
int
|
|
29
|
+
"""
|
|
30
|
+
return sum(
|
|
31
|
+
p.numel()
|
|
32
|
+
for p in model.parameters()
|
|
33
|
+
if (not trainable_only) or p.requires_grad
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _set_gate_grad(model: nn.Module, requires_grad: bool) -> None:
|
|
38
|
+
"""Recursively set requires_grad for all SAGA components."""
|
|
39
|
+
for module in model.modules():
|
|
40
|
+
if isinstance(module, SAGA):
|
|
41
|
+
for name in ("spatial_conv", "spatial_bn", "gate_generator"):
|
|
42
|
+
sub = getattr(module, name, None)
|
|
43
|
+
if sub is not None:
|
|
44
|
+
for p in sub.parameters():
|
|
45
|
+
p.requires_grad_(requires_grad)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def freeze_gate(model: nn.Module) -> None:
|
|
49
|
+
"""Freeze all SAGA gating parameters in *model*.
|
|
50
|
+
|
|
51
|
+
Useful for curriculum training: first train the backbone, then unfreeze
|
|
52
|
+
the gates for fine-tuning.
|
|
53
|
+
"""
|
|
54
|
+
_set_gate_grad(model, False)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def unfreeze_gate(model: nn.Module) -> None:
|
|
58
|
+
"""Unfreeze all SAGA gating parameters in *model*."""
|
|
59
|
+
_set_gate_grad(model, True)
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: saga-activation
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Spatially-Adaptive Gated Activation (SAGA) for medical image restoration
|
|
5
|
+
Author: Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
|
|
6
|
+
Author-email: "Siju K.S." <sijuks@example.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
Project-URL: Homepage, https://github.com/sijuswamyresearch/saga-activation
|
|
9
|
+
Project-URL: Documentation, https://sijuswamyresearch.github.io/saga-activation
|
|
10
|
+
Project-URL: Repository, https://github.com/sijuswamyresearch/saga-activation
|
|
11
|
+
Project-URL: Bug Tracker, https://github.com/sijuswamyresearch/saga-activation/issues
|
|
12
|
+
Project-URL: Paper, https://doi.org/10.1016/j.health.2026.100468
|
|
13
|
+
Keywords: deep learning,activation function,medical imaging,image restoration,deblurring,PyTorch
|
|
14
|
+
Classifier: Development Status :: 4 - Beta
|
|
15
|
+
Classifier: Intended Audience :: Science/Research
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
19
|
+
Classifier: Programming Language :: Python :: 3
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Requires-Python: >=3.10
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
License-File: LICENSE
|
|
26
|
+
Requires-Dist: torch>=2.0
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
30
|
+
Provides-Extra: docs
|
|
31
|
+
Requires-Dist: sphinx>=7.0; extra == "docs"
|
|
32
|
+
Requires-Dist: sphinx-rtd-theme>=2.0; extra == "docs"
|
|
33
|
+
Requires-Dist: myst-parser>=2.0; extra == "docs"
|
|
34
|
+
Dynamic: license-file
|
|
35
|
+
|
|
36
|
+
# SAGA — Spatially-Adaptive Gated Activation
|
|
37
|
+
|
|
38
|
+
[](https://github.com/sijuswamyresearch/SAGA/actions)
|
|
39
|
+
[](https://pypi.org/project/saga-activation/)
|
|
40
|
+
[](https://opensource.org/licenses/MIT)
|
|
41
|
+
[](https://doi.org/10.5281/zenodo.XXXXXXX)
|
|
42
|
+
[](https://doi.org/10.1016/j.health.2026.100468)
|
|
43
|
+
|
|
44
|
+
> **An Interpretable Deep Learning Method for Medical Image Deblurring and Restoration**
|
|
45
|
+
> Siju K.S., Vipin Venugopal, Mithun Kumar Kar, Jayakrishnan Anandakrishnan
|
|
46
|
+
> *Healthcare Analytics* 9 (2026) 100468 · [doi:10.1016/j.health.2026.100468](https://doi.org/10.1016/j.health.2026.100468)
|
|
47
|
+
|
|
48
|
+
---
|
|
49
|
+
|
|
50
|
+
## Overview
|
|
51
|
+
|
|
52
|
+
Standard activation functions (ReLU, SiLU, GELU) treat every spatial location in a feature map identically. In medical images — CT slices, DXA scans — the information content is *not* spatially uniform: anatomical boundaries carry high-frequency diagnostically-critical detail while homogeneous regions (background, soft tissue) require smooth suppression.
|
|
53
|
+
|
|
54
|
+
**SAGA** introduces a *learned spatial gating map* that modulates the activation response position-by-position:
|
|
55
|
+
|
|
56
|
+
```
|
|
57
|
+
G(X) = σ(W_g * X) # spatial gate (depthwise-separable conv)
|
|
58
|
+
SAGA(X) = G(X) ⊙ φ(X) # φ = SiLU (default)
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
This two-path design lets the network selectively amplify high-frequency boundary signals while smoothly gating uniform background areas — without increasing the depth of the network.
|
|
62
|
+
|
|
63
|
+
---
|
|
64
|
+
|
|
65
|
+
## Installation
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
pip install saga-activation
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
Or install from source:
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
git clone https://github.com/sijuswamyresearch/SAGA.git
|
|
75
|
+
cd SAGA
|
|
76
|
+
pip install -e ".[dev]"
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
**Requirements:** Python ≥ 3.10, PyTorch ≥ 2.0
|
|
80
|
+
|
|
81
|
+
---
|
|
82
|
+
|
|
83
|
+
## Quick Start
|
|
84
|
+
|
|
85
|
+
### Drop-in activation replacement
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
import torch
|
|
89
|
+
from saga import SAGA
|
|
90
|
+
|
|
91
|
+
# Replace any activation layer with SAGA
|
|
92
|
+
act = SAGA(in_channels=64) # matches the channel dim of your feature map
|
|
93
|
+
x = torch.randn(2, 64, 256, 256) # (B, C, H, W)
|
|
94
|
+
y = act(x) # same shape: (2, 64, 256, 256)
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
### Inside a U-Net encoder block
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
import torch.nn as nn
|
|
101
|
+
from saga import SAGA
|
|
102
|
+
|
|
103
|
+
class EncoderBlock(nn.Module):
|
|
104
|
+
def __init__(self, in_ch, out_ch):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.block = nn.Sequential(
|
|
107
|
+
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
|
108
|
+
nn.BatchNorm2d(out_ch),
|
|
109
|
+
SAGA(out_ch), # ← swap in SAGA here
|
|
110
|
+
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
|
|
111
|
+
nn.BatchNorm2d(out_ch),
|
|
112
|
+
SAGA(out_ch),
|
|
113
|
+
)
|
|
114
|
+
self.pool = nn.MaxPool2d(2)
|
|
115
|
+
|
|
116
|
+
def forward(self, x):
|
|
117
|
+
return self.pool(self.block(x))
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
### Pre-built residual blocks
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
from saga import SAGAResBlock, SAGABottleneck
|
|
124
|
+
|
|
125
|
+
res = SAGAResBlock(64) # standard residual block
|
|
126
|
+
bottle = SAGABottleneck(64, out_channels=128) # bottleneck variant
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
### Base-activation variants
|
|
130
|
+
|
|
131
|
+
```python
|
|
132
|
+
from saga import SAGA
|
|
133
|
+
|
|
134
|
+
act_relu = SAGA(64, base_activation="relu")
|
|
135
|
+
act_gelu = SAGA(64, base_activation="gelu")
|
|
136
|
+
act_tanh = SAGA(64, base_activation="tanh")
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
### Gate curriculum training
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
from saga.utils import freeze_gate, unfreeze_gate
|
|
143
|
+
|
|
144
|
+
# Phase 1 – train backbone only
|
|
145
|
+
freeze_gate(model)
|
|
146
|
+
train(model, epochs=10, lr=1e-3)
|
|
147
|
+
|
|
148
|
+
# Phase 2 – fine-tune gates
|
|
149
|
+
unfreeze_gate(model)
|
|
150
|
+
train(model, epochs=5, lr=1e-4)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
---
|
|
154
|
+
|
|
155
|
+
## Repository Structure
|
|
156
|
+
|
|
157
|
+
```
|
|
158
|
+
SAGA/
|
|
159
|
+
├── saga/ # installable Python package
|
|
160
|
+
│ ├── __init__.py
|
|
161
|
+
│ ├── activation.py # SAGA operator (core)
|
|
162
|
+
│ ├── blocks.py # SAGAResBlock, SAGABottleneck
|
|
163
|
+
│ └── utils.py # parameter counting, gate freeze helpers
|
|
164
|
+
│
|
|
165
|
+
├── tests/
|
|
166
|
+
│ ├── conftest.py
|
|
167
|
+
│ └── test_saga.py # pytest suite (shapes, edge cases, GPU, gradients)
|
|
168
|
+
│
|
|
169
|
+
├── SAGA_Supplementary_Code/ # original experimental pipeline
|
|
170
|
+
│ ├── models/
|
|
171
|
+
│ │ ├── saga_layer.py # raw research implementation
|
|
172
|
+
│ │ ├── unet.py
|
|
173
|
+
│ │ ├── resnet.py
|
|
174
|
+
│ │ ├── edsr.py
|
|
175
|
+
│ │ └── vggnet.py
|
|
176
|
+
│ ├── generate_dataset.py
|
|
177
|
+
│ ├── train.py
|
|
178
|
+
│ ├── evaluate.py
|
|
179
|
+
│ ├── xai_analysis.py
|
|
180
|
+
│ └── clinical_validation.py
|
|
181
|
+
│
|
|
182
|
+
├── docs/ # Sphinx documentation source
|
|
183
|
+
├── .github/workflows/ci.yml # GitHub Actions CI
|
|
184
|
+
├── pyproject.toml
|
|
185
|
+
└── README.md
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
---
|
|
189
|
+
|
|
190
|
+
## Experimental Results (summary)
|
|
191
|
+
|
|
192
|
+
| Model | Activation | CT PSNR (dB) | CT SSIM | DXA PSNR (dB) | DXA SSIM |
|
|
193
|
+
|---------------|-----------|:------------:|:-------:|:-------------:|:--------:|
|
|
194
|
+
| U-Net | ReLU | 32.14 | 0.891 | 30.87 | 0.873 |
|
|
195
|
+
| U-Net | SiLU | 33.01 | 0.902 | 31.54 | 0.881 |
|
|
196
|
+
| **U-Net** | **SAGA** | **34.67** | **0.921** | **33.12** | **0.903** |
|
|
197
|
+
| DeblurResNet | ReLU | 31.89 | 0.883 | 30.21 | 0.864 |
|
|
198
|
+
| **DeblurResNet** | **SAGA** | **34.11** | **0.916** | **32.78** | **0.897** |
|
|
199
|
+
|
|
200
|
+
Full results and ablation studies are reported in the paper.
|
|
201
|
+
|
|
202
|
+
---
|
|
203
|
+
|
|
204
|
+
## Running the Tests
|
|
205
|
+
|
|
206
|
+
```bash
|
|
207
|
+
pytest tests/ -v
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
To run with coverage:
|
|
211
|
+
|
|
212
|
+
```bash
|
|
213
|
+
pytest tests/ --cov=saga --cov-report=term-missing
|
|
214
|
+
```
|
|
215
|
+
|
|
216
|
+
---
|
|
217
|
+
|
|
218
|
+
## Citing
|
|
219
|
+
|
|
220
|
+
If SAGA is useful in your research, please cite:
|
|
221
|
+
|
|
222
|
+
```bibtex
|
|
223
|
+
@article{siju2026saga,
|
|
224
|
+
title = {An interpretable deep learning method for medical image deblurring and restoration},
|
|
225
|
+
author = {Siju K.S. and Vipin Venugopal and Mithun Kumar Kar and Jayakrishnan Anandakrishnan},
|
|
226
|
+
journal = {Healthcare Analytics},
|
|
227
|
+
volume = {9},
|
|
228
|
+
pages = {100468},
|
|
229
|
+
year = {2026},
|
|
230
|
+
doi = {10.1016/j.health.2026.100468}
|
|
231
|
+
}
|
|
232
|
+
```
|
|
233
|
+
|
|
234
|
+
---
|
|
235
|
+
|
|
236
|
+
## License
|
|
237
|
+
|
|
238
|
+
[MIT](LICENSE)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/saga/__init__.py
|
|
5
|
+
src/saga/activation.py
|
|
6
|
+
src/saga/blocks.py
|
|
7
|
+
src/saga/utils.py
|
|
8
|
+
src/saga_activation.egg-info/PKG-INFO
|
|
9
|
+
src/saga_activation.egg-info/SOURCES.txt
|
|
10
|
+
src/saga_activation.egg-info/dependency_links.txt
|
|
11
|
+
src/saga_activation.egg-info/requires.txt
|
|
12
|
+
src/saga_activation.egg-info/top_level.txt
|
|
13
|
+
tests/test_saga.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
saga
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""
|
|
2
|
+
tests/test_saga.py
|
|
3
|
+
==================
|
|
4
|
+
pytest test suite for the SAGA activation package.
|
|
5
|
+
|
|
6
|
+
Run with:
|
|
7
|
+
pytest tests/ -v
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
|
|
14
|
+
from src.saga import SAGA, SAGALayer, SAGAResBlock, SAGABottleneck
|
|
15
|
+
from src.saga.utils import count_parameters, freeze_gate, unfreeze_gate
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
# Fixtures
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
|
|
22
|
+
DEVICES = ["cpu"]
|
|
23
|
+
if torch.cuda.is_available():
|
|
24
|
+
DEVICES.append("cuda")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture(params=DEVICES)
|
|
28
|
+
def device(request):
|
|
29
|
+
return torch.device(request.param)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
# 1. Output shape tests
|
|
34
|
+
# ---------------------------------------------------------------------------
|
|
35
|
+
|
|
36
|
+
class TestOutputShape:
|
|
37
|
+
@pytest.mark.parametrize("B,C,H,W", [
|
|
38
|
+
(2, 1, 8, 8), # B=2 to avoid BatchNorm crash during train mode
|
|
39
|
+
(2, 16, 32, 32),
|
|
40
|
+
(4, 64, 128, 128),
|
|
41
|
+
(2, 3, 256, 256),
|
|
42
|
+
])
|
|
43
|
+
def test_shape_preserved(self, B, C, H, W):
|
|
44
|
+
"""SAGA must return a tensor with the same shape as the input."""
|
|
45
|
+
act = SAGA(in_channels=C)
|
|
46
|
+
x = torch.randn(B, C, H, W)
|
|
47
|
+
y = act(x)
|
|
48
|
+
assert y.shape == x.shape, f"Expected {x.shape}, got {y.shape}"
|
|
49
|
+
|
|
50
|
+
def test_single_channel(self):
|
|
51
|
+
act = SAGA(in_channels=1)
|
|
52
|
+
x = torch.randn(2, 1, 4, 4)
|
|
53
|
+
assert act(x).shape == x.shape
|
|
54
|
+
|
|
55
|
+
def test_large_channel(self):
|
|
56
|
+
act = SAGA(in_channels=512)
|
|
57
|
+
x = torch.randn(2, 512, 8, 8)
|
|
58
|
+
assert act(x).shape == x.shape
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# ---------------------------------------------------------------------------
|
|
62
|
+
# 2. Edge-case tensor value tests
|
|
63
|
+
# ---------------------------------------------------------------------------
|
|
64
|
+
|
|
65
|
+
class TestEdgeCases:
|
|
66
|
+
def test_zero_input(self):
|
|
67
|
+
act = SAGA(in_channels=8)
|
|
68
|
+
x = torch.zeros(2, 8, 16, 16)
|
|
69
|
+
y = act(x)
|
|
70
|
+
assert y.shape == x.shape
|
|
71
|
+
assert torch.isfinite(y).all()
|
|
72
|
+
|
|
73
|
+
def test_negative_input(self):
|
|
74
|
+
act = SAGA(in_channels=8)
|
|
75
|
+
x = -torch.abs(torch.randn(2, 8, 16, 16))
|
|
76
|
+
y = act(x)
|
|
77
|
+
assert torch.isfinite(y).all()
|
|
78
|
+
|
|
79
|
+
def test_large_positive_input(self):
|
|
80
|
+
act = SAGA(in_channels=8)
|
|
81
|
+
x = torch.full((2, 8, 16, 16), 1e3)
|
|
82
|
+
y = act(x)
|
|
83
|
+
assert torch.isfinite(y).all()
|
|
84
|
+
|
|
85
|
+
def test_large_negative_input(self):
|
|
86
|
+
act = SAGA(in_channels=8)
|
|
87
|
+
x = torch.full((2, 8, 16, 16), -1e3)
|
|
88
|
+
y = act(x)
|
|
89
|
+
assert torch.isfinite(y).all()
|
|
90
|
+
|
|
91
|
+
def test_nan_free(self):
|
|
92
|
+
"""Confirm no NaN leaks on random inputs."""
|
|
93
|
+
torch.manual_seed(42)
|
|
94
|
+
act = SAGA(in_channels=32)
|
|
95
|
+
for _ in range(10):
|
|
96
|
+
x = torch.randn(2, 32, 64, 64) * 10
|
|
97
|
+
y = act(x)
|
|
98
|
+
assert not torch.isnan(y).any(), "NaN detected in SAGA output"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ---------------------------------------------------------------------------
|
|
102
|
+
# 3. Gradient / backprop tests
|
|
103
|
+
# ---------------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
class TestGradients:
|
|
106
|
+
def test_backward_cpu(self):
|
|
107
|
+
act = SAGA(in_channels=16)
|
|
108
|
+
x = torch.randn(2, 16, 32, 32, requires_grad=True)
|
|
109
|
+
y = act(x).sum()
|
|
110
|
+
y.backward()
|
|
111
|
+
assert x.grad is not None
|
|
112
|
+
assert torch.isfinite(x.grad).all()
|
|
113
|
+
|
|
114
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
115
|
+
def test_backward_cuda(self):
|
|
116
|
+
act = SAGA(in_channels=16).cuda()
|
|
117
|
+
x = torch.randn(2, 16, 32, 32, requires_grad=True, device="cuda")
|
|
118
|
+
y = act(x).sum()
|
|
119
|
+
y.backward()
|
|
120
|
+
assert x.grad is not None
|
|
121
|
+
assert torch.isfinite(x.grad).all()
|
|
122
|
+
|
|
123
|
+
def test_gate_parameters_get_gradients(self):
|
|
124
|
+
act = SAGA(in_channels=8)
|
|
125
|
+
x = torch.randn(2, 8, 16, 16)
|
|
126
|
+
act(x).sum().backward()
|
|
127
|
+
for name, p in act.named_parameters():
|
|
128
|
+
assert p.grad is not None, f"No gradient for parameter '{name}'"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ---------------------------------------------------------------------------
|
|
132
|
+
# 4. Device tests
|
|
133
|
+
# ---------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
136
|
+
class TestCUDA:
|
|
137
|
+
def test_forward_cuda(self):
|
|
138
|
+
act = SAGA(in_channels=32).cuda()
|
|
139
|
+
x = torch.randn(2, 32, 64, 64, device="cuda")
|
|
140
|
+
y = act(x)
|
|
141
|
+
assert y.device.type == "cuda"
|
|
142
|
+
assert y.shape == x.shape
|
|
143
|
+
|
|
144
|
+
def test_no_memory_leak(self):
|
|
145
|
+
"""Check GPU memory does not grow unboundedly over repeated calls."""
|
|
146
|
+
act = SAGA(in_channels=64).cuda()
|
|
147
|
+
torch.cuda.reset_peak_memory_stats()
|
|
148
|
+
baseline = torch.cuda.memory_allocated()
|
|
149
|
+
|
|
150
|
+
for _ in range(50):
|
|
151
|
+
x = torch.randn(4, 64, 128, 128, device="cuda")
|
|
152
|
+
_ = act(x)
|
|
153
|
+
del x
|
|
154
|
+
|
|
155
|
+
torch.cuda.synchronize()
|
|
156
|
+
peak = torch.cuda.max_memory_allocated()
|
|
157
|
+
# Allow at most 200 MB overhead
|
|
158
|
+
assert (peak - baseline) < 200 * 1024 ** 2, "Potential memory leak detected"
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
# 5. Building-block tests
|
|
163
|
+
# ---------------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
class TestBlocks:
|
|
166
|
+
def test_resblock_shape(self):
|
|
167
|
+
block = SAGAResBlock(64)
|
|
168
|
+
x = torch.randn(2, 64, 32, 32)
|
|
169
|
+
assert block(x).shape == x.shape
|
|
170
|
+
|
|
171
|
+
def test_resblock_projection(self):
|
|
172
|
+
block = SAGAResBlock(32, out_channels=64, stride=2)
|
|
173
|
+
x = torch.randn(2, 32, 32, 32)
|
|
174
|
+
y = block(x)
|
|
175
|
+
assert y.shape == (2, 64, 16, 16)
|
|
176
|
+
|
|
177
|
+
def test_bottleneck_shape(self):
|
|
178
|
+
block = SAGABottleneck(64)
|
|
179
|
+
x = torch.randn(2, 64, 32, 32)
|
|
180
|
+
assert block(x).shape == x.shape
|
|
181
|
+
|
|
182
|
+
def test_bottleneck_channel_change(self):
|
|
183
|
+
block = SAGABottleneck(32, out_channels=128)
|
|
184
|
+
x = torch.randn(2, 32, 16, 16)
|
|
185
|
+
assert block(x).shape == (2, 128, 16, 16)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# ---------------------------------------------------------------------------
|
|
189
|
+
# 6. Utility tests
|
|
190
|
+
# ---------------------------------------------------------------------------
|
|
191
|
+
|
|
192
|
+
class TestUtils:
|
|
193
|
+
def test_count_parameters(self):
|
|
194
|
+
act = SAGA(in_channels=64)
|
|
195
|
+
n = count_parameters(act)
|
|
196
|
+
assert n > 0
|
|
197
|
+
|
|
198
|
+
def test_freeze_unfreeze_gate(self):
|
|
199
|
+
model = nn.Sequential(SAGA(32), SAGA(32))
|
|
200
|
+
|
|
201
|
+
# Test Freeze
|
|
202
|
+
freeze_gate(model)
|
|
203
|
+
for module in model.modules():
|
|
204
|
+
if isinstance(module, SAGA):
|
|
205
|
+
for p in module.spatial_conv.parameters():
|
|
206
|
+
assert not p.requires_grad
|
|
207
|
+
for p in module.gate_generator.parameters():
|
|
208
|
+
assert not p.requires_grad
|
|
209
|
+
|
|
210
|
+
# Test Unfreeze
|
|
211
|
+
unfreeze_gate(model)
|
|
212
|
+
for module in model.modules():
|
|
213
|
+
if isinstance(module, SAGA):
|
|
214
|
+
for p in module.spatial_conv.parameters():
|
|
215
|
+
assert p.requires_grad
|
|
216
|
+
for p in module.gate_generator.parameters():
|
|
217
|
+
assert p.requires_grad
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# ---------------------------------------------------------------------------
|
|
221
|
+
# 7. SAGALayer alias
|
|
222
|
+
# ---------------------------------------------------------------------------
|
|
223
|
+
|
|
224
|
+
def test_saga_layer_alias():
|
|
225
|
+
assert SAGALayer is SAGA
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# ---------------------------------------------------------------------------
|
|
229
|
+
# 8. Serialisation / state-dict round-trip
|
|
230
|
+
# ---------------------------------------------------------------------------
|
|
231
|
+
|
|
232
|
+
def test_state_dict_round_trip(tmp_path):
|
|
233
|
+
act = SAGA(in_channels=16)
|
|
234
|
+
x = torch.randn(2, 16, 8, 8)
|
|
235
|
+
y_before = act(x)
|
|
236
|
+
|
|
237
|
+
path = tmp_path / "saga.pt"
|
|
238
|
+
torch.save(act.state_dict(), path)
|
|
239
|
+
|
|
240
|
+
act2 = SAGA(in_channels=16)
|
|
241
|
+
act2.load_state_dict(torch.load(path, map_location="cpu"))
|
|
242
|
+
y_after = act2(x)
|
|
243
|
+
|
|
244
|
+
assert torch.allclose(y_before, y_after, atol=1e-6)
|