oikan 0.0.1.11__py3-none-any.whl → 0.0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oikan/exceptions.py +15 -0
- oikan/model.py +422 -83
- oikan/symbolic.py +24 -125
- oikan/utils.py +39 -36
- oikan-0.0.2.2.dist-info/METADATA +223 -0
- oikan-0.0.2.2.dist-info/RECORD +10 -0
- {oikan-0.0.1.11.dist-info → oikan-0.0.2.2.dist-info}/WHEEL +1 -1
- oikan/metrics.py +0 -48
- oikan/regularization.py +0 -30
- oikan/trainer.py +0 -49
- oikan/visualize.py +0 -69
- oikan-0.0.1.11.dist-info/METADATA +0 -105
- oikan-0.0.1.11.dist-info/RECORD +0 -13
- {oikan-0.0.1.11.dist-info → oikan-0.0.2.2.dist-info/licenses}/LICENSE +0 -0
- {oikan-0.0.1.11.dist-info → oikan-0.0.2.2.dist-info}/top_level.txt +0 -0
oikan/exceptions.py
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
class OikanError(Exception):
|
2
|
+
"""Base exception class for OIKAN"""
|
3
|
+
pass
|
4
|
+
|
5
|
+
class NotFittedError(OikanError):
|
6
|
+
"""Raised when prediction is attempted on unfitted model"""
|
7
|
+
pass
|
8
|
+
|
9
|
+
class DataError(OikanError):
|
10
|
+
"""Raised when there are issues with input data"""
|
11
|
+
pass
|
12
|
+
|
13
|
+
class InitializationError(OikanError):
|
14
|
+
"""Raised when model initialization fails"""
|
15
|
+
pass
|
oikan/model.py
CHANGED
@@ -1,99 +1,438 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
|
-
|
3
|
+
import numpy as np
|
4
|
+
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
|
5
|
+
from .utils import ADVANCED_LIB, EdgeActivation
|
6
|
+
from .exceptions import *
|
7
|
+
from datetime import datetime as dt
|
4
8
|
|
5
|
-
class
|
6
|
-
|
7
|
-
def __init__(self
|
9
|
+
class SymbolicEdge(nn.Module):
|
10
|
+
"""Edge-based activation function learner"""
|
11
|
+
def __init__(self):
|
8
12
|
super().__init__()
|
9
|
-
self.
|
10
|
-
self.bias = nn.Parameter(torch.zeros(hidden_dim))
|
13
|
+
self.activation = EdgeActivation()
|
11
14
|
|
12
15
|
def forward(self, x):
|
13
|
-
|
14
|
-
|
16
|
+
return self.activation(x)
|
17
|
+
|
18
|
+
def get_symbolic_repr(self, threshold=1e-4):
|
19
|
+
return self.activation.get_symbolic_repr(threshold)
|
15
20
|
|
16
|
-
class
|
17
|
-
|
18
|
-
def __init__(self, input_dim,
|
21
|
+
class KANLayer(nn.Module):
|
22
|
+
"""Kolmogorov-Arnold Network layer with interpretable edges"""
|
23
|
+
def __init__(self, input_dim, output_dim):
|
19
24
|
super().__init__()
|
20
25
|
self.input_dim = input_dim
|
21
|
-
self.
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
self.basis_output_dim = input_dim * (hidden_units - 4)
|
28
|
-
elif basis_type == 'fourier':
|
29
|
-
# Use Fourier basis transformation for each feature
|
30
|
-
self.basis_functions = nn.ModuleList([FourierBasis(hidden_units // 2) for _ in range(input_dim)])
|
31
|
-
self.basis_output_dim = input_dim * hidden_units
|
32
|
-
elif basis_type == 'combo':
|
33
|
-
# Combine BSpline and Fourier basis on a per-feature basis
|
34
|
-
self.basis_functions_bspline = nn.ModuleList([BSplineBasis(hidden_units) for _ in range(input_dim)])
|
35
|
-
self.basis_functions_fourier = nn.ModuleList([FourierBasis(hidden_units // 2) for _ in range(input_dim)])
|
36
|
-
self.basis_output_dim = input_dim * ((hidden_units - 4) + hidden_units)
|
37
|
-
else:
|
38
|
-
raise ValueError(f"Unsupported basis_type: {basis_type}")
|
26
|
+
self.output_dim = output_dim
|
27
|
+
|
28
|
+
self.edges = nn.ModuleList([
|
29
|
+
nn.ModuleList([SymbolicEdge() for _ in range(output_dim)])
|
30
|
+
for _ in range(input_dim)
|
31
|
+
])
|
39
32
|
|
40
|
-
|
41
|
-
self.interaction_weights = nn.Parameter(torch.randn(input_dim, input_dim))
|
33
|
+
self.combination_weights = nn.Parameter(torch.randn(input_dim, output_dim) * 0.1)
|
42
34
|
|
43
35
|
def forward(self, x):
|
44
|
-
#
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
basis_output = torch.cat(transformed_features, dim=1)
|
52
|
-
|
53
|
-
# Compute interaction features via fixed matrix multiplication
|
54
|
-
batch_size = x.size(0)
|
55
|
-
x_reshaped = x.view(batch_size, self.input_dim, 1) # Reshape to [batch_size, input_dim, 1]
|
56
|
-
interaction_matrix = torch.sigmoid(self.interaction_weights) # Normalize interaction weights
|
57
|
-
interaction_features = torch.bmm(x_reshaped.transpose(1, 2),
|
58
|
-
x_reshaped * interaction_matrix.unsqueeze(0)) # Result: [batch_size, 1, 1]
|
59
|
-
interaction_features = interaction_features.view(batch_size, -1) # Flatten interaction output
|
60
|
-
|
61
|
-
return torch.cat([basis_output, interaction_features], dim=1)
|
36
|
+
x_split = x.split(1, dim=1) # list of (batch, 1) tensors for each input feature
|
37
|
+
edge_outputs = torch.stack([
|
38
|
+
torch.stack([edge(x_i).squeeze() for edge in edge_list], dim=1)
|
39
|
+
for x_i, edge_list in zip(x_split, self.edges)
|
40
|
+
], dim=1) # shape: (batch, input_dim, output_dim)
|
41
|
+
combined = edge_outputs * self.combination_weights.unsqueeze(0)
|
42
|
+
return combined.sum(dim=1)
|
62
43
|
|
63
|
-
def
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
44
|
+
def get_symbolic_formula(self):
|
45
|
+
"""Extract interpretable formulas for each output"""
|
46
|
+
formulas = []
|
47
|
+
for j in range(self.output_dim):
|
48
|
+
terms = []
|
49
|
+
for i in range(self.input_dim):
|
50
|
+
weight = self.combination_weights[i, j].item()
|
51
|
+
if abs(weight) > 1e-4:
|
52
|
+
edge_formula = self.edges[i][j].get_symbolic_repr()
|
53
|
+
if edge_formula != "0":
|
54
|
+
terms.append(f"({weight:.4f} * ({edge_formula}))")
|
55
|
+
formulas.append(" + ".join(terms) if terms else "0")
|
56
|
+
return formulas
|
57
|
+
|
58
|
+
class BaseOIKAN(BaseEstimator):
|
59
|
+
"""Base OIKAN model implementing common functionality"""
|
60
|
+
def __init__(self, hidden_dims=[64, 32], num_basis=10, degree=3, dropout=0.1):
|
61
|
+
self.hidden_dims = hidden_dims
|
62
|
+
self.num_basis = num_basis
|
63
|
+
self.degree = degree
|
64
|
+
self.dropout = dropout # Dropout probability for uncertainty quantification
|
65
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Auto device chooser
|
66
|
+
self.model = None
|
67
|
+
self._is_fitted = False
|
68
|
+
self.__name = "OIKAN v0.0.2" # Version info (manually configured)
|
69
|
+
self.loss_history = [] # <-- new attribute to store loss values
|
70
|
+
|
71
|
+
def _build_network(self, input_dim, output_dim):
|
72
|
+
layers = []
|
73
|
+
prev_dim = input_dim
|
74
|
+
for hidden_dim in self.hidden_dims:
|
75
|
+
layers.append(KANLayer(prev_dim, hidden_dim))
|
76
|
+
layers.append(nn.Dropout(self.dropout)) # Apply dropout for uncertainty quantification
|
77
|
+
prev_dim = hidden_dim
|
78
|
+
layers.append(KANLayer(prev_dim, output_dim))
|
79
|
+
return nn.Sequential(*layers).to(self.device)
|
80
|
+
|
81
|
+
def _validate_data(self, X, y=None):
|
82
|
+
if not isinstance(X, torch.Tensor):
|
83
|
+
X = torch.FloatTensor(X)
|
84
|
+
if y is not None and not isinstance(y, torch.Tensor):
|
85
|
+
y = torch.FloatTensor(y)
|
86
|
+
return X.to(self.device), (y.to(self.device) if y is not None else None)
|
87
|
+
|
88
|
+
def get_symbolic_formula(self):
|
89
|
+
"""Generate and cache symbolic formulas for production‐ready inference."""
|
90
|
+
if not self._is_fitted:
|
91
|
+
raise NotFittedError("Model must be fitted before extracting formulas")
|
92
|
+
if hasattr(self, "symbolic_formula"):
|
93
|
+
return self.symbolic_formula
|
94
|
+
if hasattr(self, 'classes_'): # Classifier
|
95
|
+
n_features = self.model[0].input_dim
|
96
|
+
n_classes = len(self.classes_)
|
97
|
+
formulas = [[None for _ in range(n_classes)] for _ in range(n_features)]
|
98
|
+
first_layer = self.model[0]
|
99
|
+
for i in range(n_features):
|
100
|
+
for j in range(n_classes):
|
101
|
+
weight = first_layer.combination_weights[i, j].item()
|
102
|
+
if abs(weight) > 1e-4:
|
103
|
+
edge_formula = first_layer.edges[i][j].get_symbolic_repr()
|
104
|
+
terms = []
|
105
|
+
for term in edge_formula.split(" + "):
|
106
|
+
if term and term != "0":
|
107
|
+
if "*" in term:
|
108
|
+
coef, rest = term.split("*", 1)
|
109
|
+
coef = float(coef) * weight
|
110
|
+
terms.append(f"{coef:.4f}*{rest}")
|
111
|
+
else:
|
112
|
+
terms.append(f"{float(term)*weight:.4f}")
|
113
|
+
formulas[i][j] = " + ".join(terms) if terms else "0"
|
114
|
+
else:
|
115
|
+
formulas[i][j] = "0"
|
116
|
+
self.symbolic_formula = formulas
|
117
|
+
return formulas
|
118
|
+
else: # Regressor
|
119
|
+
formulas = []
|
120
|
+
first_layer = self.model[0]
|
121
|
+
for i in range(first_layer.input_dim):
|
122
|
+
formula = first_layer.edges[i][0].get_symbolic_repr()
|
123
|
+
formulas.append(formula)
|
124
|
+
self.symbolic_formula = formulas
|
125
|
+
return formulas
|
126
|
+
|
127
|
+
def save_symbolic_formula(self, filename="outputs/symbolic_formula.txt"):
|
128
|
+
"""Save the cached symbolic formulas to file for production use.
|
129
|
+
|
130
|
+
The file will contain:
|
131
|
+
- A header with the version and timestamp
|
132
|
+
- The symbolic formulas for each feature (and class for classification)
|
133
|
+
- A general formula, including softmax for classification
|
134
|
+
- Recommendations for production use.
|
135
|
+
"""
|
136
|
+
header = f"Generated by {self.__name} | Timestamp: {dt.now()}\n\n"
|
137
|
+
header += "Symbolic Formulas:\n"
|
138
|
+
header += "====================\n"
|
139
|
+
formulas = self.get_symbolic_formula()
|
140
|
+
formulas_text = ""
|
141
|
+
if hasattr(self, 'classes_'):
|
142
|
+
# For classifiers: formulas is a 2D list [feature][class]
|
143
|
+
for i, feature in enumerate(formulas):
|
144
|
+
for j, form in enumerate(feature):
|
145
|
+
formulas_text += f"Feature {i} - Class {j}: {form}\n"
|
146
|
+
general = ("\nGeneral Formula (with softmax):\n"
|
147
|
+
"For each class j: y_j = softmax( sum_i [ symbolic_formula(feature_i, class_j) ] )\n")
|
148
|
+
recs = ("\nRecommendations:\n"
|
149
|
+
"• Use the symbolic formulas for streamlined inference in production.\n"
|
150
|
+
"• Verify predictions with both the neural network and the compiled symbolic predictor.\n")
|
79
151
|
else:
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
152
|
+
# For regressors: formulas is a list
|
153
|
+
for i, form in enumerate(formulas):
|
154
|
+
formulas_text += f"Feature {i}: {form}\n"
|
155
|
+
general = ("\nGeneral Formula:\n"
|
156
|
+
"y = sum_i [ symbolic_formula(feature_i) ]\n")
|
157
|
+
recs = ("\nRecommendations:\n"
|
158
|
+
"• Consider the symbolic formula for lightweight and interpretable inference.\n"
|
159
|
+
"• Validate approximation accuracy against the neural model.\n")
|
160
|
+
|
161
|
+
output = header + formulas_text + general + recs
|
162
|
+
with open(filename, "w") as f:
|
163
|
+
f.write(output)
|
164
|
+
print(f"Symbolic formulas saved to {filename}")
|
89
165
|
|
90
|
-
def
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
166
|
+
def get_feature_scores(self):
|
167
|
+
"""Get feature importance scores based on edge weights."""
|
168
|
+
if not self._is_fitted:
|
169
|
+
raise NotFittedError("Model must be fitted before computing scores")
|
170
|
+
|
171
|
+
weights = self.model[0].combination_weights.detach().cpu().numpy()
|
172
|
+
return np.mean(np.abs(weights), axis=1)
|
173
|
+
|
174
|
+
def _eval_formula(self, formula, x):
|
175
|
+
"""Helper to evaluate a symbolic formula for an input vector x using ADVANCED_LIB basis functions."""
|
176
|
+
import re
|
177
|
+
total = 0
|
178
|
+
pattern = re.compile(r"(-?\d+\.\d+)\*?([\w\(\)\^]+)")
|
179
|
+
matches = pattern.findall(formula)
|
180
|
+
for coef_str, func_name in matches:
|
181
|
+
try:
|
182
|
+
coef = float(coef_str)
|
183
|
+
for key, (notation, func) in ADVANCED_LIB.items():
|
184
|
+
if notation.strip() == func_name.strip():
|
185
|
+
total += coef * func(x)
|
186
|
+
break
|
187
|
+
except Exception:
|
188
|
+
continue
|
189
|
+
return total
|
190
|
+
|
191
|
+
def symbolic_predict(self, X):
|
192
|
+
"""Predict using only the extracted symbolic formula (regressor)."""
|
193
|
+
if not self._is_fitted:
|
194
|
+
raise NotFittedError("Model must be fitted before prediction")
|
195
|
+
X = np.array(X) if not isinstance(X, np.ndarray) else X
|
196
|
+
formulas = self.get_symbolic_formula() # For regressor: list of formula strings.
|
197
|
+
predictions = np.zeros((X.shape[0], 1))
|
198
|
+
for i, formula in enumerate(formulas):
|
199
|
+
x = X[:, i]
|
200
|
+
predictions[:, 0] += self._eval_formula(formula, x)
|
201
|
+
return predictions
|
202
|
+
|
203
|
+
def compile_symbolic_formula(self, filename="output/final_symbolic_formula.txt"):
|
204
|
+
import re
|
205
|
+
from .utils import ADVANCED_LIB # needed to retrieve basis functions
|
206
|
+
with open(filename, "r") as f:
|
207
|
+
content = f.read()
|
208
|
+
# Regex to extract coefficient and function notation.
|
209
|
+
# Matches patterns like: "(-?\d+\.\d+)\*?([\w\(\)\^]+)"
|
210
|
+
matches = re.findall(r"(-?\d+\.\d+)\*?([\w\(\)\^]+)", content)
|
211
|
+
compiled_terms = []
|
212
|
+
for coef_str, func_name in matches:
|
213
|
+
try:
|
214
|
+
coef = float(coef_str)
|
215
|
+
# Search for a matching basis function in ADVANCED_LIB (e.g. 'x', 'x^2', etc.)
|
216
|
+
for key, (notation, func) in ADVANCED_LIB.items():
|
217
|
+
if notation.strip() == func_name.strip():
|
218
|
+
compiled_terms.append((coef, func))
|
219
|
+
break
|
220
|
+
except Exception:
|
221
|
+
continue
|
222
|
+
def prediction_function(x):
|
223
|
+
pred = 0
|
224
|
+
for coef, func in compiled_terms:
|
225
|
+
pred += coef * func(x)
|
226
|
+
return pred
|
227
|
+
return prediction_function
|
228
|
+
|
229
|
+
def save_model(self, filepath="models/oikan_model.pth"):
|
230
|
+
"""Save the current model's state dictionary and extra attributes to a file."""
|
231
|
+
if self.model is None:
|
232
|
+
raise NotFittedError("No model to save. Build and train a model first.")
|
233
|
+
save_dict = {'state_dict': self.model.state_dict()}
|
234
|
+
if hasattr(self, "classes_"):
|
235
|
+
# Save classes_ as a list so that it can be reloaded.
|
236
|
+
save_dict['classes_'] = self.classes_.tolist()
|
237
|
+
torch.save(save_dict, filepath)
|
238
|
+
print(f"Model saved to {filepath}")
|
239
|
+
|
240
|
+
def load_model(self, filepath="models/oikan_model.pth", input_dim=None, output_dim=None):
|
241
|
+
"""Load the model's state dictionary and extra attributes from a file.
|
242
|
+
|
243
|
+
If the model architecture does not exist, it is automatically rebuilt using provided
|
244
|
+
input_dim and output_dim.
|
245
|
+
"""
|
246
|
+
if self.model is None:
|
247
|
+
if input_dim is None or output_dim is None:
|
248
|
+
raise NotFittedError("No model architecture available. Provide input_dim and output_dim to rebuild the model.")
|
249
|
+
self.model = self._build_network(input_dim, output_dim)
|
250
|
+
loaded = torch.load(filepath, map_location=self.device)
|
251
|
+
if isinstance(loaded, dict) and 'state_dict' in loaded:
|
252
|
+
self.model.load_state_dict(loaded['state_dict'])
|
253
|
+
if 'classes_' in loaded:
|
254
|
+
self.classes_ = torch.tensor(loaded['classes_'])
|
95
255
|
else:
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
256
|
+
self.model.load_state_dict(loaded)
|
257
|
+
self._is_fitted = True # Mark model as fitted after loading
|
258
|
+
print(f"Model loaded from {filepath}")
|
259
|
+
|
260
|
+
def get_loss_history(self):
|
261
|
+
"""Retrieve training loss history."""
|
262
|
+
return self.loss_history
|
263
|
+
|
264
|
+
class OIKANRegressor(BaseOIKAN, RegressorMixin):
|
265
|
+
"""OIKAN implementation for regression tasks"""
|
266
|
+
def fit(self, X, y, epochs=100, lr=0.01, batch_size=32, verbose=True):
|
267
|
+
X, y = self._validate_data(X, y)
|
268
|
+
if len(y.shape) == 1:
|
269
|
+
y = y.reshape(-1, 1)
|
270
|
+
|
271
|
+
if self.model is None:
|
272
|
+
self.model = self._build_network(X.shape[1], y.shape[1])
|
273
|
+
|
274
|
+
criterion = nn.MSELoss()
|
275
|
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5)
|
276
|
+
|
277
|
+
self.model.train()
|
278
|
+
self.loss_history = [] # <-- reset loss history at start of training
|
279
|
+
for epoch in range(epochs):
|
280
|
+
optimizer.zero_grad()
|
281
|
+
y_pred = self.model(X)
|
282
|
+
loss = criterion(y_pred, y)
|
283
|
+
|
284
|
+
if torch.isnan(loss):
|
285
|
+
print("Warning: NaN loss detected, reinitializing model...")
|
286
|
+
self.model = None
|
287
|
+
return self.fit(X, y, epochs, lr/10, batch_size, verbose)
|
288
|
+
|
289
|
+
loss.backward()
|
290
|
+
|
291
|
+
# Clip gradients
|
292
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
293
|
+
|
294
|
+
optimizer.step()
|
295
|
+
|
296
|
+
self.loss_history.append(loss.item()) # <-- save loss value for epoch
|
297
|
+
|
298
|
+
if verbose and (epoch + 1) % 10 == 0:
|
299
|
+
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
|
300
|
+
|
301
|
+
self._is_fitted = True
|
302
|
+
return self
|
303
|
+
|
304
|
+
def predict(self, X):
|
305
|
+
if not self._is_fitted:
|
306
|
+
raise NotFittedError("Model must be fitted before prediction")
|
307
|
+
|
308
|
+
X = self._validate_data(X)[0]
|
309
|
+
self.model.eval()
|
310
|
+
with torch.no_grad():
|
311
|
+
return self.model(X).cpu().numpy()
|
312
|
+
|
313
|
+
class OIKANClassifier(BaseOIKAN, ClassifierMixin):
|
314
|
+
"""OIKAN implementation for classification tasks"""
|
315
|
+
def fit(self, X, y, epochs=100, lr=0.01, batch_size=32, verbose=True):
|
316
|
+
X, y = self._validate_data(X, y)
|
317
|
+
self.classes_ = torch.unique(y)
|
318
|
+
n_classes = len(self.classes_)
|
319
|
+
|
320
|
+
if self.model is None:
|
321
|
+
self.model = self._build_network(X.shape[1], 1 if n_classes == 2 else n_classes)
|
322
|
+
|
323
|
+
criterion = (nn.BCEWithLogitsLoss() if n_classes == 2
|
324
|
+
else nn.CrossEntropyLoss())
|
325
|
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
|
326
|
+
|
327
|
+
self.model.train()
|
328
|
+
self.loss_history = [] # <-- reset loss history at start of training
|
329
|
+
for epoch in range(epochs):
|
330
|
+
optimizer.zero_grad()
|
331
|
+
logits = self.model(X)
|
332
|
+
if n_classes == 2:
|
333
|
+
y_tensor = y.float()
|
334
|
+
logits = logits.squeeze()
|
335
|
+
else:
|
336
|
+
y_tensor = y.long()
|
337
|
+
loss = criterion(logits, y_tensor)
|
338
|
+
loss.backward()
|
339
|
+
optimizer.step()
|
340
|
+
|
341
|
+
self.loss_history.append(loss.item()) # <-- save loss value for epoch
|
342
|
+
|
343
|
+
if verbose and (epoch + 1) % 10 == 0:
|
344
|
+
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
|
345
|
+
|
346
|
+
self._is_fitted = True
|
347
|
+
return self
|
348
|
+
|
349
|
+
def predict_proba(self, X):
|
350
|
+
if not self._is_fitted:
|
351
|
+
raise NotFittedError("Model must be fitted before prediction")
|
352
|
+
|
353
|
+
X = self._validate_data(X)[0]
|
354
|
+
self.model.eval()
|
355
|
+
with torch.no_grad():
|
356
|
+
logits = self.model(X)
|
357
|
+
if len(self.classes_) == 2:
|
358
|
+
probs = torch.sigmoid(logits)
|
359
|
+
return np.column_stack([1 - probs.cpu().numpy(), probs.cpu().numpy()])
|
360
|
+
else:
|
361
|
+
return torch.softmax(logits, dim=1).cpu().numpy()
|
362
|
+
|
363
|
+
def predict(self, X):
|
364
|
+
proba = self.predict_proba(X)
|
365
|
+
return self.classes_[np.argmax(proba, axis=1)]
|
366
|
+
|
367
|
+
def symbolic_predict_proba(self, X):
|
368
|
+
"""Predict class probabilities using only the extracted symbolic formula."""
|
369
|
+
if not self._is_fitted:
|
370
|
+
raise NotFittedError("Model must be fitted before prediction")
|
371
|
+
|
372
|
+
if not isinstance(X, np.ndarray):
|
373
|
+
X = np.array(X)
|
374
|
+
|
375
|
+
# Scale input data similar to training
|
376
|
+
X_scaled = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
|
377
|
+
|
378
|
+
formulas = self.get_symbolic_formula()
|
379
|
+
n_classes = len(self.classes_)
|
380
|
+
predictions = np.zeros((X.shape[0], n_classes))
|
381
|
+
|
382
|
+
# Evaluate each feature's contribution to each class
|
383
|
+
for i in range(X.shape[1]): # For each feature
|
384
|
+
x = X_scaled[:, i] # Use scaled data
|
385
|
+
for j in range(n_classes): # For each class
|
386
|
+
formula = formulas[i][j]
|
387
|
+
if formula and formula != "0":
|
388
|
+
predictions[:, j] += self._eval_formula(formula, x)
|
389
|
+
|
390
|
+
# Apply softmax with temperature for better separation
|
391
|
+
temperature = 1.0
|
392
|
+
exp_preds = np.exp(predictions / temperature)
|
393
|
+
probas = exp_preds / exp_preds.sum(axis=1, keepdims=True)
|
394
|
+
|
395
|
+
# Clip probabilities to avoid numerical issues
|
396
|
+
probas = np.clip(probas, 1e-7, 1.0)
|
397
|
+
probas = probas / probas.sum(axis=1, keepdims=True)
|
398
|
+
|
399
|
+
return probas
|
400
|
+
|
401
|
+
def get_symbolic_formula(self):
|
402
|
+
"""Extract symbolic formulas for all features and outputs."""
|
403
|
+
if not self._is_fitted:
|
404
|
+
raise NotFittedError("Model must be fitted before extracting formulas")
|
405
|
+
|
406
|
+
n_features = self.model[0].input_dim
|
407
|
+
n_classes = len(self.classes_)
|
408
|
+
formulas = [[[] for _ in range(n_classes)] for _ in range(n_features)]
|
409
|
+
|
410
|
+
first_layer = self.model[0]
|
411
|
+
for i in range(n_features):
|
412
|
+
for j in range(n_classes):
|
413
|
+
edge = first_layer.edges[i][j]
|
414
|
+
weight = first_layer.combination_weights[i, j].item()
|
415
|
+
|
416
|
+
if abs(weight) > 1e-4:
|
417
|
+
# Get the edge formula and scale by the weight
|
418
|
+
edge_formula = edge.get_symbolic_repr()
|
419
|
+
terms = []
|
420
|
+
for term in edge_formula.split(" + "):
|
421
|
+
if term and term != "0":
|
422
|
+
if "*" in term:
|
423
|
+
coef, rest = term.split("*", 1)
|
424
|
+
coef = float(coef) * weight
|
425
|
+
terms.append(f"{coef:.4f}*{rest}")
|
426
|
+
else:
|
427
|
+
terms.append(f"{float(term) * weight:.4f}")
|
428
|
+
|
429
|
+
formulas[i][j] = " + ".join(terms) if terms else "0"
|
430
|
+
else:
|
431
|
+
formulas[i][j] = "0"
|
432
|
+
|
433
|
+
return formulas
|
434
|
+
|
435
|
+
def symbolic_predict(self, X):
|
436
|
+
"""Predict classes using only the extracted symbolic formula."""
|
437
|
+
proba = self.symbolic_predict_proba(X)
|
438
|
+
return self.classes_[np.argmax(proba, axis=1)]
|