arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,510 @@
1
+ import argparse
2
+ import json
3
+ import signal
4
+ import sys
5
+ import time
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+ import trl.extras.vllm_client
10
+ from datasets import Dataset, IterableDataset, load_dataset
11
+ from peft import LoraConfig, PeftConfig
12
+ from torch.utils.data import Dataset
13
+ from transformers import (
14
+ PreTrainedModel,
15
+ PreTrainedTokenizerBase,
16
+ Trainer,
17
+ TrainerCallback,
18
+ is_wandb_available,
19
+ )
20
+ from trl.data_utils import maybe_apply_chat_template
21
+ from trl.trainer.grpo_trainer import GRPOConfig, GRPOTrainer, nanmax, nanmin
22
+
23
+ from arbor.server.services.comms.comms import ArborScriptCommsHandler
24
+ from arbor.server.services.inference.vllm_client import VLLMClient
25
+ from arbor.server.services.scripts.utils.callbacks import WeightUpdateCallback
26
+ from arbor.server.services.scripts.utils.comms_monitors import CommandMonitor
27
+ from arbor.server.services.scripts.utils.dataset import BlockingQueueDataset
28
+ from arbor.server.services.scripts.utils.ingestion_monitor import IngestionMonitor
29
+
30
+ trl.extras.vllm_client.VLLMClient = VLLMClient
31
+
32
+
33
+ class MMGRPOTrainer(GRPOTrainer):
34
+ def __init__(
35
+ self,
36
+ model: Union[str, PreTrainedModel],
37
+ args: Optional[GRPOConfig] = None,
38
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
39
+ eval_dataset: Optional[Union[Dataset, IterableDataset]] = None,
40
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
41
+ callbacks: Optional[list[TrainerCallback]] = None,
42
+ optimizers: tuple[
43
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
44
+ ] = (None, None),
45
+ peft_config: Optional["PeftConfig"] = None,
46
+ lora: Optional[bool] = False,
47
+ vllm_group_port: Optional[int] = None,
48
+ max_context_length: Optional[int] = None,
49
+ grpo_flavor: Optional[str] = "mmgrpo",
50
+ **kwargs,
51
+ ):
52
+ super().__init__(
53
+ model=model,
54
+ reward_funcs=[],
55
+ args=args,
56
+ train_dataset=train_dataset,
57
+ eval_dataset=eval_dataset,
58
+ processing_class=processing_class,
59
+ callbacks=callbacks,
60
+ optimizers=optimizers,
61
+ peft_config=peft_config,
62
+ **kwargs,
63
+ )
64
+ self.peft_config = peft_config
65
+ self.loss_type = "mmgrpo"
66
+
67
+ self.vllm_client = None
68
+ args.use_vllm = True
69
+ self.use_vllm = True
70
+ if self.accelerator.is_main_process:
71
+ print(
72
+ f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
73
+ )
74
+ self.vllm_client = VLLMClient(
75
+ args.vllm_server_host,
76
+ args.vllm_server_port,
77
+ group_port=vllm_group_port,
78
+ connection_timeout=args.vllm_server_timeout,
79
+ )
80
+ self.vllm_client.init_communicator()
81
+
82
+ # vLLM specific sampling arguments
83
+ self.guided_decoding_regex = args.vllm_guided_decoding_regex
84
+
85
+ self._last_loaded_step = (
86
+ -1
87
+ ) # tag to avoid useless loading during grad accumulation
88
+
89
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
90
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
91
+ # synchronize all processes after vLLM has been fully initialized.
92
+ self.accelerator.wait_for_everyone()
93
+
94
+ def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
95
+ self.comms_handler = comms_handler
96
+
97
+ def get_train_dataloader(self):
98
+ return Trainer.get_train_dataloader(self)
99
+
100
+ def _get_train_sampler(self, dataset: Optional[Dataset] = None):
101
+ return Trainer._get_train_sampler(self, dataset)
102
+
103
+ def _tensorize_prompts_completions(
104
+ self, generation_batch: dict[str, Union[torch.Tensor, Any]]
105
+ ) -> dict[str, Union[torch.Tensor, Any]]:
106
+ prompt_completion_texts = []
107
+ for example in generation_batch:
108
+ prompt_completion_texts.append(
109
+ maybe_apply_chat_template(
110
+ {
111
+ "prompt": example["messages"],
112
+ "completion": (
113
+ example["completion"]
114
+ if isinstance(example["completion"], list)
115
+ else [example["completion"]]
116
+ ),
117
+ },
118
+ self.processing_class,
119
+ )
120
+ )
121
+ prompts_text = [
122
+ prompt_completion_text["prompt"]
123
+ for prompt_completion_text in prompt_completion_texts
124
+ ]
125
+ prompt_inputs = self.processing_class(
126
+ prompts_text,
127
+ return_tensors="pt",
128
+ padding=True,
129
+ padding_side="left",
130
+ add_special_tokens=False,
131
+ )
132
+ prompt_ids, prompt_mask = (
133
+ prompt_inputs["input_ids"],
134
+ prompt_inputs["attention_mask"],
135
+ )
136
+
137
+ completion_text = [
138
+ prompt_completion_text["completion"]
139
+ for prompt_completion_text in prompt_completion_texts
140
+ ]
141
+ completion_inputs = self.processing_class(
142
+ completion_text,
143
+ return_tensors="pt",
144
+ padding=True,
145
+ add_special_tokens=False,
146
+ )
147
+ completion_ids, completion_mask = (
148
+ completion_inputs["input_ids"],
149
+ completion_inputs["attention_mask"],
150
+ )
151
+
152
+ if self.max_prompt_length is not None:
153
+ if prompt_ids.shape[1] > self.max_prompt_length:
154
+ print(f"Truncating prompt to {self.max_prompt_length} tokens")
155
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
156
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
157
+
158
+ if self.max_completion_length is not None:
159
+ if completion_ids.shape[1] > self.max_completion_length:
160
+ print(f"Truncating completion to {self.max_completion_length} tokens")
161
+ completion_ids = completion_ids[:, : self.max_completion_length]
162
+ completion_mask = completion_mask[:, : self.max_completion_length]
163
+
164
+ return {
165
+ "prompt_ids": prompt_ids,
166
+ "prompt_mask": prompt_mask,
167
+ "completion_ids": completion_ids,
168
+ "completion_mask": completion_mask,
169
+ }
170
+
171
+ def _get_trajectory_lengths(
172
+ self, generation_batch: dict[str, Union[torch.Tensor, Any]]
173
+ ) -> torch.Tensor:
174
+ trajectory_lengths = []
175
+ for example in generation_batch:
176
+ full_trajectory = example["trajectory"]
177
+ prompt_completion_tensors = self._tensorize_prompts_completions(
178
+ full_trajectory
179
+ )
180
+ completion_mask = prompt_completion_tensors["completion_mask"]
181
+ trajectory_lengths.append(completion_mask.sum())
182
+ return torch.tensor(trajectory_lengths)
183
+
184
+ def _prepare_inputs(
185
+ self, generation_batch: dict[str, Union[torch.Tensor, Any]]
186
+ ) -> dict[str, Union[torch.Tensor, Any]]:
187
+ device = self.accelerator.device
188
+ mode = "train" if self.model.training else "eval"
189
+ prompt_completion_tensors = self._tensorize_prompts_completions(
190
+ generation_batch
191
+ )
192
+ prompt_ids, prompt_mask = prompt_completion_tensors["prompt_ids"].to(
193
+ device
194
+ ), prompt_completion_tensors["prompt_mask"].to(device)
195
+ completion_ids, completion_mask = prompt_completion_tensors[
196
+ "completion_ids"
197
+ ].to(device), prompt_completion_tensors["completion_mask"].to(device)
198
+
199
+ advantages = torch.tensor(
200
+ [example["advantage"] for example in generation_batch]
201
+ ).to(device)
202
+ trajectory_lengths = self._get_trajectory_lengths(generation_batch).to(device)
203
+
204
+ out = {
205
+ "prompt_ids": prompt_ids,
206
+ "prompt_mask": prompt_mask,
207
+ "completion_ids": completion_ids,
208
+ "completion_mask": completion_mask,
209
+ "advantages": advantages,
210
+ "old_per_token_logps": None,
211
+ "trajectory_lengths": trajectory_lengths,
212
+ }
213
+ return out
214
+
215
+ def compute_loss(
216
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
217
+ ):
218
+ if return_outputs:
219
+ raise ValueError("The GRPOTrainer does not support returning outputs")
220
+
221
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
222
+ completion_ids, completion_mask = (
223
+ inputs["completion_ids"],
224
+ inputs["completion_mask"],
225
+ )
226
+ inputs_ids = torch.cat([prompt_ids, completion_ids], dim=1)
227
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
228
+ logits_to_keep = completion_ids.size(1)
229
+
230
+ per_token_logps = self._get_per_token_logps(
231
+ model, inputs_ids, attention_mask, logits_to_keep
232
+ )
233
+
234
+ if self.beta != 0.0:
235
+ with torch.no_grad():
236
+ if self.ref_model is not None:
237
+ ref_per_token_logps = self._get_per_token_logps(
238
+ self.ref_model, inputs_ids, attention_mask, logits_to_keep
239
+ )
240
+ else:
241
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
242
+ ref_per_token_logps = self._get_per_token_logps(
243
+ self.model, inputs_ids, attention_mask, logits_to_keep
244
+ )
245
+ per_token_kl = (
246
+ torch.exp(ref_per_token_logps - per_token_logps)
247
+ - (ref_per_token_logps - per_token_logps)
248
+ - 1
249
+ )
250
+
251
+ advantages = inputs["advantages"]
252
+ trajectory_lengths = inputs["trajectory_lengths"]
253
+ # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
254
+ # old_per_token_logps == per_token_logps, so we can skip it's computation
255
+ # (see _generate_and_score_completions) and use per_token_logps.detach() instead.
256
+ old_per_token_logps = (
257
+ per_token_logps.detach()
258
+ if inputs["old_per_token_logps"] is None
259
+ else inputs["old_per_token_logps"]
260
+ )
261
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
262
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
263
+
264
+ if self.args.delta is not None:
265
+ coef_1 = torch.clamp(coef_1, max=self.args.delta)
266
+
267
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
268
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
269
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
270
+ if self.beta != 0.0:
271
+ per_token_loss = per_token_loss + self.beta * per_token_kl
272
+
273
+ if self.loss_type == "grpo":
274
+ loss = (
275
+ (per_token_loss * completion_mask).sum(-1)
276
+ / completion_mask.sum(-1).clamp(min=1.0)
277
+ ).mean()
278
+ elif self.loss_type == "bnpo":
279
+ loss = (
280
+ per_token_loss * completion_mask
281
+ ).sum() / completion_mask.sum().clamp(min=1.0)
282
+ elif self.loss_type == "dr_grpo":
283
+ loss = (per_token_loss * completion_mask).sum() / (
284
+ per_token_loss.size(0) * self.max_completion_length
285
+ )
286
+ elif self.loss_type == "mmgrpo":
287
+ # Sum the loss over tokens for each trajectory
288
+ trajectory_losses = (per_token_loss * completion_mask).sum(dim=-1)
289
+ # Normalize by the actual trajectory lengths
290
+ normalized_losses = trajectory_losses / trajectory_lengths.clamp(min=1.0)
291
+ # Take the mean over the batch
292
+ loss = normalized_losses.mean()
293
+ else:
294
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
295
+
296
+ # Log the metrics
297
+ mode = "train" if self.model.training else "eval"
298
+
299
+ if self.beta != 0.0:
300
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
301
+ self._metrics[mode]["kl"].append(
302
+ self.accelerator.gather(mean_kl).nanmean().item()
303
+ )
304
+
305
+ # Compute the clipped probability ratios
306
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
307
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (
308
+ advantages.unsqueeze(1) > 0
309
+ )
310
+ is_region_clipped = is_low_clipped | is_high_clipped
311
+
312
+ low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
313
+ high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
314
+ clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
315
+
316
+ gathered_low_clip = self.accelerator.gather(low_clip)
317
+ self._metrics[mode]["clip_ratio/low_mean"].append(
318
+ gathered_low_clip.nanmean().item()
319
+ )
320
+ self._metrics[mode]["clip_ratio/low_min"].append(
321
+ nanmin(gathered_low_clip).item()
322
+ )
323
+ gathered_high_clip = self.accelerator.gather(high_clip)
324
+ self._metrics[mode]["clip_ratio/high_mean"].append(
325
+ gathered_high_clip.nanmean().item()
326
+ )
327
+ self._metrics[mode]["clip_ratio/high_max"].append(
328
+ nanmax(gathered_high_clip).item()
329
+ )
330
+ gathered_clip_ratio = self.accelerator.gather(clip_ratio)
331
+ self._metrics[mode]["clip_ratio/region_mean"].append(
332
+ gathered_clip_ratio.nanmean().item()
333
+ )
334
+ print(f"Loss: {loss.item()}")
335
+
336
+ return loss
337
+
338
+
339
+ def main():
340
+ parser = argparse.ArgumentParser()
341
+ parser.add_argument("--debug", action="store_true")
342
+
343
+ pipe_args = parser.add_argument_group("Comms arguments")
344
+ pipe_args.add_argument("--host", default="localhost")
345
+ pipe_args.add_argument("--command_port", type=int, required=True)
346
+ pipe_args.add_argument("--status_port", type=int, required=True)
347
+ pipe_args.add_argument("--data_port", type=int, required=True)
348
+ pipe_args.add_argument("--broadcast_port", type=int, required=True)
349
+ pipe_args.add_argument("--handshake_port", type=int, required=True)
350
+ pipe_args.add_argument("--vllm_group_port", type=int, required=True)
351
+ pipe_args.add_argument("--vllm_port", type=int, required=True)
352
+
353
+ training_args = parser.add_argument_group("Training arguments")
354
+ training_args.add_argument(
355
+ "--model",
356
+ type=str,
357
+ help="Model to use for training",
358
+ )
359
+ training_args.add_argument(
360
+ "--trl_train_kwargs",
361
+ type=json.loads,
362
+ help="Training arguments as a JSON string",
363
+ )
364
+ training_args.add_argument(
365
+ "--arbor_train_kwargs",
366
+ type=json.loads,
367
+ help="Training arguments as a JSON string",
368
+ )
369
+
370
+ args = parser.parse_args()
371
+
372
+ if args.debug:
373
+ pass
374
+
375
+ try:
376
+ trl_train_args = {**(args.trl_train_kwargs or {})}
377
+ arbor_train_args = {**(args.arbor_train_kwargs or {})}
378
+
379
+ # TODO: These assertions should be done in some better way
380
+ assert "output_dir" in trl_train_args, "output_dir is required"
381
+ if "gradient_checkpointing_kwargs" in trl_train_args and arbor_train_args.get(
382
+ "lora", False
383
+ ):
384
+ print(
385
+ "Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
386
+ )
387
+ trl_train_args["gradient_checkpointing_kwargs"] = {
388
+ **(trl_train_args.get("gradient_checkpointing_kwargs") or {}),
389
+ "use_reentrant": False,
390
+ }
391
+
392
+ lora_config = None
393
+ if arbor_train_args.get("lora", False):
394
+ print("Using LORA for PEFT")
395
+ lora_config = LoraConfig(
396
+ r=16,
397
+ lora_alpha=64,
398
+ target_modules=[
399
+ "q_proj",
400
+ "k_proj",
401
+ "v_proj",
402
+ "o_proj",
403
+ "up_proj",
404
+ "down_proj",
405
+ "gate_proj",
406
+ ],
407
+ task_type="CAUSAL_LM",
408
+ lora_dropout=0.05,
409
+ inference_mode=False,
410
+ )
411
+
412
+ training_args = GRPOConfig(
413
+ dataloader_num_workers=0,
414
+ shuffle_dataset=False,
415
+ vllm_server_port=args.vllm_port,
416
+ **trl_train_args,
417
+ )
418
+
419
+ # Create ingestion monitor
420
+ ingestion_monitor = IngestionMonitor()
421
+
422
+ train_dataset = BlockingQueueDataset(
423
+ ingestion_monitor=ingestion_monitor,
424
+ )
425
+ weight_update_callback = WeightUpdateCallback(
426
+ ingestion_monitor=ingestion_monitor,
427
+ )
428
+ trainer = MMGRPOTrainer(
429
+ model=args.model,
430
+ args=training_args,
431
+ train_dataset=train_dataset,
432
+ callbacks=[weight_update_callback],
433
+ peft_config=lora_config,
434
+ vllm_group_port=args.vllm_group_port,
435
+ **arbor_train_args,
436
+ )
437
+
438
+ comms_handler = ArborScriptCommsHandler(
439
+ host=args.host,
440
+ command_port=args.command_port,
441
+ status_port=args.status_port,
442
+ data_port=args.data_port,
443
+ broadcast_port=args.broadcast_port,
444
+ handshake_port=args.handshake_port,
445
+ is_main_process=trainer.accelerator.is_main_process,
446
+ )
447
+
448
+ train_dataset.set_comms_handler(comms_handler)
449
+ train_dataset.set_accelerator(trainer.accelerator)
450
+
451
+ weight_update_callback.set_comms_handler(comms_handler)
452
+ weight_update_callback.set_trainer(trainer)
453
+
454
+ trainer.set_comms_handler(comms_handler)
455
+
456
+ command_monitor = CommandMonitor(
457
+ comms_handler=comms_handler,
458
+ trainer=trainer,
459
+ base_model_name=args.model,
460
+ ingestion_monitor=ingestion_monitor,
461
+ )
462
+ command_monitor.start()
463
+
464
+ print("command monitor started")
465
+
466
+ # Add signal handlers for graceful shutdown
467
+ def signal_handler(signum, frame):
468
+ print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
469
+ print("Ending training...")
470
+ trainer.accelerator.end_training()
471
+ print("Closing communications...")
472
+ comms_handler.close()
473
+ sys.exit(0)
474
+
475
+ signal.signal(signal.SIGINT, signal_handler)
476
+ signal.signal(signal.SIGTERM, signal_handler)
477
+
478
+ print("Signal handlers added")
479
+
480
+ print("Starting training...")
481
+ try:
482
+ trainer.train()
483
+ except Exception as e:
484
+ print(f"Error during training: {e}")
485
+ print(f"Error type: {type(e).__name__}")
486
+ if "unhashable" in str(e):
487
+ print("DEBUGGING: Unhashable type error during training")
488
+ print(
489
+ "This could be in data loading, model forward pass, or metrics collection"
490
+ )
491
+ raise
492
+
493
+ except Exception as e:
494
+ import traceback
495
+
496
+ print(f"Error: {e}")
497
+ print("Stack trace:")
498
+ traceback.print_exc()
499
+ comms_handler.send_status({"status": "error", "error": str(e)})
500
+ raise e
501
+ finally:
502
+ print("Cleaning up resources...")
503
+ trainer.accelerator.end_training()
504
+ comms_handler.close()
505
+ print("Cleanup complete")
506
+
507
+
508
+ # Example usage:
509
+ if __name__ == "__main__":
510
+ main()
@@ -11,6 +11,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
11
11
  from trl import SFTConfig, SFTTrainer, setup_chat_format
12
12
 
13
13
  from arbor.server.services.scripts.utils.arg_parser import get_training_arg_parser
14
+ from arbor.server.utils.logging import get_logger
15
+
16
+ logger = get_logger(__name__)
14
17
 
15
18
 
16
19
  def main():
@@ -25,7 +28,7 @@ def main():
25
28
  # TODO: These assertions should be done in some better way
26
29
  assert "output_dir" in trl_train_kwargs, "output_dir is required"
27
30
  if "gradient_checkpointing_kwargs" in trl_train_kwargs and args.lora:
28
- print(
31
+ logger.info(
29
32
  "Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
30
33
  )
31
34
  trl_train_kwargs["gradient_checkpointing_kwargs"] = {
@@ -35,7 +38,7 @@ def main():
35
38
 
36
39
  lora_config = None
37
40
  if args.lora:
38
- print("Using LORA for PEFT")
41
+ logger.info("Using LORA for PEFT")
39
42
  lora_config = LoraConfig(
40
43
  r=16,
41
44
  lora_alpha=64,
@@ -91,13 +94,13 @@ def main():
91
94
  base_model_name=args.model,
92
95
  )
93
96
 
94
- print("Training...")
97
+ logger.info("Starting training...")
95
98
  trainer.train()
96
99
 
97
100
  except KeyboardInterrupt:
98
- print("\nReceived interrupt, shutting down...")
101
+ logger.info("Received interrupt, shutting down...")
99
102
  except Exception as e:
100
- print(f"Error: {e}")
103
+ logger.error(f"Training error: {e}")
101
104
  comms_handler.send_status({"status": "error", "error": str(e)})
102
105
  raise e
103
106
  finally:
@@ -0,0 +1,33 @@
1
+ import logging
2
+
3
+ from transformers import TrainerCallback
4
+
5
+ from arbor.server.services.comms.comms import ArborScriptCommsHandler
6
+ from arbor.server.services.scripts.utils.ingestion_monitor import IngestionMonitor
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class WeightUpdateCallback(TrainerCallback):
12
+ """A callback that sends weight update completion status after each step"""
13
+
14
+ def __init__(self, ingestion_monitor: IngestionMonitor):
15
+ self.comms_handler = None
16
+ self.trainer = None
17
+ self.ingestion_monitor = ingestion_monitor
18
+
19
+ def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
20
+ self.comms_handler = comms_handler
21
+
22
+ def set_trainer(self, trainer):
23
+ self.trainer = trainer
24
+
25
+ def on_step_end(self, args, state, control, **kwargs):
26
+ self.ingestion_monitor.set_last_step_time()
27
+ if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
28
+ if state.global_step != self.trainer._last_loaded_step:
29
+ logger.info("Updating inference model...")
30
+ self.comms_handler.send_status({"status": "weight_update_start"})
31
+ self.trainer._move_model_to_vllm()
32
+ self.trainer._last_loaded_step = state.global_step
33
+ self.comms_handler.send_status({"status": "weight_update_complete"})