PraisonAI 2.0.12__cp311-cp311-macosx_15_0_arm64.whl → 2.2.16__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PraisonAI might be problematic. Click here for more details.
- praisonai/README.md +5 -0
- praisonai/agents_generator.py +83 -44
- praisonai/api/call.py +3 -3
- praisonai/auto.py +1 -1
- praisonai/cli.py +151 -16
- praisonai/deploy.py +1 -1
- praisonai/inbuilt_tools/__init__.py +1 -1
- praisonai/public/praison-ai-agents-architecture-dark.png +0 -0
- praisonai/public/praison-ai-agents-architecture.png +0 -0
- praisonai/setup/setup_conda_env.sh +55 -22
- praisonai/train.py +442 -156
- praisonai/train_vision.py +306 -0
- praisonai/ui/agents.py +822 -0
- praisonai/ui/callbacks.py +57 -0
- praisonai/ui/code.py +4 -2
- praisonai/ui/colab.py +474 -0
- praisonai/ui/colab_chainlit.py +81 -0
- praisonai/ui/config/chainlit.md +1 -1
- praisonai/ui/realtime.py +65 -10
- praisonai/ui/sql_alchemy.py +6 -5
- praisonai/ui/tools.md +133 -0
- praisonai/upload_vision.py +140 -0
- praisonai-2.2.16.dist-info/METADATA +103 -0
- {praisonai-2.0.12.dist-info → praisonai-2.2.16.dist-info}/RECORD +26 -29
- {praisonai-2.0.12.dist-info → praisonai-2.2.16.dist-info}/WHEEL +1 -1
- praisonai/ui/config/.chainlit/config.toml +0 -120
- praisonai/ui/config/.chainlit/translations/bn.json +0 -231
- praisonai/ui/config/.chainlit/translations/en-US.json +0 -229
- praisonai/ui/config/.chainlit/translations/gu.json +0 -231
- praisonai/ui/config/.chainlit/translations/he-IL.json +0 -231
- praisonai/ui/config/.chainlit/translations/hi.json +0 -231
- praisonai/ui/config/.chainlit/translations/kn.json +0 -231
- praisonai/ui/config/.chainlit/translations/ml.json +0 -231
- praisonai/ui/config/.chainlit/translations/mr.json +0 -231
- praisonai/ui/config/.chainlit/translations/ta.json +0 -231
- praisonai/ui/config/.chainlit/translations/te.json +0 -231
- praisonai/ui/config/.chainlit/translations/zh-CN.json +0 -229
- praisonai-2.0.12.dist-info/LICENSE +0 -20
- praisonai-2.0.12.dist-info/METADATA +0 -498
- {praisonai-2.0.12.dist-info → praisonai-2.2.16.dist-info}/entry_points.txt +0 -0
praisonai/train.py
CHANGED
|
@@ -1,168 +1,287 @@
|
|
|
1
|
-
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
This script finetunes a model using Unsloth's fast training framework.
|
|
5
|
+
It supports both ShareGPT and Alpaca‑style datasets by converting raw conversation
|
|
6
|
+
data into plain-text prompts using a chat template, then pre‑tokenizing the prompts.
|
|
7
|
+
Extra debug logging is added to help trace the root cause of errors.
|
|
8
|
+
"""
|
|
9
|
+
|
|
2
10
|
import os
|
|
3
11
|
import sys
|
|
4
12
|
import yaml
|
|
5
13
|
import torch
|
|
6
14
|
import shutil
|
|
15
|
+
import subprocess
|
|
7
16
|
from transformers import TextStreamer
|
|
8
17
|
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
9
18
|
from trl import SFTTrainer
|
|
10
19
|
from transformers import TrainingArguments
|
|
11
|
-
from datasets import load_dataset, concatenate_datasets
|
|
20
|
+
from datasets import load_dataset, concatenate_datasets
|
|
12
21
|
from psutil import virtual_memory
|
|
13
|
-
|
|
14
|
-
|
|
22
|
+
from unsloth.chat_templates import standardize_sharegpt, get_chat_template
|
|
23
|
+
from functools import partial
|
|
24
|
+
|
|
25
|
+
#####################################
|
|
26
|
+
# Step 1: Formatting Raw Conversations
|
|
27
|
+
#####################################
|
|
28
|
+
def formatting_prompts_func(examples, tokenizer):
|
|
29
|
+
"""
|
|
30
|
+
Converts each example's conversation into a single plain-text prompt.
|
|
31
|
+
If the example has a "conversations" field, process it as ShareGPT-style.
|
|
32
|
+
Otherwise, assume Alpaca-style data with "instruction", "input", and "output" fields.
|
|
33
|
+
"""
|
|
34
|
+
print("DEBUG: formatting_prompts_func() received batch with keys:", list(examples.keys()))
|
|
35
|
+
texts = []
|
|
36
|
+
# Check if the example has a "conversations" field.
|
|
37
|
+
if "conversations" in examples:
|
|
38
|
+
for convo in examples["conversations"]:
|
|
39
|
+
try:
|
|
40
|
+
formatted = tokenizer.apply_chat_template(
|
|
41
|
+
convo,
|
|
42
|
+
tokenize=False, # Return a plain string
|
|
43
|
+
add_generation_prompt=False
|
|
44
|
+
)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
print(f"ERROR in apply_chat_template (conversations): {e}")
|
|
47
|
+
formatted = ""
|
|
48
|
+
# Flatten list if necessary
|
|
49
|
+
if isinstance(formatted, list):
|
|
50
|
+
formatted = formatted[0] if len(formatted) == 1 else "\n".join(formatted)
|
|
51
|
+
texts.append(formatted)
|
|
52
|
+
else:
|
|
53
|
+
# Assume Alpaca format: use "instruction", "input", and "output" keys.
|
|
54
|
+
instructions = examples.get("instruction", [])
|
|
55
|
+
inputs_list = examples.get("input", [])
|
|
56
|
+
outputs_list = examples.get("output", [])
|
|
57
|
+
# If any field is missing, replace with empty string.
|
|
58
|
+
for ins, inp, out in zip(instructions, inputs_list, outputs_list):
|
|
59
|
+
# Create a conversation-like structure.
|
|
60
|
+
convo = [
|
|
61
|
+
{"role": "user", "content": ins + (f"\nInput: {inp}" if inp.strip() != "" else "")},
|
|
62
|
+
{"role": "assistant", "content": out}
|
|
63
|
+
]
|
|
64
|
+
try:
|
|
65
|
+
formatted = tokenizer.apply_chat_template(
|
|
66
|
+
convo,
|
|
67
|
+
tokenize=False,
|
|
68
|
+
add_generation_prompt=False
|
|
69
|
+
)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
print(f"ERROR in apply_chat_template (alpaca): {e}")
|
|
72
|
+
formatted = ""
|
|
73
|
+
if isinstance(formatted, list):
|
|
74
|
+
formatted = formatted[0] if len(formatted) == 1 else "\n".join(formatted)
|
|
75
|
+
texts.append(formatted)
|
|
76
|
+
if texts:
|
|
77
|
+
print("DEBUG: Raw texts sample (first 200 chars):", texts[0][:200])
|
|
78
|
+
return {"text": texts}
|
|
79
|
+
|
|
80
|
+
#####################################
|
|
81
|
+
# Step 2: Tokenizing the Prompts
|
|
82
|
+
#####################################
|
|
83
|
+
def tokenize_function(examples, hf_tokenizer, max_length):
|
|
84
|
+
"""
|
|
85
|
+
Tokenizes a batch of text prompts with padding and truncation enabled.
|
|
86
|
+
"""
|
|
87
|
+
flat_texts = []
|
|
88
|
+
for t in examples["text"]:
|
|
89
|
+
if isinstance(t, list):
|
|
90
|
+
t = t[0] if len(t) == 1 else " ".join(t)
|
|
91
|
+
flat_texts.append(t)
|
|
92
|
+
print("DEBUG: Tokenizing a batch of size:", len(flat_texts))
|
|
93
|
+
tokenized = hf_tokenizer(
|
|
94
|
+
flat_texts,
|
|
95
|
+
padding="max_length",
|
|
96
|
+
truncation=True,
|
|
97
|
+
max_length=max_length,
|
|
98
|
+
return_tensors="pt",
|
|
99
|
+
)
|
|
100
|
+
tokenized = {key: value.tolist() for key, value in tokenized.items()}
|
|
101
|
+
sample_key = list(tokenized.keys())[0]
|
|
102
|
+
print("DEBUG: Tokenized sample (first 10 tokens of", sample_key, "):", tokenized[sample_key][0][:10])
|
|
103
|
+
return tokenized
|
|
104
|
+
|
|
105
|
+
#####################################
|
|
106
|
+
# Main Training Class
|
|
107
|
+
#####################################
|
|
108
|
+
class TrainModel:
|
|
15
109
|
def __init__(self, config_path="config.yaml"):
|
|
16
110
|
self.load_config(config_path)
|
|
17
111
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
18
|
-
self.model
|
|
112
|
+
self.model = None
|
|
113
|
+
self.hf_tokenizer = None # The underlying HF tokenizer
|
|
114
|
+
self.chat_tokenizer = None # Chat wrapper for formatting
|
|
19
115
|
|
|
20
116
|
def load_config(self, path):
|
|
21
117
|
with open(path, "r") as file:
|
|
22
118
|
self.config = yaml.safe_load(file)
|
|
119
|
+
print("DEBUG: Loaded config:", self.config)
|
|
23
120
|
|
|
24
121
|
def print_system_info(self):
|
|
25
|
-
print(
|
|
26
|
-
print(
|
|
122
|
+
print("DEBUG: PyTorch version:", torch.__version__)
|
|
123
|
+
print("DEBUG: CUDA version:", torch.version.cuda)
|
|
27
124
|
if torch.cuda.is_available():
|
|
28
|
-
|
|
29
|
-
print(f"CUDA Device Capability: {device_capability}")
|
|
125
|
+
print("DEBUG: CUDA Device Capability:", torch.cuda.get_device_capability())
|
|
30
126
|
else:
|
|
31
|
-
print("CUDA is not available")
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
pip_version = subprocess.check_output(['pip', '--version']).decode().strip()
|
|
35
|
-
python_path = sys.executable
|
|
36
|
-
pip_path = subprocess.check_output(['which', 'pip']).decode().strip()
|
|
37
|
-
print(f"Python Version: {python_version}")
|
|
38
|
-
print(f"Pip Version: {pip_version}")
|
|
39
|
-
print(f"Python Path: {python_path}")
|
|
40
|
-
print(f"Pip Path: {pip_path}")
|
|
127
|
+
print("DEBUG: CUDA is not available")
|
|
128
|
+
print("DEBUG: Python Version:", sys.version)
|
|
129
|
+
print("DEBUG: Python Path:", sys.executable)
|
|
41
130
|
|
|
42
131
|
def check_gpu(self):
|
|
43
132
|
gpu_stats = torch.cuda.get_device_properties(0)
|
|
44
|
-
print(f"GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory
|
|
133
|
+
print(f"DEBUG: GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory/(1024**3),3)} GB.")
|
|
45
134
|
|
|
46
135
|
def check_ram(self):
|
|
47
136
|
ram_gb = virtual_memory().total / 1e9
|
|
48
|
-
print(
|
|
137
|
+
print(f"DEBUG: Your runtime has {ram_gb:.1f} gigabytes of available RAM")
|
|
49
138
|
if ram_gb < 20:
|
|
50
|
-
print(
|
|
139
|
+
print("DEBUG: Not using a high-RAM runtime")
|
|
51
140
|
else:
|
|
52
|
-
print(
|
|
53
|
-
|
|
54
|
-
# def install_packages(self):
|
|
55
|
-
# subprocess.run(["pip", "install", "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@4e570be9ae4ced8cdc64e498125708e34942befc"])
|
|
56
|
-
# subprocess.run(["pip", "install", "--no-deps", "trl<0.9.0", "peft==0.12.0", "accelerate==0.33.0", "bitsandbytes==0.43.3"])
|
|
141
|
+
print("DEBUG: You are using a high-RAM runtime!")
|
|
57
142
|
|
|
58
143
|
def prepare_model(self):
|
|
59
|
-
|
|
144
|
+
print("DEBUG: Preparing model and tokenizer...")
|
|
145
|
+
self.model, original_tokenizer = FastLanguageModel.from_pretrained(
|
|
60
146
|
model_name=self.config["model_name"],
|
|
61
147
|
max_seq_length=self.config["max_seq_length"],
|
|
62
148
|
dtype=None,
|
|
63
|
-
load_in_4bit=self.config["load_in_4bit"]
|
|
149
|
+
load_in_4bit=self.config["load_in_4bit"],
|
|
64
150
|
)
|
|
151
|
+
print("DEBUG: Model and original tokenizer loaded.")
|
|
152
|
+
if original_tokenizer.pad_token is None:
|
|
153
|
+
original_tokenizer.pad_token = original_tokenizer.eos_token
|
|
154
|
+
original_tokenizer.model_max_length = self.config["max_seq_length"]
|
|
155
|
+
self.chat_tokenizer = get_chat_template(original_tokenizer, chat_template="llama-3.1")
|
|
156
|
+
self.hf_tokenizer = original_tokenizer
|
|
157
|
+
print("DEBUG: Chat tokenizer created; HF tokenizer saved.")
|
|
65
158
|
self.model = FastLanguageModel.get_peft_model(
|
|
66
159
|
self.model,
|
|
67
|
-
r=
|
|
68
|
-
target_modules=
|
|
69
|
-
lora_alpha=
|
|
70
|
-
lora_dropout=
|
|
71
|
-
bias=
|
|
72
|
-
use_gradient_checkpointing=
|
|
73
|
-
random_state=
|
|
74
|
-
use_rslora=
|
|
75
|
-
loftq_config=
|
|
160
|
+
r=16,
|
|
161
|
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
162
|
+
lora_alpha=16,
|
|
163
|
+
lora_dropout=0,
|
|
164
|
+
bias="none",
|
|
165
|
+
use_gradient_checkpointing="unsloth",
|
|
166
|
+
random_state=3407,
|
|
167
|
+
use_rslora=False,
|
|
168
|
+
loftq_config=None,
|
|
76
169
|
)
|
|
170
|
+
print("DEBUG: LoRA adapters added.")
|
|
77
171
|
|
|
78
172
|
def process_dataset(self, dataset_info):
|
|
79
173
|
dataset_name = dataset_info["name"]
|
|
80
174
|
split_type = dataset_info.get("split_type", "train")
|
|
81
|
-
|
|
82
|
-
rename = dataset_info.get("rename", {})
|
|
83
|
-
filter_data = dataset_info.get("filter_data", False)
|
|
84
|
-
filter_column_value = dataset_info.get("filter_column_value", "id")
|
|
85
|
-
filter_value = dataset_info.get("filter_value", "alpaca")
|
|
86
|
-
num_samples = dataset_info.get("num_samples", 20000)
|
|
87
|
-
|
|
175
|
+
print(f"DEBUG: Loading dataset '{dataset_name}' split '{split_type}'...")
|
|
88
176
|
dataset = load_dataset(dataset_name, split=split_type)
|
|
89
|
-
|
|
90
|
-
if
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
177
|
+
print("DEBUG: Dataset columns:", dataset.column_names)
|
|
178
|
+
if "conversations" in dataset.column_names:
|
|
179
|
+
print("DEBUG: Standardizing dataset (ShareGPT style)...")
|
|
180
|
+
dataset = standardize_sharegpt(dataset)
|
|
181
|
+
else:
|
|
182
|
+
print("DEBUG: Dataset does not have 'conversations'; assuming Alpaca format.")
|
|
183
|
+
print("DEBUG: Applying formatting function to dataset...")
|
|
184
|
+
format_func = partial(formatting_prompts_func, tokenizer=self.chat_tokenizer)
|
|
185
|
+
dataset = dataset.map(format_func, batched=True, remove_columns=dataset.column_names)
|
|
186
|
+
sample = dataset[0]
|
|
187
|
+
print("DEBUG: Sample processed example keys:", list(sample.keys()))
|
|
188
|
+
if "text" in sample:
|
|
189
|
+
print("DEBUG: Sample processed 'text' type:", type(sample["text"]))
|
|
190
|
+
print("DEBUG: Sample processed 'text' content (first 200 chars):", sample["text"][:200])
|
|
191
|
+
else:
|
|
192
|
+
print("DEBUG: Processed sample does not contain 'text'.")
|
|
95
193
|
return dataset
|
|
96
194
|
|
|
97
|
-
def
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
### Response:
|
|
107
|
-
{}"""
|
|
108
|
-
texts = [alpaca_prompt.format(ins, inp, out) + self.tokenizer.eos_token for ins, inp, out in zip(examples["instruction"], examples["input"], examples["output"])]
|
|
109
|
-
return {"text": texts}
|
|
195
|
+
def tokenize_dataset(self, dataset):
|
|
196
|
+
print("DEBUG: Tokenizing the entire dataset...")
|
|
197
|
+
tokenized_dataset = dataset.map(
|
|
198
|
+
lambda examples: tokenize_function(examples, self.hf_tokenizer, self.config["max_seq_length"]),
|
|
199
|
+
batched=True
|
|
200
|
+
)
|
|
201
|
+
tokenized_dataset = tokenized_dataset.remove_columns(["text"])
|
|
202
|
+
print("DEBUG: Tokenized dataset sample keys:", tokenized_dataset[0].keys())
|
|
203
|
+
return tokenized_dataset
|
|
110
204
|
|
|
111
205
|
def load_datasets(self):
|
|
112
206
|
datasets = []
|
|
113
207
|
for dataset_info in self.config["dataset"]:
|
|
208
|
+
print("DEBUG: Processing dataset info:", dataset_info)
|
|
114
209
|
datasets.append(self.process_dataset(dataset_info))
|
|
115
|
-
|
|
210
|
+
combined = concatenate_datasets(datasets)
|
|
211
|
+
print("DEBUG: Combined dataset has", len(combined), "examples.")
|
|
212
|
+
return combined
|
|
116
213
|
|
|
117
214
|
def train_model(self):
|
|
118
|
-
|
|
215
|
+
print("DEBUG: Starting training...")
|
|
216
|
+
raw_dataset = self.load_datasets()
|
|
217
|
+
tokenized_dataset = self.tokenize_dataset(raw_dataset)
|
|
218
|
+
print("DEBUG: Dataset tokenization complete.")
|
|
219
|
+
# Build the training arguments parameters dynamically
|
|
220
|
+
ta_params = {
|
|
221
|
+
"per_device_train_batch_size": self.config.get("per_device_train_batch_size", 2),
|
|
222
|
+
"gradient_accumulation_steps": self.config.get("gradient_accumulation_steps", 2),
|
|
223
|
+
"warmup_steps": self.config.get("warmup_steps", 50),
|
|
224
|
+
"max_steps": self.config.get("max_steps", 2800),
|
|
225
|
+
"learning_rate": self.config.get("learning_rate", 2e-4),
|
|
226
|
+
"fp16": self.config.get("fp16", not is_bfloat16_supported()),
|
|
227
|
+
"bf16": self.config.get("bf16", is_bfloat16_supported()),
|
|
228
|
+
"logging_steps": self.config.get("logging_steps", 15),
|
|
229
|
+
"optim": self.config.get("optim", "adamw_8bit"),
|
|
230
|
+
"weight_decay": self.config.get("weight_decay", 0.01),
|
|
231
|
+
"lr_scheduler_type": self.config.get("lr_scheduler_type", "linear"),
|
|
232
|
+
"seed": self.config.get("seed", 3407),
|
|
233
|
+
"output_dir": self.config.get("output_dir", "outputs"),
|
|
234
|
+
"report_to": "none" if not os.getenv("PRAISON_WANDB") else "wandb",
|
|
235
|
+
"remove_unused_columns": self.config.get("remove_unused_columns", False)
|
|
236
|
+
}
|
|
237
|
+
if os.getenv("PRAISON_WANDB"):
|
|
238
|
+
ta_params["save_steps"] = self.config.get("save_steps", 100)
|
|
239
|
+
ta_params["run_name"] = os.getenv("PRAISON_WANDB_RUN_NAME", "praisonai-train")
|
|
240
|
+
|
|
241
|
+
training_args = TrainingArguments(**ta_params)
|
|
242
|
+
# Since the dataset is pre-tokenized, we supply a dummy dataset_text_field.
|
|
119
243
|
trainer = SFTTrainer(
|
|
120
244
|
model=self.model,
|
|
121
|
-
tokenizer=self.
|
|
122
|
-
train_dataset=
|
|
123
|
-
dataset_text_field=
|
|
245
|
+
tokenizer=self.hf_tokenizer,
|
|
246
|
+
train_dataset=tokenized_dataset,
|
|
247
|
+
dataset_text_field="input_ids", # Dummy field since data is numeric
|
|
124
248
|
max_seq_length=self.config["max_seq_length"],
|
|
125
|
-
dataset_num_proc=
|
|
126
|
-
packing=
|
|
127
|
-
args=
|
|
128
|
-
per_device_train_batch_size=self.config["per_device_train_batch_size"],
|
|
129
|
-
gradient_accumulation_steps=self.config["gradient_accumulation_steps"],
|
|
130
|
-
warmup_steps=self.config["warmup_steps"],
|
|
131
|
-
num_train_epochs=self.config["num_train_epochs"],
|
|
132
|
-
max_steps=self.config["max_steps"],
|
|
133
|
-
learning_rate=self.config["learning_rate"],
|
|
134
|
-
fp16=not is_bfloat16_supported(),
|
|
135
|
-
bf16=is_bfloat16_supported(),
|
|
136
|
-
logging_steps=self.config["logging_steps"],
|
|
137
|
-
optim=self.config["optim"],
|
|
138
|
-
weight_decay=self.config["weight_decay"],
|
|
139
|
-
lr_scheduler_type=self.config["lr_scheduler_type"],
|
|
140
|
-
seed=self.config["seed"],
|
|
141
|
-
output_dir=self.config["output_dir"],
|
|
142
|
-
),
|
|
249
|
+
dataset_num_proc=1, # Use a single process to avoid pickling issues
|
|
250
|
+
packing=False,
|
|
251
|
+
args=training_args,
|
|
143
252
|
)
|
|
253
|
+
from unsloth.chat_templates import train_on_responses_only
|
|
254
|
+
trainer = train_on_responses_only(
|
|
255
|
+
trainer,
|
|
256
|
+
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
|
257
|
+
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
|
258
|
+
)
|
|
259
|
+
print("DEBUG: Beginning trainer.train() ...")
|
|
144
260
|
trainer.train()
|
|
145
|
-
|
|
146
|
-
self.
|
|
261
|
+
print("DEBUG: Training complete. Saving model and tokenizer locally...")
|
|
262
|
+
self.model.save_pretrained("lora_model")
|
|
263
|
+
self.hf_tokenizer.save_pretrained("lora_model")
|
|
264
|
+
print("DEBUG: Saved model and tokenizer to 'lora_model'.")
|
|
147
265
|
|
|
148
266
|
def inference(self, instruction, input_text):
|
|
149
267
|
FastLanguageModel.for_inference(self.model)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
268
|
+
messages = [{"role": "user", "content": f"{instruction}\n\nInput: {input_text}"}]
|
|
269
|
+
inputs = self.hf_tokenizer.apply_chat_template(
|
|
270
|
+
messages,
|
|
271
|
+
tokenize=True,
|
|
272
|
+
add_generation_prompt=True,
|
|
273
|
+
return_tensors="pt"
|
|
274
|
+
).to("cuda")
|
|
275
|
+
outputs = self.model.generate(
|
|
276
|
+
input_ids=inputs,
|
|
277
|
+
max_new_tokens=64,
|
|
278
|
+
use_cache=True,
|
|
279
|
+
temperature=1.5,
|
|
280
|
+
min_p=0.1
|
|
281
|
+
)
|
|
282
|
+
print("DEBUG: Inference output:", self.hf_tokenizer.batch_decode(outputs))
|
|
157
283
|
|
|
158
|
-
### Response:
|
|
159
|
-
{}"""
|
|
160
|
-
inputs = self.tokenizer([alpaca_prompt.format(instruction, input_text, "")], return_tensors="pt").to("cuda")
|
|
161
|
-
outputs = self.model.generate(**inputs, max_new_tokens=64, use_cache=True)
|
|
162
|
-
print(self.tokenizer.batch_decode(outputs))
|
|
163
|
-
|
|
164
284
|
def load_model(self):
|
|
165
|
-
"""Loads the model and tokenizer using the FastLanguageModel library."""
|
|
166
285
|
from unsloth import FastLanguageModel
|
|
167
286
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
168
287
|
model_name=self.config["output_dir"],
|
|
@@ -177,56 +296,235 @@ class train:
|
|
|
177
296
|
shutil.rmtree(self.config["hf_model_name"])
|
|
178
297
|
self.model.push_to_hub_merged(
|
|
179
298
|
self.config["hf_model_name"],
|
|
180
|
-
self.
|
|
299
|
+
self.hf_tokenizer,
|
|
181
300
|
save_method="merged_16bit",
|
|
182
|
-
token=os.getenv(
|
|
301
|
+
token=os.getenv("HF_TOKEN")
|
|
183
302
|
)
|
|
184
303
|
|
|
185
304
|
def push_model_gguf(self):
|
|
186
305
|
self.model.push_to_hub_gguf(
|
|
187
306
|
self.config["hf_model_name"],
|
|
188
|
-
self.
|
|
307
|
+
self.hf_tokenizer,
|
|
189
308
|
quantization_method=self.config["quantization_method"],
|
|
190
|
-
token=os.getenv(
|
|
309
|
+
token=os.getenv("HF_TOKEN")
|
|
191
310
|
)
|
|
192
|
-
|
|
311
|
+
|
|
193
312
|
def save_model_gguf(self):
|
|
194
313
|
self.model.save_pretrained_gguf(
|
|
195
314
|
self.config["hf_model_name"],
|
|
196
|
-
self.
|
|
315
|
+
self.hf_tokenizer,
|
|
197
316
|
quantization_method="q4_k_m"
|
|
198
317
|
)
|
|
199
318
|
|
|
200
319
|
def prepare_modelfile_content(self):
|
|
201
320
|
output_model = self.config["hf_model_name"]
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
{{
|
|
214
|
-
|
|
215
|
-
{{
|
|
216
|
-
{{
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
321
|
+
model_name = self.config["model_name"].lower()
|
|
322
|
+
# Mapping from model name keywords to their default TEMPLATE and stop tokens (and optional SYSTEM/num_ctx)
|
|
323
|
+
mapping = {
|
|
324
|
+
"llama": {
|
|
325
|
+
"template": """<|start_header_id|>system<|end_header_id|>
|
|
326
|
+
Cutting Knowledge Date: December 2023
|
|
327
|
+
{{ if .System }}{{ .System }}
|
|
328
|
+
{{- end }}
|
|
329
|
+
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
|
|
330
|
+
You are a helpful assistant with tool calling capabilities.
|
|
331
|
+
{{- end }}<|eot_id|>
|
|
332
|
+
{{- range $i, $_ := .Messages }}
|
|
333
|
+
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
|
334
|
+
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
|
335
|
+
{{- if and $.Tools $last }}
|
|
336
|
+
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
|
337
|
+
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
|
338
|
+
{{ range $.Tools }}
|
|
339
|
+
{{- . }}
|
|
340
|
+
{{ end }}
|
|
341
|
+
{{ .Content }}<|eot_id|>
|
|
342
|
+
{{- else }}
|
|
343
|
+
{{ .Content }}<|eot_id|>
|
|
344
|
+
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
|
345
|
+
{{ end }}
|
|
346
|
+
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
|
347
|
+
{{- if .ToolCalls }}
|
|
348
|
+
{{ range .ToolCalls }}
|
|
349
|
+
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
|
350
|
+
{{- else }}
|
|
351
|
+
{{ .Content }}
|
|
352
|
+
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
|
|
353
|
+
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
|
354
|
+
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
|
355
|
+
{{ end }}
|
|
356
|
+
{{- end }}
|
|
357
|
+
{{- end }}""",
|
|
358
|
+
"stop_tokens": ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
|
|
359
|
+
},
|
|
360
|
+
"qwen": {
|
|
361
|
+
"template": """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
|
362
|
+
{{- else if .Messages }}
|
|
363
|
+
{{- if or .System .Tools }}<|im_start|>system
|
|
364
|
+
{{- if .System }}
|
|
365
|
+
{{ .System }}
|
|
366
|
+
{{- end }}
|
|
367
|
+
{{- if .Tools }}
|
|
368
|
+
# Tools
|
|
369
|
+
You may call one or more functions to assist with the user query.
|
|
370
|
+
You are provided with function signatures within <tools></tools> XML tags:
|
|
371
|
+
<tools>
|
|
372
|
+
{{- range .Tools }}
|
|
373
|
+
{"type": "function", "function": {{ .Function }}}
|
|
374
|
+
{{- end }}
|
|
375
|
+
</tools>
|
|
376
|
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
377
|
+
<tool_call>
|
|
378
|
+
{"name": <function-name>, "arguments": <args-json-object>}
|
|
379
|
+
</tool_call>
|
|
380
|
+
{{- end }}<|im_end|>
|
|
381
|
+
{{ end }}
|
|
382
|
+
{{- range $i, $_ := .Messages }}
|
|
383
|
+
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
384
|
+
{{- if eq .Role "user" }}<|im_start|>user
|
|
385
|
+
{{ .Content }}<|im_end|>
|
|
386
|
+
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
|
387
|
+
{{ if .Content }}{{ .Content }}
|
|
388
|
+
{{- else if .ToolCalls }}<tool_call>
|
|
389
|
+
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
390
|
+
{{ end }}</tool_call>
|
|
391
|
+
{{- end }}{{ if not $last }}<|im_end|>
|
|
392
|
+
{{ end }}
|
|
393
|
+
{{- else if eq .Role "tool" }}<|im_start|>user
|
|
394
|
+
<tool_response>
|
|
395
|
+
{{ .Content }}
|
|
396
|
+
</tool_response><|im_end|>
|
|
397
|
+
{{ end }}
|
|
398
|
+
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
|
399
|
+
{{ end }}
|
|
400
|
+
{{- end }}
|
|
401
|
+
{{- else }}
|
|
402
|
+
{{- if .System }}<|im_start|>system
|
|
403
|
+
{{ .System }}<|im_end|>
|
|
404
|
+
{{ end }}{{ if .Prompt }}<|im_start|>user
|
|
405
|
+
{{ .Prompt }}<|im_end|>
|
|
406
|
+
{{ end }}<|im_start|>assistant
|
|
407
|
+
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""",
|
|
408
|
+
"system": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
|
409
|
+
"num_ctx": 32768,
|
|
410
|
+
"stop_tokens": ["<|endoftext|>"]
|
|
411
|
+
},
|
|
412
|
+
"mistral": {
|
|
413
|
+
"template": "[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
|
414
|
+
"stop_tokens": ["[INST]", "[/INST]"]
|
|
415
|
+
},
|
|
416
|
+
"phi": {
|
|
417
|
+
"template": """{{- range $i, $_ := .Messages }}
|
|
418
|
+
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
419
|
+
<|im_start|>{{ .Role }}<|im_sep|>
|
|
420
|
+
{{ .Content }}{{ if not $last }}<|im_end|>
|
|
421
|
+
{{ end }}
|
|
422
|
+
{{- if and (ne .Role "assistant") $last }}<|im_end|>
|
|
423
|
+
<|im_start|>assistant<|im_sep|>
|
|
424
|
+
{{ end }}
|
|
425
|
+
{{- end }}""",
|
|
426
|
+
"stop_tokens": ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]
|
|
427
|
+
},
|
|
428
|
+
"deepseek": {
|
|
429
|
+
"template": """{{- if .System }}{{ .System }}{{ end }}
|
|
430
|
+
{{- range $i, $_ := .Messages }}
|
|
431
|
+
{{- $last := eq (len (slice $.Messages $i)) 1}}
|
|
432
|
+
{{- if eq .Role "user" }}
|
|
433
|
+
{{ .Content }}
|
|
434
|
+
{{- else if eq .Role "assistant" }}
|
|
435
|
+
{{ .Content }}{{- if not $last }}
|
|
436
|
+
{{- end }}
|
|
437
|
+
{{- end }}
|
|
438
|
+
{{- if and $last (ne .Role "assistant") }}
|
|
439
|
+
{{ end }}
|
|
440
|
+
{{- end }}""",
|
|
441
|
+
"stop_tokens": ["", "", "", ""]
|
|
442
|
+
},
|
|
443
|
+
"llava": {
|
|
444
|
+
"template": """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
|
445
|
+
{{- else if .Messages }}
|
|
446
|
+
{{- if or .System .Tools }}<|im_start|>system
|
|
447
|
+
{{- if .System }}
|
|
448
|
+
{{ .System }}
|
|
449
|
+
{{- end }}
|
|
450
|
+
{{- if .Tools }}
|
|
451
|
+
# Tools
|
|
452
|
+
You may call one or more functions to assist with the user query.
|
|
453
|
+
You are provided with function signatures within <tools></tools> XML tags:
|
|
454
|
+
<tools>
|
|
455
|
+
{{- range .Tools }}
|
|
456
|
+
{"type": "function", "function": {{ .Function }}}
|
|
457
|
+
{{- end }}
|
|
458
|
+
</tools>
|
|
459
|
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
460
|
+
<tool_call>
|
|
461
|
+
{"name": <function-name>, "arguments": <args-json-object>}
|
|
462
|
+
</tool_call>
|
|
463
|
+
{{- end }}<|im_end|>
|
|
464
|
+
{{ end }}
|
|
465
|
+
{{- range $i, $_ := .Messages }}
|
|
466
|
+
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
467
|
+
{{- if eq .Role "user" }}<|im_start|>user
|
|
468
|
+
{{ .Content }}<|im_end|>
|
|
469
|
+
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
|
470
|
+
{{ if .Content }}{{ .Content }}
|
|
471
|
+
{{- else if .ToolCalls }}<tool_call>
|
|
472
|
+
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
473
|
+
{{ end }}</tool_call>
|
|
474
|
+
{{- end }}{{ if not $last }}<|im_end|>
|
|
475
|
+
{{ end }}
|
|
476
|
+
{{- else if eq .Role "tool" }}<|im_start|>user
|
|
477
|
+
<tool_response>
|
|
478
|
+
{{ .Content }}
|
|
479
|
+
</tool_response><|im_end|>
|
|
480
|
+
{{ end }}
|
|
481
|
+
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
|
482
|
+
{{ end }}
|
|
483
|
+
{{- end }}
|
|
484
|
+
{{- else }}
|
|
485
|
+
{{- if .System }}<|im_start|>system
|
|
486
|
+
{{ .System }}<|im_end|>
|
|
487
|
+
{{ end }}{{ if .Prompt }}<|im_start|>user
|
|
488
|
+
{{ .Prompt }}<|im_end|>
|
|
489
|
+
{{ end }}<|im_start|>assistant
|
|
490
|
+
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""",
|
|
491
|
+
"stop_tokens": ["</s>", "USER:", "ASSSISTANT:"]
|
|
492
|
+
}
|
|
493
|
+
}
|
|
494
|
+
# Select mapping by checking if any key is in the model_name.
|
|
495
|
+
chosen = None
|
|
496
|
+
for key, settings in mapping.items():
|
|
497
|
+
if key in model_name:
|
|
498
|
+
chosen = settings
|
|
499
|
+
break
|
|
500
|
+
if chosen is None:
|
|
501
|
+
# Fallback default
|
|
502
|
+
chosen = {
|
|
503
|
+
"template": """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|
504
|
+
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
|
505
|
+
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
|
506
|
+
{{ .Response }}<|eot_id|>""",
|
|
507
|
+
"stop_tokens": ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
|
|
508
|
+
}
|
|
509
|
+
# Build the stop parameter lines.
|
|
510
|
+
stop_params = "\n".join([f"PARAMETER stop {token}" for token in chosen["stop_tokens"]])
|
|
511
|
+
# Optionally include a SYSTEM line and num_ctx if defined in the mapping.
|
|
512
|
+
system_line = ""
|
|
513
|
+
if "system" in chosen:
|
|
514
|
+
system_line = f"SYSTEM {chosen['system']}\n"
|
|
515
|
+
num_ctx_line = ""
|
|
516
|
+
if "num_ctx" in chosen:
|
|
517
|
+
num_ctx_line = f"PARAMETER num_ctx {chosen['num_ctx']}\n"
|
|
518
|
+
# Assemble and return the modelfile content.
|
|
519
|
+
return f"""FROM {output_model}
|
|
520
|
+
TEMPLATE \"\"\"{chosen['template']}\"\"\"
|
|
521
|
+
{system_line}{num_ctx_line}{stop_params}
|
|
522
|
+
"""
|
|
224
523
|
|
|
225
524
|
def create_and_push_ollama_model(self):
|
|
226
525
|
modelfile_content = self.prepare_modelfile_content()
|
|
227
|
-
with open(
|
|
526
|
+
with open("Modelfile", "w") as file:
|
|
228
527
|
file.write(modelfile_content)
|
|
229
|
-
|
|
230
528
|
subprocess.run(["ollama", "serve"])
|
|
231
529
|
subprocess.run(["ollama", "create", f"{self.config['ollama_model']}:{self.config['model_parameters']}", "-f", "Modelfile"])
|
|
232
530
|
subprocess.run(["ollama", "push", f"{self.config['ollama_model']}:{self.config['model_parameters']}"])
|
|
@@ -235,42 +533,30 @@ PARAMETER stop "<|reserved_special_token_"
|
|
|
235
533
|
self.print_system_info()
|
|
236
534
|
self.check_gpu()
|
|
237
535
|
self.check_ram()
|
|
238
|
-
# self.install_packages()
|
|
239
536
|
if self.config.get("train", "true").lower() == "true":
|
|
240
537
|
self.prepare_model()
|
|
241
538
|
self.train_model()
|
|
242
|
-
|
|
243
539
|
if self.config.get("huggingface_save", "true").lower() == "true":
|
|
244
|
-
# self.model, self.tokenizer = self.load_model()
|
|
245
540
|
self.save_model_merged()
|
|
246
|
-
|
|
247
541
|
if self.config.get("huggingface_save_gguf", "true").lower() == "true":
|
|
248
|
-
# self.model, self.tokenizer = self.load_model()
|
|
249
542
|
self.push_model_gguf()
|
|
250
|
-
|
|
251
|
-
# if self.config.get("save_gguf", "true").lower() == "true": ## TODO
|
|
252
|
-
# self.model, self.tokenizer = self.load_model()
|
|
253
|
-
# self.save_model_gguf()
|
|
254
|
-
|
|
255
|
-
# if self.config.get("save_merged", "true").lower() == "true": ## TODO
|
|
256
|
-
# self.model, self.tokenizer = self.load_model()
|
|
257
|
-
# self.save_model_merged()
|
|
258
|
-
|
|
259
543
|
if self.config.get("ollama_save", "true").lower() == "true":
|
|
260
544
|
self.create_and_push_ollama_model()
|
|
261
545
|
|
|
262
|
-
|
|
263
546
|
def main():
|
|
264
547
|
import argparse
|
|
265
|
-
parser = argparse.ArgumentParser(description=
|
|
266
|
-
parser.add_argument(
|
|
267
|
-
parser.add_argument(
|
|
548
|
+
parser = argparse.ArgumentParser(description="PraisonAI Training Script")
|
|
549
|
+
parser.add_argument("command", choices=["train"], help="Command to execute")
|
|
550
|
+
parser.add_argument("--config", default="config.yaml", help="Path to configuration file")
|
|
551
|
+
parser.add_argument("--model", type=str, help="Model name")
|
|
552
|
+
parser.add_argument("--hf", type=str, help="Hugging Face model name")
|
|
553
|
+
parser.add_argument("--ollama", type=str, help="Ollama model name")
|
|
554
|
+
parser.add_argument("--dataset", type=str, help="Dataset name for training")
|
|
268
555
|
args = parser.parse_args()
|
|
269
556
|
|
|
270
|
-
if args.command ==
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
557
|
+
if args.command == "train":
|
|
558
|
+
trainer_obj = TrainModel(config_path=args.config)
|
|
559
|
+
trainer_obj.run()
|
|
274
560
|
|
|
275
|
-
if __name__ ==
|
|
561
|
+
if __name__ == "__main__":
|
|
276
562
|
main()
|