arbor-ai 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/client/__init__.py +0 -0
- arbor/client/api.py +1 -0
- arbor/server/__init__.py +1 -0
- arbor/server/api/__init__.py +1 -0
- arbor/server/api/models/schemas.py +223 -0
- arbor/server/api/routes/__init__.py +0 -0
- arbor/server/api/routes/files.py +52 -0
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +117 -0
- arbor/server/core/__init__.py +1 -0
- arbor/server/core/config.py +47 -0
- arbor/server/core/logging.py +0 -0
- arbor/server/main.py +11 -0
- arbor/server/services/__init__.py +0 -0
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -0
- arbor/server/services/file_manager.py +289 -0
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +81 -0
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +561 -0
- arbor/server/utils/__init__.py +0 -0
- arbor/server/utils/helpers.py +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/METADATA +1 -1
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- arbor_ai-0.1.5.dist-info/RECORD +0 -8
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
|
|
1
|
+
import os
|
2
|
+
import queue
|
3
|
+
import socket
|
4
|
+
import threading
|
5
|
+
import time
|
6
|
+
|
7
|
+
import zmq
|
8
|
+
|
9
|
+
|
10
|
+
class ArborServerCommsHandler:
|
11
|
+
"""Handles socket communication between manager and training process"""
|
12
|
+
|
13
|
+
def __init__(self, host="localhost"):
|
14
|
+
self.host = host
|
15
|
+
self.context = zmq.Context()
|
16
|
+
|
17
|
+
# Command socket (REQ/REP pattern)
|
18
|
+
self.command_socket = self.context.socket(zmq.REQ)
|
19
|
+
self.command_port = self.command_socket.bind_to_random_port(f"tcp://{host}")
|
20
|
+
|
21
|
+
# Status socket (PUB/SUB pattern)
|
22
|
+
self.status_socket = self.context.socket(zmq.SUB)
|
23
|
+
self.status_port = self.status_socket.bind_to_random_port(f"tcp://{host}")
|
24
|
+
self.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
25
|
+
|
26
|
+
# Data socket (PUB/SUB pattern)
|
27
|
+
self.data_socket = self.context.socket(zmq.PUB)
|
28
|
+
self.data_port = self.data_socket.bind_to_random_port(f"tcp://{host}")
|
29
|
+
|
30
|
+
self.broadcast_socket = self.context.socket(zmq.PUB)
|
31
|
+
self.broadcast_port = self.broadcast_socket.bind_to_random_port(f"tcp://{host}")
|
32
|
+
|
33
|
+
self.handshake_socket = self.context.socket(zmq.REP)
|
34
|
+
self.handshake_port = self.handshake_socket.bind_to_random_port(f"tcp://{host}")
|
35
|
+
|
36
|
+
def send_command(self, command):
|
37
|
+
self.command_socket.send_json(command)
|
38
|
+
return self.command_socket.recv_json() # Wait for acknowledgment
|
39
|
+
|
40
|
+
def send_data(self, data):
|
41
|
+
self.data_socket.send_json(data)
|
42
|
+
|
43
|
+
def send_broadcast(self, message):
|
44
|
+
self.broadcast_socket.send_json(message)
|
45
|
+
|
46
|
+
def receive_status(self):
|
47
|
+
while True:
|
48
|
+
status = self.status_socket.recv_json()
|
49
|
+
yield status
|
50
|
+
|
51
|
+
def close(self):
|
52
|
+
self.command_socket.close()
|
53
|
+
self.status_socket.close()
|
54
|
+
self.data_socket.close()
|
55
|
+
self.broadcast_socket.close()
|
56
|
+
self.handshake_socket.close()
|
57
|
+
self.context.term()
|
58
|
+
|
59
|
+
def wait_for_clients(self, expected_count):
|
60
|
+
connected_clients = []
|
61
|
+
while len(connected_clients) < expected_count:
|
62
|
+
print(f"Waiting for {expected_count} clients to connect...")
|
63
|
+
msg = self.handshake_socket.recv_json()
|
64
|
+
if msg.get("type") == "hello":
|
65
|
+
client_id = msg.get("client_id")
|
66
|
+
connected_clients.append(client_id)
|
67
|
+
self.handshake_socket.send_json({"status": "ack"})
|
68
|
+
print(f"Received handshake from {client_id}")
|
69
|
+
print(f"All {expected_count} clients connected!")
|
70
|
+
|
71
|
+
|
72
|
+
class ArborScriptCommsHandler:
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
host,
|
76
|
+
command_port,
|
77
|
+
status_port,
|
78
|
+
data_port,
|
79
|
+
broadcast_port,
|
80
|
+
handshake_port,
|
81
|
+
is_main_process,
|
82
|
+
):
|
83
|
+
self.context = zmq.Context()
|
84
|
+
self.is_main_process = is_main_process
|
85
|
+
|
86
|
+
# Command socket (main process only)
|
87
|
+
if is_main_process:
|
88
|
+
self.command_socket = self.context.socket(zmq.REP)
|
89
|
+
self.command_socket.connect(f"tcp://{host}:{command_port}")
|
90
|
+
|
91
|
+
self.status_socket = self.context.socket(zmq.PUB)
|
92
|
+
self.status_socket.connect(f"tcp://{host}:{status_port}")
|
93
|
+
else:
|
94
|
+
self.command_socket = None
|
95
|
+
self.status_socket = None
|
96
|
+
|
97
|
+
# Data socket (all processes)
|
98
|
+
self.data_socket = self.context.socket(zmq.SUB)
|
99
|
+
self.data_socket.connect(f"tcp://{host}:{data_port}")
|
100
|
+
self.data_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
101
|
+
self.data_queue = queue.Queue()
|
102
|
+
self._start_data_receiver()
|
103
|
+
|
104
|
+
# Broadcast socket (all processes)
|
105
|
+
self.broadcast_socket = self.context.socket(zmq.SUB)
|
106
|
+
self.broadcast_socket.connect(f"tcp://{host}:{broadcast_port}")
|
107
|
+
self.broadcast_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
108
|
+
|
109
|
+
# Handshake socket (all processes)
|
110
|
+
self.handshake_socket = self.context.socket(zmq.REQ)
|
111
|
+
self.handshake_socket.connect(f"tcp://{host}:{handshake_port}")
|
112
|
+
self._send_handshake()
|
113
|
+
|
114
|
+
def send_status(self, status):
|
115
|
+
if self.status_socket is not None:
|
116
|
+
self.status_socket.send_json(status)
|
117
|
+
|
118
|
+
def receive_command(self):
|
119
|
+
if self.command_socket is not None:
|
120
|
+
while True:
|
121
|
+
command = self.command_socket.recv_json()
|
122
|
+
# Send acknowledgment
|
123
|
+
self.command_socket.send_json({"status": "received"})
|
124
|
+
yield command
|
125
|
+
|
126
|
+
def receive_data(self):
|
127
|
+
return self.data_queue.get()
|
128
|
+
|
129
|
+
def _start_data_receiver(self):
|
130
|
+
def _receiver():
|
131
|
+
while True:
|
132
|
+
try:
|
133
|
+
data = self.data_socket.recv_json()
|
134
|
+
self.data_queue.put(data)
|
135
|
+
except Exception as e:
|
136
|
+
print(f"Error receiving data: {e}")
|
137
|
+
break
|
138
|
+
|
139
|
+
self.receiver_thread = threading.Thread(target=_receiver, daemon=True)
|
140
|
+
self.receiver_thread.start()
|
141
|
+
|
142
|
+
def is_data_queue_empty(self):
|
143
|
+
return self.data_queue.empty()
|
144
|
+
|
145
|
+
def get_data_queue_size(self):
|
146
|
+
return self.data_queue.qsize()
|
147
|
+
|
148
|
+
def receive_broadcast(self):
|
149
|
+
while True:
|
150
|
+
broadcast = self.broadcast_socket.recv_json()
|
151
|
+
yield broadcast
|
152
|
+
|
153
|
+
def close(self):
|
154
|
+
if self.command_socket is not None:
|
155
|
+
self.command_socket.close()
|
156
|
+
if self.status_socket is not None:
|
157
|
+
self.status_socket.close()
|
158
|
+
self.data_socket.close()
|
159
|
+
self.broadcast_socket.close()
|
160
|
+
self.handshake_socket.close()
|
161
|
+
self.context.term()
|
162
|
+
|
163
|
+
def _get_client_id(self):
|
164
|
+
# Return a unique identifier for this client (could be hostname, PID, etc.)
|
165
|
+
return f"{socket.gethostname()}_{os.getpid()}"
|
166
|
+
|
167
|
+
def _send_handshake(self):
|
168
|
+
print(f"Sending handshake to {self.handshake_socket}")
|
169
|
+
self.handshake_socket.send_json(
|
170
|
+
{"type": "hello", "client_id": self._get_client_id()}
|
171
|
+
)
|
172
|
+
self.handshake_socket.recv_json() # Wait for ack
|
173
|
+
|
174
|
+
|
175
|
+
if __name__ == "__main__":
|
176
|
+
|
177
|
+
def _server_thread(server_comms):
|
178
|
+
server_comms.wait_for_clients(expected_count=3)
|
179
|
+
server_comms.send_data({"data": "test"})
|
180
|
+
# server_comms.send_command({"command": "test"})
|
181
|
+
# print("Server sent command")
|
182
|
+
|
183
|
+
def _client_thread(script_comms):
|
184
|
+
for data in script_comms.receive_data():
|
185
|
+
print("Client received data:", data)
|
186
|
+
|
187
|
+
server_comms = ArborServerCommsHandler()
|
188
|
+
t1 = threading.Thread(target=_server_thread, args=(server_comms,))
|
189
|
+
t1.start()
|
190
|
+
print("Server started")
|
191
|
+
|
192
|
+
client_threads = []
|
193
|
+
script_comms_list = []
|
194
|
+
for i in range(3):
|
195
|
+
script_comms = ArborScriptCommsHandler(
|
196
|
+
"localhost",
|
197
|
+
server_comms.command_port,
|
198
|
+
server_comms.status_port,
|
199
|
+
server_comms.data_port,
|
200
|
+
server_comms.broadcast_port,
|
201
|
+
server_comms.handshake_port,
|
202
|
+
False,
|
203
|
+
)
|
204
|
+
t = threading.Thread(target=_client_thread, args=(script_comms,))
|
205
|
+
t.start()
|
206
|
+
script_comms_list.append(script_comms)
|
207
|
+
|
208
|
+
import time
|
209
|
+
|
210
|
+
time.sleep(1)
|
211
|
+
import pdb
|
212
|
+
|
213
|
+
pdb.set_trace()
|
214
|
+
|
215
|
+
try:
|
216
|
+
t1.join()
|
217
|
+
for t in client_threads:
|
218
|
+
t.join()
|
219
|
+
except KeyboardInterrupt:
|
220
|
+
print("Keyboard interrupt")
|
221
|
+
except Exception as e:
|
222
|
+
print(f"Error: {e}")
|
223
|
+
finally:
|
224
|
+
for script_comms in script_comms_list:
|
225
|
+
script_comms.close()
|
226
|
+
server_comms.close()
|
File without changes
|
@@ -0,0 +1,289 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import shutil
|
4
|
+
import time
|
5
|
+
import uuid
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from fastapi import UploadFile
|
9
|
+
|
10
|
+
from arbor.server.core.config import Settings
|
11
|
+
|
12
|
+
|
13
|
+
class FileValidationError(Exception):
|
14
|
+
"""Custom exception for file validation errors"""
|
15
|
+
|
16
|
+
pass
|
17
|
+
|
18
|
+
|
19
|
+
class FileManager:
|
20
|
+
def __init__(self, settings: Settings):
|
21
|
+
self.uploads_dir = Path(settings.STORAGE_PATH) / "uploads"
|
22
|
+
self.uploads_dir.mkdir(parents=True, exist_ok=True)
|
23
|
+
self.files = self.load_files_from_uploads()
|
24
|
+
|
25
|
+
def load_files_from_uploads(self):
|
26
|
+
files = {}
|
27
|
+
|
28
|
+
# Scan through all directories in uploads directory
|
29
|
+
for dir_path in self.uploads_dir.glob("*"):
|
30
|
+
if not dir_path.is_dir():
|
31
|
+
continue
|
32
|
+
|
33
|
+
# Check for metadata.json
|
34
|
+
metadata_path = dir_path / "metadata.json"
|
35
|
+
if not metadata_path.exists():
|
36
|
+
continue
|
37
|
+
|
38
|
+
# Load metadata
|
39
|
+
with open(metadata_path) as f:
|
40
|
+
metadata = json.load(f)
|
41
|
+
|
42
|
+
# Find the .jsonl file
|
43
|
+
jsonl_files = list(dir_path.glob("*.jsonl"))
|
44
|
+
if not jsonl_files:
|
45
|
+
continue
|
46
|
+
|
47
|
+
file_path = jsonl_files[0]
|
48
|
+
files[dir_path.name] = {
|
49
|
+
"path": str(file_path),
|
50
|
+
"purpose": metadata.get("purpose", "training"),
|
51
|
+
"bytes": file_path.stat().st_size,
|
52
|
+
"created_at": metadata.get(
|
53
|
+
"created_at", int(file_path.stat().st_mtime)
|
54
|
+
),
|
55
|
+
"filename": metadata.get("filename", file_path.name),
|
56
|
+
}
|
57
|
+
|
58
|
+
return files
|
59
|
+
|
60
|
+
def save_uploaded_file(self, file: UploadFile):
|
61
|
+
file_id = f"file-{str(uuid.uuid4())}"
|
62
|
+
dir_path = self.uploads_dir / file_id
|
63
|
+
dir_path.mkdir(exist_ok=True)
|
64
|
+
|
65
|
+
# Save the actual file
|
66
|
+
file_path = dir_path / f"data.jsonl"
|
67
|
+
with open(file_path, "wb") as f:
|
68
|
+
shutil.copyfileobj(file.file, f)
|
69
|
+
|
70
|
+
# Create metadata
|
71
|
+
metadata = {
|
72
|
+
"purpose": "training",
|
73
|
+
"created_at": int(time.time()),
|
74
|
+
"filename": file.filename,
|
75
|
+
}
|
76
|
+
|
77
|
+
# Save metadata
|
78
|
+
with open(dir_path / "metadata.json", "w") as f:
|
79
|
+
json.dump(metadata, f)
|
80
|
+
|
81
|
+
file_data = {
|
82
|
+
"id": file_id,
|
83
|
+
"path": str(file_path),
|
84
|
+
"purpose": metadata["purpose"],
|
85
|
+
"bytes": file.size,
|
86
|
+
"created_at": metadata["created_at"],
|
87
|
+
"filename": metadata["filename"],
|
88
|
+
}
|
89
|
+
|
90
|
+
self.files[file_id] = file_data
|
91
|
+
return file_data
|
92
|
+
|
93
|
+
def get_file(self, file_id: str):
|
94
|
+
return self.files[file_id]
|
95
|
+
|
96
|
+
def delete_file(self, file_id: str):
|
97
|
+
if file_id not in self.files:
|
98
|
+
return
|
99
|
+
|
100
|
+
dir_path = self.uploads_dir / file_id
|
101
|
+
if dir_path.exists():
|
102
|
+
shutil.rmtree(dir_path)
|
103
|
+
|
104
|
+
del self.files[file_id]
|
105
|
+
|
106
|
+
def validate_file_format_sft(self, file_path: str) -> None:
|
107
|
+
"""
|
108
|
+
Validates that the file at file_path is properly formatted JSONL with expected structure.
|
109
|
+
Raises FileValidationError if validation fails.
|
110
|
+
"""
|
111
|
+
try:
|
112
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
113
|
+
for line_num, line in enumerate(f, 1):
|
114
|
+
line = line.strip()
|
115
|
+
if not line:
|
116
|
+
continue # skip empty lines
|
117
|
+
try:
|
118
|
+
data = json.loads(line)
|
119
|
+
|
120
|
+
if not isinstance(data, dict):
|
121
|
+
raise FileValidationError(
|
122
|
+
f"Line {line_num}: Each line must be a JSON object"
|
123
|
+
)
|
124
|
+
|
125
|
+
if "messages" not in data:
|
126
|
+
raise FileValidationError(
|
127
|
+
f"Line {line_num}: Missing 'messages' field"
|
128
|
+
)
|
129
|
+
|
130
|
+
if not isinstance(data["messages"], list):
|
131
|
+
raise FileValidationError(
|
132
|
+
f"Line {line_num}: 'messages' must be an array"
|
133
|
+
)
|
134
|
+
|
135
|
+
for msg in data["messages"]:
|
136
|
+
if not isinstance(msg, dict):
|
137
|
+
raise FileValidationError(
|
138
|
+
f"Line {line_num}: Each message must be an object"
|
139
|
+
)
|
140
|
+
if "role" not in msg or "content" not in msg:
|
141
|
+
raise FileValidationError(
|
142
|
+
f"Line {line_num}: Messages must have 'role' and 'content' fields"
|
143
|
+
)
|
144
|
+
if not isinstance(msg["role"], str) or not isinstance(
|
145
|
+
msg["content"], str
|
146
|
+
):
|
147
|
+
raise FileValidationError(
|
148
|
+
f"Line {line_num}: Message 'role' and 'content' must be strings"
|
149
|
+
)
|
150
|
+
|
151
|
+
except json.JSONDecodeError:
|
152
|
+
raise FileValidationError(f"Invalid JSON on line {line_num}")
|
153
|
+
|
154
|
+
except Exception as e:
|
155
|
+
raise FileValidationError(f"Failed to read or validate file: {e}")
|
156
|
+
|
157
|
+
def validate_file_format_dpo(self, file_path: str) -> None:
|
158
|
+
"""
|
159
|
+
Validates that the file at file_path is properly formatted JSONL with expected structure
|
160
|
+
for tool-use data (input/messages/tools/parallel_tool_calls and outputs).
|
161
|
+
Raises FileValidationError if validation fails.
|
162
|
+
"""
|
163
|
+
try:
|
164
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
165
|
+
for line_num, line in enumerate(f, 1):
|
166
|
+
line = line.strip()
|
167
|
+
if not line:
|
168
|
+
continue
|
169
|
+
try:
|
170
|
+
data = json.loads(line)
|
171
|
+
|
172
|
+
if not isinstance(data, dict):
|
173
|
+
raise FileValidationError(
|
174
|
+
f"Line {line_num}: Each line must be a JSON object"
|
175
|
+
)
|
176
|
+
|
177
|
+
input_data = data.get("input")
|
178
|
+
if not isinstance(input_data, dict):
|
179
|
+
raise FileValidationError(
|
180
|
+
f"Line {line_num}: Missing or invalid 'input' field"
|
181
|
+
)
|
182
|
+
|
183
|
+
if "messages" not in input_data or not isinstance(
|
184
|
+
input_data["messages"], list
|
185
|
+
):
|
186
|
+
raise FileValidationError(
|
187
|
+
f"Line {line_num}: 'input.messages' must be a list"
|
188
|
+
)
|
189
|
+
for msg in input_data["messages"]:
|
190
|
+
if not isinstance(msg, dict):
|
191
|
+
raise FileValidationError(
|
192
|
+
f"Line {line_num}: Each 'message' must be an object"
|
193
|
+
)
|
194
|
+
if "role" not in msg or "content" not in msg:
|
195
|
+
raise FileValidationError(
|
196
|
+
f"Line {line_num}: Each message must have 'role' and 'content'"
|
197
|
+
)
|
198
|
+
if not isinstance(msg["role"], str) or not isinstance(
|
199
|
+
msg["content"], str
|
200
|
+
):
|
201
|
+
raise FileValidationError(
|
202
|
+
f"Line {line_num}: 'role' and 'content' must be strings"
|
203
|
+
)
|
204
|
+
|
205
|
+
if "tools" not in input_data or not isinstance(
|
206
|
+
input_data["tools"], list
|
207
|
+
):
|
208
|
+
raise FileValidationError(
|
209
|
+
f"Line {line_num}: 'input.tools' must be a list"
|
210
|
+
)
|
211
|
+
|
212
|
+
if "parallel_tool_calls" not in input_data or not isinstance(
|
213
|
+
input_data["parallel_tool_calls"], bool
|
214
|
+
):
|
215
|
+
raise FileValidationError(
|
216
|
+
f"Line {line_num}: 'input.parallel_tool_calls' must be a boolean"
|
217
|
+
)
|
218
|
+
|
219
|
+
preferred = data.get("preferred_output")
|
220
|
+
if not isinstance(preferred, list):
|
221
|
+
raise FileValidationError(
|
222
|
+
f"Line {line_num}: 'preferred_output' must be a list"
|
223
|
+
)
|
224
|
+
for msg in preferred:
|
225
|
+
if not isinstance(msg, dict):
|
226
|
+
raise FileValidationError(
|
227
|
+
f"Line {line_num}: Each 'preferred_output' message must be an object"
|
228
|
+
)
|
229
|
+
if "role" not in msg or "content" not in msg:
|
230
|
+
raise FileValidationError(
|
231
|
+
f"Line {line_num}: Each preferred_output message must have 'role' and 'content'"
|
232
|
+
)
|
233
|
+
if not isinstance(msg["role"], str) or not isinstance(
|
234
|
+
msg["content"], str
|
235
|
+
):
|
236
|
+
raise FileValidationError(
|
237
|
+
f"Line {line_num}: 'role' and 'content' in preferred_output must be strings"
|
238
|
+
)
|
239
|
+
|
240
|
+
non_preferred = data.get("non_preferred_output")
|
241
|
+
if not isinstance(non_preferred, list):
|
242
|
+
raise FileValidationError(
|
243
|
+
f"Line {line_num}: 'non_preferred_output' must be a list"
|
244
|
+
)
|
245
|
+
for msg in non_preferred:
|
246
|
+
if not isinstance(msg, dict):
|
247
|
+
raise FileValidationError(
|
248
|
+
f"Line {line_num}: Each 'non_preferred_output' message must be an object"
|
249
|
+
)
|
250
|
+
if "role" not in msg or "content" not in msg:
|
251
|
+
raise FileValidationError(
|
252
|
+
f"Line {line_num}: Each non_preferred_output message must have 'role' and 'content'"
|
253
|
+
)
|
254
|
+
if not isinstance(msg["role"], str) or not isinstance(
|
255
|
+
msg["content"], str
|
256
|
+
):
|
257
|
+
raise FileValidationError(
|
258
|
+
f"Line {line_num}: 'role' and 'content' in non_preferred_output must be strings"
|
259
|
+
)
|
260
|
+
|
261
|
+
except json.JSONDecodeError:
|
262
|
+
raise FileValidationError(f"Invalid JSON on line {line_num}")
|
263
|
+
|
264
|
+
except Exception as e:
|
265
|
+
raise FileValidationError(f"Failed to validate file: {e}")
|
266
|
+
|
267
|
+
output_path = file_path.replace(".jsonl", "_formatted.jsonl")
|
268
|
+
|
269
|
+
with (
|
270
|
+
open(file_path, "r", encoding="utf-8") as fin,
|
271
|
+
open(output_path, "w", encoding="utf-8") as fout,
|
272
|
+
):
|
273
|
+
for line_num, line in enumerate(fin, 1):
|
274
|
+
line = line.strip()
|
275
|
+
if not line:
|
276
|
+
continue
|
277
|
+
try:
|
278
|
+
data = json.loads(line)
|
279
|
+
prompt = data["input"]["messages"]
|
280
|
+
new_line = {
|
281
|
+
"chosen": data["preferred_output"],
|
282
|
+
"rejected": data["non_preferred_output"],
|
283
|
+
"prompt": prompt,
|
284
|
+
}
|
285
|
+
fout.write(json.dumps(new_line) + "\n")
|
286
|
+
except Exception as e:
|
287
|
+
print(f"Error parsing line {line_num}: {e}")
|
288
|
+
|
289
|
+
os.replace(output_path, file_path)
|