nous 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nous might be problematic. Click here for more details.

Files changed (47) hide show
  1. nous/__init__.py +96 -19
  2. nous/data/__init__.py +4 -0
  3. nous/data/california.py +32 -0
  4. nous/data/wine.py +29 -0
  5. nous/explain/__init__.py +26 -0
  6. nous/explain/aggregator.py +34 -0
  7. nous/explain/cf.py +137 -0
  8. nous/explain/facts_desc.py +23 -0
  9. nous/explain/fidelity.py +56 -0
  10. nous/explain/generate.py +86 -0
  11. nous/explain/global_book.py +52 -0
  12. nous/explain/loo.py +130 -0
  13. nous/explain/mse.py +93 -0
  14. nous/explain/pruning.py +117 -0
  15. nous/explain/stability.py +42 -0
  16. nous/explain/traces.py +285 -0
  17. nous/explain/utils.py +15 -0
  18. nous/export/__init__.py +13 -0
  19. nous/export/numpy_infer.py +412 -0
  20. nous/facts.py +112 -0
  21. nous/model.py +226 -0
  22. nous/prototypes.py +43 -0
  23. nous/rules/__init__.py +11 -0
  24. nous/rules/blocks.py +63 -0
  25. nous/rules/fixed.py +26 -0
  26. nous/rules/softmax.py +93 -0
  27. nous/rules/sparse.py +142 -0
  28. nous/training/__init__.py +5 -0
  29. nous/training/evaluation.py +57 -0
  30. nous/training/schedulers.py +34 -0
  31. nous/training/train.py +177 -0
  32. nous/types.py +4 -0
  33. nous/utils/__init__.py +3 -0
  34. nous/utils/metrics.py +2 -0
  35. nous/utils/seed.py +13 -0
  36. nous/version.py +1 -0
  37. nous-0.2.0.dist-info/METADATA +150 -0
  38. nous-0.2.0.dist-info/RECORD +41 -0
  39. nous/causal.py +0 -63
  40. nous/interpret.py +0 -111
  41. nous/layers.py +0 -117
  42. nous/models.py +0 -65
  43. nous-0.1.0.dist-info/METADATA +0 -138
  44. nous-0.1.0.dist-info/RECORD +0 -10
  45. {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/WHEEL +0 -0
  46. {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+ from ..rules.sparse import SparseRuleLayer
3
+
4
+ def make_sparse_regression_hook(
5
+ base_lambda: float = 3e-4, warmup: int = 40, ramp: int = 80,
6
+ temp_start: float = 0.60, temp_end: float = 0.25, temp_epochs: int = 200,
7
+ disable_topk: bool = True
8
+ ):
9
+ """
10
+ Return an after-epoch scheduler hook for sparse regression:
11
+ - L0 penalty ramp,
12
+ - Hard-Concrete temperature schedule,
13
+ - Disable top-k gating to preserve gradients.
14
+ """
15
+ def hook(model, epoch: int):
16
+ # 1) L0 schedule
17
+ l0_factor = 0.0 if epoch < warmup else min(1.0, (epoch - warmup) / max(1, ramp))
18
+ for blk in model.blocks:
19
+ if isinstance(blk, SparseRuleLayer):
20
+ blk.l0_lambda = base_lambda * l0_factor
21
+
22
+ # 2) HC temperature schedule
23
+ alpha = min(1.0, epoch / max(1, temp_epochs))
24
+ t = temp_start * (1 - alpha) + temp_end * alpha
25
+ for blk in model.blocks:
26
+ if isinstance(blk, SparseRuleLayer):
27
+ blk.hard_concrete.temperature = t
28
+
29
+ # 3) Disable top-k (for regression)
30
+ if disable_topk:
31
+ for blk in model.blocks:
32
+ if hasattr(blk, "top_k_rules"):
33
+ blk.top_k_rules = blk.num_rules
34
+ return hook
nous/training/train.py ADDED
@@ -0,0 +1,177 @@
1
+ from __future__ import annotations
2
+ import inspect
3
+ from typing import Callable, Optional, Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+
9
+ def train_model(
10
+ model: nn.Module,
11
+ train_loader: DataLoader,
12
+ val_loader: DataLoader,
13
+ criterion: nn.Module,
14
+ optimizer: torch.optim.Optimizer,
15
+ epochs: int,
16
+ patience: int,
17
+ device,
18
+ after_epoch_hook: Optional[Callable[..., None]] = None,
19
+ # Progress controls
20
+ verbose: int = 1,
21
+ log_every: int = 10,
22
+ use_tqdm: bool = False,
23
+ print_l0: bool = True,
24
+ ) -> float:
25
+ """
26
+ Train with early stopping. Adds L0 loss (if model exposes compute_total_l0_loss) and gradient clipping.
27
+
28
+ Progress
29
+ - verbose >= 1 prints epoch-level logs every `log_every` epochs, on improvement, and on first/last epoch.
30
+ - use_tqdm shows a progress bar over epochs with train/val (and L0) in the postfix.
31
+ - after_epoch_hook can be:
32
+ (model, epoch) -- legacy signature
33
+ (model, epoch, metrics_dict) -- extended signature
34
+ where metrics_dict contains:
35
+ {epoch, train_loss, val_loss, l0_loss, improved}.
36
+ """
37
+ model.to(device)
38
+ best_val_loss = float('inf')
39
+ epochs_no_improve = 0
40
+ best_model_state = None
41
+
42
+ # Epoch iterator (optionally tqdm)
43
+ if use_tqdm:
44
+ try:
45
+ from tqdm.auto import tqdm as _tqdm
46
+ epoch_iter = _tqdm(range(epochs), leave=False, desc="Training")
47
+ except Exception:
48
+ epoch_iter = range(epochs)
49
+ else:
50
+ epoch_iter = range(epochs)
51
+
52
+ for epoch in epoch_iter:
53
+ # -------------------------
54
+ # Train
55
+ # -------------------------
56
+ model.train()
57
+ total_train_loss = 0.0
58
+ total_l0_loss = 0.0
59
+ steps = 0
60
+
61
+ for X_batch, y_batch in train_loader:
62
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
63
+ optimizer.zero_grad()
64
+
65
+ outputs = model(X_batch)
66
+ if isinstance(criterion, nn.CrossEntropyLoss) or (outputs.ndim == 2 and outputs.size(-1) > 1):
67
+ target = y_batch.long() # classification
68
+ else:
69
+ target = y_batch.float() # regression
70
+ loss = criterion(outputs, target)
71
+
72
+ l0_loss = getattr(model, "compute_total_l0_loss", lambda: torch.tensor(0.0, device=device))()
73
+ total_loss = loss + l0_loss
74
+
75
+ total_loss.backward()
76
+ nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
77
+ optimizer.step()
78
+
79
+ total_train_loss += float(loss.item())
80
+ total_l0_loss += float(l0_loss.item()) if isinstance(l0_loss, torch.Tensor) else float(l0_loss)
81
+ steps += 1
82
+
83
+ avg_train_loss = total_train_loss / max(1, steps)
84
+ avg_l0_loss = total_l0_loss / max(1, steps)
85
+
86
+ # -------------------------
87
+ # Validate
88
+ # -------------------------
89
+ model.eval()
90
+ total_val_loss = 0.0
91
+ with torch.no_grad():
92
+ for X_batch, y_batch in val_loader:
93
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
94
+ outputs = model(X_batch)
95
+ if isinstance(criterion, nn.CrossEntropyLoss) or (outputs.ndim == 2 and outputs.size(-1) > 1):
96
+ target = y_batch.long()
97
+ else:
98
+ target = y_batch.float()
99
+ vloss = criterion(outputs, target)
100
+ total_val_loss += float(vloss.item())
101
+ avg_val_loss = total_val_loss / max(1, len(val_loader))
102
+
103
+ improved = avg_val_loss < (best_val_loss - 1e-6)
104
+
105
+ # -------------------------
106
+ # Progress reporting
107
+ # -------------------------
108
+ if use_tqdm and hasattr(epoch_iter, "set_postfix"):
109
+ postfix: Dict[str, str] = {"train": f"{avg_train_loss:.4f}", "val": f"{avg_val_loss:.4f}"}
110
+ if print_l0:
111
+ postfix["l0"] = f"{avg_l0_loss:.4f}"
112
+ try:
113
+ epoch_iter.set_postfix(postfix) # type: ignore[attr-defined]
114
+ except Exception:
115
+ pass
116
+
117
+ if verbose >= 1:
118
+ should_log = (
119
+ epoch == 0
120
+ or (epoch + 1 == epochs)
121
+ or improved
122
+ or ((epoch + 1) % max(1, log_every) == 0)
123
+ )
124
+ if should_log:
125
+ msg = f"Epoch [{epoch+1}/{epochs}] train={avg_train_loss:.4f} val={avg_val_loss:.4f}"
126
+ if print_l0:
127
+ msg += f" l0={avg_l0_loss:.4f}"
128
+ if improved:
129
+ msg += " (*)"
130
+ print(msg)
131
+
132
+ # -------------------------
133
+ # Hook (backward compatible)
134
+ # -------------------------
135
+ if after_epoch_hook is not None:
136
+ metrics: Dict[str, float | int | bool] = dict(
137
+ epoch=epoch,
138
+ train_loss=avg_train_loss,
139
+ val_loss=avg_val_loss,
140
+ l0_loss=avg_l0_loss,
141
+ improved=improved,
142
+ )
143
+ try:
144
+ sig = inspect.signature(after_epoch_hook)
145
+ if len(sig.parameters) >= 3:
146
+ after_epoch_hook(model, epoch, metrics) # extended
147
+ else:
148
+ after_epoch_hook(model, epoch) # legacy
149
+ except Exception:
150
+ # Fallback to legacy on any inspection/runtime error
151
+ try:
152
+ after_epoch_hook(model, epoch)
153
+ except Exception:
154
+ pass
155
+
156
+ # -------------------------
157
+ # Early stopping
158
+ # -------------------------
159
+ if improved:
160
+ best_val_loss = avg_val_loss
161
+ epochs_no_improve = 0
162
+ best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
163
+ else:
164
+ epochs_no_improve += 1
165
+ if epochs_no_improve >= patience:
166
+ if verbose >= 1:
167
+ print(f"Early stopping at epoch {epoch+1} (best val={best_val_loss:.4f})")
168
+ break
169
+
170
+ # Restore best
171
+ if best_model_state:
172
+ model.load_state_dict(best_model_state)
173
+ model.to(device)
174
+ if verbose >= 1:
175
+ print(f"Restored best model (val={best_val_loss:.4f})")
176
+
177
+ return best_val_loss
nous/types.py ADDED
@@ -0,0 +1,4 @@
1
+ from typing import Any
2
+
3
+ TensorLike = Any
4
+ NDArrayLike = Any
nous/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .seed import set_global_seed
2
+
3
+ __all__ = ["set_global_seed"]
nous/utils/metrics.py ADDED
@@ -0,0 +1,2 @@
1
+ from __future__ import annotations
2
+ # Placeholder for future metric utilities
nous/utils/seed.py ADDED
@@ -0,0 +1,13 @@
1
+ from __future__ import annotations
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+
6
+ def set_global_seed(seed: int = 42) -> None:
7
+ """
8
+ Set seeds for Python, NumPy, and PyTorch (CPU/CUDA).
9
+ """
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
nous/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.2.0"
@@ -0,0 +1,150 @@
1
+ Metadata-Version: 2.4
2
+ Name: nous
3
+ Version: 0.2.0
4
+ Summary: Nous: A Neuro-Symbolic Library for Interpretable AI
5
+ Author-email: Islam Tlupov <tlupovislam@gmail.com>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Islam Tlupov
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Project-URL: Repository, https://github.com/EmotionEngineer/nous
29
+ Classifier: Development Status :: 3 - Alpha
30
+ Classifier: Intended Audience :: Developers
31
+ Classifier: Intended Audience :: Science/Research
32
+ Classifier: License :: OSI Approved :: MIT License
33
+ Classifier: Programming Language :: Python :: 3
34
+ Classifier: Programming Language :: Python :: 3 :: Only
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
37
+ Classifier: Typing :: Typed
38
+ Requires-Python: >=3.8
39
+ Description-Content-Type: text/markdown
40
+ License-File: LICENSE
41
+ Requires-Dist: torch>=2.1
42
+ Requires-Dist: numpy>=1.22
43
+ Requires-Dist: pandas>=1.5
44
+ Requires-Dist: scikit-learn>=1.2
45
+ Provides-Extra: dev
46
+ Requires-Dist: pytest>=7.0; extra == "dev"
47
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
48
+ Requires-Dist: mypy>=1.5; extra == "dev"
49
+ Requires-Dist: ruff>=0.5; extra == "dev"
50
+ Requires-Dist: black>=23.0; extra == "dev"
51
+ Requires-Dist: matplotlib>=3.6; extra == "dev"
52
+ Requires-Dist: seaborn>=0.12; extra == "dev"
53
+ Requires-Dist: tqdm>=4.65; extra == "dev"
54
+ Requires-Dist: ucimlrepo>=0.0.5; extra == "dev"
55
+ Provides-Extra: examples
56
+ Requires-Dist: matplotlib>=3.6; extra == "examples"
57
+ Requires-Dist: seaborn>=0.12; extra == "examples"
58
+ Requires-Dist: tqdm>=4.65; extra == "examples"
59
+ Requires-Dist: ucimlrepo>=0.0.5; extra == "examples"
60
+ Dynamic: license-file
61
+
62
+ # Nous — A Neuro-Symbolic Library for Interpretable AI
63
+
64
+ [![PyPI](https://img.shields.io/pypi/v/nous.svg)](https://pypi.org/project/nous/)
65
+ [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
66
+
67
+ Make tabular models you can read.
68
+ Nous learns compact logical rules and optional case‑based prototypes inside one differentiable model — so prediction and explanation come from the same place.
69
+
70
+ - 🧩 One white‑box → two styles: rules and/or prototypes
71
+ - 🔀 Learned AND / OR / k‑of‑n mixtures capture interactions without bloat
72
+ - ✂️ Minimal, faithful stories: pruning + sufficiency/comprehensiveness checks
73
+ - 🚀 Practical: competitive accuracy, NumPy export, unit‑tested toolkit
74
+
75
+ ## Key Features
76
+
77
+ - Intrinsic interpretability (not post‑hoc): explanations are part of the forward pass
78
+ - Switchable style: enable/disable prototypes; choose rule selection (fixed / softmax / sparse); add calibrators
79
+ - Fidelity diagnostics: pruned‑forward inference, minimal‑sufficient explanations, stability tools
80
+ - Ready to ship: pure‑NumPy export for inference without PyTorch
81
+
82
+ ## Installation
83
+
84
+ ```bash
85
+ # Stable from PyPI
86
+ pip install nous
87
+
88
+ # With example extras (plots, progress, UCI fetchers)
89
+ pip install "nous[examples]"
90
+
91
+ # Dev setup (tests, linters, type checks)
92
+ pip install "nous[dev]"
93
+ ```
94
+
95
+ Requirements (core):
96
+ - Python 3.9+
97
+ - torch>=2.1
98
+ - numpy>=1.22
99
+ - pandas>=1.5
100
+ - scikit-learn>=1.2
101
+
102
+ Extras:
103
+ - examples: matplotlib>=3.6, seaborn>=0.12, tqdm>=4.65, ucimlrepo>=0.0.5
104
+ - dev: pytest>=7.0, pytest-cov>=4.0, mypy>=1.5, ruff>=0.5, black>=23.0, matplotlib>=3.6, seaborn>=0.12, tqdm>=4.65, ucimlrepo>=0.0.5
105
+
106
+ ## Recommended Configurations
107
+
108
+ | Profile | Rule selection | Calibrators | Prototypes | Use when | Speed |
109
+ |--------|-----------------|-------------|------------|----------|-------|
110
+ | Fast baseline | fixed | off | off | quick sweeps, ablations | ⚡⚡⚡ |
111
+ | Default rules | softmax | on | off | general use, strong accuracy | ⚡⚡ |
112
+ | Explain‑everything | softmax | on | on | rich case‑based narratives | ⚡ |
113
+
114
+ Tips:
115
+ - Train with prototypes off for speed; enable them only on the final model if you need case‑based stories.
116
+ - 300 epochs with patience≈50 works well on common tabular datasets.
117
+
118
+ ## Bench Snapshot (5‑fold CV, typical)
119
+
120
+ | Dataset | Metric | Nous (rules) | Nous (+proto) | EBM | XGBoost |
121
+ |--------|--------|--------------|---------------|-----|---------|
122
+ | HELOC (cls) | AUC | ~0.791 | ~0.792 | ~0.799 | ~0.796 |
123
+ | Adult (cls) | AUC | ~0.913 | ~0.914 | ~0.926 | ~0.929 |
124
+ | Breast Cancer (cls) | Acc | ~0.975 | ~0.983 | ~0.970 | ~0.965 |
125
+ | California (reg) | RMSE | ~0.514 | ~0.505 | ~0.562 | ~0.439 |
126
+
127
+ Numbers vary with seed/HPO. See examples/benchmark.ipynb for reproducible runs.
128
+
129
+ ## What makes Nous different?
130
+
131
+ - The explanation is the model: rules and prototypes live in the forward pass
132
+ - Interactions without clutter: AND/OR/k‑of‑n mixtures keep explanations short
133
+ - Verified stories: minimal‑sufficient explanations + pruned‑forward confidence checks
134
+ - Lightweight deployment: NumPy export (no torch at inference)
135
+
136
+ ## Repository Layout
137
+
138
+ - examples/
139
+ - benchmark.ipynb — end‑to‑end comparison on classic tabular data
140
+ - wine_classification.py, california_regression.py — minimal scripts
141
+ - export_numpy_demo.py — deploy without torch
142
+ - nous/
143
+ - model.py (NousNet), facts.py (calibrated L−R facts)
144
+ - rules/* (fixed/softmax/sparse), explain/* (pruning, fidelity, traces, prototypes)
145
+ - training/* (loop, schedulers), export/* (NumPy), utils/*
146
+ - tests/ — unit tests for forward, rules, facts, prototypes, explanations, export
147
+
148
+ ## License
149
+
150
+ MIT — see LICENSE.
@@ -0,0 +1,41 @@
1
+ nous/__init__.py,sha256=l2QKn4xsdUxcOmbUuiTS79dSqGGa9O5c-lkpKkCw_WI,2724
2
+ nous/facts.py,sha256=GrCinW97CCcBWQ-VfajE44vlKOCR3gijC-l36d4A59w,4843
3
+ nous/model.py,sha256=DjT0PnOuBzJQ2ueOfeQENYLg_SVIG5AqOGzYWlk4z3Y,8910
4
+ nous/prototypes.py,sha256=FWF3VGSAYsXh-3qNwRT_3nZdIk8A5JYK5gKBlccrxA0,1873
5
+ nous/types.py,sha256=RclD7jN8lTT_AqLtTZsoPjsd9HysVZLBz_OaAUzM4n8,61
6
+ nous/version.py,sha256=1KhrBItVjTCR-Sumh0o09b_aKrjTTcJrpTBh5GBw6Lk,21
7
+ nous/data/__init__.py,sha256=q2AgrdHg1zcE5dyaHw13sLFJI2DNpEBaTMHGByNYDFM,143
8
+ nous/data/california.py,sha256=jauBm7Hwivohp79cK9lbPGtjwvnrdrUDZoHlc591Nvc,1297
9
+ nous/data/wine.py,sha256=WvsXsfX_AnukpkboeMye3b2xceBy6Di528Y5HP-AmuY,1252
10
+ nous/explain/__init__.py,sha256=NWpPGMgX4DtGEf_oj6Gbbjh9_k2XD1N7zuojy3SwPbk,957
11
+ nous/explain/aggregator.py,sha256=VM3U8zvpwCN_fyN5_geglL6rCshGFblkpiJnXvBuce4,1297
12
+ nous/explain/cf.py,sha256=iidWOC92ZSNwcoGDCinx85ZmOJntLjp_bBNi7PsrTbw,5878
13
+ nous/explain/facts_desc.py,sha256=BAPkLSq0uLyTaSZFDXmTTS1onB_ju_f-R_I5N3O7Npg,1178
14
+ nous/explain/fidelity.py,sha256=VNUUJqa0mfeWCozc_l-L82DcySwk4Hl171_hnUkpjbk,2401
15
+ nous/explain/generate.py,sha256=HMnhwfaFYKdZNcXYKindnGpIquU0IM79OWM7tbF5yzA,3683
16
+ nous/explain/global_book.py,sha256=i2SeJ4cB5uIZsSLBZgW2Tkv60Tr8M_aAuakeA1TrElU,2185
17
+ nous/explain/loo.py,sha256=fw6aXhMj9jkFyWviiGulU7Ymt6NFAcpYFl5JwHDbXBE,5768
18
+ nous/explain/mse.py,sha256=VfABHZqrVEPOpTxKtfJ7E8Cw5U3VzADC_RIpKD1UMPM,3769
19
+ nous/explain/pruning.py,sha256=GHFIAjaAtCmCsl7hsrP7ALBY5lZ_Rml107bjRCFiSQs,4899
20
+ nous/explain/stability.py,sha256=CHW7TTvySPfoCz_VGN578zKVi6ChvCNP9iajDQQppNk,1844
21
+ nous/explain/traces.py,sha256=SO_PwkMBIJLWNqmK8WRCD8rBJ5s58QO1hUviEYZd3Ps,11569
22
+ nous/explain/utils.py,sha256=6u_dlUeRA9p4IA6yKjAkuUfE2GFMOH9ZfFnqRiVmeA0,564
23
+ nous/export/__init__.py,sha256=lKmcutGQ9xt3th61KEZyttBwnYBktxYfPJPRXO8olys,257
24
+ nous/export/numpy_infer.py,sha256=Oh2CtNqMlxLWhxJOGfp2CXeFqT94eLMi8OfTBa77oyk,15573
25
+ nous/rules/__init__.py,sha256=6PkHHNLuXs9bJIKSWV6mo6I76HBOib7Vu_uwBxBX0A0,259
26
+ nous/rules/blocks.py,sha256=MsU7t93kY_HVRAIODLhQOFdtIoyL9aw1Pe-yuV2wPrw,2569
27
+ nous/rules/fixed.py,sha256=0ABVPmhb5b6CX4Iy2-HJt3W3U9-y5zLRzcsJY40Xpoc,985
28
+ nous/rules/softmax.py,sha256=K8_VJHmRcCTcyitouL_eagWNk_k8AzsuxdYuHDCt5yA,4142
29
+ nous/rules/sparse.py,sha256=jc2GdHT7uQ0CqXOdMw56_EmtDQ_KQiB62SHDRFlvgj4,6249
30
+ nous/training/__init__.py,sha256=CrFOoYJhOGVk6LO3wVEx9TmsxWkwXw3I5Evcs_XZ5t8,259
31
+ nous/training/evaluation.py,sha256=3RiQW74WIO3KYXvspB8WNj2cWbSivHp-9EYvy2xcIc4,2019
32
+ nous/training/schedulers.py,sha256=hZgvxQ0ht5ks76dNRYbBsE1t_20gqWM7LhYSfmuyiRo,1276
33
+ nous/training/train.py,sha256=znaKYO21vkexaYxxBQY8g0ZvRx8MUac95I2AbY-xwKQ,6331
34
+ nous/utils/__init__.py,sha256=OEkCOgzF-mJ-jf_jeWebPhSMa0BLdIalFPCXTltQe-c,64
35
+ nous/utils/metrics.py,sha256=IF9uwjnniFP7r2G3GK9TiX7_CycTh4Jvb1I_-bPgrXg,76
36
+ nous/utils/seed.py,sha256=h0w9y5ExThWDmZjHI79AsQkJrFbFzPfco1OBIs2nNGI,311
37
+ nous-0.2.0.dist-info/licenses/LICENSE,sha256=07nO-ZFpy_s_msfks8VsONyV2cBBggqsEQD2h5sdVRo,1069
38
+ nous-0.2.0.dist-info/METADATA,sha256=u03tWgz3y7ClPSFRzurCAMMlUhPytglX6SEOTjl8w3k,6612
39
+ nous-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ nous-0.2.0.dist-info/top_level.txt,sha256=yUcst4OAspsyKhX0y5ENzFkJKzR_gislA5MykV1pVbk,5
41
+ nous-0.2.0.dist-info/RECORD,,
nous/causal.py DELETED
@@ -1,63 +0,0 @@
1
- # nous/causal.py
2
- import torch
3
- import torch.nn as nn
4
- from typing import Literal
5
- from .models import NousNet
6
-
7
- def find_counterfactual(
8
- model: NousNet,
9
- x_sample: torch.Tensor,
10
- target_output: float,
11
- task: Literal['regression', 'classification'] = 'classification',
12
- lr: float = 0.01,
13
- steps: int = 200,
14
- l1_lambda: float = 0.5
15
- ) -> dict:
16
- """
17
- Finds a minimal change to the input to achieve a target output.
18
-
19
- Args:
20
- model (NousNet): The trained model.
21
- x_sample (torch.Tensor): The original input tensor.
22
- target_output (float): For classification, the desired probability (e.g., 0.8).
23
- For regression, the desired absolute value (e.g., 150.0).
24
- task (Literal['regression', 'classification']): The type of task.
25
- lr, steps, l1_lambda: Optimization parameters.
26
-
27
- Returns:
28
- dict: A dictionary containing the counterfactual sample and a list of changes.
29
- """
30
- if task == 'classification':
31
- if not (0 < target_output < 1):
32
- raise ValueError("Target for classification must be a probability between 0 and 1.")
33
- if model.output_head.out_features > 1:
34
- raise NotImplementedError("Counterfactual analysis for multi-class models is not yet supported.")
35
- # Calculate target in logit space for numerical stability
36
- target = torch.log(torch.tensor(target_output) / (1 - torch.tensor(target_output)))
37
- else: # regression
38
- target = torch.tensor(target_output, dtype=torch.float32)
39
-
40
- x_sample = x_sample.clone().detach()
41
- delta = torch.zeros_like(x_sample, requires_grad=True)
42
- optimizer = torch.optim.Adam([delta], lr=lr)
43
-
44
- for _ in range(steps):
45
- optimizer.zero_grad()
46
- x_perturbed = x_sample + delta
47
- prediction = model(x_perturbed.unsqueeze(0)).squeeze()
48
-
49
- target_loss = (prediction - target)**2
50
- l1_loss = torch.norm(delta, p=1)
51
- total_loss = target_loss + l1_lambda * l1_loss
52
-
53
- # This will now only be called for scalar losses, as multi-class is caught above
54
- total_loss.backward()
55
- optimizer.step()
56
-
57
- final_x = x_sample + delta.detach()
58
- changes = []
59
- for i, name in enumerate(model.feature_names):
60
- if not torch.isclose(x_sample[i], final_x[i], atol=1e-3):
61
- changes.append((name, x_sample[i].item(), final_x[i].item()))
62
-
63
- return {"counterfactual_x": final_x, "changes": changes}
nous/interpret.py DELETED
@@ -1,111 +0,0 @@
1
- # nous/interpret.py
2
- import torch
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- import networkx as nx
7
- from .models import NousNet
8
- from .layers import LearnedAtomicFactLayer, BetaFactLayer
9
-
10
- def trace_decision_graph(model: NousNet, x_sample: torch.Tensor) -> dict:
11
- """Traces the full reasoning path for a single sample."""
12
- if x_sample.dim() == 1: x_sample = x_sample.unsqueeze(0)
13
- model.eval()
14
- graph_data = {"trace": {}}
15
- with torch.no_grad():
16
- facts = model.atomic_fact_layer(x_sample).squeeze(0)
17
- graph_data['trace']['Atomic Facts'] = {name: {"value": facts[i].item()} for i, name in enumerate(model.atomic_fact_layer.fact_names)}
18
- h = facts.unsqueeze(0)
19
- for i, block in enumerate(model.nous_blocks):
20
- h, concepts, rule_activations = block(h)
21
- concepts, rule_activations = concepts.squeeze(0), rule_activations.squeeze(0)
22
- graph_data['trace'][f'Rules L{i}'] = {name: {"value": rule_activations[j].item()} for j, name in enumerate(block.rule_layer.rule_names)}
23
- graph_data['trace'][f'Concepts L{i}'] = {name: {"value": concepts[j].item()} for j, name in enumerate(block.concept_names)}
24
- return graph_data
25
-
26
- def explain_fact(model: NousNet, fact_name: str) -> pd.DataFrame:
27
- """Provides a detailed breakdown of a single learned fact, showing feature weights."""
28
- fact_layer = model.atomic_fact_layer
29
- if not isinstance(fact_layer, LearnedAtomicFactLayer):
30
- raise TypeError("explain_fact is only applicable to 'beta' or 'sigmoid' fact layers.")
31
- try:
32
- fact_index = fact_layer.fact_names.index(fact_name)
33
- except ValueError:
34
- raise ValueError(f"Fact '{fact_name}' not found.")
35
-
36
- with torch.no_grad():
37
- w_left, w_right = fact_layer.projection_left.weight[fact_index], fact_layer.projection_right.weight[fact_index]
38
- threshold = fact_layer.thresholds[fact_index]
39
- df = pd.DataFrame({
40
- "feature": model.feature_names,
41
- "left_weight": w_left.cpu().detach().numpy(),
42
- "right_weight": w_right.cpu().detach().numpy(),
43
- })
44
- df['net_effect'] = df['left_weight'] - df['right_weight']
45
- print(f"Explanation for fact '{fact_name}':")
46
- print(f"Fact is TRUE when: (Sum(left_weight * feat) - Sum(right_weight * feat)) > {threshold.item():.3f}")
47
- return df.sort_values(by='net_effect', key=abs, ascending=False)
48
-
49
- def plot_fact_activation_function(model: NousNet, fact_name: str, x_range=(-3, 3), n_points=200):
50
- """Visualizes the learned activation function for a single learned fact."""
51
- fact_layer = model.atomic_fact_layer
52
- if not isinstance(fact_layer, LearnedAtomicFactLayer):
53
- raise TypeError("This function is only for 'beta' or 'sigmoid' fact layers.")
54
- try:
55
- fact_index = fact_layer.fact_names.index(fact_name)
56
- except ValueError:
57
- raise ValueError(f"Fact '{fact_name}' not found.")
58
-
59
- diff_range = torch.linspace(x_range[0], x_range[1], n_points)
60
-
61
- with torch.no_grad():
62
- if isinstance(fact_layer, BetaFactLayer):
63
- k = torch.nn.functional.softplus(fact_layer.k_raw[fact_index]) + 1e-4
64
- nu = torch.nn.functional.softplus(fact_layer.nu_raw[fact_index]) + 1e-4
65
- activations = (1 + torch.exp(-k * diff_range))**(-nu)
66
- label = f'Learned Beta-like (k={k:.2f}, ν={nu:.2f})'
67
- else: # SigmoidFactLayer
68
- steepness = torch.nn.functional.softplus(fact_layer.steepness[fact_index]) + 1e-4
69
- activations = torch.sigmoid(steepness * diff_range)
70
- label = f'Learned Sigmoid (steepness={steepness:.2f})'
71
-
72
- plt.figure(figsize=(8, 5))
73
- plt.plot(diff_range.numpy(), activations.numpy(), label=label, linewidth=2.5)
74
- plt.plot(diff_range.numpy(), torch.sigmoid(diff_range).numpy(), label='Standard Sigmoid', linestyle='--', color='gray')
75
- plt.title(f"Activation Function for Fact:\n'{fact_name}'")
76
- plt.xlabel("Difference Value (Left Projection - Right Projection - Threshold)")
77
- plt.ylabel("Fact Activation (Truth Value)")
78
- plt.legend(); plt.grid(True, linestyle=':'); plt.ylim(-0.05, 1.05); plt.show()
79
-
80
- def plot_final_layer_contributions(model: NousNet, x_sample: torch.Tensor):
81
- """Calculates and plots which high-level concepts most influenced the final prediction."""
82
- if x_sample.dim() == 1: x_sample = x_sample.unsqueeze(0)
83
- model.eval()
84
- with torch.no_grad():
85
- h = model.atomic_fact_layer(x_sample)
86
- for block in model.nous_blocks: h, _, _ = block(h)
87
- final_activations = h.squeeze(0)
88
-
89
- output_dim = model.output_head.out_features
90
- weights = model.output_head.weight.squeeze(0)
91
- title = "Top Final Layer Concept Contributions"
92
-
93
- if output_dim > 1:
94
- predicted_class = model.output_head(h).argmax().item()
95
- weights = model.output_head.weight[predicted_class]
96
- title += f" for Predicted Class {predicted_class}"
97
-
98
- contributions = final_activations * weights
99
- concept_names = model.nous_blocks[-1].concept_names
100
-
101
- df = pd.DataFrame({'concept': concept_names, 'contribution': contributions.cpu().detach().numpy()})
102
- df = df.sort_values('contribution', key=abs, ascending=False).head(15)
103
-
104
- plt.figure(figsize=(10, 6));
105
- colors = ['#5fba7d' if c > 0 else '#d65f5f' for c in df['contribution']]
106
- sns.barplot(x='contribution', y='concept', data=df, palette=colors, dodge=False)
107
- plt.title(title); plt.xlabel("Contribution (Activation * Weight)"); plt.ylabel("Final Layer Concept"); plt.show()
108
-
109
- def plot_logic_graph(*args, **kwargs):
110
- print("Graph visualization is planned for a future release.")
111
- pass