arbor-ai 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbor/cli.py CHANGED
@@ -1,17 +1,101 @@
1
1
  import click
2
2
  import uvicorn
3
+
4
+ from arbor.server.core.config import Settings
3
5
  from arbor.server.main import app
6
+ from arbor.server.services.file_manager import FileManager
7
+ from arbor.server.services.grpo_manager import GRPOManager
8
+ from arbor.server.services.inference_manager import InferenceManager
9
+ from arbor.server.services.job_manager import JobManager
10
+ from arbor.server.services.training_manager import TrainingManager
11
+
4
12
 
5
13
  @click.group()
6
14
  def cli():
7
15
  pass
8
16
 
17
+
18
+ def create_app(arbor_config_path: str):
19
+ """Create and configure the Arbor API application
20
+
21
+ Args:
22
+ storage_path (str): Path to store models and uploaded training files
23
+
24
+ Returns:
25
+ FastAPI: Configured FastAPI application
26
+ """
27
+ # Create new settings instance with overrides
28
+ settings = Settings.load_from_yaml(arbor_config_path)
29
+
30
+ # Initialize services with settings
31
+ file_manager = FileManager(settings=settings)
32
+ job_manager = JobManager(settings=settings)
33
+ training_manager = TrainingManager(settings=settings)
34
+ inference_manager = InferenceManager(settings=settings)
35
+ grpo_manager = GRPOManager(settings=settings)
36
+ # Inject settings into app state
37
+ app.state.settings = settings
38
+ app.state.file_manager = file_manager
39
+ app.state.job_manager = job_manager
40
+ app.state.training_manager = training_manager
41
+ app.state.inference_manager = inference_manager
42
+ app.state.grpo_manager = grpo_manager
43
+
44
+ return app
45
+
46
+
47
+ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
48
+ """Start the Arbor API server with a single function call"""
49
+ import socket
50
+ import threading
51
+ import time
52
+ from contextlib import closing
53
+
54
+ def is_port_in_use(port):
55
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
56
+ return sock.connect_ex(("localhost", port)) == 0
57
+
58
+ # First ensure the port is free
59
+ if is_port_in_use(port):
60
+ raise RuntimeError(f"Port {port} is already in use")
61
+
62
+ app = create_app(storage_path)
63
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
64
+ server = uvicorn.Server(config)
65
+
66
+ def run_server():
67
+ server.run()
68
+
69
+ thread = threading.Thread(target=run_server, daemon=True)
70
+ thread.start()
71
+
72
+ # Wait for server to start
73
+ start_time = time.time()
74
+ while not is_port_in_use(port):
75
+ if time.time() - start_time > timeout:
76
+ raise TimeoutError(f"Server failed to start within {timeout} seconds")
77
+ time.sleep(0.1)
78
+
79
+ # Give it a little extra time to fully initialize
80
+ time.sleep(0.5)
81
+
82
+ return server
83
+
84
+
85
+ def stop_server(server):
86
+ """Stop the Arbor API server"""
87
+ server.should_exit = True
88
+
89
+
9
90
  @cli.command()
10
- @click.option('--host', default='0.0.0.0', help='Host to bind to')
11
- @click.option('--port', default=8000, help='Port to bind to')
12
- def serve(host, port):
91
+ @click.option("--host", default="0.0.0.0", help="Host to bind to")
92
+ @click.option("--port", default=7453, help="Port to bind to")
93
+ @click.option("--arbor-config", required=True, help="Path to the Arbor config file")
94
+ def serve(host, port, arbor_config):
13
95
  """Start the Arbor API server"""
96
+ app = create_app(arbor_config)
14
97
  uvicorn.run(app, host=host, port=port)
15
98
 
16
- if __name__ == '__main__':
17
- cli()
99
+
100
+ if __name__ == "__main__":
101
+ cli()
@@ -0,0 +1,78 @@
1
+ Metadata-Version: 2.4
2
+ Name: arbor-ai
3
+ Version: 0.1.5
4
+ Summary: A framework for fine-tuning and managing language models
5
+ Author-email: Noah Ziems <nziems2@nd.edu>
6
+ Project-URL: Homepage, https://github.com/Ziems/arbor
7
+ Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: fastapi
12
+ Requires-Dist: uvicorn
13
+ Requires-Dist: click
14
+ Requires-Dist: python-multipart
15
+ Requires-Dist: pydantic-settings
16
+ Requires-Dist: torch
17
+ Requires-Dist: transformers
18
+ Requires-Dist: trl
19
+ Requires-Dist: peft
20
+ Requires-Dist: ray>=2.9
21
+ Requires-Dist: setuptools<77.0.0,>=76.0.0
22
+ Requires-Dist: pyzmq>=26.4.0
23
+ Requires-Dist: pyyaml>=6.0.2
24
+ Requires-Dist: sglang>=0.4.5.post3
25
+ Requires-Dist: sglang-router
26
+ Dynamic: license-file
27
+
28
+ <p align="center">
29
+ <img src="https://github.com/user-attachments/assets/ed0dd782-65fa-48b5-a762-b343b183be09" alt="Description" width="400"/>
30
+ </p>
31
+
32
+ **A framework for optimizing DSPy programs with RL.**
33
+
34
+ ---
35
+
36
+ ## 🚀 Installation
37
+
38
+ Install Arbor via pip:
39
+
40
+ ```bash
41
+ pip install git+https://github.com/Ziems/arbor.git
42
+ ```
43
+
44
+ ---
45
+
46
+ ## ⚡ Quick Start
47
+
48
+ ### 1️⃣ Make an `arbor.yaml` File
49
+
50
+ This is all dependent on your setup. Here is an example of one:
51
+ ```yaml
52
+ inference:
53
+ gpu_ids: '0'
54
+
55
+ training:
56
+ gpu_ids: '1, 2'
57
+ ```
58
+
59
+ ### 2️⃣ Start the Server
60
+
61
+ **CLI:**
62
+
63
+ ```bash
64
+ python -m arbor.cli serve --arbor-config arbor.yaml
65
+ ```
66
+
67
+ ### 3️⃣ Optimize a DSPy Program
68
+
69
+ Follow the DSPy tutorials here to see usage examples:
70
+ [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
71
+
72
+ ---
73
+
74
+ ## 🙏 Acknowledgements
75
+
76
+ Arbor builds on the shoulders of great work. We extend our thanks to:
77
+ - **[Will Brown's Verifiers library](https://github.com/willccbb/verifiers)**
78
+ - **[Hugging Face TRL library](https://github.com/huggingface/trl)**
@@ -0,0 +1,8 @@
1
+ arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
3
+ arbor_ai-0.1.5.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
4
+ arbor_ai-0.1.5.dist-info/METADATA,sha256=Tney6uOytHDMIZg3iqKrn2lgtaF3NULjXo19XdG_2Dw,1823
5
+ arbor_ai-0.1.5.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
6
+ arbor_ai-0.1.5.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
7
+ arbor_ai-0.1.5.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
8
+ arbor_ai-0.1.5.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.1
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ arbor = arbor.cli:cli
@@ -0,0 +1 @@
1
+ arbor
arbor/client/__init__.py DELETED
File without changes
arbor/client/api.py DELETED
@@ -1,2 +0,0 @@
1
- from typing import Optional, Dict, Any
2
-
arbor/server/__init__.py DELETED
@@ -1 +0,0 @@
1
-
@@ -1 +0,0 @@
1
-
@@ -1,19 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- class FileResponse(BaseModel):
4
- id: str
5
- object: str = "file"
6
- bytes: int
7
- created_at: int
8
- filename: str
9
- purpose: str
10
-
11
- class FineTuneRequest(BaseModel):
12
- model: str
13
- training_file: str # id of uploaded jsonl file
14
-
15
- class JobStatusResponse(BaseModel):
16
- id: str
17
- status: str
18
- details: str = ""
19
- fine_tuned_model: str | None = None
File without changes
@@ -1,23 +0,0 @@
1
- from fastapi import APIRouter, UploadFile, File, Depends, HTTPException
2
- from arbor.server.services.file_manager import FileManager
3
- from arbor.server.api.models.schemas import FileResponse
4
- from arbor.server.services.dependencies import get_file_manager
5
- from arbor.server.services.file_manager import FileValidationError
6
-
7
- router = APIRouter()
8
-
9
- @router.post("", response_model=FileResponse)
10
- async def upload_file(
11
- file: UploadFile = File(...),
12
- file_manager: FileManager = Depends(get_file_manager)
13
- ):
14
- if not file.filename.endswith('.jsonl'):
15
- raise HTTPException(status_code=400, detail="Only .jsonl files are allowed")
16
-
17
- try:
18
- content = await file.read()
19
- file_manager.validate_file_format(content)
20
- await file.seek(0) # Reset file pointer to beginning
21
- return file_manager.save_uploaded_file(file)
22
- except FileValidationError as e:
23
- raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
@@ -1,14 +0,0 @@
1
- from fastapi import APIRouter, Depends
2
- from arbor.server.services.job_manager import JobManager
3
- from arbor.server.services.dependencies import get_job_manager
4
- from arbor.server.api.models.schemas import JobStatusResponse
5
-
6
- router = APIRouter()
7
-
8
- @router.get("/{job_id}", response_model=JobStatusResponse)
9
- def get_job_status(
10
- job_id: str,
11
- job_manager: JobManager = Depends(get_job_manager)
12
- ):
13
- job = job_manager.get_job(job_id)
14
- return JobStatusResponse(id=job_id, status=job.status.value, fine_tuned_model=job.fine_tuned_model)
@@ -1,16 +0,0 @@
1
- from fastapi import APIRouter, BackgroundTasks, Depends
2
-
3
- from arbor.server.api.models.schemas import FineTuneRequest, JobStatusResponse
4
- from arbor.server.services.job_manager import JobManager, JobStatus
5
- from arbor.server.services.file_manager import FileManager
6
- from arbor.server.services.training_manager import TrainingManager
7
- from arbor.server.services.dependencies import get_training_manager, get_job_manager, get_file_manager
8
-
9
- router = APIRouter()
10
-
11
- @router.post("", response_model=JobStatusResponse)
12
- def fine_tune(request: FineTuneRequest, background_tasks: BackgroundTasks, training_manager: TrainingManager = Depends(get_training_manager), job_manager: JobManager = Depends(get_job_manager), file_manager: FileManager = Depends(get_file_manager)):
13
- job = job_manager.create_job()
14
- background_tasks.add_task(training_manager.fine_tune, request, job, file_manager)
15
- job.status = JobStatus.QUEUED
16
- return JobStatusResponse(id=job.id, status=job.status.value)
@@ -1 +0,0 @@
1
-
@@ -1,10 +0,0 @@
1
- from pydantic_settings import BaseSettings
2
-
3
- class Settings(BaseSettings):
4
- UPLOADS_DIR: str = "uploads"
5
- MODEL_CACHE_DIR: str = "model_cache"
6
-
7
- class Config:
8
- env_file = ".env"
9
-
10
- settings = Settings()
File without changes
arbor/server/main.py DELETED
@@ -1,10 +0,0 @@
1
- from fastapi import FastAPI
2
- from arbor.server.api.routes import training, files, jobs
3
- from arbor.server.core.config import settings
4
-
5
- app = FastAPI(title="Arbor API")
6
-
7
- # Include routers
8
- app.include_router(training.router, prefix="/api/fine-tune")
9
- app.include_router(files.router, prefix="/api/files")
10
- app.include_router(jobs.router, prefix="/api/job")
File without changes
@@ -1,16 +0,0 @@
1
- from functools import lru_cache
2
- from arbor.server.services.file_manager import FileManager
3
- from arbor.server.services.job_manager import JobManager
4
- from arbor.server.services.training_manager import TrainingManager
5
-
6
- @lru_cache()
7
- def get_file_manager() -> FileManager:
8
- return FileManager()
9
-
10
- @lru_cache()
11
- def get_job_manager() -> JobManager:
12
- return JobManager()
13
-
14
- @lru_cache()
15
- def get_training_manager() -> TrainingManager:
16
- return TrainingManager()
@@ -1,128 +0,0 @@
1
- from pathlib import Path
2
- import json
3
- import os
4
- import shutil
5
- import time
6
- import uuid
7
- from fastapi import UploadFile
8
- from arbor.server.api.models.schemas import FileResponse
9
-
10
- class FileValidationError(Exception):
11
- """Custom exception for file validation errors"""
12
- pass
13
-
14
- class FileManager:
15
- def __init__(self):
16
- self.uploads_dir = Path("uploads")
17
- self.uploads_dir.mkdir(exist_ok=True)
18
- self.files = self.load_files_from_uploads()
19
-
20
- def load_files_from_uploads(self):
21
- files = {}
22
-
23
- # Scan through all directories in uploads directory
24
- for dir_path in self.uploads_dir.glob("*"):
25
- if not dir_path.is_dir():
26
- continue
27
-
28
- # Check for metadata.json
29
- metadata_path = dir_path / "metadata.json"
30
- if not metadata_path.exists():
31
- continue
32
-
33
- # Load metadata
34
- with open(metadata_path) as f:
35
- metadata = json.load(f)
36
-
37
- # Find the .jsonl file
38
- jsonl_files = list(dir_path.glob("*.jsonl"))
39
- if not jsonl_files:
40
- continue
41
-
42
- file_path = jsonl_files[0]
43
- files[dir_path.name] = {
44
- "path": str(file_path),
45
- "purpose": metadata.get("purpose", "training"),
46
- "bytes": file_path.stat().st_size,
47
- "created_at": metadata.get("created_at", int(file_path.stat().st_mtime)),
48
- "filename": metadata.get("filename", file_path.name)
49
- }
50
-
51
- return files
52
-
53
- def save_uploaded_file(self, file: UploadFile) -> FileResponse:
54
- file_id = str(uuid.uuid4())
55
- dir_path = self.uploads_dir / file_id
56
- dir_path.mkdir(exist_ok=True)
57
-
58
- # Save the actual file
59
- file_path = dir_path / f"data.jsonl"
60
- with open(file_path, "wb") as f:
61
- shutil.copyfileobj(file.file, f)
62
-
63
- # Create metadata
64
- metadata = {
65
- "purpose": "training",
66
- "created_at": int(time.time()),
67
- "filename": file.filename
68
- }
69
-
70
- # Save metadata
71
- with open(dir_path / "metadata.json", "w") as f:
72
- json.dump(metadata, f)
73
-
74
- file_data = {
75
- "id": file_id,
76
- "path": str(file_path),
77
- "purpose": metadata["purpose"],
78
- "bytes": file.size,
79
- "created_at": metadata["created_at"],
80
- "filename": metadata["filename"]
81
- }
82
-
83
- self.files[file_id] = file_data
84
- return FileResponse(**file_data)
85
-
86
- def get_file(self, file_id: str):
87
- return self.files[file_id]
88
-
89
- def validate_file_format(self, file_content: bytes) -> None:
90
- """
91
- Validates that the file content is properly formatted JSONL with expected structure.
92
- Raises FileValidationError if validation fails.
93
- """
94
- if not file_content:
95
- raise FileValidationError("File is empty")
96
-
97
- try:
98
- lines = file_content.decode('utf-8').strip().split('\n')
99
- if not lines:
100
- raise FileValidationError("File contains no valid data")
101
-
102
- for line_num, line in enumerate(lines, 1):
103
- try:
104
- data = json.loads(line)
105
-
106
- # Validate required structure
107
- if not isinstance(data, dict):
108
- raise FileValidationError(f"Line {line_num}: Each line must be a JSON object")
109
-
110
- if "messages" not in data:
111
- raise FileValidationError(f"Line {line_num}: Missing 'messages' field")
112
-
113
- if not isinstance(data["messages"], list):
114
- raise FileValidationError(f"Line {line_num}: 'messages' must be an array")
115
-
116
- for msg in data["messages"]:
117
- if not isinstance(msg, dict):
118
- raise FileValidationError(f"Line {line_num}: Each message must be an object")
119
- if "role" not in msg or "content" not in msg:
120
- raise FileValidationError(f"Line {line_num}: Messages must have 'role' and 'content' fields")
121
- if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
122
- raise FileValidationError(f"Line {line_num}: Message 'role' and 'content' must be strings")
123
-
124
- except json.JSONDecodeError:
125
- raise FileValidationError(f"Invalid JSON on line {line_num}")
126
-
127
- except UnicodeDecodeError:
128
- raise FileValidationError("File must be valid UTF-8 encoded text")
@@ -1,76 +0,0 @@
1
- import uuid
2
- from enum import Enum
3
- import logging
4
- from datetime import datetime
5
-
6
- # https://platform.openai.com/docs/api-reference/fine-tuning/object
7
- class JobStatus(Enum):
8
- PENDING = "pending" # Not in OAI
9
- VALIDATING_FILES = "validating_files"
10
- QUEUED = "queued"
11
- RUNNING = "running"
12
- SUCCEEDED = "succeeded"
13
- FAILED = "failed"
14
- CANCELLED = "cancelled"
15
-
16
- class JobLogHandler(logging.Handler):
17
- def __init__(self, job):
18
- super().__init__()
19
- self.job = job
20
-
21
- def emit(self, record):
22
- log_entry = {
23
- 'timestamp': datetime.fromtimestamp(record.created).isoformat(),
24
- 'level': record.levelname,
25
- 'message': record.getMessage()
26
- }
27
- self.job.logs.append(log_entry)
28
-
29
- class Job:
30
- def __init__(self, id: str, status: JobStatus):
31
- self.id = id
32
- self.status = status
33
- self.fine_tuned_model = None
34
- self.logs = []
35
- self.logger = None
36
- self.log_handler = None
37
-
38
- def setup_logger(self, logger_name: str = None) -> logging.Logger:
39
- """Sets up logging for the job with a dedicated handler."""
40
- if logger_name is None:
41
- logger_name = f"job_{self.id}"
42
-
43
- logger = logging.getLogger(logger_name)
44
- logger.setLevel(logging.INFO)
45
-
46
- # Create and setup handler if not already exists
47
- if self.log_handler is None:
48
- handler = JobLogHandler(self)
49
- formatter = logging.Formatter('%(message)s')
50
- handler.setFormatter(formatter)
51
- logger.addHandler(handler)
52
- self.log_handler = handler
53
-
54
- self.logger = logger
55
- return logger
56
-
57
- def cleanup_logger(self):
58
- """Removes the job's logging handler."""
59
- if self.logger and self.log_handler:
60
- self.logger.removeHandler(self.log_handler)
61
- self.log_handler = None
62
- self.logger = None
63
-
64
- class JobManager:
65
- def __init__(self):
66
- self.jobs = {}
67
-
68
- def get_job(self, job_id: str):
69
- if job_id not in self.jobs:
70
- raise ValueError(f"Job {job_id} not found")
71
- return self.jobs[job_id]
72
-
73
- def create_job(self):
74
- job = Job(id=str(uuid.uuid4()), status=JobStatus.PENDING)
75
- self.jobs[job.id] = job
76
- return job
@@ -1,264 +0,0 @@
1
- from arbor.server.api.models.schemas import FineTuneRequest
2
- from arbor.server.services.job_manager import Job, JobStatus
3
- from arbor.server.services.file_manager import FileManager
4
-
5
- class TrainingManager:
6
- def __init__(self):
7
- pass
8
-
9
- def find_train_args(self, request: FineTuneRequest, file_manager: FileManager):
10
- file = file_manager.get_file(request.training_file)
11
- if file is None:
12
- raise ValueError(f"Training file {request.training_file} not found")
13
-
14
- data_path = file["path"]
15
- output_dir = f"models/{request.model}" # TODO: This should be updated to be unique in some way
16
-
17
-
18
- default_train_kwargs = {
19
- "device": None,
20
- "use_peft": False,
21
- "num_train_epochs": 5,
22
- "per_device_train_batch_size": 1,
23
- "gradient_accumulation_steps": 8,
24
- "learning_rate": 1e-5,
25
- "max_seq_length": None,
26
- "packing": True,
27
- "bf16": True,
28
- "output_dir": output_dir,
29
- "train_data_path": data_path,
30
- }
31
- train_kwargs = {'packing': False}
32
- train_kwargs={**default_train_kwargs, **(train_kwargs or {})}
33
- output_dir = train_kwargs["output_dir"] # user might have changed the output_dir
34
-
35
- return train_kwargs
36
-
37
-
38
- def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
39
- # Get logger for this job
40
- logger = job.setup_logger("training")
41
-
42
- job.status = JobStatus.RUNNING
43
- logger.info("Starting fine-tuning job")
44
-
45
- try:
46
- train_kwargs = self.find_train_args(request, file_manager)
47
-
48
- import torch
49
- from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
50
- from trl import SFTConfig, SFTTrainer, setup_chat_format
51
-
52
- device = train_kwargs.get("device", None)
53
- if device is None:
54
- device = (
55
- "cuda"
56
- if torch.cuda.is_available()
57
- else "mps" if torch.backends.mps.is_available() else "cpu"
58
- )
59
- logger.info(f"Using device: {device}")
60
-
61
- model = AutoModelForCausalLM.from_pretrained(
62
- pretrained_model_name_or_path=request.model
63
- ).to(device)
64
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=request.model)
65
-
66
- # Set up the chat format; generally only for non-chat model variants, hence the try-except.
67
- try:
68
- model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
69
- except Exception:
70
- pass
71
-
72
- if tokenizer.pad_token_id is None:
73
- logger.info("Adding pad token to tokenizer")
74
- tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
75
-
76
- logger.info("Creating dataset")
77
- if "max_seq_length" not in train_kwargs or train_kwargs["max_seq_length"] is None:
78
- train_kwargs["max_seq_length"] = 4096
79
- logger.info(f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}")
80
-
81
-
82
- hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
83
- def tokenize_function(example):
84
- return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
85
- tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
86
- tokenized_dataset.set_format(type="torch")
87
- tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
88
-
89
- USE_PEFT = train_kwargs.get("use_peft", False)
90
- peft_config = None
91
-
92
- if USE_PEFT:
93
- from peft import LoraConfig
94
-
95
- rank_dimension = 32
96
- lora_alpha = 64
97
- lora_dropout = 0.05
98
-
99
- peft_config = LoraConfig(
100
- r=rank_dimension,
101
- lora_alpha=lora_alpha,
102
- lora_dropout=lora_dropout,
103
- bias="none",
104
- target_modules="all-linear",
105
- task_type="CAUSAL_LM",
106
- )
107
-
108
- sft_config = SFTConfig(
109
- output_dir=train_kwargs["output_dir"],
110
- num_train_epochs=train_kwargs["num_train_epochs"],
111
- per_device_train_batch_size=train_kwargs["per_device_train_batch_size"],
112
- gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"],
113
- learning_rate=train_kwargs["learning_rate"],
114
- max_grad_norm=2.0, # note that the current SFTConfig default is 1.0
115
- logging_steps=20,
116
- warmup_ratio=0.03,
117
- lr_scheduler_type="constant",
118
- save_steps=10_000,
119
- bf16=train_kwargs["bf16"],
120
- max_seq_length=train_kwargs["max_seq_length"],
121
- packing=train_kwargs["packing"],
122
- dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves
123
- "add_special_tokens": False, # Special tokens handled by template
124
- "append_concat_token": False, # No additional separator needed
125
- },
126
- )
127
-
128
- logger.info("Starting training")
129
- trainer = SFTTrainer(
130
- model=model,
131
- args=sft_config,
132
- train_dataset=tokenized_dataset,
133
- peft_config=peft_config,
134
-
135
- )
136
-
137
- # Train!
138
- trainer.train()
139
-
140
- # Save the model!
141
- trainer.save_model()
142
-
143
- MERGE = True
144
- if USE_PEFT and MERGE:
145
- from peft import AutoPeftModelForCausalLM
146
-
147
- # Load PEFT model on CPU
148
- model_ = AutoPeftModelForCausalLM.from_pretrained(
149
- pretrained_model_name_or_path=sft_config.output_dir,
150
- torch_dtype=torch.float16,
151
- low_cpu_mem_usage=True,
152
- )
153
-
154
- merged_model = model_.merge_and_unload()
155
- merged_model.save_pretrained(
156
- sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
157
- )
158
-
159
- # Clean up!
160
- import gc
161
-
162
- del model
163
- del tokenizer
164
- del trainer
165
- gc.collect()
166
- torch.cuda.empty_cache()
167
-
168
- logger.info("Training completed successfully")
169
- job.status = JobStatus.SUCCEEDED
170
- job.fine_tuned_model = sft_config.output_dir
171
- except Exception as e:
172
- logger.error(f"Training failed: {str(e)}")
173
- job.status = JobStatus.FAILED
174
- raise
175
- finally:
176
- job.cleanup_logger()
177
-
178
- return sft_config.output_dir
179
-
180
- def dataset_from_file(data_path):
181
- """
182
- Creates a HuggingFace Dataset from a JSONL file.
183
- """
184
- from datasets import load_dataset
185
-
186
- dataset = load_dataset("json", data_files=data_path, split="train")
187
- return dataset
188
-
189
-
190
- def encode_sft_example(example, tokenizer, max_seq_length):
191
- """
192
- This function encodes a single example into a format that can be used for sft training.
193
- Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
194
- We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
195
-
196
- Code obtained from the allenai/open-instruct repository: https://github.com/allenai/open-instruct/blob/4365dea3d1a6111e8b2712af06b22a4512a0df88/open_instruct/finetune.py
197
- """
198
- import torch
199
-
200
- messages = example["messages"]
201
- if len(messages) == 0:
202
- raise ValueError("messages field is empty.")
203
- input_ids = tokenizer.apply_chat_template(
204
- conversation=messages,
205
- tokenize=True,
206
- return_tensors="pt",
207
- padding=False,
208
- truncation=True,
209
- max_length=max_seq_length,
210
- add_generation_prompt=False,
211
- )
212
- labels = input_ids.clone()
213
- # mask the non-assistant part for avoiding loss
214
- for message_idx, message in enumerate(messages):
215
- if message["role"] != "assistant":
216
- # we calculate the start index of this non-assistant message
217
- if message_idx == 0:
218
- message_start_idx = 0
219
- else:
220
- message_start_idx = tokenizer.apply_chat_template(
221
- conversation=messages[:message_idx], # here marks the end of the previous messages
222
- tokenize=True,
223
- return_tensors="pt",
224
- padding=False,
225
- truncation=True,
226
- max_length=max_seq_length,
227
- add_generation_prompt=False,
228
- ).shape[1]
229
- # next, we calculate the end index of this non-assistant message
230
- if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
231
- # for intermediate messages that follow with an assistant message, we need to
232
- # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
233
- # (e.g., `<|assistant|>`)
234
- message_end_idx = tokenizer.apply_chat_template(
235
- conversation=messages[: message_idx + 1],
236
- tokenize=True,
237
- return_tensors="pt",
238
- padding=False,
239
- truncation=True,
240
- max_length=max_seq_length,
241
- add_generation_prompt=True,
242
- ).shape[1]
243
- else:
244
- # for the last message or the message that doesn't follow with an assistant message,
245
- # we don't need to add the assistant generation prefix
246
- message_end_idx = tokenizer.apply_chat_template(
247
- conversation=messages[: message_idx + 1],
248
- tokenize=True,
249
- return_tensors="pt",
250
- padding=False,
251
- truncation=True,
252
- max_length=max_seq_length,
253
- add_generation_prompt=False,
254
- ).shape[1]
255
- # set the label to -100 for the non-assistant part
256
- labels[:, message_start_idx:message_end_idx] = -100
257
- if max_seq_length and message_end_idx >= max_seq_length:
258
- break
259
- attention_mask = torch.ones_like(input_ids)
260
- return {
261
- "input_ids": input_ids.flatten(),
262
- "labels": labels.flatten(),
263
- "attention_mask": attention_mask.flatten()
264
- }
File without changes
File without changes
@@ -1,97 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: arbor-ai
3
- Version: 0.1.4
4
- Summary: A framework for fine-tuning and managing language models
5
- License: MIT
6
- Keywords: machine learning,fine-tuning,language models
7
- Author: Noah Ziems
8
- Author-email: nziems2@nd.edu
9
- Requires-Python: >=3.9, <3.14
10
- Classifier: Development Status :: 3 - Alpha
11
- Classifier: Intended Audience :: Developers
12
- Classifier: License :: OSI Approved :: MIT License
13
- Classifier: Programming Language :: Python :: 3
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Programming Language :: Python :: 3.10
16
- Classifier: Programming Language :: Python :: 3.11
17
- Classifier: Programming Language :: Python :: 3.12
18
- Classifier: Programming Language :: Python :: 3.13
19
- Requires-Dist: click
20
- Requires-Dist: fastapi
21
- Requires-Dist: peft (>=0.14.0,<0.15.0)
22
- Requires-Dist: pydantic-settings (>=2.8.1,<3.0.0)
23
- Requires-Dist: python-multipart (>=0.0.20,<0.0.21)
24
- Requires-Dist: torch (>=2.6.0,<3.0.0)
25
- Requires-Dist: transformers (>=4.49.0,<5.0.0)
26
- Requires-Dist: trl (>=0.15.2,<0.16.0)
27
- Requires-Dist: uvicorn
28
- Project-URL: Repository, https://github.com/arbor-ai/arbor
29
- Description-Content-Type: text/markdown
30
-
31
- # Arbor 🌳
32
-
33
- A drop-in replacement for OpenAI's fine-tuning API that lets you fine-tune and manage open-source language models locally. Train and deploy custom models with the same API you already know.
34
-
35
- ## Installation
36
-
37
- ```bash
38
- pip install arbor-ai
39
- ```
40
-
41
- ## Quick Start
42
-
43
- 1. Start the Arbor server:
44
-
45
- ```bash
46
- arbor serve
47
- ```
48
-
49
- 2. The server will be available at `http://localhost:8000`.
50
-
51
- 3. Upload your training data:
52
-
53
- ```python
54
- import requests
55
-
56
- requests.post('http://127.0.0.1:8000/api/files', files={'file': open('your_file.jsonl', 'rb')})
57
- ```
58
-
59
- 4. Submit a fine-tuning job:
60
-
61
- ```python
62
- requests.post('http://127.0.0.1:8000/api/fine-tune', json={'model': 'HuggingFaceTB/SmolLM2-135M-Instruct', 'training_file': 'Returned file ID from Step 3'})
63
- ```
64
-
65
- 5. Monitor the job status:
66
-
67
- ```python
68
- requests.get('http://127.0.0.1:8000/api/jobs/{Returned job ID from Step 4}')
69
- ```
70
-
71
-
72
-
73
- ## Development Setup
74
-
75
- ```bash
76
- poetry install
77
- ```
78
-
79
- ```bash
80
- poetry run arbor serve
81
- ```
82
-
83
- ```bash
84
- poetry run pytest
85
- ```
86
-
87
- ## Contributing
88
-
89
- Contributions are welcome! Please feel free to submit a Pull Request.
90
-
91
- ## License
92
-
93
- This project is licensed under the MIT License - see the LICENSE file for details.
94
-
95
- ## Support
96
-
97
- If you encounter any issues or have questions, please file an issue on the [GitHub repository](https://github.com/Ziems/arbor/issues).
@@ -1,27 +0,0 @@
1
- arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- arbor/cli.py,sha256=6fT5JjpXSwhpJSQNE4pnLOY04ryHPwJBAOet3hyho8k,383
3
- arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- arbor/client/api.py,sha256=WFaNtwCNWXRAHHG1Jfyl7LvTP6jiEyQOLZn2Z8Yjt5k,40
5
- arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
- arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
7
- arbor/server/api/models/schemas.py,sha256=19uDproKWhPQvVTit0hWuqmPb80zrELtCgnLybDuBKw,398
8
- arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- arbor/server/api/routes/files.py,sha256=U5QPC05VzqgDirB77lpy6BJLvg3zo1eGz7RUEk3HgRw,970
10
- arbor/server/api/routes/jobs.py,sha256=W2Y-rByaULxT0pEy3_YSNWO2CEKR5obyax-uR4ax_6Y,539
11
- arbor/server/api/routes/training.py,sha256=5M6OAtl9i8L-jBefmvPWvyf1M_x30-IlXzgleBg41Yc,977
12
- arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
13
- arbor/server/core/config.py,sha256=R67gNeUXz0RShvpr8XF3Lpn7-RMOfKf2xTIyqXvj4PI,215
14
- arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- arbor/server/main.py,sha256=I3chVYsoG56zE7Clf88lEuOPaDzJvKsOzivOWpsFDls,350
16
- arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- arbor/server/services/dependencies.py,sha256=y3EoIkwScYc811jZ8p5m0kJT4ixRo7vguimBKKMuxAQ,458
18
- arbor/server/services/file_manager.py,sha256=VUCn0cUtd-Q1BrUPtKStS1hGtV_OlymUyA0I8zeG9Po,4037
19
- arbor/server/services/job_manager.py,sha256=rZjuhwwbvL7yCJi653tv7z36iFFvp1w5J9j5DntSWKM,2073
20
- arbor/server/services/training_manager.py,sha256=BQsUsxOyRlgFDEFM77tyIahmm4NqcoOwxq8Tlmp66dY,10724
21
- arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- arbor_ai-0.1.4.dist-info/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
24
- arbor_ai-0.1.4.dist-info/METADATA,sha256=977OGIuruJzS8wkFntELEoO7Ey5VzEhv88v1Pt81pa0,2451
25
- arbor_ai-0.1.4.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
26
- arbor_ai-0.1.4.dist-info/entry_points.txt,sha256=AaLg05CZSQeP2oGlCH_AnmZPz-zzLlVtpXToI4cM3kY,39
27
- arbor_ai-0.1.4.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- [console_scripts]
2
- arbor=arbor.cli:cli
3
-