nous 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nous might be problematic. Click here for more details.

nous/model.py ADDED
@@ -0,0 +1,226 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
7
+
8
+ from .facts import BetaFactLayer, PiecewiseLinearCalibrator
9
+ from .prototypes import ScaledPrototypeLayer
10
+ from .rules.blocks import SimpleNousBlock
11
+ from .rules.softmax import SoftmaxRuleLayer
12
+ from .rules.sparse import SparseRuleLayer
13
+
14
+ class NousNet(nn.Module):
15
+ """
16
+ NousNet: rule-based neural network with honest interpretability for classification and regression.
17
+ """
18
+ def __init__(
19
+ self,
20
+ input_dim: int,
21
+ num_outputs: int,
22
+ task_type: str = "classification", # "classification" | "regression"
23
+ feature_names: Optional[Sequence[str]] = None,
24
+ num_facts: int = 48,
25
+ rules_per_layer: Sequence[int] = (24, 12),
26
+ use_calibrators: bool = False,
27
+ rule_selection_method: str = "fixed", # "fixed" | "softmax" | "sparse"
28
+ use_prototypes: bool = False,
29
+ l0_lambda: float = 1e-3,
30
+ hc_temperature: float = 0.1
31
+ ) -> None:
32
+ super().__init__()
33
+ self.config = {
34
+ 'input_dim': input_dim, 'num_outputs': num_outputs, 'task_type': task_type,
35
+ 'feature_names': list(feature_names) if feature_names is not None else [f"Feature_{i}" for i in range(input_dim)],
36
+ 'num_facts': num_facts, 'rules_per_layer': tuple(rules_per_layer),
37
+ 'use_calibrators': use_calibrators, 'rule_selection_method': rule_selection_method,
38
+ 'use_prototypes': bool(use_prototypes and task_type == "classification"),
39
+ 'l0_lambda': l0_lambda, 'hc_temperature': hc_temperature
40
+ }
41
+
42
+ if self.config['use_calibrators']:
43
+ self.calibrators = nn.ModuleList([PiecewiseLinearCalibrator() for _ in range(input_dim)])
44
+ else:
45
+ self.calibrators = None
46
+
47
+ self.fact = BetaFactLayer(input_dim, num_facts)
48
+
49
+ blocks: List[nn.Module] = []
50
+ cur = num_facts
51
+ for r in rules_per_layer:
52
+ if rule_selection_method == 'fixed':
53
+ blocks.append(SimpleNousBlock(cur, r))
54
+ elif rule_selection_method == 'softmax':
55
+ blocks.append(SoftmaxRuleLayer(cur, r))
56
+ elif rule_selection_method == 'sparse':
57
+ blocks.append(SparseRuleLayer(cur, r, l0_lambda=l0_lambda, hc_temperature=hc_temperature))
58
+ else:
59
+ raise ValueError(f"Unknown rule_selection_method: {rule_selection_method}")
60
+ cur = r
61
+ self.blocks = nn.ModuleList(blocks)
62
+
63
+ if self.config['use_prototypes']:
64
+ self.head = ScaledPrototypeLayer(cur, num_prototypes=10, num_classes=num_outputs)
65
+ else:
66
+ self.head = nn.Linear(cur, num_outputs)
67
+
68
+ if self.config['task_type'] == "regression" and self.config['num_outputs'] != 1:
69
+ self.config['num_outputs'] = 1
70
+ if isinstance(self.head, nn.Linear):
71
+ self.head = nn.Linear(cur, 1)
72
+
73
+ def forward(self, x: torch.Tensor, return_internals: bool = False):
74
+ internals: Dict[str, torch.Tensor] = {}
75
+ if self.calibrators is not None:
76
+ x = torch.stack([calib(x[:, i]) for i, calib in enumerate(self.calibrators)], dim=1)
77
+
78
+ h_facts = self.fact(x)
79
+ h = h_facts
80
+ if return_internals:
81
+ internals['facts'] = h_facts.detach()
82
+
83
+ for i, blk in enumerate(self.blocks):
84
+ if return_internals:
85
+ h, details = blk(h, return_details=True)
86
+ internals[f'block_{i}'] = details
87
+ else:
88
+ h = blk(h)
89
+
90
+ logits = self.head(h)
91
+ if self.config['task_type'] == "regression":
92
+ logits = logits.squeeze(-1)
93
+
94
+ if return_internals:
95
+ return logits, internals
96
+ return logits
97
+
98
+ @torch.no_grad()
99
+ def forward_explain(
100
+ self,
101
+ x: Union[np.ndarray, torch.Tensor],
102
+ drop_rule_spec: Optional[Tuple[int, int]] = None,
103
+ restrict_masks: Optional[List[torch.Tensor]] = None,
104
+ apply_pruning: bool = False,
105
+ pruning_threshold: float = 0.0,
106
+ device: Optional[torch.device] = None,
107
+ explain_disable_norm: bool = False,
108
+ explain_exclude_proj: bool = False
109
+ ):
110
+ """
111
+ Honest forward for explanations with interventions/gating recompute.
112
+
113
+ Returns:
114
+ - classification: (probas, logits, internals)
115
+ - regression: (pred, pred, internals)
116
+ """
117
+ self.eval()
118
+ device = device or next(self.parameters()).device
119
+
120
+ if isinstance(x, np.ndarray):
121
+ x = torch.tensor(x, dtype=torch.float32, device=device).unsqueeze(0)
122
+ elif isinstance(x, torch.Tensor) and x.dim() == 1:
123
+ x = x.unsqueeze(0).to(device)
124
+ else:
125
+ x = x.to(device)
126
+
127
+ if self.calibrators is not None:
128
+ x_cal = torch.stack([calib(x[:, i]) for i, calib in enumerate(self.calibrators)], dim=1)
129
+ else:
130
+ x_cal = x
131
+
132
+ h = self.fact(x_cal)
133
+ internals: Dict[str, torch.Tensor] = {'facts': h.detach()}
134
+
135
+ for i, blk in enumerate(self.blocks):
136
+ drop_idx = None
137
+ if drop_rule_spec is not None and drop_rule_spec[0] == i:
138
+ drop_idx = int(drop_rule_spec[1])
139
+ restrict = None
140
+ if restrict_masks is not None and i < len(restrict_masks) and restrict_masks[i] is not None:
141
+ restrict = restrict_masks[i].to(device)
142
+ prune = pruning_threshold if apply_pruning else None
143
+
144
+ h, details = blk(
145
+ h,
146
+ return_details=True,
147
+ drop_rule_idx=drop_idx,
148
+ restrict_mask=restrict,
149
+ prune_below=prune,
150
+ explain_disable_norm=explain_disable_norm,
151
+ explain_exclude_proj=explain_exclude_proj
152
+ )
153
+ internals[f'block_{i}'] = details
154
+
155
+ logits = self.head(h)
156
+ if self.config['task_type'] == "classification":
157
+ probas = F.softmax(logits, dim=-1)
158
+ return probas.squeeze(0).cpu().numpy(), logits.squeeze(0).cpu().numpy(), internals
159
+ else:
160
+ pred = logits.squeeze(-1).squeeze(0).cpu().numpy()
161
+ return np.array([pred]), np.array([pred]), internals
162
+
163
+ def compute_total_l0_loss(self) -> torch.Tensor:
164
+ if self.config['rule_selection_method'] != 'sparse':
165
+ return torch.tensor(0.0, device=next(self.parameters()).device)
166
+ total_l0 = 0.0
167
+ for blk in self.blocks:
168
+ if hasattr(blk, 'compute_l0_loss'):
169
+ total_l0 += blk.compute_l0_loss()
170
+ return total_l0
171
+
172
+ def model_summary(self) -> Dict[str, object]:
173
+ summary = {
174
+ "Task": self.config['task_type'],
175
+ "Rule Selection": self.config['rule_selection_method'],
176
+ "Use Calibrators": self.config['use_calibrators'],
177
+ "Use Prototypes": self.config['use_prototypes'],
178
+ "Num Facts": self.config['num_facts'],
179
+ "Rules per Layer": self.config['rules_per_layer'],
180
+ "Total Parameters": sum(p.numel() for p in self.parameters())
181
+ }
182
+ if self.config['rule_selection_method'] == 'sparse':
183
+ summary["L0 Lambda"] = self.config['l0_lambda']
184
+ return summary
185
+
186
+ @torch.no_grad()
187
+ def encode(
188
+ self,
189
+ X: Union[np.ndarray, torch.Tensor, Sequence[Sequence[float]]],
190
+ device: Optional[torch.device] = None,
191
+ batch_size: int = 2048,
192
+ explain_disable_norm: bool = False,
193
+ explain_exclude_proj: bool = False
194
+ ) -> torch.Tensor:
195
+ """
196
+ Return H [N, D_last] — representations at the head input.
197
+
198
+ Flags allow clean representations (without LayerNorm / without residual projection).
199
+ """
200
+ self.eval()
201
+ device = device or next(self.parameters()).device
202
+
203
+ if isinstance(X, np.ndarray):
204
+ X_tensor = torch.tensor(X, dtype=torch.float32)
205
+ elif isinstance(X, torch.Tensor):
206
+ X_tensor = X.detach().cpu().float()
207
+ else:
208
+ X_tensor = torch.tensor(np.asarray(X), dtype=torch.float32)
209
+
210
+ H_list = []
211
+ for i in range(0, len(X_tensor), batch_size):
212
+ xb = X_tensor[i:i+batch_size].to(device)
213
+
214
+ if self.calibrators is not None:
215
+ xb_cal = torch.stack([calib(xb[:, j]) for j, calib in enumerate(self.calibrators)], dim=1)
216
+ else:
217
+ xb_cal = xb
218
+
219
+ h = self.fact(xb_cal)
220
+ for blk in self.blocks:
221
+ h, _ = blk(h, return_details=True,
222
+ explain_disable_norm=explain_disable_norm,
223
+ explain_exclude_proj=explain_exclude_proj)
224
+ H_list.append(h.detach().cpu())
225
+ H = torch.cat(H_list, dim=0)
226
+ return H
nous/prototypes.py ADDED
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Dict
6
+
7
+ class ScaledPrototypeLayer(nn.Module):
8
+ """
9
+ Prototype-based head for classification with temperature scaling on distances.
10
+ Uses L2-normalized features/prototypes and exp(-tau * distance).
11
+ """
12
+ def __init__(self, input_dim: int, num_prototypes: int, num_classes: int) -> None:
13
+ super().__init__()
14
+ self.num_prototypes = num_prototypes
15
+ self.num_classes = num_classes
16
+ self.temperature = nn.Parameter(torch.tensor(1.0))
17
+ self.prototypes = nn.Parameter(torch.randn(num_prototypes, input_dim) * 0.1)
18
+ self.prototype_class = nn.Parameter(torch.randn(num_prototypes, num_classes))
19
+
20
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
21
+ h_norm = F.normalize(h, p=2, dim=1)
22
+ prototypes_norm = F.normalize(self.prototypes, p=2, dim=1)
23
+ distances = torch.cdist(h_norm, prototypes_norm)
24
+ activations = torch.exp(-F.softplus(self.temperature) * distances)
25
+ class_logits = torch.matmul(activations, self.prototype_class)
26
+ return class_logits
27
+
28
+ @torch.no_grad()
29
+ def get_params(self) -> Dict[str, torch.Tensor]:
30
+ P = self.prototypes.detach()
31
+ Pn = F.normalize(P, p=2, dim=1)
32
+ W = self.prototype_class.detach()
33
+ Wsm = F.softmax(W, dim=1)
34
+ tau = F.softplus(self.temperature).detach()
35
+ return dict(prototypes=P, prototypes_norm=Pn, class_weights=W, class_probs=Wsm, temperature=tau)
36
+
37
+ @torch.no_grad()
38
+ def compute_dist_act(self, h: torch.Tensor):
39
+ h_norm = F.normalize(h, p=2, dim=1)
40
+ Pn = F.normalize(self.prototypes, p=2, dim=1)
41
+ d = torch.cdist(h_norm, Pn)
42
+ act = torch.exp(-F.softplus(self.temperature) * d)
43
+ return d, act
nous/rules/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ from .fixed import FixedPairRuleLayer
2
+ from .softmax import SoftmaxRuleLayer
3
+ from .sparse import SparseRuleLayer
4
+ from .blocks import SimpleNousBlock
5
+
6
+ __all__ = [
7
+ "FixedPairRuleLayer",
8
+ "SoftmaxRuleLayer",
9
+ "SparseRuleLayer",
10
+ "SimpleNousBlock",
11
+ ]
nous/rules/blocks.py ADDED
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ from .fixed import FixedPairRuleLayer
5
+
6
+ class SimpleNousBlock(nn.Module):
7
+ """
8
+ Residual block over FixedPairRuleLayer with honest interventions and explain flags.
9
+
10
+ Options:
11
+ - drop_rule_idx: zero a single rule pre-gating
12
+ - restrict_mask: 0/1 mask over rules (frozen active set)
13
+ - prune_below: threshold prune post-gating
14
+ - explain_disable_norm / explain_exclude_proj for clean contributions
15
+ """
16
+ def __init__(self, in_dim: int, n_rules: int) -> None:
17
+ super().__init__()
18
+ self.rule = FixedPairRuleLayer(in_dim, n_rules)
19
+ self.proj = nn.Identity() if in_dim == n_rules else nn.Linear(in_dim, n_rules, bias=False)
20
+ self.norm = nn.LayerNorm(n_rules)
21
+
22
+ def forward(
23
+ self,
24
+ x: torch.Tensor,
25
+ return_details: bool = False,
26
+ drop_rule_idx: int | None = None,
27
+ restrict_mask: torch.Tensor | None = None,
28
+ prune_below: float | None = None,
29
+ explain_disable_norm: bool = False,
30
+ explain_exclude_proj: bool = False
31
+ ):
32
+ rule_activations = self.rule(x)
33
+
34
+ if restrict_mask is not None:
35
+ rule_activations = rule_activations * restrict_mask
36
+ if drop_rule_idx is not None:
37
+ rule_activations[:, drop_rule_idx] = 0.0
38
+
39
+ gated_activations = rule_activations
40
+ if prune_below is not None:
41
+ mask = (gated_activations.abs() >= prune_below).float()
42
+ gated_activations = gated_activations * mask
43
+
44
+ proj_contrib = torch.zeros_like(gated_activations)
45
+ if not isinstance(self.proj, nn.Identity):
46
+ proj_contrib = self.proj(x)
47
+
48
+ pre_sum = (0.0 * proj_contrib) + gated_activations if explain_exclude_proj else (proj_contrib + gated_activations)
49
+ output = pre_sum if explain_disable_norm else self.norm(pre_sum)
50
+
51
+ if return_details:
52
+ details = {
53
+ "pre_rule_activations": rule_activations.detach(),
54
+ "gated_activations": gated_activations.detach(),
55
+ "gate_mask": torch.ones_like(gated_activations, dtype=torch.float32),
56
+ "aggregator_weights": None,
57
+ "selected_indices": torch.arange(gated_activations.shape[1], device=x.device).unsqueeze(0).repeat(x.size(0), 1),
58
+ "facts_used": self.rule.idx.detach(),
59
+ "pre_norm_sum": pre_sum.detach(),
60
+ "proj_contrib": proj_contrib.detach()
61
+ }
62
+ return output, details
63
+ return output
nous/rules/fixed.py ADDED
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class FixedPairRuleLayer(nn.Module):
6
+ """
7
+ Fixed random fact pairs with AND logic. Pair indices are distinct per rule.
8
+ Output: rule_strength * (f_i * f_j)
9
+ """
10
+ def __init__(self, input_dim: int, num_rules: int) -> None:
11
+ super().__init__()
12
+ idx1 = torch.randint(0, input_dim, (num_rules,))
13
+ offset = torch.randint(1, input_dim, (num_rules,))
14
+ idx2 = (idx1 + offset) % input_dim
15
+ idx = torch.stack([idx1, idx2], dim=1)
16
+ self.register_buffer('idx', idx)
17
+ self.weight = nn.Parameter(torch.ones(num_rules))
18
+
19
+ def forward(self, facts: torch.Tensor) -> torch.Tensor:
20
+ f1, f2 = facts[:, self.idx[:,0]], facts[:, self.idx[:,1]]
21
+ rule_strength = torch.sigmoid(self.weight)
22
+ return rule_strength * (f1 * f2)
23
+
24
+ @torch.no_grad()
25
+ def get_rules(self):
26
+ return self.idx.cpu().numpy(), torch.sigmoid(self.weight).cpu().numpy()
nous/rules/softmax.py ADDED
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class SoftmaxRuleLayer(nn.Module):
7
+ """
8
+ Learnable fact connections via softmax gating with per-rule top-k and aggregator mixing.
9
+ Aggregators: AND, OR, k-of-n. Honest interventions supported.
10
+ """
11
+ def __init__(self, input_dim: int, num_rules: int, top_k_facts: int = 2, top_k_rules: int = 8) -> None:
12
+ super().__init__()
13
+ self.input_dim = input_dim
14
+ self.num_rules = num_rules
15
+ self.top_k_facts = top_k_facts
16
+ self.top_k_rules = top_k_rules
17
+
18
+ self.fact_logits = nn.Parameter(torch.randn(num_rules, input_dim) * 0.01)
19
+ self.num_aggregators = 3
20
+ self.aggregator_logits = nn.Parameter(torch.zeros(num_rules, self.num_aggregators))
21
+ self.rule_strength_raw = nn.Parameter(torch.zeros(num_rules))
22
+ self.proj = nn.Linear(input_dim, num_rules, bias=False) if input_dim != num_rules else nn.Identity()
23
+ self.norm = nn.LayerNorm(num_rules)
24
+
25
+ def forward(
26
+ self,
27
+ facts: torch.Tensor,
28
+ return_details: bool = False,
29
+ drop_rule_idx: int | None = None,
30
+ restrict_mask: torch.Tensor | None = None,
31
+ prune_below: float | None = None,
32
+ explain_disable_norm: bool = False,
33
+ explain_exclude_proj: bool = False
34
+ ):
35
+ fact_logits_soft = F.softmax(self.fact_logits, dim=1)
36
+ _, topk_fact_idx = torch.topk(fact_logits_soft, k=min(self.top_k_facts, self.input_dim), dim=1)
37
+ mask = torch.zeros_like(fact_logits_soft)
38
+ mask.scatter_(1, topk_fact_idx, 1.0)
39
+
40
+ facts_expanded = facts.unsqueeze(1)
41
+ mask_expanded = mask.unsqueeze(0)
42
+ selected_facts = facts_expanded * mask_expanded
43
+
44
+ and_agg = torch.prod(selected_facts + (1 - mask_expanded), dim=2)
45
+ or_agg = 1 - torch.prod((1 - selected_facts) * mask_expanded + (1 - mask_expanded), dim=2)
46
+ k_of_n_agg = torch.sum(selected_facts, dim=2) / (mask_expanded.sum(dim=2) + 1e-8)
47
+
48
+ agg_weights = F.softmax(self.aggregator_logits, dim=1)
49
+ aggregators = torch.stack([and_agg, or_agg, k_of_n_agg], dim=2)
50
+ mixed_agg = torch.sum(aggregators * agg_weights.unsqueeze(0), dim=2)
51
+
52
+ rule_strength = torch.sigmoid(self.rule_strength_raw)
53
+ rule_activations = mixed_agg * rule_strength.unsqueeze(0)
54
+
55
+ pre_for_topk = rule_activations.clone()
56
+ if restrict_mask is not None:
57
+ pre_for_topk = pre_for_topk + (restrict_mask - 1) * 1e9
58
+ if drop_rule_idx is not None:
59
+ pre_for_topk[:, drop_rule_idx] = -1e9
60
+
61
+ k = min(self.top_k_rules, self.num_rules)
62
+ _, topk_rule_idx = torch.topk(pre_for_topk, k=k, dim=1)
63
+ gate_mask = torch.zeros_like(rule_activations)
64
+ gate_mask.scatter_(1, topk_rule_idx, 1.0)
65
+
66
+ if restrict_mask is not None:
67
+ gate_mask = gate_mask * restrict_mask.unsqueeze(0).to(gate_mask.dtype)
68
+ if drop_rule_idx is not None:
69
+ gate_mask[:, drop_rule_idx] = 0.0
70
+
71
+ gated_activations = rule_activations * gate_mask
72
+ if prune_below is not None:
73
+ keep = (gated_activations.abs() >= prune_below).float()
74
+ gated_activations = gated_activations * keep
75
+ gate_mask = gate_mask * keep
76
+
77
+ proj_contrib = self.proj(facts) if not isinstance(self.proj, nn.Identity) else facts
78
+ pre_sum = gated_activations if explain_exclude_proj else (proj_contrib + gated_activations)
79
+ output = pre_sum if explain_disable_norm else self.norm(pre_sum)
80
+
81
+ if return_details:
82
+ details = {
83
+ "pre_rule_activations": rule_activations.detach(),
84
+ "gated_activations": gated_activations.detach(),
85
+ "gate_mask": gate_mask.detach(),
86
+ "aggregator_weights": agg_weights.detach(),
87
+ "selected_indices": topk_rule_idx.detach(),
88
+ "facts_used": topk_fact_idx.detach(),
89
+ "pre_norm_sum": pre_sum.detach(),
90
+ "proj_contrib": proj_contrib.detach()
91
+ }
92
+ return output, details
93
+ return output
nous/rules/sparse.py ADDED
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class HardConcretePerConnection(nn.Module):
7
+ """
8
+ Per-connection Hard-Concrete gate with safer sampling (clamped u and log_sigma).
9
+ """
10
+ def __init__(self, shape, init_mean=0.0, init_std=1.0, temperature=0.1, stretch_limits=(-0.1, 1.1)):
11
+ super().__init__()
12
+ self.shape = shape
13
+ self.beta = nn.Parameter(torch.ones(shape) * init_mean)
14
+ self.log_sigma = nn.Parameter(torch.ones(shape) * init_std)
15
+ self.temperature = temperature
16
+ self.l, self.r = stretch_limits
17
+
18
+ def forward(self, training: bool | None = None) -> torch.Tensor:
19
+ if training is None:
20
+ training = self.training
21
+ if training:
22
+ u = torch.rand(self.shape, device=self.beta.device).clamp(1e-6, 1 - 1e-6)
23
+ sigma = torch.exp(torch.clamp(self.log_sigma, -5.0, 5.0))
24
+ pre = (torch.log(u) - torch.log(1 - u) + self.beta) / (sigma + 1e-8)
25
+ s = torch.sigmoid(pre)
26
+ s = s * (self.r - self.l) + self.l
27
+ left = torch.sigmoid((s - self.l) / self.temperature)
28
+ right = torch.sigmoid((s - self.r) / self.temperature)
29
+ mask = (left - right).clamp(0.0, 1.0)
30
+ else:
31
+ mask = (torch.sigmoid(self.beta) > 0.5).float()
32
+ return mask
33
+
34
+ def get_proba(self) -> torch.Tensor:
35
+ l_tensor = torch.tensor(self.l, device=self.beta.device)
36
+ r_tensor = torch.tensor(self.r, device=self.beta.device)
37
+ return torch.sigmoid(torch.clamp(self.beta, -10.0, 10.0) - self.temperature * torch.log(-l_tensor / r_tensor + 1e-8))
38
+
39
+ def l0_penalty(self) -> torch.Tensor:
40
+ return self.get_proba().sum()
41
+
42
+
43
+ class SparseRuleLayer(nn.Module):
44
+ """
45
+ Hard-Concrete sparse connections + aggregator mixing + honest interventions.
46
+ Aggregators: AND, OR, k-of-n, NOT.
47
+ """
48
+ def __init__(self, input_dim: int, num_rules: int, top_k_facts: int = 2, top_k_rules: int = 8, l0_lambda: float = 1e-3, hc_temperature: float = 0.1):
49
+ super().__init__()
50
+ self.input_dim = input_dim
51
+ self.num_rules = num_rules
52
+ self.top_k_facts = top_k_facts
53
+ self.top_k_rules = top_k_rules
54
+ self.l0_lambda = l0_lambda
55
+
56
+ self.hard_concrete = HardConcretePerConnection(
57
+ shape=(num_rules, input_dim),
58
+ init_mean=0.0,
59
+ init_std=1.0,
60
+ temperature=hc_temperature,
61
+ stretch_limits=(-0.1, 1.1)
62
+ )
63
+
64
+ self.num_aggregators = 4
65
+ self.aggregator_logits = nn.Parameter(torch.zeros(num_rules, self.num_aggregators))
66
+ self.rule_strength_raw = nn.Parameter(torch.zeros(num_rules))
67
+ self.proj = nn.Linear(input_dim, num_rules, bias=False) if input_dim != num_rules else nn.Identity()
68
+ self.norm = nn.LayerNorm(num_rules)
69
+
70
+ def forward(
71
+ self,
72
+ facts: torch.Tensor,
73
+ return_details: bool = False,
74
+ drop_rule_idx: int | None = None,
75
+ restrict_mask: torch.Tensor | None = None,
76
+ prune_below: float | None = None,
77
+ explain_disable_norm: bool = False,
78
+ explain_exclude_proj: bool = False
79
+ ):
80
+ mask = self.hard_concrete(training=self.training) # [R, F]
81
+ facts_expanded = facts.unsqueeze(1)
82
+ mask_expanded = mask.unsqueeze(0)
83
+ selected_facts = facts_expanded * mask_expanded
84
+
85
+ and_agg = torch.prod(selected_facts + (1 - mask_expanded), dim=2)
86
+ or_agg = 1 - torch.prod((1 - selected_facts) * mask_expanded + (1 - mask_expanded), dim=2)
87
+ selected_count = torch.sum(mask_expanded, dim=2) + 1e-8
88
+ k_of_n_agg = torch.sum(selected_facts, dim=2) / selected_count
89
+ not_agg = 1 - k_of_n_agg
90
+
91
+ agg_weights = F.softmax(self.aggregator_logits, dim=1) # [R, 4]
92
+ aggregators = torch.stack([and_agg, or_agg, k_of_n_agg, not_agg], dim=2)
93
+ mixed_agg = torch.sum(aggregators * agg_weights.unsqueeze(0), dim=2) # [B, R]
94
+
95
+ rule_strength = torch.sigmoid(self.rule_strength_raw) # [R]
96
+ rule_activations = mixed_agg * rule_strength.unsqueeze(0) # [B, R]
97
+
98
+ pre_for_topk = rule_activations.clone()
99
+ if restrict_mask is not None:
100
+ pre_for_topk = pre_for_topk + (restrict_mask - 1) * 1e9
101
+ if drop_rule_idx is not None:
102
+ pre_for_topk[:, drop_rule_idx] = -1e9
103
+
104
+ k = min(self.top_k_rules, self.num_rules)
105
+ _, topk_rule_idx = torch.topk(pre_for_topk, k=k, dim=1)
106
+ gate_mask = torch.zeros_like(rule_activations)
107
+ gate_mask.scatter_(1, topk_rule_idx, 1.0)
108
+
109
+ if restrict_mask is not None:
110
+ gate_mask = gate_mask * restrict_mask.unsqueeze(0).to(gate_mask.dtype)
111
+ if drop_rule_idx is not None:
112
+ gate_mask[:, drop_rule_idx] = 0.0
113
+
114
+ gated_activations = rule_activations * gate_mask
115
+ if prune_below is not None:
116
+ keep = (gated_activations.abs() >= prune_below).float()
117
+ gated_activations = gated_activations * keep
118
+ gate_mask = gate_mask * keep
119
+
120
+ proj_contrib = self.proj(facts) if not isinstance(self.proj, nn.Identity) else facts
121
+ pre_sum = gated_activations if explain_exclude_proj else (proj_contrib + gated_activations)
122
+ output = pre_sum if explain_disable_norm else self.norm(pre_sum)
123
+
124
+ if return_details:
125
+ with torch.no_grad():
126
+ _, topk_fact_idx = torch.topk(mask, k=min(self.top_k_facts, self.input_dim), dim=1)
127
+ details = {
128
+ "pre_rule_activations": rule_activations.detach(),
129
+ "gated_activations": gated_activations.detach(),
130
+ "gate_mask": gate_mask.detach(),
131
+ "aggregator_weights": agg_weights.detach(),
132
+ "selected_indices": topk_rule_idx.detach(),
133
+ "facts_used": topk_fact_idx.detach(),
134
+ "pre_norm_sum": pre_sum.detach(),
135
+ "proj_contrib": proj_contrib.detach()
136
+ }
137
+ return output, details
138
+ return output
139
+
140
+ def compute_l0_loss(self) -> torch.Tensor:
141
+ num_conn = float(self.num_rules * self.input_dim)
142
+ return (self.l0_lambda * self.hard_concrete.l0_penalty()) / max(1.0, num_conn)
@@ -0,0 +1,5 @@
1
+ from .train import train_model
2
+ from .evaluation import evaluate_classification, evaluate_regression
3
+ from .schedulers import make_sparse_regression_hook
4
+
5
+ __all__ = ["train_model", "evaluate_classification", "evaluate_regression", "make_sparse_regression_hook"]
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from sklearn.metrics import (
7
+ roc_auc_score, accuracy_score, mean_absolute_error,
8
+ mean_squared_error, r2_score
9
+ )
10
+
11
+ def evaluate_classification(model, test_loader: DataLoader, device, class_names=None):
12
+ """
13
+ Evaluate classification metrics on a dataloader: accuracy and one-vs-rest AUC.
14
+ """
15
+ model.eval()
16
+ all_probs, all_labels = [], []
17
+ with torch.no_grad():
18
+ for X_batch, y_batch in test_loader:
19
+ X_batch = X_batch.to(device)
20
+ probs = F.softmax(model(X_batch), dim=1).cpu().numpy()
21
+ all_probs.extend(probs)
22
+ all_labels.extend(y_batch.numpy())
23
+ all_probs = np.array(all_probs)
24
+ all_preds = np.argmax(all_probs, axis=1)
25
+ acc = accuracy_score(all_labels, all_preds)
26
+ try:
27
+ auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
28
+ except ValueError:
29
+ auc = 0.5
30
+ return acc, auc, all_probs, all_labels
31
+
32
+ def evaluate_regression(model, test_loader: DataLoader, device, y_scaler=None):
33
+ """
34
+ Evaluate regression: RMSE, MAE, R^2. Returns unscaled predictions if y_scaler is passed.
35
+ """
36
+ model.eval()
37
+ preds, labels = [], []
38
+ with torch.no_grad():
39
+ for X_batch, y_batch in test_loader:
40
+ X_batch = X_batch.to(device)
41
+ pred = model(X_batch).cpu().numpy().ravel()
42
+ preds.extend(pred)
43
+ labels.extend(y_batch.numpy().ravel())
44
+
45
+ preds = np.array(preds)
46
+ labels = np.array(labels)
47
+
48
+ if y_scaler is not None:
49
+ preds_u = y_scaler.inverse_transform(preds.reshape(-1,1)).ravel()
50
+ labels_u = y_scaler.inverse_transform(labels.reshape(-1,1)).ravel()
51
+ else:
52
+ preds_u, labels_u = preds, labels
53
+
54
+ rmse = np.sqrt(mean_squared_error(labels_u, preds_u))
55
+ mae = mean_absolute_error(labels_u, preds_u)
56
+ r2 = r2_score(labels_u, preds_u)
57
+ return rmse, mae, r2, preds_u, labels_u