zeroth-learn 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: zeroth-learn
3
+ Version: 0.1.0
4
+ Requires-Dist: numpy
5
+ Requires-Dist: pandas
6
+ Requires-Dist: matplotlib
7
+ Requires-Dist: cycler
@@ -0,0 +1,244 @@
1
+ # Zeroth-Learn
2
+
3
+ **A research library for zeroth-order optimization (gradient-free) in machine learning, with applications to quantum computing.**
4
+
5
+ *Research project — Nicolas, X24*
6
+
7
+ ---
8
+
9
+ ## Context & Motivation
10
+
11
+ This project originated from a fundamental question in quantum machine learning:
12
+
13
+ > **How do you train parameterized quantum circuits when backpropagation is impossible?**
14
+
15
+ Quantum circuits are too complex to differentiate we treat it as a black box, thus we need to find alternatives to backpropagation, SPSA is an excellent candidate.
16
+
17
+ **SPSA (Simultaneous Perturbation Stochastic Approximation)** solves this by estimating gradients from only O(1) evaluations per iteration.
18
+
19
+ Before deploying on quantum simulators, I built this library to:
20
+ 1. Understand the theoretical foundations of zeroth-order optimization
21
+ 2. Validate SPSA stability on classical benchmarks (MNIST)
22
+ 3. Validate my results with backpropagation models.
23
+
24
+ ---
25
+
26
+ ## Technical Implementation
27
+
28
+ ### Architecture Decisions
29
+
30
+ **Problem**: Standard deep learning frameworks (PyTorch, JAX) are tightly coupled to automatic differentiation. I needed an architecture where gradient computation is a **swappable abstraction**.
31
+
32
+ **Solution**: Clean separation of concerns using abstract base classes:
33
+ ```
34
+ Model (training loop orchestration)
35
+ ├── NeuralNetwork (forward pass interface)
36
+ │ ├── NeuralNetworkBackpropagation (layer-based, stores activations)
37
+ │ └── NeuralNetworkPerturbation (parameter-vector based, no activation storage)
38
+ └── Optimizer (gradient computation + update rule)
39
+ ├── OptimizerBackprop (analytical gradients via chain rule)
40
+ └── OptimizerPerturbation (estimated gradients via function evaluations)
41
+ ```
42
+
43
+ **Key insight**: By treating gradients as an *estimated quantity* rather than an *exact derivative*, both methods become instances of the same abstraction.
44
+
45
+ ---
46
+
47
+ ## Quick Start
48
+
49
+ 1. **Clone the repository:**
50
+ ```bash
51
+ git clone https://github.com/nicolasmalet/Zeroth-Learn.git
52
+ cd Zeroth-Learn
53
+ ```
54
+
55
+ 2. **Install dependencies:**
56
+ ```bash
57
+ pip install -r requirements.txt
58
+ ```
59
+
60
+ 3. **Run a benchmark experiment:**
61
+ To train a linear MLP on MNIST using SPSA with 50 perturbations:
62
+ ```bash
63
+ python -m lab.mnist
64
+ ```
65
+ ---
66
+
67
+ ### SPSA Implementation Details
68
+
69
+ The core challenge: **evaluate multiple perturbed models in parallel without Python loops**.
70
+
71
+ #### Naive Approach (slow):
72
+ ```python
73
+ for perturbation in perturbations:
74
+ theta_perturbed = theta + perturbation
75
+ loss_perturbed[i] = evaluate_model(theta_perturbed)
76
+ ```
77
+ **Cost**: O(T) sequential forward passes for T perturbations.
78
+
79
+ #### Vectorized Approach (implemented):
80
+ ```python
81
+ # Shape: (T, n_params)
82
+ pThetas = theta[None, :] + perturbations
83
+
84
+ # Reshape to (T, n_layers) weight matrices
85
+ Ws, Bs = params.from_pThetas(pThetas)
86
+
87
+ # Broadcast input across all T models simultaneously
88
+ # X: (input_dim, batch) -> (T, input_dim, batch)
89
+ for W, B, f in zip(Ws, Bs, fs):
90
+ X = f(W @ X + B) # Matrix multiplication broadcasts automatically
91
+ ```
92
+ **Result**: All T forward passes execute in a single vectorized NumPy operation.
93
+
94
+ ---
95
+
96
+ ### Mathematical Rigor
97
+
98
+ #### Gradient Estimation
99
+ The SPSA gradient estimator:
100
+
101
+ $$\nabla L(\theta) \approx \frac{1}{T \cdot \delta} \sum_{i=1}^{T} \left( L(\theta + \delta \Delta_i) - L(\theta) \right) \Delta_i$$
102
+
103
+ where $\Delta_i \sim \text{Rademacher}(\pm 1)$ are random perturbation directions.
104
+
105
+ **Implementation** (using Einstein summation for efficiency):
106
+ ```python
107
+ # L_diff: (T, batch_size), Ps: (T, n_params)
108
+ grad = np.einsum('ij,ik->k', L_diff, self.Ps) / (batch_size * T * delta)
109
+ ```
110
+
111
+ #### Numerical Stability Considerations
112
+ - **Softmax**: Shifted by max to prevent overflow: `exp(x - max(x))`
113
+ - **CrossEntropy**: Added epsilon (1e-8) to prevent log(0)
114
+ - **Xavier initialization**: Weights sampled from $U(-\sqrt{6/(n_{in}+n_{out})}, +\sqrt{6/(n_{in}+n_{out})})$
115
+
116
+ ---
117
+
118
+ ## Experimental Validation
119
+
120
+ ### Research Question
121
+ *What are the optimal conditions (architecture depth, learning rate, perturbation count) for SPSA to compete with backpropagation?*
122
+
123
+ ### Methodology
124
+
125
+ **Phase 1: Hyperparameter Sensitivity Analysis**
126
+ - Grid search over learning rates × architectures
127
+ - Identified stability thresholds (divergence boundaries)
128
+
129
+ <img alt="Learning Rate Analysis" src="assets/plots/lr_adam.png" height="300">
130
+
131
+ **Finding**: Adam requires lr ~ 0.001 for networks with 10K to 100K parameters to avoid gradient explosion in SPSA.
132
+
133
+ ---
134
+
135
+ **Phase 2: Scalability Limits**
136
+ - Trained 6 models from 7K to 1.3M parameters (here are the first three)
137
+ - Measured convergence speed vs parameter count
138
+
139
+ <img alt="Architecture Scaling" src="assets/plots/small_sizes.png" height="300"/>
140
+
141
+ **Finding**: Models with 100K parameters are sufficient to get 97% accuracy
142
+
143
+ ---
144
+
145
+ **Phase 3: Sample Efficiency**
146
+ - Varied perturbation count T ∈ {10, 30, 100}
147
+ - Measured gradient variance vs. computational cost
148
+
149
+ <img alt="Perturbation Analysis" src="assets/plots/nb_perturbations.png" height="300"/>
150
+
151
+ **Finding**: As gradient approximation variance reduction follows $\sigma \propto 1/\sqrt{T}$, we get marginal returns beyond T=30.
152
+
153
+ **Practical implication**: For quantum circuits, 30 evaluations/step is feasible on current hardware.
154
+
155
+ ---
156
+
157
+ ## Software Engineering Practices
158
+
159
+ ### Type Safety & Configuration Management
160
+ - **Frozen dataclasses** for all configs → immutable
161
+ - **Config serialization** → full experiment reproducibility (saved as JSON)
162
+
163
+ ### Modular Design
164
+ - **Catalog pattern** for hyperparameters (see `config.py`):
165
+ ```python
166
+ @dataclass(frozen=True)
167
+ class OptimizerCatalog:
168
+ FirstOrderAdam = FirstOrderAdamConfig(lr=0.001, ...)
169
+ ZerothOrderAdam = ZerothOrderAdamConfig(lr=0.001, ...)
170
+ ```
171
+ Enables experiment generation via `itertools.product`.
172
+
173
+ ### Experiment Reproducibility
174
+ - Automatic result saving (loss curves, results dataframe, hyperparameter logs)
175
+ - Plot styling configured globally (publication-ready figures)
176
+
177
+ ---
178
+
179
+ ## Software Design Principles
180
+
181
+ - **Separation of Concerns**: Gradient computation (Optimizer) is decoupled from forward pass (NeuralNetwork)
182
+ - **Config-Driven**: All hyperparameters defined as immutable dataclasses → reproducibility
183
+ - **Polymorphism**: Models can swap between backprop and SPSA without code changes
184
+ ---
185
+
186
+ ## Skills Demonstrated
187
+
188
+ **Deep Learning Fundamentals**: Implemented backprop from scratch (no PyTorch/TensorFlow)
189
+ **Numerical Optimization**: SPSA, Adam, gradient estimation theory
190
+ **Scientific Computing**: Vectorized NumPy, broadcasting, numerical stability
191
+ **Software Architecture**: Abstract base classes, config management
192
+ **Research Methodology**: Systematic experimentation, reproducible results
193
+ **Mathematical Rigor**: Gradient derivations, loss functions
194
+
195
+ ---
196
+
197
+ ## Next Steps
198
+
199
+ ### Quantum Simulation (In Progress)
200
+ - Implement `QuantumCircuitSimulator` class using Dynamics
201
+ - Test SPSA on parameterized quantum circuits
202
+ - Validate that convergence behavior matches classical benchmarks
203
+
204
+ ---
205
+
206
+ ## Technical Stack
207
+
208
+ **Language**: Python
209
+ **Core Libraries**: NumPy (vectorization), Pandas (results), Matplotlib (visualization)
210
+ **Design Patterns**: Strategy (Optimizer), Abstract Factory (Config instantiation), Template Method (Model training loop)
211
+
212
+ ---
213
+
214
+ ## Project Structure
215
+ ```
216
+ zeroth/
217
+
218
+ │── first-order/ # Analytical gradient methods
219
+ │ ├── layer.py # Forward/backward pass logic
220
+ │ └── optimizers.py # SGD, Adam implementations
221
+ │── zeroth-order/ # Zeroth-order methods
222
+ │ ├── gradient_estimator.py # Gradient estimation strategies
223
+ │ ├── parameter_manager.py # Parameter vector management
224
+ │ └── optimizers.py # SPSA + Adam/SGD variants
225
+ │── abstract/ # Shared abstractions
226
+ │ ├── neural_network.py # Abstract base class
227
+ │ └── optimizer.py # Optimizer interface
228
+ └── experiment.py # Experiment orchestration
229
+
230
+ lab/
231
+
232
+ ├── experiments.py # Pre-configured experiments
233
+ ├── models.py # Model definitions
234
+ └── config.py # Hyperparameter catalogs
235
+ ```
236
+
237
+ ---
238
+
239
+ ## Contact
240
+
241
+ **Nicolas Malet**
242
+ X24 — École Polytechnique
243
+ nicolas.malet@polytechnique.edu
244
+ [GitHub](https://github.com/nicolasmalet) | [LinkedIn](https://www.linkedin.com/in/nicolas-malet-pro)
@@ -0,0 +1,16 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "zeroth-learn"
7
+ version = "0.1.0"
8
+ dependencies = [
9
+ "numpy",
10
+ "pandas",
11
+ "matplotlib",
12
+ "cycler"
13
+ ]
14
+
15
+ [tool.setuptools]
16
+ packages = ["zeroth"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,5 @@
1
+ from zeroth.model import (Model, ModelConfig, ZerothOrderModel, FirstOrderModel,
2
+ ZerothOrderModelConfig, FirstOrderModelConfig)
3
+ from zeroth.experiment import Experiment, ExperimentConfig, VariationConfig
4
+
5
+ from zeroth.data import Data
@@ -0,0 +1,134 @@
1
+ from zeroth.utils.dataclasses_utils import config_serializer, generate_param_map, get_name, set_value_by_path
2
+ from zeroth.model import ModelConfig, Model
3
+ from zeroth.plot_losses import plot_losses
4
+ from zeroth.data import Data
5
+
6
+ from dataclasses import dataclass, replace
7
+ from typing import Callable
8
+
9
+ import pandas as pd
10
+ import itertools
11
+ import json
12
+ import os
13
+
14
+
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class VariationConfig:
19
+ param: str
20
+ values: list
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ExperimentConfig:
25
+ name: str
26
+ title: str
27
+ base_model: ModelConfig
28
+ variations: list[VariationConfig]
29
+ create_data: Callable
30
+ plot_dimension : int
31
+
32
+ def instantiate(self):
33
+ return Experiment(self)
34
+
35
+
36
+ class Experiment:
37
+ """Manages the full lifecycle of a deep learning experiment.
38
+
39
+ It handles data loading, model instantiation, training loops, and results visualization.
40
+
41
+ Attributes:
42
+ name (str): Name of the experiment
43
+ title (str): Title of the graphs
44
+ models (list[Model]): List of models to train/compare
45
+ data (Data): The dataset wrapper.
46
+ """
47
+
48
+ def __init__(self, config: ExperimentConfig):
49
+ self.config = config
50
+ self.name: str = config.name
51
+ self.title: str = config.title
52
+ self.base_model_config: ModelConfig = config.base_model
53
+ self.models: list[Model] = generate_models(config.base_model, config.variations)
54
+ self.data: Data = config.create_data()
55
+ self.plot_dimension: int = config.plot_dimension
56
+
57
+ self.save_dir = os.path.join("results", self.name)
58
+
59
+ def launch(self, do_train, do_test, nb_print_train, do_plot_train, do_save):
60
+ """
61
+ Executes the experiment pipeline.
62
+
63
+ Args:
64
+ do_train (bool): Whether to run the training loop.
65
+ do_test (bool): Whether to run evaluation on test set.
66
+ nb_print_train (int): Number of logs to print during training.
67
+ do_plot_train (bool): If True, plots loss curves after training.
68
+ do_save (bool): if True, saves the plots and dataframes
69
+ """
70
+ print(f"### Launching Experiment : {self.name} ###")
71
+ if do_train:
72
+ self.train(nb_print=nb_print_train, do_plot=do_plot_train, do_save=do_save)
73
+ if do_test:
74
+ self.test()
75
+ if do_save:
76
+ self.save_df()
77
+
78
+ def train(self, nb_print: int, do_plot: bool, do_save: bool):
79
+ for model in self.models:
80
+ model.train(self.data, nb_print)
81
+
82
+ if do_plot:
83
+ plot_path = None
84
+ if do_save:
85
+ os.makedirs(self.save_dir, exist_ok=True)
86
+ plot_path = os.path.join(self.save_dir, "training_losses.png")
87
+
88
+ plot_losses(dimension=self.plot_dimension, models=self.models, title=self.title, save_path=plot_path)
89
+
90
+ def test(self):
91
+ for model in self.models:
92
+ model.test(self.data)
93
+
94
+ def save_df(self):
95
+ """
96
+ saves the models parameters and their args
97
+ """
98
+ os.makedirs(self.save_dir, exist_ok=True)
99
+ print(f" Saving results to: {self.save_dir}")
100
+
101
+ data = [model.id | {"test_loss": model.test_loss, "test_accuracy": model.test_accuracy}
102
+ for model in self.models]
103
+
104
+ df = pd.DataFrame(data)
105
+ df.to_csv(os.path.join(self.save_dir, "models_accuracy.csv"), index_label="iteration")
106
+
107
+ config_path = os.path.join(self.save_dir, "config.json")
108
+ with open(config_path, "w") as f:
109
+ json.dump(self.config, f, default=config_serializer, indent=4)
110
+
111
+
112
+ def generate_models(base_model: ModelConfig, variations: list[VariationConfig]) -> list[Model]:
113
+
114
+ models = []
115
+
116
+ names = [v.param for v in variations]
117
+ values = [v.values for v in variations]
118
+
119
+ # key: param value: path
120
+ PARAM_MAP = generate_param_map(base_model)
121
+
122
+ for combination in itertools.product(*values):
123
+ id_ = {}
124
+ for key, val in zip(names, combination):
125
+ id_[key] = get_name(val)
126
+ current_model = base_model
127
+ for param_key, value in zip(names, combination):
128
+ path = PARAM_MAP[param_key]
129
+ current_model = set_value_by_path(current_model, path, value)
130
+
131
+ current_model = replace(current_model, id=id_)
132
+ models.append(current_model.instantiate())
133
+
134
+ return models
@@ -0,0 +1,121 @@
1
+ from abc import ABC, abstractmethod
2
+ import numpy as np
3
+
4
+
5
+ class Loss(ABC):
6
+ @staticmethod
7
+ @abstractmethod
8
+ def compute_loss(Y_pred: np.ndarray, Y_true: np.ndarray) -> float:
9
+ """
10
+ :param Y_pred: shape (out, batch)
11
+ :param Y_true: shape (out, batch)
12
+ :return: avg loss shape: float
13
+ """
14
+ pass
15
+
16
+ @staticmethod
17
+ @abstractmethod
18
+ def compute_batch_losses(Y_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
19
+ """
20
+ :param Y_pred: shape (out, batch)
21
+ :param Y_true: shape (out, batch)
22
+ :return: batch loss shape (batch, )
23
+ """
24
+ pass
25
+
26
+ @staticmethod
27
+ @abstractmethod
28
+ def compute_perturbed_losses(pY_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
29
+ """
30
+ :param pY_pred: (T, out, batch)
31
+ :param Y_true: (out, batch)
32
+ :return: perturbed loss (nb_params, T)
33
+ """
34
+ pass
35
+
36
+ @staticmethod
37
+ @abstractmethod
38
+ def compute_gradient_wrt_preactivation(last_layer, Y_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
39
+ """
40
+ :param last_layer: last_layer of the network
41
+ :param Y_pred: shape (out, batch)
42
+ :param Y_true: shape (out, batch)
43
+ :return: batch loss shape (batch, )
44
+ """
45
+ pass
46
+
47
+ @abstractmethod
48
+ def compute_losses_for_zeroth_order(self, pY_pred: np.ndarray, Y_true: np.ndarray) -> tuple[float, np.ndarray]:
49
+ pass
50
+
51
+
52
+ @abstractmethod
53
+ def compute_losses_for_first_order(self, last_layer, Y_pred: np.ndarray, Y_true: np.ndarray):
54
+ pass
55
+
56
+
57
+ class MSE(Loss):
58
+ name = "MSE"
59
+
60
+ @staticmethod
61
+ def compute_loss(Y_pred, Y_true) -> float:
62
+ return np.mean((Y_pred - Y_true) ** 2, axis=(0, 1))
63
+
64
+ @staticmethod
65
+ def compute_batch_losses(Y_pred, Y_true):
66
+ return np.mean((Y_pred - Y_true) ** 2, axis=0)
67
+
68
+ @staticmethod
69
+ def compute_perturbed_losses(pY_pred, Y_true):
70
+ return np.mean((pY_pred - Y_true) ** 2, axis=1)
71
+
72
+ @staticmethod
73
+ def compute_gradient_wrt_activation(Y_pred, Y_true):
74
+ return 2 * np.mean(Y_pred - Y_true, axis=0)
75
+
76
+ @staticmethod
77
+ def compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true):
78
+ dL_dA = 2 * (Y_pred - Y_true) / Y_true.size
79
+ dL_dZ = dL_dA * last_layer.df(last_layer.Z)
80
+ return dL_dZ
81
+
82
+ def compute_losses_for_zeroth_order(self, pY_pred, Y_true):
83
+ return (self.compute_loss(pY_pred[0], Y_true),
84
+ self.compute_perturbed_losses(pY_pred, Y_true))
85
+
86
+ def compute_losses_for_first_order(self, last_layer, Y_pred, Y_true):
87
+ return self.compute_loss(Y_pred, Y_true), self.compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true)
88
+
89
+
90
+ class CrossEntropy(Loss):
91
+ name = "CrossEntropy"
92
+
93
+ @staticmethod
94
+ def compute_loss(Y_pred: np.ndarray, Y_true: np.ndarray) -> float:
95
+ idx = np.arange(Y_pred.shape[1])
96
+ return - np.mean(np.log(1e-8 + Y_pred[Y_true, idx]))
97
+
98
+ @staticmethod
99
+ def compute_batch_losses(Y_pred: np.ndarray, Y_true: np.ndarray):
100
+ idx = np.arange(Y_pred.shape[1])
101
+ return - np.log(1e-8 + Y_pred[Y_true, idx])
102
+
103
+ @staticmethod
104
+ def compute_perturbed_losses(pY_pred: np.ndarray, Y_true: np.ndarray):
105
+ idx = np.arange(pY_pred.shape[2])
106
+ return - np.mean(np.log(1e-8 + pY_pred[:, Y_true, idx]), axis=1)
107
+
108
+ @staticmethod
109
+ def compute_gradient_wrt_preactivation(last_layer, Y_pred: np.ndarray, Y_true: np.ndarray):
110
+ dZ = Y_pred.copy()
111
+ batch_size = Y_pred.shape[1]
112
+ dZ[Y_true[0], np.arange(batch_size)] -= 1.0
113
+ return dZ
114
+
115
+ def compute_losses_for_zeroth_order(self, pY_pred: np.ndarray, Y_true: np.ndarray):
116
+ avg_loss = self.compute_loss(pY_pred[0], Y_true)
117
+ p_loss = self.compute_perturbed_losses(pY_pred, Y_true)
118
+ return avg_loss, p_loss
119
+
120
+ def compute_losses_for_first_order(self, last_layer, Y_pred: np.ndarray, Y_true: np.ndarray) -> tuple[float, np.ndarray]:
121
+ return self.compute_loss(Y_pred, Y_true), self.compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true)
@@ -0,0 +1,144 @@
1
+ from dataclasses import dataclass
2
+ from matplotlib.axes import Axes
3
+ from typing import Callable
4
+ from abc import ABC
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ from zeroth.zeroth_order import ZerothOrderNeuralNetwork, ZerothOrderOptimizerConfig, GradientEstimatorConfig, \
10
+ GradientEstimator
11
+ from zeroth.first_order import FirstOrderNeuralNetwork, FirstOrderOptimizerConfig
12
+ from zeroth.abstract import BlackBox, NeuralNetworkConfig, Optimizer
13
+ from zeroth.losses import Loss
14
+ from zeroth.data import Data
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ModelConfig:
19
+ """
20
+ name (str): Name of the model (used for display and saving).
21
+ loss (Loss): The loss class.
22
+ metric (Callable): Function (Y_pred, Y_true) -> score (e.g., accuracy).
23
+ batch_size (int): Number of samples per gradient update.
24
+ plot_results (Callable): Function to visualize test results.
25
+ nb_epochs (int): Number of passes through the entire dataset.
26
+ """
27
+ name: str
28
+ id: dict
29
+ loss: Loss
30
+ metric: Callable
31
+ batch_size: int
32
+ nb_epochs: int
33
+
34
+ def instantiate(self):
35
+ raise NotImplementedError
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class FirstOrderModelConfig(ModelConfig):
40
+ neural_network_config: NeuralNetworkConfig
41
+ optimizer_config: FirstOrderOptimizerConfig
42
+
43
+ def instantiate(self):
44
+ return FirstOrderModel(self)
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class ZerothOrderModelConfig(ModelConfig):
49
+ neural_network_config: NeuralNetworkConfig
50
+ optimizer_config: ZerothOrderOptimizerConfig
51
+ gradient_estimator_config: GradientEstimatorConfig
52
+
53
+ def instantiate(self):
54
+ return ZerothOrderModel(self)
55
+
56
+
57
+ class Model(ABC):
58
+ """
59
+ Base class orchestrating the training and testing loop.
60
+
61
+ This class abstracts the abstract logic for training
62
+ regardless of the underlying engine (Backpropagation or zeroth_order).
63
+ """
64
+
65
+ def __init__(self, config: ModelConfig):
66
+
67
+ self.name: str = config.name
68
+ self.id: dict = config.id
69
+ self.loss: Loss = config.loss
70
+ self.metric: Callable = config.metric
71
+ self.batch_size: int = config.batch_size
72
+ self.nb_epochs: int = config.nb_epochs
73
+ self.neural_network: BlackBox | None = None
74
+ self.optimizer: Optimizer | None = None
75
+
76
+ self.train_loss: np.ndarray = np.array([])
77
+ self.test_loss: float | None = None
78
+ self.test_accuracy: float | None = None
79
+
80
+ def train(self, data: Data, nb_print: int = 0):
81
+ """Runs the training loop over the dataset.
82
+
83
+ Args:
84
+ data (Data): The dataset object containing train/test sets.
85
+ nb_print (int): Number of progress updates to print per epoch.
86
+
87
+ Returns:
88
+ np.ndarray: Array of loss values recorded at each step (for plotting).
89
+ """
90
+
91
+ nb_batches = data.nb_data // self.batch_size
92
+
93
+ self.train_loss = np.zeros(self.nb_epochs * nb_batches, dtype=np.float64)
94
+
95
+ print_indexes = np.linspace(0, nb_batches - 1, nb_print).astype(int)
96
+ print(f" Training {self.id} Model")
97
+ for epoch_idx in range(self.nb_epochs):
98
+ print(f" epoch n°{epoch_idx + 1} out of {self.nb_epochs}")
99
+ data.prepare_data(self.batch_size)
100
+ for batch_idx in range(nb_batches):
101
+ X_train, Y_train = data.X_train[batch_idx], data.Y_train[batch_idx]
102
+ avg_loss = self.optimizer.do_descent(self.neural_network, self.loss, X_train, Y_train)
103
+ self.train_loss[epoch_idx * nb_batches + batch_idx] = avg_loss
104
+
105
+ if batch_idx in print_indexes:
106
+ print(f" batch n°{batch_idx + 1} out of {nb_batches}, "
107
+ f"loss : {np.round(self.train_loss[epoch_idx * nb_batches + batch_idx], 3)}")
108
+ self.test(data)
109
+
110
+ def plot_loss(self, ax: Axes, label: str, smooth_span: int = 50):
111
+ ax.plot(self.train_loss, alpha=0.25, linewidth=1.0)
112
+ smooth = self.smooth_curve(self.train_loss, smooth_span)
113
+ ax.plot(smooth, label=label, linewidth=2.5)
114
+
115
+ @staticmethod
116
+ def smooth_curve(loss: np.ndarray, smooth_span: int) -> np.ndarray:
117
+ return np.exp(pd.Series(np.log(loss)).ewm(span=smooth_span, adjust=True).mean())
118
+
119
+ def test(self, data):
120
+ X_test, Y_true = data.X_test, data.Y_test # (in, batch), (out, batch)
121
+ Y_pred = self.neural_network.forward(X_test) # (out, batch)
122
+
123
+ self.test_accuracy = self.metric(Y_pred, Y_true)
124
+ self.test_loss = self.loss.compute_loss(Y_pred, Y_true)
125
+
126
+ print(f" {self.id} accuracy : {self.test_accuracy}, loss : {self.test_loss}")
127
+
128
+
129
+ class FirstOrderModel(Model):
130
+ def __init__(self, config: FirstOrderModelConfig):
131
+ super().__init__(config)
132
+
133
+ self.neural_network = FirstOrderNeuralNetwork(config.neural_network_config)
134
+ self.optimizer = config.optimizer_config.instantiate()
135
+
136
+
137
+ class ZerothOrderModel(Model):
138
+ def __init__(self, config: ZerothOrderModelConfig):
139
+ super().__init__(config)
140
+
141
+ self.neural_network: ZerothOrderNeuralNetwork = ZerothOrderNeuralNetwork(config.neural_network_config)
142
+ nb_params = self.neural_network.params.nb_params
143
+ self.gradient_estimator: GradientEstimator = config.gradient_estimator_config.instantiate(nb_params)
144
+ self.optimizer: Optimizer = config.optimizer_config.instantiate(self.gradient_estimator)
@@ -0,0 +1,190 @@
1
+ from zeroth.model import Model
2
+
3
+ from matplotlib.axes import Axes
4
+ from cycler import cycler
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+ def set_style():
9
+ plt.rcParams.update({
10
+ "font.family": "serif",
11
+ "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
12
+ "font.size": 10,
13
+ "axes.labelsize": 10,
14
+ "axes.titlesize": 11,
15
+ "axes.linewidth": 0.8,
16
+ "axes.spines.top": False,
17
+ "axes.spines.right": False,
18
+ "xtick.labelsize": 9,
19
+ "ytick.labelsize": 9,
20
+ "xtick.direction": "out",
21
+ "ytick.direction": "out",
22
+ "xtick.major.size": 3,
23
+ "ytick.major.size": 3,
24
+ "lines.linewidth": 2.2,
25
+ "legend.fontsize": 9,
26
+ "legend.frameon": False,
27
+ "grid.color": "0.85",
28
+ "grid.linewidth": 0.6,
29
+ "grid.linestyle": "-",
30
+ "savefig.dpi": 300,
31
+ "savefig.bbox": "tight",
32
+ })
33
+ plt.rcParams["axes.prop_cycle"] = cycler(color=[
34
+ "#4477AA", "#EE6677", "#228833",
35
+ "#CCBB44", "#66CCEE", "#AA3377"
36
+ ])
37
+
38
+
39
+ def format_ax(ax: Axes):
40
+ ax.set_axisbelow(True)
41
+ ax.set_yscale('log')
42
+ ax.grid(True, which="major", axis="y")
43
+ ax.grid(False, axis="x")
44
+ ax.spines['top'].set_visible(False)
45
+ ax.spines['right'].set_visible(False)
46
+
47
+
48
+ def plot_0d(models: list[Model], title: str, smooth_span: int = 50):
49
+ """
50
+ Plots a single graph overlaying multiple models that share the same hyperparameters.
51
+ """
52
+ fig, ax = plt.subplots(figsize=(5.5, 3.5))
53
+
54
+ for model in models:
55
+ others = [f"{k}={v}" for k, v in model.id.items()]
56
+ label = ", ".join(others)
57
+ model.plot_loss(ax, label, smooth_span)
58
+
59
+ format_ax(ax)
60
+
61
+ plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.2)
62
+ fig.suptitle(title, fontweight='bold', fontsize=12)
63
+ ax.set_xlabel("Training steps")
64
+ ax.set_ylabel("Training loss")
65
+ format_ax(ax)
66
+
67
+ handles, labels = ax.get_legend_handles_labels()
68
+
69
+ if handles:
70
+ fig.legend(handles, labels, loc='lower center', ncol=len(handles),
71
+ bbox_to_anchor=(0.5, 0), frameon=False, fontsize=9)
72
+
73
+
74
+ def plot_1d(models: list[Model], title: str, key: str, smooth_span: int = 50):
75
+ """
76
+ Plots a row of subplots, varying one hyperparameter (key) across columns.
77
+ """
78
+
79
+ cols = list(dict.fromkeys([m.id[key] for m in models]))
80
+ n_models = len(cols)
81
+ fig, axs = plt.subplots(1, n_models, figsize=(4.5 * n_models, 3.5), sharey=True)
82
+
83
+ for i, val in enumerate(cols):
84
+ ax = axs[i]
85
+ cell_models = [m for m in models if m.id[key] == val]
86
+
87
+ for model in cell_models:
88
+ others = [f"{k}={v}" for k, v in model.id.items()
89
+ if k != key]
90
+ label = ", ".join(others)
91
+ model.plot_loss(ax, label, smooth_span)
92
+
93
+ format_ax(ax)
94
+
95
+ ax.set_title(f"{key} = {val}")
96
+
97
+
98
+ fig.text(0.5, 0.1, "Training steps", ha='center', fontsize=10)
99
+ fig.text(0.01, 0.5, "Training loss", va='center', rotation='vertical', fontsize=10)
100
+
101
+ plt.subplots_adjust(left=0.05, right=0.96, top=0.85, bottom=0.2, wspace=0.10, hspace=0.18)
102
+ fig.suptitle(title, fontweight='bold', fontsize=12)
103
+
104
+ handles, labels = axs[0].get_legend_handles_labels()
105
+
106
+ if handles:
107
+ fig.legend(handles, labels, loc='lower center', ncol=len(handles),
108
+ bbox_to_anchor=(0.5, 0), frameon=False, fontsize=9)
109
+
110
+
111
+ def plot_2d_grid(models: list[Model], title: str, row_key: str, col_key: str, smooth_span: int = 50):
112
+ """
113
+ Plots a grid of subplots varying two hyperparameters: one across rows, one across columns.
114
+
115
+ Args:
116
+ models (list): List of model objects.
117
+ title (str): The title of the plot.
118
+ row_key (str): The hyperparameter key changing across rows.
119
+ col_key (str): The hyperparameter key changing across columns.
120
+ smooth_span (int): The span for the EWM average.
121
+ """
122
+ rows = list(dict.fromkeys([m.id[row_key] for m in models]))
123
+ cols = list(dict.fromkeys([m.id[col_key] for m in models]))
124
+
125
+ fig, axs = plt.subplots(len(rows), len(cols),
126
+ figsize=(4.5 * len(cols), 3.5 * len(rows)),
127
+ sharex=True, sharey=True, squeeze=False)
128
+
129
+ for i, r_val in enumerate(rows):
130
+ for j, c_val in enumerate(cols):
131
+ ax = axs[i, j]
132
+ cell_models = [m for m in models if m.id[row_key] == r_val and m.id[col_key] == c_val]
133
+
134
+ for model in cell_models:
135
+ others = [f"{k}={v}" for k, v in model.id.items() if k not in [row_key, col_key]]
136
+ label = ", ".join(others)
137
+ model.plot_loss(ax, label, smooth_span)
138
+
139
+ format_ax(ax)
140
+
141
+
142
+ if i == 0:
143
+ ax.set_title(f"{col_key} = {c_val}")
144
+
145
+ if j == len(cols) - 1:
146
+ ax.text(1.02, 0.5, f"{row_key} = {r_val}",
147
+ transform=ax.transAxes, rotation=-90,
148
+ va="center", ha="left")
149
+
150
+ plt.subplots_adjust(left=0.06, right=0.96, top=0.90, bottom=0.12, wspace=0.10, hspace=0.18)
151
+ fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98)
152
+
153
+ handles, labels = axs[0, 0].get_legend_handles_labels()
154
+ if handles:
155
+ fig.legend(handles, labels, loc='lower center', ncol=len(handles),
156
+ bbox_to_anchor=(0.5, 0.02), frameon=False, fontsize=9)
157
+
158
+ fig.text(0.5, 0.07, "Training steps", ha='center', fontsize=10)
159
+ fig.text(0.02, 0.5, "Training loss", va='center', rotation='vertical', fontsize=10)
160
+
161
+
162
+
163
+ def plot_losses(dimension: int, models: list, title: str, save_path: str = None, smooth_span: int = 100):
164
+ """
165
+ Main entry point for plotting. Automatically detects if the plot should be 0D, 1D, or 2D
166
+ based on the number of variation parameters.
167
+
168
+ Args:
169
+ dimension (int): dimension of the plot
170
+ models (list): List of model objects.
171
+ title (str): The title of the plot.
172
+ save_path (str, optional): File path to save the figure (e.g., 'plot.png').
173
+ smooth_span (int): EWM span for smoothing. Defaults to 50.
174
+ """
175
+ set_style()
176
+
177
+ keys = list(models[0].id.keys())
178
+
179
+ if dimension == 0:
180
+ plot_0d(models, title, smooth_span)
181
+ elif dimension == 1:
182
+ plot_1d(models, title, keys[0], smooth_span)
183
+ else:
184
+ plot_2d_grid(models, title, keys[0], keys[1], smooth_span)
185
+
186
+ if save_path:
187
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
188
+ print(f"Plot saved to {save_path}")
189
+
190
+ plt.show()
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: zeroth-learn
3
+ Version: 0.1.0
4
+ Requires-Dist: numpy
5
+ Requires-Dist: pandas
6
+ Requires-Dist: matplotlib
7
+ Requires-Dist: cycler
@@ -0,0 +1,12 @@
1
+ README.md
2
+ pyproject.toml
3
+ zeroth/__init__.py
4
+ zeroth/experiment.py
5
+ zeroth/losses.py
6
+ zeroth/model.py
7
+ zeroth/plot_losses.py
8
+ zeroth_learn.egg-info/PKG-INFO
9
+ zeroth_learn.egg-info/SOURCES.txt
10
+ zeroth_learn.egg-info/dependency_links.txt
11
+ zeroth_learn.egg-info/requires.txt
12
+ zeroth_learn.egg-info/top_level.txt
@@ -0,0 +1,4 @@
1
+ numpy
2
+ pandas
3
+ matplotlib
4
+ cycler