arbor-ai 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,47 @@
1
+ Metadata-Version: 2.3
2
+ Name: arbor-ai
3
+ Version: 0.1.0
4
+ Summary: A framework for fine-tuning and managing language models
5
+ License: MIT
6
+ Keywords: machine learning,fine-tuning,language models
7
+ Author: Noah Ziems
8
+ Author-email: nziems2@nd.edu
9
+ Requires-Python: >=3.13
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.8
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Requires-Dist: click
19
+ Requires-Dist: fastapi
20
+ Requires-Dist: peft (>=0.14.0,<0.15.0)
21
+ Requires-Dist: pydantic-settings (>=2.8.1,<3.0.0)
22
+ Requires-Dist: python-multipart (>=0.0.20,<0.0.21)
23
+ Requires-Dist: torch (>=2.6.0,<3.0.0)
24
+ Requires-Dist: transformers (>=4.49.0,<5.0.0)
25
+ Requires-Dist: trl (>=0.15.2,<0.16.0)
26
+ Requires-Dist: uvicorn
27
+ Project-URL: Repository, https://github.com/arbor-ai/arbor
28
+ Description-Content-Type: text/markdown
29
+
30
+ # Arbor AI
31
+
32
+ ## Setup
33
+
34
+ ```bash
35
+ poetry install
36
+ ```
37
+
38
+ ```bash
39
+ poetry run arbor serve
40
+ ```
41
+
42
+ ## Uploading Data
43
+
44
+ ```bash
45
+ curl -X POST "http://localhost:8000/api/files" -F "file=@training_data.jsonl"
46
+ ```
47
+
@@ -0,0 +1,17 @@
1
+ # Arbor AI
2
+
3
+ ## Setup
4
+
5
+ ```bash
6
+ poetry install
7
+ ```
8
+
9
+ ```bash
10
+ poetry run arbor serve
11
+ ```
12
+
13
+ ## Uploading Data
14
+
15
+ ```bash
16
+ curl -X POST "http://localhost:8000/api/files" -F "file=@training_data.jsonl"
17
+ ```
@@ -0,0 +1,52 @@
1
+ [project]
2
+ name = "arbor-ai"
3
+ version = "0.1.0"
4
+ description = "A framework for fine-tuning and managing language models"
5
+ authors = [
6
+ {name = "Noah Ziems",email = "nziems2@nd.edu"}
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">=3.13"
10
+ dependencies = [
11
+ "fastapi",
12
+ "uvicorn",
13
+ "click",
14
+ "python-multipart (>=0.0.20,<0.0.21)",
15
+ "pydantic-settings (>=2.8.1,<3.0.0)",
16
+ "torch (>=2.6.0,<3.0.0)",
17
+ "transformers (>=4.49.0,<5.0.0)",
18
+ "trl (>=0.15.2,<0.16.0)",
19
+ "peft (>=0.14.0,<0.15.0)",
20
+ ]
21
+
22
+ [project.scripts]
23
+ arbor = "arbor.cli:cli"
24
+
25
+ [tool.poetry]
26
+ name = "arbor"
27
+ version = "0.1.0"
28
+ description = "A framework for fine-tuning and managing language models"
29
+ authors = ["Noah Ziems <nziems2@nd.edu>"]
30
+ readme = "README.md"
31
+ packages = [{include = "arbor", from = "src"}]
32
+ repository = "https://github.com/arbor-ai/arbor"
33
+ keywords = ["machine learning", "fine-tuning", "language models"]
34
+ license = "MIT"
35
+ classifiers = [
36
+ "Development Status :: 3 - Alpha",
37
+ "Intended Audience :: Developers",
38
+ "License :: OSI Approved :: MIT License",
39
+ "Programming Language :: Python :: 3",
40
+ "Programming Language :: Python :: 3.8",
41
+ "Programming Language :: Python :: 3.9",
42
+ "Programming Language :: Python :: 3.10",
43
+ ]
44
+
45
+ [tool.poetry.group.dev.dependencies]
46
+ pytest = "^8.3.5"
47
+ pytest-asyncio = "^0.25.3"
48
+ httpx = "^0.28.1"
49
+
50
+ [build-system]
51
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
52
+ build-backend = "poetry.core.masonry.api"
File without changes
@@ -0,0 +1,17 @@
1
+ import click
2
+ import uvicorn
3
+ from arbor.server.main import app
4
+
5
+ @click.group()
6
+ def cli():
7
+ pass
8
+
9
+ @cli.command()
10
+ @click.option('--host', default='0.0.0.0', help='Host to bind to')
11
+ @click.option('--port', default=8000, help='Port to bind to')
12
+ def serve(host, port):
13
+ """Start the Arbor API server"""
14
+ uvicorn.run(app, host=host, port=port)
15
+
16
+ if __name__ == '__main__':
17
+ cli()
File without changes
@@ -0,0 +1,2 @@
1
+ from typing import Optional, Dict, Any
2
+
File without changes
@@ -0,0 +1,13 @@
1
+ from fastapi import APIRouter, UploadFile, File, Depends
2
+ from arbor.server.services.file_manager import FileManager
3
+ from arbor.server.api.models.schemas import FileResponse
4
+ from arbor.server.services.dependencies import get_file_manager
5
+
6
+ router = APIRouter()
7
+
8
+ @router.post("", response_model=FileResponse)
9
+ def upload_file(
10
+ file: UploadFile = File(...),
11
+ file_manager: FileManager = Depends(get_file_manager)
12
+ ):
13
+ return file_manager.save_uploaded_file(file)
@@ -0,0 +1,14 @@
1
+ from fastapi import APIRouter, Depends
2
+ from arbor.server.services.job_manager import JobManager
3
+ from arbor.server.services.dependencies import get_job_manager
4
+ from arbor.server.api.models.schemas import JobStatusResponse
5
+
6
+ router = APIRouter()
7
+
8
+ @router.get("/{job_id}", response_model=JobStatusResponse)
9
+ def get_job_status(
10
+ job_id: str,
11
+ job_manager: JobManager = Depends(get_job_manager)
12
+ ):
13
+ status = job_manager.get_job_status(job_id)
14
+ return JobStatusResponse(job_id=job_id, status=status.value)
@@ -0,0 +1,16 @@
1
+ from fastapi import APIRouter, BackgroundTasks, Depends
2
+
3
+ from arbor.server.api.models.schemas import FineTuneRequest, JobStatusResponse
4
+ from arbor.server.services.job_manager import JobManager, JobStatus
5
+ from arbor.server.services.file_manager import FileManager
6
+ from arbor.server.services.training_manager import TrainingManager
7
+ from arbor.server.services.dependencies import get_training_manager, get_job_manager, get_file_manager
8
+
9
+ router = APIRouter()
10
+
11
+ @router.post("", response_model=JobStatusResponse)
12
+ def fine_tune(request: FineTuneRequest, background_tasks: BackgroundTasks, training_manager: TrainingManager = Depends(get_training_manager), job_manager: JobManager = Depends(get_job_manager), file_manager: FileManager = Depends(get_file_manager)):
13
+ job = job_manager.create_job()
14
+ background_tasks.add_task(training_manager.fine_tune, request, job, file_manager)
15
+ job.status = JobStatus.QUEUED
16
+ return JobStatusResponse(job_id=job.id, status=job.status.value)
@@ -0,0 +1,10 @@
1
+ from pydantic_settings import BaseSettings
2
+
3
+ class Settings(BaseSettings):
4
+ UPLOADS_DIR: str = "uploads"
5
+ MODEL_CACHE_DIR: str = "model_cache"
6
+
7
+ class Config:
8
+ env_file = ".env"
9
+
10
+ settings = Settings()
File without changes
@@ -0,0 +1,10 @@
1
+ from fastapi import FastAPI
2
+ from arbor.server.api.routes import training, files, jobs
3
+ from arbor.server.core.config import settings
4
+
5
+ app = FastAPI(title="Arbor API")
6
+
7
+ # Include routers
8
+ app.include_router(training.router, prefix="/api/fine-tune")
9
+ app.include_router(files.router, prefix="/api/files")
10
+ app.include_router(jobs.router, prefix="/api/job")
File without changes
@@ -0,0 +1,16 @@
1
+ from functools import lru_cache
2
+ from arbor.server.services.file_manager import FileManager
3
+ from arbor.server.services.job_manager import JobManager
4
+ from arbor.server.services.training_manager import TrainingManager
5
+
6
+ @lru_cache()
7
+ def get_file_manager() -> FileManager:
8
+ return FileManager()
9
+
10
+ @lru_cache()
11
+ def get_job_manager() -> JobManager:
12
+ return JobManager()
13
+
14
+ @lru_cache()
15
+ def get_training_manager() -> TrainingManager:
16
+ return TrainingManager()
@@ -0,0 +1,83 @@
1
+ from pathlib import Path
2
+ import json
3
+ import os
4
+ import shutil
5
+ import time
6
+ import uuid
7
+ from fastapi import UploadFile
8
+ from arbor.server.api.models.schemas import FileResponse
9
+
10
+ class FileManager:
11
+ def __init__(self):
12
+ self.uploads_dir = Path("uploads")
13
+ self.uploads_dir.mkdir(exist_ok=True)
14
+ self.files = self.load_files_from_uploads()
15
+
16
+ def load_files_from_uploads(self):
17
+ files = {}
18
+
19
+ # Scan through all directories in uploads directory
20
+ for dir_path in self.uploads_dir.glob("*"):
21
+ if not dir_path.is_dir():
22
+ continue
23
+
24
+ # Check for metadata.json
25
+ metadata_path = dir_path / "metadata.json"
26
+ if not metadata_path.exists():
27
+ continue
28
+
29
+ # Load metadata
30
+ with open(metadata_path) as f:
31
+ metadata = json.load(f)
32
+
33
+ # Find the .jsonl file
34
+ jsonl_files = list(dir_path.glob("*.jsonl"))
35
+ if not jsonl_files:
36
+ continue
37
+
38
+ file_path = jsonl_files[0]
39
+ files[dir_path.name] = {
40
+ "path": str(file_path),
41
+ "purpose": metadata.get("purpose", "training"),
42
+ "bytes": file_path.stat().st_size,
43
+ "created_at": metadata.get("created_at", int(file_path.stat().st_mtime)),
44
+ "filename": metadata.get("filename", file_path.name)
45
+ }
46
+
47
+ return files
48
+
49
+ def save_uploaded_file(self, file: UploadFile) -> FileResponse:
50
+ file_id = str(uuid.uuid4())
51
+ dir_path = self.uploads_dir / file_id
52
+ dir_path.mkdir(exist_ok=True)
53
+
54
+ # Save the actual file
55
+ file_path = dir_path / f"data.jsonl"
56
+ with open(file_path, "wb") as f:
57
+ shutil.copyfileobj(file.file, f)
58
+
59
+ # Create metadata
60
+ metadata = {
61
+ "purpose": "training",
62
+ "created_at": int(time.time()),
63
+ "filename": file.filename
64
+ }
65
+
66
+ # Save metadata
67
+ with open(dir_path / "metadata.json", "w") as f:
68
+ json.dump(metadata, f)
69
+
70
+ file_data = {
71
+ "id": file_id,
72
+ "path": str(file_path),
73
+ "purpose": metadata["purpose"],
74
+ "bytes": file.size,
75
+ "created_at": metadata["created_at"],
76
+ "filename": metadata["filename"]
77
+ }
78
+
79
+ self.files[file_id] = file_data
80
+ return FileResponse(**file_data)
81
+
82
+ def get_file(self, file_id: str):
83
+ return self.files[file_id]
@@ -0,0 +1,72 @@
1
+ import uuid
2
+ from enum import Enum
3
+ import logging
4
+ from datetime import datetime
5
+
6
+ class JobStatus(Enum):
7
+ PENDING = "pending"
8
+ QUEUED = "queued"
9
+ RUNNING = "running"
10
+ COMPLETED = "completed"
11
+ FAILED = "failed"
12
+
13
+ class JobLogHandler(logging.Handler):
14
+ def __init__(self, job):
15
+ super().__init__()
16
+ self.job = job
17
+
18
+ def emit(self, record):
19
+ log_entry = {
20
+ 'timestamp': datetime.fromtimestamp(record.created).isoformat(),
21
+ 'level': record.levelname,
22
+ 'message': record.getMessage()
23
+ }
24
+ self.job.logs.append(log_entry)
25
+
26
+ class Job:
27
+ def __init__(self, id: str, status: JobStatus):
28
+ self.id = id
29
+ self.status = status
30
+ self.logs = []
31
+ self.logger = None
32
+ self.log_handler = None
33
+
34
+ def setup_logger(self, logger_name: str = None) -> logging.Logger:
35
+ """Sets up logging for the job with a dedicated handler."""
36
+ if logger_name is None:
37
+ logger_name = f"job_{self.id}"
38
+
39
+ logger = logging.getLogger(logger_name)
40
+ logger.setLevel(logging.INFO)
41
+
42
+ # Create and setup handler if not already exists
43
+ if self.log_handler is None:
44
+ handler = JobLogHandler(self)
45
+ formatter = logging.Formatter('%(message)s')
46
+ handler.setFormatter(formatter)
47
+ logger.addHandler(handler)
48
+ self.log_handler = handler
49
+
50
+ self.logger = logger
51
+ return logger
52
+
53
+ def cleanup_logger(self):
54
+ """Removes the job's logging handler."""
55
+ if self.logger and self.log_handler:
56
+ self.logger.removeHandler(self.log_handler)
57
+ self.log_handler = None
58
+ self.logger = None
59
+
60
+ class JobManager:
61
+ def __init__(self):
62
+ self.jobs = {}
63
+
64
+ def get_job_status(self, job_id: str):
65
+ if job_id not in self.jobs:
66
+ raise ValueError(f"Job {job_id} not found")
67
+ return self.jobs[job_id].status
68
+
69
+ def create_job(self):
70
+ job = Job(id=str(uuid.uuid4()), status=JobStatus.PENDING)
71
+ self.jobs[job.id] = job
72
+ return job
@@ -0,0 +1,263 @@
1
+ from arbor.server.api.models.schemas import FineTuneRequest
2
+ from arbor.server.services.job_manager import Job, JobStatus
3
+ from arbor.server.services.file_manager import FileManager
4
+
5
+ class TrainingManager:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def find_train_args(self, request: FineTuneRequest, file_manager: FileManager):
10
+ file = file_manager.get_file(request.training_file)
11
+ if file is None:
12
+ raise ValueError(f"Training file {request.training_file} not found")
13
+
14
+ data_path = file["path"]
15
+ output_dir = f"models/{request.model_name}" # TODO: This should be updated to be unique in some way
16
+
17
+
18
+ default_train_kwargs = {
19
+ "device": None,
20
+ "use_peft": False,
21
+ "num_train_epochs": 5,
22
+ "per_device_train_batch_size": 1,
23
+ "gradient_accumulation_steps": 8,
24
+ "learning_rate": 1e-5,
25
+ "max_seq_length": None,
26
+ "packing": True,
27
+ "bf16": True,
28
+ "output_dir": output_dir,
29
+ "train_data_path": data_path,
30
+ }
31
+ train_kwargs = {'packing': False}
32
+ train_kwargs={**default_train_kwargs, **(train_kwargs or {})}
33
+ output_dir = train_kwargs["output_dir"] # user might have changed the output_dir
34
+
35
+ return train_kwargs
36
+
37
+
38
+ def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
39
+ # Get logger for this job
40
+ logger = job.setup_logger("training")
41
+
42
+ job.status = JobStatus.RUNNING
43
+ logger.info("Starting fine-tuning job")
44
+
45
+ try:
46
+ train_kwargs = self.find_train_args(request, file_manager)
47
+
48
+ import torch
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
50
+ from trl import SFTConfig, SFTTrainer, setup_chat_format
51
+
52
+ device = train_kwargs.get("device", None)
53
+ if device is None:
54
+ device = (
55
+ "cuda"
56
+ if torch.cuda.is_available()
57
+ else "mps" if torch.backends.mps.is_available() else "cpu"
58
+ )
59
+ logger.info(f"Using device: {device}")
60
+
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ pretrained_model_name_or_path=request.model_name
63
+ ).to(device)
64
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=request.model_name)
65
+
66
+ # Set up the chat format; generally only for non-chat model variants, hence the try-except.
67
+ try:
68
+ model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
69
+ except Exception:
70
+ pass
71
+
72
+ if tokenizer.pad_token_id is None:
73
+ logger.info("Adding pad token to tokenizer")
74
+ tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
75
+
76
+ logger.info("Creating dataset")
77
+ if "max_seq_length" not in train_kwargs or train_kwargs["max_seq_length"] is None:
78
+ train_kwargs["max_seq_length"] = 4096
79
+ logger.info(f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}")
80
+
81
+
82
+ hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
83
+ def tokenize_function(example):
84
+ return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
85
+ tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
86
+ tokenized_dataset.set_format(type="torch")
87
+ tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
88
+
89
+ USE_PEFT = train_kwargs.get("use_peft", False)
90
+ peft_config = None
91
+
92
+ if USE_PEFT:
93
+ from peft import LoraConfig
94
+
95
+ rank_dimension = 32
96
+ lora_alpha = 64
97
+ lora_dropout = 0.05
98
+
99
+ peft_config = LoraConfig(
100
+ r=rank_dimension,
101
+ lora_alpha=lora_alpha,
102
+ lora_dropout=lora_dropout,
103
+ bias="none",
104
+ target_modules="all-linear",
105
+ task_type="CAUSAL_LM",
106
+ )
107
+
108
+ sft_config = SFTConfig(
109
+ output_dir=train_kwargs["output_dir"],
110
+ num_train_epochs=train_kwargs["num_train_epochs"],
111
+ per_device_train_batch_size=train_kwargs["per_device_train_batch_size"],
112
+ gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"],
113
+ learning_rate=train_kwargs["learning_rate"],
114
+ max_grad_norm=2.0, # note that the current SFTConfig default is 1.0
115
+ logging_steps=20,
116
+ warmup_ratio=0.03,
117
+ lr_scheduler_type="constant",
118
+ save_steps=10_000,
119
+ bf16=train_kwargs["bf16"],
120
+ max_seq_length=train_kwargs["max_seq_length"],
121
+ packing=train_kwargs["packing"],
122
+ dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves
123
+ "add_special_tokens": False, # Special tokens handled by template
124
+ "append_concat_token": False, # No additional separator needed
125
+ },
126
+ )
127
+
128
+ logger.info("Starting training")
129
+ trainer = SFTTrainer(
130
+ model=model,
131
+ args=sft_config,
132
+ train_dataset=tokenized_dataset,
133
+ peft_config=peft_config,
134
+
135
+ )
136
+
137
+ # Train!
138
+ trainer.train()
139
+
140
+ # Save the model!
141
+ trainer.save_model()
142
+
143
+ MERGE = True
144
+ if USE_PEFT and MERGE:
145
+ from peft import AutoPeftModelForCausalLM
146
+
147
+ # Load PEFT model on CPU
148
+ model_ = AutoPeftModelForCausalLM.from_pretrained(
149
+ pretrained_model_name_or_path=sft_config.output_dir,
150
+ torch_dtype=torch.float16,
151
+ low_cpu_mem_usage=True,
152
+ )
153
+
154
+ merged_model = model_.merge_and_unload()
155
+ merged_model.save_pretrained(
156
+ sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
157
+ )
158
+
159
+ # Clean up!
160
+ import gc
161
+
162
+ del model
163
+ del tokenizer
164
+ del trainer
165
+ gc.collect()
166
+ torch.cuda.empty_cache()
167
+
168
+ logger.info("Training completed successfully")
169
+ job.status = JobStatus.COMPLETED
170
+ except Exception as e:
171
+ logger.error(f"Training failed: {str(e)}")
172
+ job.status = JobStatus.FAILED
173
+ raise
174
+ finally:
175
+ job.cleanup_logger()
176
+
177
+ return sft_config.output_dir
178
+
179
+ def dataset_from_file(data_path):
180
+ """
181
+ Creates a HuggingFace Dataset from a JSONL file.
182
+ """
183
+ from datasets import load_dataset
184
+
185
+ dataset = load_dataset("json", data_files=data_path, split="train")
186
+ return dataset
187
+
188
+
189
+ def encode_sft_example(example, tokenizer, max_seq_length):
190
+ """
191
+ This function encodes a single example into a format that can be used for sft training.
192
+ Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
193
+ We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
194
+
195
+ Code obtained from the allenai/open-instruct repository: https://github.com/allenai/open-instruct/blob/4365dea3d1a6111e8b2712af06b22a4512a0df88/open_instruct/finetune.py
196
+ """
197
+ import torch
198
+
199
+ messages = example["messages"]
200
+ if len(messages) == 0:
201
+ raise ValueError("messages field is empty.")
202
+ input_ids = tokenizer.apply_chat_template(
203
+ conversation=messages,
204
+ tokenize=True,
205
+ return_tensors="pt",
206
+ padding=False,
207
+ truncation=True,
208
+ max_length=max_seq_length,
209
+ add_generation_prompt=False,
210
+ )
211
+ labels = input_ids.clone()
212
+ # mask the non-assistant part for avoiding loss
213
+ for message_idx, message in enumerate(messages):
214
+ if message["role"] != "assistant":
215
+ # we calculate the start index of this non-assistant message
216
+ if message_idx == 0:
217
+ message_start_idx = 0
218
+ else:
219
+ message_start_idx = tokenizer.apply_chat_template(
220
+ conversation=messages[:message_idx], # here marks the end of the previous messages
221
+ tokenize=True,
222
+ return_tensors="pt",
223
+ padding=False,
224
+ truncation=True,
225
+ max_length=max_seq_length,
226
+ add_generation_prompt=False,
227
+ ).shape[1]
228
+ # next, we calculate the end index of this non-assistant message
229
+ if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
230
+ # for intermediate messages that follow with an assistant message, we need to
231
+ # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
232
+ # (e.g., `<|assistant|>`)
233
+ message_end_idx = tokenizer.apply_chat_template(
234
+ conversation=messages[: message_idx + 1],
235
+ tokenize=True,
236
+ return_tensors="pt",
237
+ padding=False,
238
+ truncation=True,
239
+ max_length=max_seq_length,
240
+ add_generation_prompt=True,
241
+ ).shape[1]
242
+ else:
243
+ # for the last message or the message that doesn't follow with an assistant message,
244
+ # we don't need to add the assistant generation prefix
245
+ message_end_idx = tokenizer.apply_chat_template(
246
+ conversation=messages[: message_idx + 1],
247
+ tokenize=True,
248
+ return_tensors="pt",
249
+ padding=False,
250
+ truncation=True,
251
+ max_length=max_seq_length,
252
+ add_generation_prompt=False,
253
+ ).shape[1]
254
+ # set the label to -100 for the non-assistant part
255
+ labels[:, message_start_idx:message_end_idx] = -100
256
+ if max_seq_length and message_end_idx >= max_seq_length:
257
+ break
258
+ attention_mask = torch.ones_like(input_ids)
259
+ return {
260
+ "input_ids": input_ids.flatten(),
261
+ "labels": labels.flatten(),
262
+ "attention_mask": attention_mask.flatten()
263
+ }
File without changes
File without changes