npcpy 1.0.26__py3-none-any.whl → 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/__init__.py +0 -7
- npcpy/data/audio.py +16 -99
- npcpy/data/image.py +43 -42
- npcpy/data/load.py +83 -124
- npcpy/data/text.py +28 -28
- npcpy/data/video.py +8 -32
- npcpy/data/web.py +51 -23
- npcpy/ft/diff.py +110 -0
- npcpy/ft/ge.py +115 -0
- npcpy/ft/memory_trainer.py +171 -0
- npcpy/ft/model_ensembler.py +357 -0
- npcpy/ft/rl.py +360 -0
- npcpy/ft/sft.py +248 -0
- npcpy/ft/usft.py +128 -0
- npcpy/gen/audio_gen.py +24 -0
- npcpy/gen/embeddings.py +13 -13
- npcpy/gen/image_gen.py +262 -117
- npcpy/gen/response.py +615 -415
- npcpy/gen/video_gen.py +53 -7
- npcpy/llm_funcs.py +1869 -437
- npcpy/main.py +1 -1
- npcpy/memory/command_history.py +844 -510
- npcpy/memory/kg_vis.py +833 -0
- npcpy/memory/knowledge_graph.py +892 -1845
- npcpy/memory/memory_processor.py +81 -0
- npcpy/memory/search.py +188 -90
- npcpy/mix/debate.py +192 -3
- npcpy/npc_compiler.py +1672 -801
- npcpy/npc_sysenv.py +593 -1266
- npcpy/serve.py +3120 -0
- npcpy/sql/ai_function_tools.py +257 -0
- npcpy/sql/database_ai_adapters.py +186 -0
- npcpy/sql/database_ai_functions.py +163 -0
- npcpy/sql/model_runner.py +19 -19
- npcpy/sql/npcsql.py +706 -507
- npcpy/sql/sql_model_compiler.py +156 -0
- npcpy/tools.py +183 -0
- npcpy/work/plan.py +13 -279
- npcpy/work/trigger.py +3 -3
- npcpy-1.2.32.dist-info/METADATA +803 -0
- npcpy-1.2.32.dist-info/RECORD +54 -0
- npcpy/data/dataframes.py +0 -171
- npcpy/memory/deep_research.py +0 -125
- npcpy/memory/sleep.py +0 -557
- npcpy/modes/_state.py +0 -78
- npcpy/modes/alicanto.py +0 -1075
- npcpy/modes/guac.py +0 -785
- npcpy/modes/mcp_npcsh.py +0 -822
- npcpy/modes/npc.py +0 -213
- npcpy/modes/npcsh.py +0 -1158
- npcpy/modes/plonk.py +0 -409
- npcpy/modes/pti.py +0 -234
- npcpy/modes/serve.py +0 -1637
- npcpy/modes/spool.py +0 -312
- npcpy/modes/wander.py +0 -549
- npcpy/modes/yap.py +0 -572
- npcpy/npc_team/alicanto.npc +0 -2
- npcpy/npc_team/alicanto.png +0 -0
- npcpy/npc_team/assembly_lines/test_pipeline.py +0 -181
- npcpy/npc_team/corca.npc +0 -13
- npcpy/npc_team/foreman.npc +0 -7
- npcpy/npc_team/frederic.npc +0 -6
- npcpy/npc_team/frederic4.png +0 -0
- npcpy/npc_team/guac.png +0 -0
- npcpy/npc_team/jinxs/automator.jinx +0 -18
- npcpy/npc_team/jinxs/bash_executer.jinx +0 -31
- npcpy/npc_team/jinxs/calculator.jinx +0 -11
- npcpy/npc_team/jinxs/edit_file.jinx +0 -96
- npcpy/npc_team/jinxs/file_chat.jinx +0 -14
- npcpy/npc_team/jinxs/gui_controller.jinx +0 -28
- npcpy/npc_team/jinxs/image_generation.jinx +0 -29
- npcpy/npc_team/jinxs/internet_search.jinx +0 -30
- npcpy/npc_team/jinxs/local_search.jinx +0 -152
- npcpy/npc_team/jinxs/npcsh_executor.jinx +0 -31
- npcpy/npc_team/jinxs/python_executor.jinx +0 -8
- npcpy/npc_team/jinxs/screen_cap.jinx +0 -25
- npcpy/npc_team/jinxs/sql_executor.jinx +0 -33
- npcpy/npc_team/kadiefa.npc +0 -3
- npcpy/npc_team/kadiefa.png +0 -0
- npcpy/npc_team/npcsh.ctx +0 -9
- npcpy/npc_team/npcsh_sibiji.png +0 -0
- npcpy/npc_team/plonk.npc +0 -2
- npcpy/npc_team/plonk.png +0 -0
- npcpy/npc_team/plonkjr.npc +0 -2
- npcpy/npc_team/plonkjr.png +0 -0
- npcpy/npc_team/sibiji.npc +0 -5
- npcpy/npc_team/sibiji.png +0 -0
- npcpy/npc_team/spool.png +0 -0
- npcpy/npc_team/templates/analytics/celona.npc +0 -0
- npcpy/npc_team/templates/hr_support/raone.npc +0 -0
- npcpy/npc_team/templates/humanities/eriane.npc +0 -4
- npcpy/npc_team/templates/it_support/lineru.npc +0 -0
- npcpy/npc_team/templates/marketing/slean.npc +0 -4
- npcpy/npc_team/templates/philosophy/maurawa.npc +0 -0
- npcpy/npc_team/templates/sales/turnic.npc +0 -4
- npcpy/npc_team/templates/software/welxor.npc +0 -0
- npcpy/npc_team/yap.png +0 -0
- npcpy/routes.py +0 -958
- npcpy/work/mcp_helpers.py +0 -357
- npcpy/work/mcp_server.py +0 -194
- npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.npc +0 -2
- npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/automator.jinx +0 -18
- npcpy-1.0.26.data/data/npcpy/npc_team/bash_executer.jinx +0 -31
- npcpy-1.0.26.data/data/npcpy/npc_team/calculator.jinx +0 -11
- npcpy-1.0.26.data/data/npcpy/npc_team/celona.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/corca.npc +0 -13
- npcpy-1.0.26.data/data/npcpy/npc_team/edit_file.jinx +0 -96
- npcpy-1.0.26.data/data/npcpy/npc_team/eriane.npc +0 -4
- npcpy-1.0.26.data/data/npcpy/npc_team/file_chat.jinx +0 -14
- npcpy-1.0.26.data/data/npcpy/npc_team/foreman.npc +0 -7
- npcpy-1.0.26.data/data/npcpy/npc_team/frederic.npc +0 -6
- npcpy-1.0.26.data/data/npcpy/npc_team/frederic4.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/guac.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/gui_controller.jinx +0 -28
- npcpy-1.0.26.data/data/npcpy/npc_team/image_generation.jinx +0 -29
- npcpy-1.0.26.data/data/npcpy/npc_team/internet_search.jinx +0 -30
- npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.npc +0 -3
- npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/lineru.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/local_search.jinx +0 -152
- npcpy-1.0.26.data/data/npcpy/npc_team/maurawa.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/npcsh.ctx +0 -9
- npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_executor.jinx +0 -31
- npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_sibiji.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/plonk.npc +0 -2
- npcpy-1.0.26.data/data/npcpy/npc_team/plonk.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.npc +0 -2
- npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/python_executor.jinx +0 -8
- npcpy-1.0.26.data/data/npcpy/npc_team/raone.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/screen_cap.jinx +0 -25
- npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.npc +0 -5
- npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/slean.npc +0 -4
- npcpy-1.0.26.data/data/npcpy/npc_team/spool.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/sql_executor.jinx +0 -33
- npcpy-1.0.26.data/data/npcpy/npc_team/test_pipeline.py +0 -181
- npcpy-1.0.26.data/data/npcpy/npc_team/turnic.npc +0 -4
- npcpy-1.0.26.data/data/npcpy/npc_team/welxor.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/yap.png +0 -0
- npcpy-1.0.26.dist-info/METADATA +0 -827
- npcpy-1.0.26.dist-info/RECORD +0 -139
- npcpy-1.0.26.dist-info/entry_points.txt +0 -11
- /npcpy/{modes → ft}/__init__.py +0 -0
- {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/WHEEL +0 -0
- {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/top_level.txt +0 -0
npcpy/ft/rl.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import glob
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import pandas as pd
|
|
8
|
+
try:
|
|
9
|
+
from datasets import Dataset
|
|
10
|
+
|
|
11
|
+
from peft import LoraConfig, PeftModel
|
|
12
|
+
import torch
|
|
13
|
+
from transformers import (
|
|
14
|
+
AutoModelForCausalLM,
|
|
15
|
+
AutoTokenizer
|
|
16
|
+
)
|
|
17
|
+
from trl import DPOTrainer, DPOConfig
|
|
18
|
+
except:
|
|
19
|
+
Dataset = None
|
|
20
|
+
PeftModel = None
|
|
21
|
+
DPOConfig = None
|
|
22
|
+
DPOTrainer = None
|
|
23
|
+
torch = None
|
|
24
|
+
AutoModelForCausalLM = None
|
|
25
|
+
AutoTokenizer = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
import random
|
|
29
|
+
from typing import List, Dict, Any, Optional, Callable
|
|
30
|
+
from npcpy.npc_compiler import NPC
|
|
31
|
+
from npcpy.llm_funcs import get_llm_response
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class RLConfig:
|
|
36
|
+
base_model_name: str = "Qwen/Qwen3-0.6B"
|
|
37
|
+
adapter_path: str = "./rl_adapter"
|
|
38
|
+
max_iterations: int = 8
|
|
39
|
+
min_reward_gap: float = 0.4
|
|
40
|
+
num_train_epochs: int = 20
|
|
41
|
+
per_device_train_batch_size: int = 1
|
|
42
|
+
gradient_accumulation_steps: int = 2
|
|
43
|
+
learning_rate: float = 1e-6
|
|
44
|
+
beta: float = 0.5
|
|
45
|
+
max_length: int = 512
|
|
46
|
+
max_prompt_length: int = 256
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TaskExecutor:
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
agent: NPC,
|
|
54
|
+
max_iterations: int = 8
|
|
55
|
+
):
|
|
56
|
+
self.agent = agent
|
|
57
|
+
self.max_iterations = max_iterations
|
|
58
|
+
|
|
59
|
+
def execute_task(
|
|
60
|
+
self,
|
|
61
|
+
task_prompt: str
|
|
62
|
+
) -> Dict[str, Any]:
|
|
63
|
+
|
|
64
|
+
messages = [
|
|
65
|
+
{
|
|
66
|
+
"role": "system",
|
|
67
|
+
"content": self.agent.primary_directive
|
|
68
|
+
}
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
raw_responses = []
|
|
72
|
+
current_prompt = task_prompt
|
|
73
|
+
|
|
74
|
+
for i in range(self.max_iterations):
|
|
75
|
+
response_obj = self.agent.get_llm_response(
|
|
76
|
+
current_prompt,
|
|
77
|
+
messages=messages,
|
|
78
|
+
auto_process_tool_calls=True
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
raw_responses.append(response_obj)
|
|
82
|
+
messages = response_obj.get('messages', messages)
|
|
83
|
+
|
|
84
|
+
last_content = messages[-1].get('content', '')
|
|
85
|
+
|
|
86
|
+
if self._is_complete(last_content):
|
|
87
|
+
return {
|
|
88
|
+
"raw_responses": raw_responses,
|
|
89
|
+
"final_output": last_content,
|
|
90
|
+
"total_iterations": i + 1,
|
|
91
|
+
"completed": True
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
current_prompt = (
|
|
95
|
+
"Continue or provide final answer."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return {
|
|
99
|
+
"raw_responses": raw_responses,
|
|
100
|
+
"final_output": messages[-1].get('content', ''),
|
|
101
|
+
"total_iterations": self.max_iterations,
|
|
102
|
+
"completed": False
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def _is_complete(self, content: str) -> bool:
|
|
106
|
+
|
|
107
|
+
completion_markers = [
|
|
108
|
+
"final answer:",
|
|
109
|
+
"conclusion:",
|
|
110
|
+
"result:",
|
|
111
|
+
"therefore",
|
|
112
|
+
"in summary"
|
|
113
|
+
]
|
|
114
|
+
content_lower = content.lower()
|
|
115
|
+
return any(
|
|
116
|
+
marker in content_lower
|
|
117
|
+
for marker in completion_markers
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def collect_traces(
|
|
122
|
+
tasks: List[Dict[str, Any]],
|
|
123
|
+
agents: List[NPC],
|
|
124
|
+
reward_fn: Callable[[Dict], float],
|
|
125
|
+
config: Optional[RLConfig] = None
|
|
126
|
+
) -> List[Dict[str, Any]]:
|
|
127
|
+
|
|
128
|
+
if config is None:
|
|
129
|
+
config = RLConfig()
|
|
130
|
+
|
|
131
|
+
traces = []
|
|
132
|
+
|
|
133
|
+
for task in tasks:
|
|
134
|
+
task_prompt = task.get('prompt', task.get('input', ''))
|
|
135
|
+
|
|
136
|
+
for agent in agents:
|
|
137
|
+
executor = TaskExecutor(
|
|
138
|
+
agent,
|
|
139
|
+
max_iterations=config.max_iterations
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
result = executor.execute_task(task_prompt)
|
|
143
|
+
|
|
144
|
+
trace = {
|
|
145
|
+
"agent_name": agent.name,
|
|
146
|
+
"task_prompt": task_prompt,
|
|
147
|
+
"final_output": result['final_output'],
|
|
148
|
+
"total_iterations": result['total_iterations'],
|
|
149
|
+
"completed": result['completed'],
|
|
150
|
+
"task_metadata": task
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
trace['reward'] = reward_fn(trace)
|
|
154
|
+
|
|
155
|
+
traces.append(trace)
|
|
156
|
+
|
|
157
|
+
print(
|
|
158
|
+
f"Agent {agent.name}: "
|
|
159
|
+
f"Reward={trace['reward']:.2f}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return traces
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def create_preference_pairs(
|
|
166
|
+
traces: List[Dict[str, Any]],
|
|
167
|
+
min_reward_gap: float = 0.4
|
|
168
|
+
) -> Dataset:
|
|
169
|
+
|
|
170
|
+
df = pd.DataFrame(traces)
|
|
171
|
+
df = df[df['reward'] > -1.0].copy()
|
|
172
|
+
|
|
173
|
+
if len(df) < 2:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
df = df.sort_values('reward', ascending=False)
|
|
177
|
+
|
|
178
|
+
top_quantile = df['reward'].quantile(
|
|
179
|
+
0.8,
|
|
180
|
+
interpolation='higher'
|
|
181
|
+
)
|
|
182
|
+
low_quantile = df['reward'].quantile(
|
|
183
|
+
0.2,
|
|
184
|
+
interpolation='lower'
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
high_traces = df[df['reward'] >= top_quantile]
|
|
188
|
+
low_traces = df[df['reward'] <= low_quantile]
|
|
189
|
+
|
|
190
|
+
pairs = []
|
|
191
|
+
|
|
192
|
+
for _, high_trace in high_traces.iterrows():
|
|
193
|
+
for _, low_trace in low_traces.iterrows():
|
|
194
|
+
reward_gap = (
|
|
195
|
+
high_trace['reward'] - low_trace['reward']
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if reward_gap >= min_reward_gap:
|
|
199
|
+
pairs.append({
|
|
200
|
+
"prompt": str(high_trace['task_prompt']),
|
|
201
|
+
"chosen": str(high_trace['final_output']),
|
|
202
|
+
"rejected": str(low_trace['final_output'])
|
|
203
|
+
})
|
|
204
|
+
|
|
205
|
+
if len(pairs) < 5:
|
|
206
|
+
print(
|
|
207
|
+
f"Warning: Only {len(pairs)} pairs found. "
|
|
208
|
+
"May overfit."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return Dataset.from_list(pairs[:100])
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def train_with_dpo(
|
|
215
|
+
traces: List[Dict[str, Any]],
|
|
216
|
+
config: Optional[RLConfig] = None
|
|
217
|
+
) -> str:
|
|
218
|
+
|
|
219
|
+
if config is None:
|
|
220
|
+
config = RLConfig()
|
|
221
|
+
|
|
222
|
+
preference_dataset = create_preference_pairs(
|
|
223
|
+
traces,
|
|
224
|
+
min_reward_gap=config.min_reward_gap
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if preference_dataset is None or len(preference_dataset) == 0:
|
|
228
|
+
print("No valid preference pairs. Cannot train.")
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
232
|
+
config.base_model_name,
|
|
233
|
+
torch_dtype=torch.float32,
|
|
234
|
+
device_map="auto",
|
|
235
|
+
low_cpu_mem_usage=True
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
239
|
+
config.base_model_name,
|
|
240
|
+
trust_remote_code=True
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if tokenizer.pad_token is None:
|
|
244
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
245
|
+
|
|
246
|
+
peft_config = LoraConfig(
|
|
247
|
+
r=8,
|
|
248
|
+
lora_alpha=16,
|
|
249
|
+
lora_dropout=0.1,
|
|
250
|
+
bias="none",
|
|
251
|
+
task_type="CAUSAL_LM",
|
|
252
|
+
target_modules=[
|
|
253
|
+
"q_proj",
|
|
254
|
+
"k_proj",
|
|
255
|
+
"v_proj",
|
|
256
|
+
"o_proj"
|
|
257
|
+
]
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
training_args = DPOConfig(
|
|
261
|
+
output_dir="./dpo_results",
|
|
262
|
+
per_device_train_batch_size=(
|
|
263
|
+
config.per_device_train_batch_size
|
|
264
|
+
),
|
|
265
|
+
gradient_accumulation_steps=(
|
|
266
|
+
config.gradient_accumulation_steps
|
|
267
|
+
),
|
|
268
|
+
learning_rate=config.learning_rate,
|
|
269
|
+
num_train_epochs=config.num_train_epochs,
|
|
270
|
+
weight_decay=0.1,
|
|
271
|
+
beta=config.beta,
|
|
272
|
+
logging_steps=2,
|
|
273
|
+
save_steps=10,
|
|
274
|
+
remove_unused_columns=False,
|
|
275
|
+
max_length=config.max_length,
|
|
276
|
+
max_prompt_length=config.max_prompt_length,
|
|
277
|
+
dataloader_num_workers=0,
|
|
278
|
+
fp16=False,
|
|
279
|
+
bf16=False,
|
|
280
|
+
optim="adamw_torch",
|
|
281
|
+
warmup_steps=2,
|
|
282
|
+
save_strategy="steps",
|
|
283
|
+
save_total_limit=3
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
trainer = DPOTrainer(
|
|
287
|
+
model,
|
|
288
|
+
args=training_args,
|
|
289
|
+
train_dataset=preference_dataset,
|
|
290
|
+
peft_config=peft_config
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
print("Starting DPO training...")
|
|
294
|
+
trainer.train()
|
|
295
|
+
|
|
296
|
+
trainer.save_model(config.adapter_path)
|
|
297
|
+
print(f"Adapter saved to {config.adapter_path}")
|
|
298
|
+
|
|
299
|
+
return config.adapter_path
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def run_rl_training(
|
|
303
|
+
tasks: List[Dict[str, Any]],
|
|
304
|
+
agents: List[NPC],
|
|
305
|
+
reward_fn: Callable[[Dict], float],
|
|
306
|
+
config: Optional[RLConfig] = None,
|
|
307
|
+
save_traces: bool = True
|
|
308
|
+
) -> str:
|
|
309
|
+
|
|
310
|
+
if config is None:
|
|
311
|
+
config = RLConfig()
|
|
312
|
+
|
|
313
|
+
print(f"Collecting traces from {len(tasks)} tasks...")
|
|
314
|
+
traces = collect_traces(
|
|
315
|
+
tasks,
|
|
316
|
+
agents,
|
|
317
|
+
reward_fn,
|
|
318
|
+
config
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if save_traces:
|
|
322
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
323
|
+
traces_file = f"rl_traces_{timestamp}.csv"
|
|
324
|
+
df = pd.DataFrame(traces)
|
|
325
|
+
df.to_csv(traces_file, index=False)
|
|
326
|
+
print(f"Traces saved to {traces_file}")
|
|
327
|
+
|
|
328
|
+
print("Training with DPO...")
|
|
329
|
+
adapter_path = train_with_dpo(traces, config)
|
|
330
|
+
|
|
331
|
+
return adapter_path
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def load_rl_model(
|
|
335
|
+
base_model_id: str,
|
|
336
|
+
adapter_path: str
|
|
337
|
+
):
|
|
338
|
+
|
|
339
|
+
print(f"Loading base model: {base_model_id}")
|
|
340
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
341
|
+
base_model_id,
|
|
342
|
+
torch_dtype=torch.float32,
|
|
343
|
+
device_map="auto",
|
|
344
|
+
attn_implementation='eager'
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
348
|
+
base_model_id,
|
|
349
|
+
trust_remote_code=True
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if tokenizer.pad_token is None:
|
|
353
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
354
|
+
|
|
355
|
+
if adapter_path and os.path.exists(adapter_path):
|
|
356
|
+
print(f"Loading adapter: {adapter_path}")
|
|
357
|
+
model = PeftModel.from_pretrained(model, adapter_path)
|
|
358
|
+
model = model.merge_and_unload()
|
|
359
|
+
|
|
360
|
+
return model, tokenizer
|
npcpy/ft/sft.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# structured fine tuning of LLMs to produce structured output
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from datasets import Dataset
|
|
4
|
+
import json
|
|
5
|
+
import numpy as np
|
|
6
|
+
import os
|
|
7
|
+
try:
|
|
8
|
+
import torch
|
|
9
|
+
from transformers import (
|
|
10
|
+
AutoModelForCausalLM,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
TrainingArguments
|
|
13
|
+
)
|
|
14
|
+
from trl import SFTTrainer
|
|
15
|
+
from peft import LoraConfig
|
|
16
|
+
except:
|
|
17
|
+
torch = None
|
|
18
|
+
SFTTrainer = None
|
|
19
|
+
LoraConfig = None
|
|
20
|
+
AutoModelForCausalLM = None
|
|
21
|
+
AutoTokenizer = None
|
|
22
|
+
TrainingArguments = None
|
|
23
|
+
|
|
24
|
+
from typing import List, Dict, Any, Optional
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SFTConfig:
|
|
29
|
+
base_model_name: str = "google/gemma-3-270m-it"
|
|
30
|
+
output_model_path: str = "models/sft_model"
|
|
31
|
+
lora_r: int = 8
|
|
32
|
+
lora_alpha: int = 16
|
|
33
|
+
use_4bit: bool = False
|
|
34
|
+
fp16: bool = False
|
|
35
|
+
bf16: bool = False
|
|
36
|
+
lora_dropout: float = 0.15
|
|
37
|
+
lora_target_modules: List[str] = field(
|
|
38
|
+
default_factory=lambda: ["q_proj", "v_proj"]
|
|
39
|
+
)
|
|
40
|
+
num_train_epochs: int = 20
|
|
41
|
+
per_device_train_batch_size: int = 2
|
|
42
|
+
gradient_accumulation_steps: int = 4
|
|
43
|
+
learning_rate: float = 3e-5
|
|
44
|
+
logging_steps: int = 10
|
|
45
|
+
optim: str = "adamw_torch"
|
|
46
|
+
lr_scheduler_type: str = "cosine_with_restarts"
|
|
47
|
+
weight_decay: float = 0.01
|
|
48
|
+
max_length: int = 512
|
|
49
|
+
save_steps: int = 50
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def format_training_examples(
|
|
53
|
+
inputs: List[str],
|
|
54
|
+
outputs: List[str],
|
|
55
|
+
format_style: str = "gemma"
|
|
56
|
+
) -> List[Dict[str, str]]:
|
|
57
|
+
|
|
58
|
+
formatted = []
|
|
59
|
+
|
|
60
|
+
for inp, out in zip(inputs, outputs):
|
|
61
|
+
if format_style == "gemma":
|
|
62
|
+
text = (
|
|
63
|
+
f"<start_of_turn>user\n{inp}<end_of_turn>\n"
|
|
64
|
+
f"<start_of_turn>model\n{out}<end_of_turn>"
|
|
65
|
+
)
|
|
66
|
+
elif format_style == "llama":
|
|
67
|
+
text = (
|
|
68
|
+
f"<|begin_of_text|><|start_header_id|>user"
|
|
69
|
+
f"<|end_header_id|>\n\n{inp}<|eot_id|>"
|
|
70
|
+
f"<|start_header_id|>assistant<|end_header_id|>"
|
|
71
|
+
f"\n\n{out}<|eot_id|>"
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
text = f"Input: {inp}\nOutput: {out}"
|
|
75
|
+
|
|
76
|
+
formatted.append({"text": text})
|
|
77
|
+
|
|
78
|
+
return formatted
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def run_sft(
|
|
82
|
+
X: List[str],
|
|
83
|
+
y: List[str],
|
|
84
|
+
config: Optional[SFTConfig] = None,
|
|
85
|
+
validation_split: float = 0.0,
|
|
86
|
+
format_style: str = "gemma"
|
|
87
|
+
) -> str:
|
|
88
|
+
|
|
89
|
+
if config is None:
|
|
90
|
+
config = SFTConfig()
|
|
91
|
+
|
|
92
|
+
if len(X) != len(y):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"X and y must have same length: {len(X)} vs {len(y)}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
formatted_examples = format_training_examples(
|
|
98
|
+
X, y, format_style
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if validation_split > 0:
|
|
102
|
+
split_idx = int(len(formatted_examples) * (1 - validation_split))
|
|
103
|
+
train_examples = formatted_examples[:split_idx]
|
|
104
|
+
val_examples = formatted_examples[split_idx:]
|
|
105
|
+
print(
|
|
106
|
+
f"Split: {len(train_examples)} train, "
|
|
107
|
+
f"{len(val_examples)} val"
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
train_examples = formatted_examples
|
|
111
|
+
val_examples = []
|
|
112
|
+
|
|
113
|
+
dataset = Dataset.from_list(train_examples)
|
|
114
|
+
|
|
115
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
116
|
+
config.base_model_name,
|
|
117
|
+
trust_remote_code=True,
|
|
118
|
+
attn_implementation="eager"
|
|
119
|
+
)
|
|
120
|
+
model.config.use_cache = False
|
|
121
|
+
|
|
122
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
123
|
+
config.base_model_name,
|
|
124
|
+
trust_remote_code=True
|
|
125
|
+
)
|
|
126
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
127
|
+
tokenizer.padding_side = "right"
|
|
128
|
+
|
|
129
|
+
peft_config = LoraConfig(
|
|
130
|
+
r=config.lora_r,
|
|
131
|
+
lora_alpha=config.lora_alpha,
|
|
132
|
+
lora_dropout=config.lora_dropout,
|
|
133
|
+
target_modules=config.lora_target_modules,
|
|
134
|
+
bias="none",
|
|
135
|
+
task_type="CAUSAL_LM"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
training_args = TrainingArguments(
|
|
139
|
+
output_dir=config.output_model_path,
|
|
140
|
+
num_train_epochs=config.num_train_epochs,
|
|
141
|
+
per_device_train_batch_size=(
|
|
142
|
+
config.per_device_train_batch_size
|
|
143
|
+
),
|
|
144
|
+
gradient_accumulation_steps=(
|
|
145
|
+
config.gradient_accumulation_steps
|
|
146
|
+
),
|
|
147
|
+
optim=config.optim,
|
|
148
|
+
logging_steps=config.logging_steps,
|
|
149
|
+
learning_rate=config.learning_rate,
|
|
150
|
+
fp16=config.fp16,
|
|
151
|
+
bf16=config.bf16,
|
|
152
|
+
lr_scheduler_type=config.lr_scheduler_type,
|
|
153
|
+
group_by_length=True,
|
|
154
|
+
save_steps=config.save_steps,
|
|
155
|
+
weight_decay=config.weight_decay,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def formatting_func(example):
|
|
159
|
+
return example["text"]
|
|
160
|
+
|
|
161
|
+
trainer = SFTTrainer(
|
|
162
|
+
model=model,
|
|
163
|
+
train_dataset=dataset,
|
|
164
|
+
peft_config=peft_config,
|
|
165
|
+
args=training_args,
|
|
166
|
+
processing_class=tokenizer,
|
|
167
|
+
formatting_func=formatting_func
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
print(f"Training on {len(dataset)} examples")
|
|
171
|
+
trainer.train()
|
|
172
|
+
|
|
173
|
+
trainer.save_model(config.output_model_path)
|
|
174
|
+
print(f"Model saved to {config.output_model_path}")
|
|
175
|
+
|
|
176
|
+
return config.output_model_path
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def load_sft_model(model_path: str):
|
|
180
|
+
|
|
181
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
182
|
+
model_path,
|
|
183
|
+
torch_dtype=torch.float32,
|
|
184
|
+
device_map="auto",
|
|
185
|
+
attn_implementation="eager"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
189
|
+
model_path,
|
|
190
|
+
trust_remote_code=True
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if tokenizer.pad_token is None:
|
|
194
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
195
|
+
|
|
196
|
+
return model, tokenizer
|
|
197
|
+
def predict_sft(
|
|
198
|
+
model,
|
|
199
|
+
tokenizer,
|
|
200
|
+
prompt: str,
|
|
201
|
+
max_new_tokens: int = 128,
|
|
202
|
+
temperature: float = 0.7
|
|
203
|
+
) -> str:
|
|
204
|
+
|
|
205
|
+
device = next(model.parameters()).device
|
|
206
|
+
|
|
207
|
+
formatted_prompt = (
|
|
208
|
+
f"<start_of_turn>user\n{prompt}<end_of_turn>\n"
|
|
209
|
+
f"<start_of_turn>model\n"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
inputs = tokenizer(
|
|
213
|
+
formatted_prompt,
|
|
214
|
+
return_tensors="pt",
|
|
215
|
+
truncation=True,
|
|
216
|
+
max_length=512
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
input_ids = inputs.input_ids.to(device)
|
|
220
|
+
attention_mask = inputs.attention_mask.to(device)
|
|
221
|
+
|
|
222
|
+
with torch.no_grad():
|
|
223
|
+
outputs = model.generate(
|
|
224
|
+
input_ids=input_ids,
|
|
225
|
+
attention_mask=attention_mask,
|
|
226
|
+
max_new_tokens=max_new_tokens,
|
|
227
|
+
temperature=temperature,
|
|
228
|
+
do_sample=temperature > 0,
|
|
229
|
+
pad_token_id=tokenizer.eos_token_id
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
full_response = tokenizer.decode(
|
|
233
|
+
outputs[0],
|
|
234
|
+
skip_special_tokens=False
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if "<start_of_turn>model\n" in full_response:
|
|
238
|
+
response = full_response.split(
|
|
239
|
+
"<start_of_turn>model\n"
|
|
240
|
+
)[-1]
|
|
241
|
+
response = response.split("<end_of_turn>")[0].strip()
|
|
242
|
+
else:
|
|
243
|
+
response = tokenizer.decode(
|
|
244
|
+
outputs[0][len(input_ids[0]):],
|
|
245
|
+
skip_special_tokens=True
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return response
|