prepforge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- major_project/__init__.py +0 -0
- major_project/cli.py +186 -0
- major_project/config.py +17 -0
- major_project/core/__init__.py +0 -0
- major_project/core/config_store.py +20 -0
- major_project/core/model.py +112 -0
- major_project/core/train.py +149 -0
- major_project/core/utils.py +222 -0
- major_project/gui_app.py +227 -0
- major_project/streamlit_app.py +171 -0
- major_project/tui_app.py +241 -0
- prepforge-0.1.0.dist-info/METADATA +149 -0
- prepforge-0.1.0.dist-info/RECORD +17 -0
- prepforge-0.1.0.dist-info/WHEEL +5 -0
- prepforge-0.1.0.dist-info/entry_points.txt +2 -0
- prepforge-0.1.0.dist-info/licenses/LICENSE +21 -0
- prepforge-0.1.0.dist-info/top_level.txt +1 -0
|
File without changes
|
major_project/cli.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import subprocess
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main():
|
|
8
|
+
parser = argparse.ArgumentParser(
|
|
9
|
+
prog="prepforge",
|
|
10
|
+
description="prepforge - Multi-interface AI assistant with training support",
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
14
|
+
subparsers.required = True
|
|
15
|
+
|
|
16
|
+
# -------------------------
|
|
17
|
+
# RUN APPS
|
|
18
|
+
# -------------------------
|
|
19
|
+
run_parser = subparsers.add_parser(
|
|
20
|
+
"run",
|
|
21
|
+
help="Run the application (tui | gui | streamlit)",
|
|
22
|
+
description=(
|
|
23
|
+
"Run prepforce in one of the following modes:\n\n"
|
|
24
|
+
" tui Terminal interface\n"
|
|
25
|
+
" gui Desktop application\n"
|
|
26
|
+
" streamlit Web interface\n"
|
|
27
|
+
),
|
|
28
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
run_parser.add_argument(
|
|
32
|
+
"mode",
|
|
33
|
+
choices=["tui", "gui", "streamlit"],
|
|
34
|
+
help="Mode to run",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# -------------------------
|
|
38
|
+
# TRAIN
|
|
39
|
+
# -------------------------
|
|
40
|
+
train_parser = subparsers.add_parser(
|
|
41
|
+
"train",
|
|
42
|
+
help="Train a LoRA adapter",
|
|
43
|
+
description=(
|
|
44
|
+
"Train a model using your dataset\n\n"
|
|
45
|
+
"Dataset Control Options:\n"
|
|
46
|
+
" --limit N Train on first N samples\n"
|
|
47
|
+
" --subset F Train on fraction (0.1 = 10%%)\n\n"
|
|
48
|
+
"Examples:\n"
|
|
49
|
+
" prepforce train --dataset data.jsonl\n"
|
|
50
|
+
" prepforce train --dataset data.jsonl --limit 500\n"
|
|
51
|
+
" prepforce train --dataset data.jsonl --subset 0.05\n"
|
|
52
|
+
),
|
|
53
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
train_parser.add_argument(
|
|
57
|
+
"--dataset",
|
|
58
|
+
required=True,
|
|
59
|
+
help="Path to dataset file (JSON/JSONL format)",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
train_parser.add_argument(
|
|
63
|
+
"--output",
|
|
64
|
+
default="trained_model",
|
|
65
|
+
help="Output directory for trained model",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
train_parser.add_argument(
|
|
69
|
+
"--epochs",
|
|
70
|
+
type=int,
|
|
71
|
+
default=2,
|
|
72
|
+
help="Number of training epochs",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
train_parser.add_argument(
|
|
76
|
+
"--lr",
|
|
77
|
+
type=float,
|
|
78
|
+
default=2e-4,
|
|
79
|
+
help="Learning rate",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# 🔥 Mutually exclusive group
|
|
83
|
+
group = train_parser.add_mutually_exclusive_group()
|
|
84
|
+
|
|
85
|
+
group.add_argument(
|
|
86
|
+
"--limit",
|
|
87
|
+
type=int,
|
|
88
|
+
help="Limit number of samples (e.g. 500)",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
group.add_argument(
|
|
92
|
+
"--subset",
|
|
93
|
+
type=float,
|
|
94
|
+
help="Use fraction of dataset (0.05 = 5%%)",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# -------------------------
|
|
98
|
+
# CONFIG
|
|
99
|
+
# -------------------------
|
|
100
|
+
config_parser = subparsers.add_parser(
|
|
101
|
+
"config",
|
|
102
|
+
help="Set default model or LoRA",
|
|
103
|
+
description=(
|
|
104
|
+
"Configure default model and LoRA adapter\n\n"
|
|
105
|
+
"Examples:\n"
|
|
106
|
+
" prepforce config --model meta-llama/Llama-3-8B\n"
|
|
107
|
+
" prepforce config --lora ~/adapter\n"
|
|
108
|
+
),
|
|
109
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
config_parser.add_argument(
|
|
113
|
+
"--model",
|
|
114
|
+
help="Set default base model",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
config_parser.add_argument(
|
|
118
|
+
"--lora",
|
|
119
|
+
help="Set default LoRA adapter path",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
args = parser.parse_args()
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
# -------------------------
|
|
126
|
+
# RUN MODES
|
|
127
|
+
# -------------------------
|
|
128
|
+
if args.command == "run":
|
|
129
|
+
if args.mode == "tui":
|
|
130
|
+
from .tui_app import main
|
|
131
|
+
|
|
132
|
+
main()
|
|
133
|
+
|
|
134
|
+
elif args.mode == "gui":
|
|
135
|
+
from .gui_app import main
|
|
136
|
+
|
|
137
|
+
main()
|
|
138
|
+
|
|
139
|
+
elif args.mode == "streamlit":
|
|
140
|
+
base_dir = os.path.dirname(__file__)
|
|
141
|
+
app_path = os.path.join(base_dir, "streamlit_app.py")
|
|
142
|
+
|
|
143
|
+
subprocess.run([sys.executable, "-m", "streamlit", "run", app_path])
|
|
144
|
+
|
|
145
|
+
# -------------------------
|
|
146
|
+
# TRAIN
|
|
147
|
+
# -------------------------
|
|
148
|
+
elif args.command == "train":
|
|
149
|
+
from .core.train import train_model
|
|
150
|
+
|
|
151
|
+
train_model(
|
|
152
|
+
dataset_path=args.dataset,
|
|
153
|
+
output_dir=args.output,
|
|
154
|
+
epochs=args.epochs,
|
|
155
|
+
lr=args.lr,
|
|
156
|
+
limit=args.limit,
|
|
157
|
+
subset=args.subset,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# -------------------------
|
|
161
|
+
# CONFIG
|
|
162
|
+
# -------------------------
|
|
163
|
+
elif args.command == "config":
|
|
164
|
+
from .core.config_store import load_user_config, save_user_config
|
|
165
|
+
|
|
166
|
+
config = load_user_config()
|
|
167
|
+
|
|
168
|
+
if args.model:
|
|
169
|
+
config["model"] = args.model
|
|
170
|
+
print(f"✅ Default model set to: {args.model}")
|
|
171
|
+
|
|
172
|
+
if args.lora:
|
|
173
|
+
config["lora"] = args.lora
|
|
174
|
+
print(f"✅ Default LoRA set to: {args.lora}")
|
|
175
|
+
|
|
176
|
+
if not args.model and not args.lora:
|
|
177
|
+
print("⚠️ No changes provided. Use --model or --lora")
|
|
178
|
+
|
|
179
|
+
save_user_config(config)
|
|
180
|
+
|
|
181
|
+
except Exception as e:
|
|
182
|
+
print(f"\n❌ Error: {str(e)}\n")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
main()
|
major_project/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from major_project.core.config_store import load_user_config
|
|
3
|
+
|
|
4
|
+
DEFAULT_MODEL = "unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit"
|
|
5
|
+
DEFAULT_LORA = None
|
|
6
|
+
DEFAULT_APP_NAME = "PrepForge"
|
|
7
|
+
MAX_HISTORY = 6
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_model():
|
|
11
|
+
user_config = load_user_config()
|
|
12
|
+
return user_config.get("model") or os.getenv("MENTORAI_MODEL") or DEFAULT_MODEL
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_lora():
|
|
16
|
+
user_config = load_user_config()
|
|
17
|
+
return user_config.get("lora") or os.getenv("MENTORAI_LORA") or DEFAULT_LORA
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
CONFIG_DIR = os.path.expanduser("~/.prepforce")
|
|
5
|
+
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def load_user_config():
|
|
9
|
+
if not os.path.exists(CONFIG_FILE):
|
|
10
|
+
return {}
|
|
11
|
+
|
|
12
|
+
with open(CONFIG_FILE, "r") as f:
|
|
13
|
+
return json.load(f)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def save_user_config(config):
|
|
17
|
+
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
18
|
+
|
|
19
|
+
with open(CONFIG_FILE, "w") as f:
|
|
20
|
+
json.dump(config, f, indent=2)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
4
|
+
from peft import PeftModel
|
|
5
|
+
|
|
6
|
+
# ✅ Use shared config
|
|
7
|
+
from major_project.config import get_model, get_lora
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _check_cuda():
|
|
11
|
+
"""
|
|
12
|
+
Ensure NVIDIA GPU is available.
|
|
13
|
+
"""
|
|
14
|
+
if not torch.cuda.is_available():
|
|
15
|
+
raise RuntimeError(
|
|
16
|
+
"\n❌ CUDA GPU not detected.\n"
|
|
17
|
+
"This application requires an NVIDIA GPU.\n"
|
|
18
|
+
"CPU execution is not supported.\n"
|
|
19
|
+
"\n👉 Install CUDA-enabled PyTorch:\n"
|
|
20
|
+
"pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_model(
|
|
25
|
+
base_model: str | None = None,
|
|
26
|
+
lora_path: str | None = None,
|
|
27
|
+
device_map: str = "auto",
|
|
28
|
+
dtype: torch.dtype = torch.float16,
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Load base model and optionally apply LoRA adapter.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
base_model (str | None): HuggingFace model name or local path
|
|
35
|
+
lora_path (str | None): Path to LoRA adapter folder
|
|
36
|
+
device_map (str): Device mapping ("auto", "cpu", etc.)
|
|
37
|
+
dtype (torch.dtype): Torch dtype
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
model, tokenizer
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
# -----------------------------
|
|
44
|
+
# 🔥 CUDA CHECK
|
|
45
|
+
# -----------------------------
|
|
46
|
+
_check_cuda()
|
|
47
|
+
|
|
48
|
+
# -----------------------------
|
|
49
|
+
# USE DEFAULT CONFIG (if not provided)
|
|
50
|
+
# -----------------------------
|
|
51
|
+
if base_model is None:
|
|
52
|
+
base_model = get_model()
|
|
53
|
+
|
|
54
|
+
if lora_path is None:
|
|
55
|
+
lora_path = get_lora()
|
|
56
|
+
|
|
57
|
+
# -----------------------------
|
|
58
|
+
# VALIDATE INPUTS
|
|
59
|
+
# -----------------------------
|
|
60
|
+
if not base_model:
|
|
61
|
+
raise ValueError("Base model path/name must be provided")
|
|
62
|
+
|
|
63
|
+
if lora_path:
|
|
64
|
+
lora_path = os.path.expanduser(lora_path)
|
|
65
|
+
|
|
66
|
+
if not os.path.exists(lora_path):
|
|
67
|
+
raise FileNotFoundError(f"LoRA path not found: {lora_path}")
|
|
68
|
+
|
|
69
|
+
expected = os.path.join(lora_path, "adapter_config.json")
|
|
70
|
+
if not os.path.exists(expected):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Invalid LoRA adapter folder (missing adapter_config.json): {lora_path}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# -----------------------------
|
|
76
|
+
# LOAD BASE MODEL
|
|
77
|
+
# -----------------------------
|
|
78
|
+
print(f"[MODEL] Loading base model: {base_model}")
|
|
79
|
+
|
|
80
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
81
|
+
base_model,
|
|
82
|
+
device_map=device_map,
|
|
83
|
+
torch_dtype=dtype,
|
|
84
|
+
local_files_only=True,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
88
|
+
base_model,
|
|
89
|
+
local_files_only=True,
|
|
90
|
+
use_fast=True,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# -----------------------------
|
|
94
|
+
# APPLY LORA (OPTIONAL)
|
|
95
|
+
# -----------------------------
|
|
96
|
+
if lora_path:
|
|
97
|
+
print(f"[MODEL] Applying LoRA adapter: {lora_path}")
|
|
98
|
+
|
|
99
|
+
model = PeftModel.from_pretrained(
|
|
100
|
+
model,
|
|
101
|
+
lora_path,
|
|
102
|
+
local_files_only=True,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
print("[MODEL] LoRA successfully loaded")
|
|
106
|
+
|
|
107
|
+
# -----------------------------
|
|
108
|
+
# FINAL SETUP
|
|
109
|
+
# -----------------------------
|
|
110
|
+
model.eval()
|
|
111
|
+
|
|
112
|
+
return model, tokenizer
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from unsloth import FastLanguageModel
|
|
2
|
+
from datasets import load_dataset, Dataset
|
|
3
|
+
from transformers import TrainingArguments
|
|
4
|
+
from trl.trainer.sft_trainer import SFTTrainer
|
|
5
|
+
|
|
6
|
+
from major_project.config import get_model
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def train_model(
|
|
10
|
+
dataset_path,
|
|
11
|
+
output_dir,
|
|
12
|
+
epochs=2,
|
|
13
|
+
lr=2e-4,
|
|
14
|
+
batch_size=2,
|
|
15
|
+
max_seq_length=512,
|
|
16
|
+
limit=None,
|
|
17
|
+
subset=None,
|
|
18
|
+
):
|
|
19
|
+
# -----------------------------
|
|
20
|
+
# LOAD MODEL FROM CONFIG
|
|
21
|
+
# -----------------------------
|
|
22
|
+
model_name = get_model()
|
|
23
|
+
print(f"[TRAIN] Using model: {model_name}")
|
|
24
|
+
|
|
25
|
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
26
|
+
model_name=model_name,
|
|
27
|
+
max_seq_length=max_seq_length,
|
|
28
|
+
load_in_4bit=True,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# -----------------------------
|
|
32
|
+
# APPLY LoRA
|
|
33
|
+
# -----------------------------
|
|
34
|
+
model = FastLanguageModel.get_peft_model(
|
|
35
|
+
model,
|
|
36
|
+
r=16,
|
|
37
|
+
target_modules=[
|
|
38
|
+
"q_proj",
|
|
39
|
+
"k_proj",
|
|
40
|
+
"v_proj",
|
|
41
|
+
"o_proj",
|
|
42
|
+
"gate_proj",
|
|
43
|
+
"up_proj",
|
|
44
|
+
"down_proj",
|
|
45
|
+
],
|
|
46
|
+
lora_alpha=16,
|
|
47
|
+
lora_dropout=0,
|
|
48
|
+
bias="none",
|
|
49
|
+
use_gradient_checkpointing="unsloth",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# -----------------------------
|
|
53
|
+
# LOAD DATASET
|
|
54
|
+
# -----------------------------
|
|
55
|
+
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
|
56
|
+
assert isinstance(dataset, Dataset)
|
|
57
|
+
|
|
58
|
+
print(f"[TRAIN] Original dataset size: {len(dataset)}")
|
|
59
|
+
|
|
60
|
+
# -----------------------------
|
|
61
|
+
# APPLY LIMIT / SUBSET
|
|
62
|
+
# -----------------------------
|
|
63
|
+
if limit is not None:
|
|
64
|
+
dataset = dataset.select(range(min(limit, len(dataset))))
|
|
65
|
+
print(f"[TRAIN] Using first {len(dataset)} samples")
|
|
66
|
+
|
|
67
|
+
elif subset is not None:
|
|
68
|
+
size = int(len(dataset) * subset)
|
|
69
|
+
dataset = dataset.select(range(size))
|
|
70
|
+
print(f"[TRAIN] Using {subset * 100:.1f}% → {len(dataset)} samples")
|
|
71
|
+
|
|
72
|
+
# -----------------------------
|
|
73
|
+
# FORMAT DATA
|
|
74
|
+
# -----------------------------
|
|
75
|
+
def format_prompt(example):
|
|
76
|
+
instruction = example.get("instruction", "")
|
|
77
|
+
input_text = example.get("input", "")
|
|
78
|
+
output = example.get("output", "")
|
|
79
|
+
|
|
80
|
+
prompt = f"{instruction}\n{input_text}".strip()
|
|
81
|
+
|
|
82
|
+
user_part = f"""<|begin_of_text|>
|
|
83
|
+
<|start_header_id|>user<|end_header_id|>
|
|
84
|
+
|
|
85
|
+
{prompt}
|
|
86
|
+
|
|
87
|
+
<|start_header_id|>assistant<|end_header_id|>
|
|
88
|
+
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
full_text = user_part + output
|
|
92
|
+
|
|
93
|
+
user_tokens = tokenizer(user_part, add_special_tokens=False)
|
|
94
|
+
full_tokens = tokenizer(
|
|
95
|
+
full_text,
|
|
96
|
+
truncation=True,
|
|
97
|
+
max_length=max_seq_length,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
labels = full_tokens["input_ids"].copy()
|
|
101
|
+
user_len = len(user_tokens["input_ids"])
|
|
102
|
+
labels[:user_len] = [-100] * user_len
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
"input_ids": full_tokens["input_ids"],
|
|
106
|
+
"attention_mask": full_tokens["attention_mask"],
|
|
107
|
+
"labels": labels,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
dataset = dataset.map(format_prompt)
|
|
111
|
+
|
|
112
|
+
dataset = dataset.remove_columns(
|
|
113
|
+
[
|
|
114
|
+
c
|
|
115
|
+
for c in dataset.column_names or []
|
|
116
|
+
if c not in ["input_ids", "attention_mask", "labels"]
|
|
117
|
+
]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# -----------------------------
|
|
121
|
+
# TRAINER
|
|
122
|
+
# -----------------------------
|
|
123
|
+
trainer = SFTTrainer(
|
|
124
|
+
model=model,
|
|
125
|
+
train_dataset=dataset,
|
|
126
|
+
args=TrainingArguments(
|
|
127
|
+
per_device_train_batch_size=batch_size,
|
|
128
|
+
num_train_epochs=epochs,
|
|
129
|
+
learning_rate=lr,
|
|
130
|
+
output_dir=output_dir,
|
|
131
|
+
logging_steps=10,
|
|
132
|
+
save_steps=500,
|
|
133
|
+
optim="adamw_8bit",
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# -----------------------------
|
|
138
|
+
# TRAIN
|
|
139
|
+
# -----------------------------
|
|
140
|
+
print("[TRAIN] Starting training...")
|
|
141
|
+
trainer.train()
|
|
142
|
+
|
|
143
|
+
# -----------------------------
|
|
144
|
+
# SAVE
|
|
145
|
+
# -----------------------------
|
|
146
|
+
model.save_pretrained(output_dir)
|
|
147
|
+
tokenizer.save_pretrained(output_dir)
|
|
148
|
+
|
|
149
|
+
print("[TRAIN] Training complete.")
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from transformers import TextIteratorStreamer
|
|
3
|
+
|
|
4
|
+
# -----------------------------
|
|
5
|
+
# MEMORY (SHARED)
|
|
6
|
+
# -----------------------------
|
|
7
|
+
chat_history = []
|
|
8
|
+
last_mcq_block = None
|
|
9
|
+
MAX_HISTORY = 6
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def trim_history():
|
|
13
|
+
global chat_history
|
|
14
|
+
if len(chat_history) > MAX_HISTORY:
|
|
15
|
+
chat_history = chat_history[-MAX_HISTORY:]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# -----------------------------
|
|
19
|
+
# PROMPT BUILDER (SHARED)
|
|
20
|
+
# -----------------------------
|
|
21
|
+
def build_prompt(mode, user_input, count):
|
|
22
|
+
global last_mcq_block
|
|
23
|
+
|
|
24
|
+
user_lower = user_input.lower()
|
|
25
|
+
|
|
26
|
+
# -------------------------
|
|
27
|
+
# 🔥 MCQ FOLLOWUPS
|
|
28
|
+
# -------------------------
|
|
29
|
+
if last_mcq_block:
|
|
30
|
+
# ONLY ANSWERS
|
|
31
|
+
if (
|
|
32
|
+
"answer" in user_lower
|
|
33
|
+
and "why" not in user_lower
|
|
34
|
+
and "explain" not in user_lower
|
|
35
|
+
):
|
|
36
|
+
return f"""Here are the MCQs:
|
|
37
|
+
|
|
38
|
+
{last_mcq_block}
|
|
39
|
+
|
|
40
|
+
Give ONLY the answers in this format:
|
|
41
|
+
1. A
|
|
42
|
+
2. B
|
|
43
|
+
3. C
|
|
44
|
+
(continue...)
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# EXPLANATION MODE
|
|
48
|
+
if "why" in user_lower or "explain" in user_lower:
|
|
49
|
+
return f"""Here are the MCQs:
|
|
50
|
+
|
|
51
|
+
{last_mcq_block}
|
|
52
|
+
|
|
53
|
+
For EACH question:
|
|
54
|
+
- Give the correct answer
|
|
55
|
+
- Explain WHY it is correct
|
|
56
|
+
- Briefly explain why other options are incorrect
|
|
57
|
+
|
|
58
|
+
FORMAT:
|
|
59
|
+
|
|
60
|
+
1. Answer: A
|
|
61
|
+
Explanation: ...
|
|
62
|
+
|
|
63
|
+
2. Answer: B
|
|
64
|
+
Explanation: ...
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
# -------------------------
|
|
68
|
+
# CHAT
|
|
69
|
+
# -------------------------
|
|
70
|
+
if mode in ["chat", "Chat"]:
|
|
71
|
+
return user_input
|
|
72
|
+
|
|
73
|
+
# -------------------------
|
|
74
|
+
# NOTES
|
|
75
|
+
# -------------------------
|
|
76
|
+
elif mode in ["notes", "Notes"]:
|
|
77
|
+
return f"""
|
|
78
|
+
Explain {user_input} in structured notes.
|
|
79
|
+
|
|
80
|
+
Use:
|
|
81
|
+
- Clear headings
|
|
82
|
+
- Bullet points
|
|
83
|
+
- Simple explanations
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# -------------------------
|
|
87
|
+
# MCQ GENERATION (FIXED)
|
|
88
|
+
# -------------------------
|
|
89
|
+
elif mode in ["mcq", "MCQ Generator"]:
|
|
90
|
+
return f"""
|
|
91
|
+
Generate EXACTLY {count} multiple choice questions on {user_input}.
|
|
92
|
+
|
|
93
|
+
STRICT RULES:
|
|
94
|
+
- Generate EXACTLY {count} questions (no more, no less)
|
|
95
|
+
- Each question MUST have 4 options: A, B, C, D
|
|
96
|
+
- Each option MUST be on a new line
|
|
97
|
+
- Do NOT combine options in one line
|
|
98
|
+
- Do NOT stop early
|
|
99
|
+
- Do NOT skip numbers
|
|
100
|
+
|
|
101
|
+
FORMAT:
|
|
102
|
+
|
|
103
|
+
SECTION 1: QUESTIONS
|
|
104
|
+
|
|
105
|
+
1. Question text
|
|
106
|
+
A. Option
|
|
107
|
+
B. Option
|
|
108
|
+
C. Option
|
|
109
|
+
D. Option
|
|
110
|
+
|
|
111
|
+
2. Question text
|
|
112
|
+
A. Option
|
|
113
|
+
B. Option
|
|
114
|
+
C. Option
|
|
115
|
+
D. Option
|
|
116
|
+
|
|
117
|
+
(continue until {count})
|
|
118
|
+
|
|
119
|
+
SECTION 2: ANSWERS
|
|
120
|
+
|
|
121
|
+
1. A
|
|
122
|
+
2. B
|
|
123
|
+
3. C
|
|
124
|
+
4. D
|
|
125
|
+
(continue until {count})
|
|
126
|
+
|
|
127
|
+
IMPORTANT:
|
|
128
|
+
- SECTION 1 must contain ALL {count} questions
|
|
129
|
+
- SECTION 2 must contain ALL {count} answers
|
|
130
|
+
- DO NOT include explanations
|
|
131
|
+
- DO NOT add extra commentary
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
# -------------------------
|
|
135
|
+
# PRACTICE QUESTIONS
|
|
136
|
+
# -------------------------
|
|
137
|
+
elif mode in ["practice", "Practice"]:
|
|
138
|
+
return f"""
|
|
139
|
+
Generate EXACTLY {count} practice questions on {user_input}.
|
|
140
|
+
|
|
141
|
+
RULES:
|
|
142
|
+
- These must be descriptive or short-answer questions
|
|
143
|
+
- DO NOT generate MCQs
|
|
144
|
+
- DO NOT provide answers initially
|
|
145
|
+
- Keep questions clear and numbered
|
|
146
|
+
|
|
147
|
+
FORMAT:
|
|
148
|
+
|
|
149
|
+
1. Question
|
|
150
|
+
2. Question
|
|
151
|
+
3. Question
|
|
152
|
+
(continue...)
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
# -------------------------
|
|
156
|
+
# PRACTICE FOLLOWUP (EXPLANATION)
|
|
157
|
+
# -------------------------
|
|
158
|
+
if mode in ["practice", "Practice"] and (
|
|
159
|
+
"answer" in user_lower or "explain" in user_lower or "why" in user_lower
|
|
160
|
+
):
|
|
161
|
+
return f"""
|
|
162
|
+
Here are the practice questions:
|
|
163
|
+
|
|
164
|
+
{user_input}
|
|
165
|
+
|
|
166
|
+
Provide detailed answers and explanations for each question.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
# -------------------------
|
|
170
|
+
# STUDY PLAN
|
|
171
|
+
# -------------------------
|
|
172
|
+
elif mode in ["plan", "Study Plan"]:
|
|
173
|
+
return f"""
|
|
174
|
+
Create a structured study plan for {user_input}.
|
|
175
|
+
|
|
176
|
+
Include:
|
|
177
|
+
- Daily/weekly breakdown
|
|
178
|
+
- Topics to cover
|
|
179
|
+
- Practice suggestions
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
return user_input
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# -----------------------------
|
|
186
|
+
# STREAM GENERATION (SHARED)
|
|
187
|
+
# -----------------------------
|
|
188
|
+
def generate_stream(model, tokenizer, prompt, max_tokens):
|
|
189
|
+
global chat_history
|
|
190
|
+
|
|
191
|
+
chat_history.append({"role": "user", "content": prompt})
|
|
192
|
+
trim_history()
|
|
193
|
+
|
|
194
|
+
messages = [{"role": "system", "content": "You are a helpful AI teacher."}]
|
|
195
|
+
messages.extend(chat_history)
|
|
196
|
+
|
|
197
|
+
full_prompt = tokenizer.apply_chat_template(
|
|
198
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
|
202
|
+
|
|
203
|
+
streamer = TextIteratorStreamer(
|
|
204
|
+
tokenizer,
|
|
205
|
+
skip_prompt=True,
|
|
206
|
+
skip_special_tokens=True,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def run():
|
|
210
|
+
model.generate(
|
|
211
|
+
**inputs,
|
|
212
|
+
max_new_tokens=max_tokens,
|
|
213
|
+
temperature=0.7,
|
|
214
|
+
top_p=0.95,
|
|
215
|
+
do_sample=True,
|
|
216
|
+
streamer=streamer,
|
|
217
|
+
pad_token_id=tokenizer.pad_token_id,
|
|
218
|
+
eos_token_id=tokenizer.eos_token_id,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
threading.Thread(target=run).start()
|
|
222
|
+
return streamer
|