arbor-ai 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbor/cli.py CHANGED
@@ -1,3 +1,6 @@
1
+ import os
2
+ from datetime import datetime
3
+
1
4
  import click
2
5
  import uvicorn
3
6
 
@@ -10,6 +13,14 @@ from arbor.server.services.job_manager import JobManager
10
13
  from arbor.server.services.training_manager import TrainingManager
11
14
 
12
15
 
16
+ def make_log_dir(storage_path: str):
17
+ # Create a timestamped log directory under the storage path
18
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
19
+ log_dir = os.path.join(storage_path, "logs", timestamp)
20
+ os.makedirs(log_dir, exist_ok=True)
21
+ return log_dir
22
+
23
+
13
24
  @click.group()
14
25
  def cli():
15
26
  pass
@@ -26,6 +37,7 @@ def create_app(arbor_config_path: str):
26
37
  """
27
38
  # Create new settings instance with overrides
28
39
  settings = Settings.load_from_yaml(arbor_config_path)
40
+ app.state.log_dir = make_log_dir(settings.STORAGE_PATH)
29
41
 
30
42
  # Initialize services with settings
31
43
  file_manager = FileManager(settings=settings)
@@ -49,7 +49,10 @@ def update_model(request: Request):
49
49
  @router.post("/checkpoint", response_model=GRPOCheckpointResponse)
50
50
  def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
51
51
  grpo_manager = request.app.state.grpo_manager
52
- checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
52
+ inference_manager = request.app.state.inference_manager
53
+ checkpoint_data = grpo_manager.checkpoint(
54
+ grpo_checkpoint_request, inference_manager
55
+ )
53
56
  return GRPOCheckpointResponse(status="success", **checkpoint_data)
54
57
 
55
58
 
@@ -1,4 +1,5 @@
1
- import time
1
+ import json
2
+ import uuid
2
3
 
3
4
  from fastapi import APIRouter, Request
4
5
 
@@ -12,29 +13,23 @@ async def run_inference(
12
13
  inference_manager = request.app.state.inference_manager
13
14
  raw_json = await request.json()
14
15
 
15
- prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
16
- for prefix in prefixes:
17
- if raw_json["model"].startswith(prefix):
18
- raw_json["model"] = raw_json["model"][len(prefix) :]
16
+ # Generate a random hex ID
17
+ request_id = str(uuid.uuid4())
18
+ # Create requests directory if it doesn't exist
19
+ with open(f"{request.app.state.log_dir}/inference_requests.jsonl", "a") as f:
20
+ f.write(json.dumps({"id": request_id, "request": raw_json}) + "\n")
19
21
 
20
22
  # if a server isnt running, launch one
21
- if (
22
- not inference_manager.is_server_running()
23
- and not inference_manager.is_server_restarting()
24
- ):
23
+ if not inference_manager.is_server_running():
25
24
  print("No model is running, launching model...")
26
25
  inference_manager.launch(raw_json["model"])
27
26
 
28
- if inference_manager.is_server_restarting():
29
- print("Waiting for server to finish restarting...")
30
- while inference_manager.is_server_restarting():
31
- time.sleep(5)
32
- # Update the model in the request
33
- raw_json["model"] = inference_manager.current_model
34
-
35
27
  # forward the request to the inference server
36
28
  completion = await inference_manager.run_inference(raw_json)
37
29
 
30
+ with open(f"{request.app.state.log_dir}/inference_responses.jsonl", "a") as f:
31
+ f.write(json.dumps({"id": request_id, "response": completion}) + "\n")
32
+
38
33
  return completion
39
34
 
40
35
 
@@ -13,6 +13,8 @@ from datetime import datetime
13
13
  from pathlib import Path
14
14
  from typing import Optional
15
15
 
16
+ import psutil
17
+
16
18
  from arbor.server.api.models.schemas import (
17
19
  GRPOCheckpointRequest,
18
20
  GRPOConfigRequest,
@@ -31,8 +33,9 @@ class GRPOManager:
31
33
  self.train_kwargs = None
32
34
  self.server_comms_handler = None
33
35
  self.status_thread = None
34
- self.model_saved_and_reload_requested = False
35
36
  self.saving_checkpoint = False
37
+ self.saving_model = False
38
+ self.terminating = False
36
39
 
37
40
  self.checkpoints = {}
38
41
  self.last_checkpoint = None
@@ -45,8 +48,10 @@ class GRPOManager:
45
48
  def _signal_handler(self, signum, frame):
46
49
  """Handle keyboard interrupt (SIGINT) gracefully."""
47
50
  print("\nReceived keyboard interrupt. Shutting down gracefully...")
48
- self.terminate(None)
49
- sys.exit(0)
51
+ # Sleep for a bit to let async operations go through
52
+ time.sleep(2)
53
+ if self.training_process is not None:
54
+ self.cleanup_termination(None)
50
55
 
51
56
  def make_output_dir(
52
57
  self, model_name: str, run_suffix: Optional[str] = None
@@ -122,6 +127,17 @@ class GRPOManager:
122
127
 
123
128
  self.current_model = request.model
124
129
 
130
+ # The inference server has to be launched before the training process
131
+ # Launch the inference server
132
+ # launch_kwargs = {
133
+ # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
134
+ # }
135
+ inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
136
+ "max_context_length", None
137
+ )
138
+ print("Launching inference server...")
139
+ inference_manager.launch(self.current_model)
140
+
125
141
  # Initialize ZMQ socket manager - no need for connection acceptance thread anymore
126
142
  self.server_comms_handler = ArborServerCommsHandler()
127
143
 
@@ -171,6 +187,10 @@ class GRPOManager:
171
187
  str(self.server_comms_handler.broadcast_port),
172
188
  "--handshake_port",
173
189
  str(self.server_comms_handler.handshake_port),
190
+ "--vllm_port",
191
+ str(inference_manager.port),
192
+ "--vllm_group_port",
193
+ str(inference_manager.group_port),
174
194
  # Training args
175
195
  "--model",
176
196
  self.current_model,
@@ -221,33 +241,38 @@ class GRPOManager:
221
241
  self.status_thread.start()
222
242
  self.server_comms_handler.wait_for_clients(num_processes)
223
243
 
224
- # Launch the inference server
225
- print("Launching inference server...")
226
- # launch_kwargs = {
227
- # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
228
- # }
229
- inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
230
- "max_context_length", None
231
- )
232
- inference_manager.launch(self.current_model)
244
+ async def _handle_weight_update_start(self, inference_manager):
245
+ """Handle weight update start in the event loop"""
246
+ await inference_manager.start_weight_update()
247
+
248
+ async def _handle_weight_update_complete(self, inference_manager):
249
+ """Handle weight update complete in the event loop"""
250
+ await inference_manager.complete_weight_update()
251
+
252
+ def _run_in_loop(self, coro):
253
+ """Run a coroutine in the event loop from a thread"""
254
+ future = asyncio.run_coroutine_threadsafe(coro, self.event_loop)
255
+ return future.result()
233
256
 
234
257
  def _handle_status_updates(self, inference_manager: InferenceManager):
235
258
  """Handle status updates from training process using ZMQ SUB socket"""
236
259
  print("Starting status update handler...")
237
260
  try:
238
-
239
261
  for status in self.server_comms_handler.receive_status():
240
262
  print(f"Received status update: {status}")
241
- if status["status"] == "model_saved":
263
+ if status["status"] == "weight_update_start":
264
+ # Block inference calls by incrementing counter
265
+ inference_manager.start_weight_update()
266
+ elif status["status"] == "weight_update_complete":
267
+ # Decrement counter to potentially allow inference calls again
268
+ inference_manager.complete_weight_update()
269
+ elif status["status"] == "model_saved":
242
270
  print("Updating inference model...")
243
271
  # There is a case where this status is sent multiple times
244
272
  # We need to make sure we only update the model once
245
- if self._should_update_model():
246
- inference_manager.update_model(status["output_dir"])
247
- # self.last_inference_update = self.data_count
248
- self.model_saved_and_reload_requested = False
249
- self.current_model = status["output_dir"]
250
- print("Model update complete")
273
+ self.current_model = status["output_dir"]
274
+ self.saving_model = False
275
+ print("Model update complete")
251
276
  elif status["status"] == "checkpoint_saved":
252
277
  print("Received checkpoint saved status")
253
278
  self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
@@ -257,24 +282,19 @@ class GRPOManager:
257
282
  elif status["status"] == "error":
258
283
  print(f"Training error: {status.get('error', 'Unknown error')}")
259
284
  elif status["status"] == "terminated":
285
+ self.terminating = False
260
286
  print("Training process terminated")
261
- break
262
287
  except Exception as e:
263
288
  print(f"Error in status update handler: {e}")
289
+ # Make sure to allow inference if there's an error
290
+ try:
291
+ inference_manager.complete_weight_update()
292
+ except:
293
+ pass
264
294
 
265
295
  def grpo_step(
266
296
  self, request: GRPORequest, inference_manager: InferenceManager
267
297
  ) -> str:
268
- while inference_manager.is_server_restarting():
269
- print("Inferece manager restarting, waiting for GRPO step")
270
- time.sleep(5)
271
-
272
- while self._should_update_model():
273
- print(
274
- f"Waiting for model update. Data count: {self.data_count}, Last inference update: {self.last_inference_update}"
275
- )
276
- time.sleep(5)
277
-
278
298
  while self.saving_checkpoint:
279
299
  print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
280
300
  time.sleep(5)
@@ -283,8 +303,10 @@ class GRPOManager:
283
303
  # Send the batch to the training process
284
304
  self.server_comms_handler.send_data(request.batch)
285
305
  self.data_count += 1
306
+
286
307
  except Exception as e:
287
308
  print(f"Failed to send batch to training process: {e}")
309
+ raise
288
310
 
289
311
  return {
290
312
  "current_model": self.current_model,
@@ -293,35 +315,22 @@ class GRPOManager:
293
315
  }
294
316
 
295
317
  def update_model(self, request, inference_manager: InferenceManager):
296
- if inference_manager._session:
297
- # Create a new event loop if one doesn't exist
298
- try:
299
- loop = asyncio.get_event_loop()
300
- except RuntimeError:
301
- loop = asyncio.new_event_loop()
302
- asyncio.set_event_loop(loop)
303
-
304
- # Run the session closure in the event loop
305
- loop.run_until_complete(inference_manager._session.close())
306
- inference_manager._session = None
307
-
308
- inference_manager.inference_count = 0
309
- inference_manager.restarting = True
310
-
311
- self.model_saved_and_reload_requested = True
312
- self.server_comms_handler.send_command({"command": "save_model"})
313
- while self.model_saved_and_reload_requested:
314
- print(
315
- "Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
316
- )
317
- time.sleep(5)
318
+ # No longer used
318
319
  return {
319
320
  "current_model": self.current_model,
320
321
  "checkpoints": self.checkpoints,
321
322
  "last_checkpoint": self.last_checkpoint,
322
323
  }
323
324
 
324
- def checkpoint(self, request: GRPOCheckpointRequest):
325
+ def checkpoint(
326
+ self, request: GRPOCheckpointRequest, inference_manager: InferenceManager
327
+ ):
328
+ while (
329
+ inference_manager.is_updating
330
+ ): # Use the property instead of direct access
331
+ print("Waiting for weight updates to finish before checkpointing...")
332
+ time.sleep(5)
333
+
325
334
  self.saving_checkpoint = True
326
335
  self.server_comms_handler.send_command(
327
336
  {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
@@ -337,71 +346,143 @@ class GRPOManager:
337
346
 
338
347
  def terminate(self, inference_manager: InferenceManager):
339
348
  """Clean up resources and save the final model."""
349
+ time.sleep(5)
350
+
351
+ while (
352
+ inference_manager and inference_manager.is_updating
353
+ ): # Use the property instead of direct access
354
+ print("Waiting for final weight updates to finish before saving...")
355
+ time.sleep(5)
356
+
357
+ print("sending save model command")
358
+ self.saving_model = True
359
+ self.server_comms_handler.send_command({"command": "save_model"})
360
+ while self.saving_model:
361
+ print("Waiting for final model to be saved...")
362
+ time.sleep(5)
363
+
340
364
  termination_data = {
341
365
  "current_model": self.current_model,
342
366
  "checkpoints": self.checkpoints,
343
367
  "last_checkpoint": self.last_checkpoint,
344
368
  }
345
- try:
346
- # Stop the inference server
347
- if inference_manager.process is not None:
348
- inference_manager.kill()
349
369
 
350
- # Send termination command through REQ socket
351
- self.server_comms_handler.send_broadcast({"message": "terminate"})
352
- # self.training_process.terminate()
353
- print("Waiting for training process to finish")
370
+ print("sending termination command")
371
+ self.terminating = True
372
+ self.server_comms_handler.send_command({"command": "terminate"})
373
+ print("Waiting for training process to finish...")
374
+
375
+ # Wait for at most 15 seconds for termination
376
+ start_time = time.time()
377
+ while self.terminating:
378
+ if time.time() - start_time > 15:
379
+ print(
380
+ "Termination wait timed out after 15 seconds, proceeding with cleanup..."
381
+ )
382
+ break
383
+ print("Waiting for run to be terminated...")
384
+ time.sleep(3)
385
+
386
+ print("Doing cleanup")
387
+ self.cleanup_termination(inference_manager)
388
+
389
+ if self.train_kwargs and "output_dir" in self.train_kwargs:
390
+ print(
391
+ f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
392
+ )
393
+ if not os.path.exists(self.train_kwargs["output_dir"]):
394
+ print(
395
+ f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
396
+ )
397
+ output_dir = self.train_kwargs["output_dir"]
398
+ self.train_kwargs = None
399
+ else:
400
+ print("Training terminated, no output directory specified")
401
+ self.train_kwargs = None
402
+
403
+ return termination_data
354
404
 
355
- # Wait for training process to finish
405
+ def cleanup_termination(self, inference_manager):
406
+ try:
407
+ # Kill training process and all its children (accelerate launcher creates multiple processes)
356
408
  if self.training_process:
357
- self.training_process.wait(timeout=30)
409
+ print("Terminating training process and its children...")
410
+ try:
411
+ parent = psutil.Process(self.training_process.pid)
412
+ # Get all child processes including grandchildren
413
+ children = parent.children(recursive=True)
414
+
415
+ # Send SIGTERM to children first
416
+ for child in children:
417
+ try:
418
+ child.send_signal(signal.SIGTERM)
419
+ except psutil.NoSuchProcess:
420
+ pass
421
+
422
+ # Send SIGTERM to parent
423
+ parent.send_signal(signal.SIGTERM)
424
+
425
+ # Wait for processes to terminate gracefully
426
+ gone, alive = psutil.wait_procs(children + [parent], timeout=10)
427
+
428
+ # If any processes are still alive, force kill them
429
+ for p in alive:
430
+ try:
431
+ p.kill() # SIGKILL
432
+ except psutil.NoSuchProcess:
433
+ pass
434
+
435
+ except psutil.NoSuchProcess:
436
+ print(f"Process {self.training_process.pid} not found")
437
+ except Exception as e:
438
+ print(f"Error killing training process tree: {e}")
439
+ # Fallback to basic termination
440
+ self.training_process.terminate()
441
+ try:
442
+ self.training_process.wait(timeout=10)
443
+ except subprocess.TimeoutExpired:
444
+ self.training_process.kill()
445
+ self.training_process.wait(timeout=10)
358
446
 
359
- except Exception as e:
360
- print(f"Error during termination: {e}")
361
- finally:
362
447
  # Clean up ZMQ connections
363
448
  if self.server_comms_handler:
449
+ print("Closing ZMQ connections...")
364
450
  self.server_comms_handler.close()
365
451
 
366
- # Force kill training process if still running
367
- if self.training_process and self.training_process.poll() is None:
368
- self.training_process.kill()
369
- self.training_process.wait()
452
+ if inference_manager and inference_manager.process is not None:
453
+ print("Killing inference manager...")
454
+ inference_manager.kill()
370
455
 
371
- # Reinitialize incase we want to start a new training run
456
+ # Reinitialize in case we want to start a new training run
457
+ self.training_process = None
458
+ self.current_model = None
459
+ self.server_comms_handler = None
460
+ self.status_thread = None
461
+ self.data_count = 0
462
+ print("Cleanup completed successfully")
463
+ except Exception as e:
464
+ print(f"Error during cleanup: {e}")
465
+ # Still reset state even if cleanup fails
372
466
  self.training_process = None
373
467
  self.current_model = None
374
468
  self.server_comms_handler = None
375
469
  self.status_thread = None
376
- self.model_saved_and_reload_requested = False
377
-
378
470
  self.data_count = 0
379
- self.last_inference_update = 0
380
-
381
- if self.train_kwargs and "output_dir" in self.train_kwargs:
382
- print(
383
- f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
384
- )
385
- if not os.path.exists(self.train_kwargs["output_dir"]):
386
- print(
387
- f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
388
- )
389
- output_dir = self.train_kwargs["output_dir"]
390
- self.train_kwargs = None
391
- else:
392
- print("Training terminated, no output directory specified")
393
- self.train_kwargs = None
394
-
395
- return termination_data
396
-
397
- def _should_update_model(self):
398
- return self.model_saved_and_reload_requested
399
471
 
400
472
 
401
473
  def get_free_port() -> int:
402
474
  """
403
- Return a free TCP port on localhost.
475
+ Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
404
476
  """
405
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
406
- s.bind(("localhost", 0))
407
- return s.getsockname()[1]
477
+ import random
478
+ import socket
479
+
480
+ ports = []
481
+ for _ in range(random.randint(5, 10)):
482
+ try:
483
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
484
+ s.bind(("localhost", 0))
485
+ ports.append(s.getsockname()[1])
486
+ except Exception as e:
487
+ print(f"Error binding to port: {e}")
488
+ return random.choice(ports)