dreadnode 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreadnode/__init__.py +51 -0
- dreadnode/api/__init__.py +0 -0
- dreadnode/api/client.py +249 -0
- dreadnode/api/models.py +210 -0
- dreadnode/artifact/__init__.py +0 -0
- dreadnode/artifact/merger.py +599 -0
- dreadnode/artifact/storage.py +126 -0
- dreadnode/artifact/tree_builder.py +455 -0
- dreadnode/constants.py +16 -0
- dreadnode/integrations/__init__.py +0 -0
- dreadnode/integrations/transformers.py +183 -0
- dreadnode/main.py +1042 -0
- dreadnode/metric.py +225 -0
- dreadnode/object.py +29 -0
- dreadnode/py.typed +0 -0
- dreadnode/serialization.py +731 -0
- dreadnode/task.py +447 -0
- dreadnode/tracing/__init__.py +0 -0
- dreadnode/tracing/constants.py +35 -0
- dreadnode/tracing/exporters.py +157 -0
- dreadnode/tracing/span.py +811 -0
- dreadnode/types.py +25 -0
- dreadnode/util.py +150 -0
- dreadnode/version.py +3 -0
- dreadnode-1.0.0rc0.dist-info/METADATA +122 -0
- dreadnode-1.0.0rc0.dist-info/RECORD +27 -0
- dreadnode-1.0.0rc0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
|
|
3
|
+
if importlib.util.find_spec("transformers") is None:
|
|
4
|
+
raise ModuleNotFoundError("Please install the `transformers` package to use this integration")
|
|
5
|
+
|
|
6
|
+
import typing as t
|
|
7
|
+
|
|
8
|
+
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
|
9
|
+
from transformers.training_args import TrainingArguments
|
|
10
|
+
|
|
11
|
+
import dreadnode as dn
|
|
12
|
+
|
|
13
|
+
if t.TYPE_CHECKING:
|
|
14
|
+
from dreadnode.tracing.span import RunSpan, Span
|
|
15
|
+
|
|
16
|
+
# ruff: noqa: ARG002
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]:
|
|
20
|
+
cleaned: dict[str, t.Any] = {}
|
|
21
|
+
for key, val in data.items():
|
|
22
|
+
_key = key.replace("eval_", "eval/").replace("test_", "test/").replace("train_", "train/")
|
|
23
|
+
cleaned[_key] = val
|
|
24
|
+
return cleaned
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DreadnodeCallback(TrainerCallback):
|
|
28
|
+
"""
|
|
29
|
+
An implementation of the `TrainerCallback` interface for Dreadnode.
|
|
30
|
+
|
|
31
|
+
This callback is used to log metrics and parameters to Dreadnode during training inside
|
|
32
|
+
the `transformers` library or derivations (`trl`, etc.).
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
project: str | None = None,
|
|
38
|
+
run_name: str | None = None,
|
|
39
|
+
tags: list[str] | None = None,
|
|
40
|
+
):
|
|
41
|
+
self.project = project
|
|
42
|
+
self.run_name = run_name
|
|
43
|
+
self.tags = tags or []
|
|
44
|
+
|
|
45
|
+
self._initialized = False
|
|
46
|
+
self._run: RunSpan | None = None
|
|
47
|
+
self._epoch_span: Span | None = None
|
|
48
|
+
self._step_span: Span | None = None
|
|
49
|
+
|
|
50
|
+
def _shutdown(self) -> None:
|
|
51
|
+
if self._step_span is not None:
|
|
52
|
+
self._step_span.__exit__(None, None, None)
|
|
53
|
+
self._step_span = None
|
|
54
|
+
|
|
55
|
+
if self._epoch_span is not None:
|
|
56
|
+
self._epoch_span.__exit__(None, None, None)
|
|
57
|
+
self._epoch_span = None
|
|
58
|
+
|
|
59
|
+
if self._run is not None:
|
|
60
|
+
self._run.__exit__(None, None, None)
|
|
61
|
+
self._run = None
|
|
62
|
+
|
|
63
|
+
def _setup(self, args: TrainingArguments, state: TrainerState, model: t.Any) -> None:
|
|
64
|
+
if self._initialized:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
self._initialized = True
|
|
68
|
+
|
|
69
|
+
if not state.is_world_process_zero:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
combined_dict = {**args.to_sanitized_dict()}
|
|
73
|
+
|
|
74
|
+
if hasattr(model, "config") and model.config is not None:
|
|
75
|
+
model_config = (
|
|
76
|
+
model.config if isinstance(model.config, dict) else model.config.to_dict()
|
|
77
|
+
)
|
|
78
|
+
for key, value in model_config.items():
|
|
79
|
+
combined_dict[f"model/{key}"] = value
|
|
80
|
+
if hasattr(model, "peft_config") and model.peft_config is not None:
|
|
81
|
+
for key, value in model.peft_config.items():
|
|
82
|
+
combined_dict[f"peft/{key}"] = value
|
|
83
|
+
|
|
84
|
+
run_name = self.run_name or args.run_name or state.trial_name
|
|
85
|
+
|
|
86
|
+
self._run = dn.run(
|
|
87
|
+
name=run_name,
|
|
88
|
+
project=self.project,
|
|
89
|
+
tags=self.tags,
|
|
90
|
+
)
|
|
91
|
+
self._run.__enter__()
|
|
92
|
+
|
|
93
|
+
dn.log_params(**combined_dict)
|
|
94
|
+
dn.push_update()
|
|
95
|
+
|
|
96
|
+
def on_train_begin(
|
|
97
|
+
self,
|
|
98
|
+
args: TrainingArguments,
|
|
99
|
+
state: TrainerState,
|
|
100
|
+
control: TrainerControl,
|
|
101
|
+
model: t.Any | None = None,
|
|
102
|
+
**kwargs: t.Any,
|
|
103
|
+
) -> None:
|
|
104
|
+
if not self._initialized:
|
|
105
|
+
self._setup(args, state, model)
|
|
106
|
+
|
|
107
|
+
def on_train_end(
|
|
108
|
+
self,
|
|
109
|
+
args: TrainingArguments,
|
|
110
|
+
state: TrainerState,
|
|
111
|
+
control: TrainerControl,
|
|
112
|
+
**kwargs: t.Any,
|
|
113
|
+
) -> None:
|
|
114
|
+
self._shutdown()
|
|
115
|
+
|
|
116
|
+
def on_epoch_begin(
|
|
117
|
+
self,
|
|
118
|
+
args: TrainingArguments,
|
|
119
|
+
state: TrainerState,
|
|
120
|
+
control: TrainerControl,
|
|
121
|
+
**kwargs: t.Any,
|
|
122
|
+
) -> None:
|
|
123
|
+
if self._run is None or state.epoch is None:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
dn.log_metric("epoch", state.epoch)
|
|
127
|
+
|
|
128
|
+
self._epoch_span = dn.task_span(f"Epoch {state.epoch}")
|
|
129
|
+
self._epoch_span.__enter__()
|
|
130
|
+
|
|
131
|
+
def on_epoch_end(
|
|
132
|
+
self,
|
|
133
|
+
args: TrainingArguments,
|
|
134
|
+
state: TrainerState,
|
|
135
|
+
control: TrainerControl,
|
|
136
|
+
**kwargs: t.Any,
|
|
137
|
+
) -> None:
|
|
138
|
+
if self._epoch_span is not None:
|
|
139
|
+
self._epoch_span.__exit__(None, None, None)
|
|
140
|
+
self._epoch_span = None
|
|
141
|
+
|
|
142
|
+
def on_step_begin(
|
|
143
|
+
self,
|
|
144
|
+
args: TrainingArguments,
|
|
145
|
+
state: TrainerState,
|
|
146
|
+
control: TrainerControl,
|
|
147
|
+
**kwargs: t.Any,
|
|
148
|
+
) -> None:
|
|
149
|
+
if self._run is None:
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
dn.log_metric("step", state.global_step)
|
|
153
|
+
|
|
154
|
+
self._step_span = dn.span(f"Step {state.global_step}")
|
|
155
|
+
self._step_span.__enter__()
|
|
156
|
+
|
|
157
|
+
def on_step_end(
|
|
158
|
+
self,
|
|
159
|
+
args: TrainingArguments,
|
|
160
|
+
state: TrainerState,
|
|
161
|
+
control: TrainerControl,
|
|
162
|
+
**kwargs: t.Any,
|
|
163
|
+
) -> None:
|
|
164
|
+
if self._step_span is not None:
|
|
165
|
+
self._step_span.__exit__(None, None, None)
|
|
166
|
+
self._step_span = None
|
|
167
|
+
|
|
168
|
+
def on_log(
|
|
169
|
+
self,
|
|
170
|
+
args: TrainingArguments,
|
|
171
|
+
state: TrainerState,
|
|
172
|
+
control: TrainerControl,
|
|
173
|
+
logs: dict[str, t.Any] | None = None,
|
|
174
|
+
**kwargs: t.Any,
|
|
175
|
+
) -> None:
|
|
176
|
+
if self._run is None or logs is None:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
for key, value in _clean_keys(logs).items():
|
|
180
|
+
if isinstance(value, float | int):
|
|
181
|
+
dn.log_metric(key, value, step=state.global_step)
|
|
182
|
+
|
|
183
|
+
dn.push_update()
|