dreadnode 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,183 @@
1
+ import importlib.util
2
+
3
+ if importlib.util.find_spec("transformers") is None:
4
+ raise ModuleNotFoundError("Please install the `transformers` package to use this integration")
5
+
6
+ import typing as t
7
+
8
+ from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
9
+ from transformers.training_args import TrainingArguments
10
+
11
+ import dreadnode as dn
12
+
13
+ if t.TYPE_CHECKING:
14
+ from dreadnode.tracing.span import RunSpan, Span
15
+
16
+ # ruff: noqa: ARG002
17
+
18
+
19
+ def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]:
20
+ cleaned: dict[str, t.Any] = {}
21
+ for key, val in data.items():
22
+ _key = key.replace("eval_", "eval/").replace("test_", "test/").replace("train_", "train/")
23
+ cleaned[_key] = val
24
+ return cleaned
25
+
26
+
27
+ class DreadnodeCallback(TrainerCallback):
28
+ """
29
+ An implementation of the `TrainerCallback` interface for Dreadnode.
30
+
31
+ This callback is used to log metrics and parameters to Dreadnode during training inside
32
+ the `transformers` library or derivations (`trl`, etc.).
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ project: str | None = None,
38
+ run_name: str | None = None,
39
+ tags: list[str] | None = None,
40
+ ):
41
+ self.project = project
42
+ self.run_name = run_name
43
+ self.tags = tags or []
44
+
45
+ self._initialized = False
46
+ self._run: RunSpan | None = None
47
+ self._epoch_span: Span | None = None
48
+ self._step_span: Span | None = None
49
+
50
+ def _shutdown(self) -> None:
51
+ if self._step_span is not None:
52
+ self._step_span.__exit__(None, None, None)
53
+ self._step_span = None
54
+
55
+ if self._epoch_span is not None:
56
+ self._epoch_span.__exit__(None, None, None)
57
+ self._epoch_span = None
58
+
59
+ if self._run is not None:
60
+ self._run.__exit__(None, None, None)
61
+ self._run = None
62
+
63
+ def _setup(self, args: TrainingArguments, state: TrainerState, model: t.Any) -> None:
64
+ if self._initialized:
65
+ return
66
+
67
+ self._initialized = True
68
+
69
+ if not state.is_world_process_zero:
70
+ return
71
+
72
+ combined_dict = {**args.to_sanitized_dict()}
73
+
74
+ if hasattr(model, "config") and model.config is not None:
75
+ model_config = (
76
+ model.config if isinstance(model.config, dict) else model.config.to_dict()
77
+ )
78
+ for key, value in model_config.items():
79
+ combined_dict[f"model/{key}"] = value
80
+ if hasattr(model, "peft_config") and model.peft_config is not None:
81
+ for key, value in model.peft_config.items():
82
+ combined_dict[f"peft/{key}"] = value
83
+
84
+ run_name = self.run_name or args.run_name or state.trial_name
85
+
86
+ self._run = dn.run(
87
+ name=run_name,
88
+ project=self.project,
89
+ tags=self.tags,
90
+ )
91
+ self._run.__enter__()
92
+
93
+ dn.log_params(**combined_dict)
94
+ dn.push_update()
95
+
96
+ def on_train_begin(
97
+ self,
98
+ args: TrainingArguments,
99
+ state: TrainerState,
100
+ control: TrainerControl,
101
+ model: t.Any | None = None,
102
+ **kwargs: t.Any,
103
+ ) -> None:
104
+ if not self._initialized:
105
+ self._setup(args, state, model)
106
+
107
+ def on_train_end(
108
+ self,
109
+ args: TrainingArguments,
110
+ state: TrainerState,
111
+ control: TrainerControl,
112
+ **kwargs: t.Any,
113
+ ) -> None:
114
+ self._shutdown()
115
+
116
+ def on_epoch_begin(
117
+ self,
118
+ args: TrainingArguments,
119
+ state: TrainerState,
120
+ control: TrainerControl,
121
+ **kwargs: t.Any,
122
+ ) -> None:
123
+ if self._run is None or state.epoch is None:
124
+ return
125
+
126
+ dn.log_metric("epoch", state.epoch)
127
+
128
+ self._epoch_span = dn.task_span(f"Epoch {state.epoch}")
129
+ self._epoch_span.__enter__()
130
+
131
+ def on_epoch_end(
132
+ self,
133
+ args: TrainingArguments,
134
+ state: TrainerState,
135
+ control: TrainerControl,
136
+ **kwargs: t.Any,
137
+ ) -> None:
138
+ if self._epoch_span is not None:
139
+ self._epoch_span.__exit__(None, None, None)
140
+ self._epoch_span = None
141
+
142
+ def on_step_begin(
143
+ self,
144
+ args: TrainingArguments,
145
+ state: TrainerState,
146
+ control: TrainerControl,
147
+ **kwargs: t.Any,
148
+ ) -> None:
149
+ if self._run is None:
150
+ return
151
+
152
+ dn.log_metric("step", state.global_step)
153
+
154
+ self._step_span = dn.span(f"Step {state.global_step}")
155
+ self._step_span.__enter__()
156
+
157
+ def on_step_end(
158
+ self,
159
+ args: TrainingArguments,
160
+ state: TrainerState,
161
+ control: TrainerControl,
162
+ **kwargs: t.Any,
163
+ ) -> None:
164
+ if self._step_span is not None:
165
+ self._step_span.__exit__(None, None, None)
166
+ self._step_span = None
167
+
168
+ def on_log(
169
+ self,
170
+ args: TrainingArguments,
171
+ state: TrainerState,
172
+ control: TrainerControl,
173
+ logs: dict[str, t.Any] | None = None,
174
+ **kwargs: t.Any,
175
+ ) -> None:
176
+ if self._run is None or logs is None:
177
+ return
178
+
179
+ for key, value in _clean_keys(logs).items():
180
+ if isinstance(value, float | int):
181
+ dn.log_metric(key, value, step=state.global_step)
182
+
183
+ dn.push_update()