arbor-ai 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbor/cli.py CHANGED
@@ -1,17 +1,101 @@
1
1
  import click
2
2
  import uvicorn
3
+
4
+ from arbor.server.core.config import Settings
3
5
  from arbor.server.main import app
6
+ from arbor.server.services.file_manager import FileManager
7
+ from arbor.server.services.grpo_manager import GRPOManager
8
+ from arbor.server.services.inference_manager import InferenceManager
9
+ from arbor.server.services.job_manager import JobManager
10
+ from arbor.server.services.training_manager import TrainingManager
11
+
4
12
 
5
13
  @click.group()
6
14
  def cli():
7
15
  pass
8
16
 
17
+
18
+ def create_app(arbor_config_path: str):
19
+ """Create and configure the Arbor API application
20
+
21
+ Args:
22
+ storage_path (str): Path to store models and uploaded training files
23
+
24
+ Returns:
25
+ FastAPI: Configured FastAPI application
26
+ """
27
+ # Create new settings instance with overrides
28
+ settings = Settings.load_from_yaml(arbor_config_path)
29
+
30
+ # Initialize services with settings
31
+ file_manager = FileManager(settings=settings)
32
+ job_manager = JobManager(settings=settings)
33
+ training_manager = TrainingManager(settings=settings)
34
+ inference_manager = InferenceManager(settings=settings)
35
+ grpo_manager = GRPOManager(settings=settings)
36
+ # Inject settings into app state
37
+ app.state.settings = settings
38
+ app.state.file_manager = file_manager
39
+ app.state.job_manager = job_manager
40
+ app.state.training_manager = training_manager
41
+ app.state.inference_manager = inference_manager
42
+ app.state.grpo_manager = grpo_manager
43
+
44
+ return app
45
+
46
+
47
+ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
48
+ """Start the Arbor API server with a single function call"""
49
+ import socket
50
+ import threading
51
+ import time
52
+ from contextlib import closing
53
+
54
+ def is_port_in_use(port):
55
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
56
+ return sock.connect_ex(("localhost", port)) == 0
57
+
58
+ # First ensure the port is free
59
+ if is_port_in_use(port):
60
+ raise RuntimeError(f"Port {port} is already in use")
61
+
62
+ app = create_app(storage_path)
63
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
64
+ server = uvicorn.Server(config)
65
+
66
+ def run_server():
67
+ server.run()
68
+
69
+ thread = threading.Thread(target=run_server, daemon=True)
70
+ thread.start()
71
+
72
+ # Wait for server to start
73
+ start_time = time.time()
74
+ while not is_port_in_use(port):
75
+ if time.time() - start_time > timeout:
76
+ raise TimeoutError(f"Server failed to start within {timeout} seconds")
77
+ time.sleep(0.1)
78
+
79
+ # Give it a little extra time to fully initialize
80
+ time.sleep(0.5)
81
+
82
+ return server
83
+
84
+
85
+ def stop_server(server):
86
+ """Stop the Arbor API server"""
87
+ server.should_exit = True
88
+
89
+
9
90
  @cli.command()
10
- @click.option('--host', default='0.0.0.0', help='Host to bind to')
11
- @click.option('--port', default=8000, help='Port to bind to')
12
- def serve(host, port):
91
+ @click.option("--host", default="0.0.0.0", help="Host to bind to")
92
+ @click.option("--port", default=7453, help="Port to bind to")
93
+ @click.option("--arbor-config", required=True, help="Path to the Arbor config file")
94
+ def serve(host, port, arbor_config):
13
95
  """Start the Arbor API server"""
96
+ app = create_app(arbor_config)
14
97
  uvicorn.run(app, host=host, port=port)
15
98
 
16
- if __name__ == '__main__':
17
- cli()
99
+
100
+ if __name__ == "__main__":
101
+ cli()
@@ -0,0 +1,78 @@
1
+ Metadata-Version: 2.4
2
+ Name: arbor-ai
3
+ Version: 0.1.5
4
+ Summary: A framework for fine-tuning and managing language models
5
+ Author-email: Noah Ziems <nziems2@nd.edu>
6
+ Project-URL: Homepage, https://github.com/Ziems/arbor
7
+ Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: fastapi
12
+ Requires-Dist: uvicorn
13
+ Requires-Dist: click
14
+ Requires-Dist: python-multipart
15
+ Requires-Dist: pydantic-settings
16
+ Requires-Dist: torch
17
+ Requires-Dist: transformers
18
+ Requires-Dist: trl
19
+ Requires-Dist: peft
20
+ Requires-Dist: ray>=2.9
21
+ Requires-Dist: setuptools<77.0.0,>=76.0.0
22
+ Requires-Dist: pyzmq>=26.4.0
23
+ Requires-Dist: pyyaml>=6.0.2
24
+ Requires-Dist: sglang>=0.4.5.post3
25
+ Requires-Dist: sglang-router
26
+ Dynamic: license-file
27
+
28
+ <p align="center">
29
+ <img src="https://github.com/user-attachments/assets/ed0dd782-65fa-48b5-a762-b343b183be09" alt="Description" width="400"/>
30
+ </p>
31
+
32
+ **A framework for optimizing DSPy programs with RL.**
33
+
34
+ ---
35
+
36
+ ## 🚀 Installation
37
+
38
+ Install Arbor via pip:
39
+
40
+ ```bash
41
+ pip install git+https://github.com/Ziems/arbor.git
42
+ ```
43
+
44
+ ---
45
+
46
+ ## ⚡ Quick Start
47
+
48
+ ### 1️⃣ Make an `arbor.yaml` File
49
+
50
+ This is all dependent on your setup. Here is an example of one:
51
+ ```yaml
52
+ inference:
53
+ gpu_ids: '0'
54
+
55
+ training:
56
+ gpu_ids: '1, 2'
57
+ ```
58
+
59
+ ### 2️⃣ Start the Server
60
+
61
+ **CLI:**
62
+
63
+ ```bash
64
+ python -m arbor.cli serve --arbor-config arbor.yaml
65
+ ```
66
+
67
+ ### 3️⃣ Optimize a DSPy Program
68
+
69
+ Follow the DSPy tutorials here to see usage examples:
70
+ [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
71
+
72
+ ---
73
+
74
+ ## 🙏 Acknowledgements
75
+
76
+ Arbor builds on the shoulders of great work. We extend our thanks to:
77
+ - **[Will Brown's Verifiers library](https://github.com/willccbb/verifiers)**
78
+ - **[Hugging Face TRL library](https://github.com/huggingface/trl)**
@@ -0,0 +1,8 @@
1
+ arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
3
+ arbor_ai-0.1.5.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
4
+ arbor_ai-0.1.5.dist-info/METADATA,sha256=Tney6uOytHDMIZg3iqKrn2lgtaF3NULjXo19XdG_2Dw,1823
5
+ arbor_ai-0.1.5.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
6
+ arbor_ai-0.1.5.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
7
+ arbor_ai-0.1.5.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
8
+ arbor_ai-0.1.5.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.1
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ arbor = arbor.cli:cli
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Noah Ziems
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ arbor
arbor/client/__init__.py DELETED
File without changes
arbor/client/api.py DELETED
@@ -1,2 +0,0 @@
1
- from typing import Optional, Dict, Any
2
-
arbor/server/__init__.py DELETED
@@ -1 +0,0 @@
1
-
@@ -1 +0,0 @@
1
-
@@ -1,18 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- class FileResponse(BaseModel):
4
- id: str
5
- object: str = "file"
6
- bytes: int
7
- created_at: int
8
- filename: str
9
- purpose: str
10
-
11
- class FineTuneRequest(BaseModel):
12
- model_name: str
13
- training_file: str # id of uploaded jsonl file
14
-
15
- class JobStatusResponse(BaseModel):
16
- job_id: str
17
- status: str
18
- details: str = ""
File without changes
@@ -1,13 +0,0 @@
1
- from fastapi import APIRouter, UploadFile, File, Depends
2
- from arbor.server.services.file_manager import FileManager
3
- from arbor.server.api.models.schemas import FileResponse
4
- from arbor.server.services.dependencies import get_file_manager
5
-
6
- router = APIRouter()
7
-
8
- @router.post("", response_model=FileResponse)
9
- def upload_file(
10
- file: UploadFile = File(...),
11
- file_manager: FileManager = Depends(get_file_manager)
12
- ):
13
- return file_manager.save_uploaded_file(file)
@@ -1,14 +0,0 @@
1
- from fastapi import APIRouter, Depends
2
- from arbor.server.services.job_manager import JobManager
3
- from arbor.server.services.dependencies import get_job_manager
4
- from arbor.server.api.models.schemas import JobStatusResponse
5
-
6
- router = APIRouter()
7
-
8
- @router.get("/{job_id}", response_model=JobStatusResponse)
9
- def get_job_status(
10
- job_id: str,
11
- job_manager: JobManager = Depends(get_job_manager)
12
- ):
13
- status = job_manager.get_job_status(job_id)
14
- return JobStatusResponse(job_id=job_id, status=status.value)
@@ -1,16 +0,0 @@
1
- from fastapi import APIRouter, BackgroundTasks, Depends
2
-
3
- from arbor.server.api.models.schemas import FineTuneRequest, JobStatusResponse
4
- from arbor.server.services.job_manager import JobManager, JobStatus
5
- from arbor.server.services.file_manager import FileManager
6
- from arbor.server.services.training_manager import TrainingManager
7
- from arbor.server.services.dependencies import get_training_manager, get_job_manager, get_file_manager
8
-
9
- router = APIRouter()
10
-
11
- @router.post("", response_model=JobStatusResponse)
12
- def fine_tune(request: FineTuneRequest, background_tasks: BackgroundTasks, training_manager: TrainingManager = Depends(get_training_manager), job_manager: JobManager = Depends(get_job_manager), file_manager: FileManager = Depends(get_file_manager)):
13
- job = job_manager.create_job()
14
- background_tasks.add_task(training_manager.fine_tune, request, job, file_manager)
15
- job.status = JobStatus.QUEUED
16
- return JobStatusResponse(job_id=job.id, status=job.status.value)
@@ -1 +0,0 @@
1
-
@@ -1,10 +0,0 @@
1
- from pydantic_settings import BaseSettings
2
-
3
- class Settings(BaseSettings):
4
- UPLOADS_DIR: str = "uploads"
5
- MODEL_CACHE_DIR: str = "model_cache"
6
-
7
- class Config:
8
- env_file = ".env"
9
-
10
- settings = Settings()
File without changes
arbor/server/main.py DELETED
@@ -1,10 +0,0 @@
1
- from fastapi import FastAPI
2
- from arbor.server.api.routes import training, files, jobs
3
- from arbor.server.core.config import settings
4
-
5
- app = FastAPI(title="Arbor API")
6
-
7
- # Include routers
8
- app.include_router(training.router, prefix="/api/fine-tune")
9
- app.include_router(files.router, prefix="/api/files")
10
- app.include_router(jobs.router, prefix="/api/job")
File without changes
@@ -1,16 +0,0 @@
1
- from functools import lru_cache
2
- from arbor.server.services.file_manager import FileManager
3
- from arbor.server.services.job_manager import JobManager
4
- from arbor.server.services.training_manager import TrainingManager
5
-
6
- @lru_cache()
7
- def get_file_manager() -> FileManager:
8
- return FileManager()
9
-
10
- @lru_cache()
11
- def get_job_manager() -> JobManager:
12
- return JobManager()
13
-
14
- @lru_cache()
15
- def get_training_manager() -> TrainingManager:
16
- return TrainingManager()
@@ -1,83 +0,0 @@
1
- from pathlib import Path
2
- import json
3
- import os
4
- import shutil
5
- import time
6
- import uuid
7
- from fastapi import UploadFile
8
- from arbor.server.api.models.schemas import FileResponse
9
-
10
- class FileManager:
11
- def __init__(self):
12
- self.uploads_dir = Path("uploads")
13
- self.uploads_dir.mkdir(exist_ok=True)
14
- self.files = self.load_files_from_uploads()
15
-
16
- def load_files_from_uploads(self):
17
- files = {}
18
-
19
- # Scan through all directories in uploads directory
20
- for dir_path in self.uploads_dir.glob("*"):
21
- if not dir_path.is_dir():
22
- continue
23
-
24
- # Check for metadata.json
25
- metadata_path = dir_path / "metadata.json"
26
- if not metadata_path.exists():
27
- continue
28
-
29
- # Load metadata
30
- with open(metadata_path) as f:
31
- metadata = json.load(f)
32
-
33
- # Find the .jsonl file
34
- jsonl_files = list(dir_path.glob("*.jsonl"))
35
- if not jsonl_files:
36
- continue
37
-
38
- file_path = jsonl_files[0]
39
- files[dir_path.name] = {
40
- "path": str(file_path),
41
- "purpose": metadata.get("purpose", "training"),
42
- "bytes": file_path.stat().st_size,
43
- "created_at": metadata.get("created_at", int(file_path.stat().st_mtime)),
44
- "filename": metadata.get("filename", file_path.name)
45
- }
46
-
47
- return files
48
-
49
- def save_uploaded_file(self, file: UploadFile) -> FileResponse:
50
- file_id = str(uuid.uuid4())
51
- dir_path = self.uploads_dir / file_id
52
- dir_path.mkdir(exist_ok=True)
53
-
54
- # Save the actual file
55
- file_path = dir_path / f"data.jsonl"
56
- with open(file_path, "wb") as f:
57
- shutil.copyfileobj(file.file, f)
58
-
59
- # Create metadata
60
- metadata = {
61
- "purpose": "training",
62
- "created_at": int(time.time()),
63
- "filename": file.filename
64
- }
65
-
66
- # Save metadata
67
- with open(dir_path / "metadata.json", "w") as f:
68
- json.dump(metadata, f)
69
-
70
- file_data = {
71
- "id": file_id,
72
- "path": str(file_path),
73
- "purpose": metadata["purpose"],
74
- "bytes": file.size,
75
- "created_at": metadata["created_at"],
76
- "filename": metadata["filename"]
77
- }
78
-
79
- self.files[file_id] = file_data
80
- return FileResponse(**file_data)
81
-
82
- def get_file(self, file_id: str):
83
- return self.files[file_id]
@@ -1,72 +0,0 @@
1
- import uuid
2
- from enum import Enum
3
- import logging
4
- from datetime import datetime
5
-
6
- class JobStatus(Enum):
7
- PENDING = "pending"
8
- QUEUED = "queued"
9
- RUNNING = "running"
10
- COMPLETED = "completed"
11
- FAILED = "failed"
12
-
13
- class JobLogHandler(logging.Handler):
14
- def __init__(self, job):
15
- super().__init__()
16
- self.job = job
17
-
18
- def emit(self, record):
19
- log_entry = {
20
- 'timestamp': datetime.fromtimestamp(record.created).isoformat(),
21
- 'level': record.levelname,
22
- 'message': record.getMessage()
23
- }
24
- self.job.logs.append(log_entry)
25
-
26
- class Job:
27
- def __init__(self, id: str, status: JobStatus):
28
- self.id = id
29
- self.status = status
30
- self.logs = []
31
- self.logger = None
32
- self.log_handler = None
33
-
34
- def setup_logger(self, logger_name: str = None) -> logging.Logger:
35
- """Sets up logging for the job with a dedicated handler."""
36
- if logger_name is None:
37
- logger_name = f"job_{self.id}"
38
-
39
- logger = logging.getLogger(logger_name)
40
- logger.setLevel(logging.INFO)
41
-
42
- # Create and setup handler if not already exists
43
- if self.log_handler is None:
44
- handler = JobLogHandler(self)
45
- formatter = logging.Formatter('%(message)s')
46
- handler.setFormatter(formatter)
47
- logger.addHandler(handler)
48
- self.log_handler = handler
49
-
50
- self.logger = logger
51
- return logger
52
-
53
- def cleanup_logger(self):
54
- """Removes the job's logging handler."""
55
- if self.logger and self.log_handler:
56
- self.logger.removeHandler(self.log_handler)
57
- self.log_handler = None
58
- self.logger = None
59
-
60
- class JobManager:
61
- def __init__(self):
62
- self.jobs = {}
63
-
64
- def get_job_status(self, job_id: str):
65
- if job_id not in self.jobs:
66
- raise ValueError(f"Job {job_id} not found")
67
- return self.jobs[job_id].status
68
-
69
- def create_job(self):
70
- job = Job(id=str(uuid.uuid4()), status=JobStatus.PENDING)
71
- self.jobs[job.id] = job
72
- return job
@@ -1,263 +0,0 @@
1
- from arbor.server.api.models.schemas import FineTuneRequest
2
- from arbor.server.services.job_manager import Job, JobStatus
3
- from arbor.server.services.file_manager import FileManager
4
-
5
- class TrainingManager:
6
- def __init__(self):
7
- pass
8
-
9
- def find_train_args(self, request: FineTuneRequest, file_manager: FileManager):
10
- file = file_manager.get_file(request.training_file)
11
- if file is None:
12
- raise ValueError(f"Training file {request.training_file} not found")
13
-
14
- data_path = file["path"]
15
- output_dir = f"models/{request.model_name}" # TODO: This should be updated to be unique in some way
16
-
17
-
18
- default_train_kwargs = {
19
- "device": None,
20
- "use_peft": False,
21
- "num_train_epochs": 5,
22
- "per_device_train_batch_size": 1,
23
- "gradient_accumulation_steps": 8,
24
- "learning_rate": 1e-5,
25
- "max_seq_length": None,
26
- "packing": True,
27
- "bf16": True,
28
- "output_dir": output_dir,
29
- "train_data_path": data_path,
30
- }
31
- train_kwargs = {'packing': False}
32
- train_kwargs={**default_train_kwargs, **(train_kwargs or {})}
33
- output_dir = train_kwargs["output_dir"] # user might have changed the output_dir
34
-
35
- return train_kwargs
36
-
37
-
38
- def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
39
- # Get logger for this job
40
- logger = job.setup_logger("training")
41
-
42
- job.status = JobStatus.RUNNING
43
- logger.info("Starting fine-tuning job")
44
-
45
- try:
46
- train_kwargs = self.find_train_args(request, file_manager)
47
-
48
- import torch
49
- from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
50
- from trl import SFTConfig, SFTTrainer, setup_chat_format
51
-
52
- device = train_kwargs.get("device", None)
53
- if device is None:
54
- device = (
55
- "cuda"
56
- if torch.cuda.is_available()
57
- else "mps" if torch.backends.mps.is_available() else "cpu"
58
- )
59
- logger.info(f"Using device: {device}")
60
-
61
- model = AutoModelForCausalLM.from_pretrained(
62
- pretrained_model_name_or_path=request.model_name
63
- ).to(device)
64
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=request.model_name)
65
-
66
- # Set up the chat format; generally only for non-chat model variants, hence the try-except.
67
- try:
68
- model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
69
- except Exception:
70
- pass
71
-
72
- if tokenizer.pad_token_id is None:
73
- logger.info("Adding pad token to tokenizer")
74
- tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
75
-
76
- logger.info("Creating dataset")
77
- if "max_seq_length" not in train_kwargs or train_kwargs["max_seq_length"] is None:
78
- train_kwargs["max_seq_length"] = 4096
79
- logger.info(f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}")
80
-
81
-
82
- hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
83
- def tokenize_function(example):
84
- return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
85
- tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
86
- tokenized_dataset.set_format(type="torch")
87
- tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
88
-
89
- USE_PEFT = train_kwargs.get("use_peft", False)
90
- peft_config = None
91
-
92
- if USE_PEFT:
93
- from peft import LoraConfig
94
-
95
- rank_dimension = 32
96
- lora_alpha = 64
97
- lora_dropout = 0.05
98
-
99
- peft_config = LoraConfig(
100
- r=rank_dimension,
101
- lora_alpha=lora_alpha,
102
- lora_dropout=lora_dropout,
103
- bias="none",
104
- target_modules="all-linear",
105
- task_type="CAUSAL_LM",
106
- )
107
-
108
- sft_config = SFTConfig(
109
- output_dir=train_kwargs["output_dir"],
110
- num_train_epochs=train_kwargs["num_train_epochs"],
111
- per_device_train_batch_size=train_kwargs["per_device_train_batch_size"],
112
- gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"],
113
- learning_rate=train_kwargs["learning_rate"],
114
- max_grad_norm=2.0, # note that the current SFTConfig default is 1.0
115
- logging_steps=20,
116
- warmup_ratio=0.03,
117
- lr_scheduler_type="constant",
118
- save_steps=10_000,
119
- bf16=train_kwargs["bf16"],
120
- max_seq_length=train_kwargs["max_seq_length"],
121
- packing=train_kwargs["packing"],
122
- dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves
123
- "add_special_tokens": False, # Special tokens handled by template
124
- "append_concat_token": False, # No additional separator needed
125
- },
126
- )
127
-
128
- logger.info("Starting training")
129
- trainer = SFTTrainer(
130
- model=model,
131
- args=sft_config,
132
- train_dataset=tokenized_dataset,
133
- peft_config=peft_config,
134
-
135
- )
136
-
137
- # Train!
138
- trainer.train()
139
-
140
- # Save the model!
141
- trainer.save_model()
142
-
143
- MERGE = True
144
- if USE_PEFT and MERGE:
145
- from peft import AutoPeftModelForCausalLM
146
-
147
- # Load PEFT model on CPU
148
- model_ = AutoPeftModelForCausalLM.from_pretrained(
149
- pretrained_model_name_or_path=sft_config.output_dir,
150
- torch_dtype=torch.float16,
151
- low_cpu_mem_usage=True,
152
- )
153
-
154
- merged_model = model_.merge_and_unload()
155
- merged_model.save_pretrained(
156
- sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
157
- )
158
-
159
- # Clean up!
160
- import gc
161
-
162
- del model
163
- del tokenizer
164
- del trainer
165
- gc.collect()
166
- torch.cuda.empty_cache()
167
-
168
- logger.info("Training completed successfully")
169
- job.status = JobStatus.COMPLETED
170
- except Exception as e:
171
- logger.error(f"Training failed: {str(e)}")
172
- job.status = JobStatus.FAILED
173
- raise
174
- finally:
175
- job.cleanup_logger()
176
-
177
- return sft_config.output_dir
178
-
179
- def dataset_from_file(data_path):
180
- """
181
- Creates a HuggingFace Dataset from a JSONL file.
182
- """
183
- from datasets import load_dataset
184
-
185
- dataset = load_dataset("json", data_files=data_path, split="train")
186
- return dataset
187
-
188
-
189
- def encode_sft_example(example, tokenizer, max_seq_length):
190
- """
191
- This function encodes a single example into a format that can be used for sft training.
192
- Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
193
- We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
194
-
195
- Code obtained from the allenai/open-instruct repository: https://github.com/allenai/open-instruct/blob/4365dea3d1a6111e8b2712af06b22a4512a0df88/open_instruct/finetune.py
196
- """
197
- import torch
198
-
199
- messages = example["messages"]
200
- if len(messages) == 0:
201
- raise ValueError("messages field is empty.")
202
- input_ids = tokenizer.apply_chat_template(
203
- conversation=messages,
204
- tokenize=True,
205
- return_tensors="pt",
206
- padding=False,
207
- truncation=True,
208
- max_length=max_seq_length,
209
- add_generation_prompt=False,
210
- )
211
- labels = input_ids.clone()
212
- # mask the non-assistant part for avoiding loss
213
- for message_idx, message in enumerate(messages):
214
- if message["role"] != "assistant":
215
- # we calculate the start index of this non-assistant message
216
- if message_idx == 0:
217
- message_start_idx = 0
218
- else:
219
- message_start_idx = tokenizer.apply_chat_template(
220
- conversation=messages[:message_idx], # here marks the end of the previous messages
221
- tokenize=True,
222
- return_tensors="pt",
223
- padding=False,
224
- truncation=True,
225
- max_length=max_seq_length,
226
- add_generation_prompt=False,
227
- ).shape[1]
228
- # next, we calculate the end index of this non-assistant message
229
- if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
230
- # for intermediate messages that follow with an assistant message, we need to
231
- # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
232
- # (e.g., `<|assistant|>`)
233
- message_end_idx = tokenizer.apply_chat_template(
234
- conversation=messages[: message_idx + 1],
235
- tokenize=True,
236
- return_tensors="pt",
237
- padding=False,
238
- truncation=True,
239
- max_length=max_seq_length,
240
- add_generation_prompt=True,
241
- ).shape[1]
242
- else:
243
- # for the last message or the message that doesn't follow with an assistant message,
244
- # we don't need to add the assistant generation prefix
245
- message_end_idx = tokenizer.apply_chat_template(
246
- conversation=messages[: message_idx + 1],
247
- tokenize=True,
248
- return_tensors="pt",
249
- padding=False,
250
- truncation=True,
251
- max_length=max_seq_length,
252
- add_generation_prompt=False,
253
- ).shape[1]
254
- # set the label to -100 for the non-assistant part
255
- labels[:, message_start_idx:message_end_idx] = -100
256
- if max_seq_length and message_end_idx >= max_seq_length:
257
- break
258
- attention_mask = torch.ones_like(input_ids)
259
- return {
260
- "input_ids": input_ids.flatten(),
261
- "labels": labels.flatten(),
262
- "attention_mask": attention_mask.flatten()
263
- }
File without changes
File without changes
@@ -1,47 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: arbor-ai
3
- Version: 0.1.3
4
- Summary: A framework for fine-tuning and managing language models
5
- License: MIT
6
- Keywords: machine learning,fine-tuning,language models
7
- Author: Noah Ziems
8
- Author-email: nziems2@nd.edu
9
- Requires-Python: >=3.13
10
- Classifier: Development Status :: 3 - Alpha
11
- Classifier: Intended Audience :: Developers
12
- Classifier: License :: OSI Approved :: MIT License
13
- Classifier: Programming Language :: Python :: 3
14
- Classifier: Programming Language :: Python :: 3.13
15
- Classifier: Programming Language :: Python :: 3.10
16
- Classifier: Programming Language :: Python :: 3.8
17
- Classifier: Programming Language :: Python :: 3.9
18
- Requires-Dist: click
19
- Requires-Dist: fastapi
20
- Requires-Dist: peft (>=0.14.0,<0.15.0)
21
- Requires-Dist: pydantic-settings (>=2.8.1,<3.0.0)
22
- Requires-Dist: python-multipart (>=0.0.20,<0.0.21)
23
- Requires-Dist: torch (>=2.6.0,<3.0.0)
24
- Requires-Dist: transformers (>=4.49.0,<5.0.0)
25
- Requires-Dist: trl (>=0.15.2,<0.16.0)
26
- Requires-Dist: uvicorn
27
- Project-URL: Repository, https://github.com/arbor-ai/arbor
28
- Description-Content-Type: text/markdown
29
-
30
- # Arbor AI
31
-
32
- ## Setup
33
-
34
- ```bash
35
- poetry install
36
- ```
37
-
38
- ```bash
39
- poetry run arbor serve
40
- ```
41
-
42
- ## Uploading Data
43
-
44
- ```bash
45
- curl -X POST "http://localhost:8000/api/files" -F "file=@training_data.jsonl"
46
- ```
47
-
@@ -1,26 +0,0 @@
1
- arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- arbor/cli.py,sha256=6fT5JjpXSwhpJSQNE4pnLOY04ryHPwJBAOet3hyho8k,383
3
- arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- arbor/client/api.py,sha256=WFaNtwCNWXRAHHG1Jfyl7LvTP6jiEyQOLZn2Z8Yjt5k,40
5
- arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
- arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
7
- arbor/server/api/models/schemas.py,sha256=-_NOnUuEbiO8uLDwEpByYV6NAMasOmFUJXxG0eXA_D0,367
8
- arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- arbor/server/api/routes/files.py,sha256=QrPY9-886NXnXjGRlT-pl5kWbnwfogCrdmv6RufJpVg,466
10
- arbor/server/api/routes/jobs.py,sha256=ibL0tQA2Apqa91vycv3NPT0ydhkba4vnPoclw-bVKXs,510
11
- arbor/server/api/routes/training.py,sha256=43NOvh1Hubg3Ocfhu5E82Tp_kXOJL8H8oQXf-4H1yMU,981
12
- arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
13
- arbor/server/core/config.py,sha256=R67gNeUXz0RShvpr8XF3Lpn7-RMOfKf2xTIyqXvj4PI,215
14
- arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- arbor/server/main.py,sha256=I3chVYsoG56zE7Clf88lEuOPaDzJvKsOzivOWpsFDls,350
16
- arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- arbor/server/services/dependencies.py,sha256=y3EoIkwScYc811jZ8p5m0kJT4ixRo7vguimBKKMuxAQ,458
18
- arbor/server/services/file_manager.py,sha256=uE7Mnbn9fC1H7sAMzJt1x9fVak10duh0OyAZeDgi3iY,2200
19
- arbor/server/services/job_manager.py,sha256=Zx3d0h31YH9bQ4yQr3FUXUGEHd-KUiTekZ0ndGOptrY,1893
20
- arbor/server/services/training_manager.py,sha256=SNumrzM1B-V1HBucUHmxnmc4rmhSCHVgPPc5nNPRC4Q,10682
21
- arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- arbor_ai-0.1.3.dist-info/METADATA,sha256=PJbOddt69fyZXJggzpaSUb5XfUt0ouPrQQFIAaeOasE,1272
24
- arbor_ai-0.1.3.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
25
- arbor_ai-0.1.3.dist-info/entry_points.txt,sha256=AaLg05CZSQeP2oGlCH_AnmZPz-zzLlVtpXToI4cM3kY,39
26
- arbor_ai-0.1.3.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- [console_scripts]
2
- arbor=arbor.cli:cli
3
-