oikan 0.0.1.9__py3-none-any.whl → 0.0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oikan/metrics.py +32 -7
- oikan/model.py +42 -20
- oikan/trainer.py +6 -6
- oikan/visualize.py +20 -20
- {oikan-0.0.1.9.dist-info → oikan-0.0.1.11.dist-info}/METADATA +1 -1
- oikan-0.0.1.11.dist-info/RECORD +13 -0
- {oikan-0.0.1.9.dist-info → oikan-0.0.1.11.dist-info}/WHEEL +1 -1
- oikan-0.0.1.9.dist-info/RECORD +0 -13
- {oikan-0.0.1.9.dist-info → oikan-0.0.1.11.dist-info}/LICENSE +0 -0
- {oikan-0.0.1.9.dist-info → oikan-0.0.1.11.dist-info}/top_level.txt +0 -0
oikan/metrics.py
CHANGED
@@ -1,23 +1,48 @@
|
|
1
1
|
import numpy as np
|
2
2
|
import torch
|
3
|
+
from sklearn.metrics import precision_score, recall_score, f1_score, hamming_loss
|
3
4
|
|
4
5
|
def evaluate_regression(model, X, y):
|
5
|
-
'''Evaluate regression performance by computing MSE, MAE, and RMSE.'''
|
6
|
+
'''Evaluate regression performance by computing MSE, MAE, and RMSE, and print in table format.'''
|
6
7
|
with torch.no_grad():
|
7
8
|
y_pred = model(torch.FloatTensor(X)).numpy().ravel()
|
8
9
|
mse = np.mean((y - y_pred)**2)
|
9
10
|
mae = np.mean(np.abs(y - y_pred))
|
10
11
|
rmse = np.sqrt(mse)
|
11
|
-
|
12
|
-
|
13
|
-
|
12
|
+
|
13
|
+
# Print table
|
14
|
+
header = f"+{'-'*23}+{'-'*12}+"
|
15
|
+
print(header)
|
16
|
+
print(f"| {'Metric':21} | {'Value':9} |")
|
17
|
+
print(header)
|
18
|
+
print(f"| {'Mean Squared Error':21} | {mse:9.4f} |")
|
19
|
+
print(f"| {'Mean Absolute Error':21} | {mae:9.4f} |")
|
20
|
+
print(f"| {'Root Mean Squared Error':21} | {rmse:9.4f} |")
|
21
|
+
print(header)
|
22
|
+
|
14
23
|
return mse, mae, rmse
|
15
24
|
|
16
25
|
def evaluate_classification(model, X, y):
|
17
|
-
'''Evaluate classification
|
26
|
+
'''Evaluate classification performance by computing accuracy, precision, recall, f1-score, and hamming_loss, and printing in table format.'''
|
18
27
|
with torch.no_grad():
|
19
28
|
logits = model(torch.FloatTensor(X))
|
20
29
|
y_pred = torch.argmax(logits, dim=1).numpy()
|
21
30
|
accuracy = np.mean(y_pred == y)
|
22
|
-
|
23
|
-
|
31
|
+
precision = precision_score(y, y_pred, average='weighted', zero_division=0)
|
32
|
+
recall = recall_score(y, y_pred, average='weighted', zero_division=0)
|
33
|
+
f1 = f1_score(y, y_pred, average='weighted', zero_division=0)
|
34
|
+
h_loss = hamming_loss(y, y_pred)
|
35
|
+
|
36
|
+
# Print table
|
37
|
+
header = f"+{'-'*15}+{'-'*12}+"
|
38
|
+
print(header)
|
39
|
+
print(f"| {'Metric':13} | {'Value':9} |")
|
40
|
+
print(header)
|
41
|
+
print(f"| {'Accuracy':13} | {accuracy:9.4f} |")
|
42
|
+
print(f"| {'Precision':13} | {precision:9.4f} |")
|
43
|
+
print(f"| {'Recall':13} | {recall:9.4f} |")
|
44
|
+
print(f"| {'F1-score':13} | {f1:9.4f} |")
|
45
|
+
print(f"| {'Hamming Loss':13} | {h_loss:9.4f} |")
|
46
|
+
print(header)
|
47
|
+
|
48
|
+
return accuracy, precision, recall, f1, h_loss
|
oikan/model.py
CHANGED
@@ -15,28 +15,40 @@ class AdaptiveBasisLayer(nn.Module):
|
|
15
15
|
|
16
16
|
class EfficientKAN(nn.Module):
|
17
17
|
'''Module computing feature transformations using nonlinear basis functions and interaction terms.'''
|
18
|
-
def __init__(self, input_dim, hidden_units=10, basis_type='
|
18
|
+
def __init__(self, input_dim, hidden_units=10, basis_type='bsplines'):
|
19
19
|
super().__init__()
|
20
20
|
self.input_dim = input_dim
|
21
21
|
self.hidden_units = hidden_units
|
22
22
|
self.basis_type = basis_type
|
23
23
|
|
24
|
-
if basis_type == '
|
25
|
-
# One BSpline per feature
|
24
|
+
if basis_type == 'bsplines':
|
25
|
+
# One BSpline per feature with adjusted output dimensions
|
26
26
|
self.basis_functions = nn.ModuleList([BSplineBasis(hidden_units) for _ in range(input_dim)])
|
27
|
-
self.basis_output_dim = input_dim * (hidden_units - 4)
|
27
|
+
self.basis_output_dim = input_dim * (hidden_units - 4)
|
28
28
|
elif basis_type == 'fourier':
|
29
29
|
# Use Fourier basis transformation for each feature
|
30
|
-
self.basis_functions = nn.ModuleList([FourierBasis(hidden_units//2) for _ in range(input_dim)])
|
30
|
+
self.basis_functions = nn.ModuleList([FourierBasis(hidden_units // 2) for _ in range(input_dim)])
|
31
31
|
self.basis_output_dim = input_dim * hidden_units
|
32
|
-
|
32
|
+
elif basis_type == 'combo':
|
33
|
+
# Combine BSpline and Fourier basis on a per-feature basis
|
34
|
+
self.basis_functions_bspline = nn.ModuleList([BSplineBasis(hidden_units) for _ in range(input_dim)])
|
35
|
+
self.basis_functions_fourier = nn.ModuleList([FourierBasis(hidden_units // 2) for _ in range(input_dim)])
|
36
|
+
self.basis_output_dim = input_dim * ((hidden_units - 4) + hidden_units)
|
37
|
+
else:
|
38
|
+
raise ValueError(f"Unsupported basis_type: {basis_type}")
|
39
|
+
|
33
40
|
# Interaction layer: captures pairwise feature interactions
|
34
41
|
self.interaction_weights = nn.Parameter(torch.randn(input_dim, input_dim))
|
35
42
|
|
36
43
|
def forward(self, x):
|
37
|
-
#
|
38
|
-
|
39
|
-
|
44
|
+
# Process basis functions per type
|
45
|
+
if self.basis_type == 'combo':
|
46
|
+
transformed_bspline = [bf(x[:, i].unsqueeze(1)) for i, bf in enumerate(self.basis_functions_bspline)]
|
47
|
+
transformed_fourier = [bf(x[:, i].unsqueeze(1)) for i, bf in enumerate(self.basis_functions_fourier)]
|
48
|
+
basis_output = torch.cat(transformed_bspline + transformed_fourier, dim=1)
|
49
|
+
else:
|
50
|
+
transformed_features = [bf(x[:, i].unsqueeze(1)) for i, bf in enumerate(self.basis_functions)]
|
51
|
+
basis_output = torch.cat(transformed_features, dim=1)
|
40
52
|
|
41
53
|
# Compute interaction features via fixed matrix multiplication
|
42
54
|
batch_size = x.size(0)
|
@@ -53,25 +65,35 @@ class EfficientKAN(nn.Module):
|
|
53
65
|
return self.basis_output_dim + self.input_dim
|
54
66
|
|
55
67
|
class OIKAN(nn.Module):
|
56
|
-
'''Main OIKAN model combining nonlinear transformations, SVD-projection, and interpretable layers.
|
57
|
-
|
68
|
+
'''Main OIKAN model combining nonlinear transformations, SVD-projection, and interpretable layers.
|
69
|
+
Supports time series forecasting when forecast_mode is True.
|
70
|
+
'''
|
71
|
+
def __init__(self, input_dim, output_dim, hidden_units=10, reduced_dim=32, basis_type='bsplines', forecast_mode=False):
|
58
72
|
super().__init__()
|
59
|
-
self.
|
73
|
+
self.forecast_mode = forecast_mode
|
74
|
+
if self.forecast_mode:
|
75
|
+
# LSTM encoder for time series forecasting; expects input shape [batch, seq_len, input_dim]
|
76
|
+
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=input_dim, batch_first=True)
|
77
|
+
# Process the last hidden state with EfficientKAN
|
78
|
+
self.efficientkan = EfficientKAN(input_dim, hidden_units, basis_type)
|
79
|
+
else:
|
80
|
+
self.efficientkan = EfficientKAN(input_dim, hidden_units, basis_type)
|
60
81
|
feature_dim = self.efficientkan.get_output_dim()
|
61
|
-
|
62
|
-
# Apply SVD projection to compress high-dimensional features
|
63
82
|
self.svd_projection = nn.Linear(feature_dim, reduced_dim, bias=False)
|
64
|
-
feature_dim = reduced_dim
|
65
|
-
|
66
|
-
# Interpretable layers for final mapping
|
83
|
+
feature_dim = reduced_dim
|
67
84
|
self.interpretable_layers = nn.Sequential(
|
68
85
|
AdaptiveBasisLayer(feature_dim, 32),
|
69
86
|
nn.ReLU(),
|
70
87
|
AdaptiveBasisLayer(32, output_dim)
|
71
88
|
)
|
72
|
-
|
89
|
+
|
73
90
|
def forward(self, x):
|
74
|
-
|
75
|
-
|
91
|
+
if self.forecast_mode:
|
92
|
+
# x shape: [batch, seq_len, input_dim]
|
93
|
+
lstm_out, (hidden, _) = self.lstm(x)
|
94
|
+
x_in = hidden[-1] # Use the last hidden state for forecasting
|
95
|
+
else:
|
96
|
+
x_in = x
|
97
|
+
transformed_x = self.efficientkan(x_in)
|
76
98
|
transformed_x = self.svd_projection(transformed_x)
|
77
99
|
return self.interpretable_layers(transformed_x)
|
oikan/trainer.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
from .regularization import RegularizedLoss
|
4
4
|
|
5
|
-
def train(model, train_data, epochs=100, lr=0.01, save_path=None):
|
5
|
+
def train(model, train_data, epochs=100, lr=0.01, save_path=None, verbose=True):
|
6
6
|
'''Train regression model using MSE loss with regularization.
|
7
7
|
Optionally save the model when training is finished if save_path is provided.
|
8
8
|
'''
|
@@ -19,13 +19,13 @@ def train(model, train_data, epochs=100, lr=0.01, save_path=None):
|
|
19
19
|
loss.backward() # Backpropagate errors
|
20
20
|
optimizer.step() # Update parameters
|
21
21
|
|
22
|
-
if (epoch + 1) % 10 == 0:
|
22
|
+
if (epoch + 1) % 10 == 0 and verbose:
|
23
23
|
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
24
|
-
if save_path:
|
24
|
+
if save_path is not None:
|
25
25
|
torch.save(model.state_dict(), save_path)
|
26
26
|
print(f"Model saved to {save_path}")
|
27
27
|
|
28
|
-
def train_classification(model, train_data, epochs=100, lr=0.01, save_path=None):
|
28
|
+
def train_classification(model, train_data, epochs=100, lr=0.01, save_path=None, verbose=True):
|
29
29
|
'''Train classification model using CrossEntropy loss with regularization.
|
30
30
|
Optionally save the model when training is finished if save_path is provided.
|
31
31
|
'''
|
@@ -42,8 +42,8 @@ def train_classification(model, train_data, epochs=100, lr=0.01, save_path=None)
|
|
42
42
|
loss.backward() # Backpropagation
|
43
43
|
optimizer.step() # Parameter update
|
44
44
|
|
45
|
-
if (epoch + 1) % 10 == 0:
|
45
|
+
if (epoch + 1) % 10 == 0 and verbose:
|
46
46
|
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
47
|
-
if save_path:
|
47
|
+
if save_path is not None:
|
48
48
|
torch.save(model.state_dict(), save_path)
|
49
49
|
print(f"Model saved to {save_path}")
|
oikan/visualize.py
CHANGED
@@ -7,8 +7,6 @@ def visualize_regression(model, X, y):
|
|
7
7
|
model.eval()
|
8
8
|
with torch.no_grad():
|
9
9
|
y_pred = model(torch.FloatTensor(X)).numpy()
|
10
|
-
|
11
|
-
|
12
10
|
plt.figure(figsize=(10, 6))
|
13
11
|
plt.scatter(X[:, 0], y, color='blue', label='True')
|
14
12
|
plt.scatter(X[:, 0], y_pred, color='red', label='Predicted')
|
@@ -18,9 +16,7 @@ def visualize_regression(model, X, y):
|
|
18
16
|
def visualize_classification(model, X, y):
|
19
17
|
'''Visualize classification decision boundaries. For high-dimensional data, uses SVD projection.'''
|
20
18
|
model.eval()
|
21
|
-
|
22
19
|
if X.shape[1] > 2:
|
23
|
-
|
24
20
|
X_mean = np.mean(X, axis=0)
|
25
21
|
X_centered = X - X_mean
|
26
22
|
_, _, Vt = np.linalg.svd(X_centered, full_matrices=False)
|
@@ -28,42 +24,46 @@ def visualize_classification(model, X, y):
|
|
28
24
|
X_proj = (X - X_mean) @ principal.T
|
29
25
|
x_min, x_max = X_proj[:, 0].min() - 1, X_proj[:, 0].max() + 1
|
30
26
|
y_min, y_max = X_proj[:, 1].min() - 1, X_proj[:, 1].max() + 1
|
31
|
-
|
32
|
-
|
33
27
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
|
34
28
|
np.linspace(y_min, y_max, 100))
|
35
29
|
grid_2d = np.c_[xx.ravel(), yy.ravel()]
|
36
|
-
|
37
30
|
X_grid = X_mean + grid_2d @ principal
|
38
|
-
|
39
31
|
with torch.no_grad():
|
40
32
|
Z = model(torch.FloatTensor(X_grid))
|
41
33
|
Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
|
42
|
-
|
43
|
-
|
44
|
-
|
45
34
|
plt.figure(figsize=(10, 8))
|
46
35
|
plt.contourf(xx, yy, Z, alpha=0.4)
|
47
36
|
plt.scatter(X_proj[:, 0], X_proj[:, 1], c=y, alpha=0.8)
|
48
37
|
plt.title("Classification Visualization (SVD Projection)")
|
49
38
|
plt.show()
|
50
|
-
|
51
39
|
else:
|
52
40
|
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
53
41
|
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
54
|
-
|
55
|
-
|
56
42
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
|
57
43
|
np.linspace(y_min, y_max, 100))
|
58
44
|
grid_2d = np.c_[xx.ravel(), yy.ravel()]
|
59
|
-
|
60
|
-
|
61
45
|
with torch.no_grad():
|
62
46
|
Z = model(torch.FloatTensor(grid_2d))
|
63
47
|
Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
|
64
|
-
|
65
|
-
|
66
|
-
|
67
48
|
plt.figure(figsize=(10, 8))
|
68
49
|
plt.contourf(xx, yy, Z, alpha=0.4)
|
69
|
-
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
|
50
|
+
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
|
51
|
+
|
52
|
+
def visualize_time_series_forecasting(model, X, y):
|
53
|
+
'''
|
54
|
+
Visualize time series forecasting results by plotting true vs predicted values.
|
55
|
+
Expected X shape: [samples, seq_len, features] and y: true targets.
|
56
|
+
'''
|
57
|
+
model.eval()
|
58
|
+
with torch.no_grad():
|
59
|
+
y_pred = model(X).detach().cpu().numpy()
|
60
|
+
if isinstance(y, torch.Tensor):
|
61
|
+
y = y.detach().cpu().numpy()
|
62
|
+
plt.figure(figsize=(10, 5))
|
63
|
+
plt.plot(y, label='True', marker='o', linestyle='-')
|
64
|
+
plt.plot(y_pred, label='Predicted', marker='x', linestyle='--')
|
65
|
+
plt.xlabel("Time Step")
|
66
|
+
plt.ylabel("Value")
|
67
|
+
plt.title("Time Series Forecasting Visualization")
|
68
|
+
plt.legend()
|
69
|
+
plt.show()
|
@@ -0,0 +1,13 @@
|
|
1
|
+
oikan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
oikan/metrics.py,sha256=IF13bW3evsyKfZC2jhI-MPRu2Rl77Elo3of68OF_JW8,1928
|
3
|
+
oikan/model.py,sha256=blpTiAFQ-LxhvWedP5Yf5TgdwlOb4t1BuBMe9d-kJZ0,5342
|
4
|
+
oikan/regularization.py,sha256=xt8JNnPdHRAQgzF_vnyme005hWLunz9Vo2qw6m08NMM,1145
|
5
|
+
oikan/symbolic.py,sha256=RRYHOCOCJr5KXRhdcCPvT_OqyNcCnWCWt7fOtos8rRI,5765
|
6
|
+
oikan/trainer.py,sha256=PwA8PnVUiv5wYlQqj3DTplCAUZljZ4iWJUKUDvmIvX0,2062
|
7
|
+
oikan/utils.py,sha256=xbVgrbhXYj57RdD3uNPchjyfmP6Kur7tngoZPa3qWOw,2094
|
8
|
+
oikan/visualize.py,sha256=ZZiRf0P8cuBiC0reBBGVnSTotBq5oxQIRIEgqSrN6u8,2916
|
9
|
+
oikan-0.0.1.11.dist-info/LICENSE,sha256=75ASVmU-XIpN-M4LbVmJ_ibgbzbvRLVti8FhnR0BTf8,1096
|
10
|
+
oikan-0.0.1.11.dist-info/METADATA,sha256=5EpY9clgm3iQ2nLrtLesX-H8sUhZU_lL7bTEPDFj54U,3848
|
11
|
+
oikan-0.0.1.11.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
12
|
+
oikan-0.0.1.11.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
|
13
|
+
oikan-0.0.1.11.dist-info/RECORD,,
|
oikan-0.0.1.9.dist-info/RECORD
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
oikan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
oikan/metrics.py,sha256=65txPbjhTz7lCXCLtTAJTS4E5Hx42wzZ3jKar3lH_bY,860
|
3
|
-
oikan/model.py,sha256=zlw_4HbSK3IiQhE8M4NitvrXa7vffBWsOR-HLRSJADA,3944
|
4
|
-
oikan/regularization.py,sha256=xt8JNnPdHRAQgzF_vnyme005hWLunz9Vo2qw6m08NMM,1145
|
5
|
-
oikan/symbolic.py,sha256=RRYHOCOCJr5KXRhdcCPvT_OqyNcCnWCWt7fOtos8rRI,5765
|
6
|
-
oikan/trainer.py,sha256=M7F5FLCXB5HUKhMRJiX92o0DB9NHWZIb25bGgyYwXrM,1986
|
7
|
-
oikan/utils.py,sha256=xbVgrbhXYj57RdD3uNPchjyfmP6Kur7tngoZPa3qWOw,2094
|
8
|
-
oikan/visualize.py,sha256=sA__nLB35y6tuWDAM3aoC7VezxJAvdSlwzmoPfrnFhQ,2249
|
9
|
-
oikan-0.0.1.9.dist-info/LICENSE,sha256=75ASVmU-XIpN-M4LbVmJ_ibgbzbvRLVti8FhnR0BTf8,1096
|
10
|
-
oikan-0.0.1.9.dist-info/METADATA,sha256=ZyH9i_F8ymdalTVQxVbyyEFUa0xPLt7byLXP5zVfjO0,3847
|
11
|
-
oikan-0.0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
12
|
-
oikan-0.0.1.9.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
|
13
|
-
oikan-0.0.1.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|