arbor-ai 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/client/__init__.py +0 -0
- arbor/client/api.py +1 -0
- arbor/server/__init__.py +1 -0
- arbor/server/api/__init__.py +1 -0
- arbor/server/api/models/schemas.py +223 -0
- arbor/server/api/routes/__init__.py +0 -0
- arbor/server/api/routes/files.py +52 -0
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +117 -0
- arbor/server/core/__init__.py +1 -0
- arbor/server/core/config.py +47 -0
- arbor/server/core/logging.py +0 -0
- arbor/server/main.py +11 -0
- arbor/server/services/__init__.py +0 -0
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -0
- arbor/server/services/file_manager.py +289 -0
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +81 -0
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +561 -0
- arbor/server/utils/__init__.py +0 -0
- arbor/server/utils/helpers.py +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/METADATA +1 -1
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- arbor_ai-0.1.5.dist-info/RECORD +0 -8
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,561 @@
|
|
1
|
+
import os
|
2
|
+
import random
|
3
|
+
import string
|
4
|
+
import time
|
5
|
+
from datetime import datetime
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from arbor.server.api.models.schemas import FineTuneRequest
|
9
|
+
from arbor.server.core.config import Settings
|
10
|
+
from arbor.server.services.file_manager import FileManager
|
11
|
+
from arbor.server.services.job_manager import Job, JobEvent, JobStatus
|
12
|
+
|
13
|
+
|
14
|
+
class TrainingManager:
|
15
|
+
def __init__(self, settings: Settings):
|
16
|
+
self.settings = settings
|
17
|
+
|
18
|
+
def make_output_dir(self, request: FineTuneRequest):
|
19
|
+
model_name = request.model.split("/")[-1].lower()
|
20
|
+
suffix = (
|
21
|
+
request.suffix
|
22
|
+
if request.suffix is not None
|
23
|
+
else "".join(random.choices(string.ascii_letters + string.digits, k=6))
|
24
|
+
)
|
25
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
26
|
+
name = f"ft:{model_name}:{suffix}:{timestamp}"
|
27
|
+
return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
|
28
|
+
|
29
|
+
def find_train_args_sft(self, request: FineTuneRequest, file_manager: FileManager):
|
30
|
+
file = file_manager.get_file(request.training_file)
|
31
|
+
if file is None:
|
32
|
+
raise ValueError(f"Training file {request.training_file} not found")
|
33
|
+
|
34
|
+
data_path = file["path"]
|
35
|
+
file_manager.validate_file_format_sft(data_path)
|
36
|
+
|
37
|
+
name, output_dir = self.make_output_dir(request)
|
38
|
+
|
39
|
+
default_train_kwargs = {
|
40
|
+
"device": None,
|
41
|
+
"use_peft": False,
|
42
|
+
"num_train_epochs": 5,
|
43
|
+
"per_device_train_batch_size": 1,
|
44
|
+
"gradient_accumulation_steps": 8,
|
45
|
+
"learning_rate": 1e-5,
|
46
|
+
"max_seq_length": None,
|
47
|
+
"packing": True,
|
48
|
+
"bf16": True,
|
49
|
+
"output_dir": output_dir,
|
50
|
+
"train_data_path": data_path,
|
51
|
+
}
|
52
|
+
train_kwargs = {"packing": False}
|
53
|
+
train_kwargs = {**default_train_kwargs, **(train_kwargs or {})}
|
54
|
+
|
55
|
+
return train_kwargs
|
56
|
+
|
57
|
+
def find_train_args_dpo(self, request: FineTuneRequest, file_manager: FileManager):
|
58
|
+
file = file_manager.get_file(request.training_file)
|
59
|
+
if file is None:
|
60
|
+
raise ValueError(f"Training file {request.training_file} not found")
|
61
|
+
|
62
|
+
data_path = file["path"]
|
63
|
+
file_manager.validate_file_format_dpo(data_path)
|
64
|
+
|
65
|
+
name, output_dir = self.make_output_dir(request)
|
66
|
+
|
67
|
+
default_train_kwargs = {
|
68
|
+
"device": "cuda:2",
|
69
|
+
"use_peft": False,
|
70
|
+
"num_train_epochs": 5,
|
71
|
+
"per_device_train_batch_size": 1,
|
72
|
+
"gradient_accumulation_steps": 8,
|
73
|
+
"learning_rate": 1e-5,
|
74
|
+
"packing": True,
|
75
|
+
"bf16": True,
|
76
|
+
"output_dir": output_dir,
|
77
|
+
"train_data_path": data_path,
|
78
|
+
"prompt_length": 1024,
|
79
|
+
"max_seq_length": 1512,
|
80
|
+
"use_peft": False,
|
81
|
+
}
|
82
|
+
|
83
|
+
# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl#3-align-llm-with-trl-and-the-dpotrainer
|
84
|
+
|
85
|
+
train_kwargs = request.model_dump(exclude_unset=True)
|
86
|
+
train_kwargs = {**default_train_kwargs, **(train_kwargs or {})}
|
87
|
+
|
88
|
+
return train_kwargs
|
89
|
+
|
90
|
+
def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
|
91
|
+
|
92
|
+
job.status = JobStatus.RUNNING
|
93
|
+
job.add_event(
|
94
|
+
JobEvent(level="info", message="Starting fine-tuning job", data={})
|
95
|
+
)
|
96
|
+
|
97
|
+
fine_tune_type = request.method["type"]
|
98
|
+
if fine_tune_type == "dpo":
|
99
|
+
self.dpo_fine_tune(request, job, file_manager)
|
100
|
+
else:
|
101
|
+
self.sft_fine_tune(request, job, file_manager)
|
102
|
+
|
103
|
+
def dpo_fine_tune(
|
104
|
+
self, request: FineTuneRequest, job: Job, file_manager: FileManager
|
105
|
+
):
|
106
|
+
try:
|
107
|
+
|
108
|
+
job.status = JobStatus.RUNNING
|
109
|
+
job.add_event(
|
110
|
+
JobEvent(level="info", message="Starting fine-tuning job", data={})
|
111
|
+
)
|
112
|
+
|
113
|
+
train_kwargs = self.find_train_args_dpo(request, file_manager)
|
114
|
+
import torch
|
115
|
+
from transformers import (
|
116
|
+
AutoModelForCausalLM,
|
117
|
+
AutoTokenizer,
|
118
|
+
TrainingArguments,
|
119
|
+
)
|
120
|
+
from trl import DPOConfig, DPOTrainer, setup_chat_format
|
121
|
+
|
122
|
+
device = train_kwargs.get("device", None)
|
123
|
+
if device is None:
|
124
|
+
device = (
|
125
|
+
"cuda"
|
126
|
+
if torch.cuda.is_available()
|
127
|
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
128
|
+
)
|
129
|
+
|
130
|
+
job.add_event(
|
131
|
+
JobEvent(level="info", message=f"Using device: {device}", data={})
|
132
|
+
)
|
133
|
+
|
134
|
+
model = AutoModelForCausalLM.from_pretrained(
|
135
|
+
pretrained_model_name_or_path=request.model, device_map="auto"
|
136
|
+
)
|
137
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
138
|
+
pretrained_model_name_or_path=request.model
|
139
|
+
)
|
140
|
+
|
141
|
+
try:
|
142
|
+
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
|
143
|
+
except Exception:
|
144
|
+
pass
|
145
|
+
|
146
|
+
if tokenizer.pad_token_id is None:
|
147
|
+
job.add_event(
|
148
|
+
JobEvent(
|
149
|
+
level="info", message="Adding pad token to tokenizer", data={}
|
150
|
+
)
|
151
|
+
)
|
152
|
+
tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
|
153
|
+
|
154
|
+
hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
|
155
|
+
train_dataset = hf_dataset
|
156
|
+
|
157
|
+
use_peft = train_kwargs.get("use_peft", False)
|
158
|
+
peft_config = None
|
159
|
+
|
160
|
+
if use_peft:
|
161
|
+
from peft import LoraConfig
|
162
|
+
|
163
|
+
peft_config = LoraConfig(
|
164
|
+
lora_alpha=128,
|
165
|
+
lora_dropout=0.05,
|
166
|
+
r=256,
|
167
|
+
bias="none",
|
168
|
+
target_modules="all-linear",
|
169
|
+
task_type="CAUSAL_LM",
|
170
|
+
)
|
171
|
+
|
172
|
+
dpo_args = DPOConfig(
|
173
|
+
output_dir=train_kwargs["output_dir"],
|
174
|
+
num_train_epochs=train_kwargs["num_train_epochs"],
|
175
|
+
)
|
176
|
+
|
177
|
+
trainer = DPOTrainer(
|
178
|
+
model=model,
|
179
|
+
args=dpo_args,
|
180
|
+
train_dataset=train_dataset,
|
181
|
+
processing_class=tokenizer,
|
182
|
+
peft_config=peft_config,
|
183
|
+
)
|
184
|
+
|
185
|
+
trainer.train()
|
186
|
+
|
187
|
+
if job.status == JobStatus.PENDING_PAUSE:
|
188
|
+
trainer.accelerator.save_state(output_dir=dpo_args.output_dir)
|
189
|
+
current_step = trainer.state.global_step
|
190
|
+
job.status = JobStatus.PAUSED
|
191
|
+
job.add_event(
|
192
|
+
JobEvent(
|
193
|
+
level="info",
|
194
|
+
message="Training paused",
|
195
|
+
data={"step": current_step},
|
196
|
+
)
|
197
|
+
)
|
198
|
+
|
199
|
+
while job.status == JobStatus.PAUSED:
|
200
|
+
time.sleep(1) # Sleep to avoid busy waiting
|
201
|
+
if job.status == JobStatus.PENDING_RESUME:
|
202
|
+
job.status = JobStatus.RUNNING
|
203
|
+
job.add_event(JobEvent(level="info", message="Resuming training"))
|
204
|
+
trainer.accelerator.load_state(input_dir=dpo_args.output_dir)
|
205
|
+
trainer.train(resume_from_checkpoint=True)
|
206
|
+
|
207
|
+
if job.status == JobStatus.PENDING_CANCEL:
|
208
|
+
job.status = JobStatus.CANCELLED
|
209
|
+
job.add_event(JobEvent(level="info", message="Training cancelled"))
|
210
|
+
|
211
|
+
_cleanup(model, tokenizer, trainer)
|
212
|
+
raise Exception(
|
213
|
+
"Training cancelled"
|
214
|
+
) # not sure if this should be raised or just return None
|
215
|
+
|
216
|
+
job.add_event(
|
217
|
+
JobEvent(level="info", message="Training completed successfully")
|
218
|
+
)
|
219
|
+
|
220
|
+
job.add_event(JobEvent(level="info", message="Saving model", data={}))
|
221
|
+
# Save the model!
|
222
|
+
trainer.save_model()
|
223
|
+
job.add_event(
|
224
|
+
JobEvent(
|
225
|
+
level="info",
|
226
|
+
message="Model saved",
|
227
|
+
data={"location": dpo_args.output_dir},
|
228
|
+
)
|
229
|
+
)
|
230
|
+
|
231
|
+
MERGE = True
|
232
|
+
if use_peft and MERGE:
|
233
|
+
from peft import AutoPeftModelForCausalLM
|
234
|
+
|
235
|
+
# Load PEFT model on CPU
|
236
|
+
model_ = AutoPeftModelForCausalLM.from_pretrained(
|
237
|
+
pretrained_model_name_or_path=dpo_args.output_dir,
|
238
|
+
torch_dtype=torch.float16,
|
239
|
+
low_cpu_mem_usage=True,
|
240
|
+
)
|
241
|
+
|
242
|
+
merged_model = model_.merge_and_unload()
|
243
|
+
merged_model.save_pretrained(
|
244
|
+
dpo_args.output_dir, safe_serialization=True, max_shard_size="5GB"
|
245
|
+
)
|
246
|
+
|
247
|
+
_cleanup(model, tokenizer, trainer)
|
248
|
+
|
249
|
+
job.status = JobStatus.SUCCEEDED
|
250
|
+
job.fine_tuned_model = dpo_args.output_dir
|
251
|
+
except Exception as e:
|
252
|
+
job.add_event(
|
253
|
+
JobEvent(level="error", message=f"Training failed: {str(e)}", data={})
|
254
|
+
)
|
255
|
+
job.status = JobStatus.FAILED
|
256
|
+
raise
|
257
|
+
finally:
|
258
|
+
pass
|
259
|
+
|
260
|
+
return dpo_args.output_dir
|
261
|
+
|
262
|
+
def sft_fine_tune(
|
263
|
+
self, request: FineTuneRequest, job: Job, file_manager: FileManager
|
264
|
+
):
|
265
|
+
|
266
|
+
try:
|
267
|
+
train_kwargs = self.find_train_args_sft(request, file_manager)
|
268
|
+
|
269
|
+
import torch
|
270
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
271
|
+
from trl import SFTConfig, SFTTrainer, setup_chat_format
|
272
|
+
|
273
|
+
device = train_kwargs.get("device", None)
|
274
|
+
if device is None:
|
275
|
+
device = (
|
276
|
+
"cuda"
|
277
|
+
if torch.cuda.is_available()
|
278
|
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
279
|
+
)
|
280
|
+
job.add_event(
|
281
|
+
JobEvent(level="info", message=f"Using device: {device}", data={})
|
282
|
+
)
|
283
|
+
|
284
|
+
model = AutoModelForCausalLM.from_pretrained(
|
285
|
+
pretrained_model_name_or_path=request.model
|
286
|
+
).to(device)
|
287
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
288
|
+
pretrained_model_name_or_path=request.model
|
289
|
+
)
|
290
|
+
|
291
|
+
# Set up the chat format; generally only for non-chat model variants, hence the try-except.
|
292
|
+
try:
|
293
|
+
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
|
294
|
+
except Exception:
|
295
|
+
pass
|
296
|
+
|
297
|
+
if tokenizer.pad_token_id is None:
|
298
|
+
job.add_event(
|
299
|
+
JobEvent(
|
300
|
+
level="info", message="Adding pad token to tokenizer", data={}
|
301
|
+
)
|
302
|
+
)
|
303
|
+
tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
|
304
|
+
|
305
|
+
job.add_event(JobEvent(level="info", message="Creating dataset", data={}))
|
306
|
+
if (
|
307
|
+
"max_seq_length" not in train_kwargs
|
308
|
+
or train_kwargs["max_seq_length"] is None
|
309
|
+
):
|
310
|
+
train_kwargs["max_seq_length"] = 4096
|
311
|
+
job.add_event(
|
312
|
+
JobEvent(
|
313
|
+
level="info",
|
314
|
+
message=f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}",
|
315
|
+
data={},
|
316
|
+
)
|
317
|
+
)
|
318
|
+
|
319
|
+
job.add_event(JobEvent(level="info", message="Tokenizing dataset", data={}))
|
320
|
+
hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
|
321
|
+
|
322
|
+
def tokenize_function(example):
|
323
|
+
return encode_sft_example(
|
324
|
+
example, tokenizer, train_kwargs["max_seq_length"]
|
325
|
+
)
|
326
|
+
|
327
|
+
tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
|
328
|
+
tokenized_dataset.set_format(type="torch")
|
329
|
+
tokenized_dataset = tokenized_dataset.filter(
|
330
|
+
lambda example: (example["labels"] != -100).any()
|
331
|
+
)
|
332
|
+
|
333
|
+
USE_PEFT = train_kwargs.get("use_peft", False)
|
334
|
+
peft_config = None
|
335
|
+
|
336
|
+
if USE_PEFT:
|
337
|
+
from peft import LoraConfig
|
338
|
+
|
339
|
+
rank_dimension = 32
|
340
|
+
lora_alpha = 64
|
341
|
+
lora_dropout = 0.05
|
342
|
+
|
343
|
+
peft_config = LoraConfig(
|
344
|
+
r=rank_dimension,
|
345
|
+
lora_alpha=lora_alpha,
|
346
|
+
lora_dropout=lora_dropout,
|
347
|
+
bias="none",
|
348
|
+
target_modules="all-linear",
|
349
|
+
task_type="CAUSAL_LM",
|
350
|
+
)
|
351
|
+
|
352
|
+
sft_config = SFTConfig(
|
353
|
+
output_dir=train_kwargs["output_dir"],
|
354
|
+
num_train_epochs=train_kwargs["num_train_epochs"],
|
355
|
+
per_device_train_batch_size=train_kwargs["per_device_train_batch_size"],
|
356
|
+
gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"],
|
357
|
+
learning_rate=train_kwargs["learning_rate"],
|
358
|
+
max_grad_norm=2.0, # note that the current SFTConfig default is 1.0
|
359
|
+
logging_steps=20,
|
360
|
+
warmup_ratio=0.03,
|
361
|
+
lr_scheduler_type="constant",
|
362
|
+
save_steps=10_000,
|
363
|
+
bf16=train_kwargs["bf16"],
|
364
|
+
max_seq_length=train_kwargs["max_seq_length"],
|
365
|
+
packing=train_kwargs["packing"],
|
366
|
+
dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves
|
367
|
+
"add_special_tokens": False, # Special tokens handled by template
|
368
|
+
"append_concat_token": False, # No additional separator needed
|
369
|
+
},
|
370
|
+
)
|
371
|
+
|
372
|
+
job.add_event(JobEvent(level="info", message="Starting training"))
|
373
|
+
trainer = SFTTrainer(
|
374
|
+
model=model,
|
375
|
+
args=sft_config,
|
376
|
+
train_dataset=tokenized_dataset,
|
377
|
+
peft_config=peft_config,
|
378
|
+
)
|
379
|
+
|
380
|
+
# Train!
|
381
|
+
trainer.train()
|
382
|
+
|
383
|
+
if job.status == JobStatus.PENDING_PAUSE:
|
384
|
+
trainer.accelerator.save_state(output_dir=sft_config.output_dir)
|
385
|
+
current_step = trainer.state.global_step
|
386
|
+
job.status = JobStatus.PAUSED
|
387
|
+
job.add_event(
|
388
|
+
JobEvent(
|
389
|
+
level="info",
|
390
|
+
message="Training paused",
|
391
|
+
data={"step": current_step},
|
392
|
+
)
|
393
|
+
)
|
394
|
+
|
395
|
+
while job.status == JobStatus.PAUSED:
|
396
|
+
time.sleep(1) # Sleep to avoid busy waiting
|
397
|
+
if job.status == JobStatus.PENDING_RESUME:
|
398
|
+
job.status = JobStatus.RUNNING
|
399
|
+
job.add_event(JobEvent(level="info", message="Resuming training"))
|
400
|
+
trainer.accelerator.load_state(input_dir=sft_config.output_dir)
|
401
|
+
trainer.train(resume_from_checkpoint=True)
|
402
|
+
|
403
|
+
if job.status == JobStatus.PENDING_CANCEL:
|
404
|
+
job.status = JobStatus.CANCELLED
|
405
|
+
job.add_event(JobEvent(level="info", message="Training cancelled"))
|
406
|
+
|
407
|
+
_cleanup(model, tokenizer, trainer)
|
408
|
+
|
409
|
+
raise Exception(
|
410
|
+
"Training cancelled"
|
411
|
+
) # not sure if this should be raised or just return None
|
412
|
+
|
413
|
+
job.add_event(
|
414
|
+
JobEvent(level="info", message="Training completed successfully")
|
415
|
+
)
|
416
|
+
|
417
|
+
job.add_event(JobEvent(level="info", message="Saving model", data={}))
|
418
|
+
# Save the model!
|
419
|
+
trainer.save_model()
|
420
|
+
job.add_event(
|
421
|
+
JobEvent(
|
422
|
+
level="info",
|
423
|
+
message="Model saved",
|
424
|
+
data={"location": sft_config.output_dir},
|
425
|
+
)
|
426
|
+
)
|
427
|
+
|
428
|
+
MERGE = True
|
429
|
+
if USE_PEFT and MERGE:
|
430
|
+
from peft import AutoPeftModelForCausalLM
|
431
|
+
|
432
|
+
# Load PEFT model on CPU
|
433
|
+
model_ = AutoPeftModelForCausalLM.from_pretrained(
|
434
|
+
pretrained_model_name_or_path=sft_config.output_dir,
|
435
|
+
torch_dtype=torch.float16,
|
436
|
+
low_cpu_mem_usage=True,
|
437
|
+
)
|
438
|
+
|
439
|
+
merged_model = model_.merge_and_unload()
|
440
|
+
merged_model.save_pretrained(
|
441
|
+
sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
|
442
|
+
)
|
443
|
+
|
444
|
+
_cleanup(model, tokenizer, trainer)
|
445
|
+
|
446
|
+
job.status = JobStatus.SUCCEEDED
|
447
|
+
job.fine_tuned_model = sft_config.output_dir
|
448
|
+
except Exception as e:
|
449
|
+
job.add_event(
|
450
|
+
JobEvent(level="error", message=f"Training failed: {str(e)}", data={})
|
451
|
+
)
|
452
|
+
job.status = JobStatus.FAILED
|
453
|
+
raise
|
454
|
+
finally:
|
455
|
+
pass
|
456
|
+
|
457
|
+
return sft_config.output_dir
|
458
|
+
|
459
|
+
|
460
|
+
def dataset_from_file(data_path):
|
461
|
+
"""
|
462
|
+
Creates a HuggingFace Dataset from a JSONL file.
|
463
|
+
"""
|
464
|
+
from datasets import load_dataset
|
465
|
+
|
466
|
+
dataset = load_dataset("json", data_files=data_path, split="train")
|
467
|
+
return dataset
|
468
|
+
|
469
|
+
|
470
|
+
def encode_sft_example(example, tokenizer, max_seq_length):
|
471
|
+
"""
|
472
|
+
This function encodes a single example into a format that can be used for sft training.
|
473
|
+
Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
|
474
|
+
We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
|
475
|
+
|
476
|
+
Code obtained from the allenai/open-instruct repository: https://github.com/allenai/open-instruct/blob/4365dea3d1a6111e8b2712af06b22a4512a0df88/open_instruct/finetune.py
|
477
|
+
"""
|
478
|
+
import torch
|
479
|
+
|
480
|
+
messages = example["messages"]
|
481
|
+
if len(messages) == 0:
|
482
|
+
raise ValueError("messages field is empty.")
|
483
|
+
input_ids = tokenizer.apply_chat_template(
|
484
|
+
conversation=messages,
|
485
|
+
tokenize=True,
|
486
|
+
return_tensors="pt",
|
487
|
+
padding=False,
|
488
|
+
truncation=True,
|
489
|
+
max_length=max_seq_length,
|
490
|
+
add_generation_prompt=False,
|
491
|
+
)
|
492
|
+
labels = input_ids.clone()
|
493
|
+
# mask the non-assistant part for avoiding loss
|
494
|
+
for message_idx, message in enumerate(messages):
|
495
|
+
if message["role"] != "assistant":
|
496
|
+
# we calculate the start index of this non-assistant message
|
497
|
+
if message_idx == 0:
|
498
|
+
message_start_idx = 0
|
499
|
+
else:
|
500
|
+
message_start_idx = tokenizer.apply_chat_template(
|
501
|
+
conversation=messages[
|
502
|
+
:message_idx
|
503
|
+
], # here marks the end of the previous messages
|
504
|
+
tokenize=True,
|
505
|
+
return_tensors="pt",
|
506
|
+
padding=False,
|
507
|
+
truncation=True,
|
508
|
+
max_length=max_seq_length,
|
509
|
+
add_generation_prompt=False,
|
510
|
+
).shape[1]
|
511
|
+
# next, we calculate the end index of this non-assistant message
|
512
|
+
if (
|
513
|
+
message_idx < len(messages) - 1
|
514
|
+
and messages[message_idx + 1]["role"] == "assistant"
|
515
|
+
):
|
516
|
+
# for intermediate messages that follow with an assistant message, we need to
|
517
|
+
# set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
|
518
|
+
# (e.g., `<|assistant|>`)
|
519
|
+
message_end_idx = tokenizer.apply_chat_template(
|
520
|
+
conversation=messages[: message_idx + 1],
|
521
|
+
tokenize=True,
|
522
|
+
return_tensors="pt",
|
523
|
+
padding=False,
|
524
|
+
truncation=True,
|
525
|
+
max_length=max_seq_length,
|
526
|
+
add_generation_prompt=True,
|
527
|
+
).shape[1]
|
528
|
+
else:
|
529
|
+
# for the last message or the message that doesn't follow with an assistant message,
|
530
|
+
# we don't need to add the assistant generation prefix
|
531
|
+
message_end_idx = tokenizer.apply_chat_template(
|
532
|
+
conversation=messages[: message_idx + 1],
|
533
|
+
tokenize=True,
|
534
|
+
return_tensors="pt",
|
535
|
+
padding=False,
|
536
|
+
truncation=True,
|
537
|
+
max_length=max_seq_length,
|
538
|
+
add_generation_prompt=False,
|
539
|
+
).shape[1]
|
540
|
+
# set the label to -100 for the non-assistant part
|
541
|
+
labels[:, message_start_idx:message_end_idx] = -100
|
542
|
+
if max_seq_length and message_end_idx >= max_seq_length:
|
543
|
+
break
|
544
|
+
attention_mask = torch.ones_like(input_ids)
|
545
|
+
return {
|
546
|
+
"input_ids": input_ids.flatten(),
|
547
|
+
"labels": labels.flatten(),
|
548
|
+
"attention_mask": attention_mask.flatten(),
|
549
|
+
}
|
550
|
+
|
551
|
+
|
552
|
+
def _cleanup(model, tokenizer, trainer):
|
553
|
+
import gc
|
554
|
+
|
555
|
+
import torch
|
556
|
+
|
557
|
+
del model
|
558
|
+
del tokenizer
|
559
|
+
del trainer
|
560
|
+
gc.collect()
|
561
|
+
torch.cuda.empty_cache()
|
File without changes
|
File without changes
|
@@ -0,0 +1,34 @@
|
|
1
|
+
arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
|
3
|
+
arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
|
5
|
+
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
6
|
+
arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
|
7
|
+
arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
+
arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhcnrKIk,5594
|
9
|
+
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
|
+
arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
|
12
|
+
arbor/server/api/routes/inference.py,sha256=xlP-FMpOJAiiPZkE470l9mCR0ujLki8RrcO9hmTQD-k,1662
|
13
|
+
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
|
+
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
|
+
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
16
|
+
arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
+
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
+
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=T-f1TrNSi_kmxPOcpaDphS8Xf3UMUbricocc6fuaKIM,12077
|
21
|
+
arbor/server/services/inference_manager.py,sha256=qR9xPiYs4Is24vgeF72w7Hbe8j_PGEbl-qewcvUV-dA,9731
|
22
|
+
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
|
+
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
|
+
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
|
+
arbor/server/services/comms/comms.py,sha256=Dg08D2Fm5TAEiGyr0Qcr0uocabQpFD_sBVhxIkj9D2M,7424
|
26
|
+
arbor/server/services/scripts/grpo_training.py,sha256=V36pCMZDJj2DdzquxScOddi9zP8EVPGWN3HGiftFfrY,21082
|
27
|
+
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
|
+
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
+
arbor_ai-0.1.6.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
30
|
+
arbor_ai-0.1.6.dist-info/METADATA,sha256=ot1XsjoFawGbNBmaAaxv29lE_uNh8TTdl8OANkTPTS8,1823
|
31
|
+
arbor_ai-0.1.6.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
32
|
+
arbor_ai-0.1.6.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
33
|
+
arbor_ai-0.1.6.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
34
|
+
arbor_ai-0.1.6.dist-info/RECORD,,
|
arbor_ai-0.1.5.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
|
3
|
-
arbor_ai-0.1.5.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
4
|
-
arbor_ai-0.1.5.dist-info/METADATA,sha256=Tney6uOytHDMIZg3iqKrn2lgtaF3NULjXo19XdG_2Dw,1823
|
5
|
-
arbor_ai-0.1.5.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
6
|
-
arbor_ai-0.1.5.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
7
|
-
arbor_ai-0.1.5.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
8
|
-
arbor_ai-0.1.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|