scratchkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. mlscratch/__init__.py +56 -0
  2. mlscratch/__main__.py +118 -0
  3. mlscratch/bayesian/__init__.py +53 -0
  4. mlscratch/bayesian/bayesian_linear_regression.py +171 -0
  5. mlscratch/bayesian/bayesian_network.py +248 -0
  6. mlscratch/bayesian/bayesian_nn.py +315 -0
  7. mlscratch/bayesian/gaussian_process.py +207 -0
  8. mlscratch/bayesian/hmm.py +277 -0
  9. mlscratch/bayesian/init.py +52 -0
  10. mlscratch/bayesian/kalman_filter.py +182 -0
  11. mlscratch/bayesian/naive_bayes.py +209 -0
  12. mlscratch/metrics/__init__.py +59 -0
  13. mlscratch/metrics/classification.py +365 -0
  14. mlscratch/metrics/regression.py +79 -0
  15. mlscratch/neural/__init__.py +121 -0
  16. mlscratch/neural/attention.py +420 -0
  17. mlscratch/neural/autoencoder.py +543 -0
  18. mlscratch/neural/boltzmann.py +231 -0
  19. mlscratch/neural/cnn.py +593 -0
  20. mlscratch/neural/cvnn.py +322 -0
  21. mlscratch/neural/gan.py +364 -0
  22. mlscratch/neural/hopfield.py +193 -0
  23. mlscratch/neural/perceptron.py +398 -0
  24. mlscratch/neural/rbf_network.py +230 -0
  25. mlscratch/neural/recurrent.py +569 -0
  26. mlscratch/preprocessing/__init__.py +38 -0
  27. mlscratch/preprocessing/encoders.py +140 -0
  28. mlscratch/preprocessing/model_selection.py +119 -0
  29. mlscratch/preprocessing/polynomial.py +105 -0
  30. mlscratch/preprocessing/scalers.py +220 -0
  31. mlscratch/py.typed +0 -0
  32. mlscratch/reinforcement/__init__.py +59 -0
  33. mlscratch/reinforcement/ddpg.py +363 -0
  34. mlscratch/reinforcement/dqn.py +319 -0
  35. mlscratch/reinforcement/ppo.py +452 -0
  36. mlscratch/reinforcement/q_learning.py +352 -0
  37. mlscratch/reinforcement/sac.py +382 -0
  38. mlscratch/reinforcement/utils.py +594 -0
  39. mlscratch/supervised/__init__.py +76 -0
  40. mlscratch/supervised/_validation.py +50 -0
  41. mlscratch/supervised/adaboost.py +255 -0
  42. mlscratch/supervised/decision_tree.py +495 -0
  43. mlscratch/supervised/gradient_boosting.py +354 -0
  44. mlscratch/supervised/knn.py +234 -0
  45. mlscratch/supervised/lasso_regression.py +125 -0
  46. mlscratch/supervised/linear_models.py +459 -0
  47. mlscratch/supervised/linear_regression.py +197 -0
  48. mlscratch/supervised/logistic_regression.py +119 -0
  49. mlscratch/supervised/naive_bayes.py +113 -0
  50. mlscratch/supervised/random_forest.py +321 -0
  51. mlscratch/supervised/ridge_regression.py +93 -0
  52. mlscratch/supervised/svm.py +356 -0
  53. mlscratch/unsupervised/__init__.py +39 -0
  54. mlscratch/unsupervised/apriori.py +178 -0
  55. mlscratch/unsupervised/dbscan.py +141 -0
  56. mlscratch/unsupervised/gmm.py +204 -0
  57. mlscratch/unsupervised/hierarchical_clustering.py +137 -0
  58. mlscratch/unsupervised/ica.py +167 -0
  59. mlscratch/unsupervised/kmeans.py +135 -0
  60. mlscratch/unsupervised/kmedoids.py +133 -0
  61. mlscratch/unsupervised/pca.py +103 -0
  62. mlscratch/unsupervised/tsne.py +200 -0
  63. scratchkit-0.2.0.dist-info/METADATA +241 -0
  64. scratchkit-0.2.0.dist-info/RECORD +68 -0
  65. scratchkit-0.2.0.dist-info/WHEEL +5 -0
  66. scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
  67. scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
  68. scratchkit-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,322 @@
1
+ """
2
+ Complex-Valued Neural Network (CVNN)
3
+ ======================================
4
+ A feedforward network whose weights, biases, and activations are complex
5
+ numbers (z = a + bi), useful for signal-processing tasks (audio spectra,
6
+ radar, MRI, wireless communications) where phase information is meaningful
7
+ and would be lost by treating real/imaginary parts as separate real channels.
8
+
9
+ Complex linear layer
10
+ ----------------------
11
+ z_out = W z_in + b, W, b, z ∈ ℂ
12
+
13
+ Complex activation functions
14
+ -------------------------------
15
+ ``modReLU`` (Arjovsky et al., 2016):
16
+ modReLU(z) = ReLU(|z| + b) · (z / |z|) if |z| + b > 0, else 0
17
+
18
+ ``complex tanh`` (split, applied to real and imaginary parts independently):
19
+ ctanh(z) = tanh(Re(z)) + i·tanh(Im(z))
20
+
21
+ ``zReLU``:
22
+ zReLU(z) = z if Re(z) > 0 and Im(z) > 0, else 0
23
+
24
+ Backpropagation
25
+ -----------------
26
+ Implemented via Wirtinger calculus / the CR-calculus convention: for a
27
+ real-valued loss L(z, z̄), gradients are computed with respect to the
28
+ conjugate ∂L/∂z̄, and parameter updates use:
29
+
30
+ W ← W - η · ∂L/∂W̄
31
+
32
+ For the layers and activations implemented here, this reduces to applying
33
+ the same real-valued backprop formulas independently to the real and
34
+ imaginary parts, which is the standard split-complex-backprop technique
35
+ and is exact for holomorphic-friendly activations like the ones above.
36
+
37
+ References
38
+ ----------
39
+ Arjovsky, M., Shah, A., & Bengio, Y. (2016). Unitary evolution recurrent
40
+ neural networks. ICML.
41
+
42
+ Trabelsi, C. et al. (2018). Deep Complex Networks. ICLR.
43
+
44
+ Only numpy is used.
45
+ """
46
+
47
+ from __future__ import annotations
48
+
49
+ import numpy as np
50
+
51
+
52
+ # ============================================================
53
+ # Complex activations
54
+ # ============================================================
55
+
56
+ def _complex_tanh(z: np.ndarray) -> np.ndarray:
57
+ """ctanh(z) = tanh(Re z) + i tanh(Im z)."""
58
+ return np.tanh(z.real) + 1j * np.tanh(z.imag)
59
+
60
+
61
+ def _complex_tanh_grad(z: np.ndarray) -> np.ndarray:
62
+ """Derivative w.r.t. real and imaginary parts (split form)."""
63
+ return (1 - np.tanh(z.real) ** 2) + 1j * (1 - np.tanh(z.imag) ** 2)
64
+
65
+
66
+ def _mod_relu(z: np.ndarray, bias: np.ndarray) -> np.ndarray:
67
+ """modReLU(z) = ReLU(|z| + b) · (z / |z|)."""
68
+ mag = np.abs(z)
69
+ scale = np.maximum(mag + bias, 0.0) / (mag + 1e-8)
70
+ return z * scale
71
+
72
+
73
+ def _mod_relu_grad(z: np.ndarray, bias: np.ndarray) -> np.ndarray:
74
+ """
75
+ Approximate split-real/imag gradient of modReLU.
76
+ Returns the multiplicative factor applied to incoming gradients
77
+ (1 where active, 0 where the unit is "off").
78
+ """
79
+ mag = np.abs(z)
80
+ active = (mag + bias) > 0
81
+ return active.astype(float)
82
+
83
+
84
+ def _z_relu(z: np.ndarray) -> np.ndarray:
85
+ """zReLU(z) = z if Re(z) > 0 and Im(z) > 0 else 0."""
86
+ mask = (z.real > 0) & (z.imag > 0)
87
+ return z * mask
88
+
89
+
90
+ def _z_relu_grad(z: np.ndarray) -> np.ndarray:
91
+ mask = (z.real > 0) & (z.imag > 0)
92
+ return mask.astype(float)
93
+
94
+
95
+ # ============================================================
96
+ # Complex Dense Layer
97
+ # ============================================================
98
+
99
+ class ComplexDense:
100
+ """
101
+ Complex-valued fully-connected layer.
102
+
103
+ Parameters
104
+ ----------
105
+ in_features : int
106
+ out_features : int
107
+ activation : str
108
+ ``'modrelu'``, ``'ctanh'``, ``'zrelu'``, or ``'linear'``.
109
+ learning_rate : float
110
+ random_state : int or None
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ in_features: int,
116
+ out_features: int,
117
+ activation: str = "ctanh",
118
+ learning_rate: float = 1e-3,
119
+ random_state: int | None = None,
120
+ ) -> None:
121
+ if activation not in {"modrelu", "ctanh", "zrelu", "linear"}:
122
+ raise ValueError("activation must be 'modrelu', 'ctanh', 'zrelu', or 'linear'.")
123
+ self.in_features = in_features
124
+ self.out_features = out_features
125
+ self.activation = activation
126
+ self.learning_rate = learning_rate
127
+
128
+ rng = np.random.default_rng(random_state)
129
+ scale = np.sqrt(1.0 / in_features)
130
+
131
+ # Complex weights: independent real and imaginary Gaussian parts
132
+ self.W = (rng.normal(0, scale, (in_features, out_features))
133
+ + 1j * rng.normal(0, scale, (in_features, out_features)))
134
+ self.b = np.zeros(out_features, dtype=complex)
135
+
136
+ if activation == "modrelu":
137
+ self.mod_bias = np.zeros(out_features)
138
+
139
+ self._cache: dict = {}
140
+
141
+ # ------------------------------------------------------------------
142
+ # Forward
143
+ # ------------------------------------------------------------------
144
+
145
+ def forward(self, z_in: np.ndarray) -> np.ndarray:
146
+ """
147
+ z_in : (B, in_features) complex array
148
+ Returns (B, out_features) complex array
149
+ """
150
+ z = z_in @ self.W + self.b
151
+
152
+ if self.activation == "ctanh":
153
+ a = _complex_tanh(z)
154
+ elif self.activation == "modrelu":
155
+ a = _mod_relu(z, self.mod_bias)
156
+ elif self.activation == "zrelu":
157
+ a = _z_relu(z)
158
+ else:
159
+ a = z
160
+
161
+ self._cache = {"z_in": z_in, "z": z}
162
+ return a
163
+
164
+ # ------------------------------------------------------------------
165
+ # Backward (split real/imaginary backprop)
166
+ # ------------------------------------------------------------------
167
+
168
+ def backward(self, d_a: np.ndarray) -> np.ndarray:
169
+ """
170
+ d_a : (B, out_features) complex gradient of loss w.r.t. activation output
171
+ Returns d_z_in : (B, in_features) complex gradient w.r.t. layer input
172
+ """
173
+ z_in, z = self._cache["z_in"], self._cache["z"]
174
+ n = len(z_in)
175
+
176
+ if self.activation == "ctanh":
177
+ grad = _complex_tanh_grad(z)
178
+ d_z = d_a.real * grad.real + 1j * (d_a.imag * grad.imag)
179
+ elif self.activation == "modrelu":
180
+ active = _mod_relu_grad(z, self.mod_bias)
181
+ d_z = d_a * active
182
+ # mod_bias gradient
183
+ d_bias = (d_a.real * active).mean(axis=0)
184
+ self.mod_bias -= self.learning_rate * d_bias
185
+ elif self.activation == "zrelu":
186
+ active = _z_relu_grad(z)
187
+ d_z = d_a * active
188
+ else:
189
+ d_z = d_a
190
+
191
+ # Gradients w.r.t. W and b (split real/imag, standard complex backprop)
192
+ d_W = np.conj(z_in).T @ d_z / n
193
+ d_b = d_z.mean(axis=0)
194
+
195
+ self.W -= self.learning_rate * d_W
196
+ self.b -= self.learning_rate * d_b
197
+
198
+ return d_z @ np.conj(self.W).T
199
+
200
+
201
+ # ============================================================
202
+ # Complex-Valued Neural Network (multi-layer)
203
+ # ============================================================
204
+
205
+ class ComplexValuedNN:
206
+ """
207
+ Multi-layer complex-valued feedforward network for regression on
208
+ complex-valued data (e.g. signal reconstruction).
209
+
210
+ The final layer is linear (no activation) so outputs can take any
211
+ complex value; loss is mean squared modulus error.
212
+
213
+ Parameters
214
+ ----------
215
+ layer_sizes : list[int]
216
+ e.g. [n_in, hidden1, hidden2, n_out]
217
+ hidden_activation : str
218
+ ``'ctanh'``, ``'modrelu'``, or ``'zrelu'`` for hidden layers.
219
+ learning_rate : float
220
+ epochs : int
221
+ batch_size : int or None
222
+ random_state : int or None
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ layer_sizes: list[int],
228
+ hidden_activation: str = "ctanh",
229
+ learning_rate: float = 1e-3,
230
+ epochs: int = 100,
231
+ batch_size: int | None = 32,
232
+ random_state: int | None = None,
233
+ ) -> None:
234
+ self.layer_sizes = layer_sizes
235
+ self.learning_rate = learning_rate
236
+ self.epochs = epochs
237
+ self.batch_size = batch_size
238
+
239
+ self.layers: list[ComplexDense] = []
240
+ for i in range(len(layer_sizes) - 1):
241
+ is_last = (i == len(layer_sizes) - 2)
242
+ act = "linear" if is_last else hidden_activation
243
+ seed = (random_state or 0) + i
244
+ self.layers.append(
245
+ ComplexDense(layer_sizes[i], layer_sizes[i + 1], act,
246
+ learning_rate, seed)
247
+ )
248
+
249
+ self.losses_: list[float] = []
250
+
251
+ # ------------------------------------------------------------------
252
+ # Forward / backward
253
+ # ------------------------------------------------------------------
254
+
255
+ def forward(self, z: np.ndarray) -> np.ndarray:
256
+ """z : (B, n_in) complex → (B, n_out) complex"""
257
+ a = z
258
+ for layer in self.layers:
259
+ a = layer.forward(a)
260
+ return a
261
+
262
+ def _backward(self, d_out: np.ndarray) -> None:
263
+ d = d_out
264
+ for layer in reversed(self.layers):
265
+ d = layer.backward(d)
266
+
267
+ # ------------------------------------------------------------------
268
+ # Public API
269
+ # ------------------------------------------------------------------
270
+
271
+ def fit(self, X: np.ndarray, y: np.ndarray) -> "ComplexValuedNN":
272
+ """
273
+ Train on complex data.
274
+
275
+ Parameters
276
+ ----------
277
+ X : ndarray of shape (n_samples, n_in), complex dtype
278
+ y : ndarray of shape (n_samples, n_out), complex dtype
279
+
280
+ Returns
281
+ -------
282
+ self
283
+ """
284
+ n = len(X)
285
+ bs = self.batch_size or n
286
+ rng = np.random.default_rng(0)
287
+ self.losses_ = []
288
+
289
+ for _ in range(self.epochs):
290
+ idx = rng.permutation(n)
291
+ epoch_loss = 0.0
292
+ n_batches = 0
293
+
294
+ for start in range(0, n, bs):
295
+ mb = idx[start:start + bs]
296
+ Xb, yb = X[mb], y[mb]
297
+
298
+ y_hat = self.forward(Xb)
299
+ diff = y_hat - yb
300
+ loss = float(np.mean(np.abs(diff) ** 2))
301
+ epoch_loss += loss
302
+ n_batches += 1
303
+
304
+ # d/d(z̄) of |y_hat - y|² is (y_hat - y); gradient direction:
305
+ d_out = diff / len(mb)
306
+ self._backward(d_out)
307
+
308
+ self.losses_.append(epoch_loss / n_batches)
309
+
310
+ return self
311
+
312
+ def predict(self, X: np.ndarray) -> np.ndarray:
313
+ """Return complex-valued predictions."""
314
+ return self.forward(X)
315
+
316
+ def predict_magnitude(self, X: np.ndarray) -> np.ndarray:
317
+ """Return |ŷ| (magnitude of predictions) — often the quantity of interest."""
318
+ return np.abs(self.forward(X))
319
+
320
+ def predict_phase(self, X: np.ndarray) -> np.ndarray:
321
+ """Return arg(ŷ) (phase of predictions) in radians."""
322
+ return np.angle(self.forward(X))
@@ -0,0 +1,364 @@
1
+ """
2
+ Generative Adversarial Network (GAN)
3
+ ======================================
4
+ Two networks trained adversarially (Goodfellow et al., 2014):
5
+
6
+ Generator G: z → x̂ maps noise to fake samples
7
+ Discriminator D: x → [0,1] estimates P(x is real)
8
+
9
+ Minimax objective
10
+ ------------------
11
+ min_G max_D E_x[log D(x)] + E_z[log(1 - D(G(z)))]
12
+
13
+ In practice the generator is trained with the non-saturating loss:
14
+ max_G E_z[log D(G(z))]
15
+
16
+ Training loop (per batch)
17
+ ---------------------------
18
+ 1. Sample real batch x ~ data, noise z ~ N(0,I)
19
+ 2. Generate fakes: x̂ = G(z)
20
+ 3. Update D to maximise log D(x) + log(1 - D(x̂))
21
+ 4. Update G to maximise log D(x̂) (non-saturating)
22
+
23
+ Both G and D are simple MLPs (Linear → ReLU/Tanh/Sigmoid),
24
+ implemented with manual forward/backward passes.
25
+
26
+ Reference
27
+ ----------
28
+ Goodfellow et al. (2014). Generative Adversarial Networks. NeurIPS.
29
+
30
+ Only numpy is used.
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import numpy as np
36
+
37
+
38
+ # ============================================================
39
+ # Activations
40
+ # ============================================================
41
+
42
+ def _sigmoid(x: np.ndarray) -> np.ndarray:
43
+ return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))
44
+
45
+
46
+ def _relu(x: np.ndarray) -> np.ndarray:
47
+ return np.maximum(0.0, x)
48
+
49
+
50
+ def _relu_grad(x: np.ndarray) -> np.ndarray:
51
+ return (x > 0).astype(float)
52
+
53
+
54
+ def _tanh(x: np.ndarray) -> np.ndarray:
55
+ return np.tanh(x)
56
+
57
+
58
+ def _tanh_grad(x: np.ndarray) -> np.ndarray:
59
+ return 1.0 - np.tanh(x) ** 2
60
+
61
+
62
+ # ============================================================
63
+ # Simple MLP block (used for both G and D)
64
+ # ============================================================
65
+
66
+ class _MLPBlock:
67
+ """A minimal MLP with manual forward/backward for GAN sub-networks."""
68
+
69
+ def __init__(
70
+ self,
71
+ layer_sizes: list[int],
72
+ output_activation: str,
73
+ rng: np.random.Generator,
74
+ ) -> None:
75
+ self.layer_sizes = layer_sizes
76
+ self.output_activation = output_activation
77
+
78
+ self.W: list[np.ndarray] = []
79
+ self.b: list[np.ndarray] = []
80
+ for i in range(len(layer_sizes) - 1):
81
+ scale = np.sqrt(2.0 / layer_sizes[i])
82
+ self.W.append(rng.normal(0, scale, (layer_sizes[i], layer_sizes[i + 1])))
83
+ self.b.append(np.zeros(layer_sizes[i + 1]))
84
+
85
+ self._cache: dict = {}
86
+
87
+ def forward(self, x: np.ndarray) -> np.ndarray:
88
+ a = x
89
+ zs, acts = [], [x]
90
+ for i, (W, b) in enumerate(zip(self.W, self.b)):
91
+ z = a @ W + b
92
+ zs.append(z)
93
+ if i < len(self.W) - 1:
94
+ a = _relu(z)
95
+ else:
96
+ if self.output_activation == "sigmoid":
97
+ a = _sigmoid(z)
98
+ elif self.output_activation == "tanh":
99
+ a = _tanh(z)
100
+ else:
101
+ a = z
102
+ acts.append(a)
103
+ self._cache = {"zs": zs, "acts": acts}
104
+ return a
105
+
106
+ def backward(self, d_out: np.ndarray, learning_rate: float) -> np.ndarray:
107
+ """Backprop d_out through the network; returns gradient w.r.t. input."""
108
+ zs, acts = self._cache["zs"], self._cache["acts"]
109
+ n = len(d_out)
110
+
111
+ # Output layer activation gradient
112
+ if self.output_activation == "sigmoid":
113
+ delta = d_out * acts[-1] * (1 - acts[-1])
114
+ elif self.output_activation == "tanh":
115
+ delta = d_out * _tanh_grad(zs[-1])
116
+ else:
117
+ delta = d_out
118
+
119
+ for i in reversed(range(len(self.W))):
120
+ dW = acts[i].T @ delta / n
121
+ db = delta.mean(axis=0)
122
+
123
+ self.W[i] -= learning_rate * dW
124
+ self.b[i] -= learning_rate * db
125
+
126
+ if i > 0:
127
+ delta = (delta @ self.W[i].T) * _relu_grad(zs[i - 1])
128
+ else:
129
+ d_input = delta @ self.W[i].T
130
+
131
+ return d_input
132
+
133
+
134
+ # ============================================================
135
+ # Generator / Discriminator
136
+ # ============================================================
137
+
138
+ class Generator:
139
+ """
140
+ Generator network: noise → fake data.
141
+
142
+ Architecture: Linear → ReLU → ... → Linear → Tanh
143
+ (tanh assumes data is scaled to [-1, 1]).
144
+
145
+ Parameters
146
+ ----------
147
+ latent_dim : int
148
+ output_dim : int
149
+ hidden_sizes : list[int]
150
+ random_state : int or None
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ latent_dim: int,
156
+ output_dim: int,
157
+ hidden_sizes: list[int] | None = None,
158
+ random_state: int | None = None,
159
+ ) -> None:
160
+ hidden_sizes = hidden_sizes or [64, 64]
161
+ rng = np.random.default_rng(random_state)
162
+ self.latent_dim = latent_dim
163
+ self._net = _MLPBlock(
164
+ [latent_dim] + hidden_sizes + [output_dim],
165
+ output_activation="tanh",
166
+ rng=rng,
167
+ )
168
+
169
+ def forward(self, z: np.ndarray) -> np.ndarray:
170
+ """z : (B, latent_dim) → (B, output_dim)"""
171
+ return self._net.forward(z)
172
+
173
+ def backward(self, d_out: np.ndarray, learning_rate: float) -> np.ndarray:
174
+ return self._net.backward(d_out, learning_rate)
175
+
176
+ def sample_noise(self, n: int, rng: np.random.Generator) -> np.ndarray:
177
+ """Sample n latent vectors ~ N(0,I)."""
178
+ return rng.standard_normal((n, self.latent_dim))
179
+
180
+
181
+ class Discriminator:
182
+ """
183
+ Discriminator network: data → P(real).
184
+
185
+ Architecture: Linear → ReLU → ... → Linear → Sigmoid
186
+
187
+ Parameters
188
+ ----------
189
+ input_dim : int
190
+ hidden_sizes : list[int]
191
+ random_state : int or None
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ input_dim: int,
197
+ hidden_sizes: list[int] | None = None,
198
+ random_state: int | None = None,
199
+ ) -> None:
200
+ hidden_sizes = hidden_sizes or [64, 64]
201
+ rng = np.random.default_rng(random_state)
202
+ self._net = _MLPBlock(
203
+ [input_dim] + hidden_sizes + [1],
204
+ output_activation="sigmoid",
205
+ rng=rng,
206
+ )
207
+
208
+ def forward(self, x: np.ndarray) -> np.ndarray:
209
+ """x : (B, input_dim) → (B, 1) probability of being real."""
210
+ return self._net.forward(x)
211
+
212
+ def backward(self, d_out: np.ndarray, learning_rate: float) -> np.ndarray:
213
+ return self._net.backward(d_out, learning_rate)
214
+
215
+
216
+ # ============================================================
217
+ # GAN — training orchestration
218
+ # ============================================================
219
+
220
+ class GAN:
221
+ """
222
+ Generative Adversarial Network — orchestrates G and D training.
223
+
224
+ Parameters
225
+ ----------
226
+ latent_dim : int
227
+ data_dim : int
228
+ hidden_sizes : list[int]
229
+ learning_rate : float
230
+ random_state : int or None
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ latent_dim: int,
236
+ data_dim: int,
237
+ hidden_sizes: list[int] | None = None,
238
+ learning_rate: float = 1e-3,
239
+ random_state: int | None = None,
240
+ ) -> None:
241
+ self.latent_dim = latent_dim
242
+ self.data_dim = data_dim
243
+ self.learning_rate = learning_rate
244
+ self._rng = np.random.default_rng(random_state)
245
+
246
+ self.generator = Generator(latent_dim, data_dim, hidden_sizes, random_state)
247
+ self.discriminator = Discriminator(data_dim, hidden_sizes,
248
+ (random_state or 0) + 1)
249
+
250
+ self.d_losses_: list[float] = []
251
+ self.g_losses_: list[float] = []
252
+
253
+ # ------------------------------------------------------------------
254
+ # Single training step
255
+ # ------------------------------------------------------------------
256
+
257
+ def train_step(self, real_batch: np.ndarray) -> tuple[float, float]:
258
+ """
259
+ Run one D-update followed by one G-update.
260
+
261
+ Parameters
262
+ ----------
263
+ real_batch : ndarray of shape (batch_size, data_dim)
264
+ Real data scaled to [-1, 1] (to match the generator's tanh output).
265
+
266
+ Returns
267
+ -------
268
+ (d_loss, g_loss) : tuple[float, float]
269
+ """
270
+ eps = 1e-8
271
+ B = len(real_batch)
272
+
273
+ # ── 1. Discriminator update ─────────────────────────────────
274
+ z = self.generator.sample_noise(B, self._rng)
275
+ fake = self.generator.forward(z)
276
+
277
+ d_real = self.discriminator.forward(real_batch)
278
+ d_fake = self.discriminator.forward(fake)
279
+
280
+ d_loss = float(-np.mean(np.log(d_real + eps) + np.log(1 - d_fake + eps)))
281
+
282
+ # Gradients for D: maximise log(d_real) + log(1 - d_fake)
283
+ # ⇒ minimise -log(d_real) - log(1 - d_fake)
284
+ grad_real = -(1.0 / (d_real + eps)) / B
285
+ grad_fake = (1.0 / (1 - d_fake + eps)) / B
286
+
287
+ # Backprop through D for real batch (don't propagate into G)
288
+ self.discriminator.forward(real_batch) # refresh cache
289
+ self.discriminator.backward(grad_real, self.learning_rate)
290
+
291
+ # Backprop through D for fake batch (don't propagate into G here)
292
+ self.discriminator.forward(fake) # refresh cache
293
+ self.discriminator.backward(grad_fake, self.learning_rate)
294
+
295
+ # ── 2. Generator update (non-saturating loss) ────────────────
296
+ z2 = self.generator.sample_noise(B, self._rng)
297
+ fake2 = self.generator.forward(z2)
298
+ d_fake2 = self.discriminator.forward(fake2)
299
+
300
+ g_loss = float(-np.mean(np.log(d_fake2 + eps)))
301
+
302
+ # dG_loss/d(d_fake2) = -1/d_fake2
303
+ grad_g_out = -(1.0 / (d_fake2 + eps)) / B
304
+ # Backprop through D (no weight update) to get gradient w.r.t. fake2
305
+ d_input_to_D = self.discriminator.backward(grad_g_out, learning_rate=0.0)
306
+ # Backprop that gradient through G (updates G's weights)
307
+ self.generator.backward(d_input_to_D, self.learning_rate)
308
+
309
+ self.d_losses_.append(d_loss)
310
+ self.g_losses_.append(g_loss)
311
+ return d_loss, g_loss
312
+
313
+ # ------------------------------------------------------------------
314
+ # Training loop
315
+ # ------------------------------------------------------------------
316
+
317
+ def fit(
318
+ self,
319
+ X: np.ndarray,
320
+ epochs: int = 100,
321
+ batch_size: int = 32,
322
+ ) -> "GAN":
323
+ """
324
+ Train the GAN on dataset X.
325
+
326
+ Parameters
327
+ ----------
328
+ X : ndarray of shape (n_samples, data_dim)
329
+ Should be scaled to [-1, 1].
330
+ epochs : int
331
+ batch_size : int
332
+
333
+ Returns
334
+ -------
335
+ self
336
+ """
337
+ n = len(X)
338
+
339
+ for _ in range(epochs):
340
+ idx = self._rng.permutation(n)
341
+ for start in range(0, n, batch_size):
342
+ mb = idx[start:start + batch_size]
343
+ self.train_step(X[mb])
344
+
345
+ return self
346
+
347
+ # ------------------------------------------------------------------
348
+ # Generation
349
+ # ------------------------------------------------------------------
350
+
351
+ def generate(self, n_samples: int) -> np.ndarray:
352
+ """
353
+ Generate n_samples fake samples.
354
+
355
+ Returns
356
+ -------
357
+ ndarray of shape (n_samples, data_dim), values in [-1, 1]
358
+ """
359
+ z = self.generator.sample_noise(n_samples, self._rng)
360
+ return self.generator.forward(z)
361
+
362
+ def discriminate(self, X: np.ndarray) -> np.ndarray:
363
+ """Return D(X) — probability each sample is real."""
364
+ return self.discriminator.forward(X).ravel()