PraisonAI 2.0.12__cp311-cp311-macosx_15_0_arm64.whl → 2.2.16__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PraisonAI might be problematic. Click here for more details.

Files changed (40) hide show
  1. praisonai/README.md +5 -0
  2. praisonai/agents_generator.py +83 -44
  3. praisonai/api/call.py +3 -3
  4. praisonai/auto.py +1 -1
  5. praisonai/cli.py +151 -16
  6. praisonai/deploy.py +1 -1
  7. praisonai/inbuilt_tools/__init__.py +1 -1
  8. praisonai/public/praison-ai-agents-architecture-dark.png +0 -0
  9. praisonai/public/praison-ai-agents-architecture.png +0 -0
  10. praisonai/setup/setup_conda_env.sh +55 -22
  11. praisonai/train.py +442 -156
  12. praisonai/train_vision.py +306 -0
  13. praisonai/ui/agents.py +822 -0
  14. praisonai/ui/callbacks.py +57 -0
  15. praisonai/ui/code.py +4 -2
  16. praisonai/ui/colab.py +474 -0
  17. praisonai/ui/colab_chainlit.py +81 -0
  18. praisonai/ui/config/chainlit.md +1 -1
  19. praisonai/ui/realtime.py +65 -10
  20. praisonai/ui/sql_alchemy.py +6 -5
  21. praisonai/ui/tools.md +133 -0
  22. praisonai/upload_vision.py +140 -0
  23. praisonai-2.2.16.dist-info/METADATA +103 -0
  24. {praisonai-2.0.12.dist-info → praisonai-2.2.16.dist-info}/RECORD +26 -29
  25. {praisonai-2.0.12.dist-info → praisonai-2.2.16.dist-info}/WHEEL +1 -1
  26. praisonai/ui/config/.chainlit/config.toml +0 -120
  27. praisonai/ui/config/.chainlit/translations/bn.json +0 -231
  28. praisonai/ui/config/.chainlit/translations/en-US.json +0 -229
  29. praisonai/ui/config/.chainlit/translations/gu.json +0 -231
  30. praisonai/ui/config/.chainlit/translations/he-IL.json +0 -231
  31. praisonai/ui/config/.chainlit/translations/hi.json +0 -231
  32. praisonai/ui/config/.chainlit/translations/kn.json +0 -231
  33. praisonai/ui/config/.chainlit/translations/ml.json +0 -231
  34. praisonai/ui/config/.chainlit/translations/mr.json +0 -231
  35. praisonai/ui/config/.chainlit/translations/ta.json +0 -231
  36. praisonai/ui/config/.chainlit/translations/te.json +0 -231
  37. praisonai/ui/config/.chainlit/translations/zh-CN.json +0 -229
  38. praisonai-2.0.12.dist-info/LICENSE +0 -20
  39. praisonai-2.0.12.dist-info/METADATA +0 -498
  40. {praisonai-2.0.12.dist-info → praisonai-2.2.16.dist-info}/entry_points.txt +0 -0
praisonai/train.py CHANGED
@@ -1,168 +1,287 @@
1
- import subprocess
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ This script finetunes a model using Unsloth's fast training framework.
5
+ It supports both ShareGPT and Alpaca‑style datasets by converting raw conversation
6
+ data into plain-text prompts using a chat template, then pre‑tokenizing the prompts.
7
+ Extra debug logging is added to help trace the root cause of errors.
8
+ """
9
+
2
10
  import os
3
11
  import sys
4
12
  import yaml
5
13
  import torch
6
14
  import shutil
15
+ import subprocess
7
16
  from transformers import TextStreamer
8
17
  from unsloth import FastLanguageModel, is_bfloat16_supported
9
18
  from trl import SFTTrainer
10
19
  from transformers import TrainingArguments
11
- from datasets import load_dataset, concatenate_datasets, Dataset
20
+ from datasets import load_dataset, concatenate_datasets
12
21
  from psutil import virtual_memory
13
-
14
- class train:
22
+ from unsloth.chat_templates import standardize_sharegpt, get_chat_template
23
+ from functools import partial
24
+
25
+ #####################################
26
+ # Step 1: Formatting Raw Conversations
27
+ #####################################
28
+ def formatting_prompts_func(examples, tokenizer):
29
+ """
30
+ Converts each example's conversation into a single plain-text prompt.
31
+ If the example has a "conversations" field, process it as ShareGPT-style.
32
+ Otherwise, assume Alpaca-style data with "instruction", "input", and "output" fields.
33
+ """
34
+ print("DEBUG: formatting_prompts_func() received batch with keys:", list(examples.keys()))
35
+ texts = []
36
+ # Check if the example has a "conversations" field.
37
+ if "conversations" in examples:
38
+ for convo in examples["conversations"]:
39
+ try:
40
+ formatted = tokenizer.apply_chat_template(
41
+ convo,
42
+ tokenize=False, # Return a plain string
43
+ add_generation_prompt=False
44
+ )
45
+ except Exception as e:
46
+ print(f"ERROR in apply_chat_template (conversations): {e}")
47
+ formatted = ""
48
+ # Flatten list if necessary
49
+ if isinstance(formatted, list):
50
+ formatted = formatted[0] if len(formatted) == 1 else "\n".join(formatted)
51
+ texts.append(formatted)
52
+ else:
53
+ # Assume Alpaca format: use "instruction", "input", and "output" keys.
54
+ instructions = examples.get("instruction", [])
55
+ inputs_list = examples.get("input", [])
56
+ outputs_list = examples.get("output", [])
57
+ # If any field is missing, replace with empty string.
58
+ for ins, inp, out in zip(instructions, inputs_list, outputs_list):
59
+ # Create a conversation-like structure.
60
+ convo = [
61
+ {"role": "user", "content": ins + (f"\nInput: {inp}" if inp.strip() != "" else "")},
62
+ {"role": "assistant", "content": out}
63
+ ]
64
+ try:
65
+ formatted = tokenizer.apply_chat_template(
66
+ convo,
67
+ tokenize=False,
68
+ add_generation_prompt=False
69
+ )
70
+ except Exception as e:
71
+ print(f"ERROR in apply_chat_template (alpaca): {e}")
72
+ formatted = ""
73
+ if isinstance(formatted, list):
74
+ formatted = formatted[0] if len(formatted) == 1 else "\n".join(formatted)
75
+ texts.append(formatted)
76
+ if texts:
77
+ print("DEBUG: Raw texts sample (first 200 chars):", texts[0][:200])
78
+ return {"text": texts}
79
+
80
+ #####################################
81
+ # Step 2: Tokenizing the Prompts
82
+ #####################################
83
+ def tokenize_function(examples, hf_tokenizer, max_length):
84
+ """
85
+ Tokenizes a batch of text prompts with padding and truncation enabled.
86
+ """
87
+ flat_texts = []
88
+ for t in examples["text"]:
89
+ if isinstance(t, list):
90
+ t = t[0] if len(t) == 1 else " ".join(t)
91
+ flat_texts.append(t)
92
+ print("DEBUG: Tokenizing a batch of size:", len(flat_texts))
93
+ tokenized = hf_tokenizer(
94
+ flat_texts,
95
+ padding="max_length",
96
+ truncation=True,
97
+ max_length=max_length,
98
+ return_tensors="pt",
99
+ )
100
+ tokenized = {key: value.tolist() for key, value in tokenized.items()}
101
+ sample_key = list(tokenized.keys())[0]
102
+ print("DEBUG: Tokenized sample (first 10 tokens of", sample_key, "):", tokenized[sample_key][0][:10])
103
+ return tokenized
104
+
105
+ #####################################
106
+ # Main Training Class
107
+ #####################################
108
+ class TrainModel:
15
109
  def __init__(self, config_path="config.yaml"):
16
110
  self.load_config(config_path)
17
111
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- self.model, self.tokenizer = None, None
112
+ self.model = None
113
+ self.hf_tokenizer = None # The underlying HF tokenizer
114
+ self.chat_tokenizer = None # Chat wrapper for formatting
19
115
 
20
116
  def load_config(self, path):
21
117
  with open(path, "r") as file:
22
118
  self.config = yaml.safe_load(file)
119
+ print("DEBUG: Loaded config:", self.config)
23
120
 
24
121
  def print_system_info(self):
25
- print(f"PyTorch version: {torch.__version__}")
26
- print(f"CUDA version: {torch.version.cuda}")
122
+ print("DEBUG: PyTorch version:", torch.__version__)
123
+ print("DEBUG: CUDA version:", torch.version.cuda)
27
124
  if torch.cuda.is_available():
28
- device_capability = torch.cuda.get_device_capability()
29
- print(f"CUDA Device Capability: {device_capability}")
125
+ print("DEBUG: CUDA Device Capability:", torch.cuda.get_device_capability())
30
126
  else:
31
- print("CUDA is not available")
32
-
33
- python_version = sys.version
34
- pip_version = subprocess.check_output(['pip', '--version']).decode().strip()
35
- python_path = sys.executable
36
- pip_path = subprocess.check_output(['which', 'pip']).decode().strip()
37
- print(f"Python Version: {python_version}")
38
- print(f"Pip Version: {pip_version}")
39
- print(f"Python Path: {python_path}")
40
- print(f"Pip Path: {pip_path}")
127
+ print("DEBUG: CUDA is not available")
128
+ print("DEBUG: Python Version:", sys.version)
129
+ print("DEBUG: Python Path:", sys.executable)
41
130
 
42
131
  def check_gpu(self):
43
132
  gpu_stats = torch.cuda.get_device_properties(0)
44
- print(f"GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)} GB.")
133
+ print(f"DEBUG: GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory/(1024**3),3)} GB.")
45
134
 
46
135
  def check_ram(self):
47
136
  ram_gb = virtual_memory().total / 1e9
48
- print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
137
+ print(f"DEBUG: Your runtime has {ram_gb:.1f} gigabytes of available RAM")
49
138
  if ram_gb < 20:
50
- print('Not using a high-RAM runtime')
139
+ print("DEBUG: Not using a high-RAM runtime")
51
140
  else:
52
- print('You are using a high-RAM runtime!')
53
-
54
- # def install_packages(self):
55
- # subprocess.run(["pip", "install", "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@4e570be9ae4ced8cdc64e498125708e34942befc"])
56
- # subprocess.run(["pip", "install", "--no-deps", "trl<0.9.0", "peft==0.12.0", "accelerate==0.33.0", "bitsandbytes==0.43.3"])
141
+ print("DEBUG: You are using a high-RAM runtime!")
57
142
 
58
143
  def prepare_model(self):
59
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
144
+ print("DEBUG: Preparing model and tokenizer...")
145
+ self.model, original_tokenizer = FastLanguageModel.from_pretrained(
60
146
  model_name=self.config["model_name"],
61
147
  max_seq_length=self.config["max_seq_length"],
62
148
  dtype=None,
63
- load_in_4bit=self.config["load_in_4bit"]
149
+ load_in_4bit=self.config["load_in_4bit"],
64
150
  )
151
+ print("DEBUG: Model and original tokenizer loaded.")
152
+ if original_tokenizer.pad_token is None:
153
+ original_tokenizer.pad_token = original_tokenizer.eos_token
154
+ original_tokenizer.model_max_length = self.config["max_seq_length"]
155
+ self.chat_tokenizer = get_chat_template(original_tokenizer, chat_template="llama-3.1")
156
+ self.hf_tokenizer = original_tokenizer
157
+ print("DEBUG: Chat tokenizer created; HF tokenizer saved.")
65
158
  self.model = FastLanguageModel.get_peft_model(
66
159
  self.model,
67
- r=self.config["lora_r"],
68
- target_modules=self.config["lora_target_modules"],
69
- lora_alpha=self.config["lora_alpha"],
70
- lora_dropout=self.config["lora_dropout"],
71
- bias=self.config["lora_bias"],
72
- use_gradient_checkpointing=self.config["use_gradient_checkpointing"],
73
- random_state=self.config["random_state"],
74
- use_rslora=self.config["use_rslora"],
75
- loftq_config=self.config["loftq_config"],
160
+ r=16,
161
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
162
+ lora_alpha=16,
163
+ lora_dropout=0,
164
+ bias="none",
165
+ use_gradient_checkpointing="unsloth",
166
+ random_state=3407,
167
+ use_rslora=False,
168
+ loftq_config=None,
76
169
  )
170
+ print("DEBUG: LoRA adapters added.")
77
171
 
78
172
  def process_dataset(self, dataset_info):
79
173
  dataset_name = dataset_info["name"]
80
174
  split_type = dataset_info.get("split_type", "train")
81
- processing_func = getattr(self, dataset_info.get("processing_func", "format_prompts"))
82
- rename = dataset_info.get("rename", {})
83
- filter_data = dataset_info.get("filter_data", False)
84
- filter_column_value = dataset_info.get("filter_column_value", "id")
85
- filter_value = dataset_info.get("filter_value", "alpaca")
86
- num_samples = dataset_info.get("num_samples", 20000)
87
-
175
+ print(f"DEBUG: Loading dataset '{dataset_name}' split '{split_type}'...")
88
176
  dataset = load_dataset(dataset_name, split=split_type)
89
-
90
- if rename:
91
- dataset = dataset.rename_columns(rename)
92
- if filter_data:
93
- dataset = dataset.filter(lambda example: filter_value in example[filter_column_value]).shuffle(seed=42).select(range(num_samples))
94
- dataset = dataset.map(processing_func, batched=True)
177
+ print("DEBUG: Dataset columns:", dataset.column_names)
178
+ if "conversations" in dataset.column_names:
179
+ print("DEBUG: Standardizing dataset (ShareGPT style)...")
180
+ dataset = standardize_sharegpt(dataset)
181
+ else:
182
+ print("DEBUG: Dataset does not have 'conversations'; assuming Alpaca format.")
183
+ print("DEBUG: Applying formatting function to dataset...")
184
+ format_func = partial(formatting_prompts_func, tokenizer=self.chat_tokenizer)
185
+ dataset = dataset.map(format_func, batched=True, remove_columns=dataset.column_names)
186
+ sample = dataset[0]
187
+ print("DEBUG: Sample processed example keys:", list(sample.keys()))
188
+ if "text" in sample:
189
+ print("DEBUG: Sample processed 'text' type:", type(sample["text"]))
190
+ print("DEBUG: Sample processed 'text' content (first 200 chars):", sample["text"][:200])
191
+ else:
192
+ print("DEBUG: Processed sample does not contain 'text'.")
95
193
  return dataset
96
194
 
97
- def format_prompts(self, examples):
98
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
99
-
100
- ### Instruction:
101
- {}
102
-
103
- ### Input:
104
- {}
105
-
106
- ### Response:
107
- {}"""
108
- texts = [alpaca_prompt.format(ins, inp, out) + self.tokenizer.eos_token for ins, inp, out in zip(examples["instruction"], examples["input"], examples["output"])]
109
- return {"text": texts}
195
+ def tokenize_dataset(self, dataset):
196
+ print("DEBUG: Tokenizing the entire dataset...")
197
+ tokenized_dataset = dataset.map(
198
+ lambda examples: tokenize_function(examples, self.hf_tokenizer, self.config["max_seq_length"]),
199
+ batched=True
200
+ )
201
+ tokenized_dataset = tokenized_dataset.remove_columns(["text"])
202
+ print("DEBUG: Tokenized dataset sample keys:", tokenized_dataset[0].keys())
203
+ return tokenized_dataset
110
204
 
111
205
  def load_datasets(self):
112
206
  datasets = []
113
207
  for dataset_info in self.config["dataset"]:
208
+ print("DEBUG: Processing dataset info:", dataset_info)
114
209
  datasets.append(self.process_dataset(dataset_info))
115
- return concatenate_datasets(datasets)
210
+ combined = concatenate_datasets(datasets)
211
+ print("DEBUG: Combined dataset has", len(combined), "examples.")
212
+ return combined
116
213
 
117
214
  def train_model(self):
118
- dataset = self.load_datasets()
215
+ print("DEBUG: Starting training...")
216
+ raw_dataset = self.load_datasets()
217
+ tokenized_dataset = self.tokenize_dataset(raw_dataset)
218
+ print("DEBUG: Dataset tokenization complete.")
219
+ # Build the training arguments parameters dynamically
220
+ ta_params = {
221
+ "per_device_train_batch_size": self.config.get("per_device_train_batch_size", 2),
222
+ "gradient_accumulation_steps": self.config.get("gradient_accumulation_steps", 2),
223
+ "warmup_steps": self.config.get("warmup_steps", 50),
224
+ "max_steps": self.config.get("max_steps", 2800),
225
+ "learning_rate": self.config.get("learning_rate", 2e-4),
226
+ "fp16": self.config.get("fp16", not is_bfloat16_supported()),
227
+ "bf16": self.config.get("bf16", is_bfloat16_supported()),
228
+ "logging_steps": self.config.get("logging_steps", 15),
229
+ "optim": self.config.get("optim", "adamw_8bit"),
230
+ "weight_decay": self.config.get("weight_decay", 0.01),
231
+ "lr_scheduler_type": self.config.get("lr_scheduler_type", "linear"),
232
+ "seed": self.config.get("seed", 3407),
233
+ "output_dir": self.config.get("output_dir", "outputs"),
234
+ "report_to": "none" if not os.getenv("PRAISON_WANDB") else "wandb",
235
+ "remove_unused_columns": self.config.get("remove_unused_columns", False)
236
+ }
237
+ if os.getenv("PRAISON_WANDB"):
238
+ ta_params["save_steps"] = self.config.get("save_steps", 100)
239
+ ta_params["run_name"] = os.getenv("PRAISON_WANDB_RUN_NAME", "praisonai-train")
240
+
241
+ training_args = TrainingArguments(**ta_params)
242
+ # Since the dataset is pre-tokenized, we supply a dummy dataset_text_field.
119
243
  trainer = SFTTrainer(
120
244
  model=self.model,
121
- tokenizer=self.tokenizer,
122
- train_dataset=dataset,
123
- dataset_text_field=self.config["dataset_text_field"],
245
+ tokenizer=self.hf_tokenizer,
246
+ train_dataset=tokenized_dataset,
247
+ dataset_text_field="input_ids", # Dummy field since data is numeric
124
248
  max_seq_length=self.config["max_seq_length"],
125
- dataset_num_proc=self.config["dataset_num_proc"],
126
- packing=self.config["packing"],
127
- args=TrainingArguments(
128
- per_device_train_batch_size=self.config["per_device_train_batch_size"],
129
- gradient_accumulation_steps=self.config["gradient_accumulation_steps"],
130
- warmup_steps=self.config["warmup_steps"],
131
- num_train_epochs=self.config["num_train_epochs"],
132
- max_steps=self.config["max_steps"],
133
- learning_rate=self.config["learning_rate"],
134
- fp16=not is_bfloat16_supported(),
135
- bf16=is_bfloat16_supported(),
136
- logging_steps=self.config["logging_steps"],
137
- optim=self.config["optim"],
138
- weight_decay=self.config["weight_decay"],
139
- lr_scheduler_type=self.config["lr_scheduler_type"],
140
- seed=self.config["seed"],
141
- output_dir=self.config["output_dir"],
142
- ),
249
+ dataset_num_proc=1, # Use a single process to avoid pickling issues
250
+ packing=False,
251
+ args=training_args,
143
252
  )
253
+ from unsloth.chat_templates import train_on_responses_only
254
+ trainer = train_on_responses_only(
255
+ trainer,
256
+ instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
257
+ response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
258
+ )
259
+ print("DEBUG: Beginning trainer.train() ...")
144
260
  trainer.train()
145
- self.model.save_pretrained("lora_model") # Local saving
146
- self.tokenizer.save_pretrained("lora_model")
261
+ print("DEBUG: Training complete. Saving model and tokenizer locally...")
262
+ self.model.save_pretrained("lora_model")
263
+ self.hf_tokenizer.save_pretrained("lora_model")
264
+ print("DEBUG: Saved model and tokenizer to 'lora_model'.")
147
265
 
148
266
  def inference(self, instruction, input_text):
149
267
  FastLanguageModel.for_inference(self.model)
150
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
151
-
152
- ### Instruction:
153
- {}
154
-
155
- ### Input:
156
- {}
268
+ messages = [{"role": "user", "content": f"{instruction}\n\nInput: {input_text}"}]
269
+ inputs = self.hf_tokenizer.apply_chat_template(
270
+ messages,
271
+ tokenize=True,
272
+ add_generation_prompt=True,
273
+ return_tensors="pt"
274
+ ).to("cuda")
275
+ outputs = self.model.generate(
276
+ input_ids=inputs,
277
+ max_new_tokens=64,
278
+ use_cache=True,
279
+ temperature=1.5,
280
+ min_p=0.1
281
+ )
282
+ print("DEBUG: Inference output:", self.hf_tokenizer.batch_decode(outputs))
157
283
 
158
- ### Response:
159
- {}"""
160
- inputs = self.tokenizer([alpaca_prompt.format(instruction, input_text, "")], return_tensors="pt").to("cuda")
161
- outputs = self.model.generate(**inputs, max_new_tokens=64, use_cache=True)
162
- print(self.tokenizer.batch_decode(outputs))
163
-
164
284
  def load_model(self):
165
- """Loads the model and tokenizer using the FastLanguageModel library."""
166
285
  from unsloth import FastLanguageModel
167
286
  model, tokenizer = FastLanguageModel.from_pretrained(
168
287
  model_name=self.config["output_dir"],
@@ -177,56 +296,235 @@ class train:
177
296
  shutil.rmtree(self.config["hf_model_name"])
178
297
  self.model.push_to_hub_merged(
179
298
  self.config["hf_model_name"],
180
- self.tokenizer,
299
+ self.hf_tokenizer,
181
300
  save_method="merged_16bit",
182
- token=os.getenv('HF_TOKEN')
301
+ token=os.getenv("HF_TOKEN")
183
302
  )
184
303
 
185
304
  def push_model_gguf(self):
186
305
  self.model.push_to_hub_gguf(
187
306
  self.config["hf_model_name"],
188
- self.tokenizer,
307
+ self.hf_tokenizer,
189
308
  quantization_method=self.config["quantization_method"],
190
- token=os.getenv('HF_TOKEN')
309
+ token=os.getenv("HF_TOKEN")
191
310
  )
192
-
311
+
193
312
  def save_model_gguf(self):
194
313
  self.model.save_pretrained_gguf(
195
314
  self.config["hf_model_name"],
196
- self.tokenizer,
315
+ self.hf_tokenizer,
197
316
  quantization_method="q4_k_m"
198
317
  )
199
318
 
200
319
  def prepare_modelfile_content(self):
201
320
  output_model = self.config["hf_model_name"]
202
- gguf_path = f"{output_model}/unsloth.Q4_K_M.gguf"
203
-
204
- # Check if the GGUF file exists. If not, generate it ## TODO Multiple Quantisation other than Q4_K_M.gguf
205
- if not os.path.exists(gguf_path):
206
- self.model, self.tokenizer = self.load_model()
207
- self.save_model_gguf()
208
- return f"""FROM {output_model}/unsloth.Q4_K_M.gguf
209
-
210
- TEMPLATE \"\"\"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.{{{{ if .Prompt }}}}
211
-
212
- ### Instruction:
213
- {{{{ .Prompt }}}}
214
-
215
- {{{{ end }}}}### Response:
216
- {{{{ .Response }}}}\"\"\"
217
-
218
- PARAMETER stop ""
219
- PARAMETER stop ""
220
- PARAMETER stop ""
221
- PARAMETER stop ""
222
- PARAMETER stop "<|reserved_special_token_"
223
- """
321
+ model_name = self.config["model_name"].lower()
322
+ # Mapping from model name keywords to their default TEMPLATE and stop tokens (and optional SYSTEM/num_ctx)
323
+ mapping = {
324
+ "llama": {
325
+ "template": """<|start_header_id|>system<|end_header_id|>
326
+ Cutting Knowledge Date: December 2023
327
+ {{ if .System }}{{ .System }}
328
+ {{- end }}
329
+ {{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
330
+ You are a helpful assistant with tool calling capabilities.
331
+ {{- end }}<|eot_id|>
332
+ {{- range $i, $_ := .Messages }}
333
+ {{- $last := eq (len (slice $.Messages $i)) 1 }}
334
+ {{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
335
+ {{- if and $.Tools $last }}
336
+ Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
337
+ Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
338
+ {{ range $.Tools }}
339
+ {{- . }}
340
+ {{ end }}
341
+ {{ .Content }}<|eot_id|>
342
+ {{- else }}
343
+ {{ .Content }}<|eot_id|>
344
+ {{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
345
+ {{ end }}
346
+ {{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
347
+ {{- if .ToolCalls }}
348
+ {{ range .ToolCalls }}
349
+ {"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
350
+ {{- else }}
351
+ {{ .Content }}
352
+ {{- end }}{{ if not $last }}<|eot_id|>{{ end }}
353
+ {{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
354
+ {{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
355
+ {{ end }}
356
+ {{- end }}
357
+ {{- end }}""",
358
+ "stop_tokens": ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
359
+ },
360
+ "qwen": {
361
+ "template": """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
362
+ {{- else if .Messages }}
363
+ {{- if or .System .Tools }}<|im_start|>system
364
+ {{- if .System }}
365
+ {{ .System }}
366
+ {{- end }}
367
+ {{- if .Tools }}
368
+ # Tools
369
+ You may call one or more functions to assist with the user query.
370
+ You are provided with function signatures within <tools></tools> XML tags:
371
+ <tools>
372
+ {{- range .Tools }}
373
+ {"type": "function", "function": {{ .Function }}}
374
+ {{- end }}
375
+ </tools>
376
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
377
+ <tool_call>
378
+ {"name": <function-name>, "arguments": <args-json-object>}
379
+ </tool_call>
380
+ {{- end }}<|im_end|>
381
+ {{ end }}
382
+ {{- range $i, $_ := .Messages }}
383
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
384
+ {{- if eq .Role "user" }}<|im_start|>user
385
+ {{ .Content }}<|im_end|>
386
+ {{ else if eq .Role "assistant" }}<|im_start|>assistant
387
+ {{ if .Content }}{{ .Content }}
388
+ {{- else if .ToolCalls }}<tool_call>
389
+ {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
390
+ {{ end }}</tool_call>
391
+ {{- end }}{{ if not $last }}<|im_end|>
392
+ {{ end }}
393
+ {{- else if eq .Role "tool" }}<|im_start|>user
394
+ <tool_response>
395
+ {{ .Content }}
396
+ </tool_response><|im_end|>
397
+ {{ end }}
398
+ {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
399
+ {{ end }}
400
+ {{- end }}
401
+ {{- else }}
402
+ {{- if .System }}<|im_start|>system
403
+ {{ .System }}<|im_end|>
404
+ {{ end }}{{ if .Prompt }}<|im_start|>user
405
+ {{ .Prompt }}<|im_end|>
406
+ {{ end }}<|im_start|>assistant
407
+ {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""",
408
+ "system": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
409
+ "num_ctx": 32768,
410
+ "stop_tokens": ["<|endoftext|>"]
411
+ },
412
+ "mistral": {
413
+ "template": "[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
414
+ "stop_tokens": ["[INST]", "[/INST]"]
415
+ },
416
+ "phi": {
417
+ "template": """{{- range $i, $_ := .Messages }}
418
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
419
+ <|im_start|>{{ .Role }}<|im_sep|>
420
+ {{ .Content }}{{ if not $last }}<|im_end|>
421
+ {{ end }}
422
+ {{- if and (ne .Role "assistant") $last }}<|im_end|>
423
+ <|im_start|>assistant<|im_sep|>
424
+ {{ end }}
425
+ {{- end }}""",
426
+ "stop_tokens": ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]
427
+ },
428
+ "deepseek": {
429
+ "template": """{{- if .System }}{{ .System }}{{ end }}
430
+ {{- range $i, $_ := .Messages }}
431
+ {{- $last := eq (len (slice $.Messages $i)) 1}}
432
+ {{- if eq .Role "user" }}
433
+ {{ .Content }}
434
+ {{- else if eq .Role "assistant" }}
435
+ {{ .Content }}{{- if not $last }}
436
+ {{- end }}
437
+ {{- end }}
438
+ {{- if and $last (ne .Role "assistant") }}
439
+ {{ end }}
440
+ {{- end }}""",
441
+ "stop_tokens": ["", "", "", ""]
442
+ },
443
+ "llava": {
444
+ "template": """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
445
+ {{- else if .Messages }}
446
+ {{- if or .System .Tools }}<|im_start|>system
447
+ {{- if .System }}
448
+ {{ .System }}
449
+ {{- end }}
450
+ {{- if .Tools }}
451
+ # Tools
452
+ You may call one or more functions to assist with the user query.
453
+ You are provided with function signatures within <tools></tools> XML tags:
454
+ <tools>
455
+ {{- range .Tools }}
456
+ {"type": "function", "function": {{ .Function }}}
457
+ {{- end }}
458
+ </tools>
459
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
460
+ <tool_call>
461
+ {"name": <function-name>, "arguments": <args-json-object>}
462
+ </tool_call>
463
+ {{- end }}<|im_end|>
464
+ {{ end }}
465
+ {{- range $i, $_ := .Messages }}
466
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
467
+ {{- if eq .Role "user" }}<|im_start|>user
468
+ {{ .Content }}<|im_end|>
469
+ {{ else if eq .Role "assistant" }}<|im_start|>assistant
470
+ {{ if .Content }}{{ .Content }}
471
+ {{- else if .ToolCalls }}<tool_call>
472
+ {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
473
+ {{ end }}</tool_call>
474
+ {{- end }}{{ if not $last }}<|im_end|>
475
+ {{ end }}
476
+ {{- else if eq .Role "tool" }}<|im_start|>user
477
+ <tool_response>
478
+ {{ .Content }}
479
+ </tool_response><|im_end|>
480
+ {{ end }}
481
+ {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
482
+ {{ end }}
483
+ {{- end }}
484
+ {{- else }}
485
+ {{- if .System }}<|im_start|>system
486
+ {{ .System }}<|im_end|>
487
+ {{ end }}{{ if .Prompt }}<|im_start|>user
488
+ {{ .Prompt }}<|im_end|>
489
+ {{ end }}<|im_start|>assistant
490
+ {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""",
491
+ "stop_tokens": ["</s>", "USER:", "ASSSISTANT:"]
492
+ }
493
+ }
494
+ # Select mapping by checking if any key is in the model_name.
495
+ chosen = None
496
+ for key, settings in mapping.items():
497
+ if key in model_name:
498
+ chosen = settings
499
+ break
500
+ if chosen is None:
501
+ # Fallback default
502
+ chosen = {
503
+ "template": """{{ if .System }}<|start_header_id|>system<|end_header_id|>
504
+ {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
505
+ {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
506
+ {{ .Response }}<|eot_id|>""",
507
+ "stop_tokens": ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
508
+ }
509
+ # Build the stop parameter lines.
510
+ stop_params = "\n".join([f"PARAMETER stop {token}" for token in chosen["stop_tokens"]])
511
+ # Optionally include a SYSTEM line and num_ctx if defined in the mapping.
512
+ system_line = ""
513
+ if "system" in chosen:
514
+ system_line = f"SYSTEM {chosen['system']}\n"
515
+ num_ctx_line = ""
516
+ if "num_ctx" in chosen:
517
+ num_ctx_line = f"PARAMETER num_ctx {chosen['num_ctx']}\n"
518
+ # Assemble and return the modelfile content.
519
+ return f"""FROM {output_model}
520
+ TEMPLATE \"\"\"{chosen['template']}\"\"\"
521
+ {system_line}{num_ctx_line}{stop_params}
522
+ """
224
523
 
225
524
  def create_and_push_ollama_model(self):
226
525
  modelfile_content = self.prepare_modelfile_content()
227
- with open('Modelfile', 'w') as file:
526
+ with open("Modelfile", "w") as file:
228
527
  file.write(modelfile_content)
229
-
230
528
  subprocess.run(["ollama", "serve"])
231
529
  subprocess.run(["ollama", "create", f"{self.config['ollama_model']}:{self.config['model_parameters']}", "-f", "Modelfile"])
232
530
  subprocess.run(["ollama", "push", f"{self.config['ollama_model']}:{self.config['model_parameters']}"])
@@ -235,42 +533,30 @@ PARAMETER stop "<|reserved_special_token_"
235
533
  self.print_system_info()
236
534
  self.check_gpu()
237
535
  self.check_ram()
238
- # self.install_packages()
239
536
  if self.config.get("train", "true").lower() == "true":
240
537
  self.prepare_model()
241
538
  self.train_model()
242
-
243
539
  if self.config.get("huggingface_save", "true").lower() == "true":
244
- # self.model, self.tokenizer = self.load_model()
245
540
  self.save_model_merged()
246
-
247
541
  if self.config.get("huggingface_save_gguf", "true").lower() == "true":
248
- # self.model, self.tokenizer = self.load_model()
249
542
  self.push_model_gguf()
250
-
251
- # if self.config.get("save_gguf", "true").lower() == "true": ## TODO
252
- # self.model, self.tokenizer = self.load_model()
253
- # self.save_model_gguf()
254
-
255
- # if self.config.get("save_merged", "true").lower() == "true": ## TODO
256
- # self.model, self.tokenizer = self.load_model()
257
- # self.save_model_merged()
258
-
259
543
  if self.config.get("ollama_save", "true").lower() == "true":
260
544
  self.create_and_push_ollama_model()
261
545
 
262
-
263
546
  def main():
264
547
  import argparse
265
- parser = argparse.ArgumentParser(description='PraisonAI Training Script')
266
- parser.add_argument('command', choices=['train'], help='Command to execute')
267
- parser.add_argument('--config', default='config.yaml', help='Path to configuration file')
548
+ parser = argparse.ArgumentParser(description="PraisonAI Training Script")
549
+ parser.add_argument("command", choices=["train"], help="Command to execute")
550
+ parser.add_argument("--config", default="config.yaml", help="Path to configuration file")
551
+ parser.add_argument("--model", type=str, help="Model name")
552
+ parser.add_argument("--hf", type=str, help="Hugging Face model name")
553
+ parser.add_argument("--ollama", type=str, help="Ollama model name")
554
+ parser.add_argument("--dataset", type=str, help="Dataset name for training")
268
555
  args = parser.parse_args()
269
556
 
270
- if args.command == 'train':
271
- ai = train(config_path=args.config)
272
- ai.run()
273
-
557
+ if args.command == "train":
558
+ trainer_obj = TrainModel(config_path=args.config)
559
+ trainer_obj.run()
274
560
 
275
- if __name__ == '__main__':
561
+ if __name__ == "__main__":
276
562
  main()