arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbor/cli.py CHANGED
@@ -1,3 +1,6 @@
1
+ import os
2
+ from datetime import datetime
3
+
1
4
  import click
2
5
  import uvicorn
3
6
 
@@ -10,6 +13,14 @@ from arbor.server.services.job_manager import JobManager
10
13
  from arbor.server.services.training_manager import TrainingManager
11
14
 
12
15
 
16
+ def make_log_dir(storage_path: str):
17
+ # Create a timestamped log directory under the storage path
18
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
19
+ log_dir = os.path.join(storage_path, "logs", timestamp)
20
+ os.makedirs(log_dir, exist_ok=True)
21
+ return log_dir
22
+
23
+
13
24
  @click.group()
14
25
  def cli():
15
26
  pass
@@ -26,6 +37,7 @@ def create_app(arbor_config_path: str):
26
37
  """
27
38
  # Create new settings instance with overrides
28
39
  settings = Settings.load_from_yaml(arbor_config_path)
40
+ app.state.log_dir = make_log_dir(settings.STORAGE_PATH)
29
41
 
30
42
  # Initialize services with settings
31
43
  file_manager = FileManager(settings=settings)
@@ -178,7 +178,6 @@ class ChatCompletionModel(BaseModel):
178
178
 
179
179
  class GRPORequest(BaseModel):
180
180
  model: str
181
- update_inference_model: bool
182
181
  batch: List[dict]
183
182
 
184
183
 
@@ -38,18 +38,13 @@ def run_grpo_step(
38
38
  return GRPOStepResponse(status="success", **step_data)
39
39
 
40
40
 
41
- @router.post("/update_model", response_model=GRPOStepResponse)
42
- def update_model(request: Request):
43
- grpo_manager = request.app.state.grpo_manager
44
- inference_manager = request.app.state.inference_manager
45
- update_model_data = grpo_manager.update_model(request, inference_manager)
46
- return GRPOStepResponse(status="success", **update_model_data)
47
-
48
-
49
41
  @router.post("/checkpoint", response_model=GRPOCheckpointResponse)
50
42
  def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
51
43
  grpo_manager = request.app.state.grpo_manager
52
- checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
44
+ inference_manager = request.app.state.inference_manager
45
+ checkpoint_data = grpo_manager.checkpoint(
46
+ grpo_checkpoint_request, inference_manager
47
+ )
53
48
  return GRPOCheckpointResponse(status="success", **checkpoint_data)
54
49
 
55
50
 
@@ -1,4 +1,5 @@
1
- import time
1
+ import json
2
+ import uuid
2
3
 
3
4
  from fastapi import APIRouter, Request
4
5
 
@@ -12,29 +13,38 @@ async def run_inference(
12
13
  inference_manager = request.app.state.inference_manager
13
14
  raw_json = await request.json()
14
15
 
16
+ # Generate a random hex ID
17
+ request_id = str(uuid.uuid4())
18
+ # Create requests directory if it doesn't exist
19
+ with open(f"{request.app.state.log_dir}/inference_requests.jsonl", "a") as f:
20
+ f.write(json.dumps({"id": request_id, "request": raw_json}) + "\n")
21
+
22
+ request_model = raw_json["model"]
15
23
  prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
16
24
  for prefix in prefixes:
17
- if raw_json["model"].startswith(prefix):
18
- raw_json["model"] = raw_json["model"][len(prefix) :]
25
+ if request_model.startswith(prefix):
26
+ request_model = request_model[len(prefix) :]
19
27
 
20
28
  # if a server isnt running, launch one
21
- if (
22
- not inference_manager.is_server_running()
23
- and not inference_manager.is_server_restarting()
24
- ):
29
+ if not inference_manager.is_server_running():
25
30
  print("No model is running, launching model...")
26
- inference_manager.launch(raw_json["model"])
31
+ inference_manager.launch(request_model)
27
32
 
28
- if inference_manager.is_server_restarting():
29
- print("Waiting for server to finish restarting...")
30
- while inference_manager.is_server_restarting():
31
- time.sleep(5)
32
- # Update the model in the request
33
- raw_json["model"] = inference_manager.current_model
33
+ # if the requested model is different from the launched model, swap the server
34
+ if request_model != inference_manager.launched_model:
35
+ print(
36
+ f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
37
+ )
38
+ inference_manager.kill()
39
+ inference_manager.launch(request_model)
40
+ print(f"Model swapped to {request_model}")
34
41
 
35
42
  # forward the request to the inference server
36
43
  completion = await inference_manager.run_inference(raw_json)
37
44
 
45
+ with open(f"{request.app.state.log_dir}/inference_responses.jsonl", "a") as f:
46
+ f.write(json.dumps({"id": request_id, "response": completion}) + "\n")
47
+
38
48
  return completion
39
49
 
40
50
 
@@ -13,6 +13,8 @@ from datetime import datetime
13
13
  from pathlib import Path
14
14
  from typing import Optional
15
15
 
16
+ import psutil
17
+
16
18
  from arbor.server.api.models.schemas import (
17
19
  GRPOCheckpointRequest,
18
20
  GRPOConfigRequest,
@@ -31,8 +33,9 @@ class GRPOManager:
31
33
  self.train_kwargs = None
32
34
  self.server_comms_handler = None
33
35
  self.status_thread = None
34
- self.model_saved_and_reload_requested = False
35
36
  self.saving_checkpoint = False
37
+ self.saving_model = False
38
+ self.terminating = False
36
39
 
37
40
  self.checkpoints = {}
38
41
  self.last_checkpoint = None
@@ -45,8 +48,10 @@ class GRPOManager:
45
48
  def _signal_handler(self, signum, frame):
46
49
  """Handle keyboard interrupt (SIGINT) gracefully."""
47
50
  print("\nReceived keyboard interrupt. Shutting down gracefully...")
48
- self.terminate(None)
49
- sys.exit(0)
51
+ # Sleep for a bit to let async operations go through
52
+ time.sleep(2)
53
+ if self.training_process is not None:
54
+ self.cleanup_termination(None)
50
55
 
51
56
  def make_output_dir(
52
57
  self, model_name: str, run_suffix: Optional[str] = None
@@ -122,6 +127,17 @@ class GRPOManager:
122
127
 
123
128
  self.current_model = request.model
124
129
 
130
+ # The inference server has to be launched before the training process
131
+ # Launch the inference server
132
+ # launch_kwargs = {
133
+ # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
134
+ # }
135
+ inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
136
+ "max_context_length", None
137
+ )
138
+ print("Launching inference server...")
139
+ inference_manager.launch(self.current_model)
140
+
125
141
  # Initialize ZMQ socket manager - no need for connection acceptance thread anymore
126
142
  self.server_comms_handler = ArborServerCommsHandler()
127
143
 
@@ -171,6 +187,10 @@ class GRPOManager:
171
187
  str(self.server_comms_handler.broadcast_port),
172
188
  "--handshake_port",
173
189
  str(self.server_comms_handler.handshake_port),
190
+ "--vllm_port",
191
+ str(inference_manager.port),
192
+ "--vllm_group_port",
193
+ str(inference_manager.group_port),
174
194
  # Training args
175
195
  "--model",
176
196
  self.current_model,
@@ -221,33 +241,38 @@ class GRPOManager:
221
241
  self.status_thread.start()
222
242
  self.server_comms_handler.wait_for_clients(num_processes)
223
243
 
224
- # Launch the inference server
225
- print("Launching inference server...")
226
- # launch_kwargs = {
227
- # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
228
- # }
229
- inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
230
- "max_context_length", None
231
- )
232
- inference_manager.launch(self.current_model)
244
+ async def _handle_weight_update_start(self, inference_manager):
245
+ """Handle weight update start in the event loop"""
246
+ await inference_manager.start_weight_update()
247
+
248
+ async def _handle_weight_update_complete(self, inference_manager):
249
+ """Handle weight update complete in the event loop"""
250
+ await inference_manager.complete_weight_update()
251
+
252
+ def _run_in_loop(self, coro):
253
+ """Run a coroutine in the event loop from a thread"""
254
+ future = asyncio.run_coroutine_threadsafe(coro, self.event_loop)
255
+ return future.result()
233
256
 
234
257
  def _handle_status_updates(self, inference_manager: InferenceManager):
235
258
  """Handle status updates from training process using ZMQ SUB socket"""
236
259
  print("Starting status update handler...")
237
260
  try:
238
-
239
261
  for status in self.server_comms_handler.receive_status():
240
262
  print(f"Received status update: {status}")
241
- if status["status"] == "model_saved":
263
+ if status["status"] == "weight_update_start":
264
+ # Block inference calls by incrementing counter
265
+ inference_manager.start_weight_update()
266
+ elif status["status"] == "weight_update_complete":
267
+ # Decrement counter to potentially allow inference calls again
268
+ inference_manager.complete_weight_update()
269
+ elif status["status"] == "model_saved":
242
270
  print("Updating inference model...")
243
271
  # There is a case where this status is sent multiple times
244
272
  # We need to make sure we only update the model once
245
- if self._should_update_model():
246
- inference_manager.update_model(status["output_dir"])
247
- # self.last_inference_update = self.data_count
248
- self.model_saved_and_reload_requested = False
249
- self.current_model = status["output_dir"]
250
- print("Model update complete")
273
+ self.current_model = status["output_dir"]
274
+ self.saving_model = False
275
+ print("Model update complete")
251
276
  elif status["status"] == "checkpoint_saved":
252
277
  print("Received checkpoint saved status")
253
278
  self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
@@ -257,24 +282,19 @@ class GRPOManager:
257
282
  elif status["status"] == "error":
258
283
  print(f"Training error: {status.get('error', 'Unknown error')}")
259
284
  elif status["status"] == "terminated":
285
+ self.terminating = False
260
286
  print("Training process terminated")
261
- break
262
287
  except Exception as e:
263
288
  print(f"Error in status update handler: {e}")
289
+ # Make sure to allow inference if there's an error
290
+ try:
291
+ inference_manager.complete_weight_update()
292
+ except:
293
+ pass
264
294
 
265
295
  def grpo_step(
266
296
  self, request: GRPORequest, inference_manager: InferenceManager
267
297
  ) -> str:
268
- while inference_manager.is_server_restarting():
269
- print("Inferece manager restarting, waiting for GRPO step")
270
- time.sleep(5)
271
-
272
- while self._should_update_model():
273
- print(
274
- f"Waiting for model update. Data count: {self.data_count}, Last inference update: {self.last_inference_update}"
275
- )
276
- time.sleep(5)
277
-
278
298
  while self.saving_checkpoint:
279
299
  print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
280
300
  time.sleep(5)
@@ -283,8 +303,10 @@ class GRPOManager:
283
303
  # Send the batch to the training process
284
304
  self.server_comms_handler.send_data(request.batch)
285
305
  self.data_count += 1
306
+
286
307
  except Exception as e:
287
308
  print(f"Failed to send batch to training process: {e}")
309
+ raise
288
310
 
289
311
  return {
290
312
  "current_model": self.current_model,
@@ -292,36 +314,15 @@ class GRPOManager:
292
314
  "last_checkpoint": self.last_checkpoint,
293
315
  }
294
316
 
295
- def update_model(self, request, inference_manager: InferenceManager):
296
- if inference_manager._session:
297
- # Create a new event loop if one doesn't exist
298
- try:
299
- loop = asyncio.get_event_loop()
300
- except RuntimeError:
301
- loop = asyncio.new_event_loop()
302
- asyncio.set_event_loop(loop)
303
-
304
- # Run the session closure in the event loop
305
- loop.run_until_complete(inference_manager._session.close())
306
- inference_manager._session = None
307
-
308
- inference_manager.inference_count = 0
309
- inference_manager.restarting = True
310
-
311
- self.model_saved_and_reload_requested = True
312
- self.server_comms_handler.send_command({"command": "save_model"})
313
- while self.model_saved_and_reload_requested:
314
- print(
315
- "Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
316
- )
317
+ def checkpoint(
318
+ self, request: GRPOCheckpointRequest, inference_manager: InferenceManager
319
+ ):
320
+ while (
321
+ inference_manager.is_updating
322
+ ): # Use the property instead of direct access
323
+ print("Waiting for weight updates to finish before checkpointing...")
317
324
  time.sleep(5)
318
- return {
319
- "current_model": self.current_model,
320
- "checkpoints": self.checkpoints,
321
- "last_checkpoint": self.last_checkpoint,
322
- }
323
325
 
324
- def checkpoint(self, request: GRPOCheckpointRequest):
325
326
  self.saving_checkpoint = True
326
327
  self.server_comms_handler.send_command(
327
328
  {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
@@ -337,71 +338,143 @@ class GRPOManager:
337
338
 
338
339
  def terminate(self, inference_manager: InferenceManager):
339
340
  """Clean up resources and save the final model."""
341
+ time.sleep(5)
342
+
343
+ while (
344
+ inference_manager and inference_manager.is_updating
345
+ ): # Use the property instead of direct access
346
+ print("Waiting for final weight updates to finish before saving...")
347
+ time.sleep(5)
348
+
349
+ print("sending save model command")
350
+ self.saving_model = True
351
+ self.server_comms_handler.send_command({"command": "save_model"})
352
+ while self.saving_model:
353
+ print("Waiting for final model to be saved...")
354
+ time.sleep(5)
355
+
340
356
  termination_data = {
341
357
  "current_model": self.current_model,
342
358
  "checkpoints": self.checkpoints,
343
359
  "last_checkpoint": self.last_checkpoint,
344
360
  }
345
- try:
346
- # Stop the inference server
347
- if inference_manager.process is not None:
348
- inference_manager.kill()
349
361
 
350
- # Send termination command through REQ socket
351
- self.server_comms_handler.send_broadcast({"message": "terminate"})
352
- # self.training_process.terminate()
353
- print("Waiting for training process to finish")
362
+ print("sending termination command")
363
+ self.terminating = True
364
+ self.server_comms_handler.send_command({"command": "terminate"})
365
+ print("Waiting for training process to finish...")
354
366
 
355
- # Wait for training process to finish
367
+ # Wait for at most 15 seconds for termination
368
+ start_time = time.time()
369
+ while self.terminating:
370
+ if time.time() - start_time > 15:
371
+ print(
372
+ "Termination wait timed out after 15 seconds, proceeding with cleanup..."
373
+ )
374
+ break
375
+ print("Waiting for run to be terminated...")
376
+ time.sleep(3)
377
+
378
+ print("Doing cleanup")
379
+ self.cleanup_termination(inference_manager)
380
+
381
+ if self.train_kwargs and "output_dir" in self.train_kwargs:
382
+ print(
383
+ f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
384
+ )
385
+ if not os.path.exists(self.train_kwargs["output_dir"]):
386
+ print(
387
+ f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
388
+ )
389
+ output_dir = self.train_kwargs["output_dir"]
390
+ self.train_kwargs = None
391
+ else:
392
+ print("Training terminated, no output directory specified")
393
+ self.train_kwargs = None
394
+
395
+ return termination_data
396
+
397
+ def cleanup_termination(self, inference_manager):
398
+ try:
399
+ # Kill training process and all its children (accelerate launcher creates multiple processes)
356
400
  if self.training_process:
357
- self.training_process.wait(timeout=30)
401
+ print("Terminating training process and its children...")
402
+ try:
403
+ parent = psutil.Process(self.training_process.pid)
404
+ # Get all child processes including grandchildren
405
+ children = parent.children(recursive=True)
406
+
407
+ # Send SIGTERM to children first
408
+ for child in children:
409
+ try:
410
+ child.send_signal(signal.SIGTERM)
411
+ except psutil.NoSuchProcess:
412
+ pass
413
+
414
+ # Send SIGTERM to parent
415
+ parent.send_signal(signal.SIGTERM)
416
+
417
+ # Wait for processes to terminate gracefully
418
+ gone, alive = psutil.wait_procs(children + [parent], timeout=10)
419
+
420
+ # If any processes are still alive, force kill them
421
+ for p in alive:
422
+ try:
423
+ p.kill() # SIGKILL
424
+ except psutil.NoSuchProcess:
425
+ pass
426
+
427
+ except psutil.NoSuchProcess:
428
+ print(f"Process {self.training_process.pid} not found")
429
+ except Exception as e:
430
+ print(f"Error killing training process tree: {e}")
431
+ # Fallback to basic termination
432
+ self.training_process.terminate()
433
+ try:
434
+ self.training_process.wait(timeout=10)
435
+ except subprocess.TimeoutExpired:
436
+ self.training_process.kill()
437
+ self.training_process.wait(timeout=10)
358
438
 
359
- except Exception as e:
360
- print(f"Error during termination: {e}")
361
- finally:
362
439
  # Clean up ZMQ connections
363
440
  if self.server_comms_handler:
441
+ print("Closing ZMQ connections...")
364
442
  self.server_comms_handler.close()
365
443
 
366
- # Force kill training process if still running
367
- if self.training_process and self.training_process.poll() is None:
368
- self.training_process.kill()
369
- self.training_process.wait()
444
+ if inference_manager and inference_manager.process is not None:
445
+ print("Killing inference manager...")
446
+ inference_manager.kill()
370
447
 
371
- # Reinitialize incase we want to start a new training run
448
+ # Reinitialize in case we want to start a new training run
449
+ self.training_process = None
450
+ self.current_model = None
451
+ self.server_comms_handler = None
452
+ self.status_thread = None
453
+ self.data_count = 0
454
+ print("Cleanup completed successfully")
455
+ except Exception as e:
456
+ print(f"Error during cleanup: {e}")
457
+ # Still reset state even if cleanup fails
372
458
  self.training_process = None
373
459
  self.current_model = None
374
460
  self.server_comms_handler = None
375
461
  self.status_thread = None
376
- self.model_saved_and_reload_requested = False
377
-
378
462
  self.data_count = 0
379
- self.last_inference_update = 0
380
-
381
- if self.train_kwargs and "output_dir" in self.train_kwargs:
382
- print(
383
- f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
384
- )
385
- if not os.path.exists(self.train_kwargs["output_dir"]):
386
- print(
387
- f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
388
- )
389
- output_dir = self.train_kwargs["output_dir"]
390
- self.train_kwargs = None
391
- else:
392
- print("Training terminated, no output directory specified")
393
- self.train_kwargs = None
394
-
395
- return termination_data
396
-
397
- def _should_update_model(self):
398
- return self.model_saved_and_reload_requested
399
463
 
400
464
 
401
465
  def get_free_port() -> int:
402
466
  """
403
- Return a free TCP port on localhost.
467
+ Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
404
468
  """
405
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
406
- s.bind(("localhost", 0))
407
- return s.getsockname()[1]
469
+ import random
470
+ import socket
471
+
472
+ ports = []
473
+ for _ in range(random.randint(5, 10)):
474
+ try:
475
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
476
+ s.bind(("localhost", 0))
477
+ ports.append(s.getsockname()[1])
478
+ except Exception as e:
479
+ print(f"Error binding to port: {e}")
480
+ return random.choice(ports)