npcpy 1.1.28__py3-none-any.whl → 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/data/audio.py +16 -38
- npcpy/data/image.py +29 -29
- npcpy/data/load.py +4 -3
- npcpy/data/text.py +28 -28
- npcpy/data/video.py +6 -6
- npcpy/data/web.py +49 -21
- npcpy/ft/__init__.py +0 -0
- npcpy/ft/diff.py +110 -0
- npcpy/ft/ge.py +115 -0
- npcpy/ft/memory_trainer.py +171 -0
- npcpy/ft/model_ensembler.py +357 -0
- npcpy/ft/rl.py +360 -0
- npcpy/ft/sft.py +248 -0
- npcpy/ft/usft.py +128 -0
- npcpy/gen/audio_gen.py +24 -0
- npcpy/gen/embeddings.py +13 -13
- npcpy/gen/image_gen.py +37 -15
- npcpy/gen/response.py +287 -111
- npcpy/gen/video_gen.py +10 -9
- npcpy/llm_funcs.py +447 -79
- npcpy/memory/command_history.py +201 -48
- npcpy/memory/kg_vis.py +74 -74
- npcpy/memory/knowledge_graph.py +482 -115
- npcpy/memory/memory_processor.py +81 -0
- npcpy/memory/search.py +70 -70
- npcpy/mix/debate.py +192 -3
- npcpy/npc_compiler.py +1541 -879
- npcpy/npc_sysenv.py +250 -78
- npcpy/serve.py +1036 -321
- npcpy/sql/ai_function_tools.py +257 -0
- npcpy/sql/database_ai_adapters.py +186 -0
- npcpy/sql/database_ai_functions.py +163 -0
- npcpy/sql/model_runner.py +19 -19
- npcpy/sql/npcsql.py +706 -507
- npcpy/sql/sql_model_compiler.py +156 -0
- npcpy/tools.py +20 -20
- npcpy/work/plan.py +8 -8
- npcpy/work/trigger.py +3 -3
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/METADATA +169 -9
- npcpy-1.2.32.dist-info/RECORD +54 -0
- npcpy-1.1.28.dist-info/RECORD +0 -40
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/WHEEL +0 -0
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import copy
|
|
3
|
+
import random
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import List, Dict, Any, Optional
|
|
6
|
+
from npcpy.llm_funcs import get_llm_response
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from npcpy.ft.sft import predict_sft, load_sft_model
|
|
10
|
+
except:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ModelGene:
|
|
15
|
+
"""
|
|
16
|
+
Represents a specialized model with trigger patterns
|
|
17
|
+
and confidence threshold
|
|
18
|
+
"""
|
|
19
|
+
sft_path: Optional[str] = None
|
|
20
|
+
rl_path: Optional[str] = None
|
|
21
|
+
base_model: str = "Qwen/Qwen3-0.6B"
|
|
22
|
+
specialization: str = "general"
|
|
23
|
+
trigger_patterns: List[str] = field(default_factory=list)
|
|
24
|
+
confidence_threshold: float = 0.7
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def generate_trigger_patterns(specialization: str) -> List[str]:
|
|
28
|
+
"""
|
|
29
|
+
Generate trigger patterns for a given specialization domain
|
|
30
|
+
"""
|
|
31
|
+
patterns = {
|
|
32
|
+
'math': ['calculate', 'solve', 'equation', 'number'],
|
|
33
|
+
'code': ['function', 'class', 'bug', 'debug', 'code'],
|
|
34
|
+
'creative': ['story', 'poem', 'creative', 'imagine'],
|
|
35
|
+
'factual': ['what is', 'who is', 'when did', 'where is'],
|
|
36
|
+
'analysis': ['analyze', 'compare', 'evaluate', 'assess']
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
return patterns.get(specialization, ['general'])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def create_model_genome(
|
|
43
|
+
specializations: List[str],
|
|
44
|
+
base_model: str = "Qwen/Qwen3-0.6B"
|
|
45
|
+
) -> List[ModelGene]:
|
|
46
|
+
"""
|
|
47
|
+
Initialize a genome of specialized models
|
|
48
|
+
"""
|
|
49
|
+
genome = []
|
|
50
|
+
|
|
51
|
+
for spec in specializations:
|
|
52
|
+
gene = ModelGene(
|
|
53
|
+
base_model=base_model,
|
|
54
|
+
specialization=spec,
|
|
55
|
+
trigger_patterns=generate_trigger_patterns(spec),
|
|
56
|
+
confidence_threshold=random.uniform(0.6, 0.9)
|
|
57
|
+
)
|
|
58
|
+
genome.append(gene)
|
|
59
|
+
|
|
60
|
+
return genome
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def mutate_model_genome(
|
|
64
|
+
genome: List[ModelGene],
|
|
65
|
+
mutation_type: str = 'random'
|
|
66
|
+
) -> List[ModelGene]:
|
|
67
|
+
"""
|
|
68
|
+
Apply genetic mutation to model genome
|
|
69
|
+
"""
|
|
70
|
+
new_genome = copy.deepcopy(genome)
|
|
71
|
+
|
|
72
|
+
mutations = [
|
|
73
|
+
'adjust_threshold',
|
|
74
|
+
'add_trigger',
|
|
75
|
+
'remove_gene',
|
|
76
|
+
'duplicate_gene'
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
if mutation_type == 'random':
|
|
80
|
+
mutation_type = random.choice(mutations)
|
|
81
|
+
|
|
82
|
+
if mutation_type == 'adjust_threshold':
|
|
83
|
+
gene = random.choice(new_genome)
|
|
84
|
+
gene.confidence_threshold += random.uniform(-0.1, 0.1)
|
|
85
|
+
gene.confidence_threshold = max(
|
|
86
|
+
0.5,
|
|
87
|
+
min(0.95, gene.confidence_threshold)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
elif mutation_type == 'add_trigger':
|
|
91
|
+
gene = random.choice(new_genome)
|
|
92
|
+
new_trigger = f"pattern_{random.randint(1, 100)}"
|
|
93
|
+
if new_trigger not in gene.trigger_patterns:
|
|
94
|
+
gene.trigger_patterns.append(new_trigger)
|
|
95
|
+
|
|
96
|
+
elif mutation_type == 'remove_gene' and len(new_genome) > 1:
|
|
97
|
+
new_genome.pop(random.randint(0, len(new_genome) - 1))
|
|
98
|
+
|
|
99
|
+
elif mutation_type == 'duplicate_gene':
|
|
100
|
+
gene = random.choice(new_genome)
|
|
101
|
+
new_gene = copy.deepcopy(gene)
|
|
102
|
+
new_gene.specialization = f"{gene.specialization}_variant"
|
|
103
|
+
new_genome.append(new_gene)
|
|
104
|
+
|
|
105
|
+
return new_genome
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def crossover_model_genomes(
|
|
109
|
+
genome1: List[ModelGene],
|
|
110
|
+
genome2: List[ModelGene]
|
|
111
|
+
) -> List[ModelGene]:
|
|
112
|
+
"""
|
|
113
|
+
Crossover two model genomes to create child genome
|
|
114
|
+
"""
|
|
115
|
+
if not genome1 or not genome2:
|
|
116
|
+
return genome1 or genome2
|
|
117
|
+
|
|
118
|
+
split = random.randint(1, min(len(genome1), len(genome2)) - 1)
|
|
119
|
+
|
|
120
|
+
child = genome1[:split] + genome2[split:]
|
|
121
|
+
|
|
122
|
+
return child
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def evaluate_model_genome(
|
|
126
|
+
genome: List[ModelGene],
|
|
127
|
+
test_cases: List[Dict[str, Any]],
|
|
128
|
+
router: 'ResponseRouter'
|
|
129
|
+
) -> float:
|
|
130
|
+
"""
|
|
131
|
+
Evaluate fitness of a model genome based on accuracy,
|
|
132
|
+
speed and efficiency
|
|
133
|
+
"""
|
|
134
|
+
correct = 0
|
|
135
|
+
total_time = 0
|
|
136
|
+
fast_responses = 0
|
|
137
|
+
|
|
138
|
+
for test_case in test_cases:
|
|
139
|
+
result = router.route_query(
|
|
140
|
+
test_case['query'],
|
|
141
|
+
genome,
|
|
142
|
+
test_case.get('ground_truth')
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if result['correct']:
|
|
146
|
+
correct += 1
|
|
147
|
+
|
|
148
|
+
total_time += result['response_time']
|
|
149
|
+
|
|
150
|
+
if result['used_fast_path']:
|
|
151
|
+
fast_responses += 1
|
|
152
|
+
|
|
153
|
+
accuracy = correct / len(test_cases)
|
|
154
|
+
speed_bonus = fast_responses / len(test_cases)
|
|
155
|
+
efficiency = 1.0 / (total_time / len(test_cases))
|
|
156
|
+
|
|
157
|
+
fitness = (
|
|
158
|
+
accuracy * 0.6 +
|
|
159
|
+
speed_bonus * 0.2 +
|
|
160
|
+
efficiency * 0.2
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return fitness
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ResponseRouter:
|
|
167
|
+
"""
|
|
168
|
+
Routes queries through fast path, ensemble or full reasoning
|
|
169
|
+
based on confidence thresholds
|
|
170
|
+
"""
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
fast_threshold: float = 0.8,
|
|
174
|
+
ensemble_threshold: float = 0.6
|
|
175
|
+
):
|
|
176
|
+
self.fast_threshold = fast_threshold
|
|
177
|
+
self.ensemble_threshold = ensemble_threshold
|
|
178
|
+
self.response_cache = {}
|
|
179
|
+
|
|
180
|
+
def route_query(
|
|
181
|
+
self,
|
|
182
|
+
query: str,
|
|
183
|
+
genome: List[ModelGene],
|
|
184
|
+
ground_truth: Optional[str] = None
|
|
185
|
+
) -> Dict[str, Any]:
|
|
186
|
+
"""
|
|
187
|
+
Route query through system 1 fast path,
|
|
188
|
+
ensemble or system 2 reasoning
|
|
189
|
+
"""
|
|
190
|
+
start_time = time.time()
|
|
191
|
+
|
|
192
|
+
fast_response = self._try_fast_path(query, genome)
|
|
193
|
+
|
|
194
|
+
if fast_response and fast_response['confidence'] > (
|
|
195
|
+
self.fast_threshold
|
|
196
|
+
):
|
|
197
|
+
response_time = time.time() - start_time
|
|
198
|
+
|
|
199
|
+
return {
|
|
200
|
+
'response': fast_response['answer'],
|
|
201
|
+
'confidence': fast_response['confidence'],
|
|
202
|
+
'used_fast_path': True,
|
|
203
|
+
'response_time': response_time,
|
|
204
|
+
'correct': (
|
|
205
|
+
ground_truth is None or
|
|
206
|
+
self._check_correctness(
|
|
207
|
+
fast_response['answer'],
|
|
208
|
+
ground_truth
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
ensemble_response = self._try_ensemble(query, genome)
|
|
214
|
+
|
|
215
|
+
if ensemble_response['confidence'] > (
|
|
216
|
+
self.ensemble_threshold
|
|
217
|
+
):
|
|
218
|
+
response_time = time.time() - start_time
|
|
219
|
+
|
|
220
|
+
return {
|
|
221
|
+
'response': ensemble_response['answer'],
|
|
222
|
+
'confidence': ensemble_response['confidence'],
|
|
223
|
+
'used_fast_path': False,
|
|
224
|
+
'used_ensemble': True,
|
|
225
|
+
'response_time': response_time,
|
|
226
|
+
'correct': (
|
|
227
|
+
ground_truth is None or
|
|
228
|
+
self._check_correctness(
|
|
229
|
+
ensemble_response['answer'],
|
|
230
|
+
ground_truth
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
full_response = self._full_reasoning(query)
|
|
236
|
+
response_time = time.time() - start_time
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
'response': full_response,
|
|
240
|
+
'confidence': 0.5,
|
|
241
|
+
'used_fast_path': False,
|
|
242
|
+
'used_ensemble': False,
|
|
243
|
+
'response_time': response_time,
|
|
244
|
+
'correct': (
|
|
245
|
+
ground_truth is None or
|
|
246
|
+
self._check_correctness(
|
|
247
|
+
full_response,
|
|
248
|
+
ground_truth
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
def _try_fast_path(
|
|
254
|
+
self,
|
|
255
|
+
query: str,
|
|
256
|
+
genome: List[ModelGene]
|
|
257
|
+
) -> Optional[Dict[str, Any]]:
|
|
258
|
+
"""
|
|
259
|
+
Try fast system 1 gut reaction using pattern matching
|
|
260
|
+
"""
|
|
261
|
+
query_lower = query.lower()
|
|
262
|
+
|
|
263
|
+
for gene in genome:
|
|
264
|
+
if any(
|
|
265
|
+
pattern in query_lower
|
|
266
|
+
for pattern in gene.trigger_patterns
|
|
267
|
+
):
|
|
268
|
+
if gene.sft_path:
|
|
269
|
+
model, tokenizer = load_sft_model(gene.sft_path)
|
|
270
|
+
|
|
271
|
+
response = predict_sft(
|
|
272
|
+
model,
|
|
273
|
+
tokenizer,
|
|
274
|
+
query,
|
|
275
|
+
temperature=0.1
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return {
|
|
279
|
+
'answer': response,
|
|
280
|
+
'confidence': gene.confidence_threshold
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
return None
|
|
284
|
+
|
|
285
|
+
def _try_ensemble(
|
|
286
|
+
self,
|
|
287
|
+
query: str,
|
|
288
|
+
genome: List[ModelGene]
|
|
289
|
+
) -> Dict[str, Any]:
|
|
290
|
+
"""
|
|
291
|
+
Try ensemble voting across specialized models
|
|
292
|
+
"""
|
|
293
|
+
responses = []
|
|
294
|
+
|
|
295
|
+
for gene in genome:
|
|
296
|
+
if gene.sft_path or gene.rl_path:
|
|
297
|
+
model_path = gene.rl_path or gene.sft_path
|
|
298
|
+
|
|
299
|
+
model, tokenizer = load_sft_model(model_path)
|
|
300
|
+
|
|
301
|
+
response = predict_sft(
|
|
302
|
+
model,
|
|
303
|
+
tokenizer,
|
|
304
|
+
query,
|
|
305
|
+
temperature=0.3
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
responses.append({
|
|
309
|
+
'answer': response,
|
|
310
|
+
'weight': gene.confidence_threshold
|
|
311
|
+
})
|
|
312
|
+
|
|
313
|
+
if not responses:
|
|
314
|
+
return {'answer': '', 'confidence': 0.0}
|
|
315
|
+
|
|
316
|
+
best_response = max(responses, key=lambda x: x['weight'])
|
|
317
|
+
|
|
318
|
+
avg_confidence = sum(
|
|
319
|
+
r['weight'] for r in responses
|
|
320
|
+
) / len(responses)
|
|
321
|
+
|
|
322
|
+
return {
|
|
323
|
+
'answer': best_response['answer'],
|
|
324
|
+
'confidence': avg_confidence
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
def _full_reasoning(
|
|
328
|
+
self,
|
|
329
|
+
query: str,
|
|
330
|
+
model: str = "qwen3:1.7b",
|
|
331
|
+
provider: str = "ollama"
|
|
332
|
+
) -> str:
|
|
333
|
+
"""
|
|
334
|
+
Fall back to full system 2 reasoning
|
|
335
|
+
"""
|
|
336
|
+
response = get_llm_response(
|
|
337
|
+
query,
|
|
338
|
+
model=model,
|
|
339
|
+
provider=provider
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return response.get('response', '')
|
|
343
|
+
|
|
344
|
+
def _check_correctness(
|
|
345
|
+
self,
|
|
346
|
+
response: str,
|
|
347
|
+
ground_truth: str
|
|
348
|
+
) -> bool:
|
|
349
|
+
"""
|
|
350
|
+
Check if response matches ground truth
|
|
351
|
+
"""
|
|
352
|
+
response_lower = response.lower().strip()
|
|
353
|
+
truth_lower = ground_truth.lower().strip()
|
|
354
|
+
|
|
355
|
+
return response_lower == truth_lower or (
|
|
356
|
+
truth_lower in response_lower
|
|
357
|
+
)
|
npcpy/ft/rl.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import glob
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import pandas as pd
|
|
8
|
+
try:
|
|
9
|
+
from datasets import Dataset
|
|
10
|
+
|
|
11
|
+
from peft import LoraConfig, PeftModel
|
|
12
|
+
import torch
|
|
13
|
+
from transformers import (
|
|
14
|
+
AutoModelForCausalLM,
|
|
15
|
+
AutoTokenizer
|
|
16
|
+
)
|
|
17
|
+
from trl import DPOTrainer, DPOConfig
|
|
18
|
+
except:
|
|
19
|
+
Dataset = None
|
|
20
|
+
PeftModel = None
|
|
21
|
+
DPOConfig = None
|
|
22
|
+
DPOTrainer = None
|
|
23
|
+
torch = None
|
|
24
|
+
AutoModelForCausalLM = None
|
|
25
|
+
AutoTokenizer = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
import random
|
|
29
|
+
from typing import List, Dict, Any, Optional, Callable
|
|
30
|
+
from npcpy.npc_compiler import NPC
|
|
31
|
+
from npcpy.llm_funcs import get_llm_response
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class RLConfig:
|
|
36
|
+
base_model_name: str = "Qwen/Qwen3-0.6B"
|
|
37
|
+
adapter_path: str = "./rl_adapter"
|
|
38
|
+
max_iterations: int = 8
|
|
39
|
+
min_reward_gap: float = 0.4
|
|
40
|
+
num_train_epochs: int = 20
|
|
41
|
+
per_device_train_batch_size: int = 1
|
|
42
|
+
gradient_accumulation_steps: int = 2
|
|
43
|
+
learning_rate: float = 1e-6
|
|
44
|
+
beta: float = 0.5
|
|
45
|
+
max_length: int = 512
|
|
46
|
+
max_prompt_length: int = 256
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TaskExecutor:
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
agent: NPC,
|
|
54
|
+
max_iterations: int = 8
|
|
55
|
+
):
|
|
56
|
+
self.agent = agent
|
|
57
|
+
self.max_iterations = max_iterations
|
|
58
|
+
|
|
59
|
+
def execute_task(
|
|
60
|
+
self,
|
|
61
|
+
task_prompt: str
|
|
62
|
+
) -> Dict[str, Any]:
|
|
63
|
+
|
|
64
|
+
messages = [
|
|
65
|
+
{
|
|
66
|
+
"role": "system",
|
|
67
|
+
"content": self.agent.primary_directive
|
|
68
|
+
}
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
raw_responses = []
|
|
72
|
+
current_prompt = task_prompt
|
|
73
|
+
|
|
74
|
+
for i in range(self.max_iterations):
|
|
75
|
+
response_obj = self.agent.get_llm_response(
|
|
76
|
+
current_prompt,
|
|
77
|
+
messages=messages,
|
|
78
|
+
auto_process_tool_calls=True
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
raw_responses.append(response_obj)
|
|
82
|
+
messages = response_obj.get('messages', messages)
|
|
83
|
+
|
|
84
|
+
last_content = messages[-1].get('content', '')
|
|
85
|
+
|
|
86
|
+
if self._is_complete(last_content):
|
|
87
|
+
return {
|
|
88
|
+
"raw_responses": raw_responses,
|
|
89
|
+
"final_output": last_content,
|
|
90
|
+
"total_iterations": i + 1,
|
|
91
|
+
"completed": True
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
current_prompt = (
|
|
95
|
+
"Continue or provide final answer."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return {
|
|
99
|
+
"raw_responses": raw_responses,
|
|
100
|
+
"final_output": messages[-1].get('content', ''),
|
|
101
|
+
"total_iterations": self.max_iterations,
|
|
102
|
+
"completed": False
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def _is_complete(self, content: str) -> bool:
|
|
106
|
+
|
|
107
|
+
completion_markers = [
|
|
108
|
+
"final answer:",
|
|
109
|
+
"conclusion:",
|
|
110
|
+
"result:",
|
|
111
|
+
"therefore",
|
|
112
|
+
"in summary"
|
|
113
|
+
]
|
|
114
|
+
content_lower = content.lower()
|
|
115
|
+
return any(
|
|
116
|
+
marker in content_lower
|
|
117
|
+
for marker in completion_markers
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def collect_traces(
|
|
122
|
+
tasks: List[Dict[str, Any]],
|
|
123
|
+
agents: List[NPC],
|
|
124
|
+
reward_fn: Callable[[Dict], float],
|
|
125
|
+
config: Optional[RLConfig] = None
|
|
126
|
+
) -> List[Dict[str, Any]]:
|
|
127
|
+
|
|
128
|
+
if config is None:
|
|
129
|
+
config = RLConfig()
|
|
130
|
+
|
|
131
|
+
traces = []
|
|
132
|
+
|
|
133
|
+
for task in tasks:
|
|
134
|
+
task_prompt = task.get('prompt', task.get('input', ''))
|
|
135
|
+
|
|
136
|
+
for agent in agents:
|
|
137
|
+
executor = TaskExecutor(
|
|
138
|
+
agent,
|
|
139
|
+
max_iterations=config.max_iterations
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
result = executor.execute_task(task_prompt)
|
|
143
|
+
|
|
144
|
+
trace = {
|
|
145
|
+
"agent_name": agent.name,
|
|
146
|
+
"task_prompt": task_prompt,
|
|
147
|
+
"final_output": result['final_output'],
|
|
148
|
+
"total_iterations": result['total_iterations'],
|
|
149
|
+
"completed": result['completed'],
|
|
150
|
+
"task_metadata": task
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
trace['reward'] = reward_fn(trace)
|
|
154
|
+
|
|
155
|
+
traces.append(trace)
|
|
156
|
+
|
|
157
|
+
print(
|
|
158
|
+
f"Agent {agent.name}: "
|
|
159
|
+
f"Reward={trace['reward']:.2f}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return traces
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def create_preference_pairs(
|
|
166
|
+
traces: List[Dict[str, Any]],
|
|
167
|
+
min_reward_gap: float = 0.4
|
|
168
|
+
) -> Dataset:
|
|
169
|
+
|
|
170
|
+
df = pd.DataFrame(traces)
|
|
171
|
+
df = df[df['reward'] > -1.0].copy()
|
|
172
|
+
|
|
173
|
+
if len(df) < 2:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
df = df.sort_values('reward', ascending=False)
|
|
177
|
+
|
|
178
|
+
top_quantile = df['reward'].quantile(
|
|
179
|
+
0.8,
|
|
180
|
+
interpolation='higher'
|
|
181
|
+
)
|
|
182
|
+
low_quantile = df['reward'].quantile(
|
|
183
|
+
0.2,
|
|
184
|
+
interpolation='lower'
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
high_traces = df[df['reward'] >= top_quantile]
|
|
188
|
+
low_traces = df[df['reward'] <= low_quantile]
|
|
189
|
+
|
|
190
|
+
pairs = []
|
|
191
|
+
|
|
192
|
+
for _, high_trace in high_traces.iterrows():
|
|
193
|
+
for _, low_trace in low_traces.iterrows():
|
|
194
|
+
reward_gap = (
|
|
195
|
+
high_trace['reward'] - low_trace['reward']
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if reward_gap >= min_reward_gap:
|
|
199
|
+
pairs.append({
|
|
200
|
+
"prompt": str(high_trace['task_prompt']),
|
|
201
|
+
"chosen": str(high_trace['final_output']),
|
|
202
|
+
"rejected": str(low_trace['final_output'])
|
|
203
|
+
})
|
|
204
|
+
|
|
205
|
+
if len(pairs) < 5:
|
|
206
|
+
print(
|
|
207
|
+
f"Warning: Only {len(pairs)} pairs found. "
|
|
208
|
+
"May overfit."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return Dataset.from_list(pairs[:100])
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def train_with_dpo(
|
|
215
|
+
traces: List[Dict[str, Any]],
|
|
216
|
+
config: Optional[RLConfig] = None
|
|
217
|
+
) -> str:
|
|
218
|
+
|
|
219
|
+
if config is None:
|
|
220
|
+
config = RLConfig()
|
|
221
|
+
|
|
222
|
+
preference_dataset = create_preference_pairs(
|
|
223
|
+
traces,
|
|
224
|
+
min_reward_gap=config.min_reward_gap
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if preference_dataset is None or len(preference_dataset) == 0:
|
|
228
|
+
print("No valid preference pairs. Cannot train.")
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
232
|
+
config.base_model_name,
|
|
233
|
+
torch_dtype=torch.float32,
|
|
234
|
+
device_map="auto",
|
|
235
|
+
low_cpu_mem_usage=True
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
239
|
+
config.base_model_name,
|
|
240
|
+
trust_remote_code=True
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if tokenizer.pad_token is None:
|
|
244
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
245
|
+
|
|
246
|
+
peft_config = LoraConfig(
|
|
247
|
+
r=8,
|
|
248
|
+
lora_alpha=16,
|
|
249
|
+
lora_dropout=0.1,
|
|
250
|
+
bias="none",
|
|
251
|
+
task_type="CAUSAL_LM",
|
|
252
|
+
target_modules=[
|
|
253
|
+
"q_proj",
|
|
254
|
+
"k_proj",
|
|
255
|
+
"v_proj",
|
|
256
|
+
"o_proj"
|
|
257
|
+
]
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
training_args = DPOConfig(
|
|
261
|
+
output_dir="./dpo_results",
|
|
262
|
+
per_device_train_batch_size=(
|
|
263
|
+
config.per_device_train_batch_size
|
|
264
|
+
),
|
|
265
|
+
gradient_accumulation_steps=(
|
|
266
|
+
config.gradient_accumulation_steps
|
|
267
|
+
),
|
|
268
|
+
learning_rate=config.learning_rate,
|
|
269
|
+
num_train_epochs=config.num_train_epochs,
|
|
270
|
+
weight_decay=0.1,
|
|
271
|
+
beta=config.beta,
|
|
272
|
+
logging_steps=2,
|
|
273
|
+
save_steps=10,
|
|
274
|
+
remove_unused_columns=False,
|
|
275
|
+
max_length=config.max_length,
|
|
276
|
+
max_prompt_length=config.max_prompt_length,
|
|
277
|
+
dataloader_num_workers=0,
|
|
278
|
+
fp16=False,
|
|
279
|
+
bf16=False,
|
|
280
|
+
optim="adamw_torch",
|
|
281
|
+
warmup_steps=2,
|
|
282
|
+
save_strategy="steps",
|
|
283
|
+
save_total_limit=3
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
trainer = DPOTrainer(
|
|
287
|
+
model,
|
|
288
|
+
args=training_args,
|
|
289
|
+
train_dataset=preference_dataset,
|
|
290
|
+
peft_config=peft_config
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
print("Starting DPO training...")
|
|
294
|
+
trainer.train()
|
|
295
|
+
|
|
296
|
+
trainer.save_model(config.adapter_path)
|
|
297
|
+
print(f"Adapter saved to {config.adapter_path}")
|
|
298
|
+
|
|
299
|
+
return config.adapter_path
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def run_rl_training(
|
|
303
|
+
tasks: List[Dict[str, Any]],
|
|
304
|
+
agents: List[NPC],
|
|
305
|
+
reward_fn: Callable[[Dict], float],
|
|
306
|
+
config: Optional[RLConfig] = None,
|
|
307
|
+
save_traces: bool = True
|
|
308
|
+
) -> str:
|
|
309
|
+
|
|
310
|
+
if config is None:
|
|
311
|
+
config = RLConfig()
|
|
312
|
+
|
|
313
|
+
print(f"Collecting traces from {len(tasks)} tasks...")
|
|
314
|
+
traces = collect_traces(
|
|
315
|
+
tasks,
|
|
316
|
+
agents,
|
|
317
|
+
reward_fn,
|
|
318
|
+
config
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if save_traces:
|
|
322
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
323
|
+
traces_file = f"rl_traces_{timestamp}.csv"
|
|
324
|
+
df = pd.DataFrame(traces)
|
|
325
|
+
df.to_csv(traces_file, index=False)
|
|
326
|
+
print(f"Traces saved to {traces_file}")
|
|
327
|
+
|
|
328
|
+
print("Training with DPO...")
|
|
329
|
+
adapter_path = train_with_dpo(traces, config)
|
|
330
|
+
|
|
331
|
+
return adapter_path
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def load_rl_model(
|
|
335
|
+
base_model_id: str,
|
|
336
|
+
adapter_path: str
|
|
337
|
+
):
|
|
338
|
+
|
|
339
|
+
print(f"Loading base model: {base_model_id}")
|
|
340
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
341
|
+
base_model_id,
|
|
342
|
+
torch_dtype=torch.float32,
|
|
343
|
+
device_map="auto",
|
|
344
|
+
attn_implementation='eager'
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
348
|
+
base_model_id,
|
|
349
|
+
trust_remote_code=True
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if tokenizer.pad_token is None:
|
|
353
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
354
|
+
|
|
355
|
+
if adapter_path and os.path.exists(adapter_path):
|
|
356
|
+
print(f"Loading adapter: {adapter_path}")
|
|
357
|
+
model = PeftModel.from_pretrained(model, adapter_path)
|
|
358
|
+
model = model.merge_and_unload()
|
|
359
|
+
|
|
360
|
+
return model, tokenizer
|