nous 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nous might be problematic. Click here for more details.

Files changed (47) hide show
  1. nous/__init__.py +96 -19
  2. nous/data/__init__.py +4 -0
  3. nous/data/california.py +32 -0
  4. nous/data/wine.py +29 -0
  5. nous/explain/__init__.py +26 -0
  6. nous/explain/aggregator.py +34 -0
  7. nous/explain/cf.py +137 -0
  8. nous/explain/facts_desc.py +23 -0
  9. nous/explain/fidelity.py +56 -0
  10. nous/explain/generate.py +86 -0
  11. nous/explain/global_book.py +52 -0
  12. nous/explain/loo.py +130 -0
  13. nous/explain/mse.py +93 -0
  14. nous/explain/pruning.py +117 -0
  15. nous/explain/stability.py +42 -0
  16. nous/explain/traces.py +285 -0
  17. nous/explain/utils.py +15 -0
  18. nous/export/__init__.py +13 -0
  19. nous/export/numpy_infer.py +412 -0
  20. nous/facts.py +112 -0
  21. nous/model.py +226 -0
  22. nous/prototypes.py +43 -0
  23. nous/rules/__init__.py +11 -0
  24. nous/rules/blocks.py +63 -0
  25. nous/rules/fixed.py +26 -0
  26. nous/rules/softmax.py +93 -0
  27. nous/rules/sparse.py +142 -0
  28. nous/training/__init__.py +5 -0
  29. nous/training/evaluation.py +57 -0
  30. nous/training/schedulers.py +34 -0
  31. nous/training/train.py +177 -0
  32. nous/types.py +4 -0
  33. nous/utils/__init__.py +3 -0
  34. nous/utils/metrics.py +2 -0
  35. nous/utils/seed.py +13 -0
  36. nous/version.py +1 -0
  37. nous-0.2.0.dist-info/METADATA +150 -0
  38. nous-0.2.0.dist-info/RECORD +41 -0
  39. nous/causal.py +0 -63
  40. nous/interpret.py +0 -111
  41. nous/layers.py +0 -117
  42. nous/models.py +0 -65
  43. nous-0.1.0.dist-info/METADATA +0 -138
  44. nous-0.1.0.dist-info/RECORD +0 -10
  45. {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/WHEEL +0 -0
  46. {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ from .numpy_infer import (
2
+ export_numpy_inference,
3
+ validate_numpy_vs_torch,
4
+ export_and_validate,
5
+ load_numpy_module,
6
+ )
7
+
8
+ all = [
9
+ "export_numpy_inference",
10
+ "validate_numpy_vs_torch",
11
+ "export_and_validate",
12
+ "load_numpy_module",
13
+ ]
@@ -0,0 +1,412 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import re
4
+ import textwrap
5
+ import importlib.util
6
+ from types import ModuleType
7
+ from typing import Dict, Any
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from ..model import NousNet
14
+
15
+ def export_numpy_inference(model: NousNet, file_path: str = "nous_numpy_infer.py") -> str:
16
+ """
17
+ Export a NousNet instance to a self-contained NumPy inference Python module.
18
+
19
+ The generated module exposes:
20
+ - a dict P with parameters and metadata,
21
+ - predict(X, return_logits=False) using only numpy.
22
+
23
+ Returns the code as a string and writes it to file_path if provided.
24
+ """
25
+ model.eval()
26
+ cfg = model.config
27
+
28
+ def npy(x):
29
+ if isinstance(x, torch.Tensor):
30
+ return x.detach().cpu().numpy()
31
+ return np.array(x)
32
+
33
+ def tolist(x, dtype=np.float32):
34
+ return npy(x).astype(dtype).tolist()
35
+
36
+ use_cals = model.calibrators is not None
37
+ calibrators = []
38
+ if use_cals:
39
+ for c in model.calibrators:
40
+ with torch.no_grad():
41
+ deltas_sp = F.softplus(c.deltas).cpu().numpy()
42
+ cum = np.cumsum(deltas_sp)
43
+ cum = np.concatenate([np.array([0.0], dtype=np.float32), cum.astype(np.float32)], axis=0)
44
+ calibrators.append(dict(
45
+ input_min=float(c.input_min),
46
+ input_max=float(c.input_max),
47
+ num_bins=int(c.num_bins),
48
+ bias=float(c.bias.item()),
49
+ cum=cum.tolist()
50
+ ))
51
+
52
+ Lw, Rw, th, k, nu = model.fact.get_rule_parameters()
53
+ fact_dict = dict(
54
+ L=Lw.astype(np.float32).tolist(),
55
+ R=Rw.astype(np.float32).tolist(),
56
+ th=th.astype(np.float32).tolist(),
57
+ k=k.astype(np.float32).tolist(),
58
+ nu=nu.astype(np.float32).tolist()
59
+ )
60
+
61
+ blocks = []
62
+ for blk in model.blocks:
63
+ ln_w = tolist(blk.norm.weight)
64
+ ln_b = tolist(blk.norm.bias)
65
+ proj_W = None
66
+ if not isinstance(blk.proj, torch.nn.Identity):
67
+ proj_W = tolist(blk.proj.weight)
68
+
69
+ from ..rules.blocks import SimpleNousBlock
70
+ from ..rules.softmax import SoftmaxRuleLayer
71
+ from ..rules.sparse import SparseRuleLayer
72
+
73
+ if isinstance(blk, SimpleNousBlock):
74
+ idx_pairs = blk.rule.idx.detach().cpu().numpy().astype(np.int64).tolist()
75
+ rule_w = tolist(blk.rule.weight)
76
+ blocks.append(dict(
77
+ kind="simple",
78
+ proj_W=proj_W,
79
+ ln_w=ln_w, ln_b=ln_b,
80
+ idx_pairs=idx_pairs,
81
+ rule_w=rule_w
82
+ ))
83
+
84
+ elif isinstance(blk, SoftmaxRuleLayer):
85
+ with torch.no_grad():
86
+ fl = F.softmax(blk.fact_logits, dim=1).cpu().numpy()
87
+ kf = int(min(blk.top_k_facts, blk.input_dim))
88
+ topk_idx = np.argpartition(fl, -kf, axis=1)[:, -kf:]
89
+ mask = np.zeros_like(fl, dtype=np.float32)
90
+ rows = np.arange(fl.shape[0])[:, None]
91
+ mask[rows, topk_idx] = 1.0
92
+
93
+ agg_w = F.softmax(blk.aggregator_logits, dim=1).cpu().numpy().astype(np.float32) # [R, 3]
94
+ rule_strength = torch.sigmoid(blk.rule_strength_raw).cpu().numpy().astype(np.float32)
95
+ blocks.append(dict(
96
+ kind="softmax",
97
+ proj_W=proj_W,
98
+ ln_w=ln_w, ln_b=ln_b,
99
+ mask=mask.tolist(),
100
+ agg_w=agg_w.tolist(),
101
+ rule_strength=rule_strength.tolist(),
102
+ top_k_rules=int(blk.top_k_rules)
103
+ ))
104
+
105
+ elif isinstance(blk, SparseRuleLayer):
106
+ with torch.no_grad():
107
+ beta = blk.hard_concrete.beta.detach().cpu().numpy()
108
+ mask = (1.0 / (1.0 + np.exp(-beta)) > 0.5).astype(np.float32) # eval behavior
109
+ agg_w = F.softmax(blk.aggregator_logits, dim=1).cpu().numpy().astype(np.float32) # [R, 4]
110
+ rule_strength = torch.sigmoid(blk.rule_strength_raw).cpu().numpy().astype(np.float32)
111
+ blocks.append(dict(
112
+ kind="sparse",
113
+ proj_W=proj_W,
114
+ ln_w=ln_w, ln_b=ln_b,
115
+ mask=mask.tolist(),
116
+ agg_w=agg_w.tolist(),
117
+ rule_strength=rule_strength.tolist(),
118
+ top_k_rules=int(blk.top_k_rules)
119
+ ))
120
+ else:
121
+ raise ValueError(f"Unknown block type: {type(blk)}")
122
+
123
+ if isinstance(model.head, torch.nn.Linear):
124
+ with torch.no_grad():
125
+ W = model.head.weight.detach().cpu().numpy().astype(np.float32)
126
+ b = model.head.bias.detach().cpu().numpy().astype(np.float32)
127
+ head = dict(kind="linear", W=W.tolist(), b=b.tolist())
128
+ else:
129
+ from ..prototypes import ScaledPrototypeLayer
130
+ assert isinstance(model.head, ScaledPrototypeLayer)
131
+ with torch.no_grad():
132
+ Pm = model.head.prototypes.detach()
133
+ Pn = F.normalize(Pm, p=2, dim=1).cpu().numpy().astype(np.float32)
134
+ W = model.head.prototype_class.detach().cpu().numpy().astype(np.float32)
135
+ tau = float(F.softplus(model.head.temperature).item())
136
+ head = dict(kind="prototypes", P_norm=Pn.tolist(), W=W.tolist(), tau=tau)
137
+
138
+ P_dict = dict(
139
+ task=cfg['task_type'],
140
+ use_calibrators=bool(use_cals),
141
+ calibrators=calibrators,
142
+ fact=fact_dict,
143
+ blocks=blocks,
144
+ head=head
145
+ )
146
+
147
+ code = f"""# Auto-generated NumPy inference for NousNet
148
+ # This file is self-contained and requires only numpy.
149
+ import numpy as np
150
+
151
+ P = {repr(P_dict)}
152
+
153
+ def sigmoid(x):
154
+ return 1.0 / (1.0 + np.exp(-x))
155
+
156
+ def logsigmoid(z):
157
+ # stable: -log(1+exp(-z))
158
+ return -np.log1p(np.exp(-z))
159
+
160
+ def layernorm(x, gamma, beta, eps=1e-5):
161
+ mu = x.mean(axis=1, keepdims=True)
162
+ var = x.var(axis=1, keepdims=True)
163
+ xhat = (x - mu) / np.sqrt(var + eps)
164
+ return xhat * gamma + beta
165
+
166
+ def calibrate(X):
167
+ if not P['use_calibrators']:
168
+ return X.astype(np.float32)
169
+ X = X.astype(np.float32)
170
+ Xc = np.empty_like(X, dtype=np.float32)
171
+ for j, c in enumerate(P['calibrators']):
172
+ xmin, xmax = float(c['input_min']), float(c['input_max'])
173
+ nb = int(c['num_bins'])
174
+ cum = np.array(c['cum'], dtype=np.float32)
175
+ bias = float(c['bias'])
176
+ xj = X[:, j]
177
+ xn = (xj - xmin) / (xmax - xmin + 1e-8)
178
+ xn = np.clip(xn, 0.0, 1.0)
179
+ t = xn * nb
180
+ bin_idx = np.floor(t).astype(np.int32)
181
+ bin_idx = np.clip(bin_idx, 0, nb-1)
182
+ left = bias + cum[bin_idx]
183
+ right = bias + cum[bin_idx + 1]
184
+ frac = t - bin_idx.astype(np.float32)
185
+ Xc[:, j] = left + frac * (right - left)
186
+ return Xc
187
+
188
+ def beta_facts(Xc):
189
+ L = np.array(P['fact']['L'], dtype=np.float32)
190
+ R = np.array(P['fact']['R'], dtype=np.float32)
191
+ th = np.array(P['fact']['th'], dtype=np.float32)
192
+ k = np.array(P['fact']['k'], dtype=np.float32)
193
+ nu = np.array(P['fact']['nu'], dtype=np.float32)
194
+ diff = Xc @ L.T - Xc @ R.T - th
195
+ z = k * diff
196
+ log_sig = logsigmoid(z)
197
+ log_beta = nu * log_sig
198
+ log_beta = np.maximum(log_beta, -80.0)
199
+ return np.exp(log_beta).astype(np.float32)
200
+
201
+ def run_block(block, H):
202
+ kind = block['kind']
203
+ if block['proj_W'] is None:
204
+ proj = H
205
+ else:
206
+ Wp = np.array(block['proj_W'], dtype=np.float32)
207
+ proj = H @ Wp.T
208
+
209
+ if kind == 'simple':
210
+ idx = np.array(block['idx_pairs'], dtype=np.int64)
211
+ w = np.array(block['rule_w'], dtype=np.float32)
212
+ rs = sigmoid(w)
213
+ f1 = H[:, idx[:,0]]
214
+ f2 = H[:, idx[:,1]]
215
+ rule_act = (f1 * f2) * rs
216
+ pre = proj + rule_act
217
+
218
+ elif kind == 'softmax':
219
+ mask = np.array(block['mask'], dtype=np.float32) # [R,F]
220
+ sel = H[:, None, :] * mask[None, :, :]
221
+ and_agg = np.prod(sel + (1.0 - mask)[None, :, :], axis=2)
222
+ or_agg = 1.0 - np.prod((1.0 - sel) * mask[None, :, :] + (1.0 - mask)[None, :, :], axis=2)
223
+ denom = np.maximum(mask.sum(axis=1), 1e-8) # [R]
224
+ kofn = (sel.sum(axis=2)) / denom[None, :]
225
+ agg_w = np.array(block['agg_w'], dtype=np.float32) # [R,3]
226
+ aggs = np.stack([and_agg, or_agg, kofn], axis=2)
227
+ mixed = (aggs * agg_w[None, :, :]).sum(axis=2)
228
+ rs = np.array(block['rule_strength'], dtype=np.float32)
229
+ rule_act = mixed * rs[None, :]
230
+ R = mask.shape[0]
231
+ k_rules = int(block['top_k_rules'])
232
+ if k_rules < R:
233
+ gate = np.zeros_like(rule_act, dtype=np.float32)
234
+ idx_top = np.argpartition(rule_act, -k_rules, axis=1)[:, -k_rules:]
235
+ for i in range(rule_act.shape[0]):
236
+ gate[i, idx_top[i]] = 1.0
237
+ rule_act = rule_act * gate
238
+ pre = proj + rule_act
239
+
240
+ elif kind == 'sparse':
241
+ mask = np.array(block['mask'], dtype=np.float32) # [R,F]
242
+ sel = H[:, None, :] * mask[None, :, :]
243
+ and_agg = np.prod(sel + (1.0 - mask)[None, :, :], axis=2)
244
+ or_agg = 1.0 - np.prod((1.0 - sel) * mask[None, :, :] + (1.0 - mask)[None, :, :], axis=2)
245
+ denom = np.maximum(mask.sum(axis=1), 1e-8) # [R]
246
+ kofn = (sel.sum(axis=2)) / denom[None, :]
247
+ not_agg = 1.0 - kofn
248
+ agg_w = np.array(block['agg_w'], dtype=np.float32) # [R,4]
249
+ aggs = np.stack([and_agg, or_agg, kofn, not_agg], axis=2)
250
+ mixed = (aggs * agg_w[None, :, :]).sum(axis=2)
251
+ rs = np.array(block['rule_strength'], dtype=np.float32)
252
+ rule_act = mixed * rs[None, :]
253
+ R = mask.shape[0]
254
+ k_rules = int(block['top_k_rules'])
255
+ if k_rules < R:
256
+ gate = np.zeros_like(rule_act, dtype=np.float32)
257
+ idx_top = np.argpartition(rule_act, -k_rules, axis=1)[:, -k_rules:]
258
+ for i in range(rule_act.shape[0]):
259
+ gate[i, idx_top[i]] = 1.0
260
+ rule_act = rule_act * gate
261
+ pre = proj + rule_act
262
+
263
+ else:
264
+ raise ValueError("Unknown block kind: " + str(kind))
265
+
266
+ gamma = np.array(block['ln_w'], dtype=np.float32)
267
+ beta = np.array(block['ln_b'], dtype=np.float32)
268
+ return layernorm(pre, gamma, beta)
269
+
270
+ def head_forward(H):
271
+ head = P['head']
272
+ if head['kind'] == 'linear':
273
+ W = np.array(head['W'], dtype=np.float32)
274
+ b = np.array(head['b'], dtype=np.float32)
275
+ return H @ W.T + b
276
+ elif head['kind'] == 'prototypes':
277
+ Pn = np.array(head['P_norm'], dtype=np.float32)
278
+ W = np.array(head['W'], dtype=np.float32)
279
+ Hn = H / (np.linalg.norm(H, axis=1, keepdims=True) + 1e-8)
280
+ dot = Hn @ Pn.T
281
+ d = np.sqrt(np.clip(2.0 - 2.0 * dot, 1e-12, None))
282
+ tau = float(head['tau'])
283
+ act = np.exp(-tau * d)
284
+ return act @ W
285
+ else:
286
+ raise ValueError("Unknown head kind: " + str(head['kind']))
287
+
288
+ def predict(X, return_logits=False):
289
+ X = np.array(X, dtype=np.float32)
290
+ Xc = calibrate(X) if P['use_calibrators'] else X
291
+ H = beta_facts(Xc)
292
+ for blk in P['blocks']:
293
+ H = run_block(blk, H)
294
+ logits = head_forward(H)
295
+ if P['task'] == 'classification':
296
+ e = np.exp(logits - logits.max(axis=1, keepdims=True))
297
+ probs = e / e.sum(axis=1, keepdims=True)
298
+ return (probs, logits) if return_logits else probs
299
+ else:
300
+ return logits.reshape(-1)
301
+ """
302
+ code = textwrap.dedent(code)
303
+ if file_path is not None:
304
+ with open(file_path, "w", encoding="utf-8") as f:
305
+ f.write(code)
306
+ return code
307
+
308
+ def slugify(s: str) -> str:
309
+ return re.sub(r'[^a-zA-Z0-9]+', '_', s).strip('_').lower()
310
+
311
+ def load_numpy_module(path: str) -> ModuleType:
312
+ spec = importlib.util.spec_from_file_location("nous_numpy_infer_mod", path)
313
+ mod = importlib.util.module_from_spec(spec)
314
+ assert spec.loader is not None
315
+ spec.loader.exec_module(mod) # type: ignore[attr-defined]
316
+ return mod
317
+
318
+ def _softmax_np(z: np.ndarray) -> np.ndarray:
319
+ z = z - z.max(axis=1, keepdims=True)
320
+ e = np.exp(z)
321
+ return e / e.sum(axis=1, keepdims=True)
322
+
323
+ def _kl_div(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> np.ndarray:
324
+ p = np.clip(p, eps, 1.0)
325
+ q = np.clip(q, eps, 1.0)
326
+ return np.sum(p * (np.log(p) - np.log(q)), axis=1)
327
+
328
+ def _js_div(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> np.ndarray:
329
+ m = 0.5*(p+q)
330
+ return 0.5*_kl_div(p, m, eps) + 0.5*_kl_div(q, m, eps)
331
+
332
+ def validate_numpy_vs_torch(
333
+ model: NousNet, npmod: ModuleType, X, task: str, n: int = 512,
334
+ tol_prob_max: float = 1e-3, tol_prob_mean: float = 2e-4, tol_l1_mean: float = 3e-4, tol_js_mean: float = 5e-6,
335
+ tol_logit_centered: float = 1e-3,
336
+ tol_reg_max: float = 1e-4, tol_reg_mean: float = 2e-5
337
+ ) -> dict:
338
+ """
339
+ Probability-first validation:
340
+ - classification: PASS if prob metrics and centered-logit diff are within tolerances.
341
+ - regression: PASS if absolute prediction diffs within tolerances.
342
+ """
343
+ model.eval()
344
+ device = next(model.parameters()).device
345
+ Xs = X[:min(len(X), n)].astype(np.float32)
346
+
347
+ if task == "classification":
348
+ with torch.no_grad():
349
+ torch_logits = model(torch.tensor(Xs, device=device)).cpu().numpy()
350
+ torch_probs = _softmax_np(torch_logits)
351
+
352
+ np_probs, np_logits = npmod.predict(Xs, return_logits=True)
353
+ np_probs = np.asarray(np_probs)
354
+ np_logits = np.asarray(np_logits)
355
+
356
+ torch_pred = np.argmax(torch_probs, axis=1)
357
+ numpy_pred = np.argmax(np_probs, axis=1)
358
+ fidelity = float((torch_pred == numpy_pred).mean())
359
+
360
+ dprob = np.abs(torch_probs - np_probs)
361
+ max_dprob = float(dprob.max())
362
+ mean_dprob = float(dprob.mean())
363
+ l1_per_sample = np.sum(dprob, axis=1)
364
+ l1_mean = float(l1_per_sample.mean())
365
+
366
+ js = _js_div(torch_probs, np_probs)
367
+ js_mean = float(js.mean())
368
+
369
+ tl = torch_logits - torch_logits.mean(axis=1, keepdims=True)
370
+ nl = np_logits - np_logits.mean(axis=1, keepdims=True)
371
+ dlog = np.abs(tl - nl)
372
+ max_dlog_centered = float(dlog.max())
373
+
374
+ passed = (
375
+ max_dprob <= tol_prob_max and
376
+ mean_dprob <= tol_prob_mean and
377
+ l1_mean <= tol_l1_mean and
378
+ js_mean <= tol_js_mean and
379
+ max_dlog_centered <= tol_logit_centered
380
+ )
381
+ return {
382
+ "fidelity_info": fidelity,
383
+ "max_abs_prob_diff": max_dprob,
384
+ "mean_abs_prob_diff": mean_dprob,
385
+ "mean_L1_prob": l1_mean,
386
+ "mean_JS": js_mean,
387
+ "max_abs_centered_logit_diff": max_dlog_centered,
388
+ "pass": passed
389
+ }
390
+
391
+ else:
392
+ with torch.no_grad():
393
+ torch_pred = model(torch.tensor(Xs, device=device)).cpu().numpy().ravel()
394
+ np_pred = np.asarray(npmod.predict(Xs)).ravel()
395
+
396
+ dp = np.abs(torch_pred - np_pred)
397
+ max_dp = float(dp.max()) if dp.size else 0.0
398
+ mean_dp = float(dp.mean()) if dp.size else 0.0
399
+
400
+ passed = (max_dp <= tol_reg_max and mean_dp <= tol_reg_mean)
401
+ return {
402
+ "max_abs_pred_diff": max_dp,
403
+ "mean_abs_pred_diff": mean_dp,
404
+ "pass": passed
405
+ }
406
+
407
+ def export_and_validate(model: NousNet, name: str, X, base_path: str = "./exports") -> dict:
408
+ os.makedirs(base_path, exist_ok=True)
409
+ file_path = os.path.join(base_path, f"nous_numpy_infer_{slugify(name)}.py")
410
+ export_numpy_inference(model, file_path=file_path)
411
+ npmod = load_numpy_module(file_path)
412
+ return validate_numpy_vs_torch(model, npmod, X, model.config['task_type'])
nous/facts.py ADDED
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Tuple
6
+
7
+ class BetaFactLayer(nn.Module):
8
+ """
9
+ Beta-like fact activation layer.
10
+
11
+ Computes β = exp(nu * log(sigmoid(k * (Lx - Rx - th)))) with numerical clamps.
12
+ """
13
+ def __init__(self, input_dim: int, num_facts: int) -> None:
14
+ super().__init__()
15
+ self.L = nn.Linear(input_dim, num_facts, bias=False)
16
+ self.R = nn.Linear(input_dim, num_facts, bias=False)
17
+ self.th = nn.Parameter(torch.randn(num_facts) * 0.1)
18
+ self.kraw = nn.Parameter(torch.ones(num_facts) * 0.5)
19
+ self.nuraw= nn.Parameter(torch.zeros(num_facts))
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ diff = (self.L(x) - self.R(x)) - self.th
23
+ k = F.softplus(self.kraw) + 1e-4
24
+ nu = F.softplus(self.nuraw) + 1e-4
25
+ log_beta = torch.clamp(nu * F.logsigmoid(k * diff), min=-80.0)
26
+ return torch.exp(log_beta)
27
+
28
+ @torch.no_grad()
29
+ def get_rule_parameters(self):
30
+ k = F.softplus(self.kraw).cpu().numpy() + 1e-4
31
+ nu = F.softplus(self.nuraw).cpu().numpy() + 1e-4
32
+ L_weights = self.L.weight.detach().cpu().numpy()
33
+ R_weights = self.R.weight.detach().cpu().numpy()
34
+ thresholds = self.th.detach().cpu().numpy()
35
+ return L_weights, R_weights, thresholds, k, nu
36
+
37
+ @torch.no_grad()
38
+ def compute_diff_and_params(self, x_cal: torch.Tensor):
39
+ """
40
+ Returns
41
+ -------
42
+ diff : torch.Tensor
43
+ (Lx - Rx - th) of shape [B, F].
44
+ k : torch.Tensor
45
+ Softplus(kraw) + eps of shape [F].
46
+ nu : torch.Tensor
47
+ Softplus(nuraw) + eps of shape [F].
48
+ net_w : torch.Tensor
49
+ L.weight - R.weight of shape [F, D].
50
+ """
51
+ diff = (self.L(x_cal) - self.R(x_cal)) - self.th
52
+ k = F.softplus(self.kraw) + 1e-4
53
+ nu = F.softplus(self.nuraw) + 1e-4
54
+ net_w = (self.L.weight - self.R.weight)
55
+ return diff, k, nu, net_w
56
+
57
+
58
+ class PiecewiseLinearCalibrator(nn.Module):
59
+ """
60
+ Monotonic piecewise-linear per-feature calibrator using cumulative positive deltas.
61
+ """
62
+ def __init__(self, num_bins: int = 8, input_range=(-3.0, 3.0)) -> None:
63
+ super().__init__()
64
+ self.num_bins = num_bins
65
+ self.input_min, self.input_max = input_range
66
+ self.register_buffer('bin_edges', torch.linspace(self.input_min, self.input_max, num_bins + 1))
67
+ self.deltas = nn.Parameter(torch.ones(num_bins) * 0.1)
68
+ self.bias = nn.Parameter(torch.zeros(1))
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ x_normalized = (x - self.input_min) / (self.input_max - self.input_min + 1e-8)
72
+ x_normalized = torch.clamp(x_normalized, 0.0, 1.0)
73
+ bin_idx = torch.floor(x_normalized * self.num_bins).long().clamp(0, self.num_bins - 1)
74
+ cum_deltas = torch.cumsum(F.softplus(self.deltas), dim=0)
75
+ cum_deltas = torch.cat([torch.zeros(1, device=x.device), cum_deltas])
76
+ left_vals = self.bias + cum_deltas[bin_idx]
77
+ right_vals = self.bias + cum_deltas[bin_idx + 1]
78
+ t = (x_normalized * self.num_bins) - bin_idx.float()
79
+ y = left_vals + t * (right_vals - left_vals)
80
+ return y
81
+
82
+ def local_slope(self, x: torch.Tensor) -> torch.Tensor:
83
+ x_normalized = (x - self.input_min) / (self.input_max - self.input_min + 1e-8)
84
+ x_normalized = torch.clamp(x_normalized, 0.0, 1.0)
85
+ bin_idx = torch.floor(x_normalized * self.num_bins).long().clamp(0, self.num_bins - 1)
86
+ deltas_sp = F.softplus(self.deltas)
87
+ cum = torch.cumsum(deltas_sp, dim=0)
88
+ cum = torch.cat([torch.zeros(1, device=x.device), cum])
89
+ left_vals = self.bias + cum[bin_idx]
90
+ right_vals = self.bias + cum[bin_idx + 1]
91
+ slope_y_vs_xnorm = (right_vals - left_vals)
92
+ slope = slope_y_vs_xnorm * (self.num_bins / (self.input_max - self.input_min + 1e-8))
93
+ return torch.clamp(slope, min=1e-6)
94
+
95
+ @torch.no_grad()
96
+ def inverse(self, y: torch.Tensor) -> torch.Tensor:
97
+ device = y.device
98
+ deltas_sp = F.softplus(self.deltas)
99
+ cum = torch.cumsum(deltas_sp, dim=0)
100
+ cum = torch.cat([torch.zeros(1, device=device), cum])
101
+ vals = self.bias + cum
102
+
103
+ y_flat = y.view(-1)
104
+ idx = torch.searchsorted(vals, y_flat, right=True) - 1
105
+ idx = torch.clamp(idx, 0, self.num_bins - 1)
106
+
107
+ y_left = vals[idx]
108
+ y_right = vals[idx + 1]
109
+ t = (y_flat - y_left) / torch.clamp((y_right - y_left), min=1e-6)
110
+ x_norm = (idx.float() + t) / self.num_bins
111
+ x = x_norm * (self.input_max - self.input_min) + self.input_min
112
+ return x.view_as(y)