arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +98 -62
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +6 -4
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
- arbor_ai-0.2.2.dist-info/RECORD +51 -0
- arbor_ai-0.2.1.dist-info/RECORD +0 -42
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
|
|
1
|
+
import os
|
2
|
+
import shutil
|
3
|
+
import threading
|
4
|
+
import time
|
5
|
+
from typing import Callable, Optional
|
6
|
+
|
7
|
+
from peft import AutoPeftModelForCausalLM
|
8
|
+
from transformers import Trainer
|
9
|
+
|
10
|
+
from arbor.server.services.comms.comms import ArborScriptCommsHandler
|
11
|
+
|
12
|
+
|
13
|
+
class CommandMonitor:
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
comms_handler: ArborScriptCommsHandler,
|
17
|
+
trainer: Trainer,
|
18
|
+
base_model_name: str,
|
19
|
+
ingestion_monitor: Optional["IngestionMonitor"] = None,
|
20
|
+
):
|
21
|
+
self.comms_handler = comms_handler
|
22
|
+
self.trainer = trainer
|
23
|
+
self.base_model_name = base_model_name
|
24
|
+
self.command_thread = threading.Thread(
|
25
|
+
target=self._monitor_commands, daemon=True
|
26
|
+
)
|
27
|
+
self.ingestion_monitor = ingestion_monitor
|
28
|
+
|
29
|
+
def start(self):
|
30
|
+
self.command_thread.start()
|
31
|
+
|
32
|
+
def _monitor_commands(self):
|
33
|
+
"""Background thread that monitors for commands from the server."""
|
34
|
+
if not self.comms_handler:
|
35
|
+
return
|
36
|
+
try:
|
37
|
+
for command in self.comms_handler.receive_command():
|
38
|
+
print(f"Main process received command: {command}")
|
39
|
+
if (
|
40
|
+
command.get("command") == "save_model"
|
41
|
+
and self.trainer.accelerator.is_main_process
|
42
|
+
):
|
43
|
+
print(
|
44
|
+
f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
|
45
|
+
)
|
46
|
+
while self.ingestion_monitor and (
|
47
|
+
self.ingestion_monitor.time_since_last_step() <= 10
|
48
|
+
or self.ingestion_monitor.time_since_last_queue_pop() <= 10
|
49
|
+
):
|
50
|
+
print(f"Waiting for steps to finish")
|
51
|
+
if self.ingestion_monitor:
|
52
|
+
print(
|
53
|
+
f"Time since last step: {self.ingestion_monitor.time_since_last_step():.1f} (needs to be >= 10)"
|
54
|
+
)
|
55
|
+
print(
|
56
|
+
f"Time since last queue pop: {self.ingestion_monitor.time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
57
|
+
)
|
58
|
+
time.sleep(5)
|
59
|
+
print("[Training Script] Saving model...")
|
60
|
+
if self.trainer.peft_config:
|
61
|
+
self.trainer.save_model(
|
62
|
+
output_dir=self.trainer.args.output_dir + "/adapter/"
|
63
|
+
)
|
64
|
+
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
65
|
+
self.trainer.args.output_dir + "/adapter/",
|
66
|
+
config=self.trainer.peft_config,
|
67
|
+
)
|
68
|
+
merged_model = _model_to_merge.merge_and_unload()
|
69
|
+
merged_model.save_pretrained(
|
70
|
+
self.trainer.args.output_dir,
|
71
|
+
safe_serialization=True,
|
72
|
+
)
|
73
|
+
self.trainer.processing_class.save_pretrained(
|
74
|
+
self.trainer.args.output_dir
|
75
|
+
)
|
76
|
+
else:
|
77
|
+
self.trainer.save_model()
|
78
|
+
|
79
|
+
print("[Training Script] Model saved")
|
80
|
+
self.comms_handler.send_status(
|
81
|
+
{
|
82
|
+
"status": "model_saved",
|
83
|
+
"output_dir": self.trainer.args.output_dir,
|
84
|
+
}
|
85
|
+
)
|
86
|
+
elif command.get("command") == "save_checkpoint":
|
87
|
+
print(
|
88
|
+
f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
|
89
|
+
)
|
90
|
+
while self.ingestion_monitor and (
|
91
|
+
self.ingestion_monitor.time_since_last_step() <= 10
|
92
|
+
or self.ingestion_monitor.time_since_last_queue_pop() <= 10
|
93
|
+
):
|
94
|
+
print(f"Waiting for steps to finish")
|
95
|
+
if self.ingestion_monitor:
|
96
|
+
print(
|
97
|
+
f"Time since last step: {self.ingestion_monitor.time_since_last_step():.1f} (needs to be >= 10)"
|
98
|
+
)
|
99
|
+
print(
|
100
|
+
f"Time since last queue pop: {self.ingestion_monitor.time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
101
|
+
)
|
102
|
+
time.sleep(5)
|
103
|
+
if self.trainer.peft_config:
|
104
|
+
self.trainer.save_model(
|
105
|
+
output_dir=self.trainer.args.output_dir
|
106
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
|
107
|
+
)
|
108
|
+
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
109
|
+
self.trainer.args.output_dir
|
110
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
|
111
|
+
config=self.trainer.peft_config,
|
112
|
+
)
|
113
|
+
merged_model = _model_to_merge.merge_and_unload()
|
114
|
+
merged_model.save_pretrained(
|
115
|
+
self.trainer.args.output_dir
|
116
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
117
|
+
safe_serialization=True,
|
118
|
+
)
|
119
|
+
self.trainer.processing_class.save_pretrained(
|
120
|
+
self.trainer.args.output_dir
|
121
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
self.trainer.save_model(
|
125
|
+
output_dir=self.trainer.args.output_dir
|
126
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
127
|
+
)
|
128
|
+
|
129
|
+
# Copy checkpoint files to root output directory
|
130
|
+
checkpoint_dir = (
|
131
|
+
self.trainer.args.output_dir
|
132
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
133
|
+
)
|
134
|
+
root_dir = self.trainer.args.output_dir
|
135
|
+
|
136
|
+
# Copy all files from checkpoint dir to root dir, overwriting if they exist
|
137
|
+
# (effectively saves the checkpoint to the output directory)
|
138
|
+
for item in os.listdir(checkpoint_dir):
|
139
|
+
src = os.path.join(checkpoint_dir, item)
|
140
|
+
dst = os.path.join(root_dir, item)
|
141
|
+
if os.path.isdir(src):
|
142
|
+
if os.path.exists(dst):
|
143
|
+
shutil.rmtree(dst)
|
144
|
+
shutil.copytree(src, dst)
|
145
|
+
else:
|
146
|
+
shutil.copy2(src, dst)
|
147
|
+
|
148
|
+
self.comms_handler.send_status(
|
149
|
+
{
|
150
|
+
"status": "checkpoint_saved",
|
151
|
+
"checkpoint_name": command.get("checkpoint_name"),
|
152
|
+
"output_dir": self.trainer.args.output_dir
|
153
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
154
|
+
}
|
155
|
+
)
|
156
|
+
self.comms_handler.send_status(
|
157
|
+
{
|
158
|
+
"status": "model_saved",
|
159
|
+
"output_dir": self.trainer.args.output_dir,
|
160
|
+
}
|
161
|
+
)
|
162
|
+
elif command.get("command") == "terminate":
|
163
|
+
print("TERMINATED")
|
164
|
+
self.trainer.accelerator.end_training()
|
165
|
+
self.comms_handler.send_status({"status": "terminated"})
|
166
|
+
|
167
|
+
except Exception as e:
|
168
|
+
print(e)
|
169
|
+
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
@@ -0,0 +1,176 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from functools import lru_cache
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
6
|
+
|
7
|
+
from accelerate import Accelerator
|
8
|
+
from datasets import Dataset as HuggingFaceDataset
|
9
|
+
from torch.utils.data import Dataset as TorchDataset
|
10
|
+
|
11
|
+
from arbor.server.services.comms.comms import ArborScriptCommsHandler
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class BlockingRotatingQueueDataset(TorchDataset):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
size=10_000, # Just a random number
|
20
|
+
maxsize=100,
|
21
|
+
ingestion_monitor: Optional["IngestionMonitor"] = None,
|
22
|
+
):
|
23
|
+
self.size = size
|
24
|
+
# Use a regular cache dict instead of lru_cache to avoid unhashable type issues
|
25
|
+
self._data_cache = {}
|
26
|
+
self._cache_maxsize = maxsize
|
27
|
+
self.completion_counters = {}
|
28
|
+
self.ingestion_monitor = ingestion_monitor
|
29
|
+
self.accelerator = None
|
30
|
+
self.comms_handler = None
|
31
|
+
|
32
|
+
def set_accelerator(self, accelerator: Accelerator):
|
33
|
+
self.accelerator = accelerator
|
34
|
+
|
35
|
+
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
36
|
+
self.comms_handler = comms_handler
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return self.size
|
40
|
+
|
41
|
+
def _get_data(self, idx):
|
42
|
+
rank = self.accelerator.process_index
|
43
|
+
world_size = self.accelerator.num_processes
|
44
|
+
|
45
|
+
if self.accelerator.is_main_process and self.ingestion_monitor:
|
46
|
+
self.ingestion_monitor.set_last_queue_pop_time()
|
47
|
+
|
48
|
+
if idx not in self.completion_counters:
|
49
|
+
self.completion_counters[idx] = 0
|
50
|
+
|
51
|
+
try:
|
52
|
+
new_data = self.comms_handler.receive_data()
|
53
|
+
|
54
|
+
except Exception as e:
|
55
|
+
print(f"[rank {rank}] Error receiving data: {e}")
|
56
|
+
if "unhashable" in str(e):
|
57
|
+
print(
|
58
|
+
f"[rank {rank}] DEBUGGING: Unhashable type error in data reception"
|
59
|
+
)
|
60
|
+
print(
|
61
|
+
f"[rank {rank}] This might be related to caching or data structure issues"
|
62
|
+
)
|
63
|
+
new_data = None
|
64
|
+
|
65
|
+
return new_data
|
66
|
+
|
67
|
+
def get_cached_data(self, idx):
|
68
|
+
"""Get data with simple dictionary caching instead of lru_cache"""
|
69
|
+
if idx in self._data_cache:
|
70
|
+
return self._data_cache[idx]
|
71
|
+
|
72
|
+
# If cache is full, clear oldest entries (simple FIFO)
|
73
|
+
if len(self._data_cache) >= self._cache_maxsize:
|
74
|
+
# Remove first half of cache entries
|
75
|
+
keys_to_remove = list(self._data_cache.keys())[: self._cache_maxsize // 2]
|
76
|
+
for key in keys_to_remove:
|
77
|
+
del self._data_cache[key]
|
78
|
+
|
79
|
+
# Get new data and cache it
|
80
|
+
data = self._get_data(idx)
|
81
|
+
self._data_cache[idx] = data
|
82
|
+
return data
|
83
|
+
|
84
|
+
def __getitem__(self, idx):
|
85
|
+
print(f"Getting item {idx}")
|
86
|
+
data = self.get_cached_data(idx)
|
87
|
+
|
88
|
+
if data is None:
|
89
|
+
return None
|
90
|
+
|
91
|
+
counter = self.completion_counters.get(idx, 0)
|
92
|
+
item = data[counter]
|
93
|
+
self.completion_counters[idx] = (counter + 1) % len(data)
|
94
|
+
return item
|
95
|
+
|
96
|
+
|
97
|
+
class BlockingQueueDataset(HuggingFaceDataset):
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
ingestion_monitor: Optional["IngestionMonitor"] = None,
|
101
|
+
):
|
102
|
+
self._buffer: List[Dict[str, Any]] = []
|
103
|
+
self._logger = logging.getLogger(__name__)
|
104
|
+
self.ingestion_monitor = ingestion_monitor
|
105
|
+
|
106
|
+
def set_accelerator(self, accelerator: Accelerator):
|
107
|
+
self.accelerator = accelerator
|
108
|
+
|
109
|
+
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
110
|
+
self.comms_handler = comms_handler
|
111
|
+
|
112
|
+
def __len__(self) -> int:
|
113
|
+
return 1_000_000
|
114
|
+
|
115
|
+
def _fill_buffer(self, target_size: int) -> None:
|
116
|
+
while len(self._buffer) < target_size:
|
117
|
+
try:
|
118
|
+
if self.comms_handler is None:
|
119
|
+
raise ValueError("comms_handler is not initialized")
|
120
|
+
|
121
|
+
group = self.comms_handler.receive_data()
|
122
|
+
|
123
|
+
if group is not None:
|
124
|
+
self._logger.debug("Received group from comms handler")
|
125
|
+
for trajectory in group:
|
126
|
+
trajectory_copy = json.loads(json.dumps(trajectory))
|
127
|
+
for item in trajectory:
|
128
|
+
item["trajectory"] = trajectory_copy
|
129
|
+
self._buffer.append(item)
|
130
|
+
|
131
|
+
except Exception as e:
|
132
|
+
if "Context was terminated" in str(e):
|
133
|
+
self._logger.error(
|
134
|
+
"ZMQ context was terminated while filling buffer"
|
135
|
+
)
|
136
|
+
raise RuntimeError("ZMQ context was terminated") from e
|
137
|
+
self._logger.warning(f"Error receiving data: {e}")
|
138
|
+
continue
|
139
|
+
|
140
|
+
def _transform_batch(self, items: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
|
141
|
+
if not items:
|
142
|
+
raise ValueError("Cannot transform empty batch")
|
143
|
+
|
144
|
+
return {key: [item[key] for item in items] for key in items[0].keys()}
|
145
|
+
|
146
|
+
def __getitem__(self, idx: List[int]) -> Dict[str, List[Any]]:
|
147
|
+
if self.accelerator is None:
|
148
|
+
self._logger.error("Accelerator not initialized")
|
149
|
+
raise ValueError("Accelerator must be initialized before getting items")
|
150
|
+
if self.comms_handler is None:
|
151
|
+
self._logger.error("Comms handler not initialized")
|
152
|
+
raise ValueError("Comms handler must be initialized before getting items")
|
153
|
+
|
154
|
+
batch_size = len(idx)
|
155
|
+
if batch_size == 0:
|
156
|
+
raise ValueError("Batch size must be greater than 0")
|
157
|
+
|
158
|
+
try:
|
159
|
+
self._fill_buffer(batch_size)
|
160
|
+
|
161
|
+
if len(self._buffer) < batch_size:
|
162
|
+
raise RuntimeError(
|
163
|
+
f"Not enough items in buffer (got {len(self._buffer)}, need {batch_size})"
|
164
|
+
)
|
165
|
+
|
166
|
+
batch_items = self._buffer[:batch_size]
|
167
|
+
self._buffer = self._buffer[batch_size:]
|
168
|
+
|
169
|
+
if self.ingestion_monitor:
|
170
|
+
self.ingestion_monitor.set_last_queue_pop_time()
|
171
|
+
|
172
|
+
return self._transform_batch(batch_items)
|
173
|
+
|
174
|
+
except Exception as e:
|
175
|
+
self._logger.error(f"Error getting batch: {e}")
|
176
|
+
raise
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import time
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
|
5
|
+
class IngestionMonitor:
|
6
|
+
"""Monitor for tracking timing of training steps and data ingestion (queue pops)"""
|
7
|
+
|
8
|
+
def __init__(self):
|
9
|
+
self._last_step_time: Optional[float] = None
|
10
|
+
self._last_queue_pop_time: Optional[float] = None
|
11
|
+
|
12
|
+
def time_since_last_step(self) -> float:
|
13
|
+
"""Get time elapsed since last training step"""
|
14
|
+
if self._last_step_time is None:
|
15
|
+
return float("inf")
|
16
|
+
return time.time() - self._last_step_time
|
17
|
+
|
18
|
+
def time_since_last_queue_pop(self) -> float:
|
19
|
+
"""Get time elapsed since last queue pop"""
|
20
|
+
if self._last_queue_pop_time is None:
|
21
|
+
return float("inf")
|
22
|
+
return time.time() - self._last_queue_pop_time
|
23
|
+
|
24
|
+
def set_last_queue_pop_time(self, timestamp: Optional[float] = None) -> None:
|
25
|
+
"""Set the last queue pop time"""
|
26
|
+
self._last_queue_pop_time = timestamp if timestamp is not None else time.time()
|
27
|
+
|
28
|
+
def set_last_step_time(self, timestamp: Optional[float] = None) -> None:
|
29
|
+
"""Set the last step time"""
|
30
|
+
self._last_step_time = timestamp if timestamp is not None else time.time()
|
31
|
+
|
32
|
+
def reset(self) -> None:
|
33
|
+
"""Reset all timing data"""
|
34
|
+
self._last_step_time = None
|
35
|
+
self._last_queue_pop_time = None
|
@@ -0,0 +1,124 @@
|
|
1
|
+
## Mock arbor sending over data for testing
|
2
|
+
import threading
|
3
|
+
import time
|
4
|
+
|
5
|
+
import zmq
|
6
|
+
|
7
|
+
from arbor.server.services.comms.comms import ArborServerCommsHandler
|
8
|
+
|
9
|
+
group_example = [ # Entire group of trajectories
|
10
|
+
[ # Trajectory with different modules
|
11
|
+
{ # geography module
|
12
|
+
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
13
|
+
"completion": [{"role": "assistant", "content": "Paris"}],
|
14
|
+
"advantage": 0.9,
|
15
|
+
},
|
16
|
+
{ # math module
|
17
|
+
"messages": [{"role": "user", "content": "What is 2 * 2 + 2?"}],
|
18
|
+
"completion": [{"role": "assistant", "content": "6"}],
|
19
|
+
"advantage": 0.8,
|
20
|
+
},
|
21
|
+
{ # car module
|
22
|
+
"messages": [
|
23
|
+
{"role": "user", "content": "When did the first honda civic come out?"}
|
24
|
+
],
|
25
|
+
"completion": [{"role": "assistant", "content": "1973"}],
|
26
|
+
"advantage": 0.7,
|
27
|
+
},
|
28
|
+
],
|
29
|
+
[ # Trajectory with different modules
|
30
|
+
{ # geography module
|
31
|
+
"messages": [
|
32
|
+
{"role": "user", "content": "What is the capital of Germany?"}
|
33
|
+
],
|
34
|
+
"completion": {
|
35
|
+
"role": "assistant",
|
36
|
+
"content": "Berlin is the capital of Germany",
|
37
|
+
},
|
38
|
+
"advantage": 0.1,
|
39
|
+
},
|
40
|
+
{ # math module
|
41
|
+
"messages": [{"role": "user", "content": "What is 2 + 2?"}],
|
42
|
+
"completion": [{"role": "assistant", "content": "3"}],
|
43
|
+
"advantage": 0.2,
|
44
|
+
},
|
45
|
+
],
|
46
|
+
]
|
47
|
+
|
48
|
+
|
49
|
+
def flatten_batch(batch):
|
50
|
+
return [item for sublist in batch for item in sublist]
|
51
|
+
|
52
|
+
|
53
|
+
def debug_data_generator(server_comms_handler):
|
54
|
+
idx = 0
|
55
|
+
while True:
|
56
|
+
print(f"Sending group:") # Debug print
|
57
|
+
server_comms_handler.send_data(group_example)
|
58
|
+
idx += 1
|
59
|
+
time.sleep(1)
|
60
|
+
|
61
|
+
if idx >= 100:
|
62
|
+
server_comms_handler.send_command({"command": "save_model"})
|
63
|
+
|
64
|
+
|
65
|
+
def status_listener(server_comms_handler):
|
66
|
+
# Need to set subscription for PUB/SUB pattern
|
67
|
+
server_comms_handler.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
68
|
+
for status in server_comms_handler.receive_status():
|
69
|
+
print(f"Status: {status}")
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
server_comms_handler = ArborServerCommsHandler(
|
74
|
+
host="localhost",
|
75
|
+
)
|
76
|
+
|
77
|
+
# Get available ports from the server comms handler
|
78
|
+
command_port = server_comms_handler.command_port
|
79
|
+
status_port = server_comms_handler.status_port
|
80
|
+
data_port = server_comms_handler.data_port
|
81
|
+
broadcast_port = server_comms_handler.broadcast_port
|
82
|
+
handshake_port = server_comms_handler.handshake_port
|
83
|
+
|
84
|
+
# Print the command that would be used to connect to this mock server
|
85
|
+
print("\nTo connect to this mock server, run the following command:")
|
86
|
+
print(
|
87
|
+
f"CUDA_VISIBLE_DEVICES=2 python arbor/server/services/scripts/mmgrpo_training.py \\"
|
88
|
+
)
|
89
|
+
print(f" --debug \\")
|
90
|
+
print(f" --command_port {command_port} \\")
|
91
|
+
print(f" --status_port {status_port} \\")
|
92
|
+
print(f" --data_port {data_port} \\")
|
93
|
+
print(f" --broadcast_port {broadcast_port} \\")
|
94
|
+
print(f" --handshake_port {handshake_port} \\")
|
95
|
+
print(f" --vllm_group_port 0 \\")
|
96
|
+
print(f" --vllm_port 0 \\")
|
97
|
+
print(f" --model Qwen/Qwen3-0.6B \\")
|
98
|
+
print(f' --trl_train_kwargs \'{{"output_dir": ".", "report_to": "none"}}\'')
|
99
|
+
print(
|
100
|
+
"\nThis mock server will simulate sending training data to the training process."
|
101
|
+
)
|
102
|
+
print("Press Ctrl+C to exit the mock server.\n")
|
103
|
+
|
104
|
+
server_comms_handler.wait_for_clients(1)
|
105
|
+
|
106
|
+
debug_thread = threading.Thread(
|
107
|
+
target=debug_data_generator, args=(server_comms_handler,), daemon=True
|
108
|
+
)
|
109
|
+
debug_thread.start()
|
110
|
+
|
111
|
+
status_listener_thread = threading.Thread(
|
112
|
+
target=status_listener, args=(server_comms_handler,), daemon=True
|
113
|
+
)
|
114
|
+
status_listener_thread.start()
|
115
|
+
|
116
|
+
try:
|
117
|
+
print("Mock server started and waiting for training process to connect...")
|
118
|
+
while True:
|
119
|
+
time.sleep(1)
|
120
|
+
except KeyboardInterrupt:
|
121
|
+
print("\nShutting down mock server...")
|
122
|
+
finally:
|
123
|
+
server_comms_handler.close()
|
124
|
+
print("Mock server shutdown complete.")
|
@@ -6,14 +6,14 @@ from datetime import datetime
|
|
6
6
|
from pathlib import Path
|
7
7
|
|
8
8
|
from arbor.server.api.models.schemas import FineTuneRequest
|
9
|
-
from arbor.server.core.config import
|
9
|
+
from arbor.server.core.config import Config
|
10
10
|
from arbor.server.services.file_manager import FileManager
|
11
11
|
from arbor.server.services.job_manager import Job, JobEvent, JobStatus
|
12
12
|
|
13
13
|
|
14
14
|
class TrainingManager:
|
15
|
-
def __init__(self,
|
16
|
-
self.
|
15
|
+
def __init__(self, config: Config):
|
16
|
+
self.config = config
|
17
17
|
|
18
18
|
def make_output_dir(self, request: FineTuneRequest):
|
19
19
|
model_name = request.model.split("/")[-1].lower()
|
@@ -24,7 +24,7 @@ class TrainingManager:
|
|
24
24
|
)
|
25
25
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
26
26
|
name = f"ft:{model_name}:{suffix}:{timestamp}"
|
27
|
-
return name, str(Path(self.
|
27
|
+
return name, str(Path(self.config.STORAGE_PATH).resolve() / "models" / name)
|
28
28
|
|
29
29
|
def find_train_args_sft(self, request: FineTuneRequest, file_manager: FileManager):
|
30
30
|
file = file_manager.get_file(request.training_file)
|