arbor-ai 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -199,10 +199,16 @@ class GRPOConfigRequest(BaseModel):
199
199
  bf16: Optional[bool] = None
200
200
  scale_rewards: Optional[bool] = None
201
201
  max_grad_norm: Optional[float] = None
202
+ report_to: Optional[str] = None
203
+ log_completions: Optional[bool] = None
204
+ logging_steps: Optional[int] = None
205
+ mask_truncated_completions: Optional[bool] = None
206
+ # Arbor specific
207
+ max_context_length: Optional[int] = None
202
208
  lora: Optional[bool] = None
203
- update_interval: Optional[int] = None
204
209
  # To name the run
205
210
  suffix: Optional[str] = None
211
+ generation_batch_size: Optional[int] = None
206
212
 
207
213
 
208
214
  class GRPOConfigResponse(BaseModel):
@@ -216,8 +222,23 @@ class GRPOTerminateRequest(BaseModel):
216
222
  class GRPOTerminateResponse(BaseModel):
217
223
  status: str
218
224
  current_model: str
225
+ checkpoints: Optional[dict[str, str]] = None
226
+ last_checkpoint: Optional[str] = None
219
227
 
220
228
 
221
229
  class GRPOStepResponse(BaseModel):
222
230
  status: str
223
231
  current_model: str
232
+ checkpoints: dict[str, str]
233
+ last_checkpoint: Optional[str] = None
234
+
235
+
236
+ class GRPOCheckpointRequest(BaseModel):
237
+ checkpoint_name: str
238
+
239
+
240
+ class GRPOCheckpointResponse(BaseModel):
241
+ status: str
242
+ current_model: str
243
+ checkpoints: dict[str, str]
244
+ last_checkpoint: str
@@ -4,6 +4,8 @@ import subprocess
4
4
  from fastapi import APIRouter, BackgroundTasks, Request
5
5
 
6
6
  from arbor.server.api.models.schemas import (
7
+ GRPOCheckpointRequest,
8
+ GRPOCheckpointResponse,
7
9
  GRPOConfigRequest,
8
10
  GRPOConfigResponse,
9
11
  GRPORequest,
@@ -31,17 +33,24 @@ def run_grpo_step(
31
33
  inference_manager = request.app.state.inference_manager
32
34
  grpo_manager = request.app.state.grpo_manager
33
35
 
34
- current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
36
+ step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
35
37
 
36
- return GRPOStepResponse(status="success", current_model=current_model)
38
+ return GRPOStepResponse(status="success", **step_data)
37
39
 
38
40
 
39
41
  @router.post("/update_model", response_model=GRPOStepResponse)
40
42
  def update_model(request: Request):
41
43
  grpo_manager = request.app.state.grpo_manager
42
44
  inference_manager = request.app.state.inference_manager
43
- current_model = grpo_manager.update_model(request, inference_manager)
44
- return GRPOStepResponse(status="success", current_model=current_model)
45
+ update_model_data = grpo_manager.update_model(request, inference_manager)
46
+ return GRPOStepResponse(status="success", **update_model_data)
47
+
48
+
49
+ @router.post("/checkpoint", response_model=GRPOCheckpointResponse)
50
+ def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
51
+ grpo_manager = request.app.state.grpo_manager
52
+ checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
53
+ return GRPOCheckpointResponse(status="success", **checkpoint_data)
45
54
 
46
55
 
47
56
  @router.post("/terminate", response_model=GRPOTerminateResponse)
@@ -50,5 +59,5 @@ def terminate_grpo(request: Request):
50
59
  grpo_manager = request.app.state.grpo_manager
51
60
  inference_manager = request.app.state.inference_manager
52
61
 
53
- final_model = grpo_manager.terminate(inference_manager)
54
- return GRPOTerminateResponse(status="success", current_model=final_model)
62
+ terminate_data = grpo_manager.terminate(inference_manager)
63
+ return GRPOTerminateResponse(status="success", **terminate_data)
@@ -13,7 +13,11 @@ from datetime import datetime
13
13
  from pathlib import Path
14
14
  from typing import Optional
15
15
 
16
- from arbor.server.api.models.schemas import GRPOConfigRequest, GRPORequest
16
+ from arbor.server.api.models.schemas import (
17
+ GRPOCheckpointRequest,
18
+ GRPOConfigRequest,
19
+ GRPORequest,
20
+ )
17
21
  from arbor.server.core.config import Settings
18
22
  from arbor.server.services.comms.comms import ArborServerCommsHandler
19
23
  from arbor.server.services.inference_manager import InferenceManager
@@ -28,7 +32,10 @@ class GRPOManager:
28
32
  self.server_comms_handler = None
29
33
  self.status_thread = None
30
34
  self.model_saved_and_reload_requested = False
35
+ self.saving_checkpoint = False
31
36
 
37
+ self.checkpoints = {}
38
+ self.last_checkpoint = None
32
39
  self.data_count = 0
33
40
  self.last_inference_update = 0
34
41
  # Set up signal handler
@@ -86,12 +93,17 @@ class GRPOManager:
86
93
  "bf16",
87
94
  "scale_rewards",
88
95
  "max_grad_norm",
96
+ "report_to",
97
+ "log_completions",
98
+ "logging_steps",
99
+ "generation_batch_size",
100
+ "mask_truncated_completions",
89
101
  ]
90
102
  trl_train_kwargs = {
91
103
  key: train_kwargs[key] for key in trl_keys if key in train_kwargs
92
104
  }
93
105
 
94
- arbor_keys = ["update_interval", "lora"]
106
+ arbor_keys = ["max_context_length", "lora"]
95
107
  arbor_train_kwargs = {
96
108
  key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
97
109
  }
@@ -119,6 +131,8 @@ class GRPOManager:
119
131
  # Start the training process with ZMQ ports
120
132
  my_env = os.environ.copy()
121
133
  my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
134
+ # WandB can block the training process for login, so we silence it
135
+ my_env["WANDB_SILENT"] = "true"
122
136
 
123
137
  num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
124
138
 
@@ -209,6 +223,12 @@ class GRPOManager:
209
223
 
210
224
  # Launch the inference server
211
225
  print("Launching inference server...")
226
+ # launch_kwargs = {
227
+ # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
228
+ # }
229
+ inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
230
+ "max_context_length", None
231
+ )
212
232
  inference_manager.launch(self.current_model)
213
233
 
214
234
  def _handle_status_updates(self, inference_manager: InferenceManager):
@@ -228,6 +248,12 @@ class GRPOManager:
228
248
  self.model_saved_and_reload_requested = False
229
249
  self.current_model = status["output_dir"]
230
250
  print("Model update complete")
251
+ elif status["status"] == "checkpoint_saved":
252
+ print("Received checkpoint saved status")
253
+ self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
254
+ self.last_checkpoint = status["checkpoint_name"]
255
+ self.saving_checkpoint = False
256
+ print("Checkpoint saved")
231
257
  elif status["status"] == "error":
232
258
  print(f"Training error: {status.get('error', 'Unknown error')}")
233
259
  elif status["status"] == "terminated":
@@ -249,6 +275,10 @@ class GRPOManager:
249
275
  )
250
276
  time.sleep(5)
251
277
 
278
+ while self.saving_checkpoint:
279
+ print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
280
+ time.sleep(5)
281
+
252
282
  try:
253
283
  # Send the batch to the training process
254
284
  self.server_comms_handler.send_data(request.batch)
@@ -256,12 +286,11 @@ class GRPOManager:
256
286
  except Exception as e:
257
287
  print(f"Failed to send batch to training process: {e}")
258
288
 
259
- # We tell the script to save the model. The script will let us know when it's done via the status update handler
260
- # Then we'll actually run the update_model function in the inference manager and finally update the last_inference_update variable
261
- # if self._should_update_model():
262
- # self.server_comms_handler.send_command({"command": "save_model"})
263
-
264
- return self.current_model
289
+ return {
290
+ "current_model": self.current_model,
291
+ "checkpoints": self.checkpoints,
292
+ "last_checkpoint": self.last_checkpoint,
293
+ }
265
294
 
266
295
  def update_model(self, request, inference_manager: InferenceManager):
267
296
  if inference_manager._session:
@@ -286,18 +315,41 @@ class GRPOManager:
286
315
  "Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
287
316
  )
288
317
  time.sleep(5)
289
- return self.current_model
318
+ return {
319
+ "current_model": self.current_model,
320
+ "checkpoints": self.checkpoints,
321
+ "last_checkpoint": self.last_checkpoint,
322
+ }
323
+
324
+ def checkpoint(self, request: GRPOCheckpointRequest):
325
+ self.saving_checkpoint = True
326
+ self.server_comms_handler.send_command(
327
+ {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
328
+ )
329
+ while self.saving_checkpoint:
330
+ print("Waiting for checkpoint to be saved...")
331
+ time.sleep(5)
332
+ return {
333
+ "current_model": self.current_model,
334
+ "checkpoints": self.checkpoints,
335
+ "last_checkpoint": self.last_checkpoint,
336
+ }
290
337
 
291
338
  def terminate(self, inference_manager: InferenceManager):
292
339
  """Clean up resources and save the final model."""
340
+ termination_data = {
341
+ "current_model": self.current_model,
342
+ "checkpoints": self.checkpoints,
343
+ "last_checkpoint": self.last_checkpoint,
344
+ }
293
345
  try:
294
346
  # Stop the inference server
295
347
  if inference_manager.process is not None:
296
348
  inference_manager.kill()
297
349
 
298
350
  # Send termination command through REQ socket
299
- # self.server_comms_handler.send_broadcast({"message": "terminate"})
300
- self.training_process.terminate()
351
+ self.server_comms_handler.send_broadcast({"message": "terminate"})
352
+ # self.training_process.terminate()
301
353
  print("Waiting for training process to finish")
302
354
 
303
355
  # Wait for training process to finish
@@ -336,17 +388,13 @@ class GRPOManager:
336
388
  )
337
389
  output_dir = self.train_kwargs["output_dir"]
338
390
  self.train_kwargs = None
339
- return output_dir
340
391
  else:
341
392
  print("Training terminated, no output directory specified")
342
393
  self.train_kwargs = None
343
- return None
394
+
395
+ return termination_data
344
396
 
345
397
  def _should_update_model(self):
346
- # return (
347
- # self.data_count - self.last_inference_update
348
- # >= self.train_kwargs["update_interval"]
349
- # )
350
398
  return self.model_saved_and_reload_requested
351
399
 
352
400
 
File without changes
@@ -0,0 +1,226 @@
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import logging
5
+ import multiprocessing as mp
6
+ import os
7
+ import random
8
+ import signal
9
+ import sys
10
+ import time
11
+ from typing import List
12
+
13
+ import requests
14
+ import zmq
15
+ from setproctitle import setproctitle
16
+ from sglang.srt.entrypoints.http_server import launch_server
17
+ from sglang.srt.server_args import ServerArgs
18
+ from sglang.srt.utils import is_port_available
19
+ from sglang_router.launch_router import RouterArgs, launch_router
20
+
21
+
22
+ def setup_logger():
23
+ logger = logging.getLogger("router")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ formatter = logging.Formatter(
27
+ "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger()
39
+
40
+
41
+ # Create new process group
42
+ def run_server(server_args, dp_rank):
43
+ """
44
+ Note:
45
+
46
+ 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
47
+ This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
48
+
49
+ Terminal (PGID=100)
50
+ └── Main Python Process (PGID=100)
51
+ └── Server Process 1 (PGID=100)
52
+ └── Scheduler 1
53
+ └── Detokenizer 1
54
+ └── Server Process 2 (PGID=100)
55
+ └── Scheduler 2
56
+ └── Detokenizer 2
57
+
58
+ 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
59
+
60
+ Terminal (PGID=100)
61
+ └── Main Python Process (PGID=200)
62
+ └── Server Process 1 (PGID=300)
63
+ └── Scheduler 1
64
+ └── Detokenizer 1
65
+ └── Server Process 2 (PGID=400)
66
+ └── Scheduler 2
67
+ └── Detokenizer 2
68
+ """
69
+ # create new process group
70
+ os.setpgrp()
71
+
72
+ setproctitle("sglang::server")
73
+ # Set SGLANG_DP_RANK environment variable
74
+ os.environ["SGLANG_DP_RANK"] = str(dp_rank)
75
+
76
+ launch_server(server_args)
77
+
78
+
79
+ def launch_server_process(
80
+ server_args: ServerArgs, worker_port: int, dp_id: int
81
+ ) -> mp.Process:
82
+ """Launch a single server process with the given args and port."""
83
+ server_args = copy.deepcopy(server_args)
84
+ server_args.port = worker_port
85
+ server_args.base_gpu_id = dp_id * server_args.tp_size
86
+ server_args.dp_size = 1
87
+
88
+ proc = mp.Process(target=run_server, args=(server_args, dp_id))
89
+ proc.start()
90
+ return proc
91
+
92
+
93
+ def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
94
+ """Wait for server to be healthy by checking /health endpoint."""
95
+ start_time = time.time()
96
+ url = f"http://{host}:{port}/health"
97
+
98
+ while time.time() - start_time < timeout:
99
+ try:
100
+ response = requests.get(url, timeout=5)
101
+ if response.status_code == 200:
102
+ return True
103
+ except requests.exceptions.RequestException:
104
+ pass
105
+ time.sleep(1)
106
+ return False
107
+
108
+
109
+ def find_available_ports(base_port: int, count: int) -> List[int]:
110
+ """Find consecutive available ports starting from base_port."""
111
+ available_ports = []
112
+ current_port = base_port
113
+
114
+ while len(available_ports) < count:
115
+ if is_port_available(current_port):
116
+ available_ports.append(current_port)
117
+ current_port += random.randint(100, 1000)
118
+
119
+ return available_ports
120
+
121
+
122
+ def cleanup_processes(processes: List[mp.Process]):
123
+ for process in processes:
124
+ logger.info(f"Terminating process group {process.pid}")
125
+ try:
126
+ os.killpg(process.pid, signal.SIGTERM)
127
+ except ProcessLookupError:
128
+ # Process group may already be terminated
129
+ pass
130
+
131
+ # Wait for processes to terminate
132
+ for process in processes:
133
+ process.join(timeout=5)
134
+ if process.is_alive():
135
+ logger.warning(
136
+ f"Process {process.pid} did not terminate gracefully, forcing kill"
137
+ )
138
+ try:
139
+ os.killpg(process.pid, signal.SIGKILL)
140
+ except ProcessLookupError:
141
+ pass
142
+
143
+ logger.info("All process groups terminated")
144
+
145
+
146
+ def main():
147
+ # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
148
+ mp.set_start_method("spawn")
149
+
150
+ parser = argparse.ArgumentParser(
151
+ description="Launch SGLang router and server processes"
152
+ )
153
+
154
+ ServerArgs.add_cli_args(parser)
155
+ RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
156
+ parser.add_argument(
157
+ "--router-dp-worker-base-port",
158
+ type=int,
159
+ default=31000,
160
+ help="Base port number for data parallel workers",
161
+ )
162
+ parser.add_argument(
163
+ "--worker-urls-port",
164
+ type=int,
165
+ help="Port number for worker URLs publisher",
166
+ )
167
+
168
+ args = parser.parse_args()
169
+ server_args = ServerArgs.from_cli_args(args)
170
+ router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
171
+
172
+ # Find available ports for workers
173
+ worker_ports = find_available_ports(
174
+ args.router_dp_worker_base_port, server_args.dp_size
175
+ )
176
+
177
+ # Start server processes
178
+ server_processes = []
179
+
180
+ for i, worker_port in enumerate(worker_ports):
181
+ logger.info(f"Launching DP server process {i} on port {worker_port}")
182
+ proc = launch_server_process(server_args, worker_port, i)
183
+ server_processes.append(proc)
184
+
185
+ signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
186
+ signal.signal(
187
+ signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
188
+ )
189
+ signal.signal(
190
+ signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
191
+ )
192
+
193
+ # Update router args with worker URLs
194
+ worker_urls = [f"http://{server_args.host}:{port}" for port in worker_ports]
195
+ router_args.worker_urls = worker_urls
196
+
197
+ # Publish worker URLs via ZMQ if port is specified
198
+ if args.worker_urls_port:
199
+ try:
200
+ context = zmq.Context()
201
+ socket = context.socket(zmq.PUB)
202
+ socket.bind(f"tcp://*:{args.worker_urls_port}")
203
+ # Give subscribers time to connect
204
+ time.sleep(0.1)
205
+ socket.send_json({"type": "worker_urls", "urls": worker_urls})
206
+ logger.info(
207
+ f"Published worker URLs via ZMQ on port {args.worker_urls_port}"
208
+ )
209
+ socket.close()
210
+ context.term()
211
+ except Exception as e:
212
+ logger.error(f"Failed to publish worker URLs via ZMQ: {e}")
213
+ cleanup_processes(server_processes)
214
+ sys.exit(1)
215
+
216
+ # Start the router
217
+ try:
218
+ launch_router(router_args)
219
+ except Exception as e:
220
+ logger.error(f"Failed to start router: {e}")
221
+ cleanup_processes(server_processes)
222
+ sys.exit(1)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import json
3
3
  import os
4
+ import random
4
5
  import signal
5
6
  import socket
6
7
  import subprocess
@@ -12,6 +13,7 @@ from typing import Any, Dict, Optional
12
13
 
13
14
  import aiohttp
14
15
  import requests
16
+ import zmq
15
17
 
16
18
  from arbor.server.core.config import Settings
17
19
 
@@ -27,6 +29,7 @@ class InferenceManager:
27
29
  self.current_model = None
28
30
  self.inference_count = 0
29
31
  self._session = None
32
+ self.worker_urls = []
30
33
  # Set up signal handler for graceful shutdown
31
34
  signal.signal(signal.SIGINT, self._signal_handler)
32
35
  signal.signal(signal.SIGTERM, self._signal_handler)
@@ -47,7 +50,12 @@ class InferenceManager:
47
50
  def is_server_restarting(self):
48
51
  return self.restarting
49
52
 
50
- def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
53
+ def launch(
54
+ self,
55
+ model: str,
56
+ launch_kwargs: Optional[Dict[str, Any]] = None,
57
+ max_retries: int = 3,
58
+ ):
51
59
  if self.is_server_running():
52
60
  print("Server is already launched.")
53
61
  return
@@ -59,77 +67,122 @@ class InferenceManager:
59
67
  if model.startswith(prefix):
60
68
  model = model[len(prefix) :]
61
69
 
62
- print(f"Grabbing a free port to launch an SGLang server for model {model}")
63
- port = get_free_port()
64
- timeout = launch_kwargs.get("timeout", 1800)
65
- my_env = os.environ.copy()
66
- my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
67
- n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
68
- # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
69
- command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
70
- print(f"Running command: {command}")
71
-
72
- # We will manually stream & capture logs.
73
- process = subprocess.Popen(
74
- command.replace("\\\n", " ").replace("\\", " ").split(),
75
- text=True,
76
- stdout=subprocess.PIPE, # We'll read from pipe
77
- stderr=subprocess.STDOUT, # Merge stderr into stdout
78
- env=my_env,
79
- )
80
-
81
- # A threading.Event to control printing after the server is ready.
82
- # This will store *all* lines (both before and after readiness).
83
- print(f"SGLang server process started with PID {process.pid}.")
84
- stop_printing_event = threading.Event()
85
- logs_buffer = []
86
-
87
- def _tail_process(proc, buffer, stop_event):
88
- while True:
89
- line = proc.stdout.readline()
90
- if not line and proc.poll() is not None:
91
- # Process ended and no new line
92
- break
93
- if line:
94
- buffer.append(line)
95
- # Print only if stop_event is not set
96
- if not stop_event.is_set():
97
- print(f"[SGLang LOG] {line}", end="")
98
-
99
- # Start a background thread to read from the process continuously
100
- thread = threading.Thread(
101
- target=_tail_process,
102
- args=(process, logs_buffer, stop_printing_event),
103
- daemon=True,
104
- )
105
- thread.start()
106
-
107
- # Wait until the server is ready (or times out)
108
- base_url = f"http://localhost:{port}"
109
- try:
110
- wait_for_server(base_url, timeout=timeout)
111
- except TimeoutError:
112
- # If the server doesn't come up, we might want to kill it:
113
- process.kill()
114
- raise
115
-
116
- # Once server is ready, we tell the thread to stop printing further lines.
117
- stop_printing_event.set()
118
-
119
- # A convenience getter so the caller can see all logs so far (and future).
120
- def get_logs() -> str:
121
- # Join them all into a single string, or you might return a list
122
- return "".join(logs_buffer)
123
-
124
- # Let the user know server is up
125
- print(f"Server ready on random port {port}!")
70
+ retries = 0
71
+ while retries < max_retries:
72
+ try:
73
+ print(
74
+ f"Attempt {retries + 1} of {max_retries} to launch server for model {model}"
75
+ )
76
+ print(
77
+ f"Grabbing a free port to launch an SGLang server for model {model}"
78
+ )
79
+ router_port = get_free_port()
80
+ dp_worker_base_port = get_free_port()
81
+ worker_urls_port = get_free_port() # Get a port for worker URLs
82
+
83
+ timeout = launch_kwargs.get("timeout", 1800)
84
+ my_env = os.environ.copy()
85
+ my_env["CUDA_VISIBLE_DEVICES"] = (
86
+ self.settings.arbor_config.inference.gpu_ids
87
+ )
88
+ n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
89
+ command = f"python -m arbor.server.services.inference.sgl_router_launch_server --model-path {model} --dp-size {n_gpus} --port {router_port} --host 0.0.0.0 --disable-radix-cache --router-dp-worker-base-port {dp_worker_base_port} --worker-urls-port {worker_urls_port}"
90
+ print(f"Running command: {command}")
91
+ if launch_kwargs.get("max_context_length"):
92
+ command += (
93
+ f" --context-length {launch_kwargs['max_context_length']}"
94
+ )
126
95
 
127
- self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
128
- self.launch_kwargs["api_key"] = "local"
129
- self.get_logs = get_logs
130
- self.process = process
131
- self.thread = thread
132
- self.current_model = model
96
+ # We will manually stream & capture logs.
97
+ process = subprocess.Popen(
98
+ command.replace("\\\n", " ").replace("\\", " ").split(),
99
+ text=True,
100
+ stdout=subprocess.PIPE, # We'll read from pipe
101
+ stderr=subprocess.STDOUT, # Merge stderr into stdout
102
+ env=my_env,
103
+ )
104
+
105
+ # A threading.Event to control printing after the server is ready.
106
+ # This will store *all* lines (both before and after readiness).
107
+ print(f"SGLang server process started with PID {process.pid}.")
108
+ stop_printing_event = threading.Event()
109
+ logs_buffer = []
110
+
111
+ def _tail_process(proc, buffer, stop_event):
112
+ while True:
113
+ line = proc.stdout.readline()
114
+ if not line and proc.poll() is not None:
115
+ # Process ended and no new line
116
+ break
117
+ if line:
118
+ buffer.append(line)
119
+ # Print only if stop_event is not set
120
+ if not stop_event.is_set():
121
+ print(f"[SGLang LOG] {line}", end="")
122
+
123
+ # Start a background thread to read from the process continuously
124
+ thread = threading.Thread(
125
+ target=_tail_process,
126
+ args=(process, logs_buffer, stop_printing_event),
127
+ daemon=True,
128
+ )
129
+ thread.start()
130
+
131
+ # Get worker URLs before waiting for server
132
+ try:
133
+ worker_urls = get_worker_urls(worker_urls_port)
134
+ print(f"Received worker URLs: {worker_urls}")
135
+ self.worker_urls = worker_urls
136
+ except TimeoutError as e:
137
+ raise Exception(f"Failed to get worker URLs: {e}")
138
+
139
+ # Wait until the server is ready (or times out)
140
+ base_url = f"http://localhost:{router_port}"
141
+ try:
142
+ wait_for_server(base_url, timeout=timeout)
143
+ except TimeoutError:
144
+ # If the server doesn't come up, we might want to kill it:
145
+ process.kill()
146
+ raise
147
+
148
+ # Once server is ready, we tell the thread to stop printing further lines.
149
+ stop_printing_event.set()
150
+
151
+ # A convenience getter so the caller can see all logs so far (and future).
152
+ def get_logs() -> str:
153
+ # Join them all into a single string, or you might return a list
154
+ return "".join(logs_buffer)
155
+
156
+ # Let the user know server is up
157
+ print(f"Server ready on random port {router_port}!")
158
+
159
+ self.launch_kwargs["api_base"] = f"http://localhost:{router_port}/v1"
160
+ self.launch_kwargs["api_key"] = "local"
161
+ self.get_logs = get_logs
162
+ self.process = process
163
+ self.thread = thread
164
+ self.current_model = model
165
+
166
+ # If we get here, the launch was successful
167
+ return
168
+
169
+ except Exception as e:
170
+ retries += 1
171
+ print(
172
+ f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
173
+ )
174
+ # Clean up any failed processes
175
+ if "process" in locals():
176
+ try:
177
+ process.kill()
178
+ except:
179
+ pass
180
+ if retries == max_retries:
181
+ raise Exception(
182
+ f"Failed to launch server after {max_retries} attempts"
183
+ ) from e
184
+ # Wait a bit before retrying
185
+ time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
133
186
 
134
187
  def kill(self):
135
188
  from sglang.utils import terminate_process
@@ -184,7 +237,7 @@ class InferenceManager:
184
237
  print(f"Running inference for model {model}")
185
238
  # Monkeypatch:
186
239
  if model != self.current_model:
187
- print(f"MONKEYPATCH: Model changed from {model} to {self.current_model}")
240
+ print(f"Model changed from {model} to {self.current_model}")
188
241
  model = self.current_model
189
242
  request_json["model"] = model
190
243
 
@@ -214,6 +267,12 @@ class InferenceManager:
214
267
  await self._session.close()
215
268
  self._session = None
216
269
  return None
270
+ except json.decoder.JSONDecodeError:
271
+ print(f"JSON Decode Error during inference: {content}")
272
+ return {
273
+ "error": "JSON Decode Error",
274
+ "content": content if content else "Content is null",
275
+ }
217
276
  except Exception as e:
218
277
  print(f"Error during inference: {e}")
219
278
  raise
@@ -239,8 +298,10 @@ class InferenceManager:
239
298
  self.inference_count = 0
240
299
 
241
300
  tik = time.time()
242
- self.kill()
243
- print("Just killed server")
301
+ # self.kill()
302
+ # print("Just killed server")
303
+ # time.sleep(5)
304
+
244
305
  # Check that output directory exists and was created successfully
245
306
  print(f"Checking that output directory {output_dir} exists")
246
307
  if not os.path.exists(output_dir):
@@ -248,8 +309,27 @@ class InferenceManager:
248
309
  f"Failed to save model - output directory {output_dir} does not exist"
249
310
  )
250
311
 
251
- print("Launching new server")
252
- self.launch(output_dir, self.launch_kwargs)
312
+ print("Directly updating weights from disk")
313
+ for worker_url in self.worker_urls:
314
+ print(f"Updating weights from disk for worker {worker_url}")
315
+ try:
316
+ response = requests.post(
317
+ f"{worker_url}/update_weights_from_disk",
318
+ json={"model_path": output_dir},
319
+ )
320
+ response_json = response.json()
321
+ print(f"Response from update_weights_from_disk: {response_json}")
322
+ # TODO: Check that the response is successful
323
+ except Exception as e:
324
+ print(f"Error during update_weights_from_disk: {e}")
325
+ print(f"Full error during update_weights_from_disk: {str(e)}")
326
+ if hasattr(e, "response") and e.response is not None:
327
+ print(f"Response status code: {e.response.status_code}")
328
+ print(f"Response text: {e.response.text}")
329
+ self.current_model = output_dir
330
+
331
+ # print("Launching new server")
332
+ # self.launch(output_dir, self.launch_kwargs)
253
333
  tok = time.time()
254
334
  self.restarting = False
255
335
  print(f"Time taken to update model: {tok - tik} seconds")
@@ -297,3 +377,28 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
297
377
  except requests.exceptions.RequestException:
298
378
  # Server not up yet, wait and retry
299
379
  time.sleep(1)
380
+
381
+
382
+ def get_worker_urls(zmq_port: int, timeout: float = 30.0) -> list:
383
+ print(f"Attempting to get worker URLs on port {zmq_port} with timeout {timeout}s")
384
+ context = zmq.Context()
385
+ socket = context.socket(zmq.SUB)
386
+ socket.connect(f"tcp://localhost:{zmq_port}")
387
+ socket.setsockopt_string(zmq.SUBSCRIBE, "") # Subscribe to all messages
388
+
389
+ # Set a timeout for receiving
390
+ socket.setsockopt(zmq.RCVTIMEO, int(timeout * 1000))
391
+
392
+ try:
393
+ print("Waiting for worker URLs message...")
394
+ message = socket.recv_json()
395
+ print(f"Received message: {message}")
396
+ if message.get("type") == "worker_urls":
397
+ return message["urls"]
398
+ else:
399
+ raise ValueError(f"Unexpected message type: {message.get('type')}")
400
+ except zmq.error.Again:
401
+ raise TimeoutError(f"Timeout waiting for worker URLs on port {zmq_port}")
402
+ finally:
403
+ socket.close()
404
+ context.term()
@@ -14,7 +14,7 @@ from typing import Any, List, Optional, Union
14
14
  import torch
15
15
  import zmq
16
16
  from accelerate import Accelerator
17
- from accelerate.utils import gather
17
+ from accelerate.utils import broadcast_object_list, gather, gather_object
18
18
  from datasets import Dataset, IterableDataset, load_dataset
19
19
  from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig # type: ignore
20
20
  from torch.utils.data import Dataset
@@ -23,6 +23,7 @@ from transformers import (
23
23
  PreTrainedTokenizerBase,
24
24
  Trainer,
25
25
  TrainerCallback,
26
+ is_wandb_available,
26
27
  )
27
28
  from trl import GRPOConfig, GRPOTrainer
28
29
  from trl.data_utils import maybe_apply_chat_template
@@ -32,6 +33,9 @@ from arbor.server.services.comms.comms import (
32
33
  ArborServerCommsHandler,
33
34
  )
34
35
 
36
+ if is_wandb_available():
37
+ import wandb
38
+
35
39
  last_step_time = None
36
40
  last_queue_pop_time = None
37
41
 
@@ -65,8 +69,9 @@ class ArborGRPOTrainer(GRPOTrainer):
65
69
  ] = (None, None),
66
70
  peft_config: Optional["PeftConfig"] = None,
67
71
  comms_handler: Optional[ArborScriptCommsHandler] = None,
68
- update_interval: Optional[int] = 5,
69
72
  lora: Optional[bool] = False,
73
+ # We do nothing with max_context_length right now
74
+ max_context_length: Optional[int] = None,
70
75
  **kwargs,
71
76
  ):
72
77
 
@@ -85,12 +90,12 @@ class ArborGRPOTrainer(GRPOTrainer):
85
90
  self.peft_config = peft_config
86
91
  self.scale_rewards = scale_rewards
87
92
  self.comms_handler = comms_handler
88
- self.update_interval = update_interval
89
93
 
90
94
  def _generate_and_score_completions(
91
95
  self, batch: List[dict[str, Any]]
92
96
  ) -> dict[str, Union[torch.Tensor, Any]]:
93
97
  device = self.accelerator.device
98
+ mode = "train" if self.model.training else "eval"
94
99
 
95
100
  # Process prompts and completions
96
101
  prompt_completion_texts = []
@@ -106,12 +111,12 @@ class ArborGRPOTrainer(GRPOTrainer):
106
111
  )
107
112
 
108
113
  # Tokenize prompts
109
- prompt_texts = [
114
+ prompts_text = [
110
115
  prompt_completion_text["prompt"]
111
116
  for prompt_completion_text in prompt_completion_texts
112
117
  ]
113
118
  prompt_inputs = self.processing_class(
114
- prompt_texts,
119
+ prompts_text,
115
120
  return_tensors="pt",
116
121
  padding=True,
117
122
  padding_side="left",
@@ -124,12 +129,12 @@ class ArborGRPOTrainer(GRPOTrainer):
124
129
  )
125
130
 
126
131
  # Tokenize completions
127
- completion_texts = [
132
+ completions_text = [
128
133
  prompt_completion_text["completion"]
129
134
  for prompt_completion_text in prompt_completion_texts
130
135
  ]
131
136
  completion_ids = self.processing_class(
132
- completion_texts,
137
+ completions_text,
133
138
  return_tensors="pt",
134
139
  padding=True,
135
140
  add_special_tokens=False,
@@ -156,6 +161,30 @@ class ArborGRPOTrainer(GRPOTrainer):
156
161
  # self._move_model_to_vllm()
157
162
  # self._last_loaded_step = self.state.global_step
158
163
 
164
+ prompt_ids = broadcast_object_list(prompt_ids)
165
+ prompt_mask = broadcast_object_list(prompt_mask)
166
+ completion_ids = broadcast_object_list(completion_ids)
167
+ completion_mask = broadcast_object_list(completion_mask)
168
+
169
+ process_slice = slice(
170
+ self.accelerator.process_index * len(batch),
171
+ (self.accelerator.process_index + 1) * len(batch),
172
+ )
173
+
174
+ prompt_ids = prompt_ids[process_slice]
175
+ prompt_mask = prompt_mask[process_slice]
176
+ completion_ids = completion_ids[process_slice]
177
+ completion_mask = completion_mask[process_slice]
178
+
179
+ is_eos = completion_ids == self.processing_class.eos_token_id
180
+
181
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
182
+ if self.mask_truncated_completions:
183
+ truncated_completions = ~is_eos.any(dim=1)
184
+ completion_mask = (
185
+ completion_mask * (~truncated_completions).unsqueeze(1).int()
186
+ )
187
+
159
188
  prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
160
189
  attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
161
190
 
@@ -164,34 +193,29 @@ class ArborGRPOTrainer(GRPOTrainer):
164
193
  )
165
194
 
166
195
  logits_to_keep = completion_ids.size(1)
196
+ batch_size = (
197
+ self.args.per_device_train_batch_size
198
+ if mode == "train"
199
+ else self.args.per_device_eval_batch_size
200
+ )
167
201
 
168
202
  with torch.no_grad():
169
203
  # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
170
204
  # computation here, and use per_token_logps.detach() instead.
171
- if self.num_iterations > 1:
205
+ if (
206
+ self.num_iterations > 1
207
+ or self.args.steps_per_generation
208
+ > self.args.gradient_accumulation_steps
209
+ ):
172
210
  old_per_token_logps = self._get_per_token_logps(
173
- self.model, prompt_completion_ids, attention_mask, logits_to_keep
174
- )
175
- else:
176
- old_per_token_logps = None
177
-
178
- if self.beta == 0.0:
179
- ref_per_token_logps = None
180
- elif self.ref_model is not None:
181
- ref_per_token_logps = self._get_per_token_logps(
182
- self.ref_model,
211
+ self.model,
183
212
  prompt_completion_ids,
184
213
  attention_mask,
185
214
  logits_to_keep,
215
+ batch_size,
186
216
  )
187
217
  else:
188
- with self.accelerator.unwrap_model(self.model).disable_adapter():
189
- ref_per_token_logps = self._get_per_token_logps(
190
- self.model,
191
- prompt_completion_ids,
192
- attention_mask,
193
- logits_to_keep,
194
- )
218
+ old_per_token_logps = None
195
219
 
196
220
  rewards = torch.tensor(
197
221
  [example["reward"] for example in batch], dtype=torch.float32
@@ -219,7 +243,56 @@ class ArborGRPOTrainer(GRPOTrainer):
219
243
  )
220
244
  advantages = advantages[process_slice]
221
245
 
222
- ## Logged Metrics Removed Here
246
+ # Log the metrics
247
+ if mode == "train":
248
+ self.state.num_input_tokens_seen += (
249
+ self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
250
+ )
251
+ self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
252
+
253
+ # log completion lengths, mean, min, max
254
+ agg_completion_mask = self.accelerator.gather_for_metrics(
255
+ completion_mask.sum(1)
256
+ )
257
+ self._metrics[mode]["completions/mean_length"].append(
258
+ agg_completion_mask.float().mean().item()
259
+ )
260
+ self._metrics[mode]["completions/min_length"].append(
261
+ agg_completion_mask.float().min().item()
262
+ )
263
+ self._metrics[mode]["completions/max_length"].append(
264
+ agg_completion_mask.float().max().item()
265
+ )
266
+
267
+ # identify sequences that terminated with EOS and log their lengths
268
+ agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
269
+ term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
270
+ clipped_completions_ratio = 1 - len(term_completion_mask) / len(
271
+ agg_completion_mask
272
+ )
273
+ self._metrics[mode]["completions/clipped_ratio"].append(
274
+ clipped_completions_ratio
275
+ )
276
+ if len(term_completion_mask) == 0:
277
+ # edge case where no completed sequences are found
278
+ term_completion_mask = torch.zeros(1, device=device)
279
+ self._metrics[mode]["completions/mean_terminated_length"].append(
280
+ term_completion_mask.float().mean().item()
281
+ )
282
+ self._metrics[mode]["completions/min_terminated_length"].append(
283
+ term_completion_mask.float().min().item()
284
+ )
285
+ self._metrics[mode]["completions/max_terminated_length"].append(
286
+ term_completion_mask.float().max().item()
287
+ )
288
+
289
+ # Calculate mean reward
290
+ self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
291
+ self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
292
+
293
+ # Log prompt and completion texts
294
+ self._textual_logs["prompt"].extend(gather_object(prompts_text))
295
+ self._textual_logs["completion"].extend(gather_object(completions_text))
223
296
 
224
297
  return {
225
298
  "prompt_ids": prompt_ids,
@@ -227,7 +300,6 @@ class ArborGRPOTrainer(GRPOTrainer):
227
300
  "completion_ids": completion_ids,
228
301
  "completion_mask": completion_mask,
229
302
  "old_per_token_logps": old_per_token_logps,
230
- "ref_per_token_logps": ref_per_token_logps,
231
303
  "advantages": advantages,
232
304
  }
233
305
 
@@ -326,15 +398,10 @@ class CommandMonitor:
326
398
  print(
327
399
  f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
328
400
  )
329
- # Wait until data queue is empty before saving
330
-
331
401
  while (
332
402
  time_since_last_step() <= 10
333
403
  or get_time_since_last_queue_pop() <= 10
334
404
  ):
335
- # print(
336
- # f"Waiting for data queue to empty...{self.comms_handler.get_data_queue_size()}"
337
- # )
338
405
  print(f"Waiting for steps to finish")
339
406
  print(
340
407
  f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
@@ -342,15 +409,12 @@ class CommandMonitor:
342
409
  print(
343
410
  f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
344
411
  )
345
- time.sleep(5) # Small delay to prevent busy waiting)
412
+ time.sleep(5)
346
413
  print("[Training Script] Saving model...")
347
-
348
414
  if self.trainer.peft_config:
349
-
350
415
  self.trainer.save_model(
351
416
  output_dir=self.trainer.args.output_dir + "/adapter/"
352
417
  )
353
-
354
418
  _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
355
419
  self.trainer.args.output_dir + "/adapter/",
356
420
  config=self.trainer.peft_config,
@@ -373,6 +437,56 @@ class CommandMonitor:
373
437
  "output_dir": self.trainer.args.output_dir,
374
438
  }
375
439
  )
440
+ elif command.get("command") == "save_checkpoint":
441
+ print(
442
+ f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
443
+ )
444
+ while (
445
+ time_since_last_step() <= 10
446
+ or get_time_since_last_queue_pop() <= 10
447
+ ):
448
+ print(f"Waiting for steps to finish")
449
+ print(
450
+ f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
451
+ )
452
+ print(
453
+ f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
454
+ )
455
+ time.sleep(5)
456
+ if self.trainer.peft_config:
457
+ self.trainer.save_model(
458
+ output_dir=self.trainer.args.output_dir
459
+ + f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
460
+ )
461
+ _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
462
+ self.trainer.args.output_dir
463
+ + f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
464
+ config=self.trainer.peft_config,
465
+ )
466
+ merged_model = _model_to_merge.merge_and_unload()
467
+ merged_model.save_pretrained(
468
+ self.trainer.args.output_dir
469
+ + f"/checkpoints/{command.get('checkpoint_name')}/",
470
+ safe_serialization=True,
471
+ )
472
+ self.trainer.processing_class.save_pretrained(
473
+ self.trainer.args.output_dir
474
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
475
+ )
476
+ else:
477
+ self.trainer.save_model(
478
+ output_dir=self.trainer.args.output_dir
479
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
480
+ )
481
+ self.comms_handler.send_status(
482
+ {
483
+ "status": "checkpoint_saved",
484
+ "checkpoint_name": command.get("checkpoint_name"),
485
+ "output_dir": self.trainer.args.output_dir
486
+ + f"/checkpoints/{command.get('checkpoint_name')}/",
487
+ }
488
+ )
489
+
376
490
  except Exception as e:
377
491
  print(e)
378
492
  self.comms_handler.send_status({"status": "error", "error": str(e)})
@@ -385,13 +499,15 @@ class CommandMonitor:
385
499
  for broadcast in self.comms_handler.receive_broadcast():
386
500
  print(f"!!!Received broadcast: {broadcast}")
387
501
  if broadcast.get("message") == "terminate":
388
- self.trainer.control.should_training_stop = True
389
- self.comms_handler.send_status(
390
- {
391
- "status": "Received termination command",
392
- "process_id": self.trainer.accelerator.process_index,
393
- }
394
- )
502
+ # self.trainer.control.should_training_stop = True
503
+ # self.comms_handler.send_status(
504
+ # {
505
+ # "status": "Received termination command",
506
+ # "process_id": self.trainer.accelerator.process_index,
507
+ # }
508
+ # )
509
+ if self.trainer.accelerator.is_main_process:
510
+ self.trainer.accelerator.end_training()
395
511
  except Exception as e:
396
512
  self.comms_handler.send_status({"status": "error", "error": str(e)})
397
513
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -23,6 +23,7 @@ Requires-Dist: pyzmq>=26.4.0
23
23
  Requires-Dist: pyyaml>=6.0.2
24
24
  Requires-Dist: sglang[all]>=0.4.5.post3
25
25
  Requires-Dist: sglang-router
26
+ Requires-Dist: wandb
26
27
  Dynamic: license-file
27
28
 
28
29
  <p align="center">
@@ -5,10 +5,10 @@ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
5
5
  arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
6
  arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
7
7
  arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
- arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhcnrKIk,5594
8
+ arbor/server/api/models/schemas.py,sha256=KCHav1nPFbQEynrcO-MObhRmoOrdFvfGuVogApynOCA,6210
9
9
  arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
11
- arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
11
+ arbor/server/api/routes/grpo.py,sha256=AbQ_BHgk-Om5U0qSt_FeJfyBJ0vItpfrnCNtJgD6p5k,2245
12
12
  arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
13
13
  arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
14
14
  arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
@@ -17,18 +17,20 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
17
17
  arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
- arbor/server/services/grpo_manager.py,sha256=TAU2BMHgbCgiAvKNVd2Y8N20SR4qEms3lChA4Z0ZzyY,13777
21
- arbor/server/services/inference_manager.py,sha256=q4RVUqh1snGfW-AADkCqW8hC5x3WAZNe0jwXKOY5joU,10685
20
+ arbor/server/services/grpo_manager.py,sha256=-_0xjENvIrOAtHACkFPMYox9YAeckHbpX2FkrmKrWuU,15448
21
+ arbor/server/services/inference_manager.py,sha256=Ju39_7EWySzAAk7ftz-AzSNBEo0tlayloPVS0XRAp8E,15304
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
- arbor/server/services/scripts/grpo_training.py,sha256=Q9jwnbRdXAv_jVgrChLX6IiB3BLZU1F3BP6mBV0DVik,20889
26
+ arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ arbor/server/services/inference/sgl_router_launch_server.py,sha256=eqTW6nDqqoRMISHfv5ScBCrolqLBp9zyxPXqHUlP6uo,6988
28
+ arbor/server/services/scripts/grpo_training.py,sha256=eMT5cIMolAzhukANH1WRmPdxIkvLbsbrggdGFCMGMHc,26474
27
29
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
30
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- arbor_ai-0.1.12.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
- arbor_ai-0.1.12.dist-info/METADATA,sha256=upqnB_F9JDLytHm4AFrDnvPaOHdj8XiBCdrlam0rgRc,2413
31
- arbor_ai-0.1.12.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
- arbor_ai-0.1.12.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
- arbor_ai-0.1.12.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
- arbor_ai-0.1.12.dist-info/RECORD,,
31
+ arbor_ai-0.1.14.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
32
+ arbor_ai-0.1.14.dist-info/METADATA,sha256=vw8RnMPdGi36ji4rpjAldkOuCbxxjV4MFVi6yW-0kas,2434
33
+ arbor_ai-0.1.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ arbor_ai-0.1.14.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
35
+ arbor_ai-0.1.14.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
36
+ arbor_ai-0.1.14.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5