elm-prune 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- elm_prune-0.1.0/LICENSE +30 -0
- elm_prune-0.1.0/PKG-INFO +31 -0
- elm_prune-0.1.0/README.md +1 -0
- elm_prune-0.1.0/elm_prune.egg-info/PKG-INFO +31 -0
- elm_prune-0.1.0/elm_prune.egg-info/SOURCES.txt +17 -0
- elm_prune-0.1.0/elm_prune.egg-info/dependency_links.txt +1 -0
- elm_prune-0.1.0/elm_prune.egg-info/requires.txt +9 -0
- elm_prune-0.1.0/elm_prune.egg-info/top_level.txt +1 -0
- elm_prune-0.1.0/elmprune/__init__.py +20 -0
- elm_prune-0.1.0/elmprune/elm.py +105 -0
- elm_prune-0.1.0/elmprune/elm_importance_processor.py +164 -0
- elm_prune-0.1.0/elmprune/feature_extractor.py +217 -0
- elm_prune-0.1.0/elmprune/importance_processor_config.py +20 -0
- elm_prune-0.1.0/elmprune/prune_config.py +20 -0
- elm_prune-0.1.0/elmprune/prune_pipeline.py +108 -0
- elm_prune-0.1.0/elmprune/prune_processor.py +338 -0
- elm_prune-0.1.0/elmprune/utils.py +191 -0
- elm_prune-0.1.0/pyproject.toml +55 -0
- elm_prune-0.1.0/setup.cfg +4 -0
elm_prune-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Iago Rodrigues
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
This license applies only to the source code, packaging files, and documentation
|
|
16
|
+
contained in this repository for the Python package "mirc-dataset-handler".
|
|
17
|
+
|
|
18
|
+
It does not grant any rights over the MIRC dataset itself, including but not
|
|
19
|
+
limited to images, videos, annotations, derived data, redistributed copies, or
|
|
20
|
+
third-party mirrors of the dataset. Any use of the MIRC dataset must follow the
|
|
21
|
+
separate dataset terms, ethics requirements, and repository policies published
|
|
22
|
+
in the official dataset repository and/or official distribution channels.
|
|
23
|
+
|
|
24
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
25
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
26
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
27
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
28
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
29
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
30
|
+
SOFTWARE.
|
elm_prune-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: elm-prune
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: ELM-based structural pruning utilities for PyTorch segmentation models.
|
|
5
|
+
Author: Iago Rodrigues
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/iagorichard
|
|
8
|
+
Project-URL: Repository, https://github.com/iagorichard
|
|
9
|
+
Project-URL: Issues, https://github.com/iagorichard
|
|
10
|
+
Keywords: deep-learning,pytorch,pruning,structural-pruning,segmentation,elm
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.9
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: numpy>=1.24
|
|
22
|
+
Requires-Dist: tqdm>=4.66
|
|
23
|
+
Requires-Dist: torch>=2.5
|
|
24
|
+
Requires-Dist: segmentation-models-pytorch
|
|
25
|
+
Requires-Dist: torch-pruning==1.6.1
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: build>=1.2.2; extra == "dev"
|
|
28
|
+
Requires-Dist: twine>=5.1.1; extra == "dev"
|
|
29
|
+
Dynamic: license-file
|
|
30
|
+
|
|
31
|
+
# elm-pruning
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# elm-pruning
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: elm-prune
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: ELM-based structural pruning utilities for PyTorch segmentation models.
|
|
5
|
+
Author: Iago Rodrigues
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/iagorichard
|
|
8
|
+
Project-URL: Repository, https://github.com/iagorichard
|
|
9
|
+
Project-URL: Issues, https://github.com/iagorichard
|
|
10
|
+
Keywords: deep-learning,pytorch,pruning,structural-pruning,segmentation,elm
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.9
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: numpy>=1.24
|
|
22
|
+
Requires-Dist: tqdm>=4.66
|
|
23
|
+
Requires-Dist: torch>=2.5
|
|
24
|
+
Requires-Dist: segmentation-models-pytorch
|
|
25
|
+
Requires-Dist: torch-pruning==1.6.1
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: build>=1.2.2; extra == "dev"
|
|
28
|
+
Requires-Dist: twine>=5.1.1; extra == "dev"
|
|
29
|
+
Dynamic: license-file
|
|
30
|
+
|
|
31
|
+
# elm-pruning
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
elm_prune.egg-info/PKG-INFO
|
|
5
|
+
elm_prune.egg-info/SOURCES.txt
|
|
6
|
+
elm_prune.egg-info/dependency_links.txt
|
|
7
|
+
elm_prune.egg-info/requires.txt
|
|
8
|
+
elm_prune.egg-info/top_level.txt
|
|
9
|
+
elmprune/__init__.py
|
|
10
|
+
elmprune/elm.py
|
|
11
|
+
elmprune/elm_importance_processor.py
|
|
12
|
+
elmprune/feature_extractor.py
|
|
13
|
+
elmprune/importance_processor_config.py
|
|
14
|
+
elmprune/prune_config.py
|
|
15
|
+
elmprune/prune_pipeline.py
|
|
16
|
+
elmprune/prune_processor.py
|
|
17
|
+
elmprune/utils.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
elmprune
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from . import utils
|
|
2
|
+
from .importance_processor_config import ImportanceProcessorConfig
|
|
3
|
+
from .feature_extractor import FeatureExtractor
|
|
4
|
+
from .elm_importance_processor import ELMImportanceProcessor
|
|
5
|
+
from .elm import ELMRegressor
|
|
6
|
+
from .prune_processor import PruneProcessor
|
|
7
|
+
from .prune_config import PruneConfig
|
|
8
|
+
from .prune_config import PruneVerboseLevel
|
|
9
|
+
from .prune_pipeline import PrunePipeline
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = ["utils",
|
|
13
|
+
"ImportanceProcessorConfig",
|
|
14
|
+
"FeatureExtractor",
|
|
15
|
+
"ELMImportanceProcessor",
|
|
16
|
+
"ELMRegressor",
|
|
17
|
+
"PruneProcessor",
|
|
18
|
+
"PruneConfig",
|
|
19
|
+
"PruneVerboseLevel",
|
|
20
|
+
"PrunePipeline"]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
import math
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ELMRegressor:
|
|
7
|
+
|
|
8
|
+
def __init__(self, hidden_dim: int, reg_lambda: float, activation_name: str, seed: int, eps: float, use_double_for_solver: bool):
|
|
9
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
10
|
+
self.hidden_dim = hidden_dim
|
|
11
|
+
self.reg_lambda = reg_lambda
|
|
12
|
+
self.activation_name = activation_name
|
|
13
|
+
self.seed = seed
|
|
14
|
+
self.eps = eps
|
|
15
|
+
self.use_double_for_solver = use_double_for_solver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def fit(self, X: torch.Tensor, Y: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
19
|
+
X = X.to(self.device)
|
|
20
|
+
Y = Y.to(self.device)
|
|
21
|
+
|
|
22
|
+
self.X_mean = X.mean(dim=0, keepdim=True)
|
|
23
|
+
self.X_std = X.std(dim=0, keepdim=True, unbiased=False).clamp_min(self.eps)
|
|
24
|
+
Xn = (X - self.X_mean) / self.X_std
|
|
25
|
+
|
|
26
|
+
self.Y_mean = Y.mean(dim=0, keepdim=True)
|
|
27
|
+
Yn = Y - self.Y_mean
|
|
28
|
+
|
|
29
|
+
generator = torch.Generator(device=self.device)
|
|
30
|
+
generator.manual_seed(self.seed)
|
|
31
|
+
|
|
32
|
+
in_dim = Xn.shape[1]
|
|
33
|
+
out_dim = Yn.shape[1]
|
|
34
|
+
|
|
35
|
+
self.W = torch.randn((in_dim, self.hidden_dim), generator=generator, device=self.device) / math.sqrt(max(in_dim, 1))
|
|
36
|
+
self.b = torch.randn((self.hidden_dim,), generator=generator, device=self.device)
|
|
37
|
+
|
|
38
|
+
H = self.__apply_activation(Xn @ self.W + self.b, self.activation_name)
|
|
39
|
+
|
|
40
|
+
I = torch.eye(self.hidden_dim, device=self.device, dtype=H.dtype)
|
|
41
|
+
lhs = H.T @ H + self.reg_lambda * I
|
|
42
|
+
rhs = H.T @ Yn
|
|
43
|
+
|
|
44
|
+
if self.use_double_for_solver:
|
|
45
|
+
self.beta = torch.linalg.solve(lhs.double(), rhs.double()).to(H.dtype)
|
|
46
|
+
else:
|
|
47
|
+
self.beta = torch.linalg.solve(lhs, rhs)
|
|
48
|
+
|
|
49
|
+
def predict(self, X: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
X = X.to(self.W.device)
|
|
51
|
+
|
|
52
|
+
Xn = (X - self.X_mean) / self.X_std
|
|
53
|
+
H = self.__apply_activation(
|
|
54
|
+
Xn @ self.W + self.b,
|
|
55
|
+
self.activation_name,
|
|
56
|
+
)
|
|
57
|
+
return H @ self.beta + self.Y_mean
|
|
58
|
+
|
|
59
|
+
def compute_ablation_importance(self, X: torch.Tensor, Y: torch.Tensor) -> List[float]:
|
|
60
|
+
"""
|
|
61
|
+
Importance = increase in ELM loss when feature is neutralized to its mean value.
|
|
62
|
+
Since pruning removes less important filters, low scores should be pruned first.
|
|
63
|
+
"""
|
|
64
|
+
X = X.to(self.W.device)
|
|
65
|
+
Y = Y.to(self.W.device)
|
|
66
|
+
|
|
67
|
+
base_pred = self.predict(X)
|
|
68
|
+
base_loss = self.calculate_loss(base_pred, Y).item()
|
|
69
|
+
|
|
70
|
+
importances: List[float] = []
|
|
71
|
+
X_work = X.clone()
|
|
72
|
+
|
|
73
|
+
for feature_idx in range(X.shape[1]):
|
|
74
|
+
original_column = X_work[:, feature_idx].clone()
|
|
75
|
+
|
|
76
|
+
# Neutralize feature by sending it to its mean value
|
|
77
|
+
X_work[:, feature_idx] = self.X_mean[0, feature_idx]
|
|
78
|
+
|
|
79
|
+
ablated_pred = self.predict(X_work)
|
|
80
|
+
ablated_loss = self.calculate_loss(ablated_pred, Y).item()
|
|
81
|
+
|
|
82
|
+
importance = max(ablated_loss - base_loss, 0.0)
|
|
83
|
+
importances.append(float(importance))
|
|
84
|
+
|
|
85
|
+
X_work[:, feature_idx] = original_column
|
|
86
|
+
|
|
87
|
+
return importances
|
|
88
|
+
|
|
89
|
+
def calculate_loss(self, Y_pred, Y_original):
|
|
90
|
+
return self.__mse(Y_pred, Y_original)
|
|
91
|
+
|
|
92
|
+
def __apply_activation(self, x: torch.Tensor, activation_name: str) -> torch.Tensor:
|
|
93
|
+
activation_name = activation_name.lower()
|
|
94
|
+
|
|
95
|
+
if activation_name == "tanh":
|
|
96
|
+
return torch.tanh(x)
|
|
97
|
+
if activation_name == "relu":
|
|
98
|
+
return torch.relu(x)
|
|
99
|
+
if activation_name == "sigmoid":
|
|
100
|
+
return torch.sigmoid(x)
|
|
101
|
+
|
|
102
|
+
raise ValueError(f"Unsupported activation: {activation_name}")
|
|
103
|
+
|
|
104
|
+
def __mse(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
105
|
+
return torch.mean((pred - target) ** 2)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from typing import Dict, Iterable, List
|
|
2
|
+
from tqdm.auto import tqdm
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from .utils import get_all_conv_layer_names, compute_constant_baseline_loss, dump_dict, load_dict
|
|
6
|
+
from .elm import ELMRegressor
|
|
7
|
+
from .importance_processor_config import ImportanceProcessorConfig
|
|
8
|
+
from .feature_extractor import FeatureExtractor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ELMImportanceProcessor:
|
|
12
|
+
|
|
13
|
+
IMPORTANCES_RELATIVE_FOLDER = "importances"
|
|
14
|
+
IMPORTANCES_GLOBAL_CACHE_FILENAME = "importances_elm_global.json"
|
|
15
|
+
IMPORTANCES_LAYERWISE_CACHE_FILENAME = "importances_elm_layerwise.json"
|
|
16
|
+
IMPORTANCES_FILTERWISE_CACHE_FILENAME = "importances_elm_filterwise.json"
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: ImportanceProcessorConfig, model: nn.Module, dataloader: Iterable):
|
|
19
|
+
self.config = config
|
|
20
|
+
self.layer_names = get_all_conv_layer_names(model) if config.layer_names == "" else config.layer_names
|
|
21
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
22
|
+
|
|
23
|
+
self.importances_path = config.abs_path / ELMImportanceProcessor.IMPORTANCES_RELATIVE_FOLDER
|
|
24
|
+
self.importances_path.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
|
|
26
|
+
self.has_global_cache, self.has_layerwise_cache, self.has_filterwise_cache = self.__verify_cache()
|
|
27
|
+
if not config.use_cache or not self.has_global_cache or not self.has_layerwise_cache or not self.has_filterwise_cache:
|
|
28
|
+
feature_extractor = FeatureExtractor(config, model, dataloader, self.layer_names)
|
|
29
|
+
self.features_by_layer, self.targets = feature_extractor.extract_feature_and_targets()
|
|
30
|
+
else:
|
|
31
|
+
print("[ELMImportanceProcessor] Not necessary to extract features here! Using cache for all importances type.")
|
|
32
|
+
|
|
33
|
+
def __verify_cache(self):
|
|
34
|
+
has_global_cache = (self.importances_path / ELMImportanceProcessor.IMPORTANCES_GLOBAL_CACHE_FILENAME).exists()
|
|
35
|
+
has_layerwise_cache = (self.importances_path / ELMImportanceProcessor.IMPORTANCES_LAYERWISE_CACHE_FILENAME).exists()
|
|
36
|
+
has_filterwise_cache = (self.importances_path / ELMImportanceProcessor.IMPORTANCES_FILTERWISE_CACHE_FILENAME).exists()
|
|
37
|
+
return has_global_cache, has_layerwise_cache, has_filterwise_cache
|
|
38
|
+
|
|
39
|
+
def compute_elm_global_importances(self) -> Dict[str, List[float]]:
|
|
40
|
+
"""
|
|
41
|
+
One single ELM trained with features from all selected layers concatenated.
|
|
42
|
+
Importance of each filter = increase in ELM loss when that feature is neutralized.
|
|
43
|
+
"""
|
|
44
|
+
print("[ELMImportanceProcessor] Getting importances for global...")
|
|
45
|
+
|
|
46
|
+
if self.has_global_cache:
|
|
47
|
+
print("[ELMImportanceProcessor] Not necessary to calculate importances here! Using cache for this importance type.")
|
|
48
|
+
result = load_dict(self.importances_path / ELMImportanceProcessor.IMPORTANCES_GLOBAL_CACHE_FILENAME)
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
if len(self.features_by_layer) == 0:
|
|
52
|
+
raise RuntimeError("No features were collected for ELM global importance.")
|
|
53
|
+
|
|
54
|
+
X_parts = [self.features_by_layer[layer_name] for layer_name in self.layer_names]
|
|
55
|
+
X = torch.cat(X_parts, dim=1)
|
|
56
|
+
Y = self.targets
|
|
57
|
+
|
|
58
|
+
elm_model = ELMRegressor(
|
|
59
|
+
hidden_dim=self.config.hidden_dim,
|
|
60
|
+
reg_lambda=self.config.reg_lambda,
|
|
61
|
+
activation_name=self.config.activation,
|
|
62
|
+
seed=self.config.seed,
|
|
63
|
+
eps=self.config.eps,
|
|
64
|
+
use_double_for_solver=self.config.use_double_for_solver,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
elm_model.fit(X, Y)
|
|
68
|
+
importances = elm_model.compute_ablation_importance(X, Y)
|
|
69
|
+
|
|
70
|
+
result: Dict[str, List[float]] = {}
|
|
71
|
+
offset = 0
|
|
72
|
+
for layer_name in tqdm(self.layer_names, desc="ELM global feature ranking processing", dynamic_ncols=True, position=1, leave=False):
|
|
73
|
+
channels = self.features_by_layer[layer_name].shape[1]
|
|
74
|
+
result[layer_name] = importances[offset: offset + channels]
|
|
75
|
+
offset += channels
|
|
76
|
+
|
|
77
|
+
importances_dump_path = self.importances_path / ELMImportanceProcessor.IMPORTANCES_GLOBAL_CACHE_FILENAME
|
|
78
|
+
dump_dict(result, importances_dump_path)
|
|
79
|
+
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def compute_elm_layerwise_importances(self) -> Dict[str, List[float]]:
|
|
83
|
+
"""
|
|
84
|
+
One ELM per layer.
|
|
85
|
+
Importance of each filter = increase in ELM loss when that feature is neutralized.
|
|
86
|
+
"""
|
|
87
|
+
print("[ELMImportanceProcessor] Getting importances for layerwise...")
|
|
88
|
+
|
|
89
|
+
if self.has_layerwise_cache:
|
|
90
|
+
print("[ELMImportanceProcessor]: Not necessary to calculate importances here! Using cache for this importance type.")
|
|
91
|
+
result = load_dict(self.importances_path / ELMImportanceProcessor.IMPORTANCES_LAYERWISE_CACHE_FILENAME)
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
result: Dict[str, List[float]] = {}
|
|
95
|
+
Y = self.targets.to(self.device)
|
|
96
|
+
|
|
97
|
+
for layer_name in tqdm(self.layer_names, desc="ELM layerwise feature ranking processing", dynamic_ncols=True, position=1, leave=False):
|
|
98
|
+
X = self.features_by_layer[layer_name].to(self.device)
|
|
99
|
+
|
|
100
|
+
elm_model = ELMRegressor(
|
|
101
|
+
hidden_dim=self.config.hidden_dim,
|
|
102
|
+
reg_lambda=self.config.reg_lambda,
|
|
103
|
+
activation_name=self.config.activation,
|
|
104
|
+
seed=self.config.seed,
|
|
105
|
+
eps=self.config.eps,
|
|
106
|
+
use_double_for_solver=self.config.use_double_for_solver,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
elm_model.fit(X, Y)
|
|
110
|
+
importances = elm_model.compute_ablation_importance(X, Y)
|
|
111
|
+
result[layer_name] = importances
|
|
112
|
+
|
|
113
|
+
importances_dump_path = self.importances_path / ELMImportanceProcessor.IMPORTANCES_LAYERWISE_CACHE_FILENAME
|
|
114
|
+
dump_dict(result, importances_dump_path)
|
|
115
|
+
|
|
116
|
+
return result
|
|
117
|
+
|
|
118
|
+
def compute_elm_filterwise_importances(self) -> Dict[str, List[float]]:
|
|
119
|
+
"""
|
|
120
|
+
One tiny ELM per filter.
|
|
121
|
+
Importance of one filter = how much that single filter alone reduces target reconstruction loss
|
|
122
|
+
compared with a constant baseline.
|
|
123
|
+
"""
|
|
124
|
+
print("[ELMImportanceProcessor] Getting importances for filterwise...")
|
|
125
|
+
|
|
126
|
+
if self.has_filterwise_cache:
|
|
127
|
+
print("[ELMImportanceProcessor]: Not necessary to calculate importances here! Using cache for this importance type.")
|
|
128
|
+
result = load_dict(self.importances_path / ELMImportanceProcessor.IMPORTANCES_FILTERWISE_CACHE_FILENAME)
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
result: Dict[str, List[float]] = {}
|
|
132
|
+
Y = self.targets.to(self.device)
|
|
133
|
+
baseline_loss = compute_constant_baseline_loss(Y)
|
|
134
|
+
|
|
135
|
+
for layer_name in tqdm(self.layer_names, desc="ELM filterwise feature ranking processing", dynamic_ncols=True, position=1, leave=False):
|
|
136
|
+
X_layer = self.features_by_layer[layer_name].to(self.device)
|
|
137
|
+
layer_importances: List[float] = []
|
|
138
|
+
|
|
139
|
+
for filter_idx in range(X_layer.shape[1]):
|
|
140
|
+
X_filter = X_layer[:, filter_idx:filter_idx + 1]
|
|
141
|
+
|
|
142
|
+
elm_model = ELMRegressor(
|
|
143
|
+
hidden_dim=self.config.hidden_dim_per_filter,
|
|
144
|
+
reg_lambda=self.config.reg_lambda,
|
|
145
|
+
activation_name=self.config.activation,
|
|
146
|
+
seed=self.config.seed + filter_idx,
|
|
147
|
+
eps=self.config.eps,
|
|
148
|
+
use_double_for_solver=self.config.use_double_for_solver,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
elm_model.fit(X_filter, Y)
|
|
152
|
+
pred = elm_model.predict(X_filter)
|
|
153
|
+
filter_loss = elm_model.calculate_loss(pred, Y).item()
|
|
154
|
+
|
|
155
|
+
# Higher reduction => more important
|
|
156
|
+
importance = max(baseline_loss - filter_loss, 0.0)
|
|
157
|
+
layer_importances.append(float(importance))
|
|
158
|
+
|
|
159
|
+
result[layer_name] = layer_importances
|
|
160
|
+
|
|
161
|
+
importances_dump_path = self.importances_path / ELMImportanceProcessor.IMPORTANCES_FILTERWISE_CACHE_FILENAME
|
|
162
|
+
dump_dict(result, importances_dump_path)
|
|
163
|
+
|
|
164
|
+
return result
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from tqdm.auto import tqdm
|
|
5
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
|
6
|
+
from .utils import get_layer_by_string
|
|
7
|
+
from .importance_processor_config import ImportanceProcessorConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FeatureExtractor:
|
|
11
|
+
|
|
12
|
+
def __init__(self, config: ImportanceProcessorConfig, model: nn.Module, dataloader: Iterable, layer_names: List[str]):
|
|
13
|
+
self.config = config
|
|
14
|
+
self.model = model
|
|
15
|
+
self.dataloader = dataloader
|
|
16
|
+
self.layer_names = layer_names
|
|
17
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
18
|
+
|
|
19
|
+
def extract_feature_and_targets(self)-> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
|
20
|
+
|
|
21
|
+
if self.config.feature_type == "segmentation":
|
|
22
|
+
target_extractor = self.__segmentation_mask_histogram_target_extractor()
|
|
23
|
+
else:
|
|
24
|
+
target_extractor = self.__default_logits_gap_target_extractor
|
|
25
|
+
|
|
26
|
+
return self.__process_features_and_targets(target_extractor)
|
|
27
|
+
|
|
28
|
+
def __process_features_and_targets(self, target_extractor: Callable[[Any, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
|
29
|
+
|
|
30
|
+
print("[FeatureExtractor] Extracting features for ELM...")
|
|
31
|
+
|
|
32
|
+
self.model.eval()
|
|
33
|
+
self.model = self.model.to(self.device)
|
|
34
|
+
storage_device = "cpu"
|
|
35
|
+
|
|
36
|
+
hooks = []
|
|
37
|
+
feature_storage: Dict[str, List[torch.Tensor]] = {layer_name: [] for layer_name in self.layer_names}
|
|
38
|
+
target_storage: List[torch.Tensor] = []
|
|
39
|
+
|
|
40
|
+
current_batch_features: Dict[str, torch.Tensor] = {}
|
|
41
|
+
|
|
42
|
+
def make_hook(layer_name: str):
|
|
43
|
+
def hook_fn(module, inputs, output):
|
|
44
|
+
out = self.__extract_first_tensor(output)
|
|
45
|
+
|
|
46
|
+
if out.dim() == 4:
|
|
47
|
+
# [N, C, H, W] -> [N, C]
|
|
48
|
+
pooled = out.mean(dim=(2, 3))
|
|
49
|
+
elif out.dim() == 3:
|
|
50
|
+
pooled = out.mean(dim=2)
|
|
51
|
+
elif out.dim() == 2:
|
|
52
|
+
pooled = out
|
|
53
|
+
else:
|
|
54
|
+
pooled = out.flatten(start_dim=1)
|
|
55
|
+
|
|
56
|
+
current_batch_features[layer_name] = pooled.detach().to(storage_device)
|
|
57
|
+
return hook_fn
|
|
58
|
+
|
|
59
|
+
# Register hooks
|
|
60
|
+
for layer_name in self.layer_names:
|
|
61
|
+
layer = get_layer_by_string(self.model, layer_name)
|
|
62
|
+
hooks.append(layer.register_forward_hook(make_hook(layer_name)))
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
with torch.no_grad():
|
|
66
|
+
total_batches = None
|
|
67
|
+
if hasattr(self.dataloader, "__len__"):
|
|
68
|
+
total_batches = len(self.dataloader)
|
|
69
|
+
if self.config.max_batches is not None:
|
|
70
|
+
total_batches = min(total_batches, self.config.max_batches)
|
|
71
|
+
|
|
72
|
+
progress_bar = tqdm(
|
|
73
|
+
enumerate(self.dataloader),
|
|
74
|
+
total=total_batches,
|
|
75
|
+
desc="Collecting features",
|
|
76
|
+
dynamic_ncols=True,
|
|
77
|
+
file=sys.stdout,
|
|
78
|
+
position=1,
|
|
79
|
+
leave=False
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
for batch_idx, batch in progress_bar:
|
|
83
|
+
if self.config.max_batches is not None and batch_idx >= self.config.max_batches:
|
|
84
|
+
break
|
|
85
|
+
|
|
86
|
+
inputs, targets = self.__unpack_batch(batch)
|
|
87
|
+
|
|
88
|
+
inputs = inputs.to(self.device)
|
|
89
|
+
targets = targets.to(self.device) if targets is not None else None
|
|
90
|
+
|
|
91
|
+
current_batch_features.clear()
|
|
92
|
+
|
|
93
|
+
model_output = self.model(inputs)
|
|
94
|
+
y = target_extractor(model_output, inputs, targets).to(storage_device)
|
|
95
|
+
|
|
96
|
+
for layer_name in self.layer_names:
|
|
97
|
+
if layer_name not in current_batch_features:
|
|
98
|
+
raise RuntimeError(f"No activation captured for layer '{layer_name}'.")
|
|
99
|
+
|
|
100
|
+
feature_storage[layer_name].append(current_batch_features[layer_name])
|
|
101
|
+
|
|
102
|
+
target_storage.append(y)
|
|
103
|
+
|
|
104
|
+
finally:
|
|
105
|
+
for hook in hooks:
|
|
106
|
+
hook.remove()
|
|
107
|
+
|
|
108
|
+
features_by_layer = {
|
|
109
|
+
layer_name: torch.cat(feature_storage[layer_name], dim=0)
|
|
110
|
+
for layer_name in self.layer_names
|
|
111
|
+
}
|
|
112
|
+
targets = torch.cat(target_storage, dim=0)
|
|
113
|
+
|
|
114
|
+
return features_by_layer, targets
|
|
115
|
+
|
|
116
|
+
def __default_logits_gap_target_extractor(self, model_output: Any) -> torch.Tensor:
|
|
117
|
+
"""
|
|
118
|
+
Default target extractor:
|
|
119
|
+
converts model output into [N, D] using a GAP-like reduction.
|
|
120
|
+
Very practical because it does not assume a specific task label encoding.
|
|
121
|
+
"""
|
|
122
|
+
out = self.__extract_first_tensor(model_output)
|
|
123
|
+
|
|
124
|
+
if out.dim() == 4:
|
|
125
|
+
# Typical segmentation logits: [N, C, H, W] -> [N, C]
|
|
126
|
+
return out.mean(dim=(2, 3)).detach()
|
|
127
|
+
if out.dim() == 3:
|
|
128
|
+
return out.mean(dim=2).detach()
|
|
129
|
+
if out.dim() == 2:
|
|
130
|
+
return out.detach()
|
|
131
|
+
|
|
132
|
+
return out.flatten(start_dim=1).detach()
|
|
133
|
+
|
|
134
|
+
def __segmentation_mask_histogram_target_extractor(self) -> Callable[[Any, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
135
|
+
"""
|
|
136
|
+
Returns a target extractor that uses the GT mask distribution per image.
|
|
137
|
+
For each image, target becomes [p(class0), p(class1), ..., p(classK)].
|
|
138
|
+
|
|
139
|
+
This is more task-aware for segmentation.
|
|
140
|
+
"""
|
|
141
|
+
def _extractor(model_output: Any, input_batch: torch.Tensor, target_batch: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
142
|
+
if target_batch is None:
|
|
143
|
+
raise ValueError("Target batch is required for mask histogram target extractor.")
|
|
144
|
+
|
|
145
|
+
y = target_batch
|
|
146
|
+
|
|
147
|
+
if y.dim() == 4 and y.shape[1] == 1:
|
|
148
|
+
y = y.squeeze(1)
|
|
149
|
+
|
|
150
|
+
if y.dim() != 3:
|
|
151
|
+
raise ValueError(f"Expected target mask with shape [N, H, W] or [N, 1, H, W], got {tuple(y.shape)}")
|
|
152
|
+
|
|
153
|
+
y = y.long()
|
|
154
|
+
histograms = []
|
|
155
|
+
|
|
156
|
+
for class_idx in range(self.config.num_classes):
|
|
157
|
+
class_ratio = (y == class_idx).float().mean(dim=(1, 2))
|
|
158
|
+
histograms.append(class_ratio)
|
|
159
|
+
|
|
160
|
+
return torch.stack(histograms, dim=1).detach()
|
|
161
|
+
|
|
162
|
+
return _extractor
|
|
163
|
+
|
|
164
|
+
def __extract_first_tensor(self, data: Any) -> torch.Tensor:
|
|
165
|
+
if torch.is_tensor(data):
|
|
166
|
+
return data
|
|
167
|
+
|
|
168
|
+
if isinstance(data, dict):
|
|
169
|
+
if "out" in data and torch.is_tensor(data["out"]):
|
|
170
|
+
return data["out"]
|
|
171
|
+
|
|
172
|
+
for value in data.values():
|
|
173
|
+
if torch.is_tensor(value):
|
|
174
|
+
return value
|
|
175
|
+
|
|
176
|
+
if isinstance(data, (list, tuple)):
|
|
177
|
+
for value in data:
|
|
178
|
+
if torch.is_tensor(value):
|
|
179
|
+
return value
|
|
180
|
+
|
|
181
|
+
raise TypeError("Could not extract a torch.Tensor from model output / hook output.")
|
|
182
|
+
|
|
183
|
+
def __unpack_batch(self, batch: Any) -> Tuple[torch.torch.Tensor, Optional[torch.torch.Tensor]]:
|
|
184
|
+
if torch.is_tensor(batch):
|
|
185
|
+
return batch, None
|
|
186
|
+
|
|
187
|
+
if isinstance(batch, dict):
|
|
188
|
+
if "image" not in batch:
|
|
189
|
+
raise KeyError("Batch dict must contain key 'image'.")
|
|
190
|
+
|
|
191
|
+
x = batch["image"]
|
|
192
|
+
y = batch.get("mask", None)
|
|
193
|
+
|
|
194
|
+
if not torch.is_tensor(x):
|
|
195
|
+
raise TypeError("Batch['image'] is not a torch.Tensor.")
|
|
196
|
+
|
|
197
|
+
if y is not None and not torch.is_tensor(y):
|
|
198
|
+
y = None
|
|
199
|
+
|
|
200
|
+
return x, y
|
|
201
|
+
|
|
202
|
+
if isinstance(batch, (list, tuple)):
|
|
203
|
+
if len(batch) == 0:
|
|
204
|
+
raise ValueError("Empty batch received.")
|
|
205
|
+
|
|
206
|
+
x = batch[0]
|
|
207
|
+
y = batch[1] if len(batch) > 1 else None
|
|
208
|
+
|
|
209
|
+
if not torch.is_tensor(x):
|
|
210
|
+
raise TypeError("Batch input is not a torch.Tensor.")
|
|
211
|
+
|
|
212
|
+
if y is not None and not torch.is_tensor(y):
|
|
213
|
+
y = None
|
|
214
|
+
|
|
215
|
+
return x, y
|
|
216
|
+
|
|
217
|
+
raise TypeError(f"Unsupported batch type: {type(batch)}")
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ImportanceProcessorConfig:
|
|
8
|
+
hidden_dim: int = 128
|
|
9
|
+
hidden_dim_per_filter: int = 16
|
|
10
|
+
reg_lambda: float = 1e-3
|
|
11
|
+
activation: str = "tanh" # tanh | relu | sigmoid
|
|
12
|
+
max_batches: Optional[int] = None
|
|
13
|
+
eps: float = 1e-8
|
|
14
|
+
seed: int = 42
|
|
15
|
+
use_double_for_solver: bool = True
|
|
16
|
+
feature_type: str = "segmentation" # segmention | logits
|
|
17
|
+
num_classes: int = 3 # if segmentation
|
|
18
|
+
layer_names = "" # to get importance for all layers, or list (str) to specify the layer names
|
|
19
|
+
abs_path: Path = None
|
|
20
|
+
use_cache: bool = True
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
class PruneVerboseLevel(Enum):
|
|
6
|
+
BASIC = 1
|
|
7
|
+
BASIC_ERROR = 2
|
|
8
|
+
ALL = 3
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PruneConfig:
|
|
12
|
+
importance_type: Literal["elm_global", "elm_layerwise", "elm_filterwise"]
|
|
13
|
+
target_param_reduction: float # ex.: 0.20 = reduzir 20% dos params reais
|
|
14
|
+
selection_scope: Literal["global", "local"]
|
|
15
|
+
min_channels_abs: int = 16
|
|
16
|
+
min_keep_ratio: float = 0.50
|
|
17
|
+
max_layer_prune_ratio: float = 0.35
|
|
18
|
+
per_step_layer_ratio: float = 0.05 # poda pequena por iteração
|
|
19
|
+
round_to: int = 8
|
|
20
|
+
verbose: PruneVerboseLevel = PruneVerboseLevel.BASIC_ERROR
|