physkan 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- physkan/__init__.py +16 -0
- physkan/demonstrator.py +92 -0
- physkan/interaction.py +116 -0
- physkan/kan.py +260 -0
- physkan-0.1.0.dist-info/METADATA +184 -0
- physkan-0.1.0.dist-info/RECORD +8 -0
- physkan-0.1.0.dist-info/WHEEL +4 -0
- physkan-0.1.0.dist-info/licenses/LICENSE +22 -0
physkan/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""PhysKAN: Physics-Informed Kolmogorov-Arnold Networks.
|
|
2
|
+
|
|
3
|
+
A neural architecture designed for safety-critical system identification and control.
|
|
4
|
+
PhysKAN enforces rigorous physical extrapolation by combining bounded B-splines,
|
|
5
|
+
neuro-symbolic polynomial skip-connections, and interval arithmetic to mathematically
|
|
6
|
+
track out-of-bounds (OOB) severity.
|
|
7
|
+
|
|
8
|
+
Key Features:
|
|
9
|
+
- Bounded Spline Grids: Mechanical clamping outside the nominal data range.
|
|
10
|
+
- Dual Severity Tracking: Mathematical compounding of out-of-bounds errors.
|
|
11
|
+
- Hybrid Symbolic Routing: Automated discovery of stable macro-physics via polynomials.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from .demonstrator import KANDemonstrator as KANDemonstrator
|
|
15
|
+
from .kan import KAN as KAN
|
|
16
|
+
from .kan import KANLinear as KANLinear
|
physkan/demonstrator.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class KANDemonstrator:
|
|
7
|
+
"""A utility class for training and visualizing PhysKAN models,
|
|
8
|
+
specifically designed to analyze out-of-bounds (OOB) dual severity tracking.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, model, target_fn, feature_fn=None):
|
|
12
|
+
self.model = model
|
|
13
|
+
self.target_fn = target_fn
|
|
14
|
+
# If no feature engineering is provided, pass raw features through
|
|
15
|
+
self.feature_fn = feature_fn if feature_fn else lambda x: x
|
|
16
|
+
|
|
17
|
+
def train(self, x_raw_train, epochs=500, lr=0.05, weight_decay=1e-4):
|
|
18
|
+
"""Trains the model using the provided raw input tensors."""
|
|
19
|
+
y_train = self.target_fn(x_raw_train)
|
|
20
|
+
features = self.feature_fn(x_raw_train)
|
|
21
|
+
|
|
22
|
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
23
|
+
criterion = nn.MSELoss()
|
|
24
|
+
|
|
25
|
+
self.model.train()
|
|
26
|
+
for _ in range(epochs):
|
|
27
|
+
optimizer.zero_grad()
|
|
28
|
+
# Expecting the network to return (primal_prediction, dual_severity)
|
|
29
|
+
y_pred, d_pred = self.model(features, return_dual=True)
|
|
30
|
+
loss = criterion(y_pred, y_train)
|
|
31
|
+
loss.backward()
|
|
32
|
+
optimizer.step()
|
|
33
|
+
|
|
34
|
+
return loss.item()
|
|
35
|
+
|
|
36
|
+
def predict(self, x_raw):
|
|
37
|
+
"""Evaluates the model without tracking gradients."""
|
|
38
|
+
self.model.eval()
|
|
39
|
+
features = self.feature_fn(x_raw)
|
|
40
|
+
with torch.no_grad():
|
|
41
|
+
y_pred, d_pred = self.model(features, return_dual=True)
|
|
42
|
+
return y_pred, d_pred
|
|
43
|
+
|
|
44
|
+
def plot(self, x_raw_eval, title="KAN Demonstration", x_axis_idx=0):
|
|
45
|
+
"""Plots the physical prediction (primal) and severity tracking (dual).
|
|
46
|
+
x_axis_idx dictates which raw feature column to plot on the x-axis.
|
|
47
|
+
"""
|
|
48
|
+
# Sort by the primary plotting axis for clean lines
|
|
49
|
+
sort_idx = torch.argsort(x_raw_eval[:, x_axis_idx])
|
|
50
|
+
x_raw_eval = x_raw_eval[sort_idx]
|
|
51
|
+
|
|
52
|
+
y_true = self.target_fn(x_raw_eval)
|
|
53
|
+
y_pred, d_pred = self.predict(x_raw_eval)
|
|
54
|
+
|
|
55
|
+
x_plot = x_raw_eval[:, x_axis_idx].numpy()
|
|
56
|
+
|
|
57
|
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
|
|
58
|
+
fig.suptitle(title, fontsize=14)
|
|
59
|
+
|
|
60
|
+
# 1. Primal Plot (Physics)
|
|
61
|
+
for i in range(y_true.shape[1]):
|
|
62
|
+
label_true = f"True $y_{i}$" if y_true.shape[1] > 1 else "True Physics"
|
|
63
|
+
label_pred = f"Pred $y_{i}$" if y_true.shape[1] > 1 else "KAN Prediction"
|
|
64
|
+
ax1.plot(x_plot, y_pred[:, i].numpy(), "-", linewidth=2, label=label_pred)
|
|
65
|
+
ax1.plot(x_plot, y_true[:, i].numpy(), "k--", alpha=0.7, label=label_true)
|
|
66
|
+
|
|
67
|
+
ax1.axvspan(-1.0, 1.0, color="gray", alpha=0.1, label="Nominal Range")
|
|
68
|
+
ax1.set_ylabel("Physical Value")
|
|
69
|
+
ax1.legend()
|
|
70
|
+
ax1.grid(True, alpha=0.3)
|
|
71
|
+
|
|
72
|
+
# 2. Dual Plot (Severity)
|
|
73
|
+
num_targets = d_pred.shape[1]
|
|
74
|
+
|
|
75
|
+
# Use a color palette that stands out for warnings (reds, oranges, purples)
|
|
76
|
+
severity_colors = ["#ff0000", "#ff7f0e", "#800080", "#d62728"]
|
|
77
|
+
|
|
78
|
+
for i in range(num_targets):
|
|
79
|
+
severity = d_pred[:, i].numpy()
|
|
80
|
+
label_str = "Dual Severity ($D$)" if num_targets == 1 else f"Severity $y_{i}$"
|
|
81
|
+
color = severity_colors[i % len(severity_colors)]
|
|
82
|
+
|
|
83
|
+
ax2.plot(x_plot, severity, color=color, linestyle="-", linewidth=2, label=label_str)
|
|
84
|
+
|
|
85
|
+
ax2.axvspan(-1.0, 1.0, color="gray", alpha=0.1)
|
|
86
|
+
ax2.set_ylabel("OOB Severity")
|
|
87
|
+
ax2.set_xlabel(f"Raw Input (Feature index {x_axis_idx})")
|
|
88
|
+
ax2.legend()
|
|
89
|
+
ax2.grid(True, alpha=0.3)
|
|
90
|
+
|
|
91
|
+
plt.tight_layout()
|
|
92
|
+
plt.show()
|
physkan/interaction.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KANInteraction(torch.nn.Module):
|
|
9
|
+
"""Computes explicit feature interactions and their OOB duals.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
interaction_map (list[list[int]]): List of index lists defining the multiplicative terms.
|
|
13
|
+
Example: [[0, 0], [0, 1]] means add (feats[:,0]^2 and (feats[:,0] * feats[:,1]).
|
|
14
|
+
grid_range (tuple): The nominal boundaries to compute the dual.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, interaction_map: list[list[int]], grid_range=(-1.0, 1.0)):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.interaction_map = interaction_map
|
|
20
|
+
self.grid_range = grid_range
|
|
21
|
+
|
|
22
|
+
def extra_repr(self) -> str:
|
|
23
|
+
return f"interaction_map={self.interaction_map}, grid_range={self.grid_range}"
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor):
|
|
26
|
+
# 1. Genesis: Calculate initial OOB severity of raw inputs
|
|
27
|
+
lower_bound, upper_bound = self.grid_range
|
|
28
|
+
d = F.relu(x - upper_bound) + F.relu(lower_bound - x)
|
|
29
|
+
if not self.interaction_map:
|
|
30
|
+
return x, d
|
|
31
|
+
|
|
32
|
+
out_x = [x]
|
|
33
|
+
out_d = [d]
|
|
34
|
+
for term_indices in self.interaction_map:
|
|
35
|
+
# Gather the columns for the current interaction (e.g., [0, 0, 1])
|
|
36
|
+
x_gather = x[:, term_indices]
|
|
37
|
+
d_gather = d[:, term_indices]
|
|
38
|
+
x_abs_gather = x_gather.abs()
|
|
39
|
+
# --- Primal Physics ---
|
|
40
|
+
# Multiply the physical features together
|
|
41
|
+
x_interact = torch.prod(x_gather, dim=1, keepdim=True)
|
|
42
|
+
# --- Dual Severity ---
|
|
43
|
+
# Total Uncertainty Volume minus Nominal Volume
|
|
44
|
+
total_vol = torch.prod(x_abs_gather + d_gather, dim=1, keepdim=True)
|
|
45
|
+
nominal_vol = torch.prod(x_abs_gather, dim=1, keepdim=True)
|
|
46
|
+
d_interact = total_vol - nominal_vol
|
|
47
|
+
|
|
48
|
+
out_x.append(x_interact)
|
|
49
|
+
out_d.append(d_interact)
|
|
50
|
+
|
|
51
|
+
# Concatenate and return the augmented (primal, dual) tuple
|
|
52
|
+
return torch.cat(out_x, dim=1), torch.cat(out_d, dim=1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class PolynomialSkip(nn.Module):
|
|
56
|
+
def __init__(self, in_features, out_features, order=2):
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
# 1. Generate all multi-indices for polynomials up to 'order'
|
|
60
|
+
# e.g., for inputs [0, 1] order 2: (0,), (1,), (0,0), (0,1), (1,1)
|
|
61
|
+
self.combinations = []
|
|
62
|
+
for d in range(1, order + 1):
|
|
63
|
+
combos = itertools.combinations_with_replacement(range(in_features), d)
|
|
64
|
+
self.combinations.extend(list(combos))
|
|
65
|
+
|
|
66
|
+
num_features = len(self.combinations)
|
|
67
|
+
|
|
68
|
+
# 2. The standard linear weights
|
|
69
|
+
self.weights = nn.Parameter(torch.randn(out_features, num_features) * 0.1)
|
|
70
|
+
|
|
71
|
+
# 3. The Probationary Gates
|
|
72
|
+
self.gates = nn.Parameter(1.0 * torch.ones(out_features, num_features))
|
|
73
|
+
|
|
74
|
+
def forward(self, x, dual_input=None):
|
|
75
|
+
poly_features = []
|
|
76
|
+
poly_duals = []
|
|
77
|
+
|
|
78
|
+
# --- A. Compute Polynomials & Interval Duals ---
|
|
79
|
+
for combo in self.combinations:
|
|
80
|
+
term = torch.ones_like(x[:, 0:1])
|
|
81
|
+
term_dual = torch.zeros_like(x[:, 0:1]) if dual_input is not None else None
|
|
82
|
+
|
|
83
|
+
for idx in combo:
|
|
84
|
+
x_i = x[:, idx : idx + 1]
|
|
85
|
+
|
|
86
|
+
# Interval Multiplication for the Dual Severity
|
|
87
|
+
if dual_input is not None:
|
|
88
|
+
d_i = dual_input[:, idx : idx + 1]
|
|
89
|
+
# Severity of a product: (|A| + D_A) * (|B| + D_B) - |A * B|
|
|
90
|
+
new_dual = (torch.abs(term) + term_dual) * (torch.abs(x_i) + d_i) - torch.abs(
|
|
91
|
+
term * x_i
|
|
92
|
+
)
|
|
93
|
+
term_dual = new_dual
|
|
94
|
+
|
|
95
|
+
term = term * x_i
|
|
96
|
+
|
|
97
|
+
poly_features.append(term)
|
|
98
|
+
if dual_input is not None:
|
|
99
|
+
poly_duals.append(term_dual)
|
|
100
|
+
|
|
101
|
+
P_x = torch.cat(poly_features, dim=1)
|
|
102
|
+
|
|
103
|
+
# --- B. Apply the -5.0 Sigmoid Gate ---
|
|
104
|
+
active_weights = self.weights * torch.sigmoid(self.gates - 5.0)
|
|
105
|
+
|
|
106
|
+
# --- C. Route the Physical Prediction ---
|
|
107
|
+
out = F.linear(P_x, active_weights)
|
|
108
|
+
|
|
109
|
+
# --- D. Route the Dual Severity (The Abs-Weighted Path) ---
|
|
110
|
+
if dual_input is not None:
|
|
111
|
+
D_x = torch.cat(poly_duals, dim=1)
|
|
112
|
+
# You called it: the dual routes through the absolute value of the active weights!
|
|
113
|
+
out_dual = F.linear(D_x, torch.abs(active_weights))
|
|
114
|
+
return out, out_dual
|
|
115
|
+
|
|
116
|
+
return out
|
physkan/kan.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import inspect
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from .interaction import KANInteraction, PolynomialSkip
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class KANLinear(torch.nn.Module):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
in_features,
|
|
16
|
+
out_features,
|
|
17
|
+
grid_size=5,
|
|
18
|
+
spline_order=3,
|
|
19
|
+
base_activation=torch.nn.Identity,
|
|
20
|
+
grid_range=(-1.0, 1.0),
|
|
21
|
+
spline_dropout=0.0,
|
|
22
|
+
pure_spline_mode=False,
|
|
23
|
+
_quiet_init=False,
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.in_features = in_features
|
|
27
|
+
self.out_features = out_features
|
|
28
|
+
self.grid_size = grid_size
|
|
29
|
+
self.spline_order = spline_order
|
|
30
|
+
self.grid_range = grid_range
|
|
31
|
+
self.spline_dropout = spline_dropout
|
|
32
|
+
self.pure_spline_mode = pure_spline_mode
|
|
33
|
+
|
|
34
|
+
if inspect.isclass(base_activation):
|
|
35
|
+
self.base_activation = base_activation()
|
|
36
|
+
elif isinstance(base_activation, nn.Module):
|
|
37
|
+
# Deepcopy guarantees isolation if passed to multiple layers!
|
|
38
|
+
self.base_activation = copy.deepcopy(base_activation)
|
|
39
|
+
elif callable(base_activation):
|
|
40
|
+
self.base_activation = base_activation
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError("base_activation must be a class, module instance, or callable.")
|
|
43
|
+
|
|
44
|
+
# Static grid formulation
|
|
45
|
+
h = (grid_range[1] - grid_range[0]) / grid_size
|
|
46
|
+
grid = (
|
|
47
|
+
(torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0])
|
|
48
|
+
.expand(in_features, -1)
|
|
49
|
+
.contiguous()
|
|
50
|
+
)
|
|
51
|
+
self.register_buffer("grid", grid)
|
|
52
|
+
|
|
53
|
+
# The two parallel tracks
|
|
54
|
+
if not self.pure_spline_mode:
|
|
55
|
+
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
|
56
|
+
self.spline_weight = torch.nn.Parameter(
|
|
57
|
+
torch.Tensor(out_features, in_features, grid_size + spline_order)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.reset_parameters(_quiet_init)
|
|
61
|
+
|
|
62
|
+
def reset_parameters(self, quiet=False):
|
|
63
|
+
if not self.pure_spline_mode:
|
|
64
|
+
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))
|
|
65
|
+
# Force zero intercept error by mean-centering each row.
|
|
66
|
+
# This guarantees that the row-sum of base_weight is EXACTLY 0.0,
|
|
67
|
+
# meaning the effective weight row-sum is EXACTLY 1.0 across all channels.
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
self.base_weight -= self.base_weight.mean(dim=1, keepdim=True)
|
|
70
|
+
if quiet:
|
|
71
|
+
self.base_weight /= 100.0
|
|
72
|
+
torch.nn.init.zeros_(self.spline_weight)
|
|
73
|
+
|
|
74
|
+
def extra_repr(self) -> str:
|
|
75
|
+
return (
|
|
76
|
+
f"in_features={self.in_features}, "
|
|
77
|
+
f"out_features={self.out_features}, "
|
|
78
|
+
f"grid_size={self.grid_size}, "
|
|
79
|
+
f"spline_order={self.spline_order}, "
|
|
80
|
+
f"grid_range={self.grid_range}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def b_splines(self, x: torch.Tensor):
|
|
84
|
+
"""Compute the B-spline bases for the given input tensor.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
|
|
91
|
+
"""
|
|
92
|
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
93
|
+
|
|
94
|
+
grid: torch.Tensor = self.grid
|
|
95
|
+
x = x.unsqueeze(-1)
|
|
96
|
+
|
|
97
|
+
# Determine active spline basis
|
|
98
|
+
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
|
|
99
|
+
for k in range(1, self.spline_order + 1):
|
|
100
|
+
bases = (
|
|
101
|
+
(x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1]
|
|
102
|
+
) + ((grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:])
|
|
103
|
+
|
|
104
|
+
assert bases.size() == (
|
|
105
|
+
x.size(0),
|
|
106
|
+
self.in_features,
|
|
107
|
+
self.grid_size + self.spline_order,
|
|
108
|
+
)
|
|
109
|
+
return bases.contiguous()
|
|
110
|
+
|
|
111
|
+
def forward_internal(self, x: torch.Tensor, d: torch.Tensor):
|
|
112
|
+
"""Internal forward pass computing both primal (physics) and dual (severity)."""
|
|
113
|
+
assert x.size(-1) == self.in_features
|
|
114
|
+
original_shape = x.shape
|
|
115
|
+
x = x.reshape(-1, self.in_features)
|
|
116
|
+
d = d.reshape(-1, self.in_features)
|
|
117
|
+
|
|
118
|
+
# 1. Linear track (primal and dual)
|
|
119
|
+
if self.pure_spline_mode:
|
|
120
|
+
base_output = 0.0
|
|
121
|
+
dual_output = 0.0
|
|
122
|
+
else:
|
|
123
|
+
effective_weight = self.base_weight + (1.0 / self.in_features)
|
|
124
|
+
base_output = F.linear(self.base_activation(x), effective_weight)
|
|
125
|
+
# Dual severity propagates purely through absolute weights
|
|
126
|
+
dual_output = F.linear(d, effective_weight.abs())
|
|
127
|
+
|
|
128
|
+
# 2. Spline track (*clamped* primal only)
|
|
129
|
+
lower_bound, upper_bound = self.grid_range
|
|
130
|
+
if self.spline_order == 0:
|
|
131
|
+
# Deg-0 splines lack padding knots, so force the interval open
|
|
132
|
+
upper_bound -= 1e-6
|
|
133
|
+
if torch.jit.is_tracing() or x.min() < lower_bound or x.max() > upper_bound:
|
|
134
|
+
x = x.clamp(lower_bound, upper_bound)
|
|
135
|
+
|
|
136
|
+
spline_output = F.linear(
|
|
137
|
+
self.b_splines(x).view(x.size(0), -1),
|
|
138
|
+
self.spline_weight.view(self.out_features, -1),
|
|
139
|
+
)
|
|
140
|
+
if self.spline_dropout > 0.0:
|
|
141
|
+
spline_output = F.dropout(spline_output, p=self.spline_dropout, training=self.training)
|
|
142
|
+
|
|
143
|
+
if torch.is_grad_enabled() and dual_output.max() > 1e-6:
|
|
144
|
+
# Gaussian drop-off: exp(-(x * 2.5)^2)
|
|
145
|
+
# 0.01 -> 99.9% trust (noise is ignored)
|
|
146
|
+
# 0.10 -> 93.9% trust (minor leakage allowed)
|
|
147
|
+
# 0.50 -> 20.9% trust (aggressively pinching off)
|
|
148
|
+
# 1.00 -> 0.1% trust (hard detach)
|
|
149
|
+
trust = torch.exp(-((dual_output.detach() * 2.5) ** 2))
|
|
150
|
+
spline_output = trust * spline_output + (1.0 - trust) * spline_output.detach()
|
|
151
|
+
|
|
152
|
+
# 3. Combine and return tuple
|
|
153
|
+
x_final = (base_output + spline_output).reshape(*original_shape[:-1], self.out_features)
|
|
154
|
+
d_final = dual_output.reshape(*original_shape[:-1], self.out_features)
|
|
155
|
+
return x_final, d_final
|
|
156
|
+
|
|
157
|
+
def forward(self, x: torch.Tensor, return_dual: bool = False):
|
|
158
|
+
"""Public API. Assumes zero incoming severity if used as a standalone layer."""
|
|
159
|
+
d = torch.zeros_like(x)
|
|
160
|
+
x_out, d_out = self.forward_internal(x, d)
|
|
161
|
+
return (x_out, d_out) if return_dual else x_out
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class KAN(torch.nn.Module):
|
|
165
|
+
"""Kolmogorov-Arnold Network (KAN) macro-architecture composed of sequentially stacked KANLinear layers.
|
|
166
|
+
|
|
167
|
+
Coordinates deep layer propagation by chaining self-contained Bounded KAN blocks. If
|
|
168
|
+
`pure_spline_mode` is False, the underlying layers maintain an internal scale-preserving
|
|
169
|
+
uniform baseline that seamlessly conserves signal magnitude across dimensional changes
|
|
170
|
+
(expansions/contractions) during extreme out-of-bounds anomalies.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
layer_dims: Architectural dimensions mapping from input to output features (e.g., [input_dim, hidden_dim, output_dim]).
|
|
174
|
+
grid_size: Number of inner intervals partitioning the spline domain
|
|
175
|
+
spline_order: Polynomial degree of the local B-spline bases.
|
|
176
|
+
base_activation: Activation function applied exclusively to the linear track. Change with caution - see README!
|
|
177
|
+
grid_range: Physical bounds `(lower, upper)` defining the spline evaluation domain.
|
|
178
|
+
pure_spline_mode: If True, completely disables the linear track across all child layers, forcing hard
|
|
179
|
+
saturation/clipping at boundaries instead of proportional linear extrapolation.
|
|
180
|
+
spline_dropout: Dropout probability, to encourage asymptote learning.
|
|
181
|
+
interaction_map: Multiplicative feature interaction indices. Use to define explicit cross-terms while preserving
|
|
182
|
+
strict OOB propagation.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
layer_dims: list[int],
|
|
188
|
+
grid_size: int = 5,
|
|
189
|
+
spline_order: int = 3,
|
|
190
|
+
base_activation: torch.nn.Module = torch.nn.Identity,
|
|
191
|
+
grid_range: tuple[float, float] = (-1.0, 1.0),
|
|
192
|
+
pure_spline_mode: bool = False,
|
|
193
|
+
spline_dropout: float = 0.0,
|
|
194
|
+
interaction_map: list[list[int]] = [],
|
|
195
|
+
symbolic_order: int = 0,
|
|
196
|
+
):
|
|
197
|
+
super().__init__()
|
|
198
|
+
self.interactor = KANInteraction(interaction_map, grid_range)
|
|
199
|
+
self.layer_dims = layer_dims
|
|
200
|
+
eff_layer_dims = list(layer_dims)
|
|
201
|
+
eff_layer_dims[0] += len(interaction_map)
|
|
202
|
+
|
|
203
|
+
self.layers = torch.nn.ModuleList()
|
|
204
|
+
for in_features, out_features in zip(eff_layer_dims, eff_layer_dims[1:]):
|
|
205
|
+
self.layers.append(
|
|
206
|
+
KANLinear(
|
|
207
|
+
in_features,
|
|
208
|
+
out_features,
|
|
209
|
+
grid_size=grid_size,
|
|
210
|
+
spline_order=spline_order,
|
|
211
|
+
base_activation=base_activation,
|
|
212
|
+
grid_range=grid_range,
|
|
213
|
+
spline_dropout=spline_dropout,
|
|
214
|
+
pure_spline_mode=pure_spline_mode,
|
|
215
|
+
_quiet_init=symbolic_order > 0,
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# 2. The Shallow Polynomial Skip Connection
|
|
220
|
+
self.symbolic_order = symbolic_order
|
|
221
|
+
if self.symbolic_order > 0:
|
|
222
|
+
self.poly_skip = PolynomialSkip(
|
|
223
|
+
in_features=eff_layer_dims[0], out_features=eff_layer_dims[-1], order=symbolic_order
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def extra_repr(self) -> str:
|
|
227
|
+
# Most information is already in the layers, just add the pre-interactions dim
|
|
228
|
+
return f"in_features={self.layer_dims[0]}"
|
|
229
|
+
|
|
230
|
+
def forward(
|
|
231
|
+
self, x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], return_dual: bool = False
|
|
232
|
+
):
|
|
233
|
+
if isinstance(x, tuple):
|
|
234
|
+
if self.interactor.interaction_map:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
"Both explicit and implicit interactions supplied. This is not supported."
|
|
237
|
+
)
|
|
238
|
+
x, d = x
|
|
239
|
+
if x.shape[1] != self.layers[0].in_features:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Wrong input dimension {x.shape[1]}, expected {self.layers[0].in_features}"
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
if x.shape[1] != self.layer_dims[0]:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"Wrong input dimension {x.shape[1]}, expected {self.layer_dims[0]}"
|
|
247
|
+
)
|
|
248
|
+
x, d = self.interactor(x)
|
|
249
|
+
|
|
250
|
+
if self.symbolic_order > 0:
|
|
251
|
+
poly_out, poly_dual = self.poly_skip(x, d)
|
|
252
|
+
|
|
253
|
+
# Route features along with dual through the layers
|
|
254
|
+
for layer in self.layers:
|
|
255
|
+
x, d = layer.forward_internal(x, d)
|
|
256
|
+
if self.symbolic_order > 0:
|
|
257
|
+
x = x + poly_out
|
|
258
|
+
d = d + poly_dual
|
|
259
|
+
|
|
260
|
+
return (x, d) if return_dual else x
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: physkan
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A physics-constrained Kolmogorov-Arnold Network with bounded latent spaces.
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Python: >=3.12
|
|
7
|
+
Requires-Dist: torch>=2.9.0
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: ruff; extra == 'dev'
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# PhysKAN
|
|
13
|
+
|
|
14
|
+
**Physics-constrained Kolmogorov-Arnold Networks for stable system identification**
|
|
15
|
+
|
|
16
|
+
This repository provides a structural adaptation of the B-spline Kolmogorov-Arnold Network (KAN) architecture, designed for physical system identification, digital twins, and robust regression.
|
|
17
|
+
|
|
18
|
+
While standard KANs perform well at function approximation in purely mathematical domains, applying them to physical telemetry often requires interventions, like dynamic grid updates or statistical normalization such as LayerNorm, to handle out-of-bounds (OOB) anomalies.
|
|
19
|
+
In this context, OOB refers to any data point that exceeds the nominal operational range of the system, whether caused by a real but long-tail phenomenon (e.g., unseen weather regimes) or a transient sensor failure (e.g., signal spikes).
|
|
20
|
+
Unfortunately, these standard deep learning techniques remove the spatial meaning of the network's internal variables.
|
|
21
|
+
|
|
22
|
+
This architecture addresses this by freezing the spatial grid and enforcing strict physical bounds natively, prioritizing metric stability and OOB safety over localized curve-fitting flexibility.
|
|
23
|
+
It also uses forward uncertainty propagation with interval arithmetic to track the OOB state through the network.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
## Core design philosophy
|
|
27
|
+
|
|
28
|
+
PhysKAN is built on three central ideas, meant to bridge the gap between theoretical non-linear mapping and the robust fail-safes required for physical engineering:
|
|
29
|
+
|
|
30
|
+
1. **Progressive Koopman-style unbending:** Rather than relying on black-box MLP node activations, the model acts as a structural filter.
|
|
31
|
+
It uses constrained B-splines to progressively unbend non-linear physical inputs layer-by-layer, lifting them into a linearized latent space (analogous to finding "observables" in Koopman Operator Theory).
|
|
32
|
+
|
|
33
|
+
2. **Embrace out-of-bounds (OOB) values:** Real-world physics do not stay neatly within standardized grids.
|
|
34
|
+
Instead of arbitrarily squashing long-tail events or sensor glitches with clamps or global activations, the architecture uses the grid range to explicitly define the boundary between the dense, well-modeled operational regime and the sparse, asymptotic tail.
|
|
35
|
+
OOB states are safely clamped on the non-linear spline track and routed unclamped through a parallel linear track, ensuring mathematically stable extrapolation.
|
|
36
|
+
|
|
37
|
+
3. **Epistemic uncertainty tracking:** The network computes a continuous dual property alongside the physical prediction.
|
|
38
|
+
This signal forward-propagates the mathematical severity of any out-of-bounds state, providing a deterministic measure of when the network is forced to extrapolate.
|
|
39
|
+
|
|
40
|
+
---
|
|
41
|
+
|
|
42
|
+
## Under the hood: the OOB routing mechanism
|
|
43
|
+
|
|
44
|
+
To safely execute this philosophy, the network requires a specific mental model for how it routes data—especially during the backward pass.
|
|
45
|
+
|
|
46
|
+
In standard implementations, out-of-bounds data either "falls off" the spline grid entirely (dropping to zero) or requires the input to be clamped or bounded.
|
|
47
|
+
However, if clamped *without* gradient detachment, the boundary knot absorbs the training loss for all out-of-bounds states.
|
|
48
|
+
It becomes a wastebasket for outlying values, compressing the long-tail distribution into a single coordinate and warping predictions for nominal operations.
|
|
49
|
+
|
|
50
|
+
The PhysKAN architecture acts as a traffic cop for physical regimes:
|
|
51
|
+
* **The nominal regime (non-linear track):** Dense, expected data operates inside the grid, shaping the non-linear B-splines.
|
|
52
|
+
* **The out-of-bounds regime (linear track):** OOB data are clamped on the non-linear track (with detached gradients to protect the nominal-range knots).
|
|
53
|
+
The excess signal flows entirely through the linear track.
|
|
54
|
+
|
|
55
|
+
This ensures the non-linear splines strictly learn the nominal physics, while the linear track safely catches long-tail events.
|
|
56
|
+
|
|
57
|
+
## Architectural constraints
|
|
58
|
+
|
|
59
|
+
To maintain the absolute physical meaning of these latent observables during deployment, the model relies on two structural constraints:
|
|
60
|
+
|
|
61
|
+
### 1. Static grid boundaries
|
|
62
|
+
|
|
63
|
+
KAN architectures often rely on dynamic grid updates (knot insertion or movement) during training.
|
|
64
|
+
This architecture disables this.
|
|
65
|
+
Dynamic updates shift the underlying coordinate system of the network mid-training, causing downstream layers to lose their physical calibration.
|
|
66
|
+
By enforcing a static grid, the model sacrifices some theoretical curve-fitting capacity to guarantee that a specific latent state retains its exact metric meaning from initialization to deployment.
|
|
67
|
+
|
|
68
|
+
### 2. Linear skip connections as safety valves
|
|
69
|
+
Because the spline gradients are detached for OOB values, the network routes the excess gradients entirely through the parallel linear skip connection.
|
|
70
|
+
This serves as a vital safety valve: it protects the non-linear splines from gradient pollution, and it ensures that OOB inputs extrapolate linearly and predictably.
|
|
71
|
+
This limits the downstream impact of anomalies, making system filtering more reliable.
|
|
72
|
+
|
|
73
|
+
#### Justification for linear extrapolation (physical basis functions)
|
|
74
|
+
|
|
75
|
+
While real-world OOB events often exhibit higher-order scaling (e.g., cubic wave resistance), the model enforces a linear default for OOB extrapolation.
|
|
76
|
+
This is a deliberate design choice to prevent mathematical instability caused by sensor faults.
|
|
77
|
+
|
|
78
|
+
To safely capture higher-order OOB physics, domain knowledge should be embedded directly via feature engineering.
|
|
79
|
+
As long as the input features form a sufficient physical basis, particularly for asymptotic behaviours, the linear skip connection will naturally capture higher-order OOB phenomena as a linear combination of features without compromising the nominal operating region.
|
|
80
|
+
|
|
81
|
+
Applying a post-summation node activation (such as `SiLU` or `tanh`) fundamentally sabotages this mechanism.
|
|
82
|
+
A non-linear activation will warp the magnitude of the OOB event, rendering the linear skip connection unable to model it.
|
|
83
|
+
For this reason, activations are disabled by default (using `Identity`).
|
|
84
|
+
Other activations may be selected, but beware that the guarantees provided by "standard" PhysKAN may be weakened or destroyed.
|
|
85
|
+
|
|
86
|
+
## Feature engineering and explicit interactions
|
|
87
|
+
|
|
88
|
+
Deep architectures can theoretically learn multiplicative interactions (such as computing `x * y` by combining multiple layers).
|
|
89
|
+
Making the network deduce these relationships from scratch consumes capacity and degrades poorly when out-of-bounds.
|
|
90
|
+
Instead, to capture known physical behaviors, domain knowledge should be embedded directly via feature engineering.
|
|
91
|
+
Providing the network with a dictionary of physical basis functions (e.g., `x^2` or `cos(θ)`) allows the linear skip connection to latch onto these engineered features as a stable baseline.
|
|
92
|
+
This leaves the splines to map the local residuals, ensuring safe extrapolation when the splines saturate.
|
|
93
|
+
|
|
94
|
+
However, combining features naively can mask out-of-bounds anomalies.
|
|
95
|
+
If you manually pre-compute an interaction like `wave_height * cos(wind_dir)` and pass it to the network as a raw input, the anomaly signal is suppressed.
|
|
96
|
+
For instance, if `wave_height` is OOB (e.g., twice nominal range) but `cos(wave_dir)` is near zero, their product is well within nominal bounds.
|
|
97
|
+
The model treats this as a regular in-bounds prediction, and uses the data point to update its nominal-range spline.
|
|
98
|
+
|
|
99
|
+
To prevent this suppression, the network requires interaction terms to be defined internally via an `interaction_map` rather than expanded manually beforehand.
|
|
100
|
+
|
|
101
|
+
The network computes a continuous dual property alongside the standard physical prediction.
|
|
102
|
+
This dual represents the mathematical severity of the out-of-bounds state.
|
|
103
|
+
* The *physical prediction* is computed using the non-linear splines and the linear track.
|
|
104
|
+
* The *dual severity* strictly bypasses the splines and propagates via the absolute values of the linear weights, ensuring that uncertainties compound and never cancel out.
|
|
105
|
+
|
|
106
|
+
By defining interactions explicitly through the `interaction_map`, the model correctly applies the uncertainty product rule to the input features before they enter the network.
|
|
107
|
+
If a large wave anomaly interacts with a nominal-range cosine, the resulting interaction term inherits a proportional severity score.
|
|
108
|
+
This deterministic distress signal persists through the entire depth of the network, ensuring that the non-linear splines are firewalled from learning from the anomaly, while the linear track safely handles the extrapolated magnitude.
|
|
109
|
+
It also provides downstream consumers with a clear indicator of when the model is operating on dodgy data.
|
|
110
|
+
|
|
111
|
+
### Defining the nominal range: data density vs. physical limits
|
|
112
|
+
|
|
113
|
+
When defining the `grid_range` and normalizing inputs, the boundaries should reflect the density of the training data rather than the theoretical limits of the physical system.
|
|
114
|
+
|
|
115
|
+
B-splines require consistent data distribution across their internal grid to form a stable curve.
|
|
116
|
+
If for example a physical feature (such as wave height) has a theoretical operational limit of 5.0 meters, but the training dataset becomes sparse above 2.0 meters, setting the spline boundary to 5.0 meters forces the model to fit curves in an under-constrained region.
|
|
117
|
+
This often causes the splines to oscillate or overfit to a handful of isolated data points.
|
|
118
|
+
|
|
119
|
+
Instead, the grid boundary should be placed where the data density noticeably drops off (e.g., at 2.0 meters).
|
|
120
|
+
By treating the sparse region as out-of-bounds, the network safely clamps the splines in the dense region and relies on the linear track to extrapolate smoothly through the sparse tail.
|
|
121
|
+
The working principle is to treat the nominal range strictly as the bounds of the dense training data.
|
|
122
|
+
|
|
123
|
+
## Installation
|
|
124
|
+
|
|
125
|
+
You can install the package directly from GitHub:
|
|
126
|
+
|
|
127
|
+
```bash
|
|
128
|
+
pip install git+[https://github.com/simula/physkan.git](https://github.com/simula/physkan.git)
|
|
129
|
+
|
|
130
|
+
## Usage example
|
|
131
|
+
|
|
132
|
+
The model handles explicit feature expansion and interval arithmetic internally. A standard linear layer should be used as the final readout.
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
import torch
|
|
136
|
+
import torch.nn as nn
|
|
137
|
+
from physkan import KAN
|
|
138
|
+
|
|
139
|
+
# Define explicit cross-terms using indices
|
|
140
|
+
# e.g., for features [wave, wind, cos_dir]:
|
|
141
|
+
# [0, 0] adds wave^2
|
|
142
|
+
# [0, 2] adds wave * cos_dir
|
|
143
|
+
interactions = [[0, 0], [0, 2]]
|
|
144
|
+
|
|
145
|
+
# The KAN model automatically expands the initial input dimension
|
|
146
|
+
# and sets up the continuous dual routing.
|
|
147
|
+
kan_encoder = KAN(
|
|
148
|
+
layers_dims=[3, 16, 8], # Input dim is 3 (wave, wind, cos_dir)
|
|
149
|
+
grid_range=(0.0, 1.0),
|
|
150
|
+
interaction_map=interactions
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# The readout: Linear combination of the final observables of a zero-at-rest (unbiased) system
|
|
154
|
+
linear_mixer = nn.Linear(in_features=8, out_features=1, bias=False)
|
|
155
|
+
|
|
156
|
+
model = nn.Sequential(
|
|
157
|
+
kan_encoder,
|
|
158
|
+
linear_mixer
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Nominal physical data
|
|
162
|
+
x_nominal = torch.tensor([[0.5, 0.8, 0.1]])
|
|
163
|
+
|
|
164
|
+
# Pass data through the encoder, requesting the dual distress signal
|
|
165
|
+
latent_features, severity_signal = kan_encoder(x_nominal, return_dual=True)
|
|
166
|
+
prediction = linear_mixer(latent_features)
|
|
167
|
+
|
|
168
|
+
# For an out-of-bounds event (e.g., wave height sensor reads 5.0)
|
|
169
|
+
x_oob = torch.tensor([[5.0, 0.8, 0.1]])
|
|
170
|
+
latent_oob, severity_oob = kan_encoder(x_oob, return_dual=True)
|
|
171
|
+
|
|
172
|
+
# severity_oob > 0 indicates the prediction relies on mathematically
|
|
173
|
+
# extrapolated values, allowing downstream logic to trigger heuristics.
|
|
174
|
+
if severity_oob.mean() > 0.0:
|
|
175
|
+
print("Warning: operating in uncharted physical regime.")
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
## Attribution
|
|
179
|
+
|
|
180
|
+
This repository is an adaptation of the excellent **[efficient-kan](https://github.com/Blealtan/efficient-kan)** library by Blealtan.
|
|
181
|
+
|
|
182
|
+
The core B-spline evaluation mechanics, memory-efficient tensor formulation, and foundational matrix operations are directly derived from `efficient-kan`.
|
|
183
|
+
The modifications introduced here are strictly architectural (specifically the detached routing, strict boundary clamping, interval arithmetic dual, and default identity activations) designed to constrain the network for physical system identification.
|
|
184
|
+
Full credit for the underlying efficiency and base implementation belongs to the original author.
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
physkan/__init__.py,sha256=RRNVtHe-4aeA8CqQGi_dS6-JCkkj8DUPTlfvIL6jHKo,736
|
|
2
|
+
physkan/demonstrator.py,sha256=TWom4dYMxyZ5I6gkFyEYJ4-0zDIFUW9IpLnSKwETR3Y,3637
|
|
3
|
+
physkan/interaction.py,sha256=9Y5k2H0q60m-tKexvTmDS3BQ473RsciTWW4MRVQEIb8,4423
|
|
4
|
+
physkan/kan.py,sha256=67YDX_IjKiN7f0k8r9YO4or1kwBy2OHgAcLWX--fqMY,10727
|
|
5
|
+
physkan-0.1.0.dist-info/METADATA,sha256=T_QzvnGvIFKtk0-K_5Ja9Q-21eX40GuTwcUWrbJeMxM,12274
|
|
6
|
+
physkan-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
7
|
+
physkan-0.1.0.dist-info/licenses/LICENSE,sha256=cKUELmUqTmMhIckxOzk8eFnSKbq7D8KUfZLnMuOg_MQ,1153
|
|
8
|
+
physkan-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Based on efficient-kan, copyright (c) 2024 Huanqi Cao.
|
|
4
|
+
Modifications copyright (c) 2026 Simula Research Laboratory.
|
|
5
|
+
|
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
in the Software without restriction, including without limitation the rights
|
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
furnished to do so, subject to the following conditions:
|
|
12
|
+
|
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
copies or substantial portions of the Software.
|
|
15
|
+
|
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
SOFTWARE.
|