arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbor/server/main.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from fastapi import FastAPI
2
- from arbor.server.api.routes import training, files, jobs
3
- from arbor.server.core.config import settings
2
+
3
+ from arbor.server.api.routes import files, grpo, inference, jobs
4
4
 
5
5
  app = FastAPI(title="Arbor API")
6
6
 
7
7
  # Include routers
8
- app.include_router(training.router, prefix="/api/fine-tune")
9
- app.include_router(files.router, prefix="/api/files")
10
- app.include_router(jobs.router, prefix="/api/job")
8
+ app.include_router(files.router, prefix="/v1/files")
9
+ app.include_router(jobs.router, prefix="/v1/fine_tuning/jobs")
10
+ app.include_router(grpo.router, prefix="/v1/fine_tuning/grpo")
11
+ app.include_router(inference.router, prefix="/v1/chat")
File without changes
@@ -0,0 +1,226 @@
1
+ import os
2
+ import queue
3
+ import socket
4
+ import threading
5
+ import time
6
+
7
+ import zmq
8
+
9
+
10
+ class ArborServerCommsHandler:
11
+ """Handles socket communication between manager and training process"""
12
+
13
+ def __init__(self, host="localhost"):
14
+ self.host = host
15
+ self.context = zmq.Context()
16
+
17
+ # Command socket (REQ/REP pattern)
18
+ self.command_socket = self.context.socket(zmq.REQ)
19
+ self.command_port = self.command_socket.bind_to_random_port(f"tcp://{host}")
20
+
21
+ # Status socket (PUB/SUB pattern)
22
+ self.status_socket = self.context.socket(zmq.SUB)
23
+ self.status_port = self.status_socket.bind_to_random_port(f"tcp://{host}")
24
+ self.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
25
+
26
+ # Data socket (PUB/SUB pattern)
27
+ self.data_socket = self.context.socket(zmq.PUB)
28
+ self.data_port = self.data_socket.bind_to_random_port(f"tcp://{host}")
29
+
30
+ self.broadcast_socket = self.context.socket(zmq.PUB)
31
+ self.broadcast_port = self.broadcast_socket.bind_to_random_port(f"tcp://{host}")
32
+
33
+ self.handshake_socket = self.context.socket(zmq.REP)
34
+ self.handshake_port = self.handshake_socket.bind_to_random_port(f"tcp://{host}")
35
+
36
+ def send_command(self, command):
37
+ self.command_socket.send_json(command)
38
+ return self.command_socket.recv_json() # Wait for acknowledgment
39
+
40
+ def send_data(self, data):
41
+ self.data_socket.send_json(data)
42
+
43
+ def send_broadcast(self, message):
44
+ self.broadcast_socket.send_json(message)
45
+
46
+ def receive_status(self):
47
+ while True:
48
+ status = self.status_socket.recv_json()
49
+ yield status
50
+
51
+ def close(self):
52
+ self.command_socket.close()
53
+ self.status_socket.close()
54
+ self.data_socket.close()
55
+ self.broadcast_socket.close()
56
+ self.handshake_socket.close()
57
+ self.context.term()
58
+
59
+ def wait_for_clients(self, expected_count):
60
+ connected_clients = []
61
+ while len(connected_clients) < expected_count:
62
+ print(f"Waiting for {expected_count} clients to connect...")
63
+ msg = self.handshake_socket.recv_json()
64
+ if msg.get("type") == "hello":
65
+ client_id = msg.get("client_id")
66
+ connected_clients.append(client_id)
67
+ self.handshake_socket.send_json({"status": "ack"})
68
+ print(f"Received handshake from {client_id}")
69
+ print(f"All {expected_count} clients connected!")
70
+
71
+
72
+ class ArborScriptCommsHandler:
73
+ def __init__(
74
+ self,
75
+ host,
76
+ command_port,
77
+ status_port,
78
+ data_port,
79
+ broadcast_port,
80
+ handshake_port,
81
+ is_main_process,
82
+ ):
83
+ self.context = zmq.Context()
84
+ self.is_main_process = is_main_process
85
+
86
+ # Command socket (main process only)
87
+ if is_main_process:
88
+ self.command_socket = self.context.socket(zmq.REP)
89
+ self.command_socket.connect(f"tcp://{host}:{command_port}")
90
+
91
+ self.status_socket = self.context.socket(zmq.PUB)
92
+ self.status_socket.connect(f"tcp://{host}:{status_port}")
93
+ else:
94
+ self.command_socket = None
95
+ self.status_socket = None
96
+
97
+ # Data socket (all processes)
98
+ self.data_socket = self.context.socket(zmq.SUB)
99
+ self.data_socket.connect(f"tcp://{host}:{data_port}")
100
+ self.data_socket.setsockopt_string(zmq.SUBSCRIBE, "")
101
+ self.data_queue = queue.Queue()
102
+ self._start_data_receiver()
103
+
104
+ # Broadcast socket (all processes)
105
+ self.broadcast_socket = self.context.socket(zmq.SUB)
106
+ self.broadcast_socket.connect(f"tcp://{host}:{broadcast_port}")
107
+ self.broadcast_socket.setsockopt_string(zmq.SUBSCRIBE, "")
108
+
109
+ # Handshake socket (all processes)
110
+ self.handshake_socket = self.context.socket(zmq.REQ)
111
+ self.handshake_socket.connect(f"tcp://{host}:{handshake_port}")
112
+ self._send_handshake()
113
+
114
+ def send_status(self, status):
115
+ if self.status_socket is not None:
116
+ self.status_socket.send_json(status)
117
+
118
+ def receive_command(self):
119
+ if self.command_socket is not None:
120
+ while True:
121
+ command = self.command_socket.recv_json()
122
+ # Send acknowledgment
123
+ self.command_socket.send_json({"status": "received"})
124
+ yield command
125
+
126
+ def receive_data(self):
127
+ return self.data_queue.get()
128
+
129
+ def _start_data_receiver(self):
130
+ def _receiver():
131
+ while True:
132
+ try:
133
+ data = self.data_socket.recv_json()
134
+ self.data_queue.put(data)
135
+ except Exception as e:
136
+ print(f"Error receiving data: {e}")
137
+ break
138
+
139
+ self.receiver_thread = threading.Thread(target=_receiver, daemon=True)
140
+ self.receiver_thread.start()
141
+
142
+ def is_data_queue_empty(self):
143
+ return self.data_queue.empty()
144
+
145
+ def get_data_queue_size(self):
146
+ return self.data_queue.qsize()
147
+
148
+ def receive_broadcast(self):
149
+ while True:
150
+ broadcast = self.broadcast_socket.recv_json()
151
+ yield broadcast
152
+
153
+ def close(self):
154
+ if self.command_socket is not None:
155
+ self.command_socket.close()
156
+ if self.status_socket is not None:
157
+ self.status_socket.close()
158
+ self.data_socket.close()
159
+ self.broadcast_socket.close()
160
+ self.handshake_socket.close()
161
+ self.context.term()
162
+
163
+ def _get_client_id(self):
164
+ # Return a unique identifier for this client (could be hostname, PID, etc.)
165
+ return f"{socket.gethostname()}_{os.getpid()}"
166
+
167
+ def _send_handshake(self):
168
+ print(f"Sending handshake to {self.handshake_socket}")
169
+ self.handshake_socket.send_json(
170
+ {"type": "hello", "client_id": self._get_client_id()}
171
+ )
172
+ self.handshake_socket.recv_json() # Wait for ack
173
+
174
+
175
+ if __name__ == "__main__":
176
+
177
+ def _server_thread(server_comms):
178
+ server_comms.wait_for_clients(expected_count=3)
179
+ server_comms.send_data({"data": "test"})
180
+ # server_comms.send_command({"command": "test"})
181
+ # print("Server sent command")
182
+
183
+ def _client_thread(script_comms):
184
+ for data in script_comms.receive_data():
185
+ print("Client received data:", data)
186
+
187
+ server_comms = ArborServerCommsHandler()
188
+ t1 = threading.Thread(target=_server_thread, args=(server_comms,))
189
+ t1.start()
190
+ print("Server started")
191
+
192
+ client_threads = []
193
+ script_comms_list = []
194
+ for i in range(3):
195
+ script_comms = ArborScriptCommsHandler(
196
+ "localhost",
197
+ server_comms.command_port,
198
+ server_comms.status_port,
199
+ server_comms.data_port,
200
+ server_comms.broadcast_port,
201
+ server_comms.handshake_port,
202
+ False,
203
+ )
204
+ t = threading.Thread(target=_client_thread, args=(script_comms,))
205
+ t.start()
206
+ script_comms_list.append(script_comms)
207
+
208
+ import time
209
+
210
+ time.sleep(1)
211
+ import pdb
212
+
213
+ pdb.set_trace()
214
+
215
+ try:
216
+ t1.join()
217
+ for t in client_threads:
218
+ t.join()
219
+ except KeyboardInterrupt:
220
+ print("Keyboard interrupt")
221
+ except Exception as e:
222
+ print(f"Error: {e}")
223
+ finally:
224
+ for script_comms in script_comms_list:
225
+ script_comms.close()
226
+ server_comms.close()
@@ -1,16 +0,0 @@
1
- from functools import lru_cache
2
- from arbor.server.services.file_manager import FileManager
3
- from arbor.server.services.job_manager import JobManager
4
- from arbor.server.services.training_manager import TrainingManager
5
-
6
- @lru_cache()
7
- def get_file_manager() -> FileManager:
8
- return FileManager()
9
-
10
- @lru_cache()
11
- def get_job_manager() -> JobManager:
12
- return JobManager()
13
-
14
- @lru_cache()
15
- def get_training_manager() -> TrainingManager:
16
- return TrainingManager()
@@ -1,128 +1,289 @@
1
- from pathlib import Path
2
1
  import json
3
2
  import os
4
3
  import shutil
5
4
  import time
6
5
  import uuid
6
+ from pathlib import Path
7
+
7
8
  from fastapi import UploadFile
8
- from arbor.server.api.models.schemas import FileResponse
9
+
10
+ from arbor.server.core.config import Settings
11
+
9
12
 
10
13
  class FileValidationError(Exception):
11
14
  """Custom exception for file validation errors"""
15
+
12
16
  pass
13
17
 
18
+
14
19
  class FileManager:
15
- def __init__(self):
16
- self.uploads_dir = Path("uploads")
17
- self.uploads_dir.mkdir(exist_ok=True)
18
- self.files = self.load_files_from_uploads()
19
-
20
- def load_files_from_uploads(self):
21
- files = {}
22
-
23
- # Scan through all directories in uploads directory
24
- for dir_path in self.uploads_dir.glob("*"):
25
- if not dir_path.is_dir():
26
- continue
27
-
28
- # Check for metadata.json
29
- metadata_path = dir_path / "metadata.json"
30
- if not metadata_path.exists():
31
- continue
32
-
33
- # Load metadata
34
- with open(metadata_path) as f:
35
- metadata = json.load(f)
36
-
37
- # Find the .jsonl file
38
- jsonl_files = list(dir_path.glob("*.jsonl"))
39
- if not jsonl_files:
40
- continue
41
-
42
- file_path = jsonl_files[0]
43
- files[dir_path.name] = {
44
- "path": str(file_path),
45
- "purpose": metadata.get("purpose", "training"),
46
- "bytes": file_path.stat().st_size,
47
- "created_at": metadata.get("created_at", int(file_path.stat().st_mtime)),
48
- "filename": metadata.get("filename", file_path.name)
49
- }
50
-
51
- return files
52
-
53
- def save_uploaded_file(self, file: UploadFile) -> FileResponse:
54
- file_id = str(uuid.uuid4())
55
- dir_path = self.uploads_dir / file_id
56
- dir_path.mkdir(exist_ok=True)
57
-
58
- # Save the actual file
59
- file_path = dir_path / f"data.jsonl"
60
- with open(file_path, "wb") as f:
61
- shutil.copyfileobj(file.file, f)
62
-
63
- # Create metadata
64
- metadata = {
65
- "purpose": "training",
66
- "created_at": int(time.time()),
67
- "filename": file.filename
68
- }
69
-
70
- # Save metadata
71
- with open(dir_path / "metadata.json", "w") as f:
72
- json.dump(metadata, f)
73
-
74
- file_data = {
75
- "id": file_id,
76
- "path": str(file_path),
77
- "purpose": metadata["purpose"],
78
- "bytes": file.size,
79
- "created_at": metadata["created_at"],
80
- "filename": metadata["filename"]
81
- }
82
-
83
- self.files[file_id] = file_data
84
- return FileResponse(**file_data)
85
-
86
- def get_file(self, file_id: str):
87
- return self.files[file_id]
88
-
89
- def validate_file_format(self, file_content: bytes) -> None:
90
- """
91
- Validates that the file content is properly formatted JSONL with expected structure.
92
- Raises FileValidationError if validation fails.
93
- """
94
- if not file_content:
95
- raise FileValidationError("File is empty")
96
-
97
- try:
98
- lines = file_content.decode('utf-8').strip().split('\n')
99
- if not lines:
100
- raise FileValidationError("File contains no valid data")
101
-
102
- for line_num, line in enumerate(lines, 1):
20
+ def __init__(self, settings: Settings):
21
+ self.uploads_dir = Path(settings.STORAGE_PATH) / "uploads"
22
+ self.uploads_dir.mkdir(parents=True, exist_ok=True)
23
+ self.files = self.load_files_from_uploads()
24
+
25
+ def load_files_from_uploads(self):
26
+ files = {}
27
+
28
+ # Scan through all directories in uploads directory
29
+ for dir_path in self.uploads_dir.glob("*"):
30
+ if not dir_path.is_dir():
31
+ continue
32
+
33
+ # Check for metadata.json
34
+ metadata_path = dir_path / "metadata.json"
35
+ if not metadata_path.exists():
36
+ continue
37
+
38
+ # Load metadata
39
+ with open(metadata_path) as f:
40
+ metadata = json.load(f)
41
+
42
+ # Find the .jsonl file
43
+ jsonl_files = list(dir_path.glob("*.jsonl"))
44
+ if not jsonl_files:
45
+ continue
46
+
47
+ file_path = jsonl_files[0]
48
+ files[dir_path.name] = {
49
+ "path": str(file_path),
50
+ "purpose": metadata.get("purpose", "training"),
51
+ "bytes": file_path.stat().st_size,
52
+ "created_at": metadata.get(
53
+ "created_at", int(file_path.stat().st_mtime)
54
+ ),
55
+ "filename": metadata.get("filename", file_path.name),
56
+ }
57
+
58
+ return files
59
+
60
+ def save_uploaded_file(self, file: UploadFile):
61
+ file_id = f"file-{str(uuid.uuid4())}"
62
+ dir_path = self.uploads_dir / file_id
63
+ dir_path.mkdir(exist_ok=True)
64
+
65
+ # Save the actual file
66
+ file_path = dir_path / f"data.jsonl"
67
+ with open(file_path, "wb") as f:
68
+ shutil.copyfileobj(file.file, f)
69
+
70
+ # Create metadata
71
+ metadata = {
72
+ "purpose": "training",
73
+ "created_at": int(time.time()),
74
+ "filename": file.filename,
75
+ }
76
+
77
+ # Save metadata
78
+ with open(dir_path / "metadata.json", "w") as f:
79
+ json.dump(metadata, f)
80
+
81
+ file_data = {
82
+ "id": file_id,
83
+ "path": str(file_path),
84
+ "purpose": metadata["purpose"],
85
+ "bytes": file.size,
86
+ "created_at": metadata["created_at"],
87
+ "filename": metadata["filename"],
88
+ }
89
+
90
+ self.files[file_id] = file_data
91
+ return file_data
92
+
93
+ def get_file(self, file_id: str):
94
+ return self.files[file_id]
95
+
96
+ def delete_file(self, file_id: str):
97
+ if file_id not in self.files:
98
+ return
99
+
100
+ dir_path = self.uploads_dir / file_id
101
+ if dir_path.exists():
102
+ shutil.rmtree(dir_path)
103
+
104
+ del self.files[file_id]
105
+
106
+ def validate_file_format_sft(self, file_path: str) -> None:
107
+ """
108
+ Validates that the file at file_path is properly formatted JSONL with expected structure.
109
+ Raises FileValidationError if validation fails.
110
+ """
103
111
  try:
104
- data = json.loads(line)
112
+ with open(file_path, "r", encoding="utf-8") as f:
113
+ for line_num, line in enumerate(f, 1):
114
+ line = line.strip()
115
+ if not line:
116
+ continue # skip empty lines
117
+ try:
118
+ data = json.loads(line)
119
+
120
+ if not isinstance(data, dict):
121
+ raise FileValidationError(
122
+ f"Line {line_num}: Each line must be a JSON object"
123
+ )
124
+
125
+ if "messages" not in data:
126
+ raise FileValidationError(
127
+ f"Line {line_num}: Missing 'messages' field"
128
+ )
129
+
130
+ if not isinstance(data["messages"], list):
131
+ raise FileValidationError(
132
+ f"Line {line_num}: 'messages' must be an array"
133
+ )
134
+
135
+ for msg in data["messages"]:
136
+ if not isinstance(msg, dict):
137
+ raise FileValidationError(
138
+ f"Line {line_num}: Each message must be an object"
139
+ )
140
+ if "role" not in msg or "content" not in msg:
141
+ raise FileValidationError(
142
+ f"Line {line_num}: Messages must have 'role' and 'content' fields"
143
+ )
144
+ if not isinstance(msg["role"], str) or not isinstance(
145
+ msg["content"], str
146
+ ):
147
+ raise FileValidationError(
148
+ f"Line {line_num}: Message 'role' and 'content' must be strings"
149
+ )
150
+
151
+ except json.JSONDecodeError:
152
+ raise FileValidationError(f"Invalid JSON on line {line_num}")
153
+
154
+ except Exception as e:
155
+ raise FileValidationError(f"Failed to read or validate file: {e}")
156
+
157
+ def validate_file_format_dpo(self, file_path: str) -> None:
158
+ """
159
+ Validates that the file at file_path is properly formatted JSONL with expected structure
160
+ for tool-use data (input/messages/tools/parallel_tool_calls and outputs).
161
+ Raises FileValidationError if validation fails.
162
+ """
163
+ try:
164
+ with open(file_path, "r", encoding="utf-8") as f:
165
+ for line_num, line in enumerate(f, 1):
166
+ line = line.strip()
167
+ if not line:
168
+ continue
169
+ try:
170
+ data = json.loads(line)
171
+
172
+ if not isinstance(data, dict):
173
+ raise FileValidationError(
174
+ f"Line {line_num}: Each line must be a JSON object"
175
+ )
176
+
177
+ input_data = data.get("input")
178
+ if not isinstance(input_data, dict):
179
+ raise FileValidationError(
180
+ f"Line {line_num}: Missing or invalid 'input' field"
181
+ )
182
+
183
+ if "messages" not in input_data or not isinstance(
184
+ input_data["messages"], list
185
+ ):
186
+ raise FileValidationError(
187
+ f"Line {line_num}: 'input.messages' must be a list"
188
+ )
189
+ for msg in input_data["messages"]:
190
+ if not isinstance(msg, dict):
191
+ raise FileValidationError(
192
+ f"Line {line_num}: Each 'message' must be an object"
193
+ )
194
+ if "role" not in msg or "content" not in msg:
195
+ raise FileValidationError(
196
+ f"Line {line_num}: Each message must have 'role' and 'content'"
197
+ )
198
+ if not isinstance(msg["role"], str) or not isinstance(
199
+ msg["content"], str
200
+ ):
201
+ raise FileValidationError(
202
+ f"Line {line_num}: 'role' and 'content' must be strings"
203
+ )
204
+
205
+ if "tools" not in input_data or not isinstance(
206
+ input_data["tools"], list
207
+ ):
208
+ raise FileValidationError(
209
+ f"Line {line_num}: 'input.tools' must be a list"
210
+ )
211
+
212
+ if "parallel_tool_calls" not in input_data or not isinstance(
213
+ input_data["parallel_tool_calls"], bool
214
+ ):
215
+ raise FileValidationError(
216
+ f"Line {line_num}: 'input.parallel_tool_calls' must be a boolean"
217
+ )
218
+
219
+ preferred = data.get("preferred_output")
220
+ if not isinstance(preferred, list):
221
+ raise FileValidationError(
222
+ f"Line {line_num}: 'preferred_output' must be a list"
223
+ )
224
+ for msg in preferred:
225
+ if not isinstance(msg, dict):
226
+ raise FileValidationError(
227
+ f"Line {line_num}: Each 'preferred_output' message must be an object"
228
+ )
229
+ if "role" not in msg or "content" not in msg:
230
+ raise FileValidationError(
231
+ f"Line {line_num}: Each preferred_output message must have 'role' and 'content'"
232
+ )
233
+ if not isinstance(msg["role"], str) or not isinstance(
234
+ msg["content"], str
235
+ ):
236
+ raise FileValidationError(
237
+ f"Line {line_num}: 'role' and 'content' in preferred_output must be strings"
238
+ )
105
239
 
106
- # Validate required structure
107
- if not isinstance(data, dict):
108
- raise FileValidationError(f"Line {line_num}: Each line must be a JSON object")
240
+ non_preferred = data.get("non_preferred_output")
241
+ if not isinstance(non_preferred, list):
242
+ raise FileValidationError(
243
+ f"Line {line_num}: 'non_preferred_output' must be a list"
244
+ )
245
+ for msg in non_preferred:
246
+ if not isinstance(msg, dict):
247
+ raise FileValidationError(
248
+ f"Line {line_num}: Each 'non_preferred_output' message must be an object"
249
+ )
250
+ if "role" not in msg or "content" not in msg:
251
+ raise FileValidationError(
252
+ f"Line {line_num}: Each non_preferred_output message must have 'role' and 'content'"
253
+ )
254
+ if not isinstance(msg["role"], str) or not isinstance(
255
+ msg["content"], str
256
+ ):
257
+ raise FileValidationError(
258
+ f"Line {line_num}: 'role' and 'content' in non_preferred_output must be strings"
259
+ )
109
260
 
110
- if "messages" not in data:
111
- raise FileValidationError(f"Line {line_num}: Missing 'messages' field")
261
+ except json.JSONDecodeError:
262
+ raise FileValidationError(f"Invalid JSON on line {line_num}")
112
263
 
113
- if not isinstance(data["messages"], list):
114
- raise FileValidationError(f"Line {line_num}: 'messages' must be an array")
264
+ except Exception as e:
265
+ raise FileValidationError(f"Failed to validate file: {e}")
115
266
 
116
- for msg in data["messages"]:
117
- if not isinstance(msg, dict):
118
- raise FileValidationError(f"Line {line_num}: Each message must be an object")
119
- if "role" not in msg or "content" not in msg:
120
- raise FileValidationError(f"Line {line_num}: Messages must have 'role' and 'content' fields")
121
- if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
122
- raise FileValidationError(f"Line {line_num}: Message 'role' and 'content' must be strings")
267
+ output_path = file_path.replace(".jsonl", "_formatted.jsonl")
123
268
 
124
- except json.JSONDecodeError:
125
- raise FileValidationError(f"Invalid JSON on line {line_num}")
269
+ with (
270
+ open(file_path, "r", encoding="utf-8") as fin,
271
+ open(output_path, "w", encoding="utf-8") as fout,
272
+ ):
273
+ for line_num, line in enumerate(fin, 1):
274
+ line = line.strip()
275
+ if not line:
276
+ continue
277
+ try:
278
+ data = json.loads(line)
279
+ prompt = data["input"]["messages"]
280
+ new_line = {
281
+ "chosen": data["preferred_output"],
282
+ "rejected": data["non_preferred_output"],
283
+ "prompt": prompt,
284
+ }
285
+ fout.write(json.dumps(new_line) + "\n")
286
+ except Exception as e:
287
+ print(f"Error parsing line {line_num}: {e}")
126
288
 
127
- except UnicodeDecodeError:
128
- raise FileValidationError("File must be valid UTF-8 encoded text")
289
+ os.replace(output_path, file_path)