scope-rx 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scope_rx/__init__.py +178 -0
- scope_rx/cli.py +355 -0
- scope_rx/core/__init__.py +16 -0
- scope_rx/core/base.py +486 -0
- scope_rx/core/scope.py +349 -0
- scope_rx/core/wrapper.py +382 -0
- scope_rx/methods/__init__.py +60 -0
- scope_rx/methods/attention/__init__.py +18 -0
- scope_rx/methods/attention/flow.py +196 -0
- scope_rx/methods/attention/raw.py +165 -0
- scope_rx/methods/attention/rollout.py +235 -0
- scope_rx/methods/gradient/__init__.py +30 -0
- scope_rx/methods/gradient/gradcam.py +177 -0
- scope_rx/methods/gradient/gradcam_plusplus.py +170 -0
- scope_rx/methods/gradient/guided_backprop.py +133 -0
- scope_rx/methods/gradient/integrated_gradients.py +268 -0
- scope_rx/methods/gradient/layercam.py +141 -0
- scope_rx/methods/gradient/scorecam.py +201 -0
- scope_rx/methods/gradient/smoothgrad.py +170 -0
- scope_rx/methods/gradient/vanilla.py +113 -0
- scope_rx/methods/model_agnostic/__init__.py +15 -0
- scope_rx/methods/model_agnostic/kernel_shap.py +299 -0
- scope_rx/methods/model_agnostic/lime_explainer.py +250 -0
- scope_rx/methods/perturbation/__init__.py +18 -0
- scope_rx/methods/perturbation/meaningful_perturbation.py +217 -0
- scope_rx/methods/perturbation/occlusion.py +155 -0
- scope_rx/methods/perturbation/rise.py +194 -0
- scope_rx/metrics/__init__.py +42 -0
- scope_rx/metrics/faithfulness.py +271 -0
- scope_rx/metrics/sensitivity.py +114 -0
- scope_rx/metrics/stability.py +107 -0
- scope_rx/utils/__init__.py +38 -0
- scope_rx/utils/postprocessing.py +121 -0
- scope_rx/utils/preprocessing.py +150 -0
- scope_rx/utils/tensor.py +67 -0
- scope_rx/visualization/__init__.py +25 -0
- scope_rx/visualization/export.py +162 -0
- scope_rx/visualization/plots.py +248 -0
- scope_rx-2.0.0.dist-info/METADATA +361 -0
- scope_rx-2.0.0.dist-info/RECORD +44 -0
- scope_rx-2.0.0.dist-info/WHEEL +5 -0
- scope_rx-2.0.0.dist-info/entry_points.txt +2 -0
- scope_rx-2.0.0.dist-info/licenses/LICENSE +21 -0
- scope_rx-2.0.0.dist-info/top_level.txt +1 -0
scope_rx/__init__.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ScopeRX - Neural Network Explainability and Interpretability Library
|
|
3
|
+
====================================================================
|
|
4
|
+
|
|
5
|
+
ScopeRX is a comprehensive Python library for explaining and interpreting
|
|
6
|
+
neural network predictions. It provides state-of-the-art attribution methods,
|
|
7
|
+
evaluation metrics, and visualization tools.
|
|
8
|
+
|
|
9
|
+
Quick Start
|
|
10
|
+
-----------
|
|
11
|
+
>>> from scope_rx import ScopeRX
|
|
12
|
+
>>> import torch
|
|
13
|
+
>>> import torchvision.models as models
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Load a model
|
|
16
|
+
>>> model = models.resnet50(pretrained=True)
|
|
17
|
+
>>> model.eval()
|
|
18
|
+
>>>
|
|
19
|
+
>>> # Create explainer
|
|
20
|
+
>>> explainer = ScopeRX(model)
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Generate explanation
|
|
23
|
+
>>> result = explainer.explain(
|
|
24
|
+
... input_tensor,
|
|
25
|
+
... method='gradcam',
|
|
26
|
+
... target_layer='layer4',
|
|
27
|
+
... target_class=predicted_class
|
|
28
|
+
... )
|
|
29
|
+
>>>
|
|
30
|
+
>>> # Visualize
|
|
31
|
+
>>> result.visualize()
|
|
32
|
+
|
|
33
|
+
Available Methods
|
|
34
|
+
-----------------
|
|
35
|
+
Gradient-based:
|
|
36
|
+
- GradCAM, GradCAM++, ScoreCAM, LayerCAM
|
|
37
|
+
- SmoothGrad, VanillaGradients
|
|
38
|
+
- IntegratedGradients
|
|
39
|
+
|
|
40
|
+
Perturbation-based:
|
|
41
|
+
- OcclusionSensitivity
|
|
42
|
+
- RISE (Randomized Input Sampling)
|
|
43
|
+
- MeaningfulPerturbation
|
|
44
|
+
|
|
45
|
+
Model-agnostic:
|
|
46
|
+
- KernelSHAP
|
|
47
|
+
- LIME
|
|
48
|
+
|
|
49
|
+
Attention-based:
|
|
50
|
+
- AttentionRollout
|
|
51
|
+
- AttentionFlow
|
|
52
|
+
- RawAttention
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
__version__ = "2.0.0"
|
|
56
|
+
__author__ = "desenyon"
|
|
57
|
+
__email__ = "desenyon@gmail.com"
|
|
58
|
+
__license__ = "MIT"
|
|
59
|
+
|
|
60
|
+
# Core classes
|
|
61
|
+
from scope_rx.core.base import BaseExplainer, ExplanationResult
|
|
62
|
+
from scope_rx.core.scope import ScopeRX
|
|
63
|
+
from scope_rx.core.wrapper import ModelWrapper
|
|
64
|
+
|
|
65
|
+
# Attention-based methods
|
|
66
|
+
from scope_rx.methods.attention import (
|
|
67
|
+
AttentionFlow,
|
|
68
|
+
AttentionRollout,
|
|
69
|
+
RawAttention,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Gradient-based methods
|
|
73
|
+
from scope_rx.methods.gradient import (
|
|
74
|
+
GradCAM,
|
|
75
|
+
GradCAMPlusPlus,
|
|
76
|
+
GuidedBackprop,
|
|
77
|
+
IntegratedGradients,
|
|
78
|
+
LayerCAM,
|
|
79
|
+
ScoreCAM,
|
|
80
|
+
SmoothGrad,
|
|
81
|
+
VanillaGradients,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Model-agnostic methods
|
|
85
|
+
from scope_rx.methods.model_agnostic import (
|
|
86
|
+
LIME,
|
|
87
|
+
KernelSHAP,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Perturbation-based methods
|
|
91
|
+
from scope_rx.methods.perturbation import (
|
|
92
|
+
RISE,
|
|
93
|
+
MeaningfulPerturbation,
|
|
94
|
+
OcclusionSensitivity,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Metrics
|
|
98
|
+
from scope_rx.metrics import (
|
|
99
|
+
faithfulness_score,
|
|
100
|
+
insertion_deletion_auc,
|
|
101
|
+
sensitivity_score,
|
|
102
|
+
stability_score,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Utilities
|
|
106
|
+
from scope_rx.utils import (
|
|
107
|
+
load_image,
|
|
108
|
+
normalize_attribution,
|
|
109
|
+
preprocess_image,
|
|
110
|
+
to_numpy,
|
|
111
|
+
to_tensor,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Visualization
|
|
115
|
+
from scope_rx.visualization import (
|
|
116
|
+
create_interactive_plot,
|
|
117
|
+
export_visualization,
|
|
118
|
+
overlay_attribution,
|
|
119
|
+
plot_attribution,
|
|
120
|
+
plot_comparison,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
__all__ = [
|
|
124
|
+
# Version info
|
|
125
|
+
"__version__",
|
|
126
|
+
"__author__",
|
|
127
|
+
"__email__",
|
|
128
|
+
|
|
129
|
+
# Core
|
|
130
|
+
"ScopeRX",
|
|
131
|
+
"BaseExplainer",
|
|
132
|
+
"ExplanationResult",
|
|
133
|
+
"ModelWrapper",
|
|
134
|
+
|
|
135
|
+
# Gradient methods
|
|
136
|
+
"GradCAM",
|
|
137
|
+
"GradCAMPlusPlus",
|
|
138
|
+
"ScoreCAM",
|
|
139
|
+
"LayerCAM",
|
|
140
|
+
"SmoothGrad",
|
|
141
|
+
"IntegratedGradients",
|
|
142
|
+
"VanillaGradients",
|
|
143
|
+
"GuidedBackprop",
|
|
144
|
+
|
|
145
|
+
# Perturbation methods
|
|
146
|
+
"OcclusionSensitivity",
|
|
147
|
+
"RISE",
|
|
148
|
+
"MeaningfulPerturbation",
|
|
149
|
+
|
|
150
|
+
# Model-agnostic
|
|
151
|
+
"KernelSHAP",
|
|
152
|
+
"LIME",
|
|
153
|
+
|
|
154
|
+
# Attention
|
|
155
|
+
"AttentionRollout",
|
|
156
|
+
"AttentionFlow",
|
|
157
|
+
"RawAttention",
|
|
158
|
+
|
|
159
|
+
# Visualization
|
|
160
|
+
"plot_attribution",
|
|
161
|
+
"plot_comparison",
|
|
162
|
+
"overlay_attribution",
|
|
163
|
+
"create_interactive_plot",
|
|
164
|
+
"export_visualization",
|
|
165
|
+
|
|
166
|
+
# Metrics
|
|
167
|
+
"faithfulness_score",
|
|
168
|
+
"sensitivity_score",
|
|
169
|
+
"stability_score",
|
|
170
|
+
"insertion_deletion_auc",
|
|
171
|
+
|
|
172
|
+
# Utilities
|
|
173
|
+
"preprocess_image",
|
|
174
|
+
"load_image",
|
|
175
|
+
"normalize_attribution",
|
|
176
|
+
"to_numpy",
|
|
177
|
+
"to_tensor",
|
|
178
|
+
]
|
scope_rx/cli.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ScopeRX Command Line Interface.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
scope-rx explain IMAGE --model MODEL [--method METHOD] [--output OUTPUT]
|
|
6
|
+
scope-rx compare IMAGE --model MODEL --methods METHODS [--output OUTPUT]
|
|
7
|
+
scope-rx --version
|
|
8
|
+
scope-rx --help
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import sys
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main():
|
|
17
|
+
"""Main CLI entry point."""
|
|
18
|
+
parser = argparse.ArgumentParser(
|
|
19
|
+
prog="scope-rx",
|
|
20
|
+
description="ScopeRX: Neural Network Explainability Toolkit",
|
|
21
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
22
|
+
epilog="""
|
|
23
|
+
Examples:
|
|
24
|
+
# Generate a GradCAM explanation
|
|
25
|
+
scope-rx explain image.jpg --model resnet50 --method gradcam --output heatmap.png
|
|
26
|
+
|
|
27
|
+
# Compare multiple methods
|
|
28
|
+
scope-rx compare image.jpg --model resnet50 --methods gradcam,smoothgrad,ig
|
|
29
|
+
|
|
30
|
+
# List available methods
|
|
31
|
+
scope-rx list-methods
|
|
32
|
+
|
|
33
|
+
# Show model layers (for layer selection)
|
|
34
|
+
scope-rx show-layers --model resnet50
|
|
35
|
+
"""
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--version", "-v",
|
|
40
|
+
action="store_true",
|
|
41
|
+
help="Show version"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
45
|
+
|
|
46
|
+
# Explain command
|
|
47
|
+
explain_parser = subparsers.add_parser(
|
|
48
|
+
"explain",
|
|
49
|
+
help="Generate explanation for an image"
|
|
50
|
+
)
|
|
51
|
+
explain_parser.add_argument("image", help="Path to input image")
|
|
52
|
+
explain_parser.add_argument(
|
|
53
|
+
"--model", "-m",
|
|
54
|
+
required=True,
|
|
55
|
+
help="Model name (e.g., resnet50, vgg16) or path to saved model"
|
|
56
|
+
)
|
|
57
|
+
explain_parser.add_argument(
|
|
58
|
+
"--method",
|
|
59
|
+
default="gradcam",
|
|
60
|
+
help="Explanation method (default: gradcam)"
|
|
61
|
+
)
|
|
62
|
+
explain_parser.add_argument(
|
|
63
|
+
"--target-class", "-c",
|
|
64
|
+
type=int,
|
|
65
|
+
default=None,
|
|
66
|
+
help="Target class index (default: predicted class)"
|
|
67
|
+
)
|
|
68
|
+
explain_parser.add_argument(
|
|
69
|
+
"--layer", "-l",
|
|
70
|
+
default=None,
|
|
71
|
+
help="Target layer name (default: auto-detect)"
|
|
72
|
+
)
|
|
73
|
+
explain_parser.add_argument(
|
|
74
|
+
"--output", "-o",
|
|
75
|
+
default=None,
|
|
76
|
+
help="Output file path"
|
|
77
|
+
)
|
|
78
|
+
explain_parser.add_argument(
|
|
79
|
+
"--colormap",
|
|
80
|
+
default="jet",
|
|
81
|
+
help="Colormap for visualization (default: jet)"
|
|
82
|
+
)
|
|
83
|
+
explain_parser.add_argument(
|
|
84
|
+
"--no-display",
|
|
85
|
+
action="store_true",
|
|
86
|
+
help="Don't display the result"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Compare command
|
|
90
|
+
compare_parser = subparsers.add_parser(
|
|
91
|
+
"compare",
|
|
92
|
+
help="Compare multiple explanation methods"
|
|
93
|
+
)
|
|
94
|
+
compare_parser.add_argument("image", help="Path to input image")
|
|
95
|
+
compare_parser.add_argument(
|
|
96
|
+
"--model", "-m",
|
|
97
|
+
required=True,
|
|
98
|
+
help="Model name or path"
|
|
99
|
+
)
|
|
100
|
+
compare_parser.add_argument(
|
|
101
|
+
"--methods",
|
|
102
|
+
required=True,
|
|
103
|
+
help="Comma-separated list of methods"
|
|
104
|
+
)
|
|
105
|
+
compare_parser.add_argument(
|
|
106
|
+
"--target-class", "-c",
|
|
107
|
+
type=int,
|
|
108
|
+
default=None,
|
|
109
|
+
help="Target class index"
|
|
110
|
+
)
|
|
111
|
+
compare_parser.add_argument(
|
|
112
|
+
"--output", "-o",
|
|
113
|
+
default=None,
|
|
114
|
+
help="Output file path"
|
|
115
|
+
)
|
|
116
|
+
compare_parser.add_argument(
|
|
117
|
+
"--no-display",
|
|
118
|
+
action="store_true",
|
|
119
|
+
help="Don't display the result"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# List methods
|
|
123
|
+
subparsers.add_parser(
|
|
124
|
+
"list-methods",
|
|
125
|
+
help="List available explanation methods"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Show layers
|
|
129
|
+
layers_parser = subparsers.add_parser(
|
|
130
|
+
"show-layers",
|
|
131
|
+
help="Show model layers"
|
|
132
|
+
)
|
|
133
|
+
layers_parser.add_argument(
|
|
134
|
+
"--model", "-m",
|
|
135
|
+
required=True,
|
|
136
|
+
help="Model name"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
args = parser.parse_args()
|
|
140
|
+
|
|
141
|
+
if args.version:
|
|
142
|
+
from scope_rx import __version__
|
|
143
|
+
print(f"ScopeRX v{__version__}")
|
|
144
|
+
return 0
|
|
145
|
+
|
|
146
|
+
if args.command == "explain":
|
|
147
|
+
return cmd_explain(args)
|
|
148
|
+
elif args.command == "compare":
|
|
149
|
+
return cmd_compare(args)
|
|
150
|
+
elif args.command == "list-methods":
|
|
151
|
+
return cmd_list_methods(args)
|
|
152
|
+
elif args.command == "show-layers":
|
|
153
|
+
return cmd_show_layers(args)
|
|
154
|
+
else:
|
|
155
|
+
parser.print_help()
|
|
156
|
+
return 0
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def cmd_explain(args):
|
|
160
|
+
"""Handle explain command."""
|
|
161
|
+
try:
|
|
162
|
+
import torch
|
|
163
|
+
|
|
164
|
+
from scope_rx import ScopeRX
|
|
165
|
+
from scope_rx.utils import load_image, preprocess_image
|
|
166
|
+
from scope_rx.visualization import export_visualization, plot_attribution
|
|
167
|
+
|
|
168
|
+
# Load model
|
|
169
|
+
model = load_model(args.model)
|
|
170
|
+
|
|
171
|
+
# Load and preprocess image
|
|
172
|
+
original_image = load_image(args.image, size=(224, 224))
|
|
173
|
+
input_tensor = preprocess_image(args.image)
|
|
174
|
+
|
|
175
|
+
# Ensure tensor type
|
|
176
|
+
if not isinstance(input_tensor, torch.Tensor):
|
|
177
|
+
input_tensor = torch.from_numpy(input_tensor).float()
|
|
178
|
+
|
|
179
|
+
# Create explainer
|
|
180
|
+
scope = ScopeRX(model)
|
|
181
|
+
|
|
182
|
+
# Generate explanation
|
|
183
|
+
result = scope.explain(
|
|
184
|
+
input_tensor,
|
|
185
|
+
method=args.method,
|
|
186
|
+
target_class=args.target_class,
|
|
187
|
+
target_layer=args.layer
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
print(f"Method: {args.method}")
|
|
191
|
+
print(f"Target class: {result.target_class}")
|
|
192
|
+
print(f"Attribution shape: {result.attribution.shape}")
|
|
193
|
+
|
|
194
|
+
# Visualize
|
|
195
|
+
if args.output:
|
|
196
|
+
export_visualization(
|
|
197
|
+
result.attribution,
|
|
198
|
+
args.output,
|
|
199
|
+
colormap=args.colormap,
|
|
200
|
+
image=original_image
|
|
201
|
+
)
|
|
202
|
+
print(f"Saved to: {args.output}")
|
|
203
|
+
|
|
204
|
+
if not args.no_display:
|
|
205
|
+
plot_attribution(
|
|
206
|
+
result.attribution,
|
|
207
|
+
image=original_image,
|
|
208
|
+
title=args.method.replace('_', ' ').title(),
|
|
209
|
+
colormap=args.colormap
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return 0
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
216
|
+
return 1
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def cmd_compare(args):
|
|
220
|
+
"""Handle compare command."""
|
|
221
|
+
try:
|
|
222
|
+
import torch
|
|
223
|
+
|
|
224
|
+
from scope_rx import ScopeRX
|
|
225
|
+
from scope_rx.utils import load_image, preprocess_image
|
|
226
|
+
from scope_rx.visualization import plot_comparison
|
|
227
|
+
|
|
228
|
+
# Load model
|
|
229
|
+
model = load_model(args.model)
|
|
230
|
+
|
|
231
|
+
# Load and preprocess image
|
|
232
|
+
original_image = load_image(args.image, size=(224, 224))
|
|
233
|
+
input_tensor = preprocess_image(args.image)
|
|
234
|
+
|
|
235
|
+
# Ensure tensor type
|
|
236
|
+
if not isinstance(input_tensor, torch.Tensor):
|
|
237
|
+
input_tensor = torch.from_numpy(input_tensor).float()
|
|
238
|
+
|
|
239
|
+
# Create explainer
|
|
240
|
+
scope = ScopeRX(model)
|
|
241
|
+
|
|
242
|
+
# Parse methods
|
|
243
|
+
methods = [m.strip() for m in args.methods.split(',')]
|
|
244
|
+
|
|
245
|
+
# Generate explanations
|
|
246
|
+
results = scope.compare_methods(
|
|
247
|
+
input_tensor,
|
|
248
|
+
methods=methods,
|
|
249
|
+
target_class=args.target_class
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Extract attributions
|
|
253
|
+
attributions = {name: r.attribution for name, r in results.results.items()}
|
|
254
|
+
|
|
255
|
+
print(f"Compared methods: {', '.join(methods)}")
|
|
256
|
+
|
|
257
|
+
# Visualize
|
|
258
|
+
if not args.no_display:
|
|
259
|
+
plot_comparison(
|
|
260
|
+
attributions,
|
|
261
|
+
image=original_image,
|
|
262
|
+
save_path=args.output
|
|
263
|
+
)
|
|
264
|
+
elif args.output:
|
|
265
|
+
plot_comparison(
|
|
266
|
+
attributions,
|
|
267
|
+
image=original_image,
|
|
268
|
+
save_path=args.output,
|
|
269
|
+
show=False
|
|
270
|
+
)
|
|
271
|
+
print(f"Saved to: {args.output}")
|
|
272
|
+
|
|
273
|
+
return 0
|
|
274
|
+
|
|
275
|
+
except Exception as e:
|
|
276
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
277
|
+
return 1
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def cmd_list_methods(args):
|
|
281
|
+
"""List available methods."""
|
|
282
|
+
print("Available explanation methods:")
|
|
283
|
+
print("-" * 40)
|
|
284
|
+
|
|
285
|
+
# Use static list of available methods
|
|
286
|
+
methods = [
|
|
287
|
+
'gradcam', 'gradcam++', 'scorecam', 'layercam',
|
|
288
|
+
'smoothgrad', 'integrated_gradients', 'vanilla_gradients',
|
|
289
|
+
'guided_backprop', 'occlusion', 'rise', 'meaningful_perturbation',
|
|
290
|
+
'kernel_shap', 'lime', 'attention_rollout', 'attention_flow'
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
categories = {
|
|
294
|
+
"Gradient-based": ["gradcam", "gradcam++", "scorecam", "layercam",
|
|
295
|
+
"smoothgrad", "integrated_gradients", "vanilla",
|
|
296
|
+
"guided_backprop"],
|
|
297
|
+
"Perturbation-based": ["occlusion", "rise", "meaningful_perturbation"],
|
|
298
|
+
"Model-agnostic": ["kernel_shap", "lime"],
|
|
299
|
+
"Attention-based": ["attention_rollout", "attention_flow", "raw_attention"],
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
for category, method_list in categories.items():
|
|
303
|
+
print(f"\n{category}:")
|
|
304
|
+
for method in method_list:
|
|
305
|
+
if method in methods:
|
|
306
|
+
print(f" - {method}")
|
|
307
|
+
|
|
308
|
+
return 0
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def cmd_show_layers(args):
|
|
312
|
+
"""Show model layers."""
|
|
313
|
+
model = load_model(args.model)
|
|
314
|
+
|
|
315
|
+
print(f"Layers in {args.model}:")
|
|
316
|
+
print("-" * 40)
|
|
317
|
+
|
|
318
|
+
for name, module in model.named_modules():
|
|
319
|
+
if name:
|
|
320
|
+
print(f" {name}: {module.__class__.__name__}")
|
|
321
|
+
|
|
322
|
+
return 0
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def load_model(model_name: str):
|
|
326
|
+
"""Load a model by name or path."""
|
|
327
|
+
import torch
|
|
328
|
+
import torchvision.models as models # type: ignore[import-not-found]
|
|
329
|
+
|
|
330
|
+
# Check if it's a file path
|
|
331
|
+
if Path(model_name).exists():
|
|
332
|
+
model = torch.load(model_name, weights_only=False)
|
|
333
|
+
model.eval()
|
|
334
|
+
return model
|
|
335
|
+
|
|
336
|
+
# Check if it's a torchvision model
|
|
337
|
+
model_fn = getattr(models, model_name, None)
|
|
338
|
+
if model_fn is not None:
|
|
339
|
+
try:
|
|
340
|
+
# Try with weights parameter (newer torchvision)
|
|
341
|
+
model = model_fn(weights="DEFAULT")
|
|
342
|
+
except TypeError:
|
|
343
|
+
# Fall back to pretrained parameter (older torchvision)
|
|
344
|
+
model = model_fn(pretrained=True)
|
|
345
|
+
model.eval()
|
|
346
|
+
return model
|
|
347
|
+
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Unknown model: {model_name}. "
|
|
350
|
+
"Use a torchvision model name (e.g., resnet50) or path to saved model."
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
if __name__ == "__main__":
|
|
355
|
+
sys.exit(main())
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core module exports.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from scope_rx.core.base import AttributionContext, BaseExplainer, ExplanationResult
|
|
6
|
+
from scope_rx.core.scope import ScopeRX
|
|
7
|
+
from scope_rx.core.wrapper import ModelWrapper, auto_detect_target_layer
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BaseExplainer",
|
|
11
|
+
"ExplanationResult",
|
|
12
|
+
"AttributionContext",
|
|
13
|
+
"ModelWrapper",
|
|
14
|
+
"auto_detect_target_layer",
|
|
15
|
+
"ScopeRX",
|
|
16
|
+
]
|