arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ import os
2
+ import shutil
3
+ import threading
4
+ import time
5
+ from typing import Callable, Optional
6
+
7
+ from peft import AutoPeftModelForCausalLM
8
+ from transformers import Trainer
9
+
10
+ from arbor.server.services.comms.comms import ArborScriptCommsHandler
11
+
12
+
13
+ class CommandMonitor:
14
+ def __init__(
15
+ self,
16
+ comms_handler: ArborScriptCommsHandler,
17
+ trainer: Trainer,
18
+ base_model_name: str,
19
+ ingestion_monitor: Optional["IngestionMonitor"] = None,
20
+ ):
21
+ self.comms_handler = comms_handler
22
+ self.trainer = trainer
23
+ self.base_model_name = base_model_name
24
+ self.command_thread = threading.Thread(
25
+ target=self._monitor_commands, daemon=True
26
+ )
27
+ self.ingestion_monitor = ingestion_monitor
28
+
29
+ def start(self):
30
+ self.command_thread.start()
31
+
32
+ def _monitor_commands(self):
33
+ """Background thread that monitors for commands from the server."""
34
+ if not self.comms_handler:
35
+ return
36
+ try:
37
+ for command in self.comms_handler.receive_command():
38
+ print(f"Main process received command: {command}")
39
+ if (
40
+ command.get("command") == "save_model"
41
+ and self.trainer.accelerator.is_main_process
42
+ ):
43
+ print(
44
+ f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
45
+ )
46
+ while self.ingestion_monitor and (
47
+ self.ingestion_monitor.time_since_last_step() <= 10
48
+ or self.ingestion_monitor.time_since_last_queue_pop() <= 10
49
+ ):
50
+ print(f"Waiting for steps to finish")
51
+ if self.ingestion_monitor:
52
+ print(
53
+ f"Time since last step: {self.ingestion_monitor.time_since_last_step():.1f} (needs to be >= 10)"
54
+ )
55
+ print(
56
+ f"Time since last queue pop: {self.ingestion_monitor.time_since_last_queue_pop():.1f} (needs to be >= 10)"
57
+ )
58
+ time.sleep(5)
59
+ print("[Training Script] Saving model...")
60
+ if self.trainer.peft_config:
61
+ self.trainer.save_model(
62
+ output_dir=self.trainer.args.output_dir + "/adapter/"
63
+ )
64
+ _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
65
+ self.trainer.args.output_dir + "/adapter/",
66
+ config=self.trainer.peft_config,
67
+ )
68
+ merged_model = _model_to_merge.merge_and_unload()
69
+ merged_model.save_pretrained(
70
+ self.trainer.args.output_dir,
71
+ safe_serialization=True,
72
+ )
73
+ self.trainer.processing_class.save_pretrained(
74
+ self.trainer.args.output_dir
75
+ )
76
+ else:
77
+ self.trainer.save_model()
78
+
79
+ print("[Training Script] Model saved")
80
+ self.comms_handler.send_status(
81
+ {
82
+ "status": "model_saved",
83
+ "output_dir": self.trainer.args.output_dir,
84
+ }
85
+ )
86
+ elif command.get("command") == "save_checkpoint":
87
+ print(
88
+ f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
89
+ )
90
+ while self.ingestion_monitor and (
91
+ self.ingestion_monitor.time_since_last_step() <= 10
92
+ or self.ingestion_monitor.time_since_last_queue_pop() <= 10
93
+ ):
94
+ print(f"Waiting for steps to finish")
95
+ if self.ingestion_monitor:
96
+ print(
97
+ f"Time since last step: {self.ingestion_monitor.time_since_last_step():.1f} (needs to be >= 10)"
98
+ )
99
+ print(
100
+ f"Time since last queue pop: {self.ingestion_monitor.time_since_last_queue_pop():.1f} (needs to be >= 10)"
101
+ )
102
+ time.sleep(5)
103
+ if self.trainer.peft_config:
104
+ self.trainer.save_model(
105
+ output_dir=self.trainer.args.output_dir
106
+ + f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
107
+ )
108
+ _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
109
+ self.trainer.args.output_dir
110
+ + f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
111
+ config=self.trainer.peft_config,
112
+ )
113
+ merged_model = _model_to_merge.merge_and_unload()
114
+ merged_model.save_pretrained(
115
+ self.trainer.args.output_dir
116
+ + f"/checkpoints/{command.get('checkpoint_name')}/",
117
+ safe_serialization=True,
118
+ )
119
+ self.trainer.processing_class.save_pretrained(
120
+ self.trainer.args.output_dir
121
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
122
+ )
123
+ else:
124
+ self.trainer.save_model(
125
+ output_dir=self.trainer.args.output_dir
126
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
127
+ )
128
+
129
+ # Copy checkpoint files to root output directory
130
+ checkpoint_dir = (
131
+ self.trainer.args.output_dir
132
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
133
+ )
134
+ root_dir = self.trainer.args.output_dir
135
+
136
+ # Copy all files from checkpoint dir to root dir, overwriting if they exist
137
+ # (effectively saves the checkpoint to the output directory)
138
+ for item in os.listdir(checkpoint_dir):
139
+ src = os.path.join(checkpoint_dir, item)
140
+ dst = os.path.join(root_dir, item)
141
+ if os.path.isdir(src):
142
+ if os.path.exists(dst):
143
+ shutil.rmtree(dst)
144
+ shutil.copytree(src, dst)
145
+ else:
146
+ shutil.copy2(src, dst)
147
+
148
+ self.comms_handler.send_status(
149
+ {
150
+ "status": "checkpoint_saved",
151
+ "checkpoint_name": command.get("checkpoint_name"),
152
+ "output_dir": self.trainer.args.output_dir
153
+ + f"/checkpoints/{command.get('checkpoint_name')}/",
154
+ }
155
+ )
156
+ self.comms_handler.send_status(
157
+ {
158
+ "status": "model_saved",
159
+ "output_dir": self.trainer.args.output_dir,
160
+ }
161
+ )
162
+ elif command.get("command") == "terminate":
163
+ print("TERMINATED")
164
+ self.trainer.accelerator.end_training()
165
+ self.comms_handler.send_status({"status": "terminated"})
166
+
167
+ except Exception as e:
168
+ print(e)
169
+ self.comms_handler.send_status({"status": "error", "error": str(e)})
@@ -0,0 +1,176 @@
1
+ import json
2
+ import logging
3
+ import time
4
+ from functools import lru_cache
5
+ from typing import Any, Callable, Dict, List, Optional
6
+
7
+ from accelerate import Accelerator
8
+ from datasets import Dataset as HuggingFaceDataset
9
+ from torch.utils.data import Dataset as TorchDataset
10
+
11
+ from arbor.server.services.comms.comms import ArborScriptCommsHandler
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class BlockingRotatingQueueDataset(TorchDataset):
17
+ def __init__(
18
+ self,
19
+ size=10_000, # Just a random number
20
+ maxsize=100,
21
+ ingestion_monitor: Optional["IngestionMonitor"] = None,
22
+ ):
23
+ self.size = size
24
+ # Use a regular cache dict instead of lru_cache to avoid unhashable type issues
25
+ self._data_cache = {}
26
+ self._cache_maxsize = maxsize
27
+ self.completion_counters = {}
28
+ self.ingestion_monitor = ingestion_monitor
29
+ self.accelerator = None
30
+ self.comms_handler = None
31
+
32
+ def set_accelerator(self, accelerator: Accelerator):
33
+ self.accelerator = accelerator
34
+
35
+ def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
36
+ self.comms_handler = comms_handler
37
+
38
+ def __len__(self):
39
+ return self.size
40
+
41
+ def _get_data(self, idx):
42
+ rank = self.accelerator.process_index
43
+ world_size = self.accelerator.num_processes
44
+
45
+ if self.accelerator.is_main_process and self.ingestion_monitor:
46
+ self.ingestion_monitor.set_last_queue_pop_time()
47
+
48
+ if idx not in self.completion_counters:
49
+ self.completion_counters[idx] = 0
50
+
51
+ try:
52
+ new_data = self.comms_handler.receive_data()
53
+
54
+ except Exception as e:
55
+ print(f"[rank {rank}] Error receiving data: {e}")
56
+ if "unhashable" in str(e):
57
+ print(
58
+ f"[rank {rank}] DEBUGGING: Unhashable type error in data reception"
59
+ )
60
+ print(
61
+ f"[rank {rank}] This might be related to caching or data structure issues"
62
+ )
63
+ new_data = None
64
+
65
+ return new_data
66
+
67
+ def get_cached_data(self, idx):
68
+ """Get data with simple dictionary caching instead of lru_cache"""
69
+ if idx in self._data_cache:
70
+ return self._data_cache[idx]
71
+
72
+ # If cache is full, clear oldest entries (simple FIFO)
73
+ if len(self._data_cache) >= self._cache_maxsize:
74
+ # Remove first half of cache entries
75
+ keys_to_remove = list(self._data_cache.keys())[: self._cache_maxsize // 2]
76
+ for key in keys_to_remove:
77
+ del self._data_cache[key]
78
+
79
+ # Get new data and cache it
80
+ data = self._get_data(idx)
81
+ self._data_cache[idx] = data
82
+ return data
83
+
84
+ def __getitem__(self, idx):
85
+ print(f"Getting item {idx}")
86
+ data = self.get_cached_data(idx)
87
+
88
+ if data is None:
89
+ return None
90
+
91
+ counter = self.completion_counters.get(idx, 0)
92
+ item = data[counter]
93
+ self.completion_counters[idx] = (counter + 1) % len(data)
94
+ return item
95
+
96
+
97
+ class BlockingQueueDataset(HuggingFaceDataset):
98
+ def __init__(
99
+ self,
100
+ ingestion_monitor: Optional["IngestionMonitor"] = None,
101
+ ):
102
+ self._buffer: List[Dict[str, Any]] = []
103
+ self._logger = logging.getLogger(__name__)
104
+ self.ingestion_monitor = ingestion_monitor
105
+
106
+ def set_accelerator(self, accelerator: Accelerator):
107
+ self.accelerator = accelerator
108
+
109
+ def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
110
+ self.comms_handler = comms_handler
111
+
112
+ def __len__(self) -> int:
113
+ return 1_000_000
114
+
115
+ def _fill_buffer(self, target_size: int) -> None:
116
+ while len(self._buffer) < target_size:
117
+ try:
118
+ if self.comms_handler is None:
119
+ raise ValueError("comms_handler is not initialized")
120
+
121
+ group = self.comms_handler.receive_data()
122
+
123
+ if group is not None:
124
+ self._logger.debug("Received group from comms handler")
125
+ for trajectory in group:
126
+ trajectory_copy = json.loads(json.dumps(trajectory))
127
+ for item in trajectory:
128
+ item["trajectory"] = trajectory_copy
129
+ self._buffer.append(item)
130
+
131
+ except Exception as e:
132
+ if "Context was terminated" in str(e):
133
+ self._logger.error(
134
+ "ZMQ context was terminated while filling buffer"
135
+ )
136
+ raise RuntimeError("ZMQ context was terminated") from e
137
+ self._logger.warning(f"Error receiving data: {e}")
138
+ continue
139
+
140
+ def _transform_batch(self, items: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
141
+ if not items:
142
+ raise ValueError("Cannot transform empty batch")
143
+
144
+ return {key: [item[key] for item in items] for key in items[0].keys()}
145
+
146
+ def __getitem__(self, idx: List[int]) -> Dict[str, List[Any]]:
147
+ if self.accelerator is None:
148
+ self._logger.error("Accelerator not initialized")
149
+ raise ValueError("Accelerator must be initialized before getting items")
150
+ if self.comms_handler is None:
151
+ self._logger.error("Comms handler not initialized")
152
+ raise ValueError("Comms handler must be initialized before getting items")
153
+
154
+ batch_size = len(idx)
155
+ if batch_size == 0:
156
+ raise ValueError("Batch size must be greater than 0")
157
+
158
+ try:
159
+ self._fill_buffer(batch_size)
160
+
161
+ if len(self._buffer) < batch_size:
162
+ raise RuntimeError(
163
+ f"Not enough items in buffer (got {len(self._buffer)}, need {batch_size})"
164
+ )
165
+
166
+ batch_items = self._buffer[:batch_size]
167
+ self._buffer = self._buffer[batch_size:]
168
+
169
+ if self.ingestion_monitor:
170
+ self.ingestion_monitor.set_last_queue_pop_time()
171
+
172
+ return self._transform_batch(batch_items)
173
+
174
+ except Exception as e:
175
+ self._logger.error(f"Error getting batch: {e}")
176
+ raise
@@ -0,0 +1,35 @@
1
+ import time
2
+ from typing import Optional
3
+
4
+
5
+ class IngestionMonitor:
6
+ """Monitor for tracking timing of training steps and data ingestion (queue pops)"""
7
+
8
+ def __init__(self):
9
+ self._last_step_time: Optional[float] = None
10
+ self._last_queue_pop_time: Optional[float] = None
11
+
12
+ def time_since_last_step(self) -> float:
13
+ """Get time elapsed since last training step"""
14
+ if self._last_step_time is None:
15
+ return float("inf")
16
+ return time.time() - self._last_step_time
17
+
18
+ def time_since_last_queue_pop(self) -> float:
19
+ """Get time elapsed since last queue pop"""
20
+ if self._last_queue_pop_time is None:
21
+ return float("inf")
22
+ return time.time() - self._last_queue_pop_time
23
+
24
+ def set_last_queue_pop_time(self, timestamp: Optional[float] = None) -> None:
25
+ """Set the last queue pop time"""
26
+ self._last_queue_pop_time = timestamp if timestamp is not None else time.time()
27
+
28
+ def set_last_step_time(self, timestamp: Optional[float] = None) -> None:
29
+ """Set the last step time"""
30
+ self._last_step_time = timestamp if timestamp is not None else time.time()
31
+
32
+ def reset(self) -> None:
33
+ """Reset all timing data"""
34
+ self._last_step_time = None
35
+ self._last_queue_pop_time = None
@@ -0,0 +1,124 @@
1
+ ## Mock arbor sending over data for testing
2
+ import threading
3
+ import time
4
+
5
+ import zmq
6
+
7
+ from arbor.server.services.comms.comms import ArborServerCommsHandler
8
+
9
+ group_example = [ # Entire group of trajectories
10
+ [ # Trajectory with different modules
11
+ { # geography module
12
+ "messages": [{"role": "user", "content": "What is the capital of France?"}],
13
+ "completion": [{"role": "assistant", "content": "Paris"}],
14
+ "advantage": 0.9,
15
+ },
16
+ { # math module
17
+ "messages": [{"role": "user", "content": "What is 2 * 2 + 2?"}],
18
+ "completion": [{"role": "assistant", "content": "6"}],
19
+ "advantage": 0.8,
20
+ },
21
+ { # car module
22
+ "messages": [
23
+ {"role": "user", "content": "When did the first honda civic come out?"}
24
+ ],
25
+ "completion": [{"role": "assistant", "content": "1973"}],
26
+ "advantage": 0.7,
27
+ },
28
+ ],
29
+ [ # Trajectory with different modules
30
+ { # geography module
31
+ "messages": [
32
+ {"role": "user", "content": "What is the capital of Germany?"}
33
+ ],
34
+ "completion": {
35
+ "role": "assistant",
36
+ "content": "Berlin is the capital of Germany",
37
+ },
38
+ "advantage": 0.1,
39
+ },
40
+ { # math module
41
+ "messages": [{"role": "user", "content": "What is 2 + 2?"}],
42
+ "completion": [{"role": "assistant", "content": "3"}],
43
+ "advantage": 0.2,
44
+ },
45
+ ],
46
+ ]
47
+
48
+
49
+ def flatten_batch(batch):
50
+ return [item for sublist in batch for item in sublist]
51
+
52
+
53
+ def debug_data_generator(server_comms_handler):
54
+ idx = 0
55
+ while True:
56
+ print(f"Sending group:") # Debug print
57
+ server_comms_handler.send_data(group_example)
58
+ idx += 1
59
+ time.sleep(1)
60
+
61
+ if idx >= 100:
62
+ server_comms_handler.send_command({"command": "save_model"})
63
+
64
+
65
+ def status_listener(server_comms_handler):
66
+ # Need to set subscription for PUB/SUB pattern
67
+ server_comms_handler.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
68
+ for status in server_comms_handler.receive_status():
69
+ print(f"Status: {status}")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ server_comms_handler = ArborServerCommsHandler(
74
+ host="localhost",
75
+ )
76
+
77
+ # Get available ports from the server comms handler
78
+ command_port = server_comms_handler.command_port
79
+ status_port = server_comms_handler.status_port
80
+ data_port = server_comms_handler.data_port
81
+ broadcast_port = server_comms_handler.broadcast_port
82
+ handshake_port = server_comms_handler.handshake_port
83
+
84
+ # Print the command that would be used to connect to this mock server
85
+ print("\nTo connect to this mock server, run the following command:")
86
+ print(
87
+ f"CUDA_VISIBLE_DEVICES=2 python arbor/server/services/scripts/mmgrpo_training.py \\"
88
+ )
89
+ print(f" --debug \\")
90
+ print(f" --command_port {command_port} \\")
91
+ print(f" --status_port {status_port} \\")
92
+ print(f" --data_port {data_port} \\")
93
+ print(f" --broadcast_port {broadcast_port} \\")
94
+ print(f" --handshake_port {handshake_port} \\")
95
+ print(f" --vllm_group_port 0 \\")
96
+ print(f" --vllm_port 0 \\")
97
+ print(f" --model Qwen/Qwen3-0.6B \\")
98
+ print(f' --trl_train_kwargs \'{{"output_dir": ".", "report_to": "none"}}\'')
99
+ print(
100
+ "\nThis mock server will simulate sending training data to the training process."
101
+ )
102
+ print("Press Ctrl+C to exit the mock server.\n")
103
+
104
+ server_comms_handler.wait_for_clients(1)
105
+
106
+ debug_thread = threading.Thread(
107
+ target=debug_data_generator, args=(server_comms_handler,), daemon=True
108
+ )
109
+ debug_thread.start()
110
+
111
+ status_listener_thread = threading.Thread(
112
+ target=status_listener, args=(server_comms_handler,), daemon=True
113
+ )
114
+ status_listener_thread.start()
115
+
116
+ try:
117
+ print("Mock server started and waiting for training process to connect...")
118
+ while True:
119
+ time.sleep(1)
120
+ except KeyboardInterrupt:
121
+ print("\nShutting down mock server...")
122
+ finally:
123
+ server_comms_handler.close()
124
+ print("Mock server shutdown complete.")
@@ -6,14 +6,14 @@ from datetime import datetime
6
6
  from pathlib import Path
7
7
 
8
8
  from arbor.server.api.models.schemas import FineTuneRequest
9
- from arbor.server.core.config import Settings
9
+ from arbor.server.core.config import Config
10
10
  from arbor.server.services.file_manager import FileManager
11
11
  from arbor.server.services.job_manager import Job, JobEvent, JobStatus
12
12
 
13
13
 
14
14
  class TrainingManager:
15
- def __init__(self, settings: Settings):
16
- self.settings = settings
15
+ def __init__(self, config: Config):
16
+ self.config = config
17
17
 
18
18
  def make_output_dir(self, request: FineTuneRequest):
19
19
  model_name = request.model.split("/")[-1].lower()
@@ -24,7 +24,7 @@ class TrainingManager:
24
24
  )
25
25
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
26
26
  name = f"ft:{model_name}:{suffix}:{timestamp}"
27
- return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
27
+ return name, str(Path(self.config.STORAGE_PATH).resolve() / "models" / name)
28
28
 
29
29
  def find_train_args_sft(self, request: FineTuneRequest, file_manager: FileManager):
30
30
  file = file_manager.get_file(request.training_file)