erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,256 @@
1
+ import copy
2
+ import json
3
+ import math
4
+ import os
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Optional
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ from erictransformer.args.eric_args import EricTrainArgs
11
+ from erictransformer.exceptions import EricIOError
12
+ from erictransformer.eric_tracker.save_plot import (
13
+ save_lr_plot,
14
+ save_metric_plots,
15
+ save_train_eval_loss_plot,
16
+ )
17
+ from erictransformer.utils import make_dir
18
+
19
+
20
+ @dataclass
21
+ class TrackerState:
22
+ current_step: int = 0
23
+ last_eval_step: int = 0
24
+ original_eval_loss: float = 0.0
25
+ eval_loss: float = 0.0
26
+ best_eval_loss: Optional[float] = None
27
+ eval_loss_improvement: float = 0.0
28
+ train_loss: float = 0.0
29
+ last_checkpoint_step: Optional[int] = None
30
+ epoch: int = 1
31
+ lr: float = 0
32
+ metrics: Optional[dict] = None
33
+ num_tokens: int = 0
34
+
35
+ def to_dict(self):
36
+ return asdict(self)
37
+
38
+ @classmethod
39
+ def from_dict(cls, data):
40
+ return cls(**data)
41
+
42
+
43
+ @dataclass
44
+ class TrackerPaths:
45
+ log_path: str
46
+ best_model_path: str
47
+ train_args_path: str
48
+ checkpoint_path: str
49
+ checkpoint_state_path: str
50
+ best_model_state_path: str
51
+ lr_sched_path: str
52
+ optim_path: str
53
+
54
+
55
+ class EricTracker:
56
+ def __init__(
57
+ self, args: EricTrainArgs, train_steps: int, out_dir: str, tracker_state
58
+ ):
59
+ self.train_args = args
60
+ self.train_steps = train_steps
61
+
62
+ self.out_dir = out_dir
63
+
64
+ self.tracker_paths = self.__init_dirs(out_dir)
65
+ self.log_steps, self.eval_steps, self.checkpoint_steps = (
66
+ self.__get_log_eval_checkpoint_steps(self.train_args)
67
+ )
68
+ self._write_train_args()
69
+
70
+ self.tracker_state_history = []
71
+
72
+ if tracker_state is not None:
73
+ self.state = TrackerState(**tracker_state)
74
+ self.state.current_step += 1
75
+
76
+ else:
77
+ self.state = TrackerState()
78
+
79
+ self.epoch_steps = []
80
+ self.eval_steps_history = []
81
+ self.pbar = tqdm(
82
+ total=train_steps, initial=self.state.current_step, desc="Model Training"
83
+ )
84
+
85
+ def __init_dirs(self, out_dir):
86
+ try:
87
+ best_model_path = os.path.join(out_dir, "best_model")
88
+ make_dir(best_model_path)
89
+
90
+ best_model_state_path = os.path.join(
91
+ best_model_path, "eric_tracker_state.json"
92
+ )
93
+ open(best_model_state_path, "a").close()
94
+
95
+ log_path = os.path.join(out_dir, "logs.jsonl")
96
+ open(log_path, "a").close()
97
+
98
+ checkpoint_path = os.path.join(out_dir, "checkpoint")
99
+ make_dir(checkpoint_path)
100
+
101
+ train_args_path = os.path.join(out_dir, "checkpoint/train_args.json")
102
+ open(train_args_path, "a").close()
103
+
104
+ checkpoint_state_path = os.path.join(
105
+ checkpoint_path, "eric_tracker_state.json"
106
+ )
107
+ open(checkpoint_state_path, "a").close()
108
+
109
+ lr_sched_path = os.path.join(checkpoint_path, "lr_sched.pt")
110
+ open(lr_sched_path, "a").close()
111
+
112
+ optim_path = os.path.join(checkpoint_path, "optim.pt")
113
+ open(optim_path, "a").close()
114
+ except Exception as e:
115
+ raise EricIOError(
116
+ f"Error initializing tracker directories or files: {str(e)}"
117
+ )
118
+
119
+ return TrackerPaths(
120
+ best_model_state_path=best_model_state_path,
121
+ best_model_path=best_model_path,
122
+ log_path=log_path,
123
+ train_args_path=train_args_path,
124
+ checkpoint_path=checkpoint_path,
125
+ checkpoint_state_path=checkpoint_state_path,
126
+ lr_sched_path=lr_sched_path,
127
+ optim_path=optim_path,
128
+ )
129
+
130
+ def __get_log_eval_checkpoint_steps(self, args: EricTrainArgs):
131
+ def resolve_steps(value):
132
+ if value >= 1:
133
+ return math.ceil(value)
134
+ return 0
135
+
136
+ return (
137
+ resolve_steps(args.log_steps) or 1,
138
+ resolve_steps(args.eval_steps),
139
+ resolve_steps(args.checkpoint_steps),
140
+ )
141
+
142
+ def _write_train_args(self):
143
+ try:
144
+ with open(self.tracker_paths.train_args_path, "w") as f:
145
+ json.dump(asdict(self.train_args), f, indent=2)
146
+ except Exception as e:
147
+ raise EricIOError(
148
+ f"Error writing train args to {self.tracker_paths.train_args_path}: {e}"
149
+ )
150
+
151
+ def time_to_eval(self):
152
+ return bool(
153
+ self.eval_steps
154
+ and (
155
+ self.state.current_step % self.eval_steps == 0
156
+ or self.state.current_step == 0
157
+ or self.state.current_step + 1 == self.train_steps
158
+ )
159
+ )
160
+
161
+ def time_to_checkpoint(self):
162
+ if not self.checkpoint_steps:
163
+ return False
164
+
165
+ return (
166
+ self.state.current_step % self.checkpoint_steps == 0
167
+ or self.state.current_step + 1 == self.train_steps
168
+ )
169
+
170
+ def time_to_log(self):
171
+ return (
172
+ self.state.current_step % self.log_steps == 0
173
+ or self.state.current_step == 0
174
+ or self.state.current_step + 1 == self.train_steps
175
+ )
176
+
177
+ def set_train_loss(self, loss: float):
178
+ self.state.train_loss = loss
179
+
180
+ def set_eval_loss(self, eval_loss):
181
+ is_best_model = False
182
+
183
+ if self.state.current_step == 0:
184
+ self.state.original_eval_loss = eval_loss
185
+
186
+ self.state.eval_loss = eval_loss
187
+
188
+ if self.state.best_eval_loss is None or eval_loss < self.state.best_eval_loss:
189
+ self.state.best_eval_loss = eval_loss
190
+ is_best_model = True
191
+ self.state.eval_loss_improvement = self.state.original_eval_loss - eval_loss
192
+ self.state.last_eval_step = self.state.current_step
193
+
194
+ self.eval_steps_history.append(self.state.current_step)
195
+
196
+ return is_best_model
197
+
198
+ def set_metrics(self, metrics: dict):
199
+ self.state.metrics = metrics
200
+
201
+ def mark_epoch(self):
202
+ self.epoch_steps.append(self.state.current_step)
203
+ self.state.epoch += 1
204
+
205
+ def step(self, loss: float, lr: float, num_tokens: int):
206
+ self.state.lr = lr
207
+ self.state.num_tokens += num_tokens
208
+
209
+ if self.time_to_checkpoint():
210
+ try:
211
+ with open(self.tracker_paths.checkpoint_state_path, "w") as f:
212
+ self.state.last_checkpoint_step = self.state.current_step
213
+ json.dump(self.state.to_dict(), f, indent=2)
214
+ except Exception as e:
215
+ raise EricIOError(
216
+ f"Error writing checkpoint to {self.tracker_paths.checkpoint_state_path}: {e}"
217
+ )
218
+
219
+ if self.time_to_log() or self.time_to_checkpoint():
220
+ self.set_train_loss(loss.item())
221
+ self.tracker_state_history.append(copy.deepcopy(self.state))
222
+ self._save_state(self.tracker_paths.log_path)
223
+ save_train_eval_loss_plot(
224
+ self.tracker_state_history, self.eval_steps_history, self.out_dir
225
+ )
226
+ save_lr_plot(self.tracker_state_history, self.out_dir)
227
+ if self.eval_steps_history: # if eval has happened
228
+ save_metric_plots(self.tracker_state_history, self.out_dir)
229
+ self.pbar.set_postfix(
230
+ last_eval_step=self.state.last_eval_step,
231
+ eval_loss_improvement=f"{self.state.eval_loss_improvement:.4f}",
232
+ eval_loss=self.state.eval_loss,
233
+ loss=f"{loss:.4f}",
234
+ num_tokens=self.state.num_tokens,
235
+ )
236
+
237
+ self.state.current_step += 1
238
+
239
+ def optim_step(self):
240
+ self.pbar.update(1)
241
+ self.pbar.refresh() # force update
242
+
243
+ def close(self):
244
+ self.pbar.close()
245
+
246
+ def _save_state(self, save_path):
247
+ try:
248
+ with open(save_path, "a") as f:
249
+ to_save = self.state.to_dict()
250
+ del to_save[
251
+ "original_eval_loss"
252
+ ] # deleted this since it's static. Maybe we should include it anyways
253
+ json.dump(to_save, f, separators=(",", ":"), allow_nan=True)
254
+ f.write("\n")
255
+ except Exception as e:
256
+ raise EricIOError(f"Error writing state to {save_path} {e}")
@@ -0,0 +1,422 @@
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+ from typing import Dict, Sequence, Tuple
4
+
5
+ import matplotlib as mpl
6
+ import matplotlib.patheffects as pe
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from matplotlib.ticker import MaxNLocator
10
+
11
+ from erictransformer.exceptions.eric_exceptions import EricPlotError
12
+
13
+ ERIC_RED = "#d62828"
14
+ ERIC_BLUE = "#1d84e2"
15
+ MID_GREY = "#4c4c4c"
16
+ BG_LIGHT = "#fafafa"
17
+
18
+
19
+ @dataclass
20
+ class StyleConfig:
21
+ dpi: int = 150
22
+ fig_size: Tuple[float, float] = (7.5, 4.5)
23
+
24
+ colour_red: str = ERIC_RED
25
+ colour_blue: str = ERIC_BLUE
26
+ colour_light_grey: str = BG_LIGHT
27
+ colour_grey: str = MID_GREY
28
+
29
+ line_width: float = 2.5
30
+ fill_alpha: float = 0.07
31
+ line_shadow_alpha: float = 0.25
32
+ line_shadow_offset: Tuple[float, float] = (-0.5, -0.5)
33
+
34
+ marker_shape: str = "o"
35
+ marker_size: int = 3
36
+ marker_facecolor: str = "white"
37
+ marker_edgecolor: str = "white"
38
+ marker_edgewidth: float = 1.0
39
+ marker_zorder: int = 1
40
+
41
+ extrema_size: int = 120
42
+ extrema_up_marker: str = "^"
43
+ extrema_down_marker: str = "v"
44
+ extrema_edgewidth: float = 1.5
45
+
46
+ summary_xy: Tuple[float, float] = (0.50, 0.98)
47
+ summary_fontsize: int = 8
48
+ summary_box_kwargs: Dict[str, object] = field(
49
+ default_factory=lambda: dict(
50
+ boxstyle="round,pad=0.3",
51
+ facecolor=BG_LIGHT,
52
+ edgecolor=MID_GREY,
53
+ alpha=0.85,
54
+ )
55
+ )
56
+
57
+ grid_alpha: float = 0.15
58
+ rc_extra: Dict[str, object] = field(default_factory=dict)
59
+
60
+ def rc_params(self) -> Dict[str, object]:
61
+ base = {
62
+ "figure.dpi": self.dpi,
63
+ "figure.facecolor": self.colour_light_grey,
64
+ "axes.facecolor": self.colour_light_grey,
65
+ "axes.edgecolor": "none",
66
+ "grid.alpha": self.grid_alpha,
67
+ "axes.labelcolor": self.colour_grey,
68
+ "text.color": self.colour_grey,
69
+ "xtick.color": self.colour_grey,
70
+ "ytick.color": self.colour_grey,
71
+ }
72
+ base.update(self.rc_extra)
73
+ return base
74
+
75
+ def apply(self) -> None:
76
+ mpl.rcParams.update(self.rc_params())
77
+
78
+ def shadow(self) -> Sequence[pe.AbstractPathEffect]:
79
+ return [
80
+ pe.SimpleLineShadow(
81
+ alpha=self.line_shadow_alpha, rho=self.line_shadow_offset
82
+ ),
83
+ pe.Normal(),
84
+ ]
85
+
86
+
87
+ def save_train_eval_loss_plot(
88
+ tracker_state_history,
89
+ eval_steps_history,
90
+ out_dir,
91
+ style: StyleConfig = StyleConfig(),
92
+ ):
93
+ try:
94
+ if not tracker_state_history:
95
+ return
96
+
97
+ style.apply()
98
+
99
+ steps = np.array([s.current_step for s in tracker_state_history])
100
+ train_loss = np.array([s.train_loss for s in tracker_state_history])
101
+ eval_loss = np.array([s.eval_loss for s in tracker_state_history])
102
+
103
+ eval_mask = np.isin(steps, np.array(eval_steps_history))
104
+ eval_steps_plot = steps[eval_mask]
105
+ eval_loss_plot = eval_loss[eval_mask]
106
+ has_eval = eval_steps_plot.size > 0
107
+
108
+ fig, ax = plt.subplots(figsize=style.fig_size)
109
+
110
+ ax.fill_between(steps, train_loss, alpha=style.fill_alpha)
111
+ ax.plot(
112
+ steps,
113
+ train_loss,
114
+ label="train",
115
+ linewidth=style.line_width,
116
+ marker=style.marker_shape,
117
+ markersize=style.marker_size,
118
+ markerfacecolor=style.marker_facecolor,
119
+ markeredgecolor=style.marker_edgecolor,
120
+ markeredgewidth=style.marker_edgewidth,
121
+ path_effects=style.shadow(),
122
+ color=style.colour_red,
123
+ )
124
+
125
+ # eval curve
126
+ if has_eval:
127
+ ax.fill_between(
128
+ eval_steps_plot,
129
+ eval_loss_plot,
130
+ alpha=style.fill_alpha,
131
+ color=style.colour_blue,
132
+ )
133
+ ax.plot(
134
+ eval_steps_plot,
135
+ eval_loss_plot,
136
+ label="eval",
137
+ linewidth=style.line_width,
138
+ marker=style.marker_shape,
139
+ markersize=style.marker_size,
140
+ markerfacecolor=style.marker_facecolor,
141
+ markeredgecolor=style.marker_edgecolor,
142
+ markeredgewidth=style.marker_edgewidth,
143
+ path_effects=style.shadow(),
144
+ color=style.colour_blue,
145
+ )
146
+
147
+ # extrema markers
148
+ def mark(idx, xs, ys, marker, color):
149
+ ax.scatter(
150
+ xs[idx],
151
+ ys[idx],
152
+ s=style.extrema_size,
153
+ marker=marker,
154
+ zorder=style.marker_zorder,
155
+ color=color,
156
+ edgecolor=style.marker_edgecolor,
157
+ linewidth=style.extrema_edgewidth,
158
+ )
159
+
160
+ mark(
161
+ np.argmin(train_loss),
162
+ steps,
163
+ train_loss,
164
+ style.extrema_down_marker,
165
+ style.colour_red,
166
+ )
167
+ mark(
168
+ np.argmax(train_loss),
169
+ steps,
170
+ train_loss,
171
+ style.extrema_up_marker,
172
+ style.colour_red,
173
+ )
174
+ if has_eval:
175
+ mark(
176
+ np.argmin(eval_loss_plot),
177
+ eval_steps_plot,
178
+ eval_loss_plot,
179
+ style.extrema_down_marker,
180
+ style.colour_blue,
181
+ )
182
+ mark(
183
+ np.argmax(eval_loss_plot),
184
+ eval_steps_plot,
185
+ eval_loss_plot,
186
+ style.extrema_up_marker,
187
+ style.colour_blue,
188
+ )
189
+
190
+ # summary box
191
+ summary = (
192
+ f"max train: {train_loss.max():.3f}\nmin train: {train_loss.min():.3f}\n"
193
+ )
194
+ if has_eval:
195
+ summary += (
196
+ f"max eval : {eval_loss_plot.max():.3f}\n"
197
+ f"min eval : {eval_loss_plot.min():.3f}\n"
198
+ )
199
+
200
+ ax.text(
201
+ *style.summary_xy,
202
+ summary,
203
+ transform=ax.transAxes,
204
+ va="top",
205
+ ha="left",
206
+ fontsize=style.summary_fontsize,
207
+ color=style.colour_grey,
208
+ bbox=style.summary_box_kwargs,
209
+ )
210
+
211
+ ax.set_xlabel("Step", labelpad=6)
212
+ ax.set_ylabel("Loss", labelpad=6)
213
+ ax.grid(True)
214
+ ax.legend(frameon=False, loc="upper right")
215
+ ax.set_xlim(left=steps.min()) # start at step 0
216
+ ax.xaxis.set_major_locator(
217
+ MaxNLocator(integer=True)
218
+ ) # show only whole‑number ticks
219
+ fig.tight_layout()
220
+
221
+ # write file
222
+ plot_path = Path(out_dir) / "loss_curve.png"
223
+ plot_path.parent.mkdir(parents=True, exist_ok=True)
224
+ fig.savefig(plot_path, dpi=style.dpi)
225
+ plt.close(fig)
226
+ except Exception as e:
227
+ raise EricPlotError(f"Error saving train/eval loss plot: {e}")
228
+
229
+
230
+ def save_lr_plot(
231
+ tracker_state_history,
232
+ out_dir,
233
+ style: StyleConfig = StyleConfig(),
234
+ ):
235
+ try:
236
+ if not tracker_state_history:
237
+ return
238
+
239
+ style.apply()
240
+
241
+ steps = np.array([s.current_step for s in tracker_state_history])
242
+ lr = np.array([s.lr for s in tracker_state_history])
243
+
244
+ fig, ax = plt.subplots(figsize=style.fig_size)
245
+
246
+ ax.fill_between(steps, lr, alpha=style.fill_alpha)
247
+ ax.plot(
248
+ steps,
249
+ lr,
250
+ label="learning‑rate",
251
+ linewidth=style.line_width,
252
+ marker=style.marker_shape,
253
+ markersize=style.marker_size,
254
+ markerfacecolor=style.marker_facecolor,
255
+ markeredgecolor=style.marker_edgecolor,
256
+ markeredgewidth=style.marker_edgewidth,
257
+ path_effects=style.shadow(),
258
+ color=style.colour_red,
259
+ )
260
+
261
+ # extrema markers
262
+ def mark(idx, xs, ys, marker, color):
263
+ ax.scatter(
264
+ xs[idx],
265
+ ys[idx],
266
+ s=style.extrema_size,
267
+ marker=marker,
268
+ zorder=style.marker_zorder,
269
+ color=color,
270
+ edgecolor=style.marker_edgecolor,
271
+ linewidth=style.extrema_edgewidth,
272
+ )
273
+
274
+ mark(np.argmin(lr), steps, lr, style.extrema_down_marker, style.colour_red)
275
+ mark(np.argmax(lr), steps, lr, style.extrema_up_marker, style.colour_red)
276
+
277
+ # summary box
278
+ summary = f"max lr : {lr.max():.6f}\nmin lr : {lr.min():.6f}\n"
279
+
280
+ ax.text(
281
+ *style.summary_xy,
282
+ summary,
283
+ transform=ax.transAxes,
284
+ va="top",
285
+ ha="left",
286
+ fontsize=style.summary_fontsize,
287
+ color=style.colour_grey,
288
+ bbox=style.summary_box_kwargs,
289
+ )
290
+
291
+ ax.set_xlabel("Step", labelpad=6)
292
+ ax.set_ylabel("Learning Rate", labelpad=6)
293
+ ax.grid(True)
294
+ ax.legend(frameon=False, loc="upper right")
295
+ ax.set_xlim(left=steps.min()) # start at step 0
296
+ ax.xaxis.set_major_locator(
297
+ MaxNLocator(integer=True)
298
+ ) # show only whole‑number ticks
299
+
300
+ fig.tight_layout()
301
+
302
+ # write file
303
+ plot_path = Path(out_dir) / "lr_curve.png"
304
+ plot_path.parent.mkdir(parents=True, exist_ok=True)
305
+ fig.savefig(plot_path, dpi=style.dpi)
306
+ plt.close(fig)
307
+
308
+ except Exception as e:
309
+ raise EricPlotError(f"Error saving lr plot: {e}")
310
+
311
+
312
+ def save_metric_plots(
313
+ tracker_state_history,
314
+ out_dir,
315
+ style: StyleConfig = StyleConfig(),
316
+ ):
317
+ try:
318
+ if not tracker_state_history:
319
+ return
320
+
321
+ style.apply()
322
+
323
+ # initialize an empty list of metrics
324
+ all_metrics = {}
325
+ for step_idx, state in enumerate(tracker_state_history):
326
+ for metric_name, value in state.metrics.items():
327
+ if metric_name not in all_metrics:
328
+ all_metrics[metric_name] = []
329
+ all_metrics[metric_name].append(
330
+ (state.current_step, value)
331
+ ) # Track step + value
332
+
333
+ # plot each metric separately
334
+ for metric_name, data in all_metrics.items():
335
+ steps, values = zip(*data)
336
+
337
+ # history → arrays
338
+ steps = np.array(steps)
339
+ values = np.array(values)
340
+
341
+ # plotting
342
+ fig, ax = plt.subplots(figsize=style.fig_size)
343
+
344
+ ax.fill_between(steps, values, alpha=style.fill_alpha)
345
+ ax.plot(
346
+ steps,
347
+ values,
348
+ label=metric_name,
349
+ linewidth=style.line_width,
350
+ marker=style.marker_shape,
351
+ markersize=style.marker_size,
352
+ markerfacecolor=style.marker_facecolor,
353
+ markeredgecolor=style.marker_edgecolor,
354
+ markeredgewidth=style.marker_edgewidth,
355
+ path_effects=style.shadow(),
356
+ color=style.colour_red,
357
+ )
358
+
359
+ # extrema markers
360
+ def mark(idx, xs, ys, marker, color):
361
+ ax.scatter(
362
+ xs[idx],
363
+ ys[idx],
364
+ s=style.extrema_size,
365
+ marker=marker,
366
+ zorder=style.marker_zorder,
367
+ color=color,
368
+ edgecolor=style.marker_edgecolor,
369
+ linewidth=style.extrema_edgewidth,
370
+ )
371
+
372
+ mark(
373
+ np.argmin(values),
374
+ steps,
375
+ values,
376
+ style.extrema_down_marker,
377
+ style.colour_red,
378
+ )
379
+ mark(
380
+ np.argmax(values),
381
+ steps,
382
+ values,
383
+ style.extrema_up_marker,
384
+ style.colour_red,
385
+ )
386
+
387
+ # summary box
388
+ summary = (
389
+ f"max {metric_name} : {values.max():.6f}\n"
390
+ f"min {metric_name} : {values.min():.6f}\n"
391
+ )
392
+
393
+ ax.text(
394
+ *style.summary_xy,
395
+ summary,
396
+ transform=ax.transAxes,
397
+ va="top",
398
+ ha="left",
399
+ fontsize=style.summary_fontsize,
400
+ color=style.colour_grey,
401
+ bbox=style.summary_box_kwargs,
402
+ )
403
+
404
+ ax.set_xlabel("Step", labelpad=6)
405
+ ax.set_ylabel(f"{metric_name}", labelpad=6)
406
+ ax.grid(True)
407
+ ax.legend(frameon=False, loc="upper right")
408
+ ax.set_xlim(left=steps.min()) # start at step 0
409
+ ax.xaxis.set_major_locator(
410
+ MaxNLocator(integer=True)
411
+ ) # show only whole‑number ticks
412
+
413
+ fig.tight_layout()
414
+
415
+ # write file
416
+ plot_path = Path(out_dir) / f"{metric_name}.png"
417
+ plot_path.parent.mkdir(parents=True, exist_ok=True)
418
+ fig.savefig(plot_path, dpi=style.dpi)
419
+ plt.close(fig)
420
+
421
+ except Exception as e:
422
+ raise EricPlotError(f"Error saving lr plot: {e}")