arbor-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +0 -0
- arbor/cli.py +17 -0
- arbor/client/__init__.py +0 -0
- arbor/client/api.py +2 -0
- arbor/server/__init__.py +1 -0
- arbor/server/api/__init__.py +1 -0
- arbor/server/api/routes/__init__.py +0 -0
- arbor/server/api/routes/files.py +13 -0
- arbor/server/api/routes/jobs.py +14 -0
- arbor/server/api/routes/training.py +16 -0
- arbor/server/core/__init__.py +1 -0
- arbor/server/core/config.py +10 -0
- arbor/server/core/logging.py +0 -0
- arbor/server/main.py +10 -0
- arbor/server/services/__init__.py +0 -0
- arbor/server/services/dependencies.py +16 -0
- arbor/server/services/file_manager.py +83 -0
- arbor/server/services/job_manager.py +72 -0
- arbor/server/services/training_manager.py +263 -0
- arbor/server/utils/__init__.py +0 -0
- arbor/server/utils/helpers.py +0 -0
- arbor_ai-0.1.0.dist-info/METADATA +47 -0
- arbor_ai-0.1.0.dist-info/RECORD +25 -0
- arbor_ai-0.1.0.dist-info/WHEEL +4 -0
- arbor_ai-0.1.0.dist-info/entry_points.txt +3 -0
arbor/__init__.py
ADDED
File without changes
|
arbor/cli.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
import click
|
2
|
+
import uvicorn
|
3
|
+
from arbor.server.main import app
|
4
|
+
|
5
|
+
@click.group()
|
6
|
+
def cli():
|
7
|
+
pass
|
8
|
+
|
9
|
+
@cli.command()
|
10
|
+
@click.option('--host', default='0.0.0.0', help='Host to bind to')
|
11
|
+
@click.option('--port', default=8000, help='Port to bind to')
|
12
|
+
def serve(host, port):
|
13
|
+
"""Start the Arbor API server"""
|
14
|
+
uvicorn.run(app, host=host, port=port)
|
15
|
+
|
16
|
+
if __name__ == '__main__':
|
17
|
+
cli()
|
arbor/client/__init__.py
ADDED
File without changes
|
arbor/client/api.py
ADDED
arbor/server/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -0,0 +1 @@
|
|
1
|
+
|
File without changes
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from fastapi import APIRouter, UploadFile, File, Depends
|
2
|
+
from arbor.server.services.file_manager import FileManager
|
3
|
+
from arbor.server.api.models.schemas import FileResponse
|
4
|
+
from arbor.server.services.dependencies import get_file_manager
|
5
|
+
|
6
|
+
router = APIRouter()
|
7
|
+
|
8
|
+
@router.post("", response_model=FileResponse)
|
9
|
+
def upload_file(
|
10
|
+
file: UploadFile = File(...),
|
11
|
+
file_manager: FileManager = Depends(get_file_manager)
|
12
|
+
):
|
13
|
+
return file_manager.save_uploaded_file(file)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from fastapi import APIRouter, Depends
|
2
|
+
from arbor.server.services.job_manager import JobManager
|
3
|
+
from arbor.server.services.dependencies import get_job_manager
|
4
|
+
from arbor.server.api.models.schemas import JobStatusResponse
|
5
|
+
|
6
|
+
router = APIRouter()
|
7
|
+
|
8
|
+
@router.get("/{job_id}", response_model=JobStatusResponse)
|
9
|
+
def get_job_status(
|
10
|
+
job_id: str,
|
11
|
+
job_manager: JobManager = Depends(get_job_manager)
|
12
|
+
):
|
13
|
+
status = job_manager.get_job_status(job_id)
|
14
|
+
return JobStatusResponse(job_id=job_id, status=status.value)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from fastapi import APIRouter, BackgroundTasks, Depends
|
2
|
+
|
3
|
+
from arbor.server.api.models.schemas import FineTuneRequest, JobStatusResponse
|
4
|
+
from arbor.server.services.job_manager import JobManager, JobStatus
|
5
|
+
from arbor.server.services.file_manager import FileManager
|
6
|
+
from arbor.server.services.training_manager import TrainingManager
|
7
|
+
from arbor.server.services.dependencies import get_training_manager, get_job_manager, get_file_manager
|
8
|
+
|
9
|
+
router = APIRouter()
|
10
|
+
|
11
|
+
@router.post("", response_model=JobStatusResponse)
|
12
|
+
def fine_tune(request: FineTuneRequest, background_tasks: BackgroundTasks, training_manager: TrainingManager = Depends(get_training_manager), job_manager: JobManager = Depends(get_job_manager), file_manager: FileManager = Depends(get_file_manager)):
|
13
|
+
job = job_manager.create_job()
|
14
|
+
background_tasks.add_task(training_manager.fine_tune, request, job, file_manager)
|
15
|
+
job.status = JobStatus.QUEUED
|
16
|
+
return JobStatusResponse(job_id=job.id, status=job.status.value)
|
@@ -0,0 +1 @@
|
|
1
|
+
|
File without changes
|
arbor/server/main.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
from fastapi import FastAPI
|
2
|
+
from arbor.server.api.routes import training, files, jobs
|
3
|
+
from arbor.server.core.config import settings
|
4
|
+
|
5
|
+
app = FastAPI(title="Arbor API")
|
6
|
+
|
7
|
+
# Include routers
|
8
|
+
app.include_router(training.router, prefix="/api/fine-tune")
|
9
|
+
app.include_router(files.router, prefix="/api/files")
|
10
|
+
app.include_router(jobs.router, prefix="/api/job")
|
File without changes
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from arbor.server.services.file_manager import FileManager
|
3
|
+
from arbor.server.services.job_manager import JobManager
|
4
|
+
from arbor.server.services.training_manager import TrainingManager
|
5
|
+
|
6
|
+
@lru_cache()
|
7
|
+
def get_file_manager() -> FileManager:
|
8
|
+
return FileManager()
|
9
|
+
|
10
|
+
@lru_cache()
|
11
|
+
def get_job_manager() -> JobManager:
|
12
|
+
return JobManager()
|
13
|
+
|
14
|
+
@lru_cache()
|
15
|
+
def get_training_manager() -> TrainingManager:
|
16
|
+
return TrainingManager()
|
@@ -0,0 +1,83 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import shutil
|
5
|
+
import time
|
6
|
+
import uuid
|
7
|
+
from fastapi import UploadFile
|
8
|
+
from arbor.server.api.models.schemas import FileResponse
|
9
|
+
|
10
|
+
class FileManager:
|
11
|
+
def __init__(self):
|
12
|
+
self.uploads_dir = Path("uploads")
|
13
|
+
self.uploads_dir.mkdir(exist_ok=True)
|
14
|
+
self.files = self.load_files_from_uploads()
|
15
|
+
|
16
|
+
def load_files_from_uploads(self):
|
17
|
+
files = {}
|
18
|
+
|
19
|
+
# Scan through all directories in uploads directory
|
20
|
+
for dir_path in self.uploads_dir.glob("*"):
|
21
|
+
if not dir_path.is_dir():
|
22
|
+
continue
|
23
|
+
|
24
|
+
# Check for metadata.json
|
25
|
+
metadata_path = dir_path / "metadata.json"
|
26
|
+
if not metadata_path.exists():
|
27
|
+
continue
|
28
|
+
|
29
|
+
# Load metadata
|
30
|
+
with open(metadata_path) as f:
|
31
|
+
metadata = json.load(f)
|
32
|
+
|
33
|
+
# Find the .jsonl file
|
34
|
+
jsonl_files = list(dir_path.glob("*.jsonl"))
|
35
|
+
if not jsonl_files:
|
36
|
+
continue
|
37
|
+
|
38
|
+
file_path = jsonl_files[0]
|
39
|
+
files[dir_path.name] = {
|
40
|
+
"path": str(file_path),
|
41
|
+
"purpose": metadata.get("purpose", "training"),
|
42
|
+
"bytes": file_path.stat().st_size,
|
43
|
+
"created_at": metadata.get("created_at", int(file_path.stat().st_mtime)),
|
44
|
+
"filename": metadata.get("filename", file_path.name)
|
45
|
+
}
|
46
|
+
|
47
|
+
return files
|
48
|
+
|
49
|
+
def save_uploaded_file(self, file: UploadFile) -> FileResponse:
|
50
|
+
file_id = str(uuid.uuid4())
|
51
|
+
dir_path = self.uploads_dir / file_id
|
52
|
+
dir_path.mkdir(exist_ok=True)
|
53
|
+
|
54
|
+
# Save the actual file
|
55
|
+
file_path = dir_path / f"data.jsonl"
|
56
|
+
with open(file_path, "wb") as f:
|
57
|
+
shutil.copyfileobj(file.file, f)
|
58
|
+
|
59
|
+
# Create metadata
|
60
|
+
metadata = {
|
61
|
+
"purpose": "training",
|
62
|
+
"created_at": int(time.time()),
|
63
|
+
"filename": file.filename
|
64
|
+
}
|
65
|
+
|
66
|
+
# Save metadata
|
67
|
+
with open(dir_path / "metadata.json", "w") as f:
|
68
|
+
json.dump(metadata, f)
|
69
|
+
|
70
|
+
file_data = {
|
71
|
+
"id": file_id,
|
72
|
+
"path": str(file_path),
|
73
|
+
"purpose": metadata["purpose"],
|
74
|
+
"bytes": file.size,
|
75
|
+
"created_at": metadata["created_at"],
|
76
|
+
"filename": metadata["filename"]
|
77
|
+
}
|
78
|
+
|
79
|
+
self.files[file_id] = file_data
|
80
|
+
return FileResponse(**file_data)
|
81
|
+
|
82
|
+
def get_file(self, file_id: str):
|
83
|
+
return self.files[file_id]
|
@@ -0,0 +1,72 @@
|
|
1
|
+
import uuid
|
2
|
+
from enum import Enum
|
3
|
+
import logging
|
4
|
+
from datetime import datetime
|
5
|
+
|
6
|
+
class JobStatus(Enum):
|
7
|
+
PENDING = "pending"
|
8
|
+
QUEUED = "queued"
|
9
|
+
RUNNING = "running"
|
10
|
+
COMPLETED = "completed"
|
11
|
+
FAILED = "failed"
|
12
|
+
|
13
|
+
class JobLogHandler(logging.Handler):
|
14
|
+
def __init__(self, job):
|
15
|
+
super().__init__()
|
16
|
+
self.job = job
|
17
|
+
|
18
|
+
def emit(self, record):
|
19
|
+
log_entry = {
|
20
|
+
'timestamp': datetime.fromtimestamp(record.created).isoformat(),
|
21
|
+
'level': record.levelname,
|
22
|
+
'message': record.getMessage()
|
23
|
+
}
|
24
|
+
self.job.logs.append(log_entry)
|
25
|
+
|
26
|
+
class Job:
|
27
|
+
def __init__(self, id: str, status: JobStatus):
|
28
|
+
self.id = id
|
29
|
+
self.status = status
|
30
|
+
self.logs = []
|
31
|
+
self.logger = None
|
32
|
+
self.log_handler = None
|
33
|
+
|
34
|
+
def setup_logger(self, logger_name: str = None) -> logging.Logger:
|
35
|
+
"""Sets up logging for the job with a dedicated handler."""
|
36
|
+
if logger_name is None:
|
37
|
+
logger_name = f"job_{self.id}"
|
38
|
+
|
39
|
+
logger = logging.getLogger(logger_name)
|
40
|
+
logger.setLevel(logging.INFO)
|
41
|
+
|
42
|
+
# Create and setup handler if not already exists
|
43
|
+
if self.log_handler is None:
|
44
|
+
handler = JobLogHandler(self)
|
45
|
+
formatter = logging.Formatter('%(message)s')
|
46
|
+
handler.setFormatter(formatter)
|
47
|
+
logger.addHandler(handler)
|
48
|
+
self.log_handler = handler
|
49
|
+
|
50
|
+
self.logger = logger
|
51
|
+
return logger
|
52
|
+
|
53
|
+
def cleanup_logger(self):
|
54
|
+
"""Removes the job's logging handler."""
|
55
|
+
if self.logger and self.log_handler:
|
56
|
+
self.logger.removeHandler(self.log_handler)
|
57
|
+
self.log_handler = None
|
58
|
+
self.logger = None
|
59
|
+
|
60
|
+
class JobManager:
|
61
|
+
def __init__(self):
|
62
|
+
self.jobs = {}
|
63
|
+
|
64
|
+
def get_job_status(self, job_id: str):
|
65
|
+
if job_id not in self.jobs:
|
66
|
+
raise ValueError(f"Job {job_id} not found")
|
67
|
+
return self.jobs[job_id].status
|
68
|
+
|
69
|
+
def create_job(self):
|
70
|
+
job = Job(id=str(uuid.uuid4()), status=JobStatus.PENDING)
|
71
|
+
self.jobs[job.id] = job
|
72
|
+
return job
|
@@ -0,0 +1,263 @@
|
|
1
|
+
from arbor.server.api.models.schemas import FineTuneRequest
|
2
|
+
from arbor.server.services.job_manager import Job, JobStatus
|
3
|
+
from arbor.server.services.file_manager import FileManager
|
4
|
+
|
5
|
+
class TrainingManager:
|
6
|
+
def __init__(self):
|
7
|
+
pass
|
8
|
+
|
9
|
+
def find_train_args(self, request: FineTuneRequest, file_manager: FileManager):
|
10
|
+
file = file_manager.get_file(request.training_file)
|
11
|
+
if file is None:
|
12
|
+
raise ValueError(f"Training file {request.training_file} not found")
|
13
|
+
|
14
|
+
data_path = file["path"]
|
15
|
+
output_dir = f"models/{request.model_name}" # TODO: This should be updated to be unique in some way
|
16
|
+
|
17
|
+
|
18
|
+
default_train_kwargs = {
|
19
|
+
"device": None,
|
20
|
+
"use_peft": False,
|
21
|
+
"num_train_epochs": 5,
|
22
|
+
"per_device_train_batch_size": 1,
|
23
|
+
"gradient_accumulation_steps": 8,
|
24
|
+
"learning_rate": 1e-5,
|
25
|
+
"max_seq_length": None,
|
26
|
+
"packing": True,
|
27
|
+
"bf16": True,
|
28
|
+
"output_dir": output_dir,
|
29
|
+
"train_data_path": data_path,
|
30
|
+
}
|
31
|
+
train_kwargs = {'packing': False}
|
32
|
+
train_kwargs={**default_train_kwargs, **(train_kwargs or {})}
|
33
|
+
output_dir = train_kwargs["output_dir"] # user might have changed the output_dir
|
34
|
+
|
35
|
+
return train_kwargs
|
36
|
+
|
37
|
+
|
38
|
+
def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
|
39
|
+
# Get logger for this job
|
40
|
+
logger = job.setup_logger("training")
|
41
|
+
|
42
|
+
job.status = JobStatus.RUNNING
|
43
|
+
logger.info("Starting fine-tuning job")
|
44
|
+
|
45
|
+
try:
|
46
|
+
train_kwargs = self.find_train_args(request, file_manager)
|
47
|
+
|
48
|
+
import torch
|
49
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
|
50
|
+
from trl import SFTConfig, SFTTrainer, setup_chat_format
|
51
|
+
|
52
|
+
device = train_kwargs.get("device", None)
|
53
|
+
if device is None:
|
54
|
+
device = (
|
55
|
+
"cuda"
|
56
|
+
if torch.cuda.is_available()
|
57
|
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
58
|
+
)
|
59
|
+
logger.info(f"Using device: {device}")
|
60
|
+
|
61
|
+
model = AutoModelForCausalLM.from_pretrained(
|
62
|
+
pretrained_model_name_or_path=request.model_name
|
63
|
+
).to(device)
|
64
|
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=request.model_name)
|
65
|
+
|
66
|
+
# Set up the chat format; generally only for non-chat model variants, hence the try-except.
|
67
|
+
try:
|
68
|
+
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
|
69
|
+
except Exception:
|
70
|
+
pass
|
71
|
+
|
72
|
+
if tokenizer.pad_token_id is None:
|
73
|
+
logger.info("Adding pad token to tokenizer")
|
74
|
+
tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
|
75
|
+
|
76
|
+
logger.info("Creating dataset")
|
77
|
+
if "max_seq_length" not in train_kwargs or train_kwargs["max_seq_length"] is None:
|
78
|
+
train_kwargs["max_seq_length"] = 4096
|
79
|
+
logger.info(f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}")
|
80
|
+
|
81
|
+
|
82
|
+
hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
|
83
|
+
def tokenize_function(example):
|
84
|
+
return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
|
85
|
+
tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
|
86
|
+
tokenized_dataset.set_format(type="torch")
|
87
|
+
tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
|
88
|
+
|
89
|
+
USE_PEFT = train_kwargs.get("use_peft", False)
|
90
|
+
peft_config = None
|
91
|
+
|
92
|
+
if USE_PEFT:
|
93
|
+
from peft import LoraConfig
|
94
|
+
|
95
|
+
rank_dimension = 32
|
96
|
+
lora_alpha = 64
|
97
|
+
lora_dropout = 0.05
|
98
|
+
|
99
|
+
peft_config = LoraConfig(
|
100
|
+
r=rank_dimension,
|
101
|
+
lora_alpha=lora_alpha,
|
102
|
+
lora_dropout=lora_dropout,
|
103
|
+
bias="none",
|
104
|
+
target_modules="all-linear",
|
105
|
+
task_type="CAUSAL_LM",
|
106
|
+
)
|
107
|
+
|
108
|
+
sft_config = SFTConfig(
|
109
|
+
output_dir=train_kwargs["output_dir"],
|
110
|
+
num_train_epochs=train_kwargs["num_train_epochs"],
|
111
|
+
per_device_train_batch_size=train_kwargs["per_device_train_batch_size"],
|
112
|
+
gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"],
|
113
|
+
learning_rate=train_kwargs["learning_rate"],
|
114
|
+
max_grad_norm=2.0, # note that the current SFTConfig default is 1.0
|
115
|
+
logging_steps=20,
|
116
|
+
warmup_ratio=0.03,
|
117
|
+
lr_scheduler_type="constant",
|
118
|
+
save_steps=10_000,
|
119
|
+
bf16=train_kwargs["bf16"],
|
120
|
+
max_seq_length=train_kwargs["max_seq_length"],
|
121
|
+
packing=train_kwargs["packing"],
|
122
|
+
dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves
|
123
|
+
"add_special_tokens": False, # Special tokens handled by template
|
124
|
+
"append_concat_token": False, # No additional separator needed
|
125
|
+
},
|
126
|
+
)
|
127
|
+
|
128
|
+
logger.info("Starting training")
|
129
|
+
trainer = SFTTrainer(
|
130
|
+
model=model,
|
131
|
+
args=sft_config,
|
132
|
+
train_dataset=tokenized_dataset,
|
133
|
+
peft_config=peft_config,
|
134
|
+
|
135
|
+
)
|
136
|
+
|
137
|
+
# Train!
|
138
|
+
trainer.train()
|
139
|
+
|
140
|
+
# Save the model!
|
141
|
+
trainer.save_model()
|
142
|
+
|
143
|
+
MERGE = True
|
144
|
+
if USE_PEFT and MERGE:
|
145
|
+
from peft import AutoPeftModelForCausalLM
|
146
|
+
|
147
|
+
# Load PEFT model on CPU
|
148
|
+
model_ = AutoPeftModelForCausalLM.from_pretrained(
|
149
|
+
pretrained_model_name_or_path=sft_config.output_dir,
|
150
|
+
torch_dtype=torch.float16,
|
151
|
+
low_cpu_mem_usage=True,
|
152
|
+
)
|
153
|
+
|
154
|
+
merged_model = model_.merge_and_unload()
|
155
|
+
merged_model.save_pretrained(
|
156
|
+
sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
|
157
|
+
)
|
158
|
+
|
159
|
+
# Clean up!
|
160
|
+
import gc
|
161
|
+
|
162
|
+
del model
|
163
|
+
del tokenizer
|
164
|
+
del trainer
|
165
|
+
gc.collect()
|
166
|
+
torch.cuda.empty_cache()
|
167
|
+
|
168
|
+
logger.info("Training completed successfully")
|
169
|
+
job.status = JobStatus.COMPLETED
|
170
|
+
except Exception as e:
|
171
|
+
logger.error(f"Training failed: {str(e)}")
|
172
|
+
job.status = JobStatus.FAILED
|
173
|
+
raise
|
174
|
+
finally:
|
175
|
+
job.cleanup_logger()
|
176
|
+
|
177
|
+
return sft_config.output_dir
|
178
|
+
|
179
|
+
def dataset_from_file(data_path):
|
180
|
+
"""
|
181
|
+
Creates a HuggingFace Dataset from a JSONL file.
|
182
|
+
"""
|
183
|
+
from datasets import load_dataset
|
184
|
+
|
185
|
+
dataset = load_dataset("json", data_files=data_path, split="train")
|
186
|
+
return dataset
|
187
|
+
|
188
|
+
|
189
|
+
def encode_sft_example(example, tokenizer, max_seq_length):
|
190
|
+
"""
|
191
|
+
This function encodes a single example into a format that can be used for sft training.
|
192
|
+
Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
|
193
|
+
We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
|
194
|
+
|
195
|
+
Code obtained from the allenai/open-instruct repository: https://github.com/allenai/open-instruct/blob/4365dea3d1a6111e8b2712af06b22a4512a0df88/open_instruct/finetune.py
|
196
|
+
"""
|
197
|
+
import torch
|
198
|
+
|
199
|
+
messages = example["messages"]
|
200
|
+
if len(messages) == 0:
|
201
|
+
raise ValueError("messages field is empty.")
|
202
|
+
input_ids = tokenizer.apply_chat_template(
|
203
|
+
conversation=messages,
|
204
|
+
tokenize=True,
|
205
|
+
return_tensors="pt",
|
206
|
+
padding=False,
|
207
|
+
truncation=True,
|
208
|
+
max_length=max_seq_length,
|
209
|
+
add_generation_prompt=False,
|
210
|
+
)
|
211
|
+
labels = input_ids.clone()
|
212
|
+
# mask the non-assistant part for avoiding loss
|
213
|
+
for message_idx, message in enumerate(messages):
|
214
|
+
if message["role"] != "assistant":
|
215
|
+
# we calculate the start index of this non-assistant message
|
216
|
+
if message_idx == 0:
|
217
|
+
message_start_idx = 0
|
218
|
+
else:
|
219
|
+
message_start_idx = tokenizer.apply_chat_template(
|
220
|
+
conversation=messages[:message_idx], # here marks the end of the previous messages
|
221
|
+
tokenize=True,
|
222
|
+
return_tensors="pt",
|
223
|
+
padding=False,
|
224
|
+
truncation=True,
|
225
|
+
max_length=max_seq_length,
|
226
|
+
add_generation_prompt=False,
|
227
|
+
).shape[1]
|
228
|
+
# next, we calculate the end index of this non-assistant message
|
229
|
+
if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
|
230
|
+
# for intermediate messages that follow with an assistant message, we need to
|
231
|
+
# set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
|
232
|
+
# (e.g., `<|assistant|>`)
|
233
|
+
message_end_idx = tokenizer.apply_chat_template(
|
234
|
+
conversation=messages[: message_idx + 1],
|
235
|
+
tokenize=True,
|
236
|
+
return_tensors="pt",
|
237
|
+
padding=False,
|
238
|
+
truncation=True,
|
239
|
+
max_length=max_seq_length,
|
240
|
+
add_generation_prompt=True,
|
241
|
+
).shape[1]
|
242
|
+
else:
|
243
|
+
# for the last message or the message that doesn't follow with an assistant message,
|
244
|
+
# we don't need to add the assistant generation prefix
|
245
|
+
message_end_idx = tokenizer.apply_chat_template(
|
246
|
+
conversation=messages[: message_idx + 1],
|
247
|
+
tokenize=True,
|
248
|
+
return_tensors="pt",
|
249
|
+
padding=False,
|
250
|
+
truncation=True,
|
251
|
+
max_length=max_seq_length,
|
252
|
+
add_generation_prompt=False,
|
253
|
+
).shape[1]
|
254
|
+
# set the label to -100 for the non-assistant part
|
255
|
+
labels[:, message_start_idx:message_end_idx] = -100
|
256
|
+
if max_seq_length and message_end_idx >= max_seq_length:
|
257
|
+
break
|
258
|
+
attention_mask = torch.ones_like(input_ids)
|
259
|
+
return {
|
260
|
+
"input_ids": input_ids.flatten(),
|
261
|
+
"labels": labels.flatten(),
|
262
|
+
"attention_mask": attention_mask.flatten()
|
263
|
+
}
|
File without changes
|
File without changes
|
@@ -0,0 +1,47 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: arbor-ai
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A framework for fine-tuning and managing language models
|
5
|
+
License: MIT
|
6
|
+
Keywords: machine learning,fine-tuning,language models
|
7
|
+
Author: Noah Ziems
|
8
|
+
Author-email: nziems2@nd.edu
|
9
|
+
Requires-Python: >=3.13
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
11
|
+
Classifier: Intended Audience :: Developers
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
16
|
+
Classifier: Programming Language :: Python :: 3.8
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
18
|
+
Requires-Dist: click
|
19
|
+
Requires-Dist: fastapi
|
20
|
+
Requires-Dist: peft (>=0.14.0,<0.15.0)
|
21
|
+
Requires-Dist: pydantic-settings (>=2.8.1,<3.0.0)
|
22
|
+
Requires-Dist: python-multipart (>=0.0.20,<0.0.21)
|
23
|
+
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
24
|
+
Requires-Dist: transformers (>=4.49.0,<5.0.0)
|
25
|
+
Requires-Dist: trl (>=0.15.2,<0.16.0)
|
26
|
+
Requires-Dist: uvicorn
|
27
|
+
Project-URL: Repository, https://github.com/arbor-ai/arbor
|
28
|
+
Description-Content-Type: text/markdown
|
29
|
+
|
30
|
+
# Arbor AI
|
31
|
+
|
32
|
+
## Setup
|
33
|
+
|
34
|
+
```bash
|
35
|
+
poetry install
|
36
|
+
```
|
37
|
+
|
38
|
+
```bash
|
39
|
+
poetry run arbor serve
|
40
|
+
```
|
41
|
+
|
42
|
+
## Uploading Data
|
43
|
+
|
44
|
+
```bash
|
45
|
+
curl -X POST "http://localhost:8000/api/files" -F "file=@training_data.jsonl"
|
46
|
+
```
|
47
|
+
|
@@ -0,0 +1,25 @@
|
|
1
|
+
arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
arbor/cli.py,sha256=6fT5JjpXSwhpJSQNE4pnLOY04ryHPwJBAOet3hyho8k,383
|
3
|
+
arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
arbor/client/api.py,sha256=WFaNtwCNWXRAHHG1Jfyl7LvTP6jiEyQOLZn2Z8Yjt5k,40
|
5
|
+
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
6
|
+
arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
7
|
+
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
arbor/server/api/routes/files.py,sha256=QrPY9-886NXnXjGRlT-pl5kWbnwfogCrdmv6RufJpVg,466
|
9
|
+
arbor/server/api/routes/jobs.py,sha256=ibL0tQA2Apqa91vycv3NPT0ydhkba4vnPoclw-bVKXs,510
|
10
|
+
arbor/server/api/routes/training.py,sha256=43NOvh1Hubg3Ocfhu5E82Tp_kXOJL8H8oQXf-4H1yMU,981
|
11
|
+
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
12
|
+
arbor/server/core/config.py,sha256=R67gNeUXz0RShvpr8XF3Lpn7-RMOfKf2xTIyqXvj4PI,215
|
13
|
+
arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
arbor/server/main.py,sha256=I3chVYsoG56zE7Clf88lEuOPaDzJvKsOzivOWpsFDls,350
|
15
|
+
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
|
+
arbor/server/services/dependencies.py,sha256=y3EoIkwScYc811jZ8p5m0kJT4ixRo7vguimBKKMuxAQ,458
|
17
|
+
arbor/server/services/file_manager.py,sha256=uE7Mnbn9fC1H7sAMzJt1x9fVak10duh0OyAZeDgi3iY,2200
|
18
|
+
arbor/server/services/job_manager.py,sha256=Zx3d0h31YH9bQ4yQr3FUXUGEHd-KUiTekZ0ndGOptrY,1893
|
19
|
+
arbor/server/services/training_manager.py,sha256=SNumrzM1B-V1HBucUHmxnmc4rmhSCHVgPPc5nNPRC4Q,10682
|
20
|
+
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
|
+
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
+
arbor_ai-0.1.0.dist-info/METADATA,sha256=1KSvnCQMVYhdQ9ONUupWtO14MuEyirX2yD58i7xQfVc,1272
|
23
|
+
arbor_ai-0.1.0.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
24
|
+
arbor_ai-0.1.0.dist-info/entry_points.txt,sha256=AaLg05CZSQeP2oGlCH_AnmZPz-zzLlVtpXToI4cM3kY,39
|
25
|
+
arbor_ai-0.1.0.dist-info/RECORD,,
|