npcpy 1.0.26__py3-none-any.whl → 1.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. npcpy/__init__.py +0 -7
  2. npcpy/data/audio.py +16 -99
  3. npcpy/data/image.py +43 -42
  4. npcpy/data/load.py +83 -124
  5. npcpy/data/text.py +28 -28
  6. npcpy/data/video.py +8 -32
  7. npcpy/data/web.py +51 -23
  8. npcpy/ft/diff.py +110 -0
  9. npcpy/ft/ge.py +115 -0
  10. npcpy/ft/memory_trainer.py +171 -0
  11. npcpy/ft/model_ensembler.py +357 -0
  12. npcpy/ft/rl.py +360 -0
  13. npcpy/ft/sft.py +248 -0
  14. npcpy/ft/usft.py +128 -0
  15. npcpy/gen/audio_gen.py +24 -0
  16. npcpy/gen/embeddings.py +13 -13
  17. npcpy/gen/image_gen.py +262 -117
  18. npcpy/gen/response.py +615 -415
  19. npcpy/gen/video_gen.py +53 -7
  20. npcpy/llm_funcs.py +1869 -437
  21. npcpy/main.py +1 -1
  22. npcpy/memory/command_history.py +844 -510
  23. npcpy/memory/kg_vis.py +833 -0
  24. npcpy/memory/knowledge_graph.py +892 -1845
  25. npcpy/memory/memory_processor.py +81 -0
  26. npcpy/memory/search.py +188 -90
  27. npcpy/mix/debate.py +192 -3
  28. npcpy/npc_compiler.py +1672 -801
  29. npcpy/npc_sysenv.py +593 -1266
  30. npcpy/serve.py +3120 -0
  31. npcpy/sql/ai_function_tools.py +257 -0
  32. npcpy/sql/database_ai_adapters.py +186 -0
  33. npcpy/sql/database_ai_functions.py +163 -0
  34. npcpy/sql/model_runner.py +19 -19
  35. npcpy/sql/npcsql.py +706 -507
  36. npcpy/sql/sql_model_compiler.py +156 -0
  37. npcpy/tools.py +183 -0
  38. npcpy/work/plan.py +13 -279
  39. npcpy/work/trigger.py +3 -3
  40. npcpy-1.2.32.dist-info/METADATA +803 -0
  41. npcpy-1.2.32.dist-info/RECORD +54 -0
  42. npcpy/data/dataframes.py +0 -171
  43. npcpy/memory/deep_research.py +0 -125
  44. npcpy/memory/sleep.py +0 -557
  45. npcpy/modes/_state.py +0 -78
  46. npcpy/modes/alicanto.py +0 -1075
  47. npcpy/modes/guac.py +0 -785
  48. npcpy/modes/mcp_npcsh.py +0 -822
  49. npcpy/modes/npc.py +0 -213
  50. npcpy/modes/npcsh.py +0 -1158
  51. npcpy/modes/plonk.py +0 -409
  52. npcpy/modes/pti.py +0 -234
  53. npcpy/modes/serve.py +0 -1637
  54. npcpy/modes/spool.py +0 -312
  55. npcpy/modes/wander.py +0 -549
  56. npcpy/modes/yap.py +0 -572
  57. npcpy/npc_team/alicanto.npc +0 -2
  58. npcpy/npc_team/alicanto.png +0 -0
  59. npcpy/npc_team/assembly_lines/test_pipeline.py +0 -181
  60. npcpy/npc_team/corca.npc +0 -13
  61. npcpy/npc_team/foreman.npc +0 -7
  62. npcpy/npc_team/frederic.npc +0 -6
  63. npcpy/npc_team/frederic4.png +0 -0
  64. npcpy/npc_team/guac.png +0 -0
  65. npcpy/npc_team/jinxs/automator.jinx +0 -18
  66. npcpy/npc_team/jinxs/bash_executer.jinx +0 -31
  67. npcpy/npc_team/jinxs/calculator.jinx +0 -11
  68. npcpy/npc_team/jinxs/edit_file.jinx +0 -96
  69. npcpy/npc_team/jinxs/file_chat.jinx +0 -14
  70. npcpy/npc_team/jinxs/gui_controller.jinx +0 -28
  71. npcpy/npc_team/jinxs/image_generation.jinx +0 -29
  72. npcpy/npc_team/jinxs/internet_search.jinx +0 -30
  73. npcpy/npc_team/jinxs/local_search.jinx +0 -152
  74. npcpy/npc_team/jinxs/npcsh_executor.jinx +0 -31
  75. npcpy/npc_team/jinxs/python_executor.jinx +0 -8
  76. npcpy/npc_team/jinxs/screen_cap.jinx +0 -25
  77. npcpy/npc_team/jinxs/sql_executor.jinx +0 -33
  78. npcpy/npc_team/kadiefa.npc +0 -3
  79. npcpy/npc_team/kadiefa.png +0 -0
  80. npcpy/npc_team/npcsh.ctx +0 -9
  81. npcpy/npc_team/npcsh_sibiji.png +0 -0
  82. npcpy/npc_team/plonk.npc +0 -2
  83. npcpy/npc_team/plonk.png +0 -0
  84. npcpy/npc_team/plonkjr.npc +0 -2
  85. npcpy/npc_team/plonkjr.png +0 -0
  86. npcpy/npc_team/sibiji.npc +0 -5
  87. npcpy/npc_team/sibiji.png +0 -0
  88. npcpy/npc_team/spool.png +0 -0
  89. npcpy/npc_team/templates/analytics/celona.npc +0 -0
  90. npcpy/npc_team/templates/hr_support/raone.npc +0 -0
  91. npcpy/npc_team/templates/humanities/eriane.npc +0 -4
  92. npcpy/npc_team/templates/it_support/lineru.npc +0 -0
  93. npcpy/npc_team/templates/marketing/slean.npc +0 -4
  94. npcpy/npc_team/templates/philosophy/maurawa.npc +0 -0
  95. npcpy/npc_team/templates/sales/turnic.npc +0 -4
  96. npcpy/npc_team/templates/software/welxor.npc +0 -0
  97. npcpy/npc_team/yap.png +0 -0
  98. npcpy/routes.py +0 -958
  99. npcpy/work/mcp_helpers.py +0 -357
  100. npcpy/work/mcp_server.py +0 -194
  101. npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.npc +0 -2
  102. npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.png +0 -0
  103. npcpy-1.0.26.data/data/npcpy/npc_team/automator.jinx +0 -18
  104. npcpy-1.0.26.data/data/npcpy/npc_team/bash_executer.jinx +0 -31
  105. npcpy-1.0.26.data/data/npcpy/npc_team/calculator.jinx +0 -11
  106. npcpy-1.0.26.data/data/npcpy/npc_team/celona.npc +0 -0
  107. npcpy-1.0.26.data/data/npcpy/npc_team/corca.npc +0 -13
  108. npcpy-1.0.26.data/data/npcpy/npc_team/edit_file.jinx +0 -96
  109. npcpy-1.0.26.data/data/npcpy/npc_team/eriane.npc +0 -4
  110. npcpy-1.0.26.data/data/npcpy/npc_team/file_chat.jinx +0 -14
  111. npcpy-1.0.26.data/data/npcpy/npc_team/foreman.npc +0 -7
  112. npcpy-1.0.26.data/data/npcpy/npc_team/frederic.npc +0 -6
  113. npcpy-1.0.26.data/data/npcpy/npc_team/frederic4.png +0 -0
  114. npcpy-1.0.26.data/data/npcpy/npc_team/guac.png +0 -0
  115. npcpy-1.0.26.data/data/npcpy/npc_team/gui_controller.jinx +0 -28
  116. npcpy-1.0.26.data/data/npcpy/npc_team/image_generation.jinx +0 -29
  117. npcpy-1.0.26.data/data/npcpy/npc_team/internet_search.jinx +0 -30
  118. npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.npc +0 -3
  119. npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.png +0 -0
  120. npcpy-1.0.26.data/data/npcpy/npc_team/lineru.npc +0 -0
  121. npcpy-1.0.26.data/data/npcpy/npc_team/local_search.jinx +0 -152
  122. npcpy-1.0.26.data/data/npcpy/npc_team/maurawa.npc +0 -0
  123. npcpy-1.0.26.data/data/npcpy/npc_team/npcsh.ctx +0 -9
  124. npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_executor.jinx +0 -31
  125. npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_sibiji.png +0 -0
  126. npcpy-1.0.26.data/data/npcpy/npc_team/plonk.npc +0 -2
  127. npcpy-1.0.26.data/data/npcpy/npc_team/plonk.png +0 -0
  128. npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.npc +0 -2
  129. npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.png +0 -0
  130. npcpy-1.0.26.data/data/npcpy/npc_team/python_executor.jinx +0 -8
  131. npcpy-1.0.26.data/data/npcpy/npc_team/raone.npc +0 -0
  132. npcpy-1.0.26.data/data/npcpy/npc_team/screen_cap.jinx +0 -25
  133. npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.npc +0 -5
  134. npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.png +0 -0
  135. npcpy-1.0.26.data/data/npcpy/npc_team/slean.npc +0 -4
  136. npcpy-1.0.26.data/data/npcpy/npc_team/spool.png +0 -0
  137. npcpy-1.0.26.data/data/npcpy/npc_team/sql_executor.jinx +0 -33
  138. npcpy-1.0.26.data/data/npcpy/npc_team/test_pipeline.py +0 -181
  139. npcpy-1.0.26.data/data/npcpy/npc_team/turnic.npc +0 -4
  140. npcpy-1.0.26.data/data/npcpy/npc_team/welxor.npc +0 -0
  141. npcpy-1.0.26.data/data/npcpy/npc_team/yap.png +0 -0
  142. npcpy-1.0.26.dist-info/METADATA +0 -827
  143. npcpy-1.0.26.dist-info/RECORD +0 -139
  144. npcpy-1.0.26.dist-info/entry_points.txt +0 -11
  145. /npcpy/{modes → ft}/__init__.py +0 -0
  146. {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/WHEEL +0 -0
  147. {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/licenses/LICENSE +0 -0
  148. {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/top_level.txt +0 -0
npcpy/ft/rl.py ADDED
@@ -0,0 +1,360 @@
1
+ from dataclasses import dataclass
2
+
3
+ from datetime import datetime
4
+ import glob
5
+ import json
6
+ import os
7
+ import pandas as pd
8
+ try:
9
+ from datasets import Dataset
10
+
11
+ from peft import LoraConfig, PeftModel
12
+ import torch
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer
16
+ )
17
+ from trl import DPOTrainer, DPOConfig
18
+ except:
19
+ Dataset = None
20
+ PeftModel = None
21
+ DPOConfig = None
22
+ DPOTrainer = None
23
+ torch = None
24
+ AutoModelForCausalLM = None
25
+ AutoTokenizer = None
26
+
27
+
28
+ import random
29
+ from typing import List, Dict, Any, Optional, Callable
30
+ from npcpy.npc_compiler import NPC
31
+ from npcpy.llm_funcs import get_llm_response
32
+
33
+
34
+ @dataclass
35
+ class RLConfig:
36
+ base_model_name: str = "Qwen/Qwen3-0.6B"
37
+ adapter_path: str = "./rl_adapter"
38
+ max_iterations: int = 8
39
+ min_reward_gap: float = 0.4
40
+ num_train_epochs: int = 20
41
+ per_device_train_batch_size: int = 1
42
+ gradient_accumulation_steps: int = 2
43
+ learning_rate: float = 1e-6
44
+ beta: float = 0.5
45
+ max_length: int = 512
46
+ max_prompt_length: int = 256
47
+
48
+
49
+ class TaskExecutor:
50
+
51
+ def __init__(
52
+ self,
53
+ agent: NPC,
54
+ max_iterations: int = 8
55
+ ):
56
+ self.agent = agent
57
+ self.max_iterations = max_iterations
58
+
59
+ def execute_task(
60
+ self,
61
+ task_prompt: str
62
+ ) -> Dict[str, Any]:
63
+
64
+ messages = [
65
+ {
66
+ "role": "system",
67
+ "content": self.agent.primary_directive
68
+ }
69
+ ]
70
+
71
+ raw_responses = []
72
+ current_prompt = task_prompt
73
+
74
+ for i in range(self.max_iterations):
75
+ response_obj = self.agent.get_llm_response(
76
+ current_prompt,
77
+ messages=messages,
78
+ auto_process_tool_calls=True
79
+ )
80
+
81
+ raw_responses.append(response_obj)
82
+ messages = response_obj.get('messages', messages)
83
+
84
+ last_content = messages[-1].get('content', '')
85
+
86
+ if self._is_complete(last_content):
87
+ return {
88
+ "raw_responses": raw_responses,
89
+ "final_output": last_content,
90
+ "total_iterations": i + 1,
91
+ "completed": True
92
+ }
93
+
94
+ current_prompt = (
95
+ "Continue or provide final answer."
96
+ )
97
+
98
+ return {
99
+ "raw_responses": raw_responses,
100
+ "final_output": messages[-1].get('content', ''),
101
+ "total_iterations": self.max_iterations,
102
+ "completed": False
103
+ }
104
+
105
+ def _is_complete(self, content: str) -> bool:
106
+
107
+ completion_markers = [
108
+ "final answer:",
109
+ "conclusion:",
110
+ "result:",
111
+ "therefore",
112
+ "in summary"
113
+ ]
114
+ content_lower = content.lower()
115
+ return any(
116
+ marker in content_lower
117
+ for marker in completion_markers
118
+ )
119
+
120
+
121
+ def collect_traces(
122
+ tasks: List[Dict[str, Any]],
123
+ agents: List[NPC],
124
+ reward_fn: Callable[[Dict], float],
125
+ config: Optional[RLConfig] = None
126
+ ) -> List[Dict[str, Any]]:
127
+
128
+ if config is None:
129
+ config = RLConfig()
130
+
131
+ traces = []
132
+
133
+ for task in tasks:
134
+ task_prompt = task.get('prompt', task.get('input', ''))
135
+
136
+ for agent in agents:
137
+ executor = TaskExecutor(
138
+ agent,
139
+ max_iterations=config.max_iterations
140
+ )
141
+
142
+ result = executor.execute_task(task_prompt)
143
+
144
+ trace = {
145
+ "agent_name": agent.name,
146
+ "task_prompt": task_prompt,
147
+ "final_output": result['final_output'],
148
+ "total_iterations": result['total_iterations'],
149
+ "completed": result['completed'],
150
+ "task_metadata": task
151
+ }
152
+
153
+ trace['reward'] = reward_fn(trace)
154
+
155
+ traces.append(trace)
156
+
157
+ print(
158
+ f"Agent {agent.name}: "
159
+ f"Reward={trace['reward']:.2f}"
160
+ )
161
+
162
+ return traces
163
+
164
+
165
+ def create_preference_pairs(
166
+ traces: List[Dict[str, Any]],
167
+ min_reward_gap: float = 0.4
168
+ ) -> Dataset:
169
+
170
+ df = pd.DataFrame(traces)
171
+ df = df[df['reward'] > -1.0].copy()
172
+
173
+ if len(df) < 2:
174
+ return None
175
+
176
+ df = df.sort_values('reward', ascending=False)
177
+
178
+ top_quantile = df['reward'].quantile(
179
+ 0.8,
180
+ interpolation='higher'
181
+ )
182
+ low_quantile = df['reward'].quantile(
183
+ 0.2,
184
+ interpolation='lower'
185
+ )
186
+
187
+ high_traces = df[df['reward'] >= top_quantile]
188
+ low_traces = df[df['reward'] <= low_quantile]
189
+
190
+ pairs = []
191
+
192
+ for _, high_trace in high_traces.iterrows():
193
+ for _, low_trace in low_traces.iterrows():
194
+ reward_gap = (
195
+ high_trace['reward'] - low_trace['reward']
196
+ )
197
+
198
+ if reward_gap >= min_reward_gap:
199
+ pairs.append({
200
+ "prompt": str(high_trace['task_prompt']),
201
+ "chosen": str(high_trace['final_output']),
202
+ "rejected": str(low_trace['final_output'])
203
+ })
204
+
205
+ if len(pairs) < 5:
206
+ print(
207
+ f"Warning: Only {len(pairs)} pairs found. "
208
+ "May overfit."
209
+ )
210
+
211
+ return Dataset.from_list(pairs[:100])
212
+
213
+
214
+ def train_with_dpo(
215
+ traces: List[Dict[str, Any]],
216
+ config: Optional[RLConfig] = None
217
+ ) -> str:
218
+
219
+ if config is None:
220
+ config = RLConfig()
221
+
222
+ preference_dataset = create_preference_pairs(
223
+ traces,
224
+ min_reward_gap=config.min_reward_gap
225
+ )
226
+
227
+ if preference_dataset is None or len(preference_dataset) == 0:
228
+ print("No valid preference pairs. Cannot train.")
229
+ return None
230
+
231
+ model = AutoModelForCausalLM.from_pretrained(
232
+ config.base_model_name,
233
+ torch_dtype=torch.float32,
234
+ device_map="auto",
235
+ low_cpu_mem_usage=True
236
+ )
237
+
238
+ tokenizer = AutoTokenizer.from_pretrained(
239
+ config.base_model_name,
240
+ trust_remote_code=True
241
+ )
242
+
243
+ if tokenizer.pad_token is None:
244
+ tokenizer.pad_token = tokenizer.eos_token
245
+
246
+ peft_config = LoraConfig(
247
+ r=8,
248
+ lora_alpha=16,
249
+ lora_dropout=0.1,
250
+ bias="none",
251
+ task_type="CAUSAL_LM",
252
+ target_modules=[
253
+ "q_proj",
254
+ "k_proj",
255
+ "v_proj",
256
+ "o_proj"
257
+ ]
258
+ )
259
+
260
+ training_args = DPOConfig(
261
+ output_dir="./dpo_results",
262
+ per_device_train_batch_size=(
263
+ config.per_device_train_batch_size
264
+ ),
265
+ gradient_accumulation_steps=(
266
+ config.gradient_accumulation_steps
267
+ ),
268
+ learning_rate=config.learning_rate,
269
+ num_train_epochs=config.num_train_epochs,
270
+ weight_decay=0.1,
271
+ beta=config.beta,
272
+ logging_steps=2,
273
+ save_steps=10,
274
+ remove_unused_columns=False,
275
+ max_length=config.max_length,
276
+ max_prompt_length=config.max_prompt_length,
277
+ dataloader_num_workers=0,
278
+ fp16=False,
279
+ bf16=False,
280
+ optim="adamw_torch",
281
+ warmup_steps=2,
282
+ save_strategy="steps",
283
+ save_total_limit=3
284
+ )
285
+
286
+ trainer = DPOTrainer(
287
+ model,
288
+ args=training_args,
289
+ train_dataset=preference_dataset,
290
+ peft_config=peft_config
291
+ )
292
+
293
+ print("Starting DPO training...")
294
+ trainer.train()
295
+
296
+ trainer.save_model(config.adapter_path)
297
+ print(f"Adapter saved to {config.adapter_path}")
298
+
299
+ return config.adapter_path
300
+
301
+
302
+ def run_rl_training(
303
+ tasks: List[Dict[str, Any]],
304
+ agents: List[NPC],
305
+ reward_fn: Callable[[Dict], float],
306
+ config: Optional[RLConfig] = None,
307
+ save_traces: bool = True
308
+ ) -> str:
309
+
310
+ if config is None:
311
+ config = RLConfig()
312
+
313
+ print(f"Collecting traces from {len(tasks)} tasks...")
314
+ traces = collect_traces(
315
+ tasks,
316
+ agents,
317
+ reward_fn,
318
+ config
319
+ )
320
+
321
+ if save_traces:
322
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
323
+ traces_file = f"rl_traces_{timestamp}.csv"
324
+ df = pd.DataFrame(traces)
325
+ df.to_csv(traces_file, index=False)
326
+ print(f"Traces saved to {traces_file}")
327
+
328
+ print("Training with DPO...")
329
+ adapter_path = train_with_dpo(traces, config)
330
+
331
+ return adapter_path
332
+
333
+
334
+ def load_rl_model(
335
+ base_model_id: str,
336
+ adapter_path: str
337
+ ):
338
+
339
+ print(f"Loading base model: {base_model_id}")
340
+ model = AutoModelForCausalLM.from_pretrained(
341
+ base_model_id,
342
+ torch_dtype=torch.float32,
343
+ device_map="auto",
344
+ attn_implementation='eager'
345
+ )
346
+
347
+ tokenizer = AutoTokenizer.from_pretrained(
348
+ base_model_id,
349
+ trust_remote_code=True
350
+ )
351
+
352
+ if tokenizer.pad_token is None:
353
+ tokenizer.pad_token = tokenizer.eos_token
354
+
355
+ if adapter_path and os.path.exists(adapter_path):
356
+ print(f"Loading adapter: {adapter_path}")
357
+ model = PeftModel.from_pretrained(model, adapter_path)
358
+ model = model.merge_and_unload()
359
+
360
+ return model, tokenizer
npcpy/ft/sft.py ADDED
@@ -0,0 +1,248 @@
1
+ # structured fine tuning of LLMs to produce structured output
2
+ from dataclasses import dataclass, field
3
+ from datasets import Dataset
4
+ import json
5
+ import numpy as np
6
+ import os
7
+ try:
8
+ import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ TrainingArguments
13
+ )
14
+ from trl import SFTTrainer
15
+ from peft import LoraConfig
16
+ except:
17
+ torch = None
18
+ SFTTrainer = None
19
+ LoraConfig = None
20
+ AutoModelForCausalLM = None
21
+ AutoTokenizer = None
22
+ TrainingArguments = None
23
+
24
+ from typing import List, Dict, Any, Optional
25
+
26
+
27
+ @dataclass
28
+ class SFTConfig:
29
+ base_model_name: str = "google/gemma-3-270m-it"
30
+ output_model_path: str = "models/sft_model"
31
+ lora_r: int = 8
32
+ lora_alpha: int = 16
33
+ use_4bit: bool = False
34
+ fp16: bool = False
35
+ bf16: bool = False
36
+ lora_dropout: float = 0.15
37
+ lora_target_modules: List[str] = field(
38
+ default_factory=lambda: ["q_proj", "v_proj"]
39
+ )
40
+ num_train_epochs: int = 20
41
+ per_device_train_batch_size: int = 2
42
+ gradient_accumulation_steps: int = 4
43
+ learning_rate: float = 3e-5
44
+ logging_steps: int = 10
45
+ optim: str = "adamw_torch"
46
+ lr_scheduler_type: str = "cosine_with_restarts"
47
+ weight_decay: float = 0.01
48
+ max_length: int = 512
49
+ save_steps: int = 50
50
+
51
+
52
+ def format_training_examples(
53
+ inputs: List[str],
54
+ outputs: List[str],
55
+ format_style: str = "gemma"
56
+ ) -> List[Dict[str, str]]:
57
+
58
+ formatted = []
59
+
60
+ for inp, out in zip(inputs, outputs):
61
+ if format_style == "gemma":
62
+ text = (
63
+ f"<start_of_turn>user\n{inp}<end_of_turn>\n"
64
+ f"<start_of_turn>model\n{out}<end_of_turn>"
65
+ )
66
+ elif format_style == "llama":
67
+ text = (
68
+ f"<|begin_of_text|><|start_header_id|>user"
69
+ f"<|end_header_id|>\n\n{inp}<|eot_id|>"
70
+ f"<|start_header_id|>assistant<|end_header_id|>"
71
+ f"\n\n{out}<|eot_id|>"
72
+ )
73
+ else:
74
+ text = f"Input: {inp}\nOutput: {out}"
75
+
76
+ formatted.append({"text": text})
77
+
78
+ return formatted
79
+
80
+
81
+ def run_sft(
82
+ X: List[str],
83
+ y: List[str],
84
+ config: Optional[SFTConfig] = None,
85
+ validation_split: float = 0.0,
86
+ format_style: str = "gemma"
87
+ ) -> str:
88
+
89
+ if config is None:
90
+ config = SFTConfig()
91
+
92
+ if len(X) != len(y):
93
+ raise ValueError(
94
+ f"X and y must have same length: {len(X)} vs {len(y)}"
95
+ )
96
+
97
+ formatted_examples = format_training_examples(
98
+ X, y, format_style
99
+ )
100
+
101
+ if validation_split > 0:
102
+ split_idx = int(len(formatted_examples) * (1 - validation_split))
103
+ train_examples = formatted_examples[:split_idx]
104
+ val_examples = formatted_examples[split_idx:]
105
+ print(
106
+ f"Split: {len(train_examples)} train, "
107
+ f"{len(val_examples)} val"
108
+ )
109
+ else:
110
+ train_examples = formatted_examples
111
+ val_examples = []
112
+
113
+ dataset = Dataset.from_list(train_examples)
114
+
115
+ model = AutoModelForCausalLM.from_pretrained(
116
+ config.base_model_name,
117
+ trust_remote_code=True,
118
+ attn_implementation="eager"
119
+ )
120
+ model.config.use_cache = False
121
+
122
+ tokenizer = AutoTokenizer.from_pretrained(
123
+ config.base_model_name,
124
+ trust_remote_code=True
125
+ )
126
+ tokenizer.pad_token = tokenizer.eos_token
127
+ tokenizer.padding_side = "right"
128
+
129
+ peft_config = LoraConfig(
130
+ r=config.lora_r,
131
+ lora_alpha=config.lora_alpha,
132
+ lora_dropout=config.lora_dropout,
133
+ target_modules=config.lora_target_modules,
134
+ bias="none",
135
+ task_type="CAUSAL_LM"
136
+ )
137
+
138
+ training_args = TrainingArguments(
139
+ output_dir=config.output_model_path,
140
+ num_train_epochs=config.num_train_epochs,
141
+ per_device_train_batch_size=(
142
+ config.per_device_train_batch_size
143
+ ),
144
+ gradient_accumulation_steps=(
145
+ config.gradient_accumulation_steps
146
+ ),
147
+ optim=config.optim,
148
+ logging_steps=config.logging_steps,
149
+ learning_rate=config.learning_rate,
150
+ fp16=config.fp16,
151
+ bf16=config.bf16,
152
+ lr_scheduler_type=config.lr_scheduler_type,
153
+ group_by_length=True,
154
+ save_steps=config.save_steps,
155
+ weight_decay=config.weight_decay,
156
+ )
157
+
158
+ def formatting_func(example):
159
+ return example["text"]
160
+
161
+ trainer = SFTTrainer(
162
+ model=model,
163
+ train_dataset=dataset,
164
+ peft_config=peft_config,
165
+ args=training_args,
166
+ processing_class=tokenizer,
167
+ formatting_func=formatting_func
168
+ )
169
+
170
+ print(f"Training on {len(dataset)} examples")
171
+ trainer.train()
172
+
173
+ trainer.save_model(config.output_model_path)
174
+ print(f"Model saved to {config.output_model_path}")
175
+
176
+ return config.output_model_path
177
+
178
+
179
+ def load_sft_model(model_path: str):
180
+
181
+ model = AutoModelForCausalLM.from_pretrained(
182
+ model_path,
183
+ torch_dtype=torch.float32,
184
+ device_map="auto",
185
+ attn_implementation="eager"
186
+ )
187
+
188
+ tokenizer = AutoTokenizer.from_pretrained(
189
+ model_path,
190
+ trust_remote_code=True
191
+ )
192
+
193
+ if tokenizer.pad_token is None:
194
+ tokenizer.pad_token = tokenizer.eos_token
195
+
196
+ return model, tokenizer
197
+ def predict_sft(
198
+ model,
199
+ tokenizer,
200
+ prompt: str,
201
+ max_new_tokens: int = 128,
202
+ temperature: float = 0.7
203
+ ) -> str:
204
+
205
+ device = next(model.parameters()).device
206
+
207
+ formatted_prompt = (
208
+ f"<start_of_turn>user\n{prompt}<end_of_turn>\n"
209
+ f"<start_of_turn>model\n"
210
+ )
211
+
212
+ inputs = tokenizer(
213
+ formatted_prompt,
214
+ return_tensors="pt",
215
+ truncation=True,
216
+ max_length=512
217
+ )
218
+
219
+ input_ids = inputs.input_ids.to(device)
220
+ attention_mask = inputs.attention_mask.to(device)
221
+
222
+ with torch.no_grad():
223
+ outputs = model.generate(
224
+ input_ids=input_ids,
225
+ attention_mask=attention_mask,
226
+ max_new_tokens=max_new_tokens,
227
+ temperature=temperature,
228
+ do_sample=temperature > 0,
229
+ pad_token_id=tokenizer.eos_token_id
230
+ )
231
+
232
+ full_response = tokenizer.decode(
233
+ outputs[0],
234
+ skip_special_tokens=False
235
+ )
236
+
237
+ if "<start_of_turn>model\n" in full_response:
238
+ response = full_response.split(
239
+ "<start_of_turn>model\n"
240
+ )[-1]
241
+ response = response.split("<end_of_turn>")[0].strip()
242
+ else:
243
+ response = tokenizer.decode(
244
+ outputs[0][len(input_ids[0]):],
245
+ skip_special_tokens=True
246
+ )
247
+
248
+ return response