scratchkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. mlscratch/__init__.py +56 -0
  2. mlscratch/__main__.py +118 -0
  3. mlscratch/bayesian/__init__.py +53 -0
  4. mlscratch/bayesian/bayesian_linear_regression.py +171 -0
  5. mlscratch/bayesian/bayesian_network.py +248 -0
  6. mlscratch/bayesian/bayesian_nn.py +315 -0
  7. mlscratch/bayesian/gaussian_process.py +207 -0
  8. mlscratch/bayesian/hmm.py +277 -0
  9. mlscratch/bayesian/init.py +52 -0
  10. mlscratch/bayesian/kalman_filter.py +182 -0
  11. mlscratch/bayesian/naive_bayes.py +209 -0
  12. mlscratch/metrics/__init__.py +59 -0
  13. mlscratch/metrics/classification.py +365 -0
  14. mlscratch/metrics/regression.py +79 -0
  15. mlscratch/neural/__init__.py +121 -0
  16. mlscratch/neural/attention.py +420 -0
  17. mlscratch/neural/autoencoder.py +543 -0
  18. mlscratch/neural/boltzmann.py +231 -0
  19. mlscratch/neural/cnn.py +593 -0
  20. mlscratch/neural/cvnn.py +322 -0
  21. mlscratch/neural/gan.py +364 -0
  22. mlscratch/neural/hopfield.py +193 -0
  23. mlscratch/neural/perceptron.py +398 -0
  24. mlscratch/neural/rbf_network.py +230 -0
  25. mlscratch/neural/recurrent.py +569 -0
  26. mlscratch/preprocessing/__init__.py +38 -0
  27. mlscratch/preprocessing/encoders.py +140 -0
  28. mlscratch/preprocessing/model_selection.py +119 -0
  29. mlscratch/preprocessing/polynomial.py +105 -0
  30. mlscratch/preprocessing/scalers.py +220 -0
  31. mlscratch/py.typed +0 -0
  32. mlscratch/reinforcement/__init__.py +59 -0
  33. mlscratch/reinforcement/ddpg.py +363 -0
  34. mlscratch/reinforcement/dqn.py +319 -0
  35. mlscratch/reinforcement/ppo.py +452 -0
  36. mlscratch/reinforcement/q_learning.py +352 -0
  37. mlscratch/reinforcement/sac.py +382 -0
  38. mlscratch/reinforcement/utils.py +594 -0
  39. mlscratch/supervised/__init__.py +76 -0
  40. mlscratch/supervised/_validation.py +50 -0
  41. mlscratch/supervised/adaboost.py +255 -0
  42. mlscratch/supervised/decision_tree.py +495 -0
  43. mlscratch/supervised/gradient_boosting.py +354 -0
  44. mlscratch/supervised/knn.py +234 -0
  45. mlscratch/supervised/lasso_regression.py +125 -0
  46. mlscratch/supervised/linear_models.py +459 -0
  47. mlscratch/supervised/linear_regression.py +197 -0
  48. mlscratch/supervised/logistic_regression.py +119 -0
  49. mlscratch/supervised/naive_bayes.py +113 -0
  50. mlscratch/supervised/random_forest.py +321 -0
  51. mlscratch/supervised/ridge_regression.py +93 -0
  52. mlscratch/supervised/svm.py +356 -0
  53. mlscratch/unsupervised/__init__.py +39 -0
  54. mlscratch/unsupervised/apriori.py +178 -0
  55. mlscratch/unsupervised/dbscan.py +141 -0
  56. mlscratch/unsupervised/gmm.py +204 -0
  57. mlscratch/unsupervised/hierarchical_clustering.py +137 -0
  58. mlscratch/unsupervised/ica.py +167 -0
  59. mlscratch/unsupervised/kmeans.py +135 -0
  60. mlscratch/unsupervised/kmedoids.py +133 -0
  61. mlscratch/unsupervised/pca.py +103 -0
  62. mlscratch/unsupervised/tsne.py +200 -0
  63. scratchkit-0.2.0.dist-info/METADATA +241 -0
  64. scratchkit-0.2.0.dist-info/RECORD +68 -0
  65. scratchkit-0.2.0.dist-info/WHEEL +5 -0
  66. scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
  67. scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
  68. scratchkit-0.2.0.dist-info/top_level.txt +1 -0
mlscratch/__init__.py ADDED
@@ -0,0 +1,56 @@
1
+ """
2
+ mlscratch
3
+ =========
4
+ Pure-NumPy from-scratch implementations of ML / AI / RL / Bayesian algorithms.
5
+ No PyTorch. No TensorFlow. No scikit-learn. Just numpy and the maths.
6
+
7
+ Sub-packages
8
+ ------------
9
+ mlscratch.supervised Supervised learning algorithms
10
+ mlscratch.unsupervised Unsupervised learning algorithms
11
+ mlscratch.bayesian Bayesian methods
12
+ mlscratch.reinforcement Reinforcement learning algorithms
13
+ mlscratch.neural Neural network architectures
14
+ mlscratch.metrics Classification & regression evaluation metrics
15
+ mlscratch.preprocessing Scalers, encoders, polynomial features, train_test_split
16
+
17
+ Quick-start
18
+ -----------
19
+ >>> from mlscratch.unsupervised import KMeans
20
+ >>> from mlscratch.supervised import LinearRegression, RandomForestClassifier
21
+ >>> from mlscratch.bayesian import GaussianNB
22
+ >>> from mlscratch.reinforcement import QLearning
23
+ >>> from mlscratch.neural import MultiLayerPerceptron
24
+ >>> from mlscratch.metrics import accuracy_score
25
+ >>> from mlscratch.preprocessing import StandardScaler, train_test_split
26
+
27
+ Install
28
+ -------
29
+ pip install scratchkit # core (numpy only); import name is "mlscratch"
30
+ pip install "scratchkit[dev]" # + pytest, ruff, black, mypy
31
+ pip install "scratchkit[docs]" # + mkdocs
32
+ pip install "scratchkit[all]" # everything
33
+
34
+ Links
35
+ -----
36
+ GitHub : https://github.com/Mattral/ML-AI-Algorithms-from-scratch
37
+ Issues : https://github.com/Mattral/ML-AI-Algorithms-from-scratch/issues
38
+ Changelog : https://github.com/Mattral/ML-AI-Algorithms-from-scratch/blob/main/CHANGELOG.md
39
+ """
40
+
41
+ from importlib.metadata import PackageNotFoundError, version
42
+
43
+ try:
44
+ # The importable package is "mlscratch", but the PyPI distribution it
45
+ # ships under is "scratchkit" (the name "mlscratch" was already taken
46
+ # by an unrelated project) — look up metadata under the distribution
47
+ # name, the same way `import bs4` resolves "beautifulsoup4" metadata.
48
+ __version__: str = version("scratchkit")
49
+ except PackageNotFoundError:
50
+ # Package is not installed (e.g. running directly from source tree)
51
+ __version__ = "0.0.0+dev"
52
+
53
+ __author__ = "Mattral"
54
+ __license__ = "Apache-2.0"
55
+
56
+ __all__ = ["__version__", "__author__", "__license__"]
mlscratch/__main__.py ADDED
@@ -0,0 +1,118 @@
1
+ """
2
+ mlscratch CLI
3
+ =============
4
+ Usage
5
+ -----
6
+ python -m mlscratch # same as 'info'
7
+ python -m mlscratch version # print version
8
+ python -m mlscratch info # version + sub-package summary
9
+ python -m mlscratch list # list all available algorithm classes
10
+ python -m mlscratch list supervised
11
+ python -m mlscratch list unsupervised
12
+ python -m mlscratch list bayesian
13
+ python -m mlscratch list reinforcement
14
+ python -m mlscratch list neural
15
+ python -m mlscratch list metrics
16
+ python -m mlscratch list preprocessing
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import sys
23
+
24
+
25
+ # ── helpers ───────────────────────────────────────────────────────────────────
26
+
27
+ def _print_version() -> None:
28
+ import mlscratch
29
+ print(f"mlscratch {mlscratch.__version__}")
30
+
31
+
32
+ def _print_info() -> None:
33
+ import mlscratch
34
+ import numpy as np
35
+
36
+ print(f"\nmlscratch {mlscratch.__version__}")
37
+ print(f" numpy : {np.__version__}")
38
+ print(f" python : {sys.version.split()[0]}")
39
+ print()
40
+
41
+ modules = {
42
+ "supervised": "mlscratch.supervised",
43
+ "unsupervised": "mlscratch.unsupervised",
44
+ "bayesian": "mlscratch.bayesian",
45
+ "reinforcement": "mlscratch.reinforcement",
46
+ "neural": "mlscratch.neural",
47
+ "metrics": "mlscratch.metrics",
48
+ "preprocessing": "mlscratch.preprocessing",
49
+ }
50
+
51
+ for name, mod_path in modules.items():
52
+ try:
53
+ import importlib
54
+ mod = importlib.import_module(mod_path)
55
+ n = len(getattr(mod, "__all__", []))
56
+ print(f" {name:<18} {n} public symbol(s)")
57
+ except ImportError:
58
+ print(f" {name:<18} not yet installed")
59
+
60
+ print()
61
+
62
+
63
+ def _list_algorithms(subpackage: str | None = None) -> None:
64
+ import importlib
65
+
66
+ targets = (
67
+ {"supervised", "unsupervised", "bayesian", "reinforcement", "neural", "metrics", "preprocessing"}
68
+ if subpackage is None
69
+ else {subpackage}
70
+ )
71
+
72
+ for name in sorted(targets):
73
+ mod_path = f"mlscratch.{name}"
74
+ try:
75
+ mod = importlib.import_module(mod_path)
76
+ symbols = getattr(mod, "__all__", [])
77
+ print(f"\n[{name}]")
78
+ for s in sorted(symbols):
79
+ print(f" {s}")
80
+ except ImportError:
81
+ print(f"\n[{name}] — not available (sub-package not installed)")
82
+
83
+
84
+ # ── entry point ───────────────────────────────────────────────────────────────
85
+
86
+ def main(argv: list[str] | None = None) -> int:
87
+ parser = argparse.ArgumentParser(
88
+ prog="mlscratch",
89
+ description="mlscratch — ML algorithms from scratch",
90
+ )
91
+ subparsers = parser.add_subparsers(dest="command")
92
+
93
+ subparsers.add_parser("version", help="Print version and exit")
94
+ subparsers.add_parser("info", help="Print version, numpy, and sub-package summary")
95
+
96
+ list_parser = subparsers.add_parser("list", help="List available algorithms")
97
+ list_parser.add_argument(
98
+ "subpackage",
99
+ nargs="?",
100
+ choices=["supervised", "unsupervised", "bayesian", "reinforcement", "neural", "metrics", "preprocessing"],
101
+ default=None,
102
+ help="Restrict listing to one sub-package",
103
+ )
104
+
105
+ args = parser.parse_args(argv)
106
+
107
+ if args.command in (None, "info"):
108
+ _print_info()
109
+ elif args.command == "version":
110
+ _print_version()
111
+ elif args.command == "list":
112
+ _list_algorithms(args.subpackage)
113
+
114
+ return 0
115
+
116
+
117
+ if __name__ == "__main__":
118
+ sys.exit(main())
@@ -0,0 +1,53 @@
1
+ """
2
+ mlscratch.bayesian
3
+ ==================
4
+ From-scratch implementations of Bayesian learning algorithms.
5
+ Drop these files alongside existing code in src/mlscratch/bayesian/.
6
+
7
+ Algorithms
8
+ ----------
9
+ GaussianNB – Gaussian Naive Bayes
10
+ MultinomialNB – Multinomial Naive Bayes
11
+ BernoulliNB – Bernoulli Naive Bayes
12
+ BayesianLinearRegression – Conjugate Gaussian prior regression
13
+ GaussianProcessRegressor – GP Regression (RBF, Matern52, Linear, Periodic)
14
+ RBFKernel – RBF / Squared-Exponential kernel
15
+ Matern52Kernel – Matern 5/2 kernel
16
+ LinearKernel – Linear kernel
17
+ PeriodicKernel – Periodic kernel
18
+ HiddenMarkovModel – Discrete HMM (forward-backward, Viterbi, Baum-Welch)
19
+ BayesianNeuralNetwork – BNN via mean-field variational inference
20
+ BayesianNetwork – Discrete DAG (variable elimination, sampling)
21
+ KalmanFilter – Linear Kalman Filter + RTS Smoother
22
+ """
23
+
24
+ from .naive_bayes import GaussianNB, MultinomialNB, BernoulliNB # noqa: F401
25
+ from .bayesian_linear_regression import BayesianLinearRegression # noqa: F401
26
+ from .gaussian_process import ( # noqa: F401
27
+ GaussianProcessRegressor,
28
+ RBFKernel,
29
+ Matern52Kernel,
30
+ LinearKernel,
31
+ PeriodicKernel,
32
+ )
33
+ from .hmm import HiddenMarkovModel # noqa: F401
34
+ from .bayesian_nn import BayesianNeuralNetwork, BayesianLayer # noqa: F401
35
+ from .bayesian_network import BayesianNetwork # noqa: F401
36
+ from .kalman_filter import KalmanFilter # noqa: F401
37
+
38
+ __all__ = [
39
+ "GaussianNB",
40
+ "MultinomialNB",
41
+ "BernoulliNB",
42
+ "BayesianLinearRegression",
43
+ "GaussianProcessRegressor",
44
+ "RBFKernel",
45
+ "Matern52Kernel",
46
+ "LinearKernel",
47
+ "PeriodicKernel",
48
+ "HiddenMarkovModel",
49
+ "BayesianNeuralNetwork",
50
+ "BayesianLayer",
51
+ "BayesianNetwork",
52
+ "KalmanFilter",
53
+ ]
@@ -0,0 +1,171 @@
1
+ """
2
+ Bayesian Linear Regression
3
+ ===========================
4
+ Treats the model weights as a distribution, not a point estimate.
5
+
6
+ Model
7
+ -----
8
+ y = X w + ε, ε ~ N(0, β⁻¹)
9
+ w ~ N(0, α⁻¹ I)
10
+
11
+ With a Gaussian prior over weights and Gaussian noise, the posterior over
12
+ weights is also Gaussian (conjugate prior):
13
+
14
+ p(w | X, y) = N(w | m_N, S_N)
15
+
16
+ S_N = (α I + β X^T X)^{-1}
17
+ m_N = β S_N X^T y
18
+
19
+ Predictions are also Gaussian:
20
+
21
+ p(y* | x*, X, y) = N(y* | m_N^T x*, σ_N²(x*))
22
+ σ_N²(x*) = β⁻¹ + x*^T S_N x*
23
+
24
+ Parameters α (weight precision) and β (noise precision) can be fixed or
25
+ estimated via type-II maximum likelihood (evidence approximation).
26
+
27
+ Reference: Bishop, PRML, Chapter 3.
28
+ Only numpy is used.
29
+ """
30
+
31
+ import numpy as np
32
+
33
+
34
+ class BayesianLinearRegression:
35
+ """
36
+ Bayesian Linear Regression with conjugate Gaussian prior.
37
+
38
+ Parameters
39
+ ----------
40
+ alpha : float
41
+ Prior precision over weights (1 / prior variance).
42
+ beta : float
43
+ Noise precision (1 / noise variance).
44
+ fit_intercept : bool
45
+ If True, prepend a column of ones to X.
46
+ optimize_hyperparams : bool
47
+ If True, estimate alpha and beta via evidence maximisation
48
+ (iterative re-estimation). Ignored if False.
49
+ max_iter : int
50
+ Maximum iterations for hyperparameter optimisation.
51
+ tol : float
52
+ Convergence tolerance for hyperparameter optimisation.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ alpha: float = 1.0,
58
+ beta: float = 1.0,
59
+ fit_intercept: bool = True,
60
+ optimize_hyperparams: bool = False,
61
+ max_iter: int = 300,
62
+ tol: float = 1e-5,
63
+ ):
64
+ self.alpha = alpha
65
+ self.beta = beta
66
+ self.fit_intercept = fit_intercept
67
+ self.optimize_hyperparams = optimize_hyperparams
68
+ self.max_iter = max_iter
69
+ self.tol = tol
70
+
71
+ self.m_N_ = None # posterior mean (n_features,)
72
+ self.S_N_ = None # posterior cov (n_features, n_features)
73
+
74
+ # ------------------------------------------------------------------
75
+ # Helpers
76
+ # ------------------------------------------------------------------
77
+
78
+ def _add_bias(self, X: np.ndarray) -> np.ndarray:
79
+ if self.fit_intercept:
80
+ return np.column_stack([np.ones(len(X)), X])
81
+ return X
82
+
83
+ def _compute_posterior(
84
+ self, X: np.ndarray, y: np.ndarray
85
+ ) -> tuple:
86
+ """Return (m_N, S_N) given current alpha and beta."""
87
+ n_features = X.shape[1]
88
+ S_N_inv = self.alpha * np.eye(n_features) + self.beta * X.T @ X
89
+ S_N = np.linalg.inv(S_N_inv)
90
+ m_N = self.beta * S_N @ X.T @ y
91
+ return m_N, S_N
92
+
93
+ # ------------------------------------------------------------------
94
+ # Public API
95
+ # ------------------------------------------------------------------
96
+
97
+ def fit(self, X: np.ndarray, y: np.ndarray) -> "BayesianLinearRegression":
98
+ """
99
+ Compute posterior distribution over weights.
100
+
101
+ Parameters
102
+ ----------
103
+ X : ndarray of shape (n_samples, n_features)
104
+ y : ndarray of shape (n_samples,)
105
+ """
106
+ X_ = self._add_bias(X)
107
+ n_samples, n_features = X_.shape
108
+
109
+ if self.optimize_hyperparams:
110
+ # Evidence approximation (Bishop PRML §3.5.2)
111
+ alpha, beta = self.alpha, self.beta
112
+ for _ in range(self.max_iter):
113
+ m_N, S_N = self._compute_posterior(X_, y)
114
+ # Eigenvalues of β X^T X
115
+ eigvals = np.linalg.eigvalsh(beta * X_.T @ X_)
116
+ gamma = np.sum(eigvals / (alpha + eigvals))
117
+
118
+ alpha_new = gamma / (m_N @ m_N)
119
+ residuals = y - X_ @ m_N
120
+ beta_new = (n_samples - gamma) / (residuals @ residuals)
121
+
122
+ alpha_new = max(alpha_new, 1e-10)
123
+ beta_new = max(beta_new, 1e-10)
124
+
125
+ if abs(alpha_new - alpha) < self.tol and abs(beta_new - beta) < self.tol:
126
+ alpha, beta = alpha_new, beta_new
127
+ break
128
+ alpha, beta = alpha_new, beta_new
129
+
130
+ self.alpha, self.beta = alpha, beta
131
+
132
+ self.m_N_, self.S_N_ = self._compute_posterior(X_, y)
133
+ return self
134
+
135
+ def predict(self, X: np.ndarray, return_std: bool = False):
136
+ """
137
+ Predictive mean (and optionally std) for new inputs.
138
+
139
+ Parameters
140
+ ----------
141
+ X : ndarray of shape (n_samples, n_features)
142
+ return_std : bool
143
+ If True, also return the predictive standard deviation.
144
+
145
+ Returns
146
+ -------
147
+ y_mean : ndarray of shape (n_samples,)
148
+ y_std : ndarray of shape (n_samples,) [only if return_std=True]
149
+ """
150
+ X_ = self._add_bias(X)
151
+ y_mean = X_ @ self.m_N_
152
+
153
+ if not return_std:
154
+ return y_mean
155
+
156
+ # Predictive variance: β⁻¹ + x^T S_N x
157
+ var = (1.0 / self.beta) + np.einsum("ij,jk,ik->i", X_, self.S_N_, X_)
158
+ return y_mean, np.sqrt(np.maximum(var, 0.0))
159
+
160
+ @property
161
+ def coef_(self) -> np.ndarray:
162
+ """Posterior mean weights (excluding bias if fit_intercept=True)."""
163
+ if self.fit_intercept:
164
+ return self.m_N_[1:]
165
+ return self.m_N_
166
+
167
+ @property
168
+ def intercept_(self) -> float:
169
+ if self.fit_intercept:
170
+ return float(self.m_N_[0])
171
+ return 0.0
@@ -0,0 +1,248 @@
1
+ """
2
+ Bayesian Network
3
+ =================
4
+ A Directed Acyclic Graph (DAG) where each node represents a discrete random
5
+ variable and each node stores a Conditional Probability Table (CPT).
6
+
7
+ Supports:
8
+ - Manual CPT specification
9
+ - Exact inference via Variable Elimination
10
+ - Ancestral sampling
11
+
12
+ Notation
13
+ --------
14
+ Each variable is a string name. Observations are dicts {name: value}.
15
+ CPTs are given as numpy arrays indexed in the order (var, *parents).
16
+
17
+ Only numpy and Python stdlib are used.
18
+ """
19
+
20
+ import numpy as np
21
+ from itertools import product
22
+
23
+
24
+ class BayesianNetwork:
25
+ """
26
+ Discrete Bayesian Network.
27
+
28
+ Usage
29
+ -----
30
+ >>> bn = BayesianNetwork()
31
+ >>> bn.add_variable('Rain', 2)
32
+ >>> bn.add_variable('Sprinkler',2, parents=['Rain'])
33
+ >>> bn.add_variable('Wet', 2, parents=['Rain','Sprinkler'])
34
+ >>> bn.set_cpt('Rain', np.array([0.8, 0.2]))
35
+ >>> bn.set_cpt('Sprinkler', np.array([[0.6,0.4],[0.99,0.01]]))
36
+ >>> bn.set_cpt('Wet', np.array([[[0.99,0.01],[0.1,0.9]],
37
+ ... [[0.1,0.9],[0.01,0.99]]]))
38
+ >>> bn.query('Wet', evidence={'Rain':1})
39
+ """
40
+
41
+ def __init__(self):
42
+ self._parents: dict[str, list] = {} # name → parent names
43
+ self._cpt: dict[str, np.ndarray] = {} # name → CPT array
44
+ self._domain: dict[str, int] = {} # name → domain size
45
+ self._order: list[str] = [] # topological insertion order
46
+
47
+ # ------------------------------------------------------------------
48
+ # Graph construction
49
+ # ------------------------------------------------------------------
50
+
51
+ def add_variable(
52
+ self,
53
+ name: str,
54
+ domain_size: int,
55
+ parents: list | None = None,
56
+ ) -> None:
57
+ """
58
+ Register a variable.
59
+
60
+ Parameters
61
+ ----------
62
+ name : str
63
+ domain_size : int
64
+ Number of possible values (0 … domain_size-1).
65
+ parents : list of str or None
66
+ Names of parent variables (must already be added).
67
+ """
68
+ self._domain[name] = domain_size
69
+ self._parents[name] = parents or []
70
+ self._order.append(name)
71
+
72
+ def set_cpt(self, name: str, cpt: np.ndarray) -> None:
73
+ """
74
+ Set the CPT for a variable.
75
+
76
+ The array shape must be:
77
+ (domain_size,) for root nodes (no parents)
78
+ (*parent_domain_sizes, domain_size) for nodes with parents
79
+ """
80
+ self._cpt[name] = np.array(cpt, dtype=float)
81
+
82
+ # ------------------------------------------------------------------
83
+ # Topological order (Kahn's algorithm)
84
+ # ------------------------------------------------------------------
85
+
86
+ def _topological_sort(self) -> list:
87
+ in_degree = {n: len(self._parents[n]) for n in self._order}
88
+ queue = [n for n in self._order if in_degree[n] == 0]
89
+ result = []
90
+ while queue:
91
+ node = queue.pop(0)
92
+ result.append(node)
93
+ for child in self._order:
94
+ if node in self._parents[child]:
95
+ in_degree[child] -= 1
96
+ if in_degree[child] == 0:
97
+ queue.append(child)
98
+ return result
99
+
100
+ # ------------------------------------------------------------------
101
+ # Ancestral sampling
102
+ # ------------------------------------------------------------------
103
+
104
+ def sample(self, n_samples: int = 1, random_state=None) -> list:
105
+ """
106
+ Generate samples by ancestral sampling.
107
+
108
+ Returns
109
+ -------
110
+ samples : list of dicts {variable_name: value}
111
+ """
112
+ rng = np.random.default_rng(random_state)
113
+ order = self._topological_sort()
114
+ results = []
115
+
116
+ for _ in range(n_samples):
117
+ assignment = {}
118
+ for var in order:
119
+ parents = self._parents[var]
120
+ cpt = self._cpt[var]
121
+ if not parents:
122
+ probs = cpt
123
+ else:
124
+ idx = tuple(assignment[p] for p in parents)
125
+ probs = cpt[idx]
126
+ assignment[var] = int(rng.choice(len(probs), p=probs))
127
+ results.append(assignment)
128
+
129
+ return results if n_samples > 1 else results[0]
130
+
131
+ # ------------------------------------------------------------------
132
+ # Variable Elimination (exact inference)
133
+ # ------------------------------------------------------------------
134
+
135
+ def query(
136
+ self,
137
+ query_var: str,
138
+ evidence: dict | None = None,
139
+ ) -> np.ndarray:
140
+ """
141
+ Compute P(query_var | evidence) via variable elimination.
142
+
143
+ Parameters
144
+ ----------
145
+ query_var : str
146
+ evidence : dict {var_name: observed_value} or None
147
+
148
+ Returns
149
+ -------
150
+ proba : ndarray of shape (domain_size_of_query_var,)
151
+ """
152
+ evidence = evidence or {}
153
+
154
+ # Build initial factors from CPTs
155
+ # A factor maps a tuple of variable names to an ndarray
156
+ factors: list[tuple[tuple, np.ndarray]] = []
157
+
158
+ for var in self._order:
159
+ cpt = self._cpt[var].copy()
160
+ scope = tuple(self._parents[var] + [var])
161
+ # Reduce observed variables
162
+ reduced_scope = []
163
+ reduced_cpt = cpt
164
+ for i, v in enumerate(scope):
165
+ if v in evidence:
166
+ # Index into that axis
167
+ sl = [slice(None)] * len(scope)
168
+ sl[i] = evidence[v]
169
+ reduced_cpt = reduced_cpt[tuple(sl)]
170
+ else:
171
+ reduced_scope.append(v)
172
+ factors.append((tuple(reduced_scope), reduced_cpt))
173
+
174
+ # Determine elimination order: all non-query, non-evidence variables
175
+ to_eliminate = [
176
+ v for v in self._order
177
+ if v != query_var and v not in evidence
178
+ ]
179
+
180
+ for var in to_eliminate:
181
+ # Collect factors that involve `var`
182
+ relevant = [(s, f) for s, f in factors if var in s]
183
+ remaining = [(s, f) for s, f in factors if var not in s]
184
+
185
+ # Multiply relevant factors
186
+ product_scope, product_factor = self._factor_product(relevant)
187
+
188
+ # Sum out `var`
189
+ var_idx = list(product_scope).index(var)
190
+ summed = np.sum(product_factor, axis=var_idx)
191
+ new_scope = tuple(s for s in product_scope if s != var)
192
+
193
+ remaining.append((new_scope, summed))
194
+ factors = remaining
195
+
196
+ # Multiply remaining factors
197
+ if not factors:
198
+ return np.ones(self._domain[query_var]) / self._domain[query_var]
199
+
200
+ final_scope, final_factor = self._factor_product(factors)
201
+
202
+ # Sum out everything except query_var
203
+ while len(final_scope) > 1:
204
+ for i, v in enumerate(final_scope):
205
+ if v != query_var:
206
+ final_factor = np.sum(final_factor, axis=i)
207
+ final_scope = tuple(s for j, s in enumerate(final_scope) if j != i)
208
+ break
209
+
210
+ result = final_factor.ravel()
211
+ total = result.sum()
212
+ return result / total if total > 0 else result
213
+
214
+ def _factor_product(
215
+ self, factors: list[tuple[tuple, np.ndarray]]
216
+ ) -> tuple[tuple, np.ndarray]:
217
+ """Multiply a list of (scope, array) factors together."""
218
+ if not factors:
219
+ return ((), np.array(1.0))
220
+
221
+ # Compute union scope (maintaining order)
222
+ union_scope = []
223
+ for scope, _ in factors:
224
+ for v in scope:
225
+ if v not in union_scope:
226
+ union_scope.append(v)
227
+ union_scope = tuple(union_scope)
228
+
229
+ # Build shape
230
+ shape = tuple(self._domain[v] for v in union_scope)
231
+ result = np.ones(shape)
232
+
233
+ for scope, factor in factors:
234
+ # Expand factor axes to match union_scope
235
+ expand_axes = [union_scope.index(v) for v in scope]
236
+ expanded = np.ones(shape)
237
+ # Use np.einsum-style axis alignment
238
+ source_shape = [self._domain[v] for v in scope]
239
+ reshaped = factor.reshape(source_shape)
240
+
241
+ # Map each axis of scope into union_scope
242
+ new_shape = [1] * len(union_scope)
243
+ for i, ax in enumerate(expand_axes):
244
+ new_shape[ax] = source_shape[i]
245
+ expanded = reshaped.reshape(new_shape)
246
+ result = result * expanded
247
+
248
+ return union_scope, result