oikan 0.0.1.8__py3-none-any.whl → 0.0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oikan/trainer.py +16 -8
- oikan/visualize.py +54 -24
- {oikan-0.0.1.8.dist-info → oikan-0.0.1.10.dist-info}/METADATA +3 -2
- oikan-0.0.1.10.dist-info/RECORD +13 -0
- oikan-0.0.1.8.dist-info/RECORD +0 -13
- {oikan-0.0.1.8.dist-info → oikan-0.0.1.10.dist-info}/LICENSE +0 -0
- {oikan-0.0.1.8.dist-info → oikan-0.0.1.10.dist-info}/WHEEL +0 -0
- {oikan-0.0.1.8.dist-info → oikan-0.0.1.10.dist-info}/top_level.txt +0 -0
oikan/trainer.py
CHANGED
@@ -2,8 +2,10 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
from .regularization import RegularizedLoss
|
4
4
|
|
5
|
-
def train(model, train_data, epochs=100, lr=0.01):
|
6
|
-
'''Train regression model using MSE loss with regularization.
|
5
|
+
def train(model, train_data, epochs=100, lr=0.01, save_path=None, verbose=True):
|
6
|
+
'''Train regression model using MSE loss with regularization.
|
7
|
+
Optionally save the model when training is finished if save_path is provided.
|
8
|
+
'''
|
7
9
|
X_train, y_train = train_data
|
8
10
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
9
11
|
criterion = nn.MSELoss()
|
@@ -13,16 +15,20 @@ def train(model, train_data, epochs=100, lr=0.01):
|
|
13
15
|
for epoch in range(epochs):
|
14
16
|
optimizer.zero_grad() # Reset gradients
|
15
17
|
outputs = model(X_train)
|
16
|
-
# Compute loss including regularization penalties
|
17
18
|
loss = reg_loss(outputs, y_train, X_train)
|
18
19
|
loss.backward() # Backpropagate errors
|
19
20
|
optimizer.step() # Update parameters
|
20
21
|
|
21
|
-
if (epoch + 1) % 10 == 0:
|
22
|
+
if (epoch + 1) % 10 == 0 and verbose:
|
22
23
|
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
24
|
+
if save_path is not None:
|
25
|
+
torch.save(model.state_dict(), save_path)
|
26
|
+
print(f"Model saved to {save_path}")
|
23
27
|
|
24
|
-
def train_classification(model, train_data, epochs=100, lr=0.01):
|
25
|
-
'''Train classification model using CrossEntropy loss with regularization.
|
28
|
+
def train_classification(model, train_data, epochs=100, lr=0.01, save_path=None, verbose=True):
|
29
|
+
'''Train classification model using CrossEntropy loss with regularization.
|
30
|
+
Optionally save the model when training is finished if save_path is provided.
|
31
|
+
'''
|
26
32
|
X_train, y_train = train_data
|
27
33
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
28
34
|
criterion = nn.CrossEntropyLoss()
|
@@ -32,10 +38,12 @@ def train_classification(model, train_data, epochs=100, lr=0.01):
|
|
32
38
|
for epoch in range(epochs):
|
33
39
|
optimizer.zero_grad() # Reset gradients each epoch
|
34
40
|
outputs = model(X_train)
|
35
|
-
# Loss includes both cross-entropy and regularization terms
|
36
41
|
loss = reg_loss(outputs, y_train, X_train)
|
37
42
|
loss.backward() # Backpropagation
|
38
43
|
optimizer.step() # Parameter update
|
39
44
|
|
40
|
-
if (epoch + 1) % 10 == 0:
|
45
|
+
if (epoch + 1) % 10 == 0 and verbose:
|
41
46
|
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
47
|
+
if save_path is not None:
|
48
|
+
torch.save(model.state_dict(), save_path)
|
49
|
+
print(f"Model saved to {save_path}")
|
oikan/visualize.py
CHANGED
@@ -1,39 +1,69 @@
|
|
1
|
+
import torch
|
1
2
|
import numpy as np
|
2
3
|
import matplotlib.pyplot as plt
|
3
|
-
import torch
|
4
4
|
|
5
5
|
def visualize_regression(model, X, y):
|
6
|
-
'''Visualize regression results
|
6
|
+
'''Visualize regression results using true vs predicted scatter plots.'''
|
7
7
|
model.eval()
|
8
8
|
with torch.no_grad():
|
9
9
|
y_pred = model(torch.FloatTensor(X)).numpy()
|
10
|
+
|
11
|
+
|
10
12
|
plt.figure(figsize=(10, 6))
|
11
|
-
# Plot true values vs predictions
|
12
13
|
plt.scatter(X[:, 0], y, color='blue', label='True')
|
13
14
|
plt.scatter(X[:, 0], y_pred, color='red', label='Predicted')
|
14
15
|
plt.legend()
|
15
|
-
plt.title("Regression: True vs Predicted")
|
16
|
-
plt.xlabel("Feature 1")
|
17
|
-
plt.ylabel("Output")
|
18
16
|
plt.show()
|
19
17
|
|
20
18
|
def visualize_classification(model, X, y):
|
21
|
-
'''Visualize classification decision boundaries
|
19
|
+
'''Visualize classification decision boundaries. For high-dimensional data, uses SVD projection.'''
|
22
20
|
model.eval()
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
21
|
+
|
22
|
+
if X.shape[1] > 2:
|
23
|
+
|
24
|
+
X_mean = np.mean(X, axis=0)
|
25
|
+
X_centered = X - X_mean
|
26
|
+
_, _, Vt = np.linalg.svd(X_centered, full_matrices=False)
|
27
|
+
principal = Vt[:2]
|
28
|
+
X_proj = (X - X_mean) @ principal.T
|
29
|
+
x_min, x_max = X_proj[:, 0].min() - 1, X_proj[:, 0].max() + 1
|
30
|
+
y_min, y_max = X_proj[:, 1].min() - 1, X_proj[:, 1].max() + 1
|
31
|
+
|
32
|
+
|
33
|
+
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
|
34
|
+
np.linspace(y_min, y_max, 100))
|
35
|
+
grid_2d = np.c_[xx.ravel(), yy.ravel()]
|
36
|
+
|
37
|
+
X_grid = X_mean + grid_2d @ principal
|
38
|
+
|
39
|
+
with torch.no_grad():
|
40
|
+
Z = model(torch.FloatTensor(X_grid))
|
41
|
+
Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
|
42
|
+
|
43
|
+
|
44
|
+
|
45
|
+
plt.figure(figsize=(10, 8))
|
46
|
+
plt.contourf(xx, yy, Z, alpha=0.4)
|
47
|
+
plt.scatter(X_proj[:, 0], X_proj[:, 1], c=y, alpha=0.8)
|
48
|
+
plt.title("Classification Visualization (SVD Projection)")
|
49
|
+
plt.show()
|
50
|
+
|
51
|
+
else:
|
52
|
+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
53
|
+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
54
|
+
|
55
|
+
|
56
|
+
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
|
57
|
+
np.linspace(y_min, y_max, 100))
|
58
|
+
grid_2d = np.c_[xx.ravel(), yy.ravel()]
|
59
|
+
|
60
|
+
|
61
|
+
with torch.no_grad():
|
62
|
+
Z = model(torch.FloatTensor(grid_2d))
|
63
|
+
Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
|
64
|
+
|
65
|
+
|
66
|
+
|
67
|
+
plt.figure(figsize=(10, 8))
|
68
|
+
plt.contourf(xx, yy, Z, alpha=0.4)
|
69
|
+
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: oikan
|
3
|
-
Version: 0.0.1.
|
3
|
+
Version: 0.0.1.10
|
4
4
|
Summary: OIKAN: Optimized Interpretable Kolmogorov-Arnold Networks
|
5
5
|
Author: Arman Zhalgasbayev
|
6
6
|
License: MIT
|
@@ -22,7 +22,8 @@ Optimized Interpretable Kolmogorov-Arnold Networks (OIKAN)
|
|
22
22
|
A deep learning framework for interpretable neural networks using advanced basis functions.
|
23
23
|
|
24
24
|
[](https://badge.fury.io/py/oikan)
|
25
|
-
[](https://pypistats.org/packages/oikan)
|
26
|
+
[](https://pepy.tech/projects/oikan)
|
26
27
|
[](https://opensource.org/licenses/MIT)
|
27
28
|
[](https://github.com/silvermete0r/oikan/issues)
|
28
29
|
[](https://silvermete0r.github.io/oikan/)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
oikan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
oikan/metrics.py,sha256=65txPbjhTz7lCXCLtTAJTS4E5Hx42wzZ3jKar3lH_bY,860
|
3
|
+
oikan/model.py,sha256=zlw_4HbSK3IiQhE8M4NitvrXa7vffBWsOR-HLRSJADA,3944
|
4
|
+
oikan/regularization.py,sha256=xt8JNnPdHRAQgzF_vnyme005hWLunz9Vo2qw6m08NMM,1145
|
5
|
+
oikan/symbolic.py,sha256=RRYHOCOCJr5KXRhdcCPvT_OqyNcCnWCWt7fOtos8rRI,5765
|
6
|
+
oikan/trainer.py,sha256=PwA8PnVUiv5wYlQqj3DTplCAUZljZ4iWJUKUDvmIvX0,2062
|
7
|
+
oikan/utils.py,sha256=xbVgrbhXYj57RdD3uNPchjyfmP6Kur7tngoZPa3qWOw,2094
|
8
|
+
oikan/visualize.py,sha256=sA__nLB35y6tuWDAM3aoC7VezxJAvdSlwzmoPfrnFhQ,2249
|
9
|
+
oikan-0.0.1.10.dist-info/LICENSE,sha256=75ASVmU-XIpN-M4LbVmJ_ibgbzbvRLVti8FhnR0BTf8,1096
|
10
|
+
oikan-0.0.1.10.dist-info/METADATA,sha256=Mzmf6JP3taLsxytaSz8TI30M-uvmaq7jcEkZORGCatw,3848
|
11
|
+
oikan-0.0.1.10.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
12
|
+
oikan-0.0.1.10.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
|
13
|
+
oikan-0.0.1.10.dist-info/RECORD,,
|
oikan-0.0.1.8.dist-info/RECORD
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
oikan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
oikan/metrics.py,sha256=65txPbjhTz7lCXCLtTAJTS4E5Hx42wzZ3jKar3lH_bY,860
|
3
|
-
oikan/model.py,sha256=zlw_4HbSK3IiQhE8M4NitvrXa7vffBWsOR-HLRSJADA,3944
|
4
|
-
oikan/regularization.py,sha256=xt8JNnPdHRAQgzF_vnyme005hWLunz9Vo2qw6m08NMM,1145
|
5
|
-
oikan/symbolic.py,sha256=RRYHOCOCJr5KXRhdcCPvT_OqyNcCnWCWt7fOtos8rRI,5765
|
6
|
-
oikan/trainer.py,sha256=S-23uwmQ3Kx1FnE-dKd76zTZjvaV0VUZoChUsNzjcwk,1672
|
7
|
-
oikan/utils.py,sha256=xbVgrbhXYj57RdD3uNPchjyfmP6Kur7tngoZPa3qWOw,2094
|
8
|
-
oikan/visualize.py,sha256=VpIzWpwoZihQ0gPSRjsEuKSxqHf1SiKxLynOzZ4P6HE,1539
|
9
|
-
oikan-0.0.1.8.dist-info/LICENSE,sha256=75ASVmU-XIpN-M4LbVmJ_ibgbzbvRLVti8FhnR0BTf8,1096
|
10
|
-
oikan-0.0.1.8.dist-info/METADATA,sha256=aDbshPny4TxIE_tZt0I8I_T_9tRwkomC6zC1BHT4knw,3738
|
11
|
-
oikan-0.0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
12
|
-
oikan-0.0.1.8.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
|
13
|
-
oikan-0.0.1.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|