erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
|
|
10
|
+
from erictransformer.args.eric_args import EricTrainArgs
|
|
11
|
+
from erictransformer.exceptions import EricIOError
|
|
12
|
+
from erictransformer.eric_tracker.save_plot import (
|
|
13
|
+
save_lr_plot,
|
|
14
|
+
save_metric_plots,
|
|
15
|
+
save_train_eval_loss_plot,
|
|
16
|
+
)
|
|
17
|
+
from erictransformer.utils import make_dir
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class TrackerState:
|
|
22
|
+
current_step: int = 0
|
|
23
|
+
last_eval_step: int = 0
|
|
24
|
+
original_eval_loss: float = 0.0
|
|
25
|
+
eval_loss: float = 0.0
|
|
26
|
+
best_eval_loss: Optional[float] = None
|
|
27
|
+
eval_loss_improvement: float = 0.0
|
|
28
|
+
train_loss: float = 0.0
|
|
29
|
+
last_checkpoint_step: Optional[int] = None
|
|
30
|
+
epoch: int = 1
|
|
31
|
+
lr: float = 0
|
|
32
|
+
metrics: Optional[dict] = None
|
|
33
|
+
num_tokens: int = 0
|
|
34
|
+
|
|
35
|
+
def to_dict(self):
|
|
36
|
+
return asdict(self)
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def from_dict(cls, data):
|
|
40
|
+
return cls(**data)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TrackerPaths:
|
|
45
|
+
log_path: str
|
|
46
|
+
best_model_path: str
|
|
47
|
+
train_args_path: str
|
|
48
|
+
checkpoint_path: str
|
|
49
|
+
checkpoint_state_path: str
|
|
50
|
+
best_model_state_path: str
|
|
51
|
+
lr_sched_path: str
|
|
52
|
+
optim_path: str
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class EricTracker:
|
|
56
|
+
def __init__(
|
|
57
|
+
self, args: EricTrainArgs, train_steps: int, out_dir: str, tracker_state
|
|
58
|
+
):
|
|
59
|
+
self.train_args = args
|
|
60
|
+
self.train_steps = train_steps
|
|
61
|
+
|
|
62
|
+
self.out_dir = out_dir
|
|
63
|
+
|
|
64
|
+
self.tracker_paths = self.__init_dirs(out_dir)
|
|
65
|
+
self.log_steps, self.eval_steps, self.checkpoint_steps = (
|
|
66
|
+
self.__get_log_eval_checkpoint_steps(self.train_args)
|
|
67
|
+
)
|
|
68
|
+
self._write_train_args()
|
|
69
|
+
|
|
70
|
+
self.tracker_state_history = []
|
|
71
|
+
|
|
72
|
+
if tracker_state is not None:
|
|
73
|
+
self.state = TrackerState(**tracker_state)
|
|
74
|
+
self.state.current_step += 1
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
self.state = TrackerState()
|
|
78
|
+
|
|
79
|
+
self.epoch_steps = []
|
|
80
|
+
self.eval_steps_history = []
|
|
81
|
+
self.pbar = tqdm(
|
|
82
|
+
total=train_steps, initial=self.state.current_step, desc="Model Training"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def __init_dirs(self, out_dir):
|
|
86
|
+
try:
|
|
87
|
+
best_model_path = os.path.join(out_dir, "best_model")
|
|
88
|
+
make_dir(best_model_path)
|
|
89
|
+
|
|
90
|
+
best_model_state_path = os.path.join(
|
|
91
|
+
best_model_path, "eric_tracker_state.json"
|
|
92
|
+
)
|
|
93
|
+
open(best_model_state_path, "a").close()
|
|
94
|
+
|
|
95
|
+
log_path = os.path.join(out_dir, "logs.jsonl")
|
|
96
|
+
open(log_path, "a").close()
|
|
97
|
+
|
|
98
|
+
checkpoint_path = os.path.join(out_dir, "checkpoint")
|
|
99
|
+
make_dir(checkpoint_path)
|
|
100
|
+
|
|
101
|
+
train_args_path = os.path.join(out_dir, "checkpoint/train_args.json")
|
|
102
|
+
open(train_args_path, "a").close()
|
|
103
|
+
|
|
104
|
+
checkpoint_state_path = os.path.join(
|
|
105
|
+
checkpoint_path, "eric_tracker_state.json"
|
|
106
|
+
)
|
|
107
|
+
open(checkpoint_state_path, "a").close()
|
|
108
|
+
|
|
109
|
+
lr_sched_path = os.path.join(checkpoint_path, "lr_sched.pt")
|
|
110
|
+
open(lr_sched_path, "a").close()
|
|
111
|
+
|
|
112
|
+
optim_path = os.path.join(checkpoint_path, "optim.pt")
|
|
113
|
+
open(optim_path, "a").close()
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise EricIOError(
|
|
116
|
+
f"Error initializing tracker directories or files: {str(e)}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return TrackerPaths(
|
|
120
|
+
best_model_state_path=best_model_state_path,
|
|
121
|
+
best_model_path=best_model_path,
|
|
122
|
+
log_path=log_path,
|
|
123
|
+
train_args_path=train_args_path,
|
|
124
|
+
checkpoint_path=checkpoint_path,
|
|
125
|
+
checkpoint_state_path=checkpoint_state_path,
|
|
126
|
+
lr_sched_path=lr_sched_path,
|
|
127
|
+
optim_path=optim_path,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def __get_log_eval_checkpoint_steps(self, args: EricTrainArgs):
|
|
131
|
+
def resolve_steps(value):
|
|
132
|
+
if value >= 1:
|
|
133
|
+
return math.ceil(value)
|
|
134
|
+
return 0
|
|
135
|
+
|
|
136
|
+
return (
|
|
137
|
+
resolve_steps(args.log_steps) or 1,
|
|
138
|
+
resolve_steps(args.eval_steps),
|
|
139
|
+
resolve_steps(args.checkpoint_steps),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def _write_train_args(self):
|
|
143
|
+
try:
|
|
144
|
+
with open(self.tracker_paths.train_args_path, "w") as f:
|
|
145
|
+
json.dump(asdict(self.train_args), f, indent=2)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
raise EricIOError(
|
|
148
|
+
f"Error writing train args to {self.tracker_paths.train_args_path}: {e}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def time_to_eval(self):
|
|
152
|
+
return bool(
|
|
153
|
+
self.eval_steps
|
|
154
|
+
and (
|
|
155
|
+
self.state.current_step % self.eval_steps == 0
|
|
156
|
+
or self.state.current_step == 0
|
|
157
|
+
or self.state.current_step + 1 == self.train_steps
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def time_to_checkpoint(self):
|
|
162
|
+
if not self.checkpoint_steps:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
return (
|
|
166
|
+
self.state.current_step % self.checkpoint_steps == 0
|
|
167
|
+
or self.state.current_step + 1 == self.train_steps
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def time_to_log(self):
|
|
171
|
+
return (
|
|
172
|
+
self.state.current_step % self.log_steps == 0
|
|
173
|
+
or self.state.current_step == 0
|
|
174
|
+
or self.state.current_step + 1 == self.train_steps
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def set_train_loss(self, loss: float):
|
|
178
|
+
self.state.train_loss = loss
|
|
179
|
+
|
|
180
|
+
def set_eval_loss(self, eval_loss):
|
|
181
|
+
is_best_model = False
|
|
182
|
+
|
|
183
|
+
if self.state.current_step == 0:
|
|
184
|
+
self.state.original_eval_loss = eval_loss
|
|
185
|
+
|
|
186
|
+
self.state.eval_loss = eval_loss
|
|
187
|
+
|
|
188
|
+
if self.state.best_eval_loss is None or eval_loss < self.state.best_eval_loss:
|
|
189
|
+
self.state.best_eval_loss = eval_loss
|
|
190
|
+
is_best_model = True
|
|
191
|
+
self.state.eval_loss_improvement = self.state.original_eval_loss - eval_loss
|
|
192
|
+
self.state.last_eval_step = self.state.current_step
|
|
193
|
+
|
|
194
|
+
self.eval_steps_history.append(self.state.current_step)
|
|
195
|
+
|
|
196
|
+
return is_best_model
|
|
197
|
+
|
|
198
|
+
def set_metrics(self, metrics: dict):
|
|
199
|
+
self.state.metrics = metrics
|
|
200
|
+
|
|
201
|
+
def mark_epoch(self):
|
|
202
|
+
self.epoch_steps.append(self.state.current_step)
|
|
203
|
+
self.state.epoch += 1
|
|
204
|
+
|
|
205
|
+
def step(self, loss: float, lr: float, num_tokens: int):
|
|
206
|
+
self.state.lr = lr
|
|
207
|
+
self.state.num_tokens += num_tokens
|
|
208
|
+
|
|
209
|
+
if self.time_to_checkpoint():
|
|
210
|
+
try:
|
|
211
|
+
with open(self.tracker_paths.checkpoint_state_path, "w") as f:
|
|
212
|
+
self.state.last_checkpoint_step = self.state.current_step
|
|
213
|
+
json.dump(self.state.to_dict(), f, indent=2)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
raise EricIOError(
|
|
216
|
+
f"Error writing checkpoint to {self.tracker_paths.checkpoint_state_path}: {e}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if self.time_to_log() or self.time_to_checkpoint():
|
|
220
|
+
self.set_train_loss(loss.item())
|
|
221
|
+
self.tracker_state_history.append(copy.deepcopy(self.state))
|
|
222
|
+
self._save_state(self.tracker_paths.log_path)
|
|
223
|
+
save_train_eval_loss_plot(
|
|
224
|
+
self.tracker_state_history, self.eval_steps_history, self.out_dir
|
|
225
|
+
)
|
|
226
|
+
save_lr_plot(self.tracker_state_history, self.out_dir)
|
|
227
|
+
if self.eval_steps_history: # if eval has happened
|
|
228
|
+
save_metric_plots(self.tracker_state_history, self.out_dir)
|
|
229
|
+
self.pbar.set_postfix(
|
|
230
|
+
last_eval_step=self.state.last_eval_step,
|
|
231
|
+
eval_loss_improvement=f"{self.state.eval_loss_improvement:.4f}",
|
|
232
|
+
eval_loss=self.state.eval_loss,
|
|
233
|
+
loss=f"{loss:.4f}",
|
|
234
|
+
num_tokens=self.state.num_tokens,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
self.state.current_step += 1
|
|
238
|
+
|
|
239
|
+
def optim_step(self):
|
|
240
|
+
self.pbar.update(1)
|
|
241
|
+
self.pbar.refresh() # force update
|
|
242
|
+
|
|
243
|
+
def close(self):
|
|
244
|
+
self.pbar.close()
|
|
245
|
+
|
|
246
|
+
def _save_state(self, save_path):
|
|
247
|
+
try:
|
|
248
|
+
with open(save_path, "a") as f:
|
|
249
|
+
to_save = self.state.to_dict()
|
|
250
|
+
del to_save[
|
|
251
|
+
"original_eval_loss"
|
|
252
|
+
] # deleted this since it's static. Maybe we should include it anyways
|
|
253
|
+
json.dump(to_save, f, separators=(",", ":"), allow_nan=True)
|
|
254
|
+
f.write("\n")
|
|
255
|
+
except Exception as e:
|
|
256
|
+
raise EricIOError(f"Error writing state to {save_path} {e}")
|
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
import matplotlib.patheffects as pe
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
from matplotlib.ticker import MaxNLocator
|
|
10
|
+
|
|
11
|
+
from erictransformer.exceptions.eric_exceptions import EricPlotError
|
|
12
|
+
|
|
13
|
+
ERIC_RED = "#d62828"
|
|
14
|
+
ERIC_BLUE = "#1d84e2"
|
|
15
|
+
MID_GREY = "#4c4c4c"
|
|
16
|
+
BG_LIGHT = "#fafafa"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class StyleConfig:
|
|
21
|
+
dpi: int = 150
|
|
22
|
+
fig_size: Tuple[float, float] = (7.5, 4.5)
|
|
23
|
+
|
|
24
|
+
colour_red: str = ERIC_RED
|
|
25
|
+
colour_blue: str = ERIC_BLUE
|
|
26
|
+
colour_light_grey: str = BG_LIGHT
|
|
27
|
+
colour_grey: str = MID_GREY
|
|
28
|
+
|
|
29
|
+
line_width: float = 2.5
|
|
30
|
+
fill_alpha: float = 0.07
|
|
31
|
+
line_shadow_alpha: float = 0.25
|
|
32
|
+
line_shadow_offset: Tuple[float, float] = (-0.5, -0.5)
|
|
33
|
+
|
|
34
|
+
marker_shape: str = "o"
|
|
35
|
+
marker_size: int = 3
|
|
36
|
+
marker_facecolor: str = "white"
|
|
37
|
+
marker_edgecolor: str = "white"
|
|
38
|
+
marker_edgewidth: float = 1.0
|
|
39
|
+
marker_zorder: int = 1
|
|
40
|
+
|
|
41
|
+
extrema_size: int = 120
|
|
42
|
+
extrema_up_marker: str = "^"
|
|
43
|
+
extrema_down_marker: str = "v"
|
|
44
|
+
extrema_edgewidth: float = 1.5
|
|
45
|
+
|
|
46
|
+
summary_xy: Tuple[float, float] = (0.50, 0.98)
|
|
47
|
+
summary_fontsize: int = 8
|
|
48
|
+
summary_box_kwargs: Dict[str, object] = field(
|
|
49
|
+
default_factory=lambda: dict(
|
|
50
|
+
boxstyle="round,pad=0.3",
|
|
51
|
+
facecolor=BG_LIGHT,
|
|
52
|
+
edgecolor=MID_GREY,
|
|
53
|
+
alpha=0.85,
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
grid_alpha: float = 0.15
|
|
58
|
+
rc_extra: Dict[str, object] = field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
def rc_params(self) -> Dict[str, object]:
|
|
61
|
+
base = {
|
|
62
|
+
"figure.dpi": self.dpi,
|
|
63
|
+
"figure.facecolor": self.colour_light_grey,
|
|
64
|
+
"axes.facecolor": self.colour_light_grey,
|
|
65
|
+
"axes.edgecolor": "none",
|
|
66
|
+
"grid.alpha": self.grid_alpha,
|
|
67
|
+
"axes.labelcolor": self.colour_grey,
|
|
68
|
+
"text.color": self.colour_grey,
|
|
69
|
+
"xtick.color": self.colour_grey,
|
|
70
|
+
"ytick.color": self.colour_grey,
|
|
71
|
+
}
|
|
72
|
+
base.update(self.rc_extra)
|
|
73
|
+
return base
|
|
74
|
+
|
|
75
|
+
def apply(self) -> None:
|
|
76
|
+
mpl.rcParams.update(self.rc_params())
|
|
77
|
+
|
|
78
|
+
def shadow(self) -> Sequence[pe.AbstractPathEffect]:
|
|
79
|
+
return [
|
|
80
|
+
pe.SimpleLineShadow(
|
|
81
|
+
alpha=self.line_shadow_alpha, rho=self.line_shadow_offset
|
|
82
|
+
),
|
|
83
|
+
pe.Normal(),
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def save_train_eval_loss_plot(
|
|
88
|
+
tracker_state_history,
|
|
89
|
+
eval_steps_history,
|
|
90
|
+
out_dir,
|
|
91
|
+
style: StyleConfig = StyleConfig(),
|
|
92
|
+
):
|
|
93
|
+
try:
|
|
94
|
+
if not tracker_state_history:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
style.apply()
|
|
98
|
+
|
|
99
|
+
steps = np.array([s.current_step for s in tracker_state_history])
|
|
100
|
+
train_loss = np.array([s.train_loss for s in tracker_state_history])
|
|
101
|
+
eval_loss = np.array([s.eval_loss for s in tracker_state_history])
|
|
102
|
+
|
|
103
|
+
eval_mask = np.isin(steps, np.array(eval_steps_history))
|
|
104
|
+
eval_steps_plot = steps[eval_mask]
|
|
105
|
+
eval_loss_plot = eval_loss[eval_mask]
|
|
106
|
+
has_eval = eval_steps_plot.size > 0
|
|
107
|
+
|
|
108
|
+
fig, ax = plt.subplots(figsize=style.fig_size)
|
|
109
|
+
|
|
110
|
+
ax.fill_between(steps, train_loss, alpha=style.fill_alpha)
|
|
111
|
+
ax.plot(
|
|
112
|
+
steps,
|
|
113
|
+
train_loss,
|
|
114
|
+
label="train",
|
|
115
|
+
linewidth=style.line_width,
|
|
116
|
+
marker=style.marker_shape,
|
|
117
|
+
markersize=style.marker_size,
|
|
118
|
+
markerfacecolor=style.marker_facecolor,
|
|
119
|
+
markeredgecolor=style.marker_edgecolor,
|
|
120
|
+
markeredgewidth=style.marker_edgewidth,
|
|
121
|
+
path_effects=style.shadow(),
|
|
122
|
+
color=style.colour_red,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# eval curve
|
|
126
|
+
if has_eval:
|
|
127
|
+
ax.fill_between(
|
|
128
|
+
eval_steps_plot,
|
|
129
|
+
eval_loss_plot,
|
|
130
|
+
alpha=style.fill_alpha,
|
|
131
|
+
color=style.colour_blue,
|
|
132
|
+
)
|
|
133
|
+
ax.plot(
|
|
134
|
+
eval_steps_plot,
|
|
135
|
+
eval_loss_plot,
|
|
136
|
+
label="eval",
|
|
137
|
+
linewidth=style.line_width,
|
|
138
|
+
marker=style.marker_shape,
|
|
139
|
+
markersize=style.marker_size,
|
|
140
|
+
markerfacecolor=style.marker_facecolor,
|
|
141
|
+
markeredgecolor=style.marker_edgecolor,
|
|
142
|
+
markeredgewidth=style.marker_edgewidth,
|
|
143
|
+
path_effects=style.shadow(),
|
|
144
|
+
color=style.colour_blue,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# extrema markers
|
|
148
|
+
def mark(idx, xs, ys, marker, color):
|
|
149
|
+
ax.scatter(
|
|
150
|
+
xs[idx],
|
|
151
|
+
ys[idx],
|
|
152
|
+
s=style.extrema_size,
|
|
153
|
+
marker=marker,
|
|
154
|
+
zorder=style.marker_zorder,
|
|
155
|
+
color=color,
|
|
156
|
+
edgecolor=style.marker_edgecolor,
|
|
157
|
+
linewidth=style.extrema_edgewidth,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
mark(
|
|
161
|
+
np.argmin(train_loss),
|
|
162
|
+
steps,
|
|
163
|
+
train_loss,
|
|
164
|
+
style.extrema_down_marker,
|
|
165
|
+
style.colour_red,
|
|
166
|
+
)
|
|
167
|
+
mark(
|
|
168
|
+
np.argmax(train_loss),
|
|
169
|
+
steps,
|
|
170
|
+
train_loss,
|
|
171
|
+
style.extrema_up_marker,
|
|
172
|
+
style.colour_red,
|
|
173
|
+
)
|
|
174
|
+
if has_eval:
|
|
175
|
+
mark(
|
|
176
|
+
np.argmin(eval_loss_plot),
|
|
177
|
+
eval_steps_plot,
|
|
178
|
+
eval_loss_plot,
|
|
179
|
+
style.extrema_down_marker,
|
|
180
|
+
style.colour_blue,
|
|
181
|
+
)
|
|
182
|
+
mark(
|
|
183
|
+
np.argmax(eval_loss_plot),
|
|
184
|
+
eval_steps_plot,
|
|
185
|
+
eval_loss_plot,
|
|
186
|
+
style.extrema_up_marker,
|
|
187
|
+
style.colour_blue,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# summary box
|
|
191
|
+
summary = (
|
|
192
|
+
f"max train: {train_loss.max():.3f}\nmin train: {train_loss.min():.3f}\n"
|
|
193
|
+
)
|
|
194
|
+
if has_eval:
|
|
195
|
+
summary += (
|
|
196
|
+
f"max eval : {eval_loss_plot.max():.3f}\n"
|
|
197
|
+
f"min eval : {eval_loss_plot.min():.3f}\n"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
ax.text(
|
|
201
|
+
*style.summary_xy,
|
|
202
|
+
summary,
|
|
203
|
+
transform=ax.transAxes,
|
|
204
|
+
va="top",
|
|
205
|
+
ha="left",
|
|
206
|
+
fontsize=style.summary_fontsize,
|
|
207
|
+
color=style.colour_grey,
|
|
208
|
+
bbox=style.summary_box_kwargs,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
ax.set_xlabel("Step", labelpad=6)
|
|
212
|
+
ax.set_ylabel("Loss", labelpad=6)
|
|
213
|
+
ax.grid(True)
|
|
214
|
+
ax.legend(frameon=False, loc="upper right")
|
|
215
|
+
ax.set_xlim(left=steps.min()) # start at step 0
|
|
216
|
+
ax.xaxis.set_major_locator(
|
|
217
|
+
MaxNLocator(integer=True)
|
|
218
|
+
) # show only whole‑number ticks
|
|
219
|
+
fig.tight_layout()
|
|
220
|
+
|
|
221
|
+
# write file
|
|
222
|
+
plot_path = Path(out_dir) / "loss_curve.png"
|
|
223
|
+
plot_path.parent.mkdir(parents=True, exist_ok=True)
|
|
224
|
+
fig.savefig(plot_path, dpi=style.dpi)
|
|
225
|
+
plt.close(fig)
|
|
226
|
+
except Exception as e:
|
|
227
|
+
raise EricPlotError(f"Error saving train/eval loss plot: {e}")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def save_lr_plot(
|
|
231
|
+
tracker_state_history,
|
|
232
|
+
out_dir,
|
|
233
|
+
style: StyleConfig = StyleConfig(),
|
|
234
|
+
):
|
|
235
|
+
try:
|
|
236
|
+
if not tracker_state_history:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
style.apply()
|
|
240
|
+
|
|
241
|
+
steps = np.array([s.current_step for s in tracker_state_history])
|
|
242
|
+
lr = np.array([s.lr for s in tracker_state_history])
|
|
243
|
+
|
|
244
|
+
fig, ax = plt.subplots(figsize=style.fig_size)
|
|
245
|
+
|
|
246
|
+
ax.fill_between(steps, lr, alpha=style.fill_alpha)
|
|
247
|
+
ax.plot(
|
|
248
|
+
steps,
|
|
249
|
+
lr,
|
|
250
|
+
label="learning‑rate",
|
|
251
|
+
linewidth=style.line_width,
|
|
252
|
+
marker=style.marker_shape,
|
|
253
|
+
markersize=style.marker_size,
|
|
254
|
+
markerfacecolor=style.marker_facecolor,
|
|
255
|
+
markeredgecolor=style.marker_edgecolor,
|
|
256
|
+
markeredgewidth=style.marker_edgewidth,
|
|
257
|
+
path_effects=style.shadow(),
|
|
258
|
+
color=style.colour_red,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# extrema markers
|
|
262
|
+
def mark(idx, xs, ys, marker, color):
|
|
263
|
+
ax.scatter(
|
|
264
|
+
xs[idx],
|
|
265
|
+
ys[idx],
|
|
266
|
+
s=style.extrema_size,
|
|
267
|
+
marker=marker,
|
|
268
|
+
zorder=style.marker_zorder,
|
|
269
|
+
color=color,
|
|
270
|
+
edgecolor=style.marker_edgecolor,
|
|
271
|
+
linewidth=style.extrema_edgewidth,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
mark(np.argmin(lr), steps, lr, style.extrema_down_marker, style.colour_red)
|
|
275
|
+
mark(np.argmax(lr), steps, lr, style.extrema_up_marker, style.colour_red)
|
|
276
|
+
|
|
277
|
+
# summary box
|
|
278
|
+
summary = f"max lr : {lr.max():.6f}\nmin lr : {lr.min():.6f}\n"
|
|
279
|
+
|
|
280
|
+
ax.text(
|
|
281
|
+
*style.summary_xy,
|
|
282
|
+
summary,
|
|
283
|
+
transform=ax.transAxes,
|
|
284
|
+
va="top",
|
|
285
|
+
ha="left",
|
|
286
|
+
fontsize=style.summary_fontsize,
|
|
287
|
+
color=style.colour_grey,
|
|
288
|
+
bbox=style.summary_box_kwargs,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
ax.set_xlabel("Step", labelpad=6)
|
|
292
|
+
ax.set_ylabel("Learning Rate", labelpad=6)
|
|
293
|
+
ax.grid(True)
|
|
294
|
+
ax.legend(frameon=False, loc="upper right")
|
|
295
|
+
ax.set_xlim(left=steps.min()) # start at step 0
|
|
296
|
+
ax.xaxis.set_major_locator(
|
|
297
|
+
MaxNLocator(integer=True)
|
|
298
|
+
) # show only whole‑number ticks
|
|
299
|
+
|
|
300
|
+
fig.tight_layout()
|
|
301
|
+
|
|
302
|
+
# write file
|
|
303
|
+
plot_path = Path(out_dir) / "lr_curve.png"
|
|
304
|
+
plot_path.parent.mkdir(parents=True, exist_ok=True)
|
|
305
|
+
fig.savefig(plot_path, dpi=style.dpi)
|
|
306
|
+
plt.close(fig)
|
|
307
|
+
|
|
308
|
+
except Exception as e:
|
|
309
|
+
raise EricPlotError(f"Error saving lr plot: {e}")
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def save_metric_plots(
|
|
313
|
+
tracker_state_history,
|
|
314
|
+
out_dir,
|
|
315
|
+
style: StyleConfig = StyleConfig(),
|
|
316
|
+
):
|
|
317
|
+
try:
|
|
318
|
+
if not tracker_state_history:
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
style.apply()
|
|
322
|
+
|
|
323
|
+
# initialize an empty list of metrics
|
|
324
|
+
all_metrics = {}
|
|
325
|
+
for step_idx, state in enumerate(tracker_state_history):
|
|
326
|
+
for metric_name, value in state.metrics.items():
|
|
327
|
+
if metric_name not in all_metrics:
|
|
328
|
+
all_metrics[metric_name] = []
|
|
329
|
+
all_metrics[metric_name].append(
|
|
330
|
+
(state.current_step, value)
|
|
331
|
+
) # Track step + value
|
|
332
|
+
|
|
333
|
+
# plot each metric separately
|
|
334
|
+
for metric_name, data in all_metrics.items():
|
|
335
|
+
steps, values = zip(*data)
|
|
336
|
+
|
|
337
|
+
# history → arrays
|
|
338
|
+
steps = np.array(steps)
|
|
339
|
+
values = np.array(values)
|
|
340
|
+
|
|
341
|
+
# plotting
|
|
342
|
+
fig, ax = plt.subplots(figsize=style.fig_size)
|
|
343
|
+
|
|
344
|
+
ax.fill_between(steps, values, alpha=style.fill_alpha)
|
|
345
|
+
ax.plot(
|
|
346
|
+
steps,
|
|
347
|
+
values,
|
|
348
|
+
label=metric_name,
|
|
349
|
+
linewidth=style.line_width,
|
|
350
|
+
marker=style.marker_shape,
|
|
351
|
+
markersize=style.marker_size,
|
|
352
|
+
markerfacecolor=style.marker_facecolor,
|
|
353
|
+
markeredgecolor=style.marker_edgecolor,
|
|
354
|
+
markeredgewidth=style.marker_edgewidth,
|
|
355
|
+
path_effects=style.shadow(),
|
|
356
|
+
color=style.colour_red,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# extrema markers
|
|
360
|
+
def mark(idx, xs, ys, marker, color):
|
|
361
|
+
ax.scatter(
|
|
362
|
+
xs[idx],
|
|
363
|
+
ys[idx],
|
|
364
|
+
s=style.extrema_size,
|
|
365
|
+
marker=marker,
|
|
366
|
+
zorder=style.marker_zorder,
|
|
367
|
+
color=color,
|
|
368
|
+
edgecolor=style.marker_edgecolor,
|
|
369
|
+
linewidth=style.extrema_edgewidth,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
mark(
|
|
373
|
+
np.argmin(values),
|
|
374
|
+
steps,
|
|
375
|
+
values,
|
|
376
|
+
style.extrema_down_marker,
|
|
377
|
+
style.colour_red,
|
|
378
|
+
)
|
|
379
|
+
mark(
|
|
380
|
+
np.argmax(values),
|
|
381
|
+
steps,
|
|
382
|
+
values,
|
|
383
|
+
style.extrema_up_marker,
|
|
384
|
+
style.colour_red,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# summary box
|
|
388
|
+
summary = (
|
|
389
|
+
f"max {metric_name} : {values.max():.6f}\n"
|
|
390
|
+
f"min {metric_name} : {values.min():.6f}\n"
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
ax.text(
|
|
394
|
+
*style.summary_xy,
|
|
395
|
+
summary,
|
|
396
|
+
transform=ax.transAxes,
|
|
397
|
+
va="top",
|
|
398
|
+
ha="left",
|
|
399
|
+
fontsize=style.summary_fontsize,
|
|
400
|
+
color=style.colour_grey,
|
|
401
|
+
bbox=style.summary_box_kwargs,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
ax.set_xlabel("Step", labelpad=6)
|
|
405
|
+
ax.set_ylabel(f"{metric_name}", labelpad=6)
|
|
406
|
+
ax.grid(True)
|
|
407
|
+
ax.legend(frameon=False, loc="upper right")
|
|
408
|
+
ax.set_xlim(left=steps.min()) # start at step 0
|
|
409
|
+
ax.xaxis.set_major_locator(
|
|
410
|
+
MaxNLocator(integer=True)
|
|
411
|
+
) # show only whole‑number ticks
|
|
412
|
+
|
|
413
|
+
fig.tight_layout()
|
|
414
|
+
|
|
415
|
+
# write file
|
|
416
|
+
plot_path = Path(out_dir) / f"{metric_name}.png"
|
|
417
|
+
plot_path.parent.mkdir(parents=True, exist_ok=True)
|
|
418
|
+
fig.savefig(plot_path, dpi=style.dpi)
|
|
419
|
+
plt.close(fig)
|
|
420
|
+
|
|
421
|
+
except Exception as e:
|
|
422
|
+
raise EricPlotError(f"Error saving lr plot: {e}")
|