arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +98 -62
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +6 -4
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
- arbor_ai-0.2.2.dist-info/RECORD +51 -0
- arbor_ai-0.2.1.dist-info/RECORD +0 -42
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,510 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import signal
|
4
|
+
import sys
|
5
|
+
import time
|
6
|
+
from typing import Any, Optional, Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import trl.extras.vllm_client
|
10
|
+
from datasets import Dataset, IterableDataset, load_dataset
|
11
|
+
from peft import LoraConfig, PeftConfig
|
12
|
+
from torch.utils.data import Dataset
|
13
|
+
from transformers import (
|
14
|
+
PreTrainedModel,
|
15
|
+
PreTrainedTokenizerBase,
|
16
|
+
Trainer,
|
17
|
+
TrainerCallback,
|
18
|
+
is_wandb_available,
|
19
|
+
)
|
20
|
+
from trl.data_utils import maybe_apply_chat_template
|
21
|
+
from trl.trainer.grpo_trainer import GRPOConfig, GRPOTrainer, nanmax, nanmin
|
22
|
+
|
23
|
+
from arbor.server.services.comms.comms import ArborScriptCommsHandler
|
24
|
+
from arbor.server.services.inference.vllm_client import VLLMClient
|
25
|
+
from arbor.server.services.scripts.utils.callbacks import WeightUpdateCallback
|
26
|
+
from arbor.server.services.scripts.utils.comms_monitors import CommandMonitor
|
27
|
+
from arbor.server.services.scripts.utils.dataset import BlockingQueueDataset
|
28
|
+
from arbor.server.services.scripts.utils.ingestion_monitor import IngestionMonitor
|
29
|
+
|
30
|
+
trl.extras.vllm_client.VLLMClient = VLLMClient
|
31
|
+
|
32
|
+
|
33
|
+
class MMGRPOTrainer(GRPOTrainer):
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
model: Union[str, PreTrainedModel],
|
37
|
+
args: Optional[GRPOConfig] = None,
|
38
|
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
39
|
+
eval_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
40
|
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
41
|
+
callbacks: Optional[list[TrainerCallback]] = None,
|
42
|
+
optimizers: tuple[
|
43
|
+
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
44
|
+
] = (None, None),
|
45
|
+
peft_config: Optional["PeftConfig"] = None,
|
46
|
+
lora: Optional[bool] = False,
|
47
|
+
vllm_group_port: Optional[int] = None,
|
48
|
+
max_context_length: Optional[int] = None,
|
49
|
+
grpo_flavor: Optional[str] = "mmgrpo",
|
50
|
+
**kwargs,
|
51
|
+
):
|
52
|
+
super().__init__(
|
53
|
+
model=model,
|
54
|
+
reward_funcs=[],
|
55
|
+
args=args,
|
56
|
+
train_dataset=train_dataset,
|
57
|
+
eval_dataset=eval_dataset,
|
58
|
+
processing_class=processing_class,
|
59
|
+
callbacks=callbacks,
|
60
|
+
optimizers=optimizers,
|
61
|
+
peft_config=peft_config,
|
62
|
+
**kwargs,
|
63
|
+
)
|
64
|
+
self.peft_config = peft_config
|
65
|
+
self.loss_type = "mmgrpo"
|
66
|
+
|
67
|
+
self.vllm_client = None
|
68
|
+
args.use_vllm = True
|
69
|
+
self.use_vllm = True
|
70
|
+
if self.accelerator.is_main_process:
|
71
|
+
print(
|
72
|
+
f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
|
73
|
+
)
|
74
|
+
self.vllm_client = VLLMClient(
|
75
|
+
args.vllm_server_host,
|
76
|
+
args.vllm_server_port,
|
77
|
+
group_port=vllm_group_port,
|
78
|
+
connection_timeout=args.vllm_server_timeout,
|
79
|
+
)
|
80
|
+
self.vllm_client.init_communicator()
|
81
|
+
|
82
|
+
# vLLM specific sampling arguments
|
83
|
+
self.guided_decoding_regex = args.vllm_guided_decoding_regex
|
84
|
+
|
85
|
+
self._last_loaded_step = (
|
86
|
+
-1
|
87
|
+
) # tag to avoid useless loading during grad accumulation
|
88
|
+
|
89
|
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
90
|
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
91
|
+
# synchronize all processes after vLLM has been fully initialized.
|
92
|
+
self.accelerator.wait_for_everyone()
|
93
|
+
|
94
|
+
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
95
|
+
self.comms_handler = comms_handler
|
96
|
+
|
97
|
+
def get_train_dataloader(self):
|
98
|
+
return Trainer.get_train_dataloader(self)
|
99
|
+
|
100
|
+
def _get_train_sampler(self, dataset: Optional[Dataset] = None):
|
101
|
+
return Trainer._get_train_sampler(self, dataset)
|
102
|
+
|
103
|
+
def _tensorize_prompts_completions(
|
104
|
+
self, generation_batch: dict[str, Union[torch.Tensor, Any]]
|
105
|
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
106
|
+
prompt_completion_texts = []
|
107
|
+
for example in generation_batch:
|
108
|
+
prompt_completion_texts.append(
|
109
|
+
maybe_apply_chat_template(
|
110
|
+
{
|
111
|
+
"prompt": example["messages"],
|
112
|
+
"completion": (
|
113
|
+
example["completion"]
|
114
|
+
if isinstance(example["completion"], list)
|
115
|
+
else [example["completion"]]
|
116
|
+
),
|
117
|
+
},
|
118
|
+
self.processing_class,
|
119
|
+
)
|
120
|
+
)
|
121
|
+
prompts_text = [
|
122
|
+
prompt_completion_text["prompt"]
|
123
|
+
for prompt_completion_text in prompt_completion_texts
|
124
|
+
]
|
125
|
+
prompt_inputs = self.processing_class(
|
126
|
+
prompts_text,
|
127
|
+
return_tensors="pt",
|
128
|
+
padding=True,
|
129
|
+
padding_side="left",
|
130
|
+
add_special_tokens=False,
|
131
|
+
)
|
132
|
+
prompt_ids, prompt_mask = (
|
133
|
+
prompt_inputs["input_ids"],
|
134
|
+
prompt_inputs["attention_mask"],
|
135
|
+
)
|
136
|
+
|
137
|
+
completion_text = [
|
138
|
+
prompt_completion_text["completion"]
|
139
|
+
for prompt_completion_text in prompt_completion_texts
|
140
|
+
]
|
141
|
+
completion_inputs = self.processing_class(
|
142
|
+
completion_text,
|
143
|
+
return_tensors="pt",
|
144
|
+
padding=True,
|
145
|
+
add_special_tokens=False,
|
146
|
+
)
|
147
|
+
completion_ids, completion_mask = (
|
148
|
+
completion_inputs["input_ids"],
|
149
|
+
completion_inputs["attention_mask"],
|
150
|
+
)
|
151
|
+
|
152
|
+
if self.max_prompt_length is not None:
|
153
|
+
if prompt_ids.shape[1] > self.max_prompt_length:
|
154
|
+
print(f"Truncating prompt to {self.max_prompt_length} tokens")
|
155
|
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
156
|
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
157
|
+
|
158
|
+
if self.max_completion_length is not None:
|
159
|
+
if completion_ids.shape[1] > self.max_completion_length:
|
160
|
+
print(f"Truncating completion to {self.max_completion_length} tokens")
|
161
|
+
completion_ids = completion_ids[:, : self.max_completion_length]
|
162
|
+
completion_mask = completion_mask[:, : self.max_completion_length]
|
163
|
+
|
164
|
+
return {
|
165
|
+
"prompt_ids": prompt_ids,
|
166
|
+
"prompt_mask": prompt_mask,
|
167
|
+
"completion_ids": completion_ids,
|
168
|
+
"completion_mask": completion_mask,
|
169
|
+
}
|
170
|
+
|
171
|
+
def _get_trajectory_lengths(
|
172
|
+
self, generation_batch: dict[str, Union[torch.Tensor, Any]]
|
173
|
+
) -> torch.Tensor:
|
174
|
+
trajectory_lengths = []
|
175
|
+
for example in generation_batch:
|
176
|
+
full_trajectory = example["trajectory"]
|
177
|
+
prompt_completion_tensors = self._tensorize_prompts_completions(
|
178
|
+
full_trajectory
|
179
|
+
)
|
180
|
+
completion_mask = prompt_completion_tensors["completion_mask"]
|
181
|
+
trajectory_lengths.append(completion_mask.sum())
|
182
|
+
return torch.tensor(trajectory_lengths)
|
183
|
+
|
184
|
+
def _prepare_inputs(
|
185
|
+
self, generation_batch: dict[str, Union[torch.Tensor, Any]]
|
186
|
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
187
|
+
device = self.accelerator.device
|
188
|
+
mode = "train" if self.model.training else "eval"
|
189
|
+
prompt_completion_tensors = self._tensorize_prompts_completions(
|
190
|
+
generation_batch
|
191
|
+
)
|
192
|
+
prompt_ids, prompt_mask = prompt_completion_tensors["prompt_ids"].to(
|
193
|
+
device
|
194
|
+
), prompt_completion_tensors["prompt_mask"].to(device)
|
195
|
+
completion_ids, completion_mask = prompt_completion_tensors[
|
196
|
+
"completion_ids"
|
197
|
+
].to(device), prompt_completion_tensors["completion_mask"].to(device)
|
198
|
+
|
199
|
+
advantages = torch.tensor(
|
200
|
+
[example["advantage"] for example in generation_batch]
|
201
|
+
).to(device)
|
202
|
+
trajectory_lengths = self._get_trajectory_lengths(generation_batch).to(device)
|
203
|
+
|
204
|
+
out = {
|
205
|
+
"prompt_ids": prompt_ids,
|
206
|
+
"prompt_mask": prompt_mask,
|
207
|
+
"completion_ids": completion_ids,
|
208
|
+
"completion_mask": completion_mask,
|
209
|
+
"advantages": advantages,
|
210
|
+
"old_per_token_logps": None,
|
211
|
+
"trajectory_lengths": trajectory_lengths,
|
212
|
+
}
|
213
|
+
return out
|
214
|
+
|
215
|
+
def compute_loss(
|
216
|
+
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
217
|
+
):
|
218
|
+
if return_outputs:
|
219
|
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
220
|
+
|
221
|
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
222
|
+
completion_ids, completion_mask = (
|
223
|
+
inputs["completion_ids"],
|
224
|
+
inputs["completion_mask"],
|
225
|
+
)
|
226
|
+
inputs_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
227
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
228
|
+
logits_to_keep = completion_ids.size(1)
|
229
|
+
|
230
|
+
per_token_logps = self._get_per_token_logps(
|
231
|
+
model, inputs_ids, attention_mask, logits_to_keep
|
232
|
+
)
|
233
|
+
|
234
|
+
if self.beta != 0.0:
|
235
|
+
with torch.no_grad():
|
236
|
+
if self.ref_model is not None:
|
237
|
+
ref_per_token_logps = self._get_per_token_logps(
|
238
|
+
self.ref_model, inputs_ids, attention_mask, logits_to_keep
|
239
|
+
)
|
240
|
+
else:
|
241
|
+
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
242
|
+
ref_per_token_logps = self._get_per_token_logps(
|
243
|
+
self.model, inputs_ids, attention_mask, logits_to_keep
|
244
|
+
)
|
245
|
+
per_token_kl = (
|
246
|
+
torch.exp(ref_per_token_logps - per_token_logps)
|
247
|
+
- (ref_per_token_logps - per_token_logps)
|
248
|
+
- 1
|
249
|
+
)
|
250
|
+
|
251
|
+
advantages = inputs["advantages"]
|
252
|
+
trajectory_lengths = inputs["trajectory_lengths"]
|
253
|
+
# When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
|
254
|
+
# old_per_token_logps == per_token_logps, so we can skip it's computation
|
255
|
+
# (see _generate_and_score_completions) and use per_token_logps.detach() instead.
|
256
|
+
old_per_token_logps = (
|
257
|
+
per_token_logps.detach()
|
258
|
+
if inputs["old_per_token_logps"] is None
|
259
|
+
else inputs["old_per_token_logps"]
|
260
|
+
)
|
261
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
262
|
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
263
|
+
|
264
|
+
if self.args.delta is not None:
|
265
|
+
coef_1 = torch.clamp(coef_1, max=self.args.delta)
|
266
|
+
|
267
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
268
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
269
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
270
|
+
if self.beta != 0.0:
|
271
|
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
272
|
+
|
273
|
+
if self.loss_type == "grpo":
|
274
|
+
loss = (
|
275
|
+
(per_token_loss * completion_mask).sum(-1)
|
276
|
+
/ completion_mask.sum(-1).clamp(min=1.0)
|
277
|
+
).mean()
|
278
|
+
elif self.loss_type == "bnpo":
|
279
|
+
loss = (
|
280
|
+
per_token_loss * completion_mask
|
281
|
+
).sum() / completion_mask.sum().clamp(min=1.0)
|
282
|
+
elif self.loss_type == "dr_grpo":
|
283
|
+
loss = (per_token_loss * completion_mask).sum() / (
|
284
|
+
per_token_loss.size(0) * self.max_completion_length
|
285
|
+
)
|
286
|
+
elif self.loss_type == "mmgrpo":
|
287
|
+
# Sum the loss over tokens for each trajectory
|
288
|
+
trajectory_losses = (per_token_loss * completion_mask).sum(dim=-1)
|
289
|
+
# Normalize by the actual trajectory lengths
|
290
|
+
normalized_losses = trajectory_losses / trajectory_lengths.clamp(min=1.0)
|
291
|
+
# Take the mean over the batch
|
292
|
+
loss = normalized_losses.mean()
|
293
|
+
else:
|
294
|
+
raise ValueError(f"Unknown loss type: {self.loss_type}")
|
295
|
+
|
296
|
+
# Log the metrics
|
297
|
+
mode = "train" if self.model.training else "eval"
|
298
|
+
|
299
|
+
if self.beta != 0.0:
|
300
|
+
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
301
|
+
self._metrics[mode]["kl"].append(
|
302
|
+
self.accelerator.gather(mean_kl).nanmean().item()
|
303
|
+
)
|
304
|
+
|
305
|
+
# Compute the clipped probability ratios
|
306
|
+
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
|
307
|
+
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (
|
308
|
+
advantages.unsqueeze(1) > 0
|
309
|
+
)
|
310
|
+
is_region_clipped = is_low_clipped | is_high_clipped
|
311
|
+
|
312
|
+
low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
|
313
|
+
high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
|
314
|
+
clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
|
315
|
+
|
316
|
+
gathered_low_clip = self.accelerator.gather(low_clip)
|
317
|
+
self._metrics[mode]["clip_ratio/low_mean"].append(
|
318
|
+
gathered_low_clip.nanmean().item()
|
319
|
+
)
|
320
|
+
self._metrics[mode]["clip_ratio/low_min"].append(
|
321
|
+
nanmin(gathered_low_clip).item()
|
322
|
+
)
|
323
|
+
gathered_high_clip = self.accelerator.gather(high_clip)
|
324
|
+
self._metrics[mode]["clip_ratio/high_mean"].append(
|
325
|
+
gathered_high_clip.nanmean().item()
|
326
|
+
)
|
327
|
+
self._metrics[mode]["clip_ratio/high_max"].append(
|
328
|
+
nanmax(gathered_high_clip).item()
|
329
|
+
)
|
330
|
+
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
|
331
|
+
self._metrics[mode]["clip_ratio/region_mean"].append(
|
332
|
+
gathered_clip_ratio.nanmean().item()
|
333
|
+
)
|
334
|
+
print(f"Loss: {loss.item()}")
|
335
|
+
|
336
|
+
return loss
|
337
|
+
|
338
|
+
|
339
|
+
def main():
|
340
|
+
parser = argparse.ArgumentParser()
|
341
|
+
parser.add_argument("--debug", action="store_true")
|
342
|
+
|
343
|
+
pipe_args = parser.add_argument_group("Comms arguments")
|
344
|
+
pipe_args.add_argument("--host", default="localhost")
|
345
|
+
pipe_args.add_argument("--command_port", type=int, required=True)
|
346
|
+
pipe_args.add_argument("--status_port", type=int, required=True)
|
347
|
+
pipe_args.add_argument("--data_port", type=int, required=True)
|
348
|
+
pipe_args.add_argument("--broadcast_port", type=int, required=True)
|
349
|
+
pipe_args.add_argument("--handshake_port", type=int, required=True)
|
350
|
+
pipe_args.add_argument("--vllm_group_port", type=int, required=True)
|
351
|
+
pipe_args.add_argument("--vllm_port", type=int, required=True)
|
352
|
+
|
353
|
+
training_args = parser.add_argument_group("Training arguments")
|
354
|
+
training_args.add_argument(
|
355
|
+
"--model",
|
356
|
+
type=str,
|
357
|
+
help="Model to use for training",
|
358
|
+
)
|
359
|
+
training_args.add_argument(
|
360
|
+
"--trl_train_kwargs",
|
361
|
+
type=json.loads,
|
362
|
+
help="Training arguments as a JSON string",
|
363
|
+
)
|
364
|
+
training_args.add_argument(
|
365
|
+
"--arbor_train_kwargs",
|
366
|
+
type=json.loads,
|
367
|
+
help="Training arguments as a JSON string",
|
368
|
+
)
|
369
|
+
|
370
|
+
args = parser.parse_args()
|
371
|
+
|
372
|
+
if args.debug:
|
373
|
+
pass
|
374
|
+
|
375
|
+
try:
|
376
|
+
trl_train_args = {**(args.trl_train_kwargs or {})}
|
377
|
+
arbor_train_args = {**(args.arbor_train_kwargs or {})}
|
378
|
+
|
379
|
+
# TODO: These assertions should be done in some better way
|
380
|
+
assert "output_dir" in trl_train_args, "output_dir is required"
|
381
|
+
if "gradient_checkpointing_kwargs" in trl_train_args and arbor_train_args.get(
|
382
|
+
"lora", False
|
383
|
+
):
|
384
|
+
print(
|
385
|
+
"Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
|
386
|
+
)
|
387
|
+
trl_train_args["gradient_checkpointing_kwargs"] = {
|
388
|
+
**(trl_train_args.get("gradient_checkpointing_kwargs") or {}),
|
389
|
+
"use_reentrant": False,
|
390
|
+
}
|
391
|
+
|
392
|
+
lora_config = None
|
393
|
+
if arbor_train_args.get("lora", False):
|
394
|
+
print("Using LORA for PEFT")
|
395
|
+
lora_config = LoraConfig(
|
396
|
+
r=16,
|
397
|
+
lora_alpha=64,
|
398
|
+
target_modules=[
|
399
|
+
"q_proj",
|
400
|
+
"k_proj",
|
401
|
+
"v_proj",
|
402
|
+
"o_proj",
|
403
|
+
"up_proj",
|
404
|
+
"down_proj",
|
405
|
+
"gate_proj",
|
406
|
+
],
|
407
|
+
task_type="CAUSAL_LM",
|
408
|
+
lora_dropout=0.05,
|
409
|
+
inference_mode=False,
|
410
|
+
)
|
411
|
+
|
412
|
+
training_args = GRPOConfig(
|
413
|
+
dataloader_num_workers=0,
|
414
|
+
shuffle_dataset=False,
|
415
|
+
vllm_server_port=args.vllm_port,
|
416
|
+
**trl_train_args,
|
417
|
+
)
|
418
|
+
|
419
|
+
# Create ingestion monitor
|
420
|
+
ingestion_monitor = IngestionMonitor()
|
421
|
+
|
422
|
+
train_dataset = BlockingQueueDataset(
|
423
|
+
ingestion_monitor=ingestion_monitor,
|
424
|
+
)
|
425
|
+
weight_update_callback = WeightUpdateCallback(
|
426
|
+
ingestion_monitor=ingestion_monitor,
|
427
|
+
)
|
428
|
+
trainer = MMGRPOTrainer(
|
429
|
+
model=args.model,
|
430
|
+
args=training_args,
|
431
|
+
train_dataset=train_dataset,
|
432
|
+
callbacks=[weight_update_callback],
|
433
|
+
peft_config=lora_config,
|
434
|
+
vllm_group_port=args.vllm_group_port,
|
435
|
+
**arbor_train_args,
|
436
|
+
)
|
437
|
+
|
438
|
+
comms_handler = ArborScriptCommsHandler(
|
439
|
+
host=args.host,
|
440
|
+
command_port=args.command_port,
|
441
|
+
status_port=args.status_port,
|
442
|
+
data_port=args.data_port,
|
443
|
+
broadcast_port=args.broadcast_port,
|
444
|
+
handshake_port=args.handshake_port,
|
445
|
+
is_main_process=trainer.accelerator.is_main_process,
|
446
|
+
)
|
447
|
+
|
448
|
+
train_dataset.set_comms_handler(comms_handler)
|
449
|
+
train_dataset.set_accelerator(trainer.accelerator)
|
450
|
+
|
451
|
+
weight_update_callback.set_comms_handler(comms_handler)
|
452
|
+
weight_update_callback.set_trainer(trainer)
|
453
|
+
|
454
|
+
trainer.set_comms_handler(comms_handler)
|
455
|
+
|
456
|
+
command_monitor = CommandMonitor(
|
457
|
+
comms_handler=comms_handler,
|
458
|
+
trainer=trainer,
|
459
|
+
base_model_name=args.model,
|
460
|
+
ingestion_monitor=ingestion_monitor,
|
461
|
+
)
|
462
|
+
command_monitor.start()
|
463
|
+
|
464
|
+
print("command monitor started")
|
465
|
+
|
466
|
+
# Add signal handlers for graceful shutdown
|
467
|
+
def signal_handler(signum, frame):
|
468
|
+
print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
|
469
|
+
print("Ending training...")
|
470
|
+
trainer.accelerator.end_training()
|
471
|
+
print("Closing communications...")
|
472
|
+
comms_handler.close()
|
473
|
+
sys.exit(0)
|
474
|
+
|
475
|
+
signal.signal(signal.SIGINT, signal_handler)
|
476
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
477
|
+
|
478
|
+
print("Signal handlers added")
|
479
|
+
|
480
|
+
print("Starting training...")
|
481
|
+
try:
|
482
|
+
trainer.train()
|
483
|
+
except Exception as e:
|
484
|
+
print(f"Error during training: {e}")
|
485
|
+
print(f"Error type: {type(e).__name__}")
|
486
|
+
if "unhashable" in str(e):
|
487
|
+
print("DEBUGGING: Unhashable type error during training")
|
488
|
+
print(
|
489
|
+
"This could be in data loading, model forward pass, or metrics collection"
|
490
|
+
)
|
491
|
+
raise
|
492
|
+
|
493
|
+
except Exception as e:
|
494
|
+
import traceback
|
495
|
+
|
496
|
+
print(f"Error: {e}")
|
497
|
+
print("Stack trace:")
|
498
|
+
traceback.print_exc()
|
499
|
+
comms_handler.send_status({"status": "error", "error": str(e)})
|
500
|
+
raise e
|
501
|
+
finally:
|
502
|
+
print("Cleaning up resources...")
|
503
|
+
trainer.accelerator.end_training()
|
504
|
+
comms_handler.close()
|
505
|
+
print("Cleanup complete")
|
506
|
+
|
507
|
+
|
508
|
+
# Example usage:
|
509
|
+
if __name__ == "__main__":
|
510
|
+
main()
|
@@ -11,6 +11,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
11
11
|
from trl import SFTConfig, SFTTrainer, setup_chat_format
|
12
12
|
|
13
13
|
from arbor.server.services.scripts.utils.arg_parser import get_training_arg_parser
|
14
|
+
from arbor.server.utils.logging import get_logger
|
15
|
+
|
16
|
+
logger = get_logger(__name__)
|
14
17
|
|
15
18
|
|
16
19
|
def main():
|
@@ -25,7 +28,7 @@ def main():
|
|
25
28
|
# TODO: These assertions should be done in some better way
|
26
29
|
assert "output_dir" in trl_train_kwargs, "output_dir is required"
|
27
30
|
if "gradient_checkpointing_kwargs" in trl_train_kwargs and args.lora:
|
28
|
-
|
31
|
+
logger.info(
|
29
32
|
"Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
|
30
33
|
)
|
31
34
|
trl_train_kwargs["gradient_checkpointing_kwargs"] = {
|
@@ -35,7 +38,7 @@ def main():
|
|
35
38
|
|
36
39
|
lora_config = None
|
37
40
|
if args.lora:
|
38
|
-
|
41
|
+
logger.info("Using LORA for PEFT")
|
39
42
|
lora_config = LoraConfig(
|
40
43
|
r=16,
|
41
44
|
lora_alpha=64,
|
@@ -91,13 +94,13 @@ def main():
|
|
91
94
|
base_model_name=args.model,
|
92
95
|
)
|
93
96
|
|
94
|
-
|
97
|
+
logger.info("Starting training...")
|
95
98
|
trainer.train()
|
96
99
|
|
97
100
|
except KeyboardInterrupt:
|
98
|
-
|
101
|
+
logger.info("Received interrupt, shutting down...")
|
99
102
|
except Exception as e:
|
100
|
-
|
103
|
+
logger.error(f"Training error: {e}")
|
101
104
|
comms_handler.send_status({"status": "error", "error": str(e)})
|
102
105
|
raise e
|
103
106
|
finally:
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from transformers import TrainerCallback
|
4
|
+
|
5
|
+
from arbor.server.services.comms.comms import ArborScriptCommsHandler
|
6
|
+
from arbor.server.services.scripts.utils.ingestion_monitor import IngestionMonitor
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class WeightUpdateCallback(TrainerCallback):
|
12
|
+
"""A callback that sends weight update completion status after each step"""
|
13
|
+
|
14
|
+
def __init__(self, ingestion_monitor: IngestionMonitor):
|
15
|
+
self.comms_handler = None
|
16
|
+
self.trainer = None
|
17
|
+
self.ingestion_monitor = ingestion_monitor
|
18
|
+
|
19
|
+
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
20
|
+
self.comms_handler = comms_handler
|
21
|
+
|
22
|
+
def set_trainer(self, trainer):
|
23
|
+
self.trainer = trainer
|
24
|
+
|
25
|
+
def on_step_end(self, args, state, control, **kwargs):
|
26
|
+
self.ingestion_monitor.set_last_step_time()
|
27
|
+
if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
|
28
|
+
if state.global_step != self.trainer._last_loaded_step:
|
29
|
+
logger.info("Updating inference model...")
|
30
|
+
self.comms_handler.send_status({"status": "weight_update_start"})
|
31
|
+
self.trainer._move_model_to_vllm()
|
32
|
+
self.trainer._last_loaded_step = state.global_step
|
33
|
+
self.comms_handler.send_status({"status": "weight_update_complete"})
|