scratchkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. mlscratch/__init__.py +56 -0
  2. mlscratch/__main__.py +118 -0
  3. mlscratch/bayesian/__init__.py +53 -0
  4. mlscratch/bayesian/bayesian_linear_regression.py +171 -0
  5. mlscratch/bayesian/bayesian_network.py +248 -0
  6. mlscratch/bayesian/bayesian_nn.py +315 -0
  7. mlscratch/bayesian/gaussian_process.py +207 -0
  8. mlscratch/bayesian/hmm.py +277 -0
  9. mlscratch/bayesian/init.py +52 -0
  10. mlscratch/bayesian/kalman_filter.py +182 -0
  11. mlscratch/bayesian/naive_bayes.py +209 -0
  12. mlscratch/metrics/__init__.py +59 -0
  13. mlscratch/metrics/classification.py +365 -0
  14. mlscratch/metrics/regression.py +79 -0
  15. mlscratch/neural/__init__.py +121 -0
  16. mlscratch/neural/attention.py +420 -0
  17. mlscratch/neural/autoencoder.py +543 -0
  18. mlscratch/neural/boltzmann.py +231 -0
  19. mlscratch/neural/cnn.py +593 -0
  20. mlscratch/neural/cvnn.py +322 -0
  21. mlscratch/neural/gan.py +364 -0
  22. mlscratch/neural/hopfield.py +193 -0
  23. mlscratch/neural/perceptron.py +398 -0
  24. mlscratch/neural/rbf_network.py +230 -0
  25. mlscratch/neural/recurrent.py +569 -0
  26. mlscratch/preprocessing/__init__.py +38 -0
  27. mlscratch/preprocessing/encoders.py +140 -0
  28. mlscratch/preprocessing/model_selection.py +119 -0
  29. mlscratch/preprocessing/polynomial.py +105 -0
  30. mlscratch/preprocessing/scalers.py +220 -0
  31. mlscratch/py.typed +0 -0
  32. mlscratch/reinforcement/__init__.py +59 -0
  33. mlscratch/reinforcement/ddpg.py +363 -0
  34. mlscratch/reinforcement/dqn.py +319 -0
  35. mlscratch/reinforcement/ppo.py +452 -0
  36. mlscratch/reinforcement/q_learning.py +352 -0
  37. mlscratch/reinforcement/sac.py +382 -0
  38. mlscratch/reinforcement/utils.py +594 -0
  39. mlscratch/supervised/__init__.py +76 -0
  40. mlscratch/supervised/_validation.py +50 -0
  41. mlscratch/supervised/adaboost.py +255 -0
  42. mlscratch/supervised/decision_tree.py +495 -0
  43. mlscratch/supervised/gradient_boosting.py +354 -0
  44. mlscratch/supervised/knn.py +234 -0
  45. mlscratch/supervised/lasso_regression.py +125 -0
  46. mlscratch/supervised/linear_models.py +459 -0
  47. mlscratch/supervised/linear_regression.py +197 -0
  48. mlscratch/supervised/logistic_regression.py +119 -0
  49. mlscratch/supervised/naive_bayes.py +113 -0
  50. mlscratch/supervised/random_forest.py +321 -0
  51. mlscratch/supervised/ridge_regression.py +93 -0
  52. mlscratch/supervised/svm.py +356 -0
  53. mlscratch/unsupervised/__init__.py +39 -0
  54. mlscratch/unsupervised/apriori.py +178 -0
  55. mlscratch/unsupervised/dbscan.py +141 -0
  56. mlscratch/unsupervised/gmm.py +204 -0
  57. mlscratch/unsupervised/hierarchical_clustering.py +137 -0
  58. mlscratch/unsupervised/ica.py +167 -0
  59. mlscratch/unsupervised/kmeans.py +135 -0
  60. mlscratch/unsupervised/kmedoids.py +133 -0
  61. mlscratch/unsupervised/pca.py +103 -0
  62. mlscratch/unsupervised/tsne.py +200 -0
  63. scratchkit-0.2.0.dist-info/METADATA +241 -0
  64. scratchkit-0.2.0.dist-info/RECORD +68 -0
  65. scratchkit-0.2.0.dist-info/WHEEL +5 -0
  66. scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
  67. scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
  68. scratchkit-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,356 @@
1
+ r"""
2
+ Kernel Support Vector Classifier
3
+ =================================
4
+ A soft-margin SVM solved in its dual form by Platt's Sequential Minimal
5
+ Optimization (SMO) algorithm — pure numpy, no external QP solver.
6
+
7
+ Primal problem (for reference; never solved directly):
8
+
9
+ .. math::
10
+ \min_{w,b,\xi} \ \tfrac12\|w\|^2 + C\sum_i \xi_i
11
+ \quad \text{s.t.} \quad y_i(w^\top \phi(x_i) + b) \ge 1-\xi_i,\ \xi_i \ge 0
12
+
13
+ Dual problem (what SMO actually maximises):
14
+
15
+ .. math::
16
+ \max_{\alpha} \ \sum_i \alpha_i - \tfrac12\sum_{i,j}\alpha_i\alpha_j y_i y_j K(x_i,x_j)
17
+ \quad \text{s.t.} \quad 0 \le \alpha_i \le C,\ \ \sum_i \alpha_i y_i = 0
18
+
19
+ SMO repeatedly picks a pair :math:`(\alpha_i,\alpha_j)` and solves the
20
+ resulting 1-D constrained quadratic exactly in closed form — this is
21
+ the simplified heuristic from Platt's original paper / the CS229 SMO
22
+ notes, with random second-variable selection rather than the full
23
+ working-set heuristic.
24
+
25
+ Kernels
26
+ -------
27
+ - ``'linear'`` : :math:`K(x,y) = x^\top y`
28
+ - ``'poly'`` : :math:`K(x,y) = (\gamma x^\top y + c_0)^d`
29
+ - ``'rbf'`` : :math:`K(x,y) = \exp(-\gamma\|x-y\|^2)`
30
+ - ``'sigmoid'`` : :math:`K(x,y) = \tanh(\gamma x^\top y + c_0)`
31
+ - a user-supplied callable ``K(X, Y) -> (n_X, n_Y)`` Gram matrix
32
+
33
+ Multiclass
34
+ ----------
35
+ ``SVC`` is natively binary; if more than two classes are present at
36
+ ``fit`` time it transparently trains one binary SVM per class
37
+ (one-vs-rest) and predicts via arg-max of the binary decision
38
+ functions.
39
+
40
+ Complexity
41
+ ----------
42
+ - Training : O(n^2 d) to build the Gram matrix, then O(n) work per SMO
43
+ step for a heuristically-bounded number of sweeps.
44
+ - Inference: O(n_SV * d) per sample.
45
+ """
46
+
47
+ from __future__ import annotations
48
+
49
+ from collections.abc import Callable
50
+
51
+ import numpy as np
52
+ from numpy.typing import ArrayLike, NDArray
53
+
54
+ from ._validation import validate_x, validate_xy
55
+
56
+ FloatArray = NDArray[np.float64]
57
+ IntArray = NDArray[np.int64]
58
+
59
+ _EPS = 1e-12
60
+ _MAX_SWEEPS = 2000
61
+
62
+
63
+ # ──────────────────────────────────────────────────────────────────────────
64
+ # Kernels
65
+ # ──────────────────────────────────────────────────────────────────────────
66
+
67
+
68
+ def _linear_kernel(X: FloatArray, Y: FloatArray) -> FloatArray:
69
+ return X @ Y.T
70
+
71
+
72
+ def _poly_kernel(
73
+ X: FloatArray, Y: FloatArray, degree: int, gamma: float, coef0: float
74
+ ) -> FloatArray:
75
+ return (gamma * (X @ Y.T) + coef0) ** degree
76
+
77
+
78
+ def _rbf_kernel(X: FloatArray, Y: FloatArray, gamma: float) -> FloatArray:
79
+ X_sq = np.sum(X**2, axis=1)[:, None]
80
+ Y_sq = np.sum(Y**2, axis=1)[None, :]
81
+ sq_dists = np.maximum(X_sq + Y_sq - 2.0 * (X @ Y.T), 0.0)
82
+ return np.exp(-gamma * sq_dists)
83
+
84
+
85
+ def _sigmoid_kernel(X: FloatArray, Y: FloatArray, gamma: float, coef0: float) -> FloatArray:
86
+ return np.tanh(gamma * (X @ Y.T) + coef0)
87
+
88
+
89
+ _BUILTIN_KERNELS = {"linear", "poly", "rbf", "sigmoid"}
90
+
91
+
92
+ # ──────────────────────────────────────────────────────────────────────────
93
+ # SVC
94
+ # ──────────────────────────────────────────────────────────────────────────
95
+
96
+
97
+ class SVC:
98
+ """Kernel Support Vector Classifier trained via Sequential Minimal Optimization.
99
+
100
+ Parameters
101
+ ----------
102
+ C : float, default=1.0
103
+ Inverse regularisation strength (penalty on margin violations).
104
+ kernel : str | Callable, default='rbf'
105
+ ``'linear'``, ``'poly'``, ``'rbf'``, ``'sigmoid'``, or a callable
106
+ ``K(X, Y) -> ndarray`` returning a Gram matrix of shape
107
+ ``(len(X), len(Y))``.
108
+ degree : int, default=3
109
+ Degree for the ``'poly'`` kernel.
110
+ gamma : float | str, default='scale'
111
+ Kernel coefficient for ``'rbf'``/``'poly'``/``'sigmoid'``.
112
+ ``'scale'`` uses ``1 / (n_features * X.var())``, ``'auto'`` uses
113
+ ``1 / n_features``.
114
+ coef0 : float, default=0.0
115
+ Independent term for ``'poly'``/``'sigmoid'`` kernels.
116
+ tol : float, default=1e-3
117
+ KKT violation tolerance used to select active variables.
118
+ max_iter : int, default=10
119
+ Number of consecutive full sweeps over the data with *no*
120
+ alpha updates before SMO is declared converged (Platt's
121
+ ``max_passes``). A hard cap of a few thousand total sweeps
122
+ applies regardless, to guarantee termination.
123
+ random_state : int | None, default=None
124
+ Seed for the random selection of the second SMO variable.
125
+
126
+ Attributes
127
+ ----------
128
+ support_ : indices of the support vectors within the training data
129
+ support_vectors_ : the support vectors themselves
130
+ dual_coef_ : :math:`\\alpha_i y_i` for each support vector
131
+ intercept_ : the bias term ``b``
132
+ classes_ : sorted unique labels seen during fit
133
+ multiclass_ : whether one-vs-rest decomposition was used
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ C: float = 1.0,
139
+ kernel: str | Callable[[FloatArray, FloatArray], FloatArray] = "rbf",
140
+ degree: int = 3,
141
+ gamma: float | str = "scale",
142
+ coef0: float = 0.0,
143
+ tol: float = 1e-3,
144
+ max_iter: int = 10,
145
+ random_state: int | None = None,
146
+ ) -> None:
147
+ if not callable(kernel) and kernel not in _BUILTIN_KERNELS:
148
+ raise ValueError(f"kernel must be one of {_BUILTIN_KERNELS} or a callable.")
149
+ if C <= 0:
150
+ raise ValueError("C must be positive.")
151
+ self.C = float(C)
152
+ self.kernel = kernel
153
+ self.degree = int(degree)
154
+ self.gamma = gamma
155
+ self.coef0 = float(coef0)
156
+ self.tol = float(tol)
157
+ self.max_iter = int(max_iter)
158
+ self.random_state = random_state
159
+
160
+ self.classes_: NDArray | None = None
161
+ self.n_features_in_: int | None = None
162
+ self.multiclass_: bool = False
163
+
164
+ # binary-fit attributes
165
+ self.support_: IntArray | None = None
166
+ self.support_vectors_: FloatArray | None = None
167
+ self.dual_coef_: FloatArray | None = None
168
+ self.intercept_: float | None = None
169
+ self.n_support_: int | None = None
170
+ self.n_iter_: int | None = None
171
+
172
+ # multiclass (one-vs-rest) attributes
173
+ self._ovr_estimators_: list[SVC] | None = None
174
+
175
+ self._gamma_value: float | None = None
176
+ self._fitted = False
177
+
178
+ # -- kernel plumbing -----------------------------------------------------
179
+
180
+ def _resolve_gamma(self, X: FloatArray) -> float:
181
+ if isinstance(self.gamma, str):
182
+ if self.gamma == "scale":
183
+ var = float(X.var())
184
+ return 1.0 / (X.shape[1] * var) if var > _EPS else 1.0
185
+ if self.gamma == "auto":
186
+ return 1.0 / X.shape[1]
187
+ raise ValueError("gamma must be 'scale', 'auto', or a positive float.")
188
+ if self.gamma <= 0:
189
+ raise ValueError("gamma must be positive.")
190
+ return float(self.gamma)
191
+
192
+ def _kernel_fn(self, X: FloatArray, Y: FloatArray) -> FloatArray:
193
+ if callable(self.kernel):
194
+ return np.asarray(self.kernel(X, Y), dtype=np.float64)
195
+ if self.kernel == "linear":
196
+ return _linear_kernel(X, Y)
197
+ if self.kernel == "poly":
198
+ return _poly_kernel(X, Y, self.degree, self._gamma_value, self.coef0)
199
+ if self.kernel == "rbf":
200
+ return _rbf_kernel(X, Y, self._gamma_value)
201
+ return _sigmoid_kernel(X, Y, self._gamma_value, self.coef0)
202
+
203
+ # -- public API -----------------------------------------------------------
204
+
205
+ def fit(self, X: ArrayLike, y: ArrayLike) -> SVC:
206
+ """Fit the SVM. Dispatches to one-vs-rest if >2 classes are present."""
207
+ X_arr, y_arr = validate_xy(X, y)
208
+ self.classes_ = np.unique(y_arr)
209
+ self.n_features_in_ = X_arr.shape[1]
210
+ self._gamma_value = self._resolve_gamma(X_arr)
211
+
212
+ if self.classes_.size < 2:
213
+ raise ValueError("SVC requires at least 2 classes.")
214
+ if self.classes_.size == 2:
215
+ self.multiclass_ = False
216
+ self._fit_binary(X_arr, y_arr)
217
+ else:
218
+ self.multiclass_ = True
219
+ self._ovr_estimators_ = []
220
+ for cls in self.classes_:
221
+ binary_y = (y_arr == cls).astype(np.int64) # {0, 1}, 1 = "is this class"
222
+ sub = SVC(
223
+ C=self.C,
224
+ kernel=self.kernel,
225
+ degree=self.degree,
226
+ gamma=self.gamma,
227
+ coef0=self.coef0,
228
+ tol=self.tol,
229
+ max_iter=self.max_iter,
230
+ random_state=self.random_state,
231
+ )
232
+ sub._gamma_value = self._gamma_value
233
+ sub.n_features_in_ = self.n_features_in_
234
+ sub._fit_binary(X_arr, binary_y)
235
+ self._ovr_estimators_.append(sub)
236
+ self._fitted = True
237
+ return self
238
+
239
+ def decision_function(self, X: ArrayLike) -> FloatArray:
240
+ """Signed distance to the separating hyperplane (margin).
241
+
242
+ For multiclass, returns shape ``(n_samples, n_classes)`` — one
243
+ one-vs-rest margin per class.
244
+ """
245
+ if not self._fitted:
246
+ raise RuntimeError("Call fit() before decision_function().")
247
+ X_arr = validate_x(X)
248
+ if self.multiclass_:
249
+ return np.column_stack([est.decision_function(X_arr) for est in self._ovr_estimators_])
250
+ K = self._kernel_fn(self.support_vectors_, X_arr)
251
+ return self.dual_coef_ @ K + self.intercept_
252
+
253
+ def predict(self, X: ArrayLike) -> NDArray:
254
+ """Predict class labels."""
255
+ if not self._fitted:
256
+ raise RuntimeError("Call fit() before predict().")
257
+ scores = self.decision_function(X)
258
+ if self.multiclass_:
259
+ return self.classes_[np.argmax(scores, axis=1)]
260
+ return np.where(scores >= 0.0, self.classes_[1], self.classes_[0])
261
+
262
+ def score(self, X: ArrayLike, y: ArrayLike) -> float:
263
+ """Return classification accuracy."""
264
+ X_arr, y_arr = validate_xy(X, y)
265
+ return float(np.mean(self.predict(X_arr) == y_arr))
266
+
267
+ # -- core binary SMO solver -------------------------------------------------
268
+
269
+ def _fit_binary(self, X_arr: FloatArray, y_arr: NDArray) -> None:
270
+ classes = np.unique(y_arr)
271
+ if classes.size != 2:
272
+ raise ValueError("_fit_binary requires exactly 2 classes.")
273
+ self.classes_ = classes
274
+ y_signed = np.where(y_arr == classes[0], -1.0, 1.0)
275
+ n = X_arr.shape[0]
276
+ K = self._kernel_fn(X_arr, X_arr)
277
+
278
+ alpha = np.zeros(n, dtype=np.float64)
279
+ b = 0.0
280
+ rng = np.random.default_rng(self.random_state)
281
+ ay = alpha * y_signed # kept in sync after every update
282
+
283
+ def f(i: int) -> float:
284
+ return float(np.dot(ay, K[:, i]) + b)
285
+
286
+ passes = 0
287
+ sweeps = 0
288
+ while passes < self.max_iter and sweeps < _MAX_SWEEPS:
289
+ num_changed = 0
290
+ for i in range(n):
291
+ Ei = f(i) - y_signed[i]
292
+ violates = (y_signed[i] * Ei < -self.tol and alpha[i] < self.C) or (
293
+ y_signed[i] * Ei > self.tol and alpha[i] > 0.0
294
+ )
295
+ if not violates:
296
+ continue
297
+
298
+ j = i
299
+ while j == i:
300
+ j = int(rng.integers(0, n))
301
+ Ej = f(j) - y_signed[j]
302
+
303
+ ai_old, aj_old = alpha[i], alpha[j]
304
+ if y_signed[i] != y_signed[j]:
305
+ lo = max(0.0, alpha[j] - alpha[i])
306
+ hi = min(self.C, self.C + alpha[j] - alpha[i])
307
+ else:
308
+ lo = max(0.0, alpha[i] + alpha[j] - self.C)
309
+ hi = min(self.C, alpha[i] + alpha[j])
310
+ if lo >= hi:
311
+ continue
312
+
313
+ eta = 2.0 * K[i, j] - K[i, i] - K[j, j]
314
+ if eta >= 0:
315
+ continue
316
+
317
+ alpha[j] -= y_signed[j] * (Ei - Ej) / eta
318
+ alpha[j] = min(hi, max(lo, alpha[j]))
319
+ if abs(alpha[j] - aj_old) < 1e-7:
320
+ continue
321
+ alpha[i] += y_signed[i] * y_signed[j] * (aj_old - alpha[j])
322
+
323
+ b1 = (
324
+ b
325
+ - Ei
326
+ - y_signed[i] * (alpha[i] - ai_old) * K[i, i]
327
+ - y_signed[j] * (alpha[j] - aj_old) * K[i, j]
328
+ )
329
+ b2 = (
330
+ b
331
+ - Ej
332
+ - y_signed[i] * (alpha[i] - ai_old) * K[i, j]
333
+ - y_signed[j] * (alpha[j] - aj_old) * K[j, j]
334
+ )
335
+ if 0.0 < alpha[i] < self.C:
336
+ b = b1
337
+ elif 0.0 < alpha[j] < self.C:
338
+ b = b2
339
+ else:
340
+ b = (b1 + b2) / 2.0
341
+
342
+ ay[i] = alpha[i] * y_signed[i]
343
+ ay[j] = alpha[j] * y_signed[j]
344
+ num_changed += 1
345
+
346
+ sweeps += 1
347
+ passes = passes + 1 if num_changed == 0 else 0
348
+
349
+ sv_mask = alpha > 1e-8
350
+ self.support_ = np.flatnonzero(sv_mask).astype(np.int64)
351
+ self.support_vectors_ = X_arr[sv_mask]
352
+ self.dual_coef_ = ay[sv_mask]
353
+ self.intercept_ = float(b)
354
+ self.n_support_ = int(sv_mask.sum())
355
+ self.n_iter_ = sweeps
356
+ self._fitted = True
@@ -0,0 +1,39 @@
1
+ """
2
+ mlscratch.unsupervised
3
+ ======================
4
+ From-scratch implementations of unsupervised learning algorithms.
5
+ Drop these files alongside the existing kmeans.py in src/mlscratch/unsupervised/.
6
+
7
+ New algorithms added
8
+ --------------------
9
+ DBSCAN – Density-based spatial clustering
10
+ PCA – Principal Component Analysis
11
+ GaussianMixtureModel – GMM via Expectation-Maximization
12
+ AgglomerativeClustering – Hierarchical agglomerative clustering
13
+ KMedoids – K-Medoids (PAM) clustering
14
+ Apriori – Association rule mining
15
+ FastICA – Independent Component Analysis (FastICA)
16
+ TSNE – t-SNE dimensionality reduction
17
+ """
18
+
19
+ from .kmeans import KMeans # noqa: F401
20
+ from .dbscan import DBSCAN # noqa: F401
21
+ from .pca import PCA # noqa: F401
22
+ from .gmm import GaussianMixtureModel # noqa: F401
23
+ from .hierarchical_clustering import AgglomerativeClustering # noqa: F401
24
+ from .kmedoids import KMedoids # noqa: F401
25
+ from .apriori import Apriori # noqa: F401
26
+ from .ica import FastICA # noqa: F401
27
+ from .tsne import TSNE # noqa: F401
28
+
29
+ __all__ = [
30
+ "KMeans",
31
+ "DBSCAN",
32
+ "PCA",
33
+ "GaussianMixtureModel",
34
+ "AgglomerativeClustering",
35
+ "KMedoids",
36
+ "Apriori",
37
+ "FastICA",
38
+ "TSNE",
39
+ ]
@@ -0,0 +1,178 @@
1
+ """
2
+ Apriori Algorithm for Association Rule Mining
3
+ ==============================================
4
+ Discovers frequent itemsets in a transaction database and generates
5
+ association rules with user-specified support and confidence thresholds.
6
+
7
+ Core concepts
8
+ -------------
9
+ - Support(A) = (# transactions containing A) / (# transactions total)
10
+ - Confidence(A→B) = Support(A ∪ B) / Support(A)
11
+ - Lift(A→B) = Confidence(A→B) / Support(B)
12
+
13
+ The Apriori property: every subset of a frequent itemset is frequent.
14
+ This prunes the candidate search space dramatically.
15
+
16
+ Only Python stdlib and numpy are used.
17
+ """
18
+
19
+ from itertools import combinations
20
+ from collections import defaultdict
21
+
22
+
23
+ class Apriori:
24
+ """
25
+ Apriori frequent-itemset mining and association-rule generation.
26
+
27
+ Parameters
28
+ ----------
29
+ min_support : float
30
+ Minimum support threshold in [0, 1].
31
+ min_confidence : float
32
+ Minimum confidence threshold for rule generation, in [0, 1].
33
+ min_lift : float
34
+ Minimum lift for rule generation (default 1.0, i.e. no filter).
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ min_support: float = 0.5,
40
+ min_confidence: float = 0.5,
41
+ min_lift: float = 1.0,
42
+ ):
43
+ self.min_support = min_support
44
+ self.min_confidence = min_confidence
45
+ self.min_lift = min_lift
46
+
47
+ self.frequent_itemsets_ = {} # frozenset → support
48
+ self.rules_ = [] # list of dicts with antecedent/consequent/metrics
49
+
50
+ # ------------------------------------------------------------------
51
+ # Internal helpers
52
+ # ------------------------------------------------------------------
53
+
54
+ def _get_support(self, itemset: frozenset, transactions: list) -> float:
55
+ """Compute support of an itemset over the transaction list."""
56
+ count = sum(1 for t in transactions if itemset.issubset(t))
57
+ return count / len(transactions)
58
+
59
+ def _generate_candidates(self, prev_frequent: list, k: int) -> list:
60
+ """
61
+ Generate candidate k-itemsets by joining (k-1)-itemsets that share
62
+ their first k-2 items (standard Apriori join step).
63
+ """
64
+ candidates = set()
65
+ prev_list = sorted([sorted(fs) for fs in prev_frequent])
66
+ for i in range(len(prev_list)):
67
+ for j in range(i + 1, len(prev_list)):
68
+ # Join if first k-2 elements match
69
+ if prev_list[i][:k - 2] == prev_list[j][:k - 2]:
70
+ union = frozenset(prev_list[i]) | frozenset(prev_list[j])
71
+ if len(union) == k:
72
+ candidates.add(union)
73
+ return list(candidates)
74
+
75
+ def _prune_candidates(
76
+ self, candidates: list, prev_frequent_set: set, k: int
77
+ ) -> list:
78
+ """
79
+ Remove candidates that have an infrequent subset (Apriori pruning).
80
+ """
81
+ pruned = []
82
+ for candidate in candidates:
83
+ subsets = [
84
+ frozenset(s) for s in combinations(candidate, k - 1)
85
+ ]
86
+ if all(s in prev_frequent_set for s in subsets):
87
+ pruned.append(candidate)
88
+ return pruned
89
+
90
+ # ------------------------------------------------------------------
91
+ # Public API
92
+ # ------------------------------------------------------------------
93
+
94
+ def fit(self, transactions: list) -> "Apriori":
95
+ """
96
+ Mine frequent itemsets from a list of transactions.
97
+
98
+ Parameters
99
+ ----------
100
+ transactions : list of list/set
101
+ Each element is a collection of items in one transaction.
102
+ e.g. [['milk', 'bread'], ['milk', 'eggs', 'butter'], ...]
103
+
104
+ Returns
105
+ -------
106
+ self
107
+ """
108
+ # Encode transactions as frozensets
109
+ trans = [frozenset(t) for t in transactions]
110
+ all_items = sorted({item for t in trans for item in t})
111
+
112
+ # --- Pass 1: find frequent 1-itemsets ---
113
+ frequent_k = {}
114
+ for item in all_items:
115
+ fs = frozenset([item])
116
+ sup = self._get_support(fs, trans)
117
+ if sup >= self.min_support:
118
+ frequent_k[fs] = sup
119
+
120
+ self.frequent_itemsets_.update(frequent_k)
121
+
122
+ k = 2
123
+ while frequent_k:
124
+ prev_frequent_set = set(frequent_k.keys())
125
+
126
+ # Generate and prune candidates
127
+ candidates = self._generate_candidates(list(frequent_k.keys()), k)
128
+ candidates = self._prune_candidates(candidates, prev_frequent_set, k)
129
+
130
+ # Count support
131
+ new_frequent = {}
132
+ for candidate in candidates:
133
+ sup = self._get_support(candidate, trans)
134
+ if sup >= self.min_support:
135
+ new_frequent[candidate] = sup
136
+
137
+ self.frequent_itemsets_.update(new_frequent)
138
+ frequent_k = new_frequent
139
+ k += 1
140
+
141
+ # --- Generate association rules ---
142
+ self.rules_ = []
143
+ for itemset, itemset_sup in self.frequent_itemsets_.items():
144
+ if len(itemset) < 2:
145
+ continue
146
+ # Try all non-empty proper subsets as antecedents
147
+ for r in range(1, len(itemset)):
148
+ for antecedent in map(frozenset, combinations(itemset, r)):
149
+ consequent = itemset - antecedent
150
+ ant_sup = self.frequent_itemsets_.get(antecedent, 0)
151
+ con_sup = self.frequent_itemsets_.get(consequent, 0)
152
+
153
+ if ant_sup == 0:
154
+ continue
155
+
156
+ confidence = itemset_sup / ant_sup
157
+ lift = confidence / con_sup if con_sup > 0 else 0.0
158
+
159
+ if confidence >= self.min_confidence and lift >= self.min_lift:
160
+ self.rules_.append({
161
+ "antecedent": antecedent,
162
+ "consequent": consequent,
163
+ "support": itemset_sup,
164
+ "confidence": confidence,
165
+ "lift": lift,
166
+ })
167
+
168
+ return self
169
+
170
+ def get_frequent_itemsets(self) -> list:
171
+ """Return list of (itemset, support) tuples sorted by support desc."""
172
+ return sorted(
173
+ self.frequent_itemsets_.items(), key=lambda x: -x[1]
174
+ )
175
+
176
+ def get_rules(self) -> list:
177
+ """Return association rules sorted by confidence descending."""
178
+ return sorted(self.rules_, key=lambda r: -r["confidence"])