oikan 0.0.1.1__py3-none-any.whl → 0.0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oikan/model.py CHANGED
@@ -1,28 +1,65 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ from .utils import BSplineBasis, FourierBasis
4
+
5
+ class AdaptiveBasisLayer(nn.Module):
6
+ def __init__(self, input_dim, hidden_dim):
7
+ super().__init__()
8
+ self.weights = nn.Parameter(torch.randn(input_dim, hidden_dim))
9
+ self.bias = nn.Parameter(torch.zeros(hidden_dim))
10
+
11
+ def forward(self, x):
12
+ return torch.matmul(x, self.weights) + self.bias
3
13
 
4
- # EfficientKAN Layer
5
14
  class EfficientKAN(nn.Module):
6
- def __init__(self, input_dim, hidden_units=10):
7
- super(EfficientKAN, self).__init__()
8
- self.basis_functions = nn.ModuleList([nn.Linear(1, hidden_units) for _ in range(input_dim)])
9
- self.activations = nn.ReLU()
15
+ def __init__(self, input_dim, hidden_units=10, basis_type='bspline'):
16
+ super().__init__()
17
+ self.input_dim = input_dim
18
+ self.hidden_units = hidden_units
19
+ self.basis_type = basis_type
20
+
21
+ if basis_type == 'bspline':
22
+ self.basis_functions = nn.ModuleList([BSplineBasis(hidden_units) for _ in range(input_dim)])
23
+ self.basis_output_dim = input_dim * (hidden_units - 4) # Adjusted for BSpline output
24
+ elif basis_type == 'fourier':
25
+ self.basis_functions = nn.ModuleList([FourierBasis(hidden_units//2) for _ in range(input_dim)])
26
+ self.basis_output_dim = input_dim * hidden_units
27
+
28
+ # Grid-based interaction layer
29
+ self.interaction_weights = nn.Parameter(torch.randn(input_dim, input_dim))
10
30
 
11
31
  def forward(self, x):
12
- transformed_features = [self.activations(bf(x[:, i].unsqueeze(1))) for i, bf in enumerate(self.basis_functions)]
13
- return torch.cat(transformed_features, dim=1)
32
+ # Transform each feature using basis functions
33
+ transformed_features = [bf(x[:, i].unsqueeze(1)) for i, bf in enumerate(self.basis_functions)]
34
+ basis_output = torch.cat(transformed_features, dim=1)
35
+
36
+ # Compute feature interactions - fixed matrix multiplication
37
+ batch_size = x.size(0)
38
+ x_reshaped = x.view(batch_size, self.input_dim, 1) # [batch_size, input_dim, 1]
39
+ interaction_matrix = torch.sigmoid(self.interaction_weights) # [input_dim, input_dim]
40
+ interaction_features = torch.bmm(x_reshaped.transpose(1, 2),
41
+ x_reshaped * interaction_matrix.unsqueeze(0)) # [batch_size, 1, 1]
42
+ interaction_features = interaction_features.view(batch_size, -1) # [batch_size, 1]
43
+
44
+ return torch.cat([basis_output, interaction_features], dim=1)
45
+
46
+ def get_output_dim(self):
47
+ return self.basis_output_dim + self.input_dim
14
48
 
15
- # OIKAN Model
16
49
  class OIKAN(nn.Module):
17
50
  def __init__(self, input_dim, output_dim, hidden_units=10):
18
- super(OIKAN, self).__init__()
51
+ super().__init__()
19
52
  self.efficientkan = EfficientKAN(input_dim, hidden_units)
20
- self.mlp = nn.Sequential(
21
- nn.Linear(input_dim * hidden_units, 32),
53
+
54
+ # Get actual feature dimension after transformation
55
+ feature_dim = self.efficientkan.get_output_dim()
56
+
57
+ self.interpretable_layers = nn.Sequential(
58
+ AdaptiveBasisLayer(feature_dim, 32),
22
59
  nn.ReLU(),
23
- nn.Linear(32, output_dim)
60
+ AdaptiveBasisLayer(32, output_dim)
24
61
  )
25
62
 
26
63
  def forward(self, x):
27
64
  transformed_x = self.efficientkan(x)
28
- return self.mlp(transformed_x)
65
+ return self.interpretable_layers(transformed_x)
@@ -0,0 +1,30 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class RegularizedLoss:
5
+ def __init__(self, base_criterion, model, l1_lambda=0.01, gradient_lambda=0.01):
6
+ self.base_criterion = base_criterion
7
+ self.model = model
8
+ self.l1_lambda = l1_lambda
9
+ self.gradient_lambda = gradient_lambda
10
+
11
+ def __call__(self, pred, target, inputs):
12
+ base_loss = self.base_criterion(pred, target)
13
+
14
+ # L1 regularization
15
+ l1_loss = 0
16
+ for param in self.model.parameters():
17
+ l1_loss += torch.norm(param, p=1)
18
+
19
+ # Gradient penalty
20
+ grad_penalty = 0
21
+ inputs.requires_grad_(True)
22
+ outputs = self.model(inputs)
23
+ gradients = torch.autograd.grad(
24
+ outputs=outputs.sum(),
25
+ inputs=inputs,
26
+ create_graph=True
27
+ )[0]
28
+ grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
29
+
30
+ return base_loss + self.l1_lambda * l1_loss + self.gradient_lambda * grad_penalty
oikan/symbolic.py CHANGED
@@ -1,36 +1,129 @@
1
1
  import torch
2
- from sympy import symbols, simplify, Add
2
+ import numpy as np
3
+ import networkx as nx
4
+ import matplotlib.pyplot as plt
3
5
 
4
- # Regression symbolic extraction
5
- def extract_symbolic_formula_regression(model, input_data):
6
- symbolic_vars = symbols([f'x{i}' for i in range(input_data.shape[1])])
7
-
6
+ ADVANCED_LIB = {
7
+ 'x': lambda x: x,
8
+ 'x^2': lambda x: x**2,
9
+ 'x^3': lambda x: x**3,
10
+ 'x^4': lambda x: x**4,
11
+ 'x^5': lambda x: x**5,
12
+ 'exp': lambda x: np.exp(x),
13
+ 'log': lambda x: np.log(np.abs(x) + 1e-8),
14
+ 'sqrt': lambda x: np.sqrt(np.abs(x)),
15
+ 'tanh': lambda x: np.tanh(x),
16
+ 'sin': lambda x: np.sin(x),
17
+ 'abs': lambda x: np.abs(x)
18
+ }
19
+
20
+ # STEP-1: Helper functions
21
+ def get_model_predictions(model, X, mode):
22
+ """Compute model predictions and return target values (and raw preds for classification)."""
23
+ X_tensor = torch.FloatTensor(X)
8
24
  with torch.no_grad():
9
- weights = model.mlp[0].weight.cpu().numpy()
10
- if weights.size == 0:
11
- print("Warning: Extracted weights are empty.")
12
- return "NaN"
25
+ preds = model(X_tensor)
26
+ if mode == 'regression':
27
+ return preds.detach().cpu().numpy().flatten(), None
28
+ elif mode == 'classification':
29
+ out = preds.detach().cpu().numpy()
30
+ target = (out[:, 0] - out[:, 1]).flatten() if (out.ndim > 1 and out.shape[1] > 1) else out.flatten()
31
+ return target, out
32
+ else:
33
+ raise ValueError("Unknown mode")
13
34
 
14
- formula = sum(weights[0, i] * symbolic_vars[i] for i in range(len(symbolic_vars)))
15
- return simplify(formula)
35
+ def build_design_matrix(X, return_names=False):
36
+ """Build the design matrix using the advanced nonlinear bases."""
37
+ X_np = np.array(X)
38
+ n_samples, d = X_np.shape
39
+ F_parts = [np.ones((n_samples, 1))]
40
+ names = ['1'] if return_names else None
41
+ for j in range(d):
42
+ xj = X_np[:, j:j+1]
43
+ for key, func in ADVANCED_LIB.items():
44
+ F_parts.append(func(xj))
45
+ if return_names:
46
+ names.append(f"{key}(x{j+1})")
47
+ return (np.hstack(F_parts), names) if return_names else np.hstack(F_parts)
16
48
 
17
- # Classification symbolic extraction
18
- def extract_symbolic_formula_classification(model, input_data):
49
+ # STEP-2: Main functions using helpers
50
+ def extract_symbolic_formula(model, X, mode='regression'):
19
51
  """
20
- Extracts a symbolic decision boundary for a two-class classifier.
21
- Approximates:
22
- decision = (w[0] - w[1]) · x + (b[0] - b[1])
23
- where w and b are from the model's final linear layer.
52
+ Approximate a symbolic formula from the model using advanced nonlinear bases.
24
53
  """
25
- symbolic_vars = symbols([f'x{i}' for i in range(input_data.shape[1])])
26
- with torch.no_grad():
27
- final_layer = model.mlp[-1]
28
- w = final_layer.weight.cpu().numpy()
29
- b = final_layer.bias.cpu().numpy()
30
- if w.shape[0] < 2:
31
- print("Classification symbolic extraction requires at least 2 classes.")
32
- return "NaN"
33
- w_diff = w[0] - w[1]
34
- b_diff = b[0] - b[1]
35
- formula = sum(w_diff[i] * symbolic_vars[i] for i in range(len(symbolic_vars))) + b_diff
36
- return simplify(formula)
54
+ n_samples = np.array(X).shape[0]
55
+ y_target, _ = get_model_predictions(model, X, mode)
56
+ F, func_names = build_design_matrix(X, return_names=True)
57
+ beta, _, _, _ = np.linalg.lstsq(F, y_target, rcond=None)
58
+ terms = [f"({c:.2f}*{name})" for c, name in zip(beta, func_names) if abs(c) > 1e-4]
59
+ return " + ".join(terms)
60
+
61
+ def test_symbolic_formula(model, X, mode='regression'):
62
+ """
63
+ Evaluate the extracted symbolic formula against model outputs.
64
+ """
65
+ n_samples = np.array(X).shape[0]
66
+ y_target, out = get_model_predictions(model, X, mode)
67
+ F = build_design_matrix(X, return_names=False)
68
+ beta, _, _, _ = np.linalg.lstsq(F, y_target, rcond=None)
69
+ symbolic_vals = F.dot(beta)
70
+ if mode == 'regression':
71
+ mse = np.mean((symbolic_vals - y_target) ** 2)
72
+ mae = np.mean(np.abs(symbolic_vals - y_target))
73
+ rmse = np.sqrt(mse)
74
+ print(f"(Advanced) MSE: {mse:.4f}, MAE: {mae:.4f}, RMSE: {rmse:.4f}")
75
+ return mse, mae, rmse
76
+ elif mode == 'classification':
77
+ sym_preds = np.where(symbolic_vals >= 0, 0, 1)
78
+ model_classes = np.argmax(out, axis=1) if (out.ndim > 1) else (out >= 0.5).astype(int)
79
+ if model_classes.shape[0] != sym_preds.shape[0]:
80
+ raise ValueError("Shape mismatch between symbolic and model predictions.")
81
+ accuracy = np.mean(sym_preds == model_classes)
82
+ print(f"(Advanced) Accuracy: {accuracy:.4f}")
83
+ return accuracy
84
+
85
+ def plot_symbolic_formula(model, X, mode='regression'):
86
+ """
87
+ Plot a graph representation of the extracted symbolic formula.
88
+ """
89
+ formula = extract_symbolic_formula(model, X, mode)
90
+ G = nx.DiGraph()
91
+ G.add_node("Output")
92
+ terms = formula.split(" + ")
93
+ for term in terms:
94
+ expr = term.strip("()")
95
+ coeff_str, basis = expr.split("*", 1) if "*" in expr else (expr, "unknown")
96
+ node_label = f"{basis}\n({float(coeff_str):.2f})"
97
+ G.add_node(node_label)
98
+ G.add_edge(node_label, "Output", weight=float(coeff_str))
99
+ left_nodes = [n for n in G.nodes() if n != "Output"]
100
+ pos = {}
101
+ n_left = len(left_nodes)
102
+ for i, node in enumerate(sorted(left_nodes)):
103
+ pos[node] = (0, 1 - (i / max(n_left - 1, 1)))
104
+ pos["Output"] = (1, 0.5)
105
+ plt.figure(figsize=(12, 8))
106
+ nx.draw(G, pos, with_labels=True, node_color="skyblue", node_size=2500, font_size=10,
107
+ arrows=True, arrowstyle='->', arrowsize=20)
108
+ edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}
109
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red', font_size=10)
110
+ plt.title("OIKAN Symbolic Formula Graph")
111
+ plt.axis("off")
112
+ plt.show()
113
+
114
+ def extract_latex_formula(model, X, mode='regression'):
115
+ """
116
+ Return the extracted symbolic formula as LaTeX code.
117
+ """
118
+ formula = extract_symbolic_formula(model, X, mode)
119
+ terms = formula.split(" + ")
120
+ latex_terms = []
121
+ for term in terms:
122
+ expr = term.strip("()")
123
+ coeff_str, basis = expr.split("*", 1) if "*" in expr else (expr, "")
124
+ coeff = float(coeff_str)
125
+ coeff_latex = f"{abs(coeff):.2f}".rstrip("0").rstrip(".")
126
+ term_latex = coeff_latex if basis.strip() == "1" else f"{coeff_latex} \\cdot {basis.strip()}"
127
+ latex_terms.append(f"- {term_latex}" if coeff < 0 else f"+ {term_latex}")
128
+ latex_formula = " ".join(latex_terms).lstrip("+ ").strip()
129
+ return f"$$ {latex_formula} $$"
oikan/trainer.py CHANGED
@@ -1,32 +1,37 @@
1
- import torch.optim as optim
1
+ import torch
2
2
  import torch.nn as nn
3
+ from .regularization import RegularizedLoss
3
4
 
4
- # Regression training
5
- def train(model, train_loader, epochs=100, lr=0.01):
5
+ def train(model, train_data, epochs=100, lr=0.01):
6
+ X_train, y_train = train_data
7
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
6
8
  criterion = nn.MSELoss()
7
- optimizer = optim.LBFGS(model.parameters(), lr=lr)
8
-
9
- def closure():
9
+ reg_loss = RegularizedLoss(criterion, model)
10
+
11
+ model.train()
12
+ for epoch in range(epochs):
10
13
  optimizer.zero_grad()
11
- outputs = model(train_loader[0])
12
- loss = criterion(outputs, train_loader[1])
14
+ outputs = model(X_train)
15
+ loss = reg_loss(outputs, y_train, X_train)
13
16
  loss.backward()
14
- print(f"Loss: {loss.item()}")
15
- return loss
16
-
17
- for epoch in range(epochs):
18
- optimizer.step(closure)
19
- print(f"Epoch {epoch+1}/{epochs}")
17
+ optimizer.step()
18
+
19
+ if (epoch + 1) % 10 == 0:
20
+ print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
20
21
 
21
- # Classification training
22
- def train_classification(model, train_loader, epochs=100, lr=0.01):
22
+ def train_classification(model, train_data, epochs=100, lr=0.01):
23
+ X_train, y_train = train_data
24
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
23
25
  criterion = nn.CrossEntropyLoss()
24
- optimizer = optim.Adam(model.parameters(), lr=lr)
26
+ reg_loss = RegularizedLoss(criterion, model)
25
27
 
28
+ model.train()
26
29
  for epoch in range(epochs):
27
30
  optimizer.zero_grad()
28
- outputs = model(train_loader[0])
29
- loss = criterion(outputs, train_loader[1])
31
+ outputs = model(X_train)
32
+ loss = reg_loss(outputs, y_train, X_train)
30
33
  loss.backward()
31
34
  optimizer.step()
32
- print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
35
+
36
+ if (epoch + 1) % 10 == 0:
37
+ print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
oikan/utils.py ADDED
@@ -0,0 +1,43 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from scipy.interpolate import BSpline
5
+
6
+ class BSplineBasis(nn.Module):
7
+ def __init__(self, num_knots=10, degree=3):
8
+ super().__init__()
9
+ self.num_knots = max(num_knots, degree + 5) # Ensure minimum number of knots
10
+ self.degree = degree
11
+
12
+ # Create knot vector with proper padding
13
+ inner_knots = np.linspace(0, 1, self.num_knots - 2 * degree)
14
+ left_pad = np.zeros(degree)
15
+ right_pad = np.ones(degree)
16
+ knots = np.concatenate([left_pad, inner_knots, right_pad])
17
+
18
+ self.register_buffer('knots', torch.FloatTensor(knots))
19
+
20
+ def forward(self, x):
21
+ x_np = x.detach().cpu().numpy()
22
+ basis_values = np.zeros((x_np.shape[0], self.num_knots - self.degree - 1))
23
+
24
+ # Normalize input to [0,1] range
25
+ x_normalized = (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-8)
26
+
27
+ for i in range(self.num_knots - self.degree - 1):
28
+ spl = BSpline.basis_element(self.knots[i:i+self.degree+2])
29
+ basis_values[:, i] = spl(x_normalized.squeeze())
30
+
31
+ # Replace NaN values with 0
32
+ basis_values = np.nan_to_num(basis_values, 0)
33
+ return torch.FloatTensor(basis_values).to(x.device)
34
+
35
+ class FourierBasis(nn.Module):
36
+ def __init__(self, num_frequencies=5):
37
+ super().__init__()
38
+ self.num_frequencies = num_frequencies
39
+
40
+ def forward(self, x):
41
+ frequencies = torch.arange(1, self.num_frequencies + 1, device=x.device).float()
42
+ x_expanded = x * frequencies.view(1, -1) * 2 * np.pi
43
+ return torch.cat([torch.sin(x_expanded), torch.cos(x_expanded)], dim=1)
oikan/visualize.py CHANGED
@@ -1,20 +1,37 @@
1
+ import numpy as np
1
2
  import matplotlib.pyplot as plt
2
3
  import torch
3
4
 
4
- # Regression Visualization Function
5
5
  def visualize_regression(model, X, y):
6
+ model.eval()
6
7
  with torch.no_grad():
7
- y_pred = model(torch.tensor(X, dtype=torch.float32)).numpy()
8
- plt.scatter(X[:, 0], y, label='True Data')
9
- plt.scatter(X[:, 0], y_pred, label='OIKAN Predictions', color='r')
8
+ X_tensor = torch.FloatTensor(X)
9
+ y_pred = model(X_tensor).numpy()
10
+
11
+ plt.figure(figsize=(10, 6))
12
+ plt.scatter(X[:, 0], y, color='blue', label='True')
13
+ plt.scatter(X[:, 0], y_pred, color='red', label='Predicted')
10
14
  plt.legend()
11
15
  plt.show()
12
16
 
13
- # Classification visualization
14
17
  def visualize_classification(model, X, y):
18
+ model.eval()
19
+
20
+ # Create a mesh grid
21
+ x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
22
+ y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
23
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
24
+ np.linspace(y_min, y_max, 100))
25
+
26
+ # Make predictions
15
27
  with torch.no_grad():
16
- outputs = model(torch.tensor(X, dtype=torch.float32))
17
- preds = torch.argmax(outputs, dim=1).numpy()
18
- plt.scatter(X[:, 0], X[:, 1], c=preds, cmap='viridis', edgecolor='k')
19
- plt.title("Classification Results")
28
+ X_grid = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()])
29
+ Z = model(X_grid)
30
+ Z = torch.argmax(Z, dim=1).numpy()
31
+ Z = Z.reshape(xx.shape)
32
+
33
+ # Plot
34
+ plt.figure(figsize=(10, 8))
35
+ plt.contourf(xx, yy, Z, alpha=0.4)
36
+ plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
20
37
  plt.show()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: oikan
3
- Version: 0.0.1.1
3
+ Version: 0.0.1.3
4
4
  Summary: OIKAN: Optimized Interpretable Kolmogorov-Arnold Networks
5
5
  Author: Arman Zhalgasbayev
6
6
  License: MIT
@@ -0,0 +1,11 @@
1
+ oikan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ oikan/model.py,sha256=9_U3jh1YwASbLOgHpFm4F80J3QGEhzIgQHNkqbZCPJs,2920
3
+ oikan/regularization.py,sha256=D0Xc2lr5X5ORdA5ltvWDbNDuN8z0hkyoGzFo7pum2XE,1033
4
+ oikan/symbolic.py,sha256=SGYWwNIQYjc_ik2bIF-_0LckWImHGECzn773btbqees,5394
5
+ oikan/trainer.py,sha256=itFCHSR_T6KHqa0D5RLRCmqFHa4lUIamsFGWKHmUZuI,1258
6
+ oikan/utils.py,sha256=XwY6pgAgfYlUI9SOjdop91wh0_t6LfPLCiHretlw2Wg,1754
7
+ oikan/visualize.py,sha256=8Dlk-tsqGZb63NyZBpZsLSlcsC51m2nXblQaS2Jf1y0,1142
8
+ oikan-0.0.1.3.dist-info/METADATA,sha256=3vY37GVJC0yuOQJCM0gggAu7FXyRu8WMje3Gfs9_XpA,1872
9
+ oikan-0.0.1.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
+ oikan-0.0.1.3.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
11
+ oikan-0.0.1.3.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- oikan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- oikan/model.py,sha256=LTWlXTlmeTwNe70Q7vjGOG6MUukCuWoHryvHB_yPzjc,1035
3
- oikan/symbolic.py,sha256=QjNGWU6LpPzZ35b-WYmSEYPM5FH9tKMS5pKgCujFd64,1431
4
- oikan/trainer.py,sha256=FmZ2TtcPiaam4ip0AzpzL6BXzDtsouh34GjhIxl0btw,1033
5
- oikan/visualize.py,sha256=J58pbWYaqV5vWkkRNUem0Jse5gHjQ-rRDKQDPIJouW0,729
6
- oikan-0.0.1.1.dist-info/METADATA,sha256=F77-yv451wCW6hzQsb9nJPHfI2YBDLFyK6S2mSn69JY,1872
7
- oikan-0.0.1.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
- oikan-0.0.1.1.dist-info/top_level.txt,sha256=XwnwKwTJddZwIvtrUsAz-l-58BJRj6HjAGWrfYi_3QY,6
9
- oikan-0.0.1.1.dist-info/RECORD,,