zeroth-learn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zeroth/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from zeroth.model import (Model, ModelConfig, ZerothOrderModel, FirstOrderModel,
2
+ ZerothOrderModelConfig, FirstOrderModelConfig)
3
+ from zeroth.experiment import Experiment, ExperimentConfig, VariationConfig
4
+
5
+ from zeroth.data import Data
zeroth/experiment.py ADDED
@@ -0,0 +1,134 @@
1
+ from zeroth.utils.dataclasses_utils import config_serializer, generate_param_map, get_name, set_value_by_path
2
+ from zeroth.model import ModelConfig, Model
3
+ from zeroth.plot_losses import plot_losses
4
+ from zeroth.data import Data
5
+
6
+ from dataclasses import dataclass, replace
7
+ from typing import Callable
8
+
9
+ import pandas as pd
10
+ import itertools
11
+ import json
12
+ import os
13
+
14
+
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class VariationConfig:
19
+ param: str
20
+ values: list
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ExperimentConfig:
25
+ name: str
26
+ title: str
27
+ base_model: ModelConfig
28
+ variations: list[VariationConfig]
29
+ create_data: Callable
30
+ plot_dimension : int
31
+
32
+ def instantiate(self):
33
+ return Experiment(self)
34
+
35
+
36
+ class Experiment:
37
+ """Manages the full lifecycle of a deep learning experiment.
38
+
39
+ It handles data loading, model instantiation, training loops, and results visualization.
40
+
41
+ Attributes:
42
+ name (str): Name of the experiment
43
+ title (str): Title of the graphs
44
+ models (list[Model]): List of models to train/compare
45
+ data (Data): The dataset wrapper.
46
+ """
47
+
48
+ def __init__(self, config: ExperimentConfig):
49
+ self.config = config
50
+ self.name: str = config.name
51
+ self.title: str = config.title
52
+ self.base_model_config: ModelConfig = config.base_model
53
+ self.models: list[Model] = generate_models(config.base_model, config.variations)
54
+ self.data: Data = config.create_data()
55
+ self.plot_dimension: int = config.plot_dimension
56
+
57
+ self.save_dir = os.path.join("results", self.name)
58
+
59
+ def launch(self, do_train, do_test, nb_print_train, do_plot_train, do_save):
60
+ """
61
+ Executes the experiment pipeline.
62
+
63
+ Args:
64
+ do_train (bool): Whether to run the training loop.
65
+ do_test (bool): Whether to run evaluation on test set.
66
+ nb_print_train (int): Number of logs to print during training.
67
+ do_plot_train (bool): If True, plots loss curves after training.
68
+ do_save (bool): if True, saves the plots and dataframes
69
+ """
70
+ print(f"### Launching Experiment : {self.name} ###")
71
+ if do_train:
72
+ self.train(nb_print=nb_print_train, do_plot=do_plot_train, do_save=do_save)
73
+ if do_test:
74
+ self.test()
75
+ if do_save:
76
+ self.save_df()
77
+
78
+ def train(self, nb_print: int, do_plot: bool, do_save: bool):
79
+ for model in self.models:
80
+ model.train(self.data, nb_print)
81
+
82
+ if do_plot:
83
+ plot_path = None
84
+ if do_save:
85
+ os.makedirs(self.save_dir, exist_ok=True)
86
+ plot_path = os.path.join(self.save_dir, "training_losses.png")
87
+
88
+ plot_losses(dimension=self.plot_dimension, models=self.models, title=self.title, save_path=plot_path)
89
+
90
+ def test(self):
91
+ for model in self.models:
92
+ model.test(self.data)
93
+
94
+ def save_df(self):
95
+ """
96
+ saves the models parameters and their args
97
+ """
98
+ os.makedirs(self.save_dir, exist_ok=True)
99
+ print(f" Saving results to: {self.save_dir}")
100
+
101
+ data = [model.id | {"test_loss": model.test_loss, "test_accuracy": model.test_accuracy}
102
+ for model in self.models]
103
+
104
+ df = pd.DataFrame(data)
105
+ df.to_csv(os.path.join(self.save_dir, "models_accuracy.csv"), index_label="iteration")
106
+
107
+ config_path = os.path.join(self.save_dir, "config.json")
108
+ with open(config_path, "w") as f:
109
+ json.dump(self.config, f, default=config_serializer, indent=4)
110
+
111
+
112
+ def generate_models(base_model: ModelConfig, variations: list[VariationConfig]) -> list[Model]:
113
+
114
+ models = []
115
+
116
+ names = [v.param for v in variations]
117
+ values = [v.values for v in variations]
118
+
119
+ # key: param value: path
120
+ PARAM_MAP = generate_param_map(base_model)
121
+
122
+ for combination in itertools.product(*values):
123
+ id_ = {}
124
+ for key, val in zip(names, combination):
125
+ id_[key] = get_name(val)
126
+ current_model = base_model
127
+ for param_key, value in zip(names, combination):
128
+ path = PARAM_MAP[param_key]
129
+ current_model = set_value_by_path(current_model, path, value)
130
+
131
+ current_model = replace(current_model, id=id_)
132
+ models.append(current_model.instantiate())
133
+
134
+ return models
zeroth/losses.py ADDED
@@ -0,0 +1,121 @@
1
+ from abc import ABC, abstractmethod
2
+ import numpy as np
3
+
4
+
5
+ class Loss(ABC):
6
+ @staticmethod
7
+ @abstractmethod
8
+ def compute_loss(Y_pred: np.ndarray, Y_true: np.ndarray) -> float:
9
+ """
10
+ :param Y_pred: shape (out, batch)
11
+ :param Y_true: shape (out, batch)
12
+ :return: avg loss shape: float
13
+ """
14
+ pass
15
+
16
+ @staticmethod
17
+ @abstractmethod
18
+ def compute_batch_losses(Y_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
19
+ """
20
+ :param Y_pred: shape (out, batch)
21
+ :param Y_true: shape (out, batch)
22
+ :return: batch loss shape (batch, )
23
+ """
24
+ pass
25
+
26
+ @staticmethod
27
+ @abstractmethod
28
+ def compute_perturbed_losses(pY_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
29
+ """
30
+ :param pY_pred: (T, out, batch)
31
+ :param Y_true: (out, batch)
32
+ :return: perturbed loss (nb_params, T)
33
+ """
34
+ pass
35
+
36
+ @staticmethod
37
+ @abstractmethod
38
+ def compute_gradient_wrt_preactivation(last_layer, Y_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
39
+ """
40
+ :param last_layer: last_layer of the network
41
+ :param Y_pred: shape (out, batch)
42
+ :param Y_true: shape (out, batch)
43
+ :return: batch loss shape (batch, )
44
+ """
45
+ pass
46
+
47
+ @abstractmethod
48
+ def compute_losses_for_zeroth_order(self, pY_pred: np.ndarray, Y_true: np.ndarray) -> tuple[float, np.ndarray]:
49
+ pass
50
+
51
+
52
+ @abstractmethod
53
+ def compute_losses_for_first_order(self, last_layer, Y_pred: np.ndarray, Y_true: np.ndarray):
54
+ pass
55
+
56
+
57
+ class MSE(Loss):
58
+ name = "MSE"
59
+
60
+ @staticmethod
61
+ def compute_loss(Y_pred, Y_true) -> float:
62
+ return np.mean((Y_pred - Y_true) ** 2, axis=(0, 1))
63
+
64
+ @staticmethod
65
+ def compute_batch_losses(Y_pred, Y_true):
66
+ return np.mean((Y_pred - Y_true) ** 2, axis=0)
67
+
68
+ @staticmethod
69
+ def compute_perturbed_losses(pY_pred, Y_true):
70
+ return np.mean((pY_pred - Y_true) ** 2, axis=1)
71
+
72
+ @staticmethod
73
+ def compute_gradient_wrt_activation(Y_pred, Y_true):
74
+ return 2 * np.mean(Y_pred - Y_true, axis=0)
75
+
76
+ @staticmethod
77
+ def compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true):
78
+ dL_dA = 2 * (Y_pred - Y_true) / Y_true.size
79
+ dL_dZ = dL_dA * last_layer.df(last_layer.Z)
80
+ return dL_dZ
81
+
82
+ def compute_losses_for_zeroth_order(self, pY_pred, Y_true):
83
+ return (self.compute_loss(pY_pred[0], Y_true),
84
+ self.compute_perturbed_losses(pY_pred, Y_true))
85
+
86
+ def compute_losses_for_first_order(self, last_layer, Y_pred, Y_true):
87
+ return self.compute_loss(Y_pred, Y_true), self.compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true)
88
+
89
+
90
+ class CrossEntropy(Loss):
91
+ name = "CrossEntropy"
92
+
93
+ @staticmethod
94
+ def compute_loss(Y_pred: np.ndarray, Y_true: np.ndarray) -> float:
95
+ idx = np.arange(Y_pred.shape[1])
96
+ return - np.mean(np.log(1e-8 + Y_pred[Y_true, idx]))
97
+
98
+ @staticmethod
99
+ def compute_batch_losses(Y_pred: np.ndarray, Y_true: np.ndarray):
100
+ idx = np.arange(Y_pred.shape[1])
101
+ return - np.log(1e-8 + Y_pred[Y_true, idx])
102
+
103
+ @staticmethod
104
+ def compute_perturbed_losses(pY_pred: np.ndarray, Y_true: np.ndarray):
105
+ idx = np.arange(pY_pred.shape[2])
106
+ return - np.mean(np.log(1e-8 + pY_pred[:, Y_true, idx]), axis=1)
107
+
108
+ @staticmethod
109
+ def compute_gradient_wrt_preactivation(last_layer, Y_pred: np.ndarray, Y_true: np.ndarray):
110
+ dZ = Y_pred.copy()
111
+ batch_size = Y_pred.shape[1]
112
+ dZ[Y_true[0], np.arange(batch_size)] -= 1.0
113
+ return dZ
114
+
115
+ def compute_losses_for_zeroth_order(self, pY_pred: np.ndarray, Y_true: np.ndarray):
116
+ avg_loss = self.compute_loss(pY_pred[0], Y_true)
117
+ p_loss = self.compute_perturbed_losses(pY_pred, Y_true)
118
+ return avg_loss, p_loss
119
+
120
+ def compute_losses_for_first_order(self, last_layer, Y_pred: np.ndarray, Y_true: np.ndarray) -> tuple[float, np.ndarray]:
121
+ return self.compute_loss(Y_pred, Y_true), self.compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true)
zeroth/model.py ADDED
@@ -0,0 +1,144 @@
1
+ from dataclasses import dataclass
2
+ from matplotlib.axes import Axes
3
+ from typing import Callable
4
+ from abc import ABC
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ from zeroth.zeroth_order import ZerothOrderNeuralNetwork, ZerothOrderOptimizerConfig, GradientEstimatorConfig, \
10
+ GradientEstimator
11
+ from zeroth.first_order import FirstOrderNeuralNetwork, FirstOrderOptimizerConfig
12
+ from zeroth.abstract import BlackBox, NeuralNetworkConfig, Optimizer
13
+ from zeroth.losses import Loss
14
+ from zeroth.data import Data
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ModelConfig:
19
+ """
20
+ name (str): Name of the model (used for display and saving).
21
+ loss (Loss): The loss class.
22
+ metric (Callable): Function (Y_pred, Y_true) -> score (e.g., accuracy).
23
+ batch_size (int): Number of samples per gradient update.
24
+ plot_results (Callable): Function to visualize test results.
25
+ nb_epochs (int): Number of passes through the entire dataset.
26
+ """
27
+ name: str
28
+ id: dict
29
+ loss: Loss
30
+ metric: Callable
31
+ batch_size: int
32
+ nb_epochs: int
33
+
34
+ def instantiate(self):
35
+ raise NotImplementedError
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class FirstOrderModelConfig(ModelConfig):
40
+ neural_network_config: NeuralNetworkConfig
41
+ optimizer_config: FirstOrderOptimizerConfig
42
+
43
+ def instantiate(self):
44
+ return FirstOrderModel(self)
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class ZerothOrderModelConfig(ModelConfig):
49
+ neural_network_config: NeuralNetworkConfig
50
+ optimizer_config: ZerothOrderOptimizerConfig
51
+ gradient_estimator_config: GradientEstimatorConfig
52
+
53
+ def instantiate(self):
54
+ return ZerothOrderModel(self)
55
+
56
+
57
+ class Model(ABC):
58
+ """
59
+ Base class orchestrating the training and testing loop.
60
+
61
+ This class abstracts the abstract logic for training
62
+ regardless of the underlying engine (Backpropagation or zeroth_order).
63
+ """
64
+
65
+ def __init__(self, config: ModelConfig):
66
+
67
+ self.name: str = config.name
68
+ self.id: dict = config.id
69
+ self.loss: Loss = config.loss
70
+ self.metric: Callable = config.metric
71
+ self.batch_size: int = config.batch_size
72
+ self.nb_epochs: int = config.nb_epochs
73
+ self.neural_network: BlackBox | None = None
74
+ self.optimizer: Optimizer | None = None
75
+
76
+ self.train_loss: np.ndarray = np.array([])
77
+ self.test_loss: float | None = None
78
+ self.test_accuracy: float | None = None
79
+
80
+ def train(self, data: Data, nb_print: int = 0):
81
+ """Runs the training loop over the dataset.
82
+
83
+ Args:
84
+ data (Data): The dataset object containing train/test sets.
85
+ nb_print (int): Number of progress updates to print per epoch.
86
+
87
+ Returns:
88
+ np.ndarray: Array of loss values recorded at each step (for plotting).
89
+ """
90
+
91
+ nb_batches = data.nb_data // self.batch_size
92
+
93
+ self.train_loss = np.zeros(self.nb_epochs * nb_batches, dtype=np.float64)
94
+
95
+ print_indexes = np.linspace(0, nb_batches - 1, nb_print).astype(int)
96
+ print(f" Training {self.id} Model")
97
+ for epoch_idx in range(self.nb_epochs):
98
+ print(f" epoch n°{epoch_idx + 1} out of {self.nb_epochs}")
99
+ data.prepare_data(self.batch_size)
100
+ for batch_idx in range(nb_batches):
101
+ X_train, Y_train = data.X_train[batch_idx], data.Y_train[batch_idx]
102
+ avg_loss = self.optimizer.do_descent(self.neural_network, self.loss, X_train, Y_train)
103
+ self.train_loss[epoch_idx * nb_batches + batch_idx] = avg_loss
104
+
105
+ if batch_idx in print_indexes:
106
+ print(f" batch n°{batch_idx + 1} out of {nb_batches}, "
107
+ f"loss : {np.round(self.train_loss[epoch_idx * nb_batches + batch_idx], 3)}")
108
+ self.test(data)
109
+
110
+ def plot_loss(self, ax: Axes, label: str, smooth_span: int = 50):
111
+ ax.plot(self.train_loss, alpha=0.25, linewidth=1.0)
112
+ smooth = self.smooth_curve(self.train_loss, smooth_span)
113
+ ax.plot(smooth, label=label, linewidth=2.5)
114
+
115
+ @staticmethod
116
+ def smooth_curve(loss: np.ndarray, smooth_span: int) -> np.ndarray:
117
+ return np.exp(pd.Series(np.log(loss)).ewm(span=smooth_span, adjust=True).mean())
118
+
119
+ def test(self, data):
120
+ X_test, Y_true = data.X_test, data.Y_test # (in, batch), (out, batch)
121
+ Y_pred = self.neural_network.forward(X_test) # (out, batch)
122
+
123
+ self.test_accuracy = self.metric(Y_pred, Y_true)
124
+ self.test_loss = self.loss.compute_loss(Y_pred, Y_true)
125
+
126
+ print(f" {self.id} accuracy : {self.test_accuracy}, loss : {self.test_loss}")
127
+
128
+
129
+ class FirstOrderModel(Model):
130
+ def __init__(self, config: FirstOrderModelConfig):
131
+ super().__init__(config)
132
+
133
+ self.neural_network = FirstOrderNeuralNetwork(config.neural_network_config)
134
+ self.optimizer = config.optimizer_config.instantiate()
135
+
136
+
137
+ class ZerothOrderModel(Model):
138
+ def __init__(self, config: ZerothOrderModelConfig):
139
+ super().__init__(config)
140
+
141
+ self.neural_network: ZerothOrderNeuralNetwork = ZerothOrderNeuralNetwork(config.neural_network_config)
142
+ nb_params = self.neural_network.params.nb_params
143
+ self.gradient_estimator: GradientEstimator = config.gradient_estimator_config.instantiate(nb_params)
144
+ self.optimizer: Optimizer = config.optimizer_config.instantiate(self.gradient_estimator)
zeroth/plot_losses.py ADDED
@@ -0,0 +1,190 @@
1
+ from zeroth.model import Model
2
+
3
+ from matplotlib.axes import Axes
4
+ from cycler import cycler
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+ def set_style():
9
+ plt.rcParams.update({
10
+ "font.family": "serif",
11
+ "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
12
+ "font.size": 10,
13
+ "axes.labelsize": 10,
14
+ "axes.titlesize": 11,
15
+ "axes.linewidth": 0.8,
16
+ "axes.spines.top": False,
17
+ "axes.spines.right": False,
18
+ "xtick.labelsize": 9,
19
+ "ytick.labelsize": 9,
20
+ "xtick.direction": "out",
21
+ "ytick.direction": "out",
22
+ "xtick.major.size": 3,
23
+ "ytick.major.size": 3,
24
+ "lines.linewidth": 2.2,
25
+ "legend.fontsize": 9,
26
+ "legend.frameon": False,
27
+ "grid.color": "0.85",
28
+ "grid.linewidth": 0.6,
29
+ "grid.linestyle": "-",
30
+ "savefig.dpi": 300,
31
+ "savefig.bbox": "tight",
32
+ })
33
+ plt.rcParams["axes.prop_cycle"] = cycler(color=[
34
+ "#4477AA", "#EE6677", "#228833",
35
+ "#CCBB44", "#66CCEE", "#AA3377"
36
+ ])
37
+
38
+
39
+ def format_ax(ax: Axes):
40
+ ax.set_axisbelow(True)
41
+ ax.set_yscale('log')
42
+ ax.grid(True, which="major", axis="y")
43
+ ax.grid(False, axis="x")
44
+ ax.spines['top'].set_visible(False)
45
+ ax.spines['right'].set_visible(False)
46
+
47
+
48
+ def plot_0d(models: list[Model], title: str, smooth_span: int = 50):
49
+ """
50
+ Plots a single graph overlaying multiple models that share the same hyperparameters.
51
+ """
52
+ fig, ax = plt.subplots(figsize=(5.5, 3.5))
53
+
54
+ for model in models:
55
+ others = [f"{k}={v}" for k, v in model.id.items()]
56
+ label = ", ".join(others)
57
+ model.plot_loss(ax, label, smooth_span)
58
+
59
+ format_ax(ax)
60
+
61
+ plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.2)
62
+ fig.suptitle(title, fontweight='bold', fontsize=12)
63
+ ax.set_xlabel("Training steps")
64
+ ax.set_ylabel("Training loss")
65
+ format_ax(ax)
66
+
67
+ handles, labels = ax.get_legend_handles_labels()
68
+
69
+ if handles:
70
+ fig.legend(handles, labels, loc='lower center', ncol=len(handles),
71
+ bbox_to_anchor=(0.5, 0), frameon=False, fontsize=9)
72
+
73
+
74
+ def plot_1d(models: list[Model], title: str, key: str, smooth_span: int = 50):
75
+ """
76
+ Plots a row of subplots, varying one hyperparameter (key) across columns.
77
+ """
78
+
79
+ cols = list(dict.fromkeys([m.id[key] for m in models]))
80
+ n_models = len(cols)
81
+ fig, axs = plt.subplots(1, n_models, figsize=(4.5 * n_models, 3.5), sharey=True)
82
+
83
+ for i, val in enumerate(cols):
84
+ ax = axs[i]
85
+ cell_models = [m for m in models if m.id[key] == val]
86
+
87
+ for model in cell_models:
88
+ others = [f"{k}={v}" for k, v in model.id.items()
89
+ if k != key]
90
+ label = ", ".join(others)
91
+ model.plot_loss(ax, label, smooth_span)
92
+
93
+ format_ax(ax)
94
+
95
+ ax.set_title(f"{key} = {val}")
96
+
97
+
98
+ fig.text(0.5, 0.1, "Training steps", ha='center', fontsize=10)
99
+ fig.text(0.01, 0.5, "Training loss", va='center', rotation='vertical', fontsize=10)
100
+
101
+ plt.subplots_adjust(left=0.05, right=0.96, top=0.85, bottom=0.2, wspace=0.10, hspace=0.18)
102
+ fig.suptitle(title, fontweight='bold', fontsize=12)
103
+
104
+ handles, labels = axs[0].get_legend_handles_labels()
105
+
106
+ if handles:
107
+ fig.legend(handles, labels, loc='lower center', ncol=len(handles),
108
+ bbox_to_anchor=(0.5, 0), frameon=False, fontsize=9)
109
+
110
+
111
+ def plot_2d_grid(models: list[Model], title: str, row_key: str, col_key: str, smooth_span: int = 50):
112
+ """
113
+ Plots a grid of subplots varying two hyperparameters: one across rows, one across columns.
114
+
115
+ Args:
116
+ models (list): List of model objects.
117
+ title (str): The title of the plot.
118
+ row_key (str): The hyperparameter key changing across rows.
119
+ col_key (str): The hyperparameter key changing across columns.
120
+ smooth_span (int): The span for the EWM average.
121
+ """
122
+ rows = list(dict.fromkeys([m.id[row_key] for m in models]))
123
+ cols = list(dict.fromkeys([m.id[col_key] for m in models]))
124
+
125
+ fig, axs = plt.subplots(len(rows), len(cols),
126
+ figsize=(4.5 * len(cols), 3.5 * len(rows)),
127
+ sharex=True, sharey=True, squeeze=False)
128
+
129
+ for i, r_val in enumerate(rows):
130
+ for j, c_val in enumerate(cols):
131
+ ax = axs[i, j]
132
+ cell_models = [m for m in models if m.id[row_key] == r_val and m.id[col_key] == c_val]
133
+
134
+ for model in cell_models:
135
+ others = [f"{k}={v}" for k, v in model.id.items() if k not in [row_key, col_key]]
136
+ label = ", ".join(others)
137
+ model.plot_loss(ax, label, smooth_span)
138
+
139
+ format_ax(ax)
140
+
141
+
142
+ if i == 0:
143
+ ax.set_title(f"{col_key} = {c_val}")
144
+
145
+ if j == len(cols) - 1:
146
+ ax.text(1.02, 0.5, f"{row_key} = {r_val}",
147
+ transform=ax.transAxes, rotation=-90,
148
+ va="center", ha="left")
149
+
150
+ plt.subplots_adjust(left=0.06, right=0.96, top=0.90, bottom=0.12, wspace=0.10, hspace=0.18)
151
+ fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98)
152
+
153
+ handles, labels = axs[0, 0].get_legend_handles_labels()
154
+ if handles:
155
+ fig.legend(handles, labels, loc='lower center', ncol=len(handles),
156
+ bbox_to_anchor=(0.5, 0.02), frameon=False, fontsize=9)
157
+
158
+ fig.text(0.5, 0.07, "Training steps", ha='center', fontsize=10)
159
+ fig.text(0.02, 0.5, "Training loss", va='center', rotation='vertical', fontsize=10)
160
+
161
+
162
+
163
+ def plot_losses(dimension: int, models: list, title: str, save_path: str = None, smooth_span: int = 100):
164
+ """
165
+ Main entry point for plotting. Automatically detects if the plot should be 0D, 1D, or 2D
166
+ based on the number of variation parameters.
167
+
168
+ Args:
169
+ dimension (int): dimension of the plot
170
+ models (list): List of model objects.
171
+ title (str): The title of the plot.
172
+ save_path (str, optional): File path to save the figure (e.g., 'plot.png').
173
+ smooth_span (int): EWM span for smoothing. Defaults to 50.
174
+ """
175
+ set_style()
176
+
177
+ keys = list(models[0].id.keys())
178
+
179
+ if dimension == 0:
180
+ plot_0d(models, title, smooth_span)
181
+ elif dimension == 1:
182
+ plot_1d(models, title, keys[0], smooth_span)
183
+ else:
184
+ plot_2d_grid(models, title, keys[0], keys[1], smooth_span)
185
+
186
+ if save_path:
187
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
188
+ print(f"Plot saved to {save_path}")
189
+
190
+ plt.show()
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: zeroth-learn
3
+ Version: 0.1.0
4
+ Requires-Dist: numpy
5
+ Requires-Dist: pandas
6
+ Requires-Dist: matplotlib
7
+ Requires-Dist: cycler
@@ -0,0 +1,9 @@
1
+ zeroth/__init__.py,sha256=Zi4bPQQUNwI9sXGkaQFEZA84w0a4je3vdJJEIAsH3kU,260
2
+ zeroth/experiment.py,sha256=Ca60CyNjarDWTCaamvW_2WINAO6SZE4HQsPNIsM_HWA,4320
3
+ zeroth/losses.py,sha256=V9w_o9asXIWtEx_V7dWeurlv-jXGYjsjIc2zT1AQxK0,4071
4
+ zeroth/model.py,sha256=PufxZTHkDb1xe0ExjVy-MAhkR0LYm-ft-o0xlQSMPDc,5355
5
+ zeroth/plot_losses.py,sha256=fT3ipCWtgeBSmhALIn6IkHKcyKRH2ErCrSiX7FhI-nA,6503
6
+ zeroth_learn-0.1.0.dist-info/METADATA,sha256=DsSHGh2eGfTnA0UPg7aY7ir131FqGI2g58Zq0jQmcag,147
7
+ zeroth_learn-0.1.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
8
+ zeroth_learn-0.1.0.dist-info/top_level.txt,sha256=ZNQ6QAwK_P_ev_5dpk-HUbpu49Lu25HJ6kbv6f0mOyo,7
9
+ zeroth_learn-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ zeroth