arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +98 -62
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +6 -4
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
- arbor_ai-0.2.2.dist-info/RECORD +51 -0
- arbor_ai-0.2.1.dist-info/RECORD +0 -42
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -5,24 +5,19 @@
|
|
5
5
|
|
6
6
|
import argparse
|
7
7
|
import json
|
8
|
-
import os
|
9
8
|
import random
|
10
|
-
import shutil
|
11
9
|
import signal
|
12
10
|
import sys
|
13
11
|
import threading
|
14
12
|
import time
|
15
|
-
from functools import lru_cache
|
16
13
|
from typing import Any, List, Optional, Union
|
17
14
|
|
18
15
|
import torch
|
19
16
|
import trl.extras.vllm_client
|
20
17
|
import zmq
|
21
|
-
from accelerate import Accelerator
|
22
18
|
from accelerate.utils import broadcast_object_list, gather, gather_object
|
23
19
|
from datasets import Dataset, IterableDataset, load_dataset
|
24
|
-
from peft import
|
25
|
-
from torch.utils.data import Dataset
|
20
|
+
from peft import LoraConfig, PeftConfig
|
26
21
|
from transformers import (
|
27
22
|
PreTrainedModel,
|
28
23
|
PreTrainedTokenizerBase,
|
@@ -38,28 +33,19 @@ from arbor.server.services.comms.comms import (
|
|
38
33
|
ArborServerCommsHandler,
|
39
34
|
)
|
40
35
|
from arbor.server.services.inference.vllm_client import VLLMClient
|
36
|
+
from arbor.server.services.scripts.utils.callbacks import WeightUpdateCallback
|
37
|
+
from arbor.server.services.scripts.utils.comms_monitors import CommandMonitor
|
38
|
+
from arbor.server.services.scripts.utils.dataset import BlockingRotatingQueueDataset
|
39
|
+
from arbor.server.services.scripts.utils.ingestion_monitor import IngestionMonitor
|
41
40
|
|
42
41
|
trl.extras.vllm_client.VLLMClient = VLLMClient
|
43
42
|
|
44
|
-
|
45
|
-
import wandb
|
46
|
-
|
47
|
-
last_step_time = None
|
48
|
-
last_queue_pop_time = None
|
49
|
-
|
50
|
-
|
51
|
-
def time_since_last_step():
|
52
|
-
global last_step_time
|
53
|
-
if last_step_time is None:
|
54
|
-
return float("inf")
|
55
|
-
return time.time() - last_step_time
|
43
|
+
from arbor.server.utils.logging import get_logger
|
56
44
|
|
45
|
+
logger = get_logger(__name__)
|
57
46
|
|
58
|
-
|
59
|
-
|
60
|
-
if last_queue_pop_time is None:
|
61
|
-
return float("inf")
|
62
|
-
return time.time() - last_queue_pop_time
|
47
|
+
if is_wandb_available():
|
48
|
+
import wandb
|
63
49
|
|
64
50
|
|
65
51
|
class ArborGRPOTrainer(GRPOTrainer):
|
@@ -77,10 +63,7 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
77
63
|
] = (None, None),
|
78
64
|
peft_config: Optional["PeftConfig"] = None,
|
79
65
|
comms_handler: Optional[ArborScriptCommsHandler] = None,
|
80
|
-
lora: Optional[bool] = False,
|
81
|
-
# We do nothing with max_context_length right now
|
82
66
|
vllm_group_port: Optional[int] = None,
|
83
|
-
max_context_length: Optional[int] = None,
|
84
67
|
**kwargs,
|
85
68
|
):
|
86
69
|
super().__init__(
|
@@ -103,7 +86,7 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
103
86
|
args.use_vllm = True
|
104
87
|
self.use_vllm = True
|
105
88
|
if self.accelerator.is_main_process:
|
106
|
-
|
89
|
+
logger.info(
|
107
90
|
f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
|
108
91
|
)
|
109
92
|
self.vllm_client = VLLMClient(
|
@@ -185,13 +168,15 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
185
168
|
|
186
169
|
if self.max_prompt_length is not None:
|
187
170
|
if prompt_ids.shape[1] > self.max_prompt_length:
|
188
|
-
|
171
|
+
logger.info(f"Truncating prompt to {self.max_prompt_length} tokens")
|
189
172
|
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
190
173
|
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
191
174
|
|
192
175
|
if self.max_completion_length is not None:
|
193
176
|
if completion_ids.shape[1] > self.max_completion_length:
|
194
|
-
|
177
|
+
logger.info(
|
178
|
+
f"Truncating completion to {self.max_completion_length} tokens"
|
179
|
+
)
|
195
180
|
completion_ids = completion_ids[:, : self.max_completion_length]
|
196
181
|
completion_mask = completion_mask[:, : self.max_completion_length]
|
197
182
|
|
@@ -225,7 +210,7 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
225
210
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
226
211
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
227
212
|
|
228
|
-
|
213
|
+
logger.info(
|
229
214
|
f"prompt_completion_ids.shape (after truncation, if enabled): {prompt_completion_ids.shape}, prompt_ids.shape: {prompt_ids.shape}, completion_ids.shape: {completion_ids.shape}"
|
230
215
|
)
|
231
216
|
|
@@ -354,238 +339,14 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
354
339
|
class LastStepTimeCallback(TrainerCallback):
|
355
340
|
"A callback that prints a message at the beginning of training"
|
356
341
|
|
357
|
-
def
|
358
|
-
|
359
|
-
print(f"Time since last step: {time_since_last_step()}")
|
360
|
-
last_step_time = time.time()
|
361
|
-
|
362
|
-
|
363
|
-
class WeightUpdateCallback(TrainerCallback):
|
364
|
-
"""A callback that sends weight update completion status after each step"""
|
365
|
-
|
366
|
-
def __init__(self):
|
367
|
-
self.comms_handler = None
|
368
|
-
self.trainer = None
|
369
|
-
|
370
|
-
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
371
|
-
self.comms_handler = comms_handler
|
372
|
-
|
373
|
-
def set_trainer(self, trainer):
|
374
|
-
self.trainer = trainer
|
342
|
+
def __init__(self, ingestion_monitor: IngestionMonitor):
|
343
|
+
self.ingestion_monitor = ingestion_monitor
|
375
344
|
|
376
345
|
def on_step_end(self, args, state, control, **kwargs):
|
377
|
-
|
378
|
-
|
379
|
-
print("Updating inference model...")
|
380
|
-
self.comms_handler.send_status({"status": "weight_update_start"})
|
381
|
-
self.trainer._move_model_to_vllm()
|
382
|
-
self.trainer._last_loaded_step = state.global_step
|
383
|
-
print("[DEBUG] Sending weight update completion status")
|
384
|
-
self.comms_handler.send_status({"status": "weight_update_complete"})
|
385
|
-
|
386
|
-
|
387
|
-
class BlockingQueueDataset(Dataset):
|
388
|
-
def __init__(
|
389
|
-
self,
|
390
|
-
accelerator: Accelerator,
|
391
|
-
comms_handler: ArborScriptCommsHandler,
|
392
|
-
size=10_000, # Just a random number
|
393
|
-
maxsize=100,
|
394
|
-
):
|
395
|
-
self.size = size
|
396
|
-
self.accelerator = accelerator
|
397
|
-
self.comms_handler = comms_handler
|
398
|
-
self.get_cached_data = lru_cache(maxsize=maxsize)(self._get_data)
|
399
|
-
self.completion_counters = {}
|
400
|
-
|
401
|
-
def __len__(self):
|
402
|
-
return self.size
|
403
|
-
|
404
|
-
def _get_data(self, idx):
|
405
|
-
rank = self.accelerator.process_index
|
406
|
-
world_size = self.accelerator.num_processes
|
407
|
-
|
408
|
-
if self.accelerator.is_main_process:
|
409
|
-
global last_queue_pop_time
|
410
|
-
last_queue_pop_time = time.time()
|
411
|
-
|
412
|
-
if idx not in self.completion_counters:
|
413
|
-
self.completion_counters[idx] = 0
|
414
|
-
|
415
|
-
try:
|
416
|
-
new_data = self.comms_handler.receive_data()
|
417
|
-
|
418
|
-
except Exception as e:
|
419
|
-
print(f"[rank {rank}] Error receiving data: {e}")
|
420
|
-
new_data = None
|
421
|
-
|
422
|
-
return new_data
|
423
|
-
|
424
|
-
def __getitem__(self, idx):
|
425
|
-
data = self.get_cached_data(idx)
|
426
|
-
# Create hash of data to detect if processes are using the same idx for the same data
|
427
|
-
data_hash = format(abs(hash(str(data))) % (16**8), "08x")
|
428
|
-
|
429
|
-
if data is None:
|
430
|
-
return None
|
431
|
-
|
432
|
-
counter = self.completion_counters.get(idx, 0)
|
433
|
-
item = data[counter]
|
434
|
-
self.completion_counters[idx] = (counter + 1) % len(data)
|
435
|
-
return item
|
436
|
-
|
437
|
-
|
438
|
-
class CommandMonitor:
|
439
|
-
def __init__(
|
440
|
-
self,
|
441
|
-
comms_handler: ArborScriptCommsHandler,
|
442
|
-
trainer: ArborGRPOTrainer,
|
443
|
-
base_model_name: str,
|
444
|
-
):
|
445
|
-
self.comms_handler = comms_handler
|
446
|
-
self.trainer = trainer
|
447
|
-
self.base_model_name = base_model_name
|
448
|
-
self.command_thread = threading.Thread(
|
449
|
-
target=self._monitor_commands, daemon=True
|
346
|
+
logger.info(
|
347
|
+
f"Time since last step: {self.ingestion_monitor.time_since_last_step()}"
|
450
348
|
)
|
451
|
-
self.
|
452
|
-
|
453
|
-
def _monitor_commands(self):
|
454
|
-
"""Background thread that monitors for commands from the server."""
|
455
|
-
if not self.comms_handler:
|
456
|
-
return
|
457
|
-
try:
|
458
|
-
for command in self.comms_handler.receive_command():
|
459
|
-
print(f"Main process received command: {command}")
|
460
|
-
if (
|
461
|
-
command.get("command") == "save_model"
|
462
|
-
and self.trainer.accelerator.is_main_process
|
463
|
-
):
|
464
|
-
print(
|
465
|
-
f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
|
466
|
-
)
|
467
|
-
while (
|
468
|
-
time_since_last_step() <= 10
|
469
|
-
or get_time_since_last_queue_pop() <= 10
|
470
|
-
):
|
471
|
-
print(f"Waiting for steps to finish")
|
472
|
-
print(
|
473
|
-
f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
|
474
|
-
)
|
475
|
-
print(
|
476
|
-
f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
477
|
-
)
|
478
|
-
time.sleep(5)
|
479
|
-
print("[Training Script] Saving model...")
|
480
|
-
if self.trainer.peft_config:
|
481
|
-
self.trainer.save_model(
|
482
|
-
output_dir=self.trainer.args.output_dir + "/adapter/"
|
483
|
-
)
|
484
|
-
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
485
|
-
self.trainer.args.output_dir + "/adapter/",
|
486
|
-
config=self.trainer.peft_config,
|
487
|
-
)
|
488
|
-
merged_model = _model_to_merge.merge_and_unload()
|
489
|
-
merged_model.save_pretrained(
|
490
|
-
self.trainer.args.output_dir,
|
491
|
-
safe_serialization=True,
|
492
|
-
)
|
493
|
-
self.trainer.processing_class.save_pretrained(
|
494
|
-
self.trainer.args.output_dir
|
495
|
-
)
|
496
|
-
else:
|
497
|
-
self.trainer.save_model()
|
498
|
-
|
499
|
-
print("[Training Script] Model saved")
|
500
|
-
self.comms_handler.send_status(
|
501
|
-
{
|
502
|
-
"status": "model_saved",
|
503
|
-
"output_dir": self.trainer.args.output_dir,
|
504
|
-
}
|
505
|
-
)
|
506
|
-
elif command.get("command") == "save_checkpoint":
|
507
|
-
print(
|
508
|
-
f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
|
509
|
-
)
|
510
|
-
while (
|
511
|
-
time_since_last_step() <= 10
|
512
|
-
or get_time_since_last_queue_pop() <= 10
|
513
|
-
):
|
514
|
-
print(f"Waiting for steps to finish")
|
515
|
-
print(
|
516
|
-
f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
|
517
|
-
)
|
518
|
-
print(
|
519
|
-
f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
520
|
-
)
|
521
|
-
time.sleep(5)
|
522
|
-
if self.trainer.peft_config:
|
523
|
-
self.trainer.save_model(
|
524
|
-
output_dir=self.trainer.args.output_dir
|
525
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
|
526
|
-
)
|
527
|
-
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
528
|
-
self.trainer.args.output_dir
|
529
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
|
530
|
-
config=self.trainer.peft_config,
|
531
|
-
)
|
532
|
-
merged_model = _model_to_merge.merge_and_unload()
|
533
|
-
merged_model.save_pretrained(
|
534
|
-
self.trainer.args.output_dir
|
535
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
536
|
-
safe_serialization=True,
|
537
|
-
)
|
538
|
-
self.trainer.processing_class.save_pretrained(
|
539
|
-
self.trainer.args.output_dir
|
540
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
541
|
-
)
|
542
|
-
else:
|
543
|
-
self.trainer.save_model(
|
544
|
-
output_dir=self.trainer.args.output_dir
|
545
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
546
|
-
)
|
547
|
-
|
548
|
-
# Copy checkpoint files to root output directory
|
549
|
-
checkpoint_dir = (
|
550
|
-
self.trainer.args.output_dir
|
551
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
552
|
-
)
|
553
|
-
root_dir = self.trainer.args.output_dir
|
554
|
-
|
555
|
-
# Copy all files from checkpoint dir to root dir, overwriting if they exist
|
556
|
-
# (effectively saves the checkpoint to the output directory)
|
557
|
-
for item in os.listdir(checkpoint_dir):
|
558
|
-
src = os.path.join(checkpoint_dir, item)
|
559
|
-
dst = os.path.join(root_dir, item)
|
560
|
-
if os.path.isdir(src):
|
561
|
-
if os.path.exists(dst):
|
562
|
-
shutil.rmtree(dst)
|
563
|
-
shutil.copytree(src, dst)
|
564
|
-
else:
|
565
|
-
shutil.copy2(src, dst)
|
566
|
-
|
567
|
-
self.comms_handler.send_status(
|
568
|
-
{
|
569
|
-
"status": "checkpoint_saved",
|
570
|
-
"checkpoint_name": command.get("checkpoint_name"),
|
571
|
-
"output_dir": self.trainer.args.output_dir
|
572
|
-
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
573
|
-
}
|
574
|
-
)
|
575
|
-
self.comms_handler.send_status(
|
576
|
-
{
|
577
|
-
"status": "model_saved",
|
578
|
-
"output_dir": self.trainer.args.output_dir,
|
579
|
-
}
|
580
|
-
)
|
581
|
-
elif command.get("command") == "terminate":
|
582
|
-
print("TERMINATED")
|
583
|
-
self.trainer.accelerator.end_training()
|
584
|
-
self.comms_handler.send_status({"status": "terminated"})
|
585
|
-
|
586
|
-
except Exception as e:
|
587
|
-
print(e)
|
588
|
-
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
349
|
+
self.ingestion_monitor.set_last_step_time()
|
589
350
|
|
590
351
|
|
591
352
|
def main():
|
@@ -679,7 +440,7 @@ def main():
|
|
679
440
|
# Need to set subscription for PUB/SUB pattern
|
680
441
|
server_comms_handler.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
681
442
|
for status in server_comms_handler.receive_status():
|
682
|
-
|
443
|
+
logger.info(f"Status: {status}")
|
683
444
|
|
684
445
|
status_listener_thread = threading.Thread(target=status_listener, daemon=True)
|
685
446
|
status_listener_thread.start()
|
@@ -693,7 +454,7 @@ def main():
|
|
693
454
|
if "gradient_checkpointing_kwargs" in trl_train_args and arbor_train_args.get(
|
694
455
|
"lora", False
|
695
456
|
):
|
696
|
-
|
457
|
+
logger.info(
|
697
458
|
"Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
|
698
459
|
)
|
699
460
|
trl_train_args["gradient_checkpointing_kwargs"] = {
|
@@ -703,7 +464,7 @@ def main():
|
|
703
464
|
|
704
465
|
lora_config = None
|
705
466
|
if arbor_train_args.get("lora", False):
|
706
|
-
|
467
|
+
logger.info("Using LORA for PEFT")
|
707
468
|
lora_config = LoraConfig(
|
708
469
|
r=16,
|
709
470
|
lora_alpha=64,
|
@@ -721,6 +482,12 @@ def main():
|
|
721
482
|
inference_mode=False,
|
722
483
|
)
|
723
484
|
|
485
|
+
if "report_to" in trl_train_args and trl_train_args["report_to"] == "wandb":
|
486
|
+
import wandb
|
487
|
+
|
488
|
+
if "wandb_kwargs" in arbor_train_args and arbor_train_args["wandb_kwargs"]:
|
489
|
+
wandb.init(**arbor_train_args["wandb_kwargs"])
|
490
|
+
|
724
491
|
training_args = GRPOConfig(
|
725
492
|
dataloader_num_workers=0,
|
726
493
|
shuffle_dataset=False,
|
@@ -728,15 +495,24 @@ def main():
|
|
728
495
|
**trl_train_args,
|
729
496
|
)
|
730
497
|
|
731
|
-
|
498
|
+
# Create ingestion monitor
|
499
|
+
ingestion_monitor = IngestionMonitor()
|
500
|
+
|
501
|
+
train_dataset = BlockingRotatingQueueDataset(
|
502
|
+
ingestion_monitor=ingestion_monitor,
|
503
|
+
)
|
504
|
+
|
505
|
+
weight_update_callback = WeightUpdateCallback(
|
506
|
+
ingestion_monitor=ingestion_monitor,
|
507
|
+
)
|
508
|
+
|
732
509
|
trainer = ArborGRPOTrainer(
|
733
510
|
model=args.model,
|
734
511
|
args=training_args,
|
735
|
-
train_dataset=
|
736
|
-
callbacks=[LastStepTimeCallback(), weight_update_callback],
|
512
|
+
train_dataset=train_dataset,
|
513
|
+
callbacks=[LastStepTimeCallback(ingestion_monitor), weight_update_callback],
|
737
514
|
peft_config=lora_config,
|
738
515
|
vllm_group_port=args.vllm_group_port,
|
739
|
-
**arbor_train_args,
|
740
516
|
)
|
741
517
|
# Create client handler
|
742
518
|
comms_handler = ArborScriptCommsHandler(
|
@@ -748,48 +524,53 @@ def main():
|
|
748
524
|
handshake_port=args.handshake_port,
|
749
525
|
is_main_process=trainer.accelerator.is_main_process,
|
750
526
|
)
|
527
|
+
|
528
|
+
train_dataset.set_comms_handler(comms_handler)
|
529
|
+
train_dataset.set_accelerator(trainer.accelerator)
|
530
|
+
|
751
531
|
weight_update_callback.set_comms_handler(comms_handler)
|
752
532
|
weight_update_callback.set_trainer(trainer)
|
753
533
|
trainer.comms_handler = comms_handler
|
754
534
|
|
755
|
-
# Initialize the dataset with the actual accelerator
|
756
|
-
trainer.train_dataset = BlockingQueueDataset(
|
757
|
-
accelerator=trainer.accelerator,
|
758
|
-
comms_handler=trainer.comms_handler,
|
759
|
-
)
|
760
|
-
|
761
535
|
command_monitor = CommandMonitor(
|
762
536
|
comms_handler=comms_handler,
|
763
537
|
trainer=trainer,
|
764
538
|
base_model_name=args.model,
|
539
|
+
ingestion_monitor=ingestion_monitor,
|
765
540
|
)
|
541
|
+
command_monitor.start()
|
766
542
|
|
767
543
|
# Add signal handlers for graceful shutdown
|
768
544
|
def signal_handler(signum, frame):
|
769
|
-
|
770
|
-
|
545
|
+
logger.info(f"\nReceived signal {signum}. Initiating graceful shutdown...")
|
546
|
+
logger.info("Ending training...")
|
771
547
|
trainer.accelerator.end_training()
|
772
|
-
|
548
|
+
logger.info("Closing communications...")
|
773
549
|
comms_handler.close()
|
774
550
|
sys.exit(0)
|
775
551
|
|
776
552
|
signal.signal(signal.SIGINT, signal_handler)
|
777
553
|
signal.signal(signal.SIGTERM, signal_handler)
|
778
554
|
|
779
|
-
|
780
|
-
|
555
|
+
logger.info("Starting training...")
|
556
|
+
try:
|
557
|
+
trainer.train()
|
558
|
+
except Exception as e:
|
559
|
+
logger.error(f"Error during training: {e}")
|
560
|
+
logger.error(f"Error type: {type(e).__name__}")
|
561
|
+
raise
|
781
562
|
|
782
563
|
except KeyboardInterrupt:
|
783
|
-
|
564
|
+
logger.info("\nReceived interrupt, shutting down...")
|
784
565
|
except Exception as e:
|
785
|
-
|
566
|
+
logger.error(f"Error: {e}")
|
786
567
|
comms_handler.send_status({"status": "error", "error": str(e)})
|
787
568
|
raise e
|
788
569
|
finally:
|
789
|
-
|
570
|
+
logger.info("Cleaning up resources...")
|
790
571
|
trainer.accelerator.end_training()
|
791
572
|
comms_handler.close()
|
792
|
-
|
573
|
+
logger.info("Cleanup complete")
|
793
574
|
|
794
575
|
|
795
576
|
if __name__ == "__main__":
|