zeroth-learn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zeroth/__init__.py +5 -0
- zeroth/experiment.py +134 -0
- zeroth/losses.py +121 -0
- zeroth/model.py +144 -0
- zeroth/plot_losses.py +190 -0
- zeroth_learn-0.1.0.dist-info/METADATA +7 -0
- zeroth_learn-0.1.0.dist-info/RECORD +9 -0
- zeroth_learn-0.1.0.dist-info/WHEEL +5 -0
- zeroth_learn-0.1.0.dist-info/top_level.txt +1 -0
zeroth/__init__.py
ADDED
zeroth/experiment.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from zeroth.utils.dataclasses_utils import config_serializer, generate_param_map, get_name, set_value_by_path
|
|
2
|
+
from zeroth.model import ModelConfig, Model
|
|
3
|
+
from zeroth.plot_losses import plot_losses
|
|
4
|
+
from zeroth.data import Data
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, replace
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import itertools
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class VariationConfig:
|
|
19
|
+
param: str
|
|
20
|
+
values: list
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class ExperimentConfig:
|
|
25
|
+
name: str
|
|
26
|
+
title: str
|
|
27
|
+
base_model: ModelConfig
|
|
28
|
+
variations: list[VariationConfig]
|
|
29
|
+
create_data: Callable
|
|
30
|
+
plot_dimension : int
|
|
31
|
+
|
|
32
|
+
def instantiate(self):
|
|
33
|
+
return Experiment(self)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Experiment:
|
|
37
|
+
"""Manages the full lifecycle of a deep learning experiment.
|
|
38
|
+
|
|
39
|
+
It handles data loading, model instantiation, training loops, and results visualization.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
name (str): Name of the experiment
|
|
43
|
+
title (str): Title of the graphs
|
|
44
|
+
models (list[Model]): List of models to train/compare
|
|
45
|
+
data (Data): The dataset wrapper.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: ExperimentConfig):
|
|
49
|
+
self.config = config
|
|
50
|
+
self.name: str = config.name
|
|
51
|
+
self.title: str = config.title
|
|
52
|
+
self.base_model_config: ModelConfig = config.base_model
|
|
53
|
+
self.models: list[Model] = generate_models(config.base_model, config.variations)
|
|
54
|
+
self.data: Data = config.create_data()
|
|
55
|
+
self.plot_dimension: int = config.plot_dimension
|
|
56
|
+
|
|
57
|
+
self.save_dir = os.path.join("results", self.name)
|
|
58
|
+
|
|
59
|
+
def launch(self, do_train, do_test, nb_print_train, do_plot_train, do_save):
|
|
60
|
+
"""
|
|
61
|
+
Executes the experiment pipeline.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
do_train (bool): Whether to run the training loop.
|
|
65
|
+
do_test (bool): Whether to run evaluation on test set.
|
|
66
|
+
nb_print_train (int): Number of logs to print during training.
|
|
67
|
+
do_plot_train (bool): If True, plots loss curves after training.
|
|
68
|
+
do_save (bool): if True, saves the plots and dataframes
|
|
69
|
+
"""
|
|
70
|
+
print(f"### Launching Experiment : {self.name} ###")
|
|
71
|
+
if do_train:
|
|
72
|
+
self.train(nb_print=nb_print_train, do_plot=do_plot_train, do_save=do_save)
|
|
73
|
+
if do_test:
|
|
74
|
+
self.test()
|
|
75
|
+
if do_save:
|
|
76
|
+
self.save_df()
|
|
77
|
+
|
|
78
|
+
def train(self, nb_print: int, do_plot: bool, do_save: bool):
|
|
79
|
+
for model in self.models:
|
|
80
|
+
model.train(self.data, nb_print)
|
|
81
|
+
|
|
82
|
+
if do_plot:
|
|
83
|
+
plot_path = None
|
|
84
|
+
if do_save:
|
|
85
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
86
|
+
plot_path = os.path.join(self.save_dir, "training_losses.png")
|
|
87
|
+
|
|
88
|
+
plot_losses(dimension=self.plot_dimension, models=self.models, title=self.title, save_path=plot_path)
|
|
89
|
+
|
|
90
|
+
def test(self):
|
|
91
|
+
for model in self.models:
|
|
92
|
+
model.test(self.data)
|
|
93
|
+
|
|
94
|
+
def save_df(self):
|
|
95
|
+
"""
|
|
96
|
+
saves the models parameters and their args
|
|
97
|
+
"""
|
|
98
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
99
|
+
print(f" Saving results to: {self.save_dir}")
|
|
100
|
+
|
|
101
|
+
data = [model.id | {"test_loss": model.test_loss, "test_accuracy": model.test_accuracy}
|
|
102
|
+
for model in self.models]
|
|
103
|
+
|
|
104
|
+
df = pd.DataFrame(data)
|
|
105
|
+
df.to_csv(os.path.join(self.save_dir, "models_accuracy.csv"), index_label="iteration")
|
|
106
|
+
|
|
107
|
+
config_path = os.path.join(self.save_dir, "config.json")
|
|
108
|
+
with open(config_path, "w") as f:
|
|
109
|
+
json.dump(self.config, f, default=config_serializer, indent=4)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def generate_models(base_model: ModelConfig, variations: list[VariationConfig]) -> list[Model]:
|
|
113
|
+
|
|
114
|
+
models = []
|
|
115
|
+
|
|
116
|
+
names = [v.param for v in variations]
|
|
117
|
+
values = [v.values for v in variations]
|
|
118
|
+
|
|
119
|
+
# key: param value: path
|
|
120
|
+
PARAM_MAP = generate_param_map(base_model)
|
|
121
|
+
|
|
122
|
+
for combination in itertools.product(*values):
|
|
123
|
+
id_ = {}
|
|
124
|
+
for key, val in zip(names, combination):
|
|
125
|
+
id_[key] = get_name(val)
|
|
126
|
+
current_model = base_model
|
|
127
|
+
for param_key, value in zip(names, combination):
|
|
128
|
+
path = PARAM_MAP[param_key]
|
|
129
|
+
current_model = set_value_by_path(current_model, path, value)
|
|
130
|
+
|
|
131
|
+
current_model = replace(current_model, id=id_)
|
|
132
|
+
models.append(current_model.instantiate())
|
|
133
|
+
|
|
134
|
+
return models
|
zeroth/losses.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Loss(ABC):
|
|
6
|
+
@staticmethod
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def compute_loss(Y_pred: np.ndarray, Y_true: np.ndarray) -> float:
|
|
9
|
+
"""
|
|
10
|
+
:param Y_pred: shape (out, batch)
|
|
11
|
+
:param Y_true: shape (out, batch)
|
|
12
|
+
:return: avg loss shape: float
|
|
13
|
+
"""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def compute_batch_losses(Y_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
|
|
19
|
+
"""
|
|
20
|
+
:param Y_pred: shape (out, batch)
|
|
21
|
+
:param Y_true: shape (out, batch)
|
|
22
|
+
:return: batch loss shape (batch, )
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def compute_perturbed_losses(pY_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
|
|
29
|
+
"""
|
|
30
|
+
:param pY_pred: (T, out, batch)
|
|
31
|
+
:param Y_true: (out, batch)
|
|
32
|
+
:return: perturbed loss (nb_params, T)
|
|
33
|
+
"""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def compute_gradient_wrt_preactivation(last_layer, Y_pred: np.ndarray, Y_true: np.ndarray) -> np.ndarray:
|
|
39
|
+
"""
|
|
40
|
+
:param last_layer: last_layer of the network
|
|
41
|
+
:param Y_pred: shape (out, batch)
|
|
42
|
+
:param Y_true: shape (out, batch)
|
|
43
|
+
:return: batch loss shape (batch, )
|
|
44
|
+
"""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def compute_losses_for_zeroth_order(self, pY_pred: np.ndarray, Y_true: np.ndarray) -> tuple[float, np.ndarray]:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def compute_losses_for_first_order(self, last_layer, Y_pred: np.ndarray, Y_true: np.ndarray):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class MSE(Loss):
|
|
58
|
+
name = "MSE"
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def compute_loss(Y_pred, Y_true) -> float:
|
|
62
|
+
return np.mean((Y_pred - Y_true) ** 2, axis=(0, 1))
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def compute_batch_losses(Y_pred, Y_true):
|
|
66
|
+
return np.mean((Y_pred - Y_true) ** 2, axis=0)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def compute_perturbed_losses(pY_pred, Y_true):
|
|
70
|
+
return np.mean((pY_pred - Y_true) ** 2, axis=1)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def compute_gradient_wrt_activation(Y_pred, Y_true):
|
|
74
|
+
return 2 * np.mean(Y_pred - Y_true, axis=0)
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true):
|
|
78
|
+
dL_dA = 2 * (Y_pred - Y_true) / Y_true.size
|
|
79
|
+
dL_dZ = dL_dA * last_layer.df(last_layer.Z)
|
|
80
|
+
return dL_dZ
|
|
81
|
+
|
|
82
|
+
def compute_losses_for_zeroth_order(self, pY_pred, Y_true):
|
|
83
|
+
return (self.compute_loss(pY_pred[0], Y_true),
|
|
84
|
+
self.compute_perturbed_losses(pY_pred, Y_true))
|
|
85
|
+
|
|
86
|
+
def compute_losses_for_first_order(self, last_layer, Y_pred, Y_true):
|
|
87
|
+
return self.compute_loss(Y_pred, Y_true), self.compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class CrossEntropy(Loss):
|
|
91
|
+
name = "CrossEntropy"
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def compute_loss(Y_pred: np.ndarray, Y_true: np.ndarray) -> float:
|
|
95
|
+
idx = np.arange(Y_pred.shape[1])
|
|
96
|
+
return - np.mean(np.log(1e-8 + Y_pred[Y_true, idx]))
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def compute_batch_losses(Y_pred: np.ndarray, Y_true: np.ndarray):
|
|
100
|
+
idx = np.arange(Y_pred.shape[1])
|
|
101
|
+
return - np.log(1e-8 + Y_pred[Y_true, idx])
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def compute_perturbed_losses(pY_pred: np.ndarray, Y_true: np.ndarray):
|
|
105
|
+
idx = np.arange(pY_pred.shape[2])
|
|
106
|
+
return - np.mean(np.log(1e-8 + pY_pred[:, Y_true, idx]), axis=1)
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def compute_gradient_wrt_preactivation(last_layer, Y_pred: np.ndarray, Y_true: np.ndarray):
|
|
110
|
+
dZ = Y_pred.copy()
|
|
111
|
+
batch_size = Y_pred.shape[1]
|
|
112
|
+
dZ[Y_true[0], np.arange(batch_size)] -= 1.0
|
|
113
|
+
return dZ
|
|
114
|
+
|
|
115
|
+
def compute_losses_for_zeroth_order(self, pY_pred: np.ndarray, Y_true: np.ndarray):
|
|
116
|
+
avg_loss = self.compute_loss(pY_pred[0], Y_true)
|
|
117
|
+
p_loss = self.compute_perturbed_losses(pY_pred, Y_true)
|
|
118
|
+
return avg_loss, p_loss
|
|
119
|
+
|
|
120
|
+
def compute_losses_for_first_order(self, last_layer, Y_pred: np.ndarray, Y_true: np.ndarray) -> tuple[float, np.ndarray]:
|
|
121
|
+
return self.compute_loss(Y_pred, Y_true), self.compute_gradient_wrt_preactivation(last_layer, Y_pred, Y_true)
|
zeroth/model.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from matplotlib.axes import Axes
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from abc import ABC
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from zeroth.zeroth_order import ZerothOrderNeuralNetwork, ZerothOrderOptimizerConfig, GradientEstimatorConfig, \
|
|
10
|
+
GradientEstimator
|
|
11
|
+
from zeroth.first_order import FirstOrderNeuralNetwork, FirstOrderOptimizerConfig
|
|
12
|
+
from zeroth.abstract import BlackBox, NeuralNetworkConfig, Optimizer
|
|
13
|
+
from zeroth.losses import Loss
|
|
14
|
+
from zeroth.data import Data
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class ModelConfig:
|
|
19
|
+
"""
|
|
20
|
+
name (str): Name of the model (used for display and saving).
|
|
21
|
+
loss (Loss): The loss class.
|
|
22
|
+
metric (Callable): Function (Y_pred, Y_true) -> score (e.g., accuracy).
|
|
23
|
+
batch_size (int): Number of samples per gradient update.
|
|
24
|
+
plot_results (Callable): Function to visualize test results.
|
|
25
|
+
nb_epochs (int): Number of passes through the entire dataset.
|
|
26
|
+
"""
|
|
27
|
+
name: str
|
|
28
|
+
id: dict
|
|
29
|
+
loss: Loss
|
|
30
|
+
metric: Callable
|
|
31
|
+
batch_size: int
|
|
32
|
+
nb_epochs: int
|
|
33
|
+
|
|
34
|
+
def instantiate(self):
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class FirstOrderModelConfig(ModelConfig):
|
|
40
|
+
neural_network_config: NeuralNetworkConfig
|
|
41
|
+
optimizer_config: FirstOrderOptimizerConfig
|
|
42
|
+
|
|
43
|
+
def instantiate(self):
|
|
44
|
+
return FirstOrderModel(self)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class ZerothOrderModelConfig(ModelConfig):
|
|
49
|
+
neural_network_config: NeuralNetworkConfig
|
|
50
|
+
optimizer_config: ZerothOrderOptimizerConfig
|
|
51
|
+
gradient_estimator_config: GradientEstimatorConfig
|
|
52
|
+
|
|
53
|
+
def instantiate(self):
|
|
54
|
+
return ZerothOrderModel(self)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Model(ABC):
|
|
58
|
+
"""
|
|
59
|
+
Base class orchestrating the training and testing loop.
|
|
60
|
+
|
|
61
|
+
This class abstracts the abstract logic for training
|
|
62
|
+
regardless of the underlying engine (Backpropagation or zeroth_order).
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, config: ModelConfig):
|
|
66
|
+
|
|
67
|
+
self.name: str = config.name
|
|
68
|
+
self.id: dict = config.id
|
|
69
|
+
self.loss: Loss = config.loss
|
|
70
|
+
self.metric: Callable = config.metric
|
|
71
|
+
self.batch_size: int = config.batch_size
|
|
72
|
+
self.nb_epochs: int = config.nb_epochs
|
|
73
|
+
self.neural_network: BlackBox | None = None
|
|
74
|
+
self.optimizer: Optimizer | None = None
|
|
75
|
+
|
|
76
|
+
self.train_loss: np.ndarray = np.array([])
|
|
77
|
+
self.test_loss: float | None = None
|
|
78
|
+
self.test_accuracy: float | None = None
|
|
79
|
+
|
|
80
|
+
def train(self, data: Data, nb_print: int = 0):
|
|
81
|
+
"""Runs the training loop over the dataset.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
data (Data): The dataset object containing train/test sets.
|
|
85
|
+
nb_print (int): Number of progress updates to print per epoch.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
np.ndarray: Array of loss values recorded at each step (for plotting).
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
nb_batches = data.nb_data // self.batch_size
|
|
92
|
+
|
|
93
|
+
self.train_loss = np.zeros(self.nb_epochs * nb_batches, dtype=np.float64)
|
|
94
|
+
|
|
95
|
+
print_indexes = np.linspace(0, nb_batches - 1, nb_print).astype(int)
|
|
96
|
+
print(f" Training {self.id} Model")
|
|
97
|
+
for epoch_idx in range(self.nb_epochs):
|
|
98
|
+
print(f" epoch n°{epoch_idx + 1} out of {self.nb_epochs}")
|
|
99
|
+
data.prepare_data(self.batch_size)
|
|
100
|
+
for batch_idx in range(nb_batches):
|
|
101
|
+
X_train, Y_train = data.X_train[batch_idx], data.Y_train[batch_idx]
|
|
102
|
+
avg_loss = self.optimizer.do_descent(self.neural_network, self.loss, X_train, Y_train)
|
|
103
|
+
self.train_loss[epoch_idx * nb_batches + batch_idx] = avg_loss
|
|
104
|
+
|
|
105
|
+
if batch_idx in print_indexes:
|
|
106
|
+
print(f" batch n°{batch_idx + 1} out of {nb_batches}, "
|
|
107
|
+
f"loss : {np.round(self.train_loss[epoch_idx * nb_batches + batch_idx], 3)}")
|
|
108
|
+
self.test(data)
|
|
109
|
+
|
|
110
|
+
def plot_loss(self, ax: Axes, label: str, smooth_span: int = 50):
|
|
111
|
+
ax.plot(self.train_loss, alpha=0.25, linewidth=1.0)
|
|
112
|
+
smooth = self.smooth_curve(self.train_loss, smooth_span)
|
|
113
|
+
ax.plot(smooth, label=label, linewidth=2.5)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def smooth_curve(loss: np.ndarray, smooth_span: int) -> np.ndarray:
|
|
117
|
+
return np.exp(pd.Series(np.log(loss)).ewm(span=smooth_span, adjust=True).mean())
|
|
118
|
+
|
|
119
|
+
def test(self, data):
|
|
120
|
+
X_test, Y_true = data.X_test, data.Y_test # (in, batch), (out, batch)
|
|
121
|
+
Y_pred = self.neural_network.forward(X_test) # (out, batch)
|
|
122
|
+
|
|
123
|
+
self.test_accuracy = self.metric(Y_pred, Y_true)
|
|
124
|
+
self.test_loss = self.loss.compute_loss(Y_pred, Y_true)
|
|
125
|
+
|
|
126
|
+
print(f" {self.id} accuracy : {self.test_accuracy}, loss : {self.test_loss}")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class FirstOrderModel(Model):
|
|
130
|
+
def __init__(self, config: FirstOrderModelConfig):
|
|
131
|
+
super().__init__(config)
|
|
132
|
+
|
|
133
|
+
self.neural_network = FirstOrderNeuralNetwork(config.neural_network_config)
|
|
134
|
+
self.optimizer = config.optimizer_config.instantiate()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class ZerothOrderModel(Model):
|
|
138
|
+
def __init__(self, config: ZerothOrderModelConfig):
|
|
139
|
+
super().__init__(config)
|
|
140
|
+
|
|
141
|
+
self.neural_network: ZerothOrderNeuralNetwork = ZerothOrderNeuralNetwork(config.neural_network_config)
|
|
142
|
+
nb_params = self.neural_network.params.nb_params
|
|
143
|
+
self.gradient_estimator: GradientEstimator = config.gradient_estimator_config.instantiate(nb_params)
|
|
144
|
+
self.optimizer: Optimizer = config.optimizer_config.instantiate(self.gradient_estimator)
|
zeroth/plot_losses.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from zeroth.model import Model
|
|
2
|
+
|
|
3
|
+
from matplotlib.axes import Axes
|
|
4
|
+
from cycler import cycler
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
|
|
8
|
+
def set_style():
|
|
9
|
+
plt.rcParams.update({
|
|
10
|
+
"font.family": "serif",
|
|
11
|
+
"font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
|
|
12
|
+
"font.size": 10,
|
|
13
|
+
"axes.labelsize": 10,
|
|
14
|
+
"axes.titlesize": 11,
|
|
15
|
+
"axes.linewidth": 0.8,
|
|
16
|
+
"axes.spines.top": False,
|
|
17
|
+
"axes.spines.right": False,
|
|
18
|
+
"xtick.labelsize": 9,
|
|
19
|
+
"ytick.labelsize": 9,
|
|
20
|
+
"xtick.direction": "out",
|
|
21
|
+
"ytick.direction": "out",
|
|
22
|
+
"xtick.major.size": 3,
|
|
23
|
+
"ytick.major.size": 3,
|
|
24
|
+
"lines.linewidth": 2.2,
|
|
25
|
+
"legend.fontsize": 9,
|
|
26
|
+
"legend.frameon": False,
|
|
27
|
+
"grid.color": "0.85",
|
|
28
|
+
"grid.linewidth": 0.6,
|
|
29
|
+
"grid.linestyle": "-",
|
|
30
|
+
"savefig.dpi": 300,
|
|
31
|
+
"savefig.bbox": "tight",
|
|
32
|
+
})
|
|
33
|
+
plt.rcParams["axes.prop_cycle"] = cycler(color=[
|
|
34
|
+
"#4477AA", "#EE6677", "#228833",
|
|
35
|
+
"#CCBB44", "#66CCEE", "#AA3377"
|
|
36
|
+
])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def format_ax(ax: Axes):
|
|
40
|
+
ax.set_axisbelow(True)
|
|
41
|
+
ax.set_yscale('log')
|
|
42
|
+
ax.grid(True, which="major", axis="y")
|
|
43
|
+
ax.grid(False, axis="x")
|
|
44
|
+
ax.spines['top'].set_visible(False)
|
|
45
|
+
ax.spines['right'].set_visible(False)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def plot_0d(models: list[Model], title: str, smooth_span: int = 50):
|
|
49
|
+
"""
|
|
50
|
+
Plots a single graph overlaying multiple models that share the same hyperparameters.
|
|
51
|
+
"""
|
|
52
|
+
fig, ax = plt.subplots(figsize=(5.5, 3.5))
|
|
53
|
+
|
|
54
|
+
for model in models:
|
|
55
|
+
others = [f"{k}={v}" for k, v in model.id.items()]
|
|
56
|
+
label = ", ".join(others)
|
|
57
|
+
model.plot_loss(ax, label, smooth_span)
|
|
58
|
+
|
|
59
|
+
format_ax(ax)
|
|
60
|
+
|
|
61
|
+
plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.2)
|
|
62
|
+
fig.suptitle(title, fontweight='bold', fontsize=12)
|
|
63
|
+
ax.set_xlabel("Training steps")
|
|
64
|
+
ax.set_ylabel("Training loss")
|
|
65
|
+
format_ax(ax)
|
|
66
|
+
|
|
67
|
+
handles, labels = ax.get_legend_handles_labels()
|
|
68
|
+
|
|
69
|
+
if handles:
|
|
70
|
+
fig.legend(handles, labels, loc='lower center', ncol=len(handles),
|
|
71
|
+
bbox_to_anchor=(0.5, 0), frameon=False, fontsize=9)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def plot_1d(models: list[Model], title: str, key: str, smooth_span: int = 50):
|
|
75
|
+
"""
|
|
76
|
+
Plots a row of subplots, varying one hyperparameter (key) across columns.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
cols = list(dict.fromkeys([m.id[key] for m in models]))
|
|
80
|
+
n_models = len(cols)
|
|
81
|
+
fig, axs = plt.subplots(1, n_models, figsize=(4.5 * n_models, 3.5), sharey=True)
|
|
82
|
+
|
|
83
|
+
for i, val in enumerate(cols):
|
|
84
|
+
ax = axs[i]
|
|
85
|
+
cell_models = [m for m in models if m.id[key] == val]
|
|
86
|
+
|
|
87
|
+
for model in cell_models:
|
|
88
|
+
others = [f"{k}={v}" for k, v in model.id.items()
|
|
89
|
+
if k != key]
|
|
90
|
+
label = ", ".join(others)
|
|
91
|
+
model.plot_loss(ax, label, smooth_span)
|
|
92
|
+
|
|
93
|
+
format_ax(ax)
|
|
94
|
+
|
|
95
|
+
ax.set_title(f"{key} = {val}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
fig.text(0.5, 0.1, "Training steps", ha='center', fontsize=10)
|
|
99
|
+
fig.text(0.01, 0.5, "Training loss", va='center', rotation='vertical', fontsize=10)
|
|
100
|
+
|
|
101
|
+
plt.subplots_adjust(left=0.05, right=0.96, top=0.85, bottom=0.2, wspace=0.10, hspace=0.18)
|
|
102
|
+
fig.suptitle(title, fontweight='bold', fontsize=12)
|
|
103
|
+
|
|
104
|
+
handles, labels = axs[0].get_legend_handles_labels()
|
|
105
|
+
|
|
106
|
+
if handles:
|
|
107
|
+
fig.legend(handles, labels, loc='lower center', ncol=len(handles),
|
|
108
|
+
bbox_to_anchor=(0.5, 0), frameon=False, fontsize=9)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def plot_2d_grid(models: list[Model], title: str, row_key: str, col_key: str, smooth_span: int = 50):
|
|
112
|
+
"""
|
|
113
|
+
Plots a grid of subplots varying two hyperparameters: one across rows, one across columns.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
models (list): List of model objects.
|
|
117
|
+
title (str): The title of the plot.
|
|
118
|
+
row_key (str): The hyperparameter key changing across rows.
|
|
119
|
+
col_key (str): The hyperparameter key changing across columns.
|
|
120
|
+
smooth_span (int): The span for the EWM average.
|
|
121
|
+
"""
|
|
122
|
+
rows = list(dict.fromkeys([m.id[row_key] for m in models]))
|
|
123
|
+
cols = list(dict.fromkeys([m.id[col_key] for m in models]))
|
|
124
|
+
|
|
125
|
+
fig, axs = plt.subplots(len(rows), len(cols),
|
|
126
|
+
figsize=(4.5 * len(cols), 3.5 * len(rows)),
|
|
127
|
+
sharex=True, sharey=True, squeeze=False)
|
|
128
|
+
|
|
129
|
+
for i, r_val in enumerate(rows):
|
|
130
|
+
for j, c_val in enumerate(cols):
|
|
131
|
+
ax = axs[i, j]
|
|
132
|
+
cell_models = [m for m in models if m.id[row_key] == r_val and m.id[col_key] == c_val]
|
|
133
|
+
|
|
134
|
+
for model in cell_models:
|
|
135
|
+
others = [f"{k}={v}" for k, v in model.id.items() if k not in [row_key, col_key]]
|
|
136
|
+
label = ", ".join(others)
|
|
137
|
+
model.plot_loss(ax, label, smooth_span)
|
|
138
|
+
|
|
139
|
+
format_ax(ax)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
if i == 0:
|
|
143
|
+
ax.set_title(f"{col_key} = {c_val}")
|
|
144
|
+
|
|
145
|
+
if j == len(cols) - 1:
|
|
146
|
+
ax.text(1.02, 0.5, f"{row_key} = {r_val}",
|
|
147
|
+
transform=ax.transAxes, rotation=-90,
|
|
148
|
+
va="center", ha="left")
|
|
149
|
+
|
|
150
|
+
plt.subplots_adjust(left=0.06, right=0.96, top=0.90, bottom=0.12, wspace=0.10, hspace=0.18)
|
|
151
|
+
fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98)
|
|
152
|
+
|
|
153
|
+
handles, labels = axs[0, 0].get_legend_handles_labels()
|
|
154
|
+
if handles:
|
|
155
|
+
fig.legend(handles, labels, loc='lower center', ncol=len(handles),
|
|
156
|
+
bbox_to_anchor=(0.5, 0.02), frameon=False, fontsize=9)
|
|
157
|
+
|
|
158
|
+
fig.text(0.5, 0.07, "Training steps", ha='center', fontsize=10)
|
|
159
|
+
fig.text(0.02, 0.5, "Training loss", va='center', rotation='vertical', fontsize=10)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def plot_losses(dimension: int, models: list, title: str, save_path: str = None, smooth_span: int = 100):
|
|
164
|
+
"""
|
|
165
|
+
Main entry point for plotting. Automatically detects if the plot should be 0D, 1D, or 2D
|
|
166
|
+
based on the number of variation parameters.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
dimension (int): dimension of the plot
|
|
170
|
+
models (list): List of model objects.
|
|
171
|
+
title (str): The title of the plot.
|
|
172
|
+
save_path (str, optional): File path to save the figure (e.g., 'plot.png').
|
|
173
|
+
smooth_span (int): EWM span for smoothing. Defaults to 50.
|
|
174
|
+
"""
|
|
175
|
+
set_style()
|
|
176
|
+
|
|
177
|
+
keys = list(models[0].id.keys())
|
|
178
|
+
|
|
179
|
+
if dimension == 0:
|
|
180
|
+
plot_0d(models, title, smooth_span)
|
|
181
|
+
elif dimension == 1:
|
|
182
|
+
plot_1d(models, title, keys[0], smooth_span)
|
|
183
|
+
else:
|
|
184
|
+
plot_2d_grid(models, title, keys[0], keys[1], smooth_span)
|
|
185
|
+
|
|
186
|
+
if save_path:
|
|
187
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
188
|
+
print(f"Plot saved to {save_path}")
|
|
189
|
+
|
|
190
|
+
plt.show()
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
zeroth/__init__.py,sha256=Zi4bPQQUNwI9sXGkaQFEZA84w0a4je3vdJJEIAsH3kU,260
|
|
2
|
+
zeroth/experiment.py,sha256=Ca60CyNjarDWTCaamvW_2WINAO6SZE4HQsPNIsM_HWA,4320
|
|
3
|
+
zeroth/losses.py,sha256=V9w_o9asXIWtEx_V7dWeurlv-jXGYjsjIc2zT1AQxK0,4071
|
|
4
|
+
zeroth/model.py,sha256=PufxZTHkDb1xe0ExjVy-MAhkR0LYm-ft-o0xlQSMPDc,5355
|
|
5
|
+
zeroth/plot_losses.py,sha256=fT3ipCWtgeBSmhALIn6IkHKcyKRH2ErCrSiX7FhI-nA,6503
|
|
6
|
+
zeroth_learn-0.1.0.dist-info/METADATA,sha256=DsSHGh2eGfTnA0UPg7aY7ir131FqGI2g58Zq0jQmcag,147
|
|
7
|
+
zeroth_learn-0.1.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
|
|
8
|
+
zeroth_learn-0.1.0.dist-info/top_level.txt,sha256=ZNQ6QAwK_P_ev_5dpk-HUbpu49Lu25HJ6kbv6f0mOyo,7
|
|
9
|
+
zeroth_learn-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
zeroth
|