arbor-ai 0.1.12__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {arbor_ai-0.1.12/arbor_ai.egg-info → arbor_ai-0.1.14}/PKG-INFO +2 -1
  2. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/models/schemas.py +22 -1
  3. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/grpo.py +15 -6
  4. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/grpo_manager.py +65 -17
  5. arbor_ai-0.1.14/arbor/server/services/inference/sgl_router_launch_server.py +226 -0
  6. arbor_ai-0.1.14/arbor/server/services/inference_manager.py +404 -0
  7. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/scripts/grpo_training.py +159 -43
  8. arbor_ai-0.1.14/arbor/server/utils/helpers.py +0 -0
  9. {arbor_ai-0.1.12 → arbor_ai-0.1.14/arbor_ai.egg-info}/PKG-INFO +2 -1
  10. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/SOURCES.txt +2 -0
  11. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/requires.txt +1 -0
  12. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/pyproject.toml +3 -2
  13. arbor_ai-0.1.12/arbor/server/services/inference_manager.py +0 -299
  14. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/LICENSE +0 -0
  15. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/README.md +0 -0
  16. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/__init__.py +0 -0
  17. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/cli.py +0 -0
  18. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/client/__init__.py +0 -0
  19. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/client/api.py +0 -0
  20. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/__init__.py +0 -0
  21. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/__init__.py +0 -0
  22. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/__init__.py +0 -0
  23. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/files.py +0 -0
  24. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/inference.py +0 -0
  25. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/jobs.py +0 -0
  26. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/core/__init__.py +0 -0
  27. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/core/config.py +0 -0
  28. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/core/logging.py +0 -0
  29. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/main.py +0 -0
  30. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/__init__.py +0 -0
  31. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/comms/__init__.py +0 -0
  32. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/comms/comms.py +0 -0
  33. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/dependencies.py +0 -0
  34. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/file_manager.py +0 -0
  35. {arbor_ai-0.1.12/arbor/server/utils → arbor_ai-0.1.14/arbor/server/services/inference}/__init__.py +0 -0
  36. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/job_manager.py +0 -0
  37. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/training_manager.py +0 -0
  38. /arbor_ai-0.1.12/arbor/server/utils/helpers.py → /arbor_ai-0.1.14/arbor/server/utils/__init__.py +0 -0
  39. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/dependency_links.txt +0 -0
  40. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/entry_points.txt +0 -0
  41. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/top_level.txt +0 -0
  42. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/setup.cfg +0 -0
  43. {arbor_ai-0.1.12 → arbor_ai-0.1.14}/tests/test_cli.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -23,6 +23,7 @@ Requires-Dist: pyzmq>=26.4.0
23
23
  Requires-Dist: pyyaml>=6.0.2
24
24
  Requires-Dist: sglang[all]>=0.4.5.post3
25
25
  Requires-Dist: sglang-router
26
+ Requires-Dist: wandb
26
27
  Dynamic: license-file
27
28
 
28
29
  <p align="center">
@@ -199,10 +199,16 @@ class GRPOConfigRequest(BaseModel):
199
199
  bf16: Optional[bool] = None
200
200
  scale_rewards: Optional[bool] = None
201
201
  max_grad_norm: Optional[float] = None
202
+ report_to: Optional[str] = None
203
+ log_completions: Optional[bool] = None
204
+ logging_steps: Optional[int] = None
205
+ mask_truncated_completions: Optional[bool] = None
206
+ # Arbor specific
207
+ max_context_length: Optional[int] = None
202
208
  lora: Optional[bool] = None
203
- update_interval: Optional[int] = None
204
209
  # To name the run
205
210
  suffix: Optional[str] = None
211
+ generation_batch_size: Optional[int] = None
206
212
 
207
213
 
208
214
  class GRPOConfigResponse(BaseModel):
@@ -216,8 +222,23 @@ class GRPOTerminateRequest(BaseModel):
216
222
  class GRPOTerminateResponse(BaseModel):
217
223
  status: str
218
224
  current_model: str
225
+ checkpoints: Optional[dict[str, str]] = None
226
+ last_checkpoint: Optional[str] = None
219
227
 
220
228
 
221
229
  class GRPOStepResponse(BaseModel):
222
230
  status: str
223
231
  current_model: str
232
+ checkpoints: dict[str, str]
233
+ last_checkpoint: Optional[str] = None
234
+
235
+
236
+ class GRPOCheckpointRequest(BaseModel):
237
+ checkpoint_name: str
238
+
239
+
240
+ class GRPOCheckpointResponse(BaseModel):
241
+ status: str
242
+ current_model: str
243
+ checkpoints: dict[str, str]
244
+ last_checkpoint: str
@@ -4,6 +4,8 @@ import subprocess
4
4
  from fastapi import APIRouter, BackgroundTasks, Request
5
5
 
6
6
  from arbor.server.api.models.schemas import (
7
+ GRPOCheckpointRequest,
8
+ GRPOCheckpointResponse,
7
9
  GRPOConfigRequest,
8
10
  GRPOConfigResponse,
9
11
  GRPORequest,
@@ -31,17 +33,24 @@ def run_grpo_step(
31
33
  inference_manager = request.app.state.inference_manager
32
34
  grpo_manager = request.app.state.grpo_manager
33
35
 
34
- current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
36
+ step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
35
37
 
36
- return GRPOStepResponse(status="success", current_model=current_model)
38
+ return GRPOStepResponse(status="success", **step_data)
37
39
 
38
40
 
39
41
  @router.post("/update_model", response_model=GRPOStepResponse)
40
42
  def update_model(request: Request):
41
43
  grpo_manager = request.app.state.grpo_manager
42
44
  inference_manager = request.app.state.inference_manager
43
- current_model = grpo_manager.update_model(request, inference_manager)
44
- return GRPOStepResponse(status="success", current_model=current_model)
45
+ update_model_data = grpo_manager.update_model(request, inference_manager)
46
+ return GRPOStepResponse(status="success", **update_model_data)
47
+
48
+
49
+ @router.post("/checkpoint", response_model=GRPOCheckpointResponse)
50
+ def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
51
+ grpo_manager = request.app.state.grpo_manager
52
+ checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
53
+ return GRPOCheckpointResponse(status="success", **checkpoint_data)
45
54
 
46
55
 
47
56
  @router.post("/terminate", response_model=GRPOTerminateResponse)
@@ -50,5 +59,5 @@ def terminate_grpo(request: Request):
50
59
  grpo_manager = request.app.state.grpo_manager
51
60
  inference_manager = request.app.state.inference_manager
52
61
 
53
- final_model = grpo_manager.terminate(inference_manager)
54
- return GRPOTerminateResponse(status="success", current_model=final_model)
62
+ terminate_data = grpo_manager.terminate(inference_manager)
63
+ return GRPOTerminateResponse(status="success", **terminate_data)
@@ -13,7 +13,11 @@ from datetime import datetime
13
13
  from pathlib import Path
14
14
  from typing import Optional
15
15
 
16
- from arbor.server.api.models.schemas import GRPOConfigRequest, GRPORequest
16
+ from arbor.server.api.models.schemas import (
17
+ GRPOCheckpointRequest,
18
+ GRPOConfigRequest,
19
+ GRPORequest,
20
+ )
17
21
  from arbor.server.core.config import Settings
18
22
  from arbor.server.services.comms.comms import ArborServerCommsHandler
19
23
  from arbor.server.services.inference_manager import InferenceManager
@@ -28,7 +32,10 @@ class GRPOManager:
28
32
  self.server_comms_handler = None
29
33
  self.status_thread = None
30
34
  self.model_saved_and_reload_requested = False
35
+ self.saving_checkpoint = False
31
36
 
37
+ self.checkpoints = {}
38
+ self.last_checkpoint = None
32
39
  self.data_count = 0
33
40
  self.last_inference_update = 0
34
41
  # Set up signal handler
@@ -86,12 +93,17 @@ class GRPOManager:
86
93
  "bf16",
87
94
  "scale_rewards",
88
95
  "max_grad_norm",
96
+ "report_to",
97
+ "log_completions",
98
+ "logging_steps",
99
+ "generation_batch_size",
100
+ "mask_truncated_completions",
89
101
  ]
90
102
  trl_train_kwargs = {
91
103
  key: train_kwargs[key] for key in trl_keys if key in train_kwargs
92
104
  }
93
105
 
94
- arbor_keys = ["update_interval", "lora"]
106
+ arbor_keys = ["max_context_length", "lora"]
95
107
  arbor_train_kwargs = {
96
108
  key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
97
109
  }
@@ -119,6 +131,8 @@ class GRPOManager:
119
131
  # Start the training process with ZMQ ports
120
132
  my_env = os.environ.copy()
121
133
  my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
134
+ # WandB can block the training process for login, so we silence it
135
+ my_env["WANDB_SILENT"] = "true"
122
136
 
123
137
  num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
124
138
 
@@ -209,6 +223,12 @@ class GRPOManager:
209
223
 
210
224
  # Launch the inference server
211
225
  print("Launching inference server...")
226
+ # launch_kwargs = {
227
+ # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
228
+ # }
229
+ inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
230
+ "max_context_length", None
231
+ )
212
232
  inference_manager.launch(self.current_model)
213
233
 
214
234
  def _handle_status_updates(self, inference_manager: InferenceManager):
@@ -228,6 +248,12 @@ class GRPOManager:
228
248
  self.model_saved_and_reload_requested = False
229
249
  self.current_model = status["output_dir"]
230
250
  print("Model update complete")
251
+ elif status["status"] == "checkpoint_saved":
252
+ print("Received checkpoint saved status")
253
+ self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
254
+ self.last_checkpoint = status["checkpoint_name"]
255
+ self.saving_checkpoint = False
256
+ print("Checkpoint saved")
231
257
  elif status["status"] == "error":
232
258
  print(f"Training error: {status.get('error', 'Unknown error')}")
233
259
  elif status["status"] == "terminated":
@@ -249,6 +275,10 @@ class GRPOManager:
249
275
  )
250
276
  time.sleep(5)
251
277
 
278
+ while self.saving_checkpoint:
279
+ print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
280
+ time.sleep(5)
281
+
252
282
  try:
253
283
  # Send the batch to the training process
254
284
  self.server_comms_handler.send_data(request.batch)
@@ -256,12 +286,11 @@ class GRPOManager:
256
286
  except Exception as e:
257
287
  print(f"Failed to send batch to training process: {e}")
258
288
 
259
- # We tell the script to save the model. The script will let us know when it's done via the status update handler
260
- # Then we'll actually run the update_model function in the inference manager and finally update the last_inference_update variable
261
- # if self._should_update_model():
262
- # self.server_comms_handler.send_command({"command": "save_model"})
263
-
264
- return self.current_model
289
+ return {
290
+ "current_model": self.current_model,
291
+ "checkpoints": self.checkpoints,
292
+ "last_checkpoint": self.last_checkpoint,
293
+ }
265
294
 
266
295
  def update_model(self, request, inference_manager: InferenceManager):
267
296
  if inference_manager._session:
@@ -286,18 +315,41 @@ class GRPOManager:
286
315
  "Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
287
316
  )
288
317
  time.sleep(5)
289
- return self.current_model
318
+ return {
319
+ "current_model": self.current_model,
320
+ "checkpoints": self.checkpoints,
321
+ "last_checkpoint": self.last_checkpoint,
322
+ }
323
+
324
+ def checkpoint(self, request: GRPOCheckpointRequest):
325
+ self.saving_checkpoint = True
326
+ self.server_comms_handler.send_command(
327
+ {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
328
+ )
329
+ while self.saving_checkpoint:
330
+ print("Waiting for checkpoint to be saved...")
331
+ time.sleep(5)
332
+ return {
333
+ "current_model": self.current_model,
334
+ "checkpoints": self.checkpoints,
335
+ "last_checkpoint": self.last_checkpoint,
336
+ }
290
337
 
291
338
  def terminate(self, inference_manager: InferenceManager):
292
339
  """Clean up resources and save the final model."""
340
+ termination_data = {
341
+ "current_model": self.current_model,
342
+ "checkpoints": self.checkpoints,
343
+ "last_checkpoint": self.last_checkpoint,
344
+ }
293
345
  try:
294
346
  # Stop the inference server
295
347
  if inference_manager.process is not None:
296
348
  inference_manager.kill()
297
349
 
298
350
  # Send termination command through REQ socket
299
- # self.server_comms_handler.send_broadcast({"message": "terminate"})
300
- self.training_process.terminate()
351
+ self.server_comms_handler.send_broadcast({"message": "terminate"})
352
+ # self.training_process.terminate()
301
353
  print("Waiting for training process to finish")
302
354
 
303
355
  # Wait for training process to finish
@@ -336,17 +388,13 @@ class GRPOManager:
336
388
  )
337
389
  output_dir = self.train_kwargs["output_dir"]
338
390
  self.train_kwargs = None
339
- return output_dir
340
391
  else:
341
392
  print("Training terminated, no output directory specified")
342
393
  self.train_kwargs = None
343
- return None
394
+
395
+ return termination_data
344
396
 
345
397
  def _should_update_model(self):
346
- # return (
347
- # self.data_count - self.last_inference_update
348
- # >= self.train_kwargs["update_interval"]
349
- # )
350
398
  return self.model_saved_and_reload_requested
351
399
 
352
400
 
@@ -0,0 +1,226 @@
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import logging
5
+ import multiprocessing as mp
6
+ import os
7
+ import random
8
+ import signal
9
+ import sys
10
+ import time
11
+ from typing import List
12
+
13
+ import requests
14
+ import zmq
15
+ from setproctitle import setproctitle
16
+ from sglang.srt.entrypoints.http_server import launch_server
17
+ from sglang.srt.server_args import ServerArgs
18
+ from sglang.srt.utils import is_port_available
19
+ from sglang_router.launch_router import RouterArgs, launch_router
20
+
21
+
22
+ def setup_logger():
23
+ logger = logging.getLogger("router")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ formatter = logging.Formatter(
27
+ "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger()
39
+
40
+
41
+ # Create new process group
42
+ def run_server(server_args, dp_rank):
43
+ """
44
+ Note:
45
+
46
+ 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
47
+ This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
48
+
49
+ Terminal (PGID=100)
50
+ └── Main Python Process (PGID=100)
51
+ └── Server Process 1 (PGID=100)
52
+ └── Scheduler 1
53
+ └── Detokenizer 1
54
+ └── Server Process 2 (PGID=100)
55
+ └── Scheduler 2
56
+ └── Detokenizer 2
57
+
58
+ 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
59
+
60
+ Terminal (PGID=100)
61
+ └── Main Python Process (PGID=200)
62
+ └── Server Process 1 (PGID=300)
63
+ └── Scheduler 1
64
+ └── Detokenizer 1
65
+ └── Server Process 2 (PGID=400)
66
+ └── Scheduler 2
67
+ └── Detokenizer 2
68
+ """
69
+ # create new process group
70
+ os.setpgrp()
71
+
72
+ setproctitle("sglang::server")
73
+ # Set SGLANG_DP_RANK environment variable
74
+ os.environ["SGLANG_DP_RANK"] = str(dp_rank)
75
+
76
+ launch_server(server_args)
77
+
78
+
79
+ def launch_server_process(
80
+ server_args: ServerArgs, worker_port: int, dp_id: int
81
+ ) -> mp.Process:
82
+ """Launch a single server process with the given args and port."""
83
+ server_args = copy.deepcopy(server_args)
84
+ server_args.port = worker_port
85
+ server_args.base_gpu_id = dp_id * server_args.tp_size
86
+ server_args.dp_size = 1
87
+
88
+ proc = mp.Process(target=run_server, args=(server_args, dp_id))
89
+ proc.start()
90
+ return proc
91
+
92
+
93
+ def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
94
+ """Wait for server to be healthy by checking /health endpoint."""
95
+ start_time = time.time()
96
+ url = f"http://{host}:{port}/health"
97
+
98
+ while time.time() - start_time < timeout:
99
+ try:
100
+ response = requests.get(url, timeout=5)
101
+ if response.status_code == 200:
102
+ return True
103
+ except requests.exceptions.RequestException:
104
+ pass
105
+ time.sleep(1)
106
+ return False
107
+
108
+
109
+ def find_available_ports(base_port: int, count: int) -> List[int]:
110
+ """Find consecutive available ports starting from base_port."""
111
+ available_ports = []
112
+ current_port = base_port
113
+
114
+ while len(available_ports) < count:
115
+ if is_port_available(current_port):
116
+ available_ports.append(current_port)
117
+ current_port += random.randint(100, 1000)
118
+
119
+ return available_ports
120
+
121
+
122
+ def cleanup_processes(processes: List[mp.Process]):
123
+ for process in processes:
124
+ logger.info(f"Terminating process group {process.pid}")
125
+ try:
126
+ os.killpg(process.pid, signal.SIGTERM)
127
+ except ProcessLookupError:
128
+ # Process group may already be terminated
129
+ pass
130
+
131
+ # Wait for processes to terminate
132
+ for process in processes:
133
+ process.join(timeout=5)
134
+ if process.is_alive():
135
+ logger.warning(
136
+ f"Process {process.pid} did not terminate gracefully, forcing kill"
137
+ )
138
+ try:
139
+ os.killpg(process.pid, signal.SIGKILL)
140
+ except ProcessLookupError:
141
+ pass
142
+
143
+ logger.info("All process groups terminated")
144
+
145
+
146
+ def main():
147
+ # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
148
+ mp.set_start_method("spawn")
149
+
150
+ parser = argparse.ArgumentParser(
151
+ description="Launch SGLang router and server processes"
152
+ )
153
+
154
+ ServerArgs.add_cli_args(parser)
155
+ RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
156
+ parser.add_argument(
157
+ "--router-dp-worker-base-port",
158
+ type=int,
159
+ default=31000,
160
+ help="Base port number for data parallel workers",
161
+ )
162
+ parser.add_argument(
163
+ "--worker-urls-port",
164
+ type=int,
165
+ help="Port number for worker URLs publisher",
166
+ )
167
+
168
+ args = parser.parse_args()
169
+ server_args = ServerArgs.from_cli_args(args)
170
+ router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
171
+
172
+ # Find available ports for workers
173
+ worker_ports = find_available_ports(
174
+ args.router_dp_worker_base_port, server_args.dp_size
175
+ )
176
+
177
+ # Start server processes
178
+ server_processes = []
179
+
180
+ for i, worker_port in enumerate(worker_ports):
181
+ logger.info(f"Launching DP server process {i} on port {worker_port}")
182
+ proc = launch_server_process(server_args, worker_port, i)
183
+ server_processes.append(proc)
184
+
185
+ signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
186
+ signal.signal(
187
+ signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
188
+ )
189
+ signal.signal(
190
+ signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
191
+ )
192
+
193
+ # Update router args with worker URLs
194
+ worker_urls = [f"http://{server_args.host}:{port}" for port in worker_ports]
195
+ router_args.worker_urls = worker_urls
196
+
197
+ # Publish worker URLs via ZMQ if port is specified
198
+ if args.worker_urls_port:
199
+ try:
200
+ context = zmq.Context()
201
+ socket = context.socket(zmq.PUB)
202
+ socket.bind(f"tcp://*:{args.worker_urls_port}")
203
+ # Give subscribers time to connect
204
+ time.sleep(0.1)
205
+ socket.send_json({"type": "worker_urls", "urls": worker_urls})
206
+ logger.info(
207
+ f"Published worker URLs via ZMQ on port {args.worker_urls_port}"
208
+ )
209
+ socket.close()
210
+ context.term()
211
+ except Exception as e:
212
+ logger.error(f"Failed to publish worker URLs via ZMQ: {e}")
213
+ cleanup_processes(server_processes)
214
+ sys.exit(1)
215
+
216
+ # Start the router
217
+ try:
218
+ launch_router(router_args)
219
+ except Exception as e:
220
+ logger.error(f"Failed to start router: {e}")
221
+ cleanup_processes(server_processes)
222
+ sys.exit(1)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()