nous 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nous might be problematic. Click here for more details.
- nous/__init__.py +96 -19
- nous/data/__init__.py +4 -0
- nous/data/california.py +32 -0
- nous/data/wine.py +29 -0
- nous/explain/__init__.py +26 -0
- nous/explain/aggregator.py +34 -0
- nous/explain/cf.py +137 -0
- nous/explain/facts_desc.py +23 -0
- nous/explain/fidelity.py +56 -0
- nous/explain/generate.py +86 -0
- nous/explain/global_book.py +52 -0
- nous/explain/loo.py +130 -0
- nous/explain/mse.py +93 -0
- nous/explain/pruning.py +117 -0
- nous/explain/stability.py +42 -0
- nous/explain/traces.py +285 -0
- nous/explain/utils.py +15 -0
- nous/export/__init__.py +13 -0
- nous/export/numpy_infer.py +412 -0
- nous/facts.py +112 -0
- nous/model.py +226 -0
- nous/prototypes.py +43 -0
- nous/rules/__init__.py +11 -0
- nous/rules/blocks.py +63 -0
- nous/rules/fixed.py +26 -0
- nous/rules/softmax.py +93 -0
- nous/rules/sparse.py +142 -0
- nous/training/__init__.py +5 -0
- nous/training/evaluation.py +57 -0
- nous/training/schedulers.py +34 -0
- nous/training/train.py +177 -0
- nous/types.py +4 -0
- nous/utils/__init__.py +3 -0
- nous/utils/metrics.py +2 -0
- nous/utils/seed.py +13 -0
- nous/version.py +1 -0
- nous-0.2.0.dist-info/METADATA +150 -0
- nous-0.2.0.dist-info/RECORD +41 -0
- nous/causal.py +0 -63
- nous/interpret.py +0 -111
- nous/layers.py +0 -117
- nous/models.py +0 -65
- nous-0.1.0.dist-info/METADATA +0 -138
- nous-0.1.0.dist-info/RECORD +0 -10
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/WHEEL +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/top_level.txt +0 -0
nous/export/__init__.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import textwrap
|
|
5
|
+
import importlib.util
|
|
6
|
+
from types import ModuleType
|
|
7
|
+
from typing import Dict, Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
from ..model import NousNet
|
|
14
|
+
|
|
15
|
+
def export_numpy_inference(model: NousNet, file_path: str = "nous_numpy_infer.py") -> str:
|
|
16
|
+
"""
|
|
17
|
+
Export a NousNet instance to a self-contained NumPy inference Python module.
|
|
18
|
+
|
|
19
|
+
The generated module exposes:
|
|
20
|
+
- a dict P with parameters and metadata,
|
|
21
|
+
- predict(X, return_logits=False) using only numpy.
|
|
22
|
+
|
|
23
|
+
Returns the code as a string and writes it to file_path if provided.
|
|
24
|
+
"""
|
|
25
|
+
model.eval()
|
|
26
|
+
cfg = model.config
|
|
27
|
+
|
|
28
|
+
def npy(x):
|
|
29
|
+
if isinstance(x, torch.Tensor):
|
|
30
|
+
return x.detach().cpu().numpy()
|
|
31
|
+
return np.array(x)
|
|
32
|
+
|
|
33
|
+
def tolist(x, dtype=np.float32):
|
|
34
|
+
return npy(x).astype(dtype).tolist()
|
|
35
|
+
|
|
36
|
+
use_cals = model.calibrators is not None
|
|
37
|
+
calibrators = []
|
|
38
|
+
if use_cals:
|
|
39
|
+
for c in model.calibrators:
|
|
40
|
+
with torch.no_grad():
|
|
41
|
+
deltas_sp = F.softplus(c.deltas).cpu().numpy()
|
|
42
|
+
cum = np.cumsum(deltas_sp)
|
|
43
|
+
cum = np.concatenate([np.array([0.0], dtype=np.float32), cum.astype(np.float32)], axis=0)
|
|
44
|
+
calibrators.append(dict(
|
|
45
|
+
input_min=float(c.input_min),
|
|
46
|
+
input_max=float(c.input_max),
|
|
47
|
+
num_bins=int(c.num_bins),
|
|
48
|
+
bias=float(c.bias.item()),
|
|
49
|
+
cum=cum.tolist()
|
|
50
|
+
))
|
|
51
|
+
|
|
52
|
+
Lw, Rw, th, k, nu = model.fact.get_rule_parameters()
|
|
53
|
+
fact_dict = dict(
|
|
54
|
+
L=Lw.astype(np.float32).tolist(),
|
|
55
|
+
R=Rw.astype(np.float32).tolist(),
|
|
56
|
+
th=th.astype(np.float32).tolist(),
|
|
57
|
+
k=k.astype(np.float32).tolist(),
|
|
58
|
+
nu=nu.astype(np.float32).tolist()
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
blocks = []
|
|
62
|
+
for blk in model.blocks:
|
|
63
|
+
ln_w = tolist(blk.norm.weight)
|
|
64
|
+
ln_b = tolist(blk.norm.bias)
|
|
65
|
+
proj_W = None
|
|
66
|
+
if not isinstance(blk.proj, torch.nn.Identity):
|
|
67
|
+
proj_W = tolist(blk.proj.weight)
|
|
68
|
+
|
|
69
|
+
from ..rules.blocks import SimpleNousBlock
|
|
70
|
+
from ..rules.softmax import SoftmaxRuleLayer
|
|
71
|
+
from ..rules.sparse import SparseRuleLayer
|
|
72
|
+
|
|
73
|
+
if isinstance(blk, SimpleNousBlock):
|
|
74
|
+
idx_pairs = blk.rule.idx.detach().cpu().numpy().astype(np.int64).tolist()
|
|
75
|
+
rule_w = tolist(blk.rule.weight)
|
|
76
|
+
blocks.append(dict(
|
|
77
|
+
kind="simple",
|
|
78
|
+
proj_W=proj_W,
|
|
79
|
+
ln_w=ln_w, ln_b=ln_b,
|
|
80
|
+
idx_pairs=idx_pairs,
|
|
81
|
+
rule_w=rule_w
|
|
82
|
+
))
|
|
83
|
+
|
|
84
|
+
elif isinstance(blk, SoftmaxRuleLayer):
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
fl = F.softmax(blk.fact_logits, dim=1).cpu().numpy()
|
|
87
|
+
kf = int(min(blk.top_k_facts, blk.input_dim))
|
|
88
|
+
topk_idx = np.argpartition(fl, -kf, axis=1)[:, -kf:]
|
|
89
|
+
mask = np.zeros_like(fl, dtype=np.float32)
|
|
90
|
+
rows = np.arange(fl.shape[0])[:, None]
|
|
91
|
+
mask[rows, topk_idx] = 1.0
|
|
92
|
+
|
|
93
|
+
agg_w = F.softmax(blk.aggregator_logits, dim=1).cpu().numpy().astype(np.float32) # [R, 3]
|
|
94
|
+
rule_strength = torch.sigmoid(blk.rule_strength_raw).cpu().numpy().astype(np.float32)
|
|
95
|
+
blocks.append(dict(
|
|
96
|
+
kind="softmax",
|
|
97
|
+
proj_W=proj_W,
|
|
98
|
+
ln_w=ln_w, ln_b=ln_b,
|
|
99
|
+
mask=mask.tolist(),
|
|
100
|
+
agg_w=agg_w.tolist(),
|
|
101
|
+
rule_strength=rule_strength.tolist(),
|
|
102
|
+
top_k_rules=int(blk.top_k_rules)
|
|
103
|
+
))
|
|
104
|
+
|
|
105
|
+
elif isinstance(blk, SparseRuleLayer):
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
beta = blk.hard_concrete.beta.detach().cpu().numpy()
|
|
108
|
+
mask = (1.0 / (1.0 + np.exp(-beta)) > 0.5).astype(np.float32) # eval behavior
|
|
109
|
+
agg_w = F.softmax(blk.aggregator_logits, dim=1).cpu().numpy().astype(np.float32) # [R, 4]
|
|
110
|
+
rule_strength = torch.sigmoid(blk.rule_strength_raw).cpu().numpy().astype(np.float32)
|
|
111
|
+
blocks.append(dict(
|
|
112
|
+
kind="sparse",
|
|
113
|
+
proj_W=proj_W,
|
|
114
|
+
ln_w=ln_w, ln_b=ln_b,
|
|
115
|
+
mask=mask.tolist(),
|
|
116
|
+
agg_w=agg_w.tolist(),
|
|
117
|
+
rule_strength=rule_strength.tolist(),
|
|
118
|
+
top_k_rules=int(blk.top_k_rules)
|
|
119
|
+
))
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Unknown block type: {type(blk)}")
|
|
122
|
+
|
|
123
|
+
if isinstance(model.head, torch.nn.Linear):
|
|
124
|
+
with torch.no_grad():
|
|
125
|
+
W = model.head.weight.detach().cpu().numpy().astype(np.float32)
|
|
126
|
+
b = model.head.bias.detach().cpu().numpy().astype(np.float32)
|
|
127
|
+
head = dict(kind="linear", W=W.tolist(), b=b.tolist())
|
|
128
|
+
else:
|
|
129
|
+
from ..prototypes import ScaledPrototypeLayer
|
|
130
|
+
assert isinstance(model.head, ScaledPrototypeLayer)
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
Pm = model.head.prototypes.detach()
|
|
133
|
+
Pn = F.normalize(Pm, p=2, dim=1).cpu().numpy().astype(np.float32)
|
|
134
|
+
W = model.head.prototype_class.detach().cpu().numpy().astype(np.float32)
|
|
135
|
+
tau = float(F.softplus(model.head.temperature).item())
|
|
136
|
+
head = dict(kind="prototypes", P_norm=Pn.tolist(), W=W.tolist(), tau=tau)
|
|
137
|
+
|
|
138
|
+
P_dict = dict(
|
|
139
|
+
task=cfg['task_type'],
|
|
140
|
+
use_calibrators=bool(use_cals),
|
|
141
|
+
calibrators=calibrators,
|
|
142
|
+
fact=fact_dict,
|
|
143
|
+
blocks=blocks,
|
|
144
|
+
head=head
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
code = f"""# Auto-generated NumPy inference for NousNet
|
|
148
|
+
# This file is self-contained and requires only numpy.
|
|
149
|
+
import numpy as np
|
|
150
|
+
|
|
151
|
+
P = {repr(P_dict)}
|
|
152
|
+
|
|
153
|
+
def sigmoid(x):
|
|
154
|
+
return 1.0 / (1.0 + np.exp(-x))
|
|
155
|
+
|
|
156
|
+
def logsigmoid(z):
|
|
157
|
+
# stable: -log(1+exp(-z))
|
|
158
|
+
return -np.log1p(np.exp(-z))
|
|
159
|
+
|
|
160
|
+
def layernorm(x, gamma, beta, eps=1e-5):
|
|
161
|
+
mu = x.mean(axis=1, keepdims=True)
|
|
162
|
+
var = x.var(axis=1, keepdims=True)
|
|
163
|
+
xhat = (x - mu) / np.sqrt(var + eps)
|
|
164
|
+
return xhat * gamma + beta
|
|
165
|
+
|
|
166
|
+
def calibrate(X):
|
|
167
|
+
if not P['use_calibrators']:
|
|
168
|
+
return X.astype(np.float32)
|
|
169
|
+
X = X.astype(np.float32)
|
|
170
|
+
Xc = np.empty_like(X, dtype=np.float32)
|
|
171
|
+
for j, c in enumerate(P['calibrators']):
|
|
172
|
+
xmin, xmax = float(c['input_min']), float(c['input_max'])
|
|
173
|
+
nb = int(c['num_bins'])
|
|
174
|
+
cum = np.array(c['cum'], dtype=np.float32)
|
|
175
|
+
bias = float(c['bias'])
|
|
176
|
+
xj = X[:, j]
|
|
177
|
+
xn = (xj - xmin) / (xmax - xmin + 1e-8)
|
|
178
|
+
xn = np.clip(xn, 0.0, 1.0)
|
|
179
|
+
t = xn * nb
|
|
180
|
+
bin_idx = np.floor(t).astype(np.int32)
|
|
181
|
+
bin_idx = np.clip(bin_idx, 0, nb-1)
|
|
182
|
+
left = bias + cum[bin_idx]
|
|
183
|
+
right = bias + cum[bin_idx + 1]
|
|
184
|
+
frac = t - bin_idx.astype(np.float32)
|
|
185
|
+
Xc[:, j] = left + frac * (right - left)
|
|
186
|
+
return Xc
|
|
187
|
+
|
|
188
|
+
def beta_facts(Xc):
|
|
189
|
+
L = np.array(P['fact']['L'], dtype=np.float32)
|
|
190
|
+
R = np.array(P['fact']['R'], dtype=np.float32)
|
|
191
|
+
th = np.array(P['fact']['th'], dtype=np.float32)
|
|
192
|
+
k = np.array(P['fact']['k'], dtype=np.float32)
|
|
193
|
+
nu = np.array(P['fact']['nu'], dtype=np.float32)
|
|
194
|
+
diff = Xc @ L.T - Xc @ R.T - th
|
|
195
|
+
z = k * diff
|
|
196
|
+
log_sig = logsigmoid(z)
|
|
197
|
+
log_beta = nu * log_sig
|
|
198
|
+
log_beta = np.maximum(log_beta, -80.0)
|
|
199
|
+
return np.exp(log_beta).astype(np.float32)
|
|
200
|
+
|
|
201
|
+
def run_block(block, H):
|
|
202
|
+
kind = block['kind']
|
|
203
|
+
if block['proj_W'] is None:
|
|
204
|
+
proj = H
|
|
205
|
+
else:
|
|
206
|
+
Wp = np.array(block['proj_W'], dtype=np.float32)
|
|
207
|
+
proj = H @ Wp.T
|
|
208
|
+
|
|
209
|
+
if kind == 'simple':
|
|
210
|
+
idx = np.array(block['idx_pairs'], dtype=np.int64)
|
|
211
|
+
w = np.array(block['rule_w'], dtype=np.float32)
|
|
212
|
+
rs = sigmoid(w)
|
|
213
|
+
f1 = H[:, idx[:,0]]
|
|
214
|
+
f2 = H[:, idx[:,1]]
|
|
215
|
+
rule_act = (f1 * f2) * rs
|
|
216
|
+
pre = proj + rule_act
|
|
217
|
+
|
|
218
|
+
elif kind == 'softmax':
|
|
219
|
+
mask = np.array(block['mask'], dtype=np.float32) # [R,F]
|
|
220
|
+
sel = H[:, None, :] * mask[None, :, :]
|
|
221
|
+
and_agg = np.prod(sel + (1.0 - mask)[None, :, :], axis=2)
|
|
222
|
+
or_agg = 1.0 - np.prod((1.0 - sel) * mask[None, :, :] + (1.0 - mask)[None, :, :], axis=2)
|
|
223
|
+
denom = np.maximum(mask.sum(axis=1), 1e-8) # [R]
|
|
224
|
+
kofn = (sel.sum(axis=2)) / denom[None, :]
|
|
225
|
+
agg_w = np.array(block['agg_w'], dtype=np.float32) # [R,3]
|
|
226
|
+
aggs = np.stack([and_agg, or_agg, kofn], axis=2)
|
|
227
|
+
mixed = (aggs * agg_w[None, :, :]).sum(axis=2)
|
|
228
|
+
rs = np.array(block['rule_strength'], dtype=np.float32)
|
|
229
|
+
rule_act = mixed * rs[None, :]
|
|
230
|
+
R = mask.shape[0]
|
|
231
|
+
k_rules = int(block['top_k_rules'])
|
|
232
|
+
if k_rules < R:
|
|
233
|
+
gate = np.zeros_like(rule_act, dtype=np.float32)
|
|
234
|
+
idx_top = np.argpartition(rule_act, -k_rules, axis=1)[:, -k_rules:]
|
|
235
|
+
for i in range(rule_act.shape[0]):
|
|
236
|
+
gate[i, idx_top[i]] = 1.0
|
|
237
|
+
rule_act = rule_act * gate
|
|
238
|
+
pre = proj + rule_act
|
|
239
|
+
|
|
240
|
+
elif kind == 'sparse':
|
|
241
|
+
mask = np.array(block['mask'], dtype=np.float32) # [R,F]
|
|
242
|
+
sel = H[:, None, :] * mask[None, :, :]
|
|
243
|
+
and_agg = np.prod(sel + (1.0 - mask)[None, :, :], axis=2)
|
|
244
|
+
or_agg = 1.0 - np.prod((1.0 - sel) * mask[None, :, :] + (1.0 - mask)[None, :, :], axis=2)
|
|
245
|
+
denom = np.maximum(mask.sum(axis=1), 1e-8) # [R]
|
|
246
|
+
kofn = (sel.sum(axis=2)) / denom[None, :]
|
|
247
|
+
not_agg = 1.0 - kofn
|
|
248
|
+
agg_w = np.array(block['agg_w'], dtype=np.float32) # [R,4]
|
|
249
|
+
aggs = np.stack([and_agg, or_agg, kofn, not_agg], axis=2)
|
|
250
|
+
mixed = (aggs * agg_w[None, :, :]).sum(axis=2)
|
|
251
|
+
rs = np.array(block['rule_strength'], dtype=np.float32)
|
|
252
|
+
rule_act = mixed * rs[None, :]
|
|
253
|
+
R = mask.shape[0]
|
|
254
|
+
k_rules = int(block['top_k_rules'])
|
|
255
|
+
if k_rules < R:
|
|
256
|
+
gate = np.zeros_like(rule_act, dtype=np.float32)
|
|
257
|
+
idx_top = np.argpartition(rule_act, -k_rules, axis=1)[:, -k_rules:]
|
|
258
|
+
for i in range(rule_act.shape[0]):
|
|
259
|
+
gate[i, idx_top[i]] = 1.0
|
|
260
|
+
rule_act = rule_act * gate
|
|
261
|
+
pre = proj + rule_act
|
|
262
|
+
|
|
263
|
+
else:
|
|
264
|
+
raise ValueError("Unknown block kind: " + str(kind))
|
|
265
|
+
|
|
266
|
+
gamma = np.array(block['ln_w'], dtype=np.float32)
|
|
267
|
+
beta = np.array(block['ln_b'], dtype=np.float32)
|
|
268
|
+
return layernorm(pre, gamma, beta)
|
|
269
|
+
|
|
270
|
+
def head_forward(H):
|
|
271
|
+
head = P['head']
|
|
272
|
+
if head['kind'] == 'linear':
|
|
273
|
+
W = np.array(head['W'], dtype=np.float32)
|
|
274
|
+
b = np.array(head['b'], dtype=np.float32)
|
|
275
|
+
return H @ W.T + b
|
|
276
|
+
elif head['kind'] == 'prototypes':
|
|
277
|
+
Pn = np.array(head['P_norm'], dtype=np.float32)
|
|
278
|
+
W = np.array(head['W'], dtype=np.float32)
|
|
279
|
+
Hn = H / (np.linalg.norm(H, axis=1, keepdims=True) + 1e-8)
|
|
280
|
+
dot = Hn @ Pn.T
|
|
281
|
+
d = np.sqrt(np.clip(2.0 - 2.0 * dot, 1e-12, None))
|
|
282
|
+
tau = float(head['tau'])
|
|
283
|
+
act = np.exp(-tau * d)
|
|
284
|
+
return act @ W
|
|
285
|
+
else:
|
|
286
|
+
raise ValueError("Unknown head kind: " + str(head['kind']))
|
|
287
|
+
|
|
288
|
+
def predict(X, return_logits=False):
|
|
289
|
+
X = np.array(X, dtype=np.float32)
|
|
290
|
+
Xc = calibrate(X) if P['use_calibrators'] else X
|
|
291
|
+
H = beta_facts(Xc)
|
|
292
|
+
for blk in P['blocks']:
|
|
293
|
+
H = run_block(blk, H)
|
|
294
|
+
logits = head_forward(H)
|
|
295
|
+
if P['task'] == 'classification':
|
|
296
|
+
e = np.exp(logits - logits.max(axis=1, keepdims=True))
|
|
297
|
+
probs = e / e.sum(axis=1, keepdims=True)
|
|
298
|
+
return (probs, logits) if return_logits else probs
|
|
299
|
+
else:
|
|
300
|
+
return logits.reshape(-1)
|
|
301
|
+
"""
|
|
302
|
+
code = textwrap.dedent(code)
|
|
303
|
+
if file_path is not None:
|
|
304
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
305
|
+
f.write(code)
|
|
306
|
+
return code
|
|
307
|
+
|
|
308
|
+
def slugify(s: str) -> str:
|
|
309
|
+
return re.sub(r'[^a-zA-Z0-9]+', '_', s).strip('_').lower()
|
|
310
|
+
|
|
311
|
+
def load_numpy_module(path: str) -> ModuleType:
|
|
312
|
+
spec = importlib.util.spec_from_file_location("nous_numpy_infer_mod", path)
|
|
313
|
+
mod = importlib.util.module_from_spec(spec)
|
|
314
|
+
assert spec.loader is not None
|
|
315
|
+
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
|
316
|
+
return mod
|
|
317
|
+
|
|
318
|
+
def _softmax_np(z: np.ndarray) -> np.ndarray:
|
|
319
|
+
z = z - z.max(axis=1, keepdims=True)
|
|
320
|
+
e = np.exp(z)
|
|
321
|
+
return e / e.sum(axis=1, keepdims=True)
|
|
322
|
+
|
|
323
|
+
def _kl_div(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
|
324
|
+
p = np.clip(p, eps, 1.0)
|
|
325
|
+
q = np.clip(q, eps, 1.0)
|
|
326
|
+
return np.sum(p * (np.log(p) - np.log(q)), axis=1)
|
|
327
|
+
|
|
328
|
+
def _js_div(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
|
329
|
+
m = 0.5*(p+q)
|
|
330
|
+
return 0.5*_kl_div(p, m, eps) + 0.5*_kl_div(q, m, eps)
|
|
331
|
+
|
|
332
|
+
def validate_numpy_vs_torch(
|
|
333
|
+
model: NousNet, npmod: ModuleType, X, task: str, n: int = 512,
|
|
334
|
+
tol_prob_max: float = 1e-3, tol_prob_mean: float = 2e-4, tol_l1_mean: float = 3e-4, tol_js_mean: float = 5e-6,
|
|
335
|
+
tol_logit_centered: float = 1e-3,
|
|
336
|
+
tol_reg_max: float = 1e-4, tol_reg_mean: float = 2e-5
|
|
337
|
+
) -> dict:
|
|
338
|
+
"""
|
|
339
|
+
Probability-first validation:
|
|
340
|
+
- classification: PASS if prob metrics and centered-logit diff are within tolerances.
|
|
341
|
+
- regression: PASS if absolute prediction diffs within tolerances.
|
|
342
|
+
"""
|
|
343
|
+
model.eval()
|
|
344
|
+
device = next(model.parameters()).device
|
|
345
|
+
Xs = X[:min(len(X), n)].astype(np.float32)
|
|
346
|
+
|
|
347
|
+
if task == "classification":
|
|
348
|
+
with torch.no_grad():
|
|
349
|
+
torch_logits = model(torch.tensor(Xs, device=device)).cpu().numpy()
|
|
350
|
+
torch_probs = _softmax_np(torch_logits)
|
|
351
|
+
|
|
352
|
+
np_probs, np_logits = npmod.predict(Xs, return_logits=True)
|
|
353
|
+
np_probs = np.asarray(np_probs)
|
|
354
|
+
np_logits = np.asarray(np_logits)
|
|
355
|
+
|
|
356
|
+
torch_pred = np.argmax(torch_probs, axis=1)
|
|
357
|
+
numpy_pred = np.argmax(np_probs, axis=1)
|
|
358
|
+
fidelity = float((torch_pred == numpy_pred).mean())
|
|
359
|
+
|
|
360
|
+
dprob = np.abs(torch_probs - np_probs)
|
|
361
|
+
max_dprob = float(dprob.max())
|
|
362
|
+
mean_dprob = float(dprob.mean())
|
|
363
|
+
l1_per_sample = np.sum(dprob, axis=1)
|
|
364
|
+
l1_mean = float(l1_per_sample.mean())
|
|
365
|
+
|
|
366
|
+
js = _js_div(torch_probs, np_probs)
|
|
367
|
+
js_mean = float(js.mean())
|
|
368
|
+
|
|
369
|
+
tl = torch_logits - torch_logits.mean(axis=1, keepdims=True)
|
|
370
|
+
nl = np_logits - np_logits.mean(axis=1, keepdims=True)
|
|
371
|
+
dlog = np.abs(tl - nl)
|
|
372
|
+
max_dlog_centered = float(dlog.max())
|
|
373
|
+
|
|
374
|
+
passed = (
|
|
375
|
+
max_dprob <= tol_prob_max and
|
|
376
|
+
mean_dprob <= tol_prob_mean and
|
|
377
|
+
l1_mean <= tol_l1_mean and
|
|
378
|
+
js_mean <= tol_js_mean and
|
|
379
|
+
max_dlog_centered <= tol_logit_centered
|
|
380
|
+
)
|
|
381
|
+
return {
|
|
382
|
+
"fidelity_info": fidelity,
|
|
383
|
+
"max_abs_prob_diff": max_dprob,
|
|
384
|
+
"mean_abs_prob_diff": mean_dprob,
|
|
385
|
+
"mean_L1_prob": l1_mean,
|
|
386
|
+
"mean_JS": js_mean,
|
|
387
|
+
"max_abs_centered_logit_diff": max_dlog_centered,
|
|
388
|
+
"pass": passed
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
else:
|
|
392
|
+
with torch.no_grad():
|
|
393
|
+
torch_pred = model(torch.tensor(Xs, device=device)).cpu().numpy().ravel()
|
|
394
|
+
np_pred = np.asarray(npmod.predict(Xs)).ravel()
|
|
395
|
+
|
|
396
|
+
dp = np.abs(torch_pred - np_pred)
|
|
397
|
+
max_dp = float(dp.max()) if dp.size else 0.0
|
|
398
|
+
mean_dp = float(dp.mean()) if dp.size else 0.0
|
|
399
|
+
|
|
400
|
+
passed = (max_dp <= tol_reg_max and mean_dp <= tol_reg_mean)
|
|
401
|
+
return {
|
|
402
|
+
"max_abs_pred_diff": max_dp,
|
|
403
|
+
"mean_abs_pred_diff": mean_dp,
|
|
404
|
+
"pass": passed
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
def export_and_validate(model: NousNet, name: str, X, base_path: str = "./exports") -> dict:
|
|
408
|
+
os.makedirs(base_path, exist_ok=True)
|
|
409
|
+
file_path = os.path.join(base_path, f"nous_numpy_infer_{slugify(name)}.py")
|
|
410
|
+
export_numpy_inference(model, file_path=file_path)
|
|
411
|
+
npmod = load_numpy_module(file_path)
|
|
412
|
+
return validate_numpy_vs_torch(model, npmod, X, model.config['task_type'])
|
nous/facts.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
class BetaFactLayer(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Beta-like fact activation layer.
|
|
10
|
+
|
|
11
|
+
Computes β = exp(nu * log(sigmoid(k * (Lx - Rx - th)))) with numerical clamps.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, input_dim: int, num_facts: int) -> None:
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.L = nn.Linear(input_dim, num_facts, bias=False)
|
|
16
|
+
self.R = nn.Linear(input_dim, num_facts, bias=False)
|
|
17
|
+
self.th = nn.Parameter(torch.randn(num_facts) * 0.1)
|
|
18
|
+
self.kraw = nn.Parameter(torch.ones(num_facts) * 0.5)
|
|
19
|
+
self.nuraw= nn.Parameter(torch.zeros(num_facts))
|
|
20
|
+
|
|
21
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
diff = (self.L(x) - self.R(x)) - self.th
|
|
23
|
+
k = F.softplus(self.kraw) + 1e-4
|
|
24
|
+
nu = F.softplus(self.nuraw) + 1e-4
|
|
25
|
+
log_beta = torch.clamp(nu * F.logsigmoid(k * diff), min=-80.0)
|
|
26
|
+
return torch.exp(log_beta)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad()
|
|
29
|
+
def get_rule_parameters(self):
|
|
30
|
+
k = F.softplus(self.kraw).cpu().numpy() + 1e-4
|
|
31
|
+
nu = F.softplus(self.nuraw).cpu().numpy() + 1e-4
|
|
32
|
+
L_weights = self.L.weight.detach().cpu().numpy()
|
|
33
|
+
R_weights = self.R.weight.detach().cpu().numpy()
|
|
34
|
+
thresholds = self.th.detach().cpu().numpy()
|
|
35
|
+
return L_weights, R_weights, thresholds, k, nu
|
|
36
|
+
|
|
37
|
+
@torch.no_grad()
|
|
38
|
+
def compute_diff_and_params(self, x_cal: torch.Tensor):
|
|
39
|
+
"""
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
diff : torch.Tensor
|
|
43
|
+
(Lx - Rx - th) of shape [B, F].
|
|
44
|
+
k : torch.Tensor
|
|
45
|
+
Softplus(kraw) + eps of shape [F].
|
|
46
|
+
nu : torch.Tensor
|
|
47
|
+
Softplus(nuraw) + eps of shape [F].
|
|
48
|
+
net_w : torch.Tensor
|
|
49
|
+
L.weight - R.weight of shape [F, D].
|
|
50
|
+
"""
|
|
51
|
+
diff = (self.L(x_cal) - self.R(x_cal)) - self.th
|
|
52
|
+
k = F.softplus(self.kraw) + 1e-4
|
|
53
|
+
nu = F.softplus(self.nuraw) + 1e-4
|
|
54
|
+
net_w = (self.L.weight - self.R.weight)
|
|
55
|
+
return diff, k, nu, net_w
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PiecewiseLinearCalibrator(nn.Module):
|
|
59
|
+
"""
|
|
60
|
+
Monotonic piecewise-linear per-feature calibrator using cumulative positive deltas.
|
|
61
|
+
"""
|
|
62
|
+
def __init__(self, num_bins: int = 8, input_range=(-3.0, 3.0)) -> None:
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.num_bins = num_bins
|
|
65
|
+
self.input_min, self.input_max = input_range
|
|
66
|
+
self.register_buffer('bin_edges', torch.linspace(self.input_min, self.input_max, num_bins + 1))
|
|
67
|
+
self.deltas = nn.Parameter(torch.ones(num_bins) * 0.1)
|
|
68
|
+
self.bias = nn.Parameter(torch.zeros(1))
|
|
69
|
+
|
|
70
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
x_normalized = (x - self.input_min) / (self.input_max - self.input_min + 1e-8)
|
|
72
|
+
x_normalized = torch.clamp(x_normalized, 0.0, 1.0)
|
|
73
|
+
bin_idx = torch.floor(x_normalized * self.num_bins).long().clamp(0, self.num_bins - 1)
|
|
74
|
+
cum_deltas = torch.cumsum(F.softplus(self.deltas), dim=0)
|
|
75
|
+
cum_deltas = torch.cat([torch.zeros(1, device=x.device), cum_deltas])
|
|
76
|
+
left_vals = self.bias + cum_deltas[bin_idx]
|
|
77
|
+
right_vals = self.bias + cum_deltas[bin_idx + 1]
|
|
78
|
+
t = (x_normalized * self.num_bins) - bin_idx.float()
|
|
79
|
+
y = left_vals + t * (right_vals - left_vals)
|
|
80
|
+
return y
|
|
81
|
+
|
|
82
|
+
def local_slope(self, x: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
x_normalized = (x - self.input_min) / (self.input_max - self.input_min + 1e-8)
|
|
84
|
+
x_normalized = torch.clamp(x_normalized, 0.0, 1.0)
|
|
85
|
+
bin_idx = torch.floor(x_normalized * self.num_bins).long().clamp(0, self.num_bins - 1)
|
|
86
|
+
deltas_sp = F.softplus(self.deltas)
|
|
87
|
+
cum = torch.cumsum(deltas_sp, dim=0)
|
|
88
|
+
cum = torch.cat([torch.zeros(1, device=x.device), cum])
|
|
89
|
+
left_vals = self.bias + cum[bin_idx]
|
|
90
|
+
right_vals = self.bias + cum[bin_idx + 1]
|
|
91
|
+
slope_y_vs_xnorm = (right_vals - left_vals)
|
|
92
|
+
slope = slope_y_vs_xnorm * (self.num_bins / (self.input_max - self.input_min + 1e-8))
|
|
93
|
+
return torch.clamp(slope, min=1e-6)
|
|
94
|
+
|
|
95
|
+
@torch.no_grad()
|
|
96
|
+
def inverse(self, y: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
device = y.device
|
|
98
|
+
deltas_sp = F.softplus(self.deltas)
|
|
99
|
+
cum = torch.cumsum(deltas_sp, dim=0)
|
|
100
|
+
cum = torch.cat([torch.zeros(1, device=device), cum])
|
|
101
|
+
vals = self.bias + cum
|
|
102
|
+
|
|
103
|
+
y_flat = y.view(-1)
|
|
104
|
+
idx = torch.searchsorted(vals, y_flat, right=True) - 1
|
|
105
|
+
idx = torch.clamp(idx, 0, self.num_bins - 1)
|
|
106
|
+
|
|
107
|
+
y_left = vals[idx]
|
|
108
|
+
y_right = vals[idx + 1]
|
|
109
|
+
t = (y_flat - y_left) / torch.clamp((y_right - y_left), min=1e-6)
|
|
110
|
+
x_norm = (idx.float() + t) / self.num_bins
|
|
111
|
+
x = x_norm * (self.input_max - self.input_min) + self.input_min
|
|
112
|
+
return x.view_as(y)
|