elm-prune 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Iago Rodrigues
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ This license applies only to the source code, packaging files, and documentation
16
+ contained in this repository for the Python package "mirc-dataset-handler".
17
+
18
+ It does not grant any rights over the MIRC dataset itself, including but not
19
+ limited to images, videos, annotations, derived data, redistributed copies, or
20
+ third-party mirrors of the dataset. Any use of the MIRC dataset must follow the
21
+ separate dataset terms, ethics requirements, and repository policies published
22
+ in the official dataset repository and/or official distribution channels.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE.
@@ -0,0 +1,31 @@
1
+ Metadata-Version: 2.4
2
+ Name: elm-prune
3
+ Version: 0.1.0
4
+ Summary: ELM-based structural pruning utilities for PyTorch segmentation models.
5
+ Author: Iago Rodrigues
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/iagorichard
8
+ Project-URL: Repository, https://github.com/iagorichard
9
+ Project-URL: Issues, https://github.com/iagorichard
10
+ Keywords: deep-learning,pytorch,pruning,structural-pruning,segmentation,elm
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: OS Independent
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.9
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.9
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: numpy>=1.24
22
+ Requires-Dist: tqdm>=4.66
23
+ Requires-Dist: torch>=2.5
24
+ Requires-Dist: segmentation-models-pytorch
25
+ Requires-Dist: torch-pruning==1.6.1
26
+ Provides-Extra: dev
27
+ Requires-Dist: build>=1.2.2; extra == "dev"
28
+ Requires-Dist: twine>=5.1.1; extra == "dev"
29
+ Dynamic: license-file
30
+
31
+ # elm-pruning
@@ -0,0 +1 @@
1
+ # elm-pruning
@@ -0,0 +1,31 @@
1
+ Metadata-Version: 2.4
2
+ Name: elm-prune
3
+ Version: 0.1.0
4
+ Summary: ELM-based structural pruning utilities for PyTorch segmentation models.
5
+ Author: Iago Rodrigues
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/iagorichard
8
+ Project-URL: Repository, https://github.com/iagorichard
9
+ Project-URL: Issues, https://github.com/iagorichard
10
+ Keywords: deep-learning,pytorch,pruning,structural-pruning,segmentation,elm
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: OS Independent
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.9
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.9
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: numpy>=1.24
22
+ Requires-Dist: tqdm>=4.66
23
+ Requires-Dist: torch>=2.5
24
+ Requires-Dist: segmentation-models-pytorch
25
+ Requires-Dist: torch-pruning==1.6.1
26
+ Provides-Extra: dev
27
+ Requires-Dist: build>=1.2.2; extra == "dev"
28
+ Requires-Dist: twine>=5.1.1; extra == "dev"
29
+ Dynamic: license-file
30
+
31
+ # elm-pruning
@@ -0,0 +1,17 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ elm_prune.egg-info/PKG-INFO
5
+ elm_prune.egg-info/SOURCES.txt
6
+ elm_prune.egg-info/dependency_links.txt
7
+ elm_prune.egg-info/requires.txt
8
+ elm_prune.egg-info/top_level.txt
9
+ elmprune/__init__.py
10
+ elmprune/elm.py
11
+ elmprune/elm_importance_processor.py
12
+ elmprune/feature_extractor.py
13
+ elmprune/importance_processor_config.py
14
+ elmprune/prune_config.py
15
+ elmprune/prune_pipeline.py
16
+ elmprune/prune_processor.py
17
+ elmprune/utils.py
@@ -0,0 +1,9 @@
1
+ numpy>=1.24
2
+ tqdm>=4.66
3
+ torch>=2.5
4
+ segmentation-models-pytorch
5
+ torch-pruning==1.6.1
6
+
7
+ [dev]
8
+ build>=1.2.2
9
+ twine>=5.1.1
@@ -0,0 +1 @@
1
+ elmprune
@@ -0,0 +1,20 @@
1
+ from . import utils
2
+ from .importance_processor_config import ImportanceProcessorConfig
3
+ from .feature_extractor import FeatureExtractor
4
+ from .elm_importance_processor import ELMImportanceProcessor
5
+ from .elm import ELMRegressor
6
+ from .prune_processor import PruneProcessor
7
+ from .prune_config import PruneConfig
8
+ from .prune_config import PruneVerboseLevel
9
+ from .prune_pipeline import PrunePipeline
10
+
11
+
12
+ __all__ = ["utils",
13
+ "ImportanceProcessorConfig",
14
+ "FeatureExtractor",
15
+ "ELMImportanceProcessor",
16
+ "ELMRegressor",
17
+ "PruneProcessor",
18
+ "PruneConfig",
19
+ "PruneVerboseLevel",
20
+ "PrunePipeline"]
@@ -0,0 +1,105 @@
1
+ from typing import Dict, List
2
+ import math
3
+ import torch
4
+
5
+
6
+ class ELMRegressor:
7
+
8
+ def __init__(self, hidden_dim: int, reg_lambda: float, activation_name: str, seed: int, eps: float, use_double_for_solver: bool):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.hidden_dim = hidden_dim
11
+ self.reg_lambda = reg_lambda
12
+ self.activation_name = activation_name
13
+ self.seed = seed
14
+ self.eps = eps
15
+ self.use_double_for_solver = use_double_for_solver
16
+
17
+
18
+ def fit(self, X: torch.Tensor, Y: torch.Tensor) -> Dict[str, torch.Tensor]:
19
+ X = X.to(self.device)
20
+ Y = Y.to(self.device)
21
+
22
+ self.X_mean = X.mean(dim=0, keepdim=True)
23
+ self.X_std = X.std(dim=0, keepdim=True, unbiased=False).clamp_min(self.eps)
24
+ Xn = (X - self.X_mean) / self.X_std
25
+
26
+ self.Y_mean = Y.mean(dim=0, keepdim=True)
27
+ Yn = Y - self.Y_mean
28
+
29
+ generator = torch.Generator(device=self.device)
30
+ generator.manual_seed(self.seed)
31
+
32
+ in_dim = Xn.shape[1]
33
+ out_dim = Yn.shape[1]
34
+
35
+ self.W = torch.randn((in_dim, self.hidden_dim), generator=generator, device=self.device) / math.sqrt(max(in_dim, 1))
36
+ self.b = torch.randn((self.hidden_dim,), generator=generator, device=self.device)
37
+
38
+ H = self.__apply_activation(Xn @ self.W + self.b, self.activation_name)
39
+
40
+ I = torch.eye(self.hidden_dim, device=self.device, dtype=H.dtype)
41
+ lhs = H.T @ H + self.reg_lambda * I
42
+ rhs = H.T @ Yn
43
+
44
+ if self.use_double_for_solver:
45
+ self.beta = torch.linalg.solve(lhs.double(), rhs.double()).to(H.dtype)
46
+ else:
47
+ self.beta = torch.linalg.solve(lhs, rhs)
48
+
49
+ def predict(self, X: torch.Tensor) -> torch.Tensor:
50
+ X = X.to(self.W.device)
51
+
52
+ Xn = (X - self.X_mean) / self.X_std
53
+ H = self.__apply_activation(
54
+ Xn @ self.W + self.b,
55
+ self.activation_name,
56
+ )
57
+ return H @ self.beta + self.Y_mean
58
+
59
+ def compute_ablation_importance(self, X: torch.Tensor, Y: torch.Tensor) -> List[float]:
60
+ """
61
+ Importance = increase in ELM loss when feature is neutralized to its mean value.
62
+ Since pruning removes less important filters, low scores should be pruned first.
63
+ """
64
+ X = X.to(self.W.device)
65
+ Y = Y.to(self.W.device)
66
+
67
+ base_pred = self.predict(X)
68
+ base_loss = self.calculate_loss(base_pred, Y).item()
69
+
70
+ importances: List[float] = []
71
+ X_work = X.clone()
72
+
73
+ for feature_idx in range(X.shape[1]):
74
+ original_column = X_work[:, feature_idx].clone()
75
+
76
+ # Neutralize feature by sending it to its mean value
77
+ X_work[:, feature_idx] = self.X_mean[0, feature_idx]
78
+
79
+ ablated_pred = self.predict(X_work)
80
+ ablated_loss = self.calculate_loss(ablated_pred, Y).item()
81
+
82
+ importance = max(ablated_loss - base_loss, 0.0)
83
+ importances.append(float(importance))
84
+
85
+ X_work[:, feature_idx] = original_column
86
+
87
+ return importances
88
+
89
+ def calculate_loss(self, Y_pred, Y_original):
90
+ return self.__mse(Y_pred, Y_original)
91
+
92
+ def __apply_activation(self, x: torch.Tensor, activation_name: str) -> torch.Tensor:
93
+ activation_name = activation_name.lower()
94
+
95
+ if activation_name == "tanh":
96
+ return torch.tanh(x)
97
+ if activation_name == "relu":
98
+ return torch.relu(x)
99
+ if activation_name == "sigmoid":
100
+ return torch.sigmoid(x)
101
+
102
+ raise ValueError(f"Unsupported activation: {activation_name}")
103
+
104
+ def __mse(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
105
+ return torch.mean((pred - target) ** 2)
@@ -0,0 +1,164 @@
1
+ from typing import Dict, Iterable, List
2
+ from tqdm.auto import tqdm
3
+ import torch
4
+ from torch import nn
5
+ from .utils import get_all_conv_layer_names, compute_constant_baseline_loss, dump_dict, load_dict
6
+ from .elm import ELMRegressor
7
+ from .importance_processor_config import ImportanceProcessorConfig
8
+ from .feature_extractor import FeatureExtractor
9
+
10
+
11
+ class ELMImportanceProcessor:
12
+
13
+ IMPORTANCES_RELATIVE_FOLDER = "importances"
14
+ IMPORTANCES_GLOBAL_CACHE_FILENAME = "importances_elm_global.json"
15
+ IMPORTANCES_LAYERWISE_CACHE_FILENAME = "importances_elm_layerwise.json"
16
+ IMPORTANCES_FILTERWISE_CACHE_FILENAME = "importances_elm_filterwise.json"
17
+
18
+ def __init__(self, config: ImportanceProcessorConfig, model: nn.Module, dataloader: Iterable):
19
+ self.config = config
20
+ self.layer_names = get_all_conv_layer_names(model) if config.layer_names == "" else config.layer_names
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ self.importances_path = config.abs_path / ELMImportanceProcessor.IMPORTANCES_RELATIVE_FOLDER
24
+ self.importances_path.mkdir(parents=True, exist_ok=True)
25
+
26
+ self.has_global_cache, self.has_layerwise_cache, self.has_filterwise_cache = self.__verify_cache()
27
+ if not config.use_cache or not self.has_global_cache or not self.has_layerwise_cache or not self.has_filterwise_cache:
28
+ feature_extractor = FeatureExtractor(config, model, dataloader, self.layer_names)
29
+ self.features_by_layer, self.targets = feature_extractor.extract_feature_and_targets()
30
+ else:
31
+ print("[ELMImportanceProcessor] Not necessary to extract features here! Using cache for all importances type.")
32
+
33
+ def __verify_cache(self):
34
+ has_global_cache = (self.importances_path / ELMImportanceProcessor.IMPORTANCES_GLOBAL_CACHE_FILENAME).exists()
35
+ has_layerwise_cache = (self.importances_path / ELMImportanceProcessor.IMPORTANCES_LAYERWISE_CACHE_FILENAME).exists()
36
+ has_filterwise_cache = (self.importances_path / ELMImportanceProcessor.IMPORTANCES_FILTERWISE_CACHE_FILENAME).exists()
37
+ return has_global_cache, has_layerwise_cache, has_filterwise_cache
38
+
39
+ def compute_elm_global_importances(self) -> Dict[str, List[float]]:
40
+ """
41
+ One single ELM trained with features from all selected layers concatenated.
42
+ Importance of each filter = increase in ELM loss when that feature is neutralized.
43
+ """
44
+ print("[ELMImportanceProcessor] Getting importances for global...")
45
+
46
+ if self.has_global_cache:
47
+ print("[ELMImportanceProcessor] Not necessary to calculate importances here! Using cache for this importance type.")
48
+ result = load_dict(self.importances_path / ELMImportanceProcessor.IMPORTANCES_GLOBAL_CACHE_FILENAME)
49
+ return result
50
+
51
+ if len(self.features_by_layer) == 0:
52
+ raise RuntimeError("No features were collected for ELM global importance.")
53
+
54
+ X_parts = [self.features_by_layer[layer_name] for layer_name in self.layer_names]
55
+ X = torch.cat(X_parts, dim=1)
56
+ Y = self.targets
57
+
58
+ elm_model = ELMRegressor(
59
+ hidden_dim=self.config.hidden_dim,
60
+ reg_lambda=self.config.reg_lambda,
61
+ activation_name=self.config.activation,
62
+ seed=self.config.seed,
63
+ eps=self.config.eps,
64
+ use_double_for_solver=self.config.use_double_for_solver,
65
+ )
66
+
67
+ elm_model.fit(X, Y)
68
+ importances = elm_model.compute_ablation_importance(X, Y)
69
+
70
+ result: Dict[str, List[float]] = {}
71
+ offset = 0
72
+ for layer_name in tqdm(self.layer_names, desc="ELM global feature ranking processing", dynamic_ncols=True, position=1, leave=False):
73
+ channels = self.features_by_layer[layer_name].shape[1]
74
+ result[layer_name] = importances[offset: offset + channels]
75
+ offset += channels
76
+
77
+ importances_dump_path = self.importances_path / ELMImportanceProcessor.IMPORTANCES_GLOBAL_CACHE_FILENAME
78
+ dump_dict(result, importances_dump_path)
79
+
80
+ return result
81
+
82
+ def compute_elm_layerwise_importances(self) -> Dict[str, List[float]]:
83
+ """
84
+ One ELM per layer.
85
+ Importance of each filter = increase in ELM loss when that feature is neutralized.
86
+ """
87
+ print("[ELMImportanceProcessor] Getting importances for layerwise...")
88
+
89
+ if self.has_layerwise_cache:
90
+ print("[ELMImportanceProcessor]: Not necessary to calculate importances here! Using cache for this importance type.")
91
+ result = load_dict(self.importances_path / ELMImportanceProcessor.IMPORTANCES_LAYERWISE_CACHE_FILENAME)
92
+ return result
93
+
94
+ result: Dict[str, List[float]] = {}
95
+ Y = self.targets.to(self.device)
96
+
97
+ for layer_name in tqdm(self.layer_names, desc="ELM layerwise feature ranking processing", dynamic_ncols=True, position=1, leave=False):
98
+ X = self.features_by_layer[layer_name].to(self.device)
99
+
100
+ elm_model = ELMRegressor(
101
+ hidden_dim=self.config.hidden_dim,
102
+ reg_lambda=self.config.reg_lambda,
103
+ activation_name=self.config.activation,
104
+ seed=self.config.seed,
105
+ eps=self.config.eps,
106
+ use_double_for_solver=self.config.use_double_for_solver,
107
+ )
108
+
109
+ elm_model.fit(X, Y)
110
+ importances = elm_model.compute_ablation_importance(X, Y)
111
+ result[layer_name] = importances
112
+
113
+ importances_dump_path = self.importances_path / ELMImportanceProcessor.IMPORTANCES_LAYERWISE_CACHE_FILENAME
114
+ dump_dict(result, importances_dump_path)
115
+
116
+ return result
117
+
118
+ def compute_elm_filterwise_importances(self) -> Dict[str, List[float]]:
119
+ """
120
+ One tiny ELM per filter.
121
+ Importance of one filter = how much that single filter alone reduces target reconstruction loss
122
+ compared with a constant baseline.
123
+ """
124
+ print("[ELMImportanceProcessor] Getting importances for filterwise...")
125
+
126
+ if self.has_filterwise_cache:
127
+ print("[ELMImportanceProcessor]: Not necessary to calculate importances here! Using cache for this importance type.")
128
+ result = load_dict(self.importances_path / ELMImportanceProcessor.IMPORTANCES_FILTERWISE_CACHE_FILENAME)
129
+ return result
130
+
131
+ result: Dict[str, List[float]] = {}
132
+ Y = self.targets.to(self.device)
133
+ baseline_loss = compute_constant_baseline_loss(Y)
134
+
135
+ for layer_name in tqdm(self.layer_names, desc="ELM filterwise feature ranking processing", dynamic_ncols=True, position=1, leave=False):
136
+ X_layer = self.features_by_layer[layer_name].to(self.device)
137
+ layer_importances: List[float] = []
138
+
139
+ for filter_idx in range(X_layer.shape[1]):
140
+ X_filter = X_layer[:, filter_idx:filter_idx + 1]
141
+
142
+ elm_model = ELMRegressor(
143
+ hidden_dim=self.config.hidden_dim_per_filter,
144
+ reg_lambda=self.config.reg_lambda,
145
+ activation_name=self.config.activation,
146
+ seed=self.config.seed + filter_idx,
147
+ eps=self.config.eps,
148
+ use_double_for_solver=self.config.use_double_for_solver,
149
+ )
150
+
151
+ elm_model.fit(X_filter, Y)
152
+ pred = elm_model.predict(X_filter)
153
+ filter_loss = elm_model.calculate_loss(pred, Y).item()
154
+
155
+ # Higher reduction => more important
156
+ importance = max(baseline_loss - filter_loss, 0.0)
157
+ layer_importances.append(float(importance))
158
+
159
+ result[layer_name] = layer_importances
160
+
161
+ importances_dump_path = self.importances_path / ELMImportanceProcessor.IMPORTANCES_FILTERWISE_CACHE_FILENAME
162
+ dump_dict(result, importances_dump_path)
163
+
164
+ return result
@@ -0,0 +1,217 @@
1
+ import sys
2
+ import torch
3
+ import torch.nn as nn
4
+ from tqdm.auto import tqdm
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
6
+ from .utils import get_layer_by_string
7
+ from .importance_processor_config import ImportanceProcessorConfig
8
+
9
+
10
+ class FeatureExtractor:
11
+
12
+ def __init__(self, config: ImportanceProcessorConfig, model: nn.Module, dataloader: Iterable, layer_names: List[str]):
13
+ self.config = config
14
+ self.model = model
15
+ self.dataloader = dataloader
16
+ self.layer_names = layer_names
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ def extract_feature_and_targets(self)-> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
20
+
21
+ if self.config.feature_type == "segmentation":
22
+ target_extractor = self.__segmentation_mask_histogram_target_extractor()
23
+ else:
24
+ target_extractor = self.__default_logits_gap_target_extractor
25
+
26
+ return self.__process_features_and_targets(target_extractor)
27
+
28
+ def __process_features_and_targets(self, target_extractor: Callable[[Any, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
29
+
30
+ print("[FeatureExtractor] Extracting features for ELM...")
31
+
32
+ self.model.eval()
33
+ self.model = self.model.to(self.device)
34
+ storage_device = "cpu"
35
+
36
+ hooks = []
37
+ feature_storage: Dict[str, List[torch.Tensor]] = {layer_name: [] for layer_name in self.layer_names}
38
+ target_storage: List[torch.Tensor] = []
39
+
40
+ current_batch_features: Dict[str, torch.Tensor] = {}
41
+
42
+ def make_hook(layer_name: str):
43
+ def hook_fn(module, inputs, output):
44
+ out = self.__extract_first_tensor(output)
45
+
46
+ if out.dim() == 4:
47
+ # [N, C, H, W] -> [N, C]
48
+ pooled = out.mean(dim=(2, 3))
49
+ elif out.dim() == 3:
50
+ pooled = out.mean(dim=2)
51
+ elif out.dim() == 2:
52
+ pooled = out
53
+ else:
54
+ pooled = out.flatten(start_dim=1)
55
+
56
+ current_batch_features[layer_name] = pooled.detach().to(storage_device)
57
+ return hook_fn
58
+
59
+ # Register hooks
60
+ for layer_name in self.layer_names:
61
+ layer = get_layer_by_string(self.model, layer_name)
62
+ hooks.append(layer.register_forward_hook(make_hook(layer_name)))
63
+
64
+ try:
65
+ with torch.no_grad():
66
+ total_batches = None
67
+ if hasattr(self.dataloader, "__len__"):
68
+ total_batches = len(self.dataloader)
69
+ if self.config.max_batches is not None:
70
+ total_batches = min(total_batches, self.config.max_batches)
71
+
72
+ progress_bar = tqdm(
73
+ enumerate(self.dataloader),
74
+ total=total_batches,
75
+ desc="Collecting features",
76
+ dynamic_ncols=True,
77
+ file=sys.stdout,
78
+ position=1,
79
+ leave=False
80
+ )
81
+
82
+ for batch_idx, batch in progress_bar:
83
+ if self.config.max_batches is not None and batch_idx >= self.config.max_batches:
84
+ break
85
+
86
+ inputs, targets = self.__unpack_batch(batch)
87
+
88
+ inputs = inputs.to(self.device)
89
+ targets = targets.to(self.device) if targets is not None else None
90
+
91
+ current_batch_features.clear()
92
+
93
+ model_output = self.model(inputs)
94
+ y = target_extractor(model_output, inputs, targets).to(storage_device)
95
+
96
+ for layer_name in self.layer_names:
97
+ if layer_name not in current_batch_features:
98
+ raise RuntimeError(f"No activation captured for layer '{layer_name}'.")
99
+
100
+ feature_storage[layer_name].append(current_batch_features[layer_name])
101
+
102
+ target_storage.append(y)
103
+
104
+ finally:
105
+ for hook in hooks:
106
+ hook.remove()
107
+
108
+ features_by_layer = {
109
+ layer_name: torch.cat(feature_storage[layer_name], dim=0)
110
+ for layer_name in self.layer_names
111
+ }
112
+ targets = torch.cat(target_storage, dim=0)
113
+
114
+ return features_by_layer, targets
115
+
116
+ def __default_logits_gap_target_extractor(self, model_output: Any) -> torch.Tensor:
117
+ """
118
+ Default target extractor:
119
+ converts model output into [N, D] using a GAP-like reduction.
120
+ Very practical because it does not assume a specific task label encoding.
121
+ """
122
+ out = self.__extract_first_tensor(model_output)
123
+
124
+ if out.dim() == 4:
125
+ # Typical segmentation logits: [N, C, H, W] -> [N, C]
126
+ return out.mean(dim=(2, 3)).detach()
127
+ if out.dim() == 3:
128
+ return out.mean(dim=2).detach()
129
+ if out.dim() == 2:
130
+ return out.detach()
131
+
132
+ return out.flatten(start_dim=1).detach()
133
+
134
+ def __segmentation_mask_histogram_target_extractor(self) -> Callable[[Any, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
135
+ """
136
+ Returns a target extractor that uses the GT mask distribution per image.
137
+ For each image, target becomes [p(class0), p(class1), ..., p(classK)].
138
+
139
+ This is more task-aware for segmentation.
140
+ """
141
+ def _extractor(model_output: Any, input_batch: torch.Tensor, target_batch: Optional[torch.Tensor] = None) -> torch.Tensor:
142
+ if target_batch is None:
143
+ raise ValueError("Target batch is required for mask histogram target extractor.")
144
+
145
+ y = target_batch
146
+
147
+ if y.dim() == 4 and y.shape[1] == 1:
148
+ y = y.squeeze(1)
149
+
150
+ if y.dim() != 3:
151
+ raise ValueError(f"Expected target mask with shape [N, H, W] or [N, 1, H, W], got {tuple(y.shape)}")
152
+
153
+ y = y.long()
154
+ histograms = []
155
+
156
+ for class_idx in range(self.config.num_classes):
157
+ class_ratio = (y == class_idx).float().mean(dim=(1, 2))
158
+ histograms.append(class_ratio)
159
+
160
+ return torch.stack(histograms, dim=1).detach()
161
+
162
+ return _extractor
163
+
164
+ def __extract_first_tensor(self, data: Any) -> torch.Tensor:
165
+ if torch.is_tensor(data):
166
+ return data
167
+
168
+ if isinstance(data, dict):
169
+ if "out" in data and torch.is_tensor(data["out"]):
170
+ return data["out"]
171
+
172
+ for value in data.values():
173
+ if torch.is_tensor(value):
174
+ return value
175
+
176
+ if isinstance(data, (list, tuple)):
177
+ for value in data:
178
+ if torch.is_tensor(value):
179
+ return value
180
+
181
+ raise TypeError("Could not extract a torch.Tensor from model output / hook output.")
182
+
183
+ def __unpack_batch(self, batch: Any) -> Tuple[torch.torch.Tensor, Optional[torch.torch.Tensor]]:
184
+ if torch.is_tensor(batch):
185
+ return batch, None
186
+
187
+ if isinstance(batch, dict):
188
+ if "image" not in batch:
189
+ raise KeyError("Batch dict must contain key 'image'.")
190
+
191
+ x = batch["image"]
192
+ y = batch.get("mask", None)
193
+
194
+ if not torch.is_tensor(x):
195
+ raise TypeError("Batch['image'] is not a torch.Tensor.")
196
+
197
+ if y is not None and not torch.is_tensor(y):
198
+ y = None
199
+
200
+ return x, y
201
+
202
+ if isinstance(batch, (list, tuple)):
203
+ if len(batch) == 0:
204
+ raise ValueError("Empty batch received.")
205
+
206
+ x = batch[0]
207
+ y = batch[1] if len(batch) > 1 else None
208
+
209
+ if not torch.is_tensor(x):
210
+ raise TypeError("Batch input is not a torch.Tensor.")
211
+
212
+ if y is not None and not torch.is_tensor(y):
213
+ y = None
214
+
215
+ return x, y
216
+
217
+ raise TypeError(f"Unsupported batch type: {type(batch)}")
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ from pathlib import Path
4
+
5
+
6
+ @dataclass
7
+ class ImportanceProcessorConfig:
8
+ hidden_dim: int = 128
9
+ hidden_dim_per_filter: int = 16
10
+ reg_lambda: float = 1e-3
11
+ activation: str = "tanh" # tanh | relu | sigmoid
12
+ max_batches: Optional[int] = None
13
+ eps: float = 1e-8
14
+ seed: int = 42
15
+ use_double_for_solver: bool = True
16
+ feature_type: str = "segmentation" # segmention | logits
17
+ num_classes: int = 3 # if segmentation
18
+ layer_names = "" # to get importance for all layers, or list (str) to specify the layer names
19
+ abs_path: Path = None
20
+ use_cache: bool = True
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+ from enum import Enum
4
+
5
+ class PruneVerboseLevel(Enum):
6
+ BASIC = 1
7
+ BASIC_ERROR = 2
8
+ ALL = 3
9
+
10
+ @dataclass
11
+ class PruneConfig:
12
+ importance_type: Literal["elm_global", "elm_layerwise", "elm_filterwise"]
13
+ target_param_reduction: float # ex.: 0.20 = reduzir 20% dos params reais
14
+ selection_scope: Literal["global", "local"]
15
+ min_channels_abs: int = 16
16
+ min_keep_ratio: float = 0.50
17
+ max_layer_prune_ratio: float = 0.35
18
+ per_step_layer_ratio: float = 0.05 # poda pequena por iteração
19
+ round_to: int = 8
20
+ verbose: PruneVerboseLevel = PruneVerboseLevel.BASIC_ERROR