arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +89 -5
- arbor/client/api.py +1 -2
- arbor/server/api/models/schemas.py +209 -5
- arbor/server/api/routes/files.py +39 -10
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +110 -7
- arbor/server/core/config.py +44 -7
- arbor/server/main.py +6 -5
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -16
- arbor/server/services/file_manager.py +270 -109
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +74 -69
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +337 -40
- arbor_ai-0.1.6.dist-info/METADATA +78 -0
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +2 -1
- arbor_ai-0.1.6.dist-info/entry_points.txt +2 -0
- arbor_ai-0.1.6.dist-info/top_level.txt +1 -0
- arbor/server/api/routes/training.py +0 -16
- arbor_ai-0.1.4.dist-info/METADATA +0 -97
- arbor_ai-0.1.4.dist-info/RECORD +0 -27
- arbor_ai-0.1.4.dist-info/entry_points.txt +0 -3
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,576 @@
|
|
1
|
+
###############################################################################
|
2
|
+
# Initial Versions of this File Borrowed from Will Brown's Verifiers Library #
|
3
|
+
# https://github.com/willccbb/verifiers #
|
4
|
+
###############################################################################
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
import json
|
8
|
+
import random
|
9
|
+
import threading
|
10
|
+
import time
|
11
|
+
from functools import lru_cache
|
12
|
+
from typing import Any, List, Optional, Union
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import zmq
|
16
|
+
from accelerate import Accelerator
|
17
|
+
from accelerate.utils import gather
|
18
|
+
from datasets import Dataset, IterableDataset, load_dataset
|
19
|
+
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig # type: ignore
|
20
|
+
from torch.utils.data import Dataset
|
21
|
+
from transformers import (
|
22
|
+
PreTrainedModel,
|
23
|
+
PreTrainedTokenizerBase,
|
24
|
+
Trainer,
|
25
|
+
TrainerCallback,
|
26
|
+
)
|
27
|
+
from trl import GRPOConfig, GRPOTrainer
|
28
|
+
from trl.data_utils import maybe_apply_chat_template
|
29
|
+
|
30
|
+
from arbor.server.services.comms.comms import (
|
31
|
+
ArborScriptCommsHandler,
|
32
|
+
ArborServerCommsHandler,
|
33
|
+
)
|
34
|
+
|
35
|
+
last_step_time = None
|
36
|
+
last_queue_pop_time = None
|
37
|
+
|
38
|
+
|
39
|
+
def time_since_last_step():
|
40
|
+
global last_step_time
|
41
|
+
if last_step_time is None:
|
42
|
+
return float("inf")
|
43
|
+
return time.time() - last_step_time
|
44
|
+
|
45
|
+
|
46
|
+
def get_time_since_last_queue_pop():
|
47
|
+
global last_queue_pop_time
|
48
|
+
if last_queue_pop_time is None:
|
49
|
+
return float("inf")
|
50
|
+
return time.time() - last_queue_pop_time
|
51
|
+
|
52
|
+
|
53
|
+
class ArborGRPOTrainer(GRPOTrainer):
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model: Union[str, PreTrainedModel],
|
57
|
+
scale_rewards: bool = True,
|
58
|
+
args: Optional[GRPOConfig] = None,
|
59
|
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
60
|
+
eval_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
61
|
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
62
|
+
callbacks: Optional[list[TrainerCallback]] = None,
|
63
|
+
optimizers: tuple[
|
64
|
+
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
65
|
+
] = (None, None),
|
66
|
+
peft_config: Optional["PeftConfig"] = None,
|
67
|
+
comms_handler: Optional[ArborScriptCommsHandler] = None,
|
68
|
+
update_interval: Optional[int] = 5,
|
69
|
+
lora: Optional[bool] = False,
|
70
|
+
**kwargs,
|
71
|
+
):
|
72
|
+
|
73
|
+
super().__init__(
|
74
|
+
model=model,
|
75
|
+
reward_funcs=[],
|
76
|
+
args=args,
|
77
|
+
train_dataset=train_dataset,
|
78
|
+
eval_dataset=eval_dataset,
|
79
|
+
processing_class=processing_class,
|
80
|
+
callbacks=callbacks,
|
81
|
+
optimizers=optimizers,
|
82
|
+
peft_config=peft_config,
|
83
|
+
**kwargs,
|
84
|
+
)
|
85
|
+
self.peft_config = peft_config
|
86
|
+
self.scale_rewards = scale_rewards
|
87
|
+
self.comms_handler = comms_handler
|
88
|
+
self.update_interval = update_interval
|
89
|
+
|
90
|
+
def _generate_and_score_completions(
|
91
|
+
self, batch: List[dict[str, Any]]
|
92
|
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
93
|
+
device = self.accelerator.device
|
94
|
+
|
95
|
+
# Process prompts and completions
|
96
|
+
prompt_completion_texts = []
|
97
|
+
for example in batch:
|
98
|
+
prompt_completion_texts.append(
|
99
|
+
maybe_apply_chat_template(
|
100
|
+
{
|
101
|
+
"prompt": example["messages"],
|
102
|
+
"completion": [example["completion"]],
|
103
|
+
},
|
104
|
+
self.processing_class,
|
105
|
+
)
|
106
|
+
)
|
107
|
+
|
108
|
+
# Tokenize prompts
|
109
|
+
prompt_texts = [
|
110
|
+
prompt_completion_text["prompt"]
|
111
|
+
for prompt_completion_text in prompt_completion_texts
|
112
|
+
]
|
113
|
+
prompt_inputs = self.processing_class(
|
114
|
+
prompt_texts,
|
115
|
+
return_tensors="pt",
|
116
|
+
padding=True,
|
117
|
+
padding_side="left",
|
118
|
+
add_special_tokens=False,
|
119
|
+
).to(device)
|
120
|
+
prompt_ids = Trainer._prepare_inputs(self, prompt_inputs)
|
121
|
+
prompt_ids, prompt_mask = (
|
122
|
+
prompt_inputs["input_ids"],
|
123
|
+
prompt_inputs["attention_mask"],
|
124
|
+
)
|
125
|
+
|
126
|
+
# Tokenize completions
|
127
|
+
completion_texts = [
|
128
|
+
prompt_completion_text["completion"]
|
129
|
+
for prompt_completion_text in prompt_completion_texts
|
130
|
+
]
|
131
|
+
completion_ids = self.processing_class(
|
132
|
+
completion_texts,
|
133
|
+
return_tensors="pt",
|
134
|
+
padding=True,
|
135
|
+
add_special_tokens=False,
|
136
|
+
).to(device)
|
137
|
+
completion_ids, completion_mask = (
|
138
|
+
completion_ids["input_ids"],
|
139
|
+
completion_ids["attention_mask"],
|
140
|
+
)
|
141
|
+
|
142
|
+
if self.max_prompt_length is not None:
|
143
|
+
if prompt_ids.shape[1] > self.max_prompt_length:
|
144
|
+
print(f"Truncating prompt to {self.max_prompt_length} tokens")
|
145
|
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
146
|
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
147
|
+
|
148
|
+
if self.max_completion_length is not None:
|
149
|
+
if completion_ids.shape[1] > self.max_completion_length:
|
150
|
+
print(f"Truncating completion to {self.max_completion_length} tokens")
|
151
|
+
completion_ids = completion_ids[:, : self.max_completion_length]
|
152
|
+
completion_mask = completion_mask[:, : self.max_completion_length]
|
153
|
+
|
154
|
+
# Keeping this for when we switch to vllm
|
155
|
+
# if self.state.global_step != self._last_loaded_step:
|
156
|
+
# self._move_model_to_vllm()
|
157
|
+
# self._last_loaded_step = self.state.global_step
|
158
|
+
|
159
|
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
160
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
161
|
+
|
162
|
+
print(
|
163
|
+
f"prompt_completion_ids.shape (after truncation, if enabled): {prompt_completion_ids.shape}, prompt_ids.shape: {prompt_ids.shape}, completion_ids.shape: {completion_ids.shape}"
|
164
|
+
)
|
165
|
+
|
166
|
+
logits_to_keep = completion_ids.size(1)
|
167
|
+
|
168
|
+
with torch.no_grad():
|
169
|
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
170
|
+
# computation here, and use per_token_logps.detach() instead.
|
171
|
+
if self.num_iterations > 1:
|
172
|
+
old_per_token_logps = self._get_per_token_logps(
|
173
|
+
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
old_per_token_logps = None
|
177
|
+
|
178
|
+
if self.beta == 0.0:
|
179
|
+
ref_per_token_logps = None
|
180
|
+
elif self.ref_model is not None:
|
181
|
+
ref_per_token_logps = self._get_per_token_logps(
|
182
|
+
self.ref_model,
|
183
|
+
prompt_completion_ids,
|
184
|
+
attention_mask,
|
185
|
+
logits_to_keep,
|
186
|
+
)
|
187
|
+
else:
|
188
|
+
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
189
|
+
ref_per_token_logps = self._get_per_token_logps(
|
190
|
+
self.model,
|
191
|
+
prompt_completion_ids,
|
192
|
+
attention_mask,
|
193
|
+
logits_to_keep,
|
194
|
+
)
|
195
|
+
|
196
|
+
rewards = torch.tensor(
|
197
|
+
[example["reward"] for example in batch], dtype=torch.float32
|
198
|
+
).to(device)
|
199
|
+
rewards = gather(rewards)
|
200
|
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
201
|
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
202
|
+
|
203
|
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
|
204
|
+
self.num_generations, dim=0
|
205
|
+
)
|
206
|
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
207
|
+
self.num_generations, dim=0
|
208
|
+
)
|
209
|
+
advantages = rewards - mean_grouped_rewards
|
210
|
+
|
211
|
+
if self.scale_rewards:
|
212
|
+
# Scale the rewards to be between 0 and 1
|
213
|
+
advantages = advantages / (std_grouped_rewards + 1e-4)
|
214
|
+
|
215
|
+
# Slice to keep only the local part of the data
|
216
|
+
process_slice = slice(
|
217
|
+
self.accelerator.process_index * len(batch),
|
218
|
+
(self.accelerator.process_index + 1) * len(batch),
|
219
|
+
)
|
220
|
+
advantages = advantages[process_slice]
|
221
|
+
|
222
|
+
## Logged Metrics Removed Here
|
223
|
+
|
224
|
+
return {
|
225
|
+
"prompt_ids": prompt_ids,
|
226
|
+
"prompt_mask": prompt_mask,
|
227
|
+
"completion_ids": completion_ids,
|
228
|
+
"completion_mask": completion_mask,
|
229
|
+
"old_per_token_logps": old_per_token_logps,
|
230
|
+
"ref_per_token_logps": ref_per_token_logps,
|
231
|
+
"advantages": advantages,
|
232
|
+
}
|
233
|
+
|
234
|
+
|
235
|
+
class LastStepTimeCallback(TrainerCallback):
|
236
|
+
"A callback that prints a message at the beginning of training"
|
237
|
+
|
238
|
+
def on_step_end(self, args, state, control, **kwargs):
|
239
|
+
global last_step_time
|
240
|
+
print(f"Time since last step: {time_since_last_step()}")
|
241
|
+
last_step_time = time.time()
|
242
|
+
|
243
|
+
|
244
|
+
class BlockingQueueDataset(Dataset):
|
245
|
+
def __init__(
|
246
|
+
self,
|
247
|
+
accelerator: Accelerator,
|
248
|
+
comms_handler: ArborScriptCommsHandler,
|
249
|
+
size=10_000, # Just a random number
|
250
|
+
maxsize=100,
|
251
|
+
):
|
252
|
+
self.size = size
|
253
|
+
self.accelerator = accelerator
|
254
|
+
self.comms_handler = comms_handler
|
255
|
+
self.get_cached_data = lru_cache(maxsize=maxsize)(self._get_data)
|
256
|
+
self.completion_counters = {}
|
257
|
+
|
258
|
+
def __len__(self):
|
259
|
+
return self.size
|
260
|
+
|
261
|
+
def _get_data(self, idx):
|
262
|
+
rank = self.accelerator.process_index
|
263
|
+
world_size = self.accelerator.num_processes
|
264
|
+
|
265
|
+
if self.accelerator.is_main_process:
|
266
|
+
global last_queue_pop_time
|
267
|
+
last_queue_pop_time = time.time()
|
268
|
+
|
269
|
+
if idx not in self.completion_counters:
|
270
|
+
self.completion_counters[idx] = 0
|
271
|
+
|
272
|
+
try:
|
273
|
+
new_data = self.comms_handler.receive_data()
|
274
|
+
|
275
|
+
except Exception as e:
|
276
|
+
print(f"[rank {rank}] Error receiving data: {e}")
|
277
|
+
new_data = None
|
278
|
+
|
279
|
+
return new_data
|
280
|
+
|
281
|
+
def __getitem__(self, idx):
|
282
|
+
data = self.get_cached_data(idx)
|
283
|
+
# Create hash of data to detect if processes are using the same idx for the same data
|
284
|
+
data_hash = format(abs(hash(str(data))) % (16**8), "08x")
|
285
|
+
|
286
|
+
if data is None:
|
287
|
+
return None
|
288
|
+
|
289
|
+
counter = self.completion_counters.get(idx, 0)
|
290
|
+
item = data[counter]
|
291
|
+
self.completion_counters[idx] = (counter + 1) % len(data)
|
292
|
+
return item
|
293
|
+
|
294
|
+
|
295
|
+
class CommandMonitor:
|
296
|
+
def __init__(
|
297
|
+
self,
|
298
|
+
comms_handler: ArborScriptCommsHandler,
|
299
|
+
trainer: ArborGRPOTrainer,
|
300
|
+
base_model_name: str,
|
301
|
+
):
|
302
|
+
self.comms_handler = comms_handler
|
303
|
+
self.trainer = trainer
|
304
|
+
self.base_model_name = base_model_name
|
305
|
+
self.command_thread = threading.Thread(
|
306
|
+
target=self._monitor_commands, daemon=True
|
307
|
+
)
|
308
|
+
self.command_thread.start()
|
309
|
+
|
310
|
+
self.broadcast_thread = threading.Thread(
|
311
|
+
target=self._monitor_broadcasts, daemon=True
|
312
|
+
)
|
313
|
+
self.broadcast_thread.start()
|
314
|
+
|
315
|
+
def _monitor_commands(self):
|
316
|
+
"""Background thread that monitors for commands from the server."""
|
317
|
+
if not self.comms_handler:
|
318
|
+
return
|
319
|
+
try:
|
320
|
+
for command in self.comms_handler.receive_command():
|
321
|
+
print(f"Main process received command: {command}")
|
322
|
+
if (
|
323
|
+
command.get("command") == "save_model"
|
324
|
+
and self.trainer.accelerator.is_main_process
|
325
|
+
):
|
326
|
+
print(
|
327
|
+
f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
|
328
|
+
)
|
329
|
+
# Wait until data queue is empty before saving
|
330
|
+
|
331
|
+
while (
|
332
|
+
time_since_last_step() <= 10
|
333
|
+
or get_time_since_last_queue_pop() <= 10
|
334
|
+
):
|
335
|
+
# print(
|
336
|
+
# f"Waiting for data queue to empty...{self.comms_handler.get_data_queue_size()}"
|
337
|
+
# )
|
338
|
+
print(f"Waiting for steps to finish")
|
339
|
+
print(
|
340
|
+
f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
|
341
|
+
)
|
342
|
+
print(
|
343
|
+
f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
344
|
+
)
|
345
|
+
time.sleep(5) # Small delay to prevent busy waiting)
|
346
|
+
print("[Training Script] Saving model...")
|
347
|
+
|
348
|
+
if self.trainer.peft_config:
|
349
|
+
|
350
|
+
self.trainer.save_model(
|
351
|
+
output_dir=self.trainer.args.output_dir + "/adapter/"
|
352
|
+
)
|
353
|
+
|
354
|
+
# base_model = AutoModelForCausalLM.from_pretrained(
|
355
|
+
# self.base_model_name
|
356
|
+
# ).to(self.trainer.accelerator.device)
|
357
|
+
|
358
|
+
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
359
|
+
self.trainer.args.output_dir + "/adapter/",
|
360
|
+
config=self.trainer.peft_config,
|
361
|
+
)
|
362
|
+
merged_model = _model_to_merge.merge_and_unload()
|
363
|
+
merged_model.save_pretrained(
|
364
|
+
self.trainer.args.output_dir,
|
365
|
+
safe_serialization=True,
|
366
|
+
)
|
367
|
+
self.trainer.processing_class.save_pretrained(
|
368
|
+
self.trainer.args.output_dir
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
self.trainer.save_model()
|
372
|
+
|
373
|
+
print("[Training Script] Model saved")
|
374
|
+
self.comms_handler.send_status(
|
375
|
+
{
|
376
|
+
"status": "model_saved",
|
377
|
+
"output_dir": self.trainer.args.output_dir,
|
378
|
+
}
|
379
|
+
)
|
380
|
+
except Exception as e:
|
381
|
+
print(e)
|
382
|
+
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
383
|
+
|
384
|
+
def _monitor_broadcasts(self):
|
385
|
+
"""Background thread that monitors for broadcasts from the server."""
|
386
|
+
if not self.comms_handler:
|
387
|
+
return
|
388
|
+
try:
|
389
|
+
for broadcast in self.comms_handler.receive_broadcast():
|
390
|
+
print(f"!!!Received broadcast: {broadcast}")
|
391
|
+
if broadcast.get("message") == "terminate":
|
392
|
+
self.trainer.control.should_training_stop = True
|
393
|
+
self.comms_handler.send_status(
|
394
|
+
{
|
395
|
+
"status": "Received termination command",
|
396
|
+
"process_id": self.trainer.accelerator.process_index,
|
397
|
+
}
|
398
|
+
)
|
399
|
+
except Exception as e:
|
400
|
+
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
401
|
+
|
402
|
+
|
403
|
+
def main():
|
404
|
+
parser = argparse.ArgumentParser()
|
405
|
+
parser.add_argument("--debug", action="store_true")
|
406
|
+
|
407
|
+
pipe_args = parser.add_argument_group("Comms arguments")
|
408
|
+
pipe_args.add_argument("--host", default="localhost")
|
409
|
+
pipe_args.add_argument("--command_port", type=int, required=True)
|
410
|
+
pipe_args.add_argument("--status_port", type=int, required=True)
|
411
|
+
pipe_args.add_argument("--data_port", type=int, required=True)
|
412
|
+
pipe_args.add_argument("--broadcast_port", type=int, required=True)
|
413
|
+
pipe_args.add_argument("--handshake_port", type=int, required=True)
|
414
|
+
|
415
|
+
training_args = parser.add_argument_group("Training arguments")
|
416
|
+
training_args.add_argument(
|
417
|
+
"--model",
|
418
|
+
type=str,
|
419
|
+
help="Model to use for training",
|
420
|
+
)
|
421
|
+
training_args.add_argument(
|
422
|
+
"--trl_train_kwargs",
|
423
|
+
type=json.loads,
|
424
|
+
help="Training arguments as a JSON string",
|
425
|
+
)
|
426
|
+
training_args.add_argument(
|
427
|
+
"--arbor_train_kwargs",
|
428
|
+
type=json.loads,
|
429
|
+
help="Training arguments as a JSON string",
|
430
|
+
)
|
431
|
+
|
432
|
+
args = parser.parse_args()
|
433
|
+
|
434
|
+
if args.debug:
|
435
|
+
server_comms_handler = ArborServerCommsHandler(
|
436
|
+
host=args.host,
|
437
|
+
)
|
438
|
+
|
439
|
+
args.command_port = server_comms_handler.command_port
|
440
|
+
args.status_port = server_comms_handler.status_port
|
441
|
+
args.data_port = server_comms_handler.data_port
|
442
|
+
args.broadcast_port = server_comms_handler.broadcast_port
|
443
|
+
args.handshake_port = server_comms_handler.handshake_port
|
444
|
+
|
445
|
+
def debug_data_generator():
|
446
|
+
tldr_dataset = load_dataset("trl-lib/tldr", split="train")
|
447
|
+
idx = 0
|
448
|
+
for item in tldr_dataset:
|
449
|
+
input_messages = [{"role": "user", "content": item["prompt"]}]
|
450
|
+
completions = [
|
451
|
+
{
|
452
|
+
"role": "assistant",
|
453
|
+
"content": "This is a test completion"
|
454
|
+
+ hex(random.randint(0, 0xFFFFFF))[2:],
|
455
|
+
}
|
456
|
+
for _ in range(8)
|
457
|
+
]
|
458
|
+
|
459
|
+
rewards = [-abs(20 - len(c["content"])) for c in completions]
|
460
|
+
batch = []
|
461
|
+
for completion, reward in zip(completions, rewards):
|
462
|
+
batch.append(
|
463
|
+
{
|
464
|
+
"messages": input_messages,
|
465
|
+
"completion": completion,
|
466
|
+
"reward": reward,
|
467
|
+
}
|
468
|
+
)
|
469
|
+
server_comms_handler.send_data(batch)
|
470
|
+
time.sleep(1)
|
471
|
+
|
472
|
+
if idx >= 25:
|
473
|
+
server_comms_handler.send_command({"command": "save_model"})
|
474
|
+
|
475
|
+
debug_thread = threading.Thread(target=debug_data_generator, daemon=True)
|
476
|
+
debug_thread.start()
|
477
|
+
|
478
|
+
def status_listener():
|
479
|
+
# Need to set subscription for PUB/SUB pattern
|
480
|
+
server_comms_handler.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
481
|
+
for status in server_comms_handler.receive_status():
|
482
|
+
print(f"Status: {status}")
|
483
|
+
|
484
|
+
status_listener_thread = threading.Thread(target=status_listener, daemon=True)
|
485
|
+
status_listener_thread.start()
|
486
|
+
|
487
|
+
try:
|
488
|
+
trl_train_args = {**(args.trl_train_kwargs or {})}
|
489
|
+
arbor_train_args = {**(args.arbor_train_kwargs or {})}
|
490
|
+
|
491
|
+
# TODO: These assertions should be done in some better way
|
492
|
+
assert "output_dir" in trl_train_args, "output_dir is required"
|
493
|
+
if "gradient_checkpointing_kwargs" in trl_train_args and arbor_train_args.get(
|
494
|
+
"lora", False
|
495
|
+
):
|
496
|
+
print(
|
497
|
+
"Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
|
498
|
+
)
|
499
|
+
trl_train_args["gradient_checkpointing_kwargs"] = {
|
500
|
+
**(trl_train_args.get("gradient_checkpointing_kwargs") or {}),
|
501
|
+
"use_reentrant": False,
|
502
|
+
}
|
503
|
+
|
504
|
+
lora_config = None
|
505
|
+
if arbor_train_args.get("lora", False):
|
506
|
+
print("Using LORA for PEFT")
|
507
|
+
lora_config = LoraConfig(
|
508
|
+
r=16,
|
509
|
+
lora_alpha=64,
|
510
|
+
target_modules=[
|
511
|
+
"q_proj",
|
512
|
+
"k_proj",
|
513
|
+
"v_proj",
|
514
|
+
"o_proj",
|
515
|
+
"up_proj",
|
516
|
+
"down_proj",
|
517
|
+
"gate_proj",
|
518
|
+
],
|
519
|
+
task_type="CAUSAL_LM",
|
520
|
+
lora_dropout=0.05,
|
521
|
+
inference_mode=False,
|
522
|
+
)
|
523
|
+
|
524
|
+
training_args = GRPOConfig(
|
525
|
+
dataloader_num_workers=0,
|
526
|
+
shuffle_dataset=False,
|
527
|
+
**trl_train_args,
|
528
|
+
)
|
529
|
+
|
530
|
+
trainer = ArborGRPOTrainer(
|
531
|
+
model=args.model,
|
532
|
+
args=training_args,
|
533
|
+
train_dataset=BlockingQueueDataset(None, None),
|
534
|
+
callbacks=[LastStepTimeCallback()],
|
535
|
+
peft_config=lora_config,
|
536
|
+
**arbor_train_args,
|
537
|
+
)
|
538
|
+
# Create client handler
|
539
|
+
comms_handler = ArborScriptCommsHandler(
|
540
|
+
host=args.host,
|
541
|
+
command_port=args.command_port,
|
542
|
+
status_port=args.status_port,
|
543
|
+
data_port=args.data_port,
|
544
|
+
broadcast_port=args.broadcast_port,
|
545
|
+
handshake_port=args.handshake_port,
|
546
|
+
is_main_process=trainer.accelerator.is_main_process,
|
547
|
+
)
|
548
|
+
trainer.comms_handler = comms_handler
|
549
|
+
|
550
|
+
# Initialize the dataset with the actual accelerator
|
551
|
+
trainer.train_dataset = BlockingQueueDataset(
|
552
|
+
accelerator=trainer.accelerator,
|
553
|
+
comms_handler=trainer.comms_handler,
|
554
|
+
)
|
555
|
+
|
556
|
+
command_monitor = CommandMonitor(
|
557
|
+
comms_handler=comms_handler,
|
558
|
+
trainer=trainer,
|
559
|
+
base_model_name=args.model,
|
560
|
+
)
|
561
|
+
|
562
|
+
print("Training...")
|
563
|
+
trainer.train()
|
564
|
+
|
565
|
+
except KeyboardInterrupt:
|
566
|
+
print("\nReceived interrupt, shutting down...")
|
567
|
+
except Exception as e:
|
568
|
+
print(f"Error: {e}")
|
569
|
+
comms_handler.send_status({"status": "error", "error": str(e)})
|
570
|
+
raise e
|
571
|
+
finally:
|
572
|
+
comms_handler.close()
|
573
|
+
|
574
|
+
|
575
|
+
if __name__ == "__main__":
|
576
|
+
main()
|