quantmllibrary 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantml/__init__.py +74 -0
- quantml/autograd.py +154 -0
- quantml/cli/__init__.py +10 -0
- quantml/cli/run_experiment.py +385 -0
- quantml/config/__init__.py +28 -0
- quantml/config/config.py +259 -0
- quantml/data/__init__.py +33 -0
- quantml/data/cache.py +149 -0
- quantml/data/feature_store.py +234 -0
- quantml/data/futures.py +254 -0
- quantml/data/loaders.py +236 -0
- quantml/data/memory_optimizer.py +234 -0
- quantml/data/validators.py +390 -0
- quantml/experiments/__init__.py +23 -0
- quantml/experiments/logger.py +208 -0
- quantml/experiments/results.py +158 -0
- quantml/experiments/tracker.py +223 -0
- quantml/features/__init__.py +25 -0
- quantml/features/base.py +104 -0
- quantml/features/gap_features.py +124 -0
- quantml/features/registry.py +138 -0
- quantml/features/volatility_features.py +140 -0
- quantml/features/volume_features.py +142 -0
- quantml/functional.py +37 -0
- quantml/models/__init__.py +27 -0
- quantml/models/attention.py +258 -0
- quantml/models/dropout.py +130 -0
- quantml/models/gru.py +319 -0
- quantml/models/linear.py +112 -0
- quantml/models/lstm.py +353 -0
- quantml/models/mlp.py +286 -0
- quantml/models/normalization.py +289 -0
- quantml/models/rnn.py +154 -0
- quantml/models/tcn.py +238 -0
- quantml/online.py +209 -0
- quantml/ops.py +1707 -0
- quantml/optim/__init__.py +42 -0
- quantml/optim/adafactor.py +206 -0
- quantml/optim/adagrad.py +157 -0
- quantml/optim/adam.py +267 -0
- quantml/optim/lookahead.py +97 -0
- quantml/optim/quant_optimizer.py +228 -0
- quantml/optim/radam.py +192 -0
- quantml/optim/rmsprop.py +203 -0
- quantml/optim/schedulers.py +286 -0
- quantml/optim/sgd.py +181 -0
- quantml/py.typed +0 -0
- quantml/streaming.py +175 -0
- quantml/tensor.py +462 -0
- quantml/time_series.py +447 -0
- quantml/training/__init__.py +135 -0
- quantml/training/alpha_eval.py +203 -0
- quantml/training/backtest.py +280 -0
- quantml/training/backtest_analysis.py +168 -0
- quantml/training/cv.py +106 -0
- quantml/training/data_loader.py +177 -0
- quantml/training/ensemble.py +84 -0
- quantml/training/feature_importance.py +135 -0
- quantml/training/features.py +364 -0
- quantml/training/futures_backtest.py +266 -0
- quantml/training/gradient_clipping.py +206 -0
- quantml/training/losses.py +248 -0
- quantml/training/lr_finder.py +127 -0
- quantml/training/metrics.py +376 -0
- quantml/training/regularization.py +89 -0
- quantml/training/trainer.py +239 -0
- quantml/training/walk_forward.py +190 -0
- quantml/utils/__init__.py +51 -0
- quantml/utils/gradient_check.py +274 -0
- quantml/utils/logging.py +181 -0
- quantml/utils/ops_cpu.py +231 -0
- quantml/utils/profiling.py +364 -0
- quantml/utils/reproducibility.py +220 -0
- quantml/utils/serialization.py +335 -0
- quantmllibrary-0.1.0.dist-info/METADATA +536 -0
- quantmllibrary-0.1.0.dist-info/RECORD +79 -0
- quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
- quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
- quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
quantml/training/cv.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cross-validation utilities for time-series data.
|
|
3
|
+
|
|
4
|
+
Provides time-series aware cross-validation to avoid lookahead bias.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Iterator, Tuple, Optional
|
|
8
|
+
import math
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TimeSeriesSplit:
|
|
12
|
+
"""
|
|
13
|
+
Time-series aware cross-validation splitter.
|
|
14
|
+
|
|
15
|
+
Splits data sequentially to avoid lookahead bias.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, n_splits: int = 5, test_size: Optional[int] = None):
|
|
19
|
+
"""
|
|
20
|
+
Initialize TimeSeriesSplit.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
n_splits: Number of splits
|
|
24
|
+
test_size: Size of test set (if None, uses 1/n_splits)
|
|
25
|
+
"""
|
|
26
|
+
self.n_splits = n_splits
|
|
27
|
+
self.test_size = test_size
|
|
28
|
+
|
|
29
|
+
def split(self, X: List, y: Optional[List] = None) -> Iterator[Tuple[List[int], List[int]]]:
|
|
30
|
+
"""
|
|
31
|
+
Generate train/test splits.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
X: Input data
|
|
35
|
+
y: Optional target data
|
|
36
|
+
|
|
37
|
+
Yields:
|
|
38
|
+
Tuple of (train_indices, test_indices)
|
|
39
|
+
"""
|
|
40
|
+
n_samples = len(X)
|
|
41
|
+
if self.test_size is None:
|
|
42
|
+
test_size = n_samples // (self.n_splits + 1)
|
|
43
|
+
else:
|
|
44
|
+
test_size = self.test_size
|
|
45
|
+
|
|
46
|
+
for i in range(self.n_splits):
|
|
47
|
+
test_start = (i + 1) * test_size
|
|
48
|
+
test_end = min(test_start + test_size, n_samples)
|
|
49
|
+
|
|
50
|
+
if test_start >= n_samples:
|
|
51
|
+
break
|
|
52
|
+
|
|
53
|
+
train_indices = list(range(test_start))
|
|
54
|
+
test_indices = list(range(test_start, test_end))
|
|
55
|
+
|
|
56
|
+
yield train_indices, test_indices
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class PurgedKFold:
|
|
60
|
+
"""
|
|
61
|
+
Purged K-fold for time-series.
|
|
62
|
+
|
|
63
|
+
Removes overlapping periods to avoid leakage.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, n_splits: int = 5, purge_gap: int = 1):
|
|
67
|
+
"""
|
|
68
|
+
Initialize PurgedKFold.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
n_splits: Number of splits
|
|
72
|
+
purge_gap: Number of samples to purge between train and test
|
|
73
|
+
"""
|
|
74
|
+
self.n_splits = n_splits
|
|
75
|
+
self.purge_gap = purge_gap
|
|
76
|
+
|
|
77
|
+
def split(self, X: List, y: Optional[List] = None) -> Iterator[Tuple[List[int], List[int]]]:
|
|
78
|
+
"""
|
|
79
|
+
Generate purged train/test splits.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
X: Input data
|
|
83
|
+
y: Optional target data
|
|
84
|
+
|
|
85
|
+
Yields:
|
|
86
|
+
Tuple of (train_indices, test_indices)
|
|
87
|
+
"""
|
|
88
|
+
n_samples = len(X)
|
|
89
|
+
fold_size = n_samples // self.n_splits
|
|
90
|
+
|
|
91
|
+
for i in range(self.n_splits):
|
|
92
|
+
test_start = i * fold_size
|
|
93
|
+
test_end = min((i + 1) * fold_size, n_samples)
|
|
94
|
+
|
|
95
|
+
# Purge gap before test set
|
|
96
|
+
train_end = max(0, test_start - self.purge_gap)
|
|
97
|
+
train_indices = list(range(train_end))
|
|
98
|
+
|
|
99
|
+
# Purge gap after test set
|
|
100
|
+
train_start_after = min(n_samples, test_end + self.purge_gap)
|
|
101
|
+
train_indices.extend(list(range(train_start_after, n_samples)))
|
|
102
|
+
|
|
103
|
+
test_indices = list(range(test_start, test_end))
|
|
104
|
+
|
|
105
|
+
yield train_indices, test_indices
|
|
106
|
+
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Efficient data loader for time-series aware batching.
|
|
3
|
+
|
|
4
|
+
This module provides data loaders optimized for quant training,
|
|
5
|
+
with support for time-series aware batching to prevent lookahead bias.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Tuple, Optional, Iterator, Any
|
|
9
|
+
from quantml.tensor import Tensor
|
|
10
|
+
|
|
11
|
+
# Try to import NumPy
|
|
12
|
+
try:
|
|
13
|
+
import numpy as np
|
|
14
|
+
HAS_NUMPY = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
HAS_NUMPY = False
|
|
17
|
+
np = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class QuantDataLoader:
|
|
21
|
+
"""
|
|
22
|
+
Data loader for quant model training with time-series awareness.
|
|
23
|
+
|
|
24
|
+
This loader ensures no lookahead bias by only using past data
|
|
25
|
+
for training batches.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
X: Feature data
|
|
29
|
+
y: Target data
|
|
30
|
+
batch_size: Batch size
|
|
31
|
+
shuffle: Whether to shuffle (should be False for time-series)
|
|
32
|
+
drop_last: Whether to drop last incomplete batch
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
>>> loader = QuantDataLoader(X, y, batch_size=32)
|
|
36
|
+
>>> for batch_x, batch_y in loader:
|
|
37
|
+
>>> # Train on batch
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
X: List,
|
|
43
|
+
y: List,
|
|
44
|
+
batch_size: int = 32,
|
|
45
|
+
shuffle: bool = False,
|
|
46
|
+
drop_last: bool = False
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Initialize data loader.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
X: Feature data
|
|
53
|
+
y: Target data
|
|
54
|
+
batch_size: Batch size
|
|
55
|
+
shuffle: Whether to shuffle (False recommended for time-series)
|
|
56
|
+
drop_last: Whether to drop last incomplete batch
|
|
57
|
+
"""
|
|
58
|
+
if len(X) != len(y):
|
|
59
|
+
raise ValueError("X and y must have same length")
|
|
60
|
+
|
|
61
|
+
self.X = X
|
|
62
|
+
self.y = y
|
|
63
|
+
self.batch_size = batch_size
|
|
64
|
+
self.shuffle = shuffle
|
|
65
|
+
self.drop_last = drop_last
|
|
66
|
+
self.n_samples = len(X)
|
|
67
|
+
|
|
68
|
+
def __len__(self) -> int:
|
|
69
|
+
"""Get number of batches."""
|
|
70
|
+
if self.drop_last:
|
|
71
|
+
return self.n_samples // self.batch_size
|
|
72
|
+
return (self.n_samples + self.batch_size - 1) // self.batch_size
|
|
73
|
+
|
|
74
|
+
def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
|
|
75
|
+
"""Iterate over batches."""
|
|
76
|
+
indices = list(range(self.n_samples))
|
|
77
|
+
|
|
78
|
+
if self.shuffle:
|
|
79
|
+
import random
|
|
80
|
+
random.shuffle(indices)
|
|
81
|
+
|
|
82
|
+
for i in range(0, self.n_samples, self.batch_size):
|
|
83
|
+
batch_indices = indices[i:i + self.batch_size]
|
|
84
|
+
|
|
85
|
+
if self.drop_last and len(batch_indices) < self.batch_size:
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
batch_x = [self.X[idx] for idx in batch_indices]
|
|
89
|
+
batch_y = [self.y[idx] for idx in batch_indices]
|
|
90
|
+
|
|
91
|
+
# Convert to tensors
|
|
92
|
+
x_tensor = Tensor(batch_x)
|
|
93
|
+
y_tensor = Tensor(batch_y)
|
|
94
|
+
|
|
95
|
+
yield x_tensor, y_tensor
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class TimeSeriesDataLoader:
|
|
99
|
+
"""
|
|
100
|
+
Time-series aware data loader with sequence support.
|
|
101
|
+
|
|
102
|
+
This loader creates sequences for RNN/TCN models while ensuring
|
|
103
|
+
no lookahead bias.
|
|
104
|
+
|
|
105
|
+
Attributes:
|
|
106
|
+
X: Feature data
|
|
107
|
+
y: Target data
|
|
108
|
+
sequence_length: Length of input sequences
|
|
109
|
+
batch_size: Batch size
|
|
110
|
+
stride: Stride for sequence creation
|
|
111
|
+
|
|
112
|
+
Examples:
|
|
113
|
+
>>> loader = TimeSeriesDataLoader(X, y, sequence_length=20, batch_size=16)
|
|
114
|
+
>>> for seq_x, seq_y in loader:
|
|
115
|
+
>>> # Train on sequences
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
X: List,
|
|
121
|
+
y: List,
|
|
122
|
+
sequence_length: int = 20,
|
|
123
|
+
batch_size: int = 32,
|
|
124
|
+
stride: int = 1
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Initialize time-series data loader.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
X: Feature data
|
|
131
|
+
y: Target data
|
|
132
|
+
sequence_length: Length of input sequences
|
|
133
|
+
batch_size: Batch size
|
|
134
|
+
stride: Stride for sequence creation
|
|
135
|
+
"""
|
|
136
|
+
if len(X) != len(y):
|
|
137
|
+
raise ValueError("X and y must have same length")
|
|
138
|
+
|
|
139
|
+
self.X = X
|
|
140
|
+
self.y = y
|
|
141
|
+
self.sequence_length = sequence_length
|
|
142
|
+
self.batch_size = batch_size
|
|
143
|
+
self.stride = stride
|
|
144
|
+
|
|
145
|
+
# Create sequences
|
|
146
|
+
self.sequences = self._create_sequences()
|
|
147
|
+
self.n_sequences = len(self.sequences)
|
|
148
|
+
|
|
149
|
+
def _create_sequences(self) -> List[Tuple[List, float]]:
|
|
150
|
+
"""Create sequences from data."""
|
|
151
|
+
sequences = []
|
|
152
|
+
|
|
153
|
+
for i in range(self.sequence_length, len(self.X), self.stride):
|
|
154
|
+
seq_x = self.X[i - self.sequence_length:i]
|
|
155
|
+
seq_y = self.y[i]
|
|
156
|
+
sequences.append((seq_x, seq_y))
|
|
157
|
+
|
|
158
|
+
return sequences
|
|
159
|
+
|
|
160
|
+
def __len__(self) -> int:
|
|
161
|
+
"""Get number of batches."""
|
|
162
|
+
return (self.n_sequences + self.batch_size - 1) // self.batch_size
|
|
163
|
+
|
|
164
|
+
def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
|
|
165
|
+
"""Iterate over batches."""
|
|
166
|
+
for i in range(0, self.n_sequences, self.batch_size):
|
|
167
|
+
batch_sequences = self.sequences[i:i + self.batch_size]
|
|
168
|
+
|
|
169
|
+
batch_x = [seq[0] for seq in batch_sequences]
|
|
170
|
+
batch_y = [seq[1] for seq in batch_sequences]
|
|
171
|
+
|
|
172
|
+
# Convert to tensors
|
|
173
|
+
x_tensor = Tensor(batch_x)
|
|
174
|
+
y_tensor = Tensor(batch_y)
|
|
175
|
+
|
|
176
|
+
yield x_tensor, y_tensor
|
|
177
|
+
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model ensembling utilities for quant models.
|
|
3
|
+
|
|
4
|
+
Provides utilities for combining multiple models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Callable, Any, Dict
|
|
8
|
+
from quantml.tensor import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EnsembleModel:
|
|
12
|
+
"""
|
|
13
|
+
Ensemble model that combines multiple models.
|
|
14
|
+
|
|
15
|
+
Supports weighted averaging, voting, and stacking strategies.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
models: List[Any],
|
|
21
|
+
weights: Optional[List[float]] = None,
|
|
22
|
+
strategy: str = 'weighted_avg'
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Initialize ensemble model.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
models: List of models to ensemble
|
|
29
|
+
weights: Optional weights for each model (defaults to equal weights)
|
|
30
|
+
strategy: Ensemble strategy ('weighted_avg', 'voting', 'stacking')
|
|
31
|
+
"""
|
|
32
|
+
self.models = models
|
|
33
|
+
self.strategy = strategy
|
|
34
|
+
|
|
35
|
+
if weights is None:
|
|
36
|
+
self.weights = [1.0 / len(models)] * len(models)
|
|
37
|
+
else:
|
|
38
|
+
# Normalize weights
|
|
39
|
+
total = sum(weights)
|
|
40
|
+
self.weights = [w / total for w in weights]
|
|
41
|
+
|
|
42
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
43
|
+
"""
|
|
44
|
+
Forward pass through ensemble.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
x: Input features
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Ensemble prediction
|
|
51
|
+
"""
|
|
52
|
+
if self.strategy == 'weighted_avg':
|
|
53
|
+
predictions = [model.forward(x) for model in self.models]
|
|
54
|
+
# Weighted average
|
|
55
|
+
from quantml import ops
|
|
56
|
+
result = ops.mul(predictions[0], self.weights[0])
|
|
57
|
+
for i in range(1, len(predictions)):
|
|
58
|
+
result = ops.add(result, ops.mul(predictions[i], self.weights[i]))
|
|
59
|
+
return result
|
|
60
|
+
elif self.strategy == 'voting':
|
|
61
|
+
# Simple voting (for classification - simplified for regression)
|
|
62
|
+
predictions = [model.forward(x) for model in self.models]
|
|
63
|
+
from quantml import ops
|
|
64
|
+
result = ops.mul(predictions[0], self.weights[0])
|
|
65
|
+
for i in range(1, len(predictions)):
|
|
66
|
+
result = ops.add(result, ops.mul(predictions[i], self.weights[i]))
|
|
67
|
+
return result
|
|
68
|
+
else: # stacking
|
|
69
|
+
# Stacking would require a meta-learner (simplified here)
|
|
70
|
+
predictions = [model.forward(x) for model in self.models]
|
|
71
|
+
from quantml import ops
|
|
72
|
+
result = ops.mul(predictions[0], self.weights[0])
|
|
73
|
+
for i in range(1, len(predictions)):
|
|
74
|
+
result = ops.add(result, ops.mul(predictions[i], self.weights[i]))
|
|
75
|
+
return result
|
|
76
|
+
|
|
77
|
+
def parameters(self):
|
|
78
|
+
"""Get all parameters from all models."""
|
|
79
|
+
params = []
|
|
80
|
+
for model in self.models:
|
|
81
|
+
if hasattr(model, 'parameters'):
|
|
82
|
+
params.extend(model.parameters())
|
|
83
|
+
return params
|
|
84
|
+
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feature importance tracking and analysis.
|
|
3
|
+
|
|
4
|
+
Provides utilities for tracking feature contributions to model predictions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Dict, Optional, Any, Callable
|
|
8
|
+
from quantml.tensor import Tensor
|
|
9
|
+
|
|
10
|
+
# Try to import NumPy
|
|
11
|
+
try:
|
|
12
|
+
import numpy as np
|
|
13
|
+
HAS_NUMPY = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
HAS_NUMPY = False
|
|
16
|
+
np = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FeatureImportanceTracker:
|
|
20
|
+
"""
|
|
21
|
+
Track feature importance for quant models.
|
|
22
|
+
|
|
23
|
+
Tracks gradient-based and permutation-based feature importance.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
"""Initialize feature importance tracker."""
|
|
28
|
+
self.gradient_importance: Dict[int, List[float]] = {}
|
|
29
|
+
self.permutation_importance: Dict[int, List[float]] = {}
|
|
30
|
+
|
|
31
|
+
def compute_gradient_importance(self, model: Any, x: Tensor, y: Tensor) -> List[float]:
|
|
32
|
+
"""
|
|
33
|
+
Compute gradient-based feature importance.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model: Model to analyze
|
|
37
|
+
x: Input features
|
|
38
|
+
y: Targets
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
List of importance scores for each feature
|
|
42
|
+
"""
|
|
43
|
+
# Forward pass
|
|
44
|
+
pred = model.forward(x)
|
|
45
|
+
loss = model.loss_fn(pred, y) if hasattr(model, 'loss_fn') else None
|
|
46
|
+
|
|
47
|
+
if loss is None:
|
|
48
|
+
from quantml.training.losses import mse_loss
|
|
49
|
+
loss = mse_loss(pred, y)
|
|
50
|
+
|
|
51
|
+
# Backward pass
|
|
52
|
+
if loss.requires_grad:
|
|
53
|
+
loss.backward()
|
|
54
|
+
|
|
55
|
+
# Extract gradients from input
|
|
56
|
+
# This is simplified - would need to track gradients through model
|
|
57
|
+
importance = []
|
|
58
|
+
if hasattr(x, 'grad') and x.grad is not None:
|
|
59
|
+
grad = x.grad
|
|
60
|
+
if HAS_NUMPY and isinstance(grad, np.ndarray):
|
|
61
|
+
importance = np.abs(grad).mean(axis=0).tolist()
|
|
62
|
+
elif isinstance(grad, list):
|
|
63
|
+
if isinstance(grad[0], list):
|
|
64
|
+
# Average over samples
|
|
65
|
+
importance = [sum(abs(grad[i][j]) for i in range(len(grad))) / len(grad)
|
|
66
|
+
for j in range(len(grad[0]))]
|
|
67
|
+
else:
|
|
68
|
+
importance = [abs(g) for g in grad]
|
|
69
|
+
|
|
70
|
+
return importance
|
|
71
|
+
|
|
72
|
+
def compute_permutation_importance(
|
|
73
|
+
self,
|
|
74
|
+
model: Any,
|
|
75
|
+
x: Tensor,
|
|
76
|
+
y: Tensor,
|
|
77
|
+
metric_fn: Callable,
|
|
78
|
+
n_permutations: int = 10
|
|
79
|
+
) -> List[float]:
|
|
80
|
+
"""
|
|
81
|
+
Compute permutation-based feature importance.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model: Model to analyze
|
|
85
|
+
x: Input features
|
|
86
|
+
y: Targets
|
|
87
|
+
metric_fn: Metric function to use
|
|
88
|
+
n_permutations: Number of permutations per feature
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
List of importance scores
|
|
92
|
+
"""
|
|
93
|
+
# Baseline metric
|
|
94
|
+
pred_baseline = model.forward(x)
|
|
95
|
+
baseline_score = metric_fn(pred_baseline, y)
|
|
96
|
+
|
|
97
|
+
importance = []
|
|
98
|
+
x_data = x.data
|
|
99
|
+
|
|
100
|
+
# For each feature
|
|
101
|
+
num_features = len(x_data[0]) if isinstance(x_data[0], list) else len(x_data)
|
|
102
|
+
|
|
103
|
+
for feat_idx in range(num_features):
|
|
104
|
+
scores = []
|
|
105
|
+
for _ in range(n_permutations):
|
|
106
|
+
# Permute feature
|
|
107
|
+
x_permuted = self._permute_feature(x_data, feat_idx)
|
|
108
|
+
x_perm_tensor = Tensor(x_permuted)
|
|
109
|
+
pred_perm = model.forward(x_perm_tensor)
|
|
110
|
+
score = metric_fn(pred_perm, y)
|
|
111
|
+
scores.append(score)
|
|
112
|
+
|
|
113
|
+
# Importance is decrease in performance
|
|
114
|
+
avg_score = sum(scores) / len(scores)
|
|
115
|
+
importance.append(baseline_score - avg_score)
|
|
116
|
+
|
|
117
|
+
return importance
|
|
118
|
+
|
|
119
|
+
def _permute_feature(self, data: List, feat_idx: int) -> List:
|
|
120
|
+
"""Permute a single feature in the data."""
|
|
121
|
+
if isinstance(data[0], list):
|
|
122
|
+
# 2D case
|
|
123
|
+
import random
|
|
124
|
+
values = [row[feat_idx] for row in data]
|
|
125
|
+
random.shuffle(values)
|
|
126
|
+
return [[row[j] if j != feat_idx else values[i]
|
|
127
|
+
for j in range(len(row))]
|
|
128
|
+
for i, row in enumerate(data)]
|
|
129
|
+
else:
|
|
130
|
+
# 1D case
|
|
131
|
+
import random
|
|
132
|
+
values = data.copy()
|
|
133
|
+
random.shuffle(values)
|
|
134
|
+
return values
|
|
135
|
+
|