oikan 0.0.1.8__py3-none-any.whl → 0.0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oikan/trainer.py CHANGED
@@ -2,8 +2,10 @@ import torch
2
2
  import torch.nn as nn
3
3
  from .regularization import RegularizedLoss
4
4
 
5
- def train(model, train_data, epochs=100, lr=0.01):
6
- '''Train regression model using MSE loss with regularization.'''
5
+ def train(model, train_data, epochs=100, lr=0.01, save_path=None):
6
+ '''Train regression model using MSE loss with regularization.
7
+ Optionally save the model when training is finished if save_path is provided.
8
+ '''
7
9
  X_train, y_train = train_data
8
10
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
9
11
  criterion = nn.MSELoss()
@@ -13,16 +15,20 @@ def train(model, train_data, epochs=100, lr=0.01):
13
15
  for epoch in range(epochs):
14
16
  optimizer.zero_grad() # Reset gradients
15
17
  outputs = model(X_train)
16
- # Compute loss including regularization penalties
17
18
  loss = reg_loss(outputs, y_train, X_train)
18
19
  loss.backward() # Backpropagate errors
19
20
  optimizer.step() # Update parameters
20
21
 
21
22
  if (epoch + 1) % 10 == 0:
22
23
  print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
24
+ if save_path:
25
+ torch.save(model.state_dict(), save_path)
26
+ print(f"Model saved to {save_path}")
23
27
 
24
- def train_classification(model, train_data, epochs=100, lr=0.01):
25
- '''Train classification model using CrossEntropy loss with regularization.'''
28
+ def train_classification(model, train_data, epochs=100, lr=0.01, save_path=None):
29
+ '''Train classification model using CrossEntropy loss with regularization.
30
+ Optionally save the model when training is finished if save_path is provided.
31
+ '''
26
32
  X_train, y_train = train_data
27
33
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
28
34
  criterion = nn.CrossEntropyLoss()
@@ -32,10 +38,12 @@ def train_classification(model, train_data, epochs=100, lr=0.01):
32
38
  for epoch in range(epochs):
33
39
  optimizer.zero_grad() # Reset gradients each epoch
34
40
  outputs = model(X_train)
35
- # Loss includes both cross-entropy and regularization terms
36
41
  loss = reg_loss(outputs, y_train, X_train)
37
42
  loss.backward() # Backpropagation
38
43
  optimizer.step() # Parameter update
39
44
 
40
45
  if (epoch + 1) % 10 == 0:
41
46
  print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
47
+ if save_path:
48
+ torch.save(model.state_dict(), save_path)
49
+ print(f"Model saved to {save_path}")
oikan/visualize.py CHANGED
@@ -1,39 +1,69 @@
1
+ import torch
1
2
  import numpy as np
2
3
  import matplotlib.pyplot as plt
3
- import torch
4
4
 
5
5
  def visualize_regression(model, X, y):
6
- '''Visualize regression results via a scatter plot, comparing true vs predicted values.'''
6
+ '''Visualize regression results using true vs predicted scatter plots.'''
7
7
  model.eval()
8
8
  with torch.no_grad():
9
9
  y_pred = model(torch.FloatTensor(X)).numpy()
10
+
11
+
10
12
  plt.figure(figsize=(10, 6))
11
- # Plot true values vs predictions
12
13
  plt.scatter(X[:, 0], y, color='blue', label='True')
13
14
  plt.scatter(X[:, 0], y_pred, color='red', label='Predicted')
14
15
  plt.legend()
15
- plt.title("Regression: True vs Predicted")
16
- plt.xlabel("Feature 1")
17
- plt.ylabel("Output")
18
16
  plt.show()
19
17
 
20
18
  def visualize_classification(model, X, y):
21
- '''Visualize classification decision boundaries for 2D input data.'''
19
+ '''Visualize classification decision boundaries. For high-dimensional data, uses SVD projection.'''
22
20
  model.eval()
23
- x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
24
- y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
25
- xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
26
- np.linspace(y_min, y_max, 100))
27
- grid_2d = np.c_[xx.ravel(), yy.ravel()]
28
- with torch.no_grad():
29
- # Compute prediction for each point in the grid
30
- Z = model(torch.FloatTensor(grid_2d))
31
- Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
32
- plt.figure(figsize=(10, 8))
33
- # Draw decision boundaries and scatter the data
34
- plt.contourf(xx, yy, Z, alpha=0.4)
35
- plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
36
- plt.title("Classification Visualization")
37
- plt.xlabel("Feature 1")
38
- plt.ylabel("Feature 2")
39
- plt.show()
21
+
22
+ if X.shape[1] > 2:
23
+
24
+ X_mean = np.mean(X, axis=0)
25
+ X_centered = X - X_mean
26
+ _, _, Vt = np.linalg.svd(X_centered, full_matrices=False)
27
+ principal = Vt[:2]
28
+ X_proj = (X - X_mean) @ principal.T
29
+ x_min, x_max = X_proj[:, 0].min() - 1, X_proj[:, 0].max() + 1
30
+ y_min, y_max = X_proj[:, 1].min() - 1, X_proj[:, 1].max() + 1
31
+
32
+
33
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
34
+ np.linspace(y_min, y_max, 100))
35
+ grid_2d = np.c_[xx.ravel(), yy.ravel()]
36
+
37
+ X_grid = X_mean + grid_2d @ principal
38
+
39
+ with torch.no_grad():
40
+ Z = model(torch.FloatTensor(X_grid))
41
+ Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
42
+
43
+
44
+
45
+ plt.figure(figsize=(10, 8))
46
+ plt.contourf(xx, yy, Z, alpha=0.4)
47
+ plt.scatter(X_proj[:, 0], X_proj[:, 1], c=y, alpha=0.8)
48
+ plt.title("Classification Visualization (SVD Projection)")
49
+ plt.show()
50
+
51
+ else:
52
+ x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
53
+ y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
54
+
55
+
56
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
57
+ np.linspace(y_min, y_max, 100))
58
+ grid_2d = np.c_[xx.ravel(), yy.ravel()]
59
+
60
+
61
+ with torch.no_grad():
62
+ Z = model(torch.FloatTensor(grid_2d))
63
+ Z = torch.argmax(Z, dim=1).numpy().reshape(xx.shape)
64
+
65
+
66
+
67
+ plt.figure(figsize=(10, 8))
68
+ plt.contourf(xx, yy, Z, alpha=0.4)
69
+ plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: oikan
3
- Version: 0.0.1.8
3
+ Version: 0.0.1.9
4
4
  Summary: OIKAN: Optimized Interpretable Kolmogorov-Arnold Networks
5
5
  Author: Arman Zhalgasbayev
6
6
  License: MIT
@@ -22,7 +22,8 @@ Optimized Interpretable Kolmogorov-Arnold Networks (OIKAN)
22
22
  A deep learning framework for interpretable neural networks using advanced basis functions.
23
23
 
24
24
  [![PyPI version](https://badge.fury.io/py/oikan.svg)](https://badge.fury.io/py/oikan)
25
- [![PyPI downloads](https://img.shields.io/pypi/dm/oikan.svg)](https://pypistats.org/packages/oikan)
25
+ [![PyPI Downloads per month](https://img.shields.io/pypi/dm/oikan.svg)](https://pypistats.org/packages/oikan)
26
+ [![PyPI Total Downloads](https://static.pepy.tech/badge/oikan)](https://pepy.tech/projects/oikan)
26
27
  [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
27
28
  [![GitHub issues](https://img.shields.io/github/issues/silvermete0r/OIKAN.svg)](https://github.com/silvermete0r/oikan/issues)
28
29
  [![Docs](https://img.shields.io/badge/docs-passing-brightgreen)](https://silvermete0r.github.io/oikan/)
@@ -3,11 +3,11 @@ oikan/metrics.py,sha256=65txPbjhTz7lCXCLtTAJTS4E5Hx42wzZ3jKar3lH_bY,860
3
3
  oikan/model.py,sha256=zlw_4HbSK3IiQhE8M4NitvrXa7vffBWsOR-HLRSJADA,3944
4
4
  oikan/regularization.py,sha256=xt8JNnPdHRAQgzF_vnyme005hWLunz9Vo2qw6m08NMM,1145
5
5
  oikan/symbolic.py,sha256=RRYHOCOCJr5KXRhdcCPvT_OqyNcCnWCWt7fOtos8rRI,5765
6
- oikan/trainer.py,sha256=S-23uwmQ3Kx1FnE-dKd76zTZjvaV0VUZoChUsNzjcwk,1672
6
+ oikan/trainer.py,sha256=M7F5FLCXB5HUKhMRJiX92o0DB9NHWZIb25bGgyYwXrM,1986
7
7
  oikan/utils.py,sha256=xbVgrbhXYj57RdD3uNPchjyfmP6Kur7tngoZPa3qWOw,2094
8
- oikan/visualize.py,sha256=VpIzWpwoZihQ0gPSRjsEuKSxqHf1SiKxLynOzZ4P6HE,1539
9
- oikan-0.0.1.8.dist-info/LICENSE,sha256=75ASVmU-XIpN-M4LbVmJ_ibgbzbvRLVti8FhnR0BTf8,1096
10
- oikan-0.0.1.8.dist-info/METADATA,sha256=aDbshPny4TxIE_tZt0I8I_T_9tRwkomC6zC1BHT4knw,3738
11
- oikan-0.0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
12
- oikan-0.0.1.8.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
13
- oikan-0.0.1.8.dist-info/RECORD,,
8
+ oikan/visualize.py,sha256=sA__nLB35y6tuWDAM3aoC7VezxJAvdSlwzmoPfrnFhQ,2249
9
+ oikan-0.0.1.9.dist-info/LICENSE,sha256=75ASVmU-XIpN-M4LbVmJ_ibgbzbvRLVti8FhnR0BTf8,1096
10
+ oikan-0.0.1.9.dist-info/METADATA,sha256=ZyH9i_F8ymdalTVQxVbyyEFUa0xPLt7byLXP5zVfjO0,3847
11
+ oikan-0.0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
12
+ oikan-0.0.1.9.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
13
+ oikan-0.0.1.9.dist-info/RECORD,,