titan-synapse 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CONTRIBUTING.md +187 -0
- package/Cargo.lock +3976 -0
- package/Cargo.toml +10 -0
- package/LICENSE +190 -0
- package/PROGRESS.md +151 -0
- package/README.md +514 -0
- package/TEST_LOG.md +220 -0
- package/config/default.yaml +36 -0
- package/crates/synapse/Cargo.toml +70 -0
- package/crates/synapse/src/cli/bench.rs +44 -0
- package/crates/synapse/src/cli/eval.rs +395 -0
- package/crates/synapse/src/cli/export.rs +45 -0
- package/crates/synapse/src/cli/hub.rs +179 -0
- package/crates/synapse/src/cli/import.rs +35 -0
- package/crates/synapse/src/cli/learn.rs +53 -0
- package/crates/synapse/src/cli/mod.rs +10 -0
- package/crates/synapse/src/cli/models.rs +36 -0
- package/crates/synapse/src/cli/pull.rs +60 -0
- package/crates/synapse/src/cli/status.rs +52 -0
- package/crates/synapse/src/cli/train.rs +99 -0
- package/crates/synapse/src/config.rs +220 -0
- package/crates/synapse/src/dashboard.rs +281 -0
- package/crates/synapse/src/format/manifest.rs +57 -0
- package/crates/synapse/src/format/mod.rs +4 -0
- package/crates/synapse/src/format/packer.rs +213 -0
- package/crates/synapse/src/inference/engine.rs +361 -0
- package/crates/synapse/src/inference/kv_cache.rs +97 -0
- package/crates/synapse/src/inference/lora.rs +166 -0
- package/crates/synapse/src/inference/mod.rs +9 -0
- package/crates/synapse/src/inference/model.rs +167 -0
- package/crates/synapse/src/inference/sampler.rs +133 -0
- package/crates/synapse/src/inference/speculative.rs +153 -0
- package/crates/synapse/src/learn/cloud_fallback.rs +186 -0
- package/crates/synapse/src/learn/engine.rs +109 -0
- package/crates/synapse/src/learn/mod.rs +5 -0
- package/crates/synapse/src/main.rs +185 -0
- package/crates/synapse/src/memory/extractor.rs +201 -0
- package/crates/synapse/src/memory/graph.rs +332 -0
- package/crates/synapse/src/memory/hallucination.rs +259 -0
- package/crates/synapse/src/memory/mod.rs +7 -0
- package/crates/synapse/src/openai.rs +232 -0
- package/crates/synapse/src/server.rs +166 -0
- package/crates/synapse/src/streaming.rs +80 -0
- package/crates/synapse/src/swarm/coordinator.rs +198 -0
- package/crates/synapse/src/swarm/mod.rs +8 -0
- package/crates/synapse/src/swarm/orchestrator.rs +225 -0
- package/crates/synapse/src/swarm/pool.rs +64 -0
- package/crates/synapse/src/swarm/spawner.rs +199 -0
- package/crates/synapse/src/swarm/synthesizer.rs +26 -0
- package/crates/synapse/src/vram/manager.rs +67 -0
- package/crates/synapse/src/vram/mod.rs +3 -0
- package/docker-compose.yml +19 -0
- package/install.sh +311 -0
- package/package.json +36 -0
- package/python/Dockerfile.learn +18 -0
- package/python/requirements.txt +11 -0
- package/python/synapse_learn/__init__.py +0 -0
- package/python/synapse_learn/datasets.py +233 -0
- package/python/synapse_learn/real_eval.py +616 -0
- package/python/synapse_learn/server.py +431 -0
- package/python/synapse_learn/train_base.py +672 -0
- package/python/synapse_learn/train_specialists.py +787 -0
|
@@ -0,0 +1,787 @@
|
|
|
1
|
+
"""TITAN Synapse — Real Specialist Training Pipeline
|
|
2
|
+
|
|
3
|
+
Downloads actual high-quality datasets and trains QLoRA specialist adapters
|
|
4
|
+
that will measurably improve benchmark scores.
|
|
5
|
+
|
|
6
|
+
Target improvements:
|
|
7
|
+
- HumanEval: 65.2% → 75%+ (code specialist trained on CodeAlpaca + Evol-Instruct)
|
|
8
|
+
- GSM8K: 83.7% → 90%+ (math specialist trained on MetaMathQA + Orca-Math)
|
|
9
|
+
- MMLU: 61.9% → 65%+ (general specialist trained on SlimOrca + OpenHermes)
|
|
10
|
+
- TruthfulQA: 89.1% → maintain (honesty/refusal training)
|
|
11
|
+
|
|
12
|
+
Hardware: RTX 5090 32GB VRAM
|
|
13
|
+
Training time: ~2-4 hours per specialist (10-20k samples, 2 epochs)
|
|
14
|
+
|
|
15
|
+
Usage:
|
|
16
|
+
python train_specialists.py --specialist all
|
|
17
|
+
python train_specialists.py --specialist code
|
|
18
|
+
python train_specialists.py --specialist math
|
|
19
|
+
python train_specialists.py --specialist general
|
|
20
|
+
python train_specialists.py --specialist coordinator
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
import sys
|
|
25
|
+
import json
|
|
26
|
+
import logging
|
|
27
|
+
import argparse
|
|
28
|
+
import time
|
|
29
|
+
import gc
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from typing import Optional
|
|
33
|
+
|
|
34
|
+
# Fix import path (same as real_eval.py)
|
|
35
|
+
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
36
|
+
if _script_dir in sys.path:
|
|
37
|
+
sys.path.remove(_script_dir)
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
from datasets import load_dataset, Dataset, concatenate_datasets
|
|
41
|
+
from transformers import (
|
|
42
|
+
AutoTokenizer,
|
|
43
|
+
AutoModelForCausalLM,
|
|
44
|
+
BitsAndBytesConfig,
|
|
45
|
+
TrainingArguments,
|
|
46
|
+
)
|
|
47
|
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
|
48
|
+
from trl import SFTTrainer, SFTConfig
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger("synapse-train")
|
|
51
|
+
logging.basicConfig(
|
|
52
|
+
level=logging.INFO,
|
|
53
|
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
DATA_DIR = Path(os.environ.get("SYNAPSE_DATA_DIR", os.path.expanduser("~/.synapse")))
|
|
57
|
+
ADAPTERS_DIR = DATA_DIR / "adapters"
|
|
58
|
+
TRAINING_DIR = DATA_DIR / "training"
|
|
59
|
+
CACHE_DIR = DATA_DIR / "hf_cache"
|
|
60
|
+
|
|
61
|
+
for d in [ADAPTERS_DIR, TRAINING_DIR, CACHE_DIR]:
|
|
62
|
+
d.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
|
|
65
|
+
|
|
66
|
+
# ============================================================
|
|
67
|
+
# Dataset Definitions — what we train each specialist on
|
|
68
|
+
# ============================================================
|
|
69
|
+
|
|
70
|
+
SPECIALIST_DATASETS = {
|
|
71
|
+
"code": {
|
|
72
|
+
"description": "Code generation specialist — targets HumanEval improvement",
|
|
73
|
+
"datasets": [
|
|
74
|
+
{
|
|
75
|
+
"name": "sahil2801/CodeAlpaca-20k",
|
|
76
|
+
"split": "train",
|
|
77
|
+
"samples": 20000,
|
|
78
|
+
"format": "alpaca", # instruction/input/output
|
|
79
|
+
},
|
|
80
|
+
{
|
|
81
|
+
"name": "nickrosh/Evol-Instruct-Code-80k-v1",
|
|
82
|
+
"split": "train",
|
|
83
|
+
"samples": 20000,
|
|
84
|
+
"format": "evol", # instruction/output
|
|
85
|
+
},
|
|
86
|
+
{
|
|
87
|
+
"name": "iamtarun/python_code_instructions_18k_alpaca",
|
|
88
|
+
"split": "train",
|
|
89
|
+
"samples": 18000,
|
|
90
|
+
"format": "alpaca",
|
|
91
|
+
},
|
|
92
|
+
],
|
|
93
|
+
"system_prompt": "You are an expert programmer. Write clean, correct, well-tested code. Always include proper error handling.",
|
|
94
|
+
"lora_rank": 64,
|
|
95
|
+
"epochs": 2,
|
|
96
|
+
"max_seq_length": 2048,
|
|
97
|
+
"learning_rate": 2e-4,
|
|
98
|
+
},
|
|
99
|
+
"math": {
|
|
100
|
+
"description": "Math reasoning specialist — targets GSM8K improvement",
|
|
101
|
+
"datasets": [
|
|
102
|
+
{
|
|
103
|
+
"name": "meta-math/MetaMathQA",
|
|
104
|
+
"split": "train",
|
|
105
|
+
"samples": 30000,
|
|
106
|
+
"format": "metamath", # query/response
|
|
107
|
+
},
|
|
108
|
+
{
|
|
109
|
+
"name": "microsoft/orca-math-word-problems-200k",
|
|
110
|
+
"split": "train",
|
|
111
|
+
"samples": 20000,
|
|
112
|
+
"format": "orca_math", # question/answer
|
|
113
|
+
},
|
|
114
|
+
],
|
|
115
|
+
"system_prompt": "You are a math expert. Solve problems step by step, showing your work clearly. Always verify your answer at the end.",
|
|
116
|
+
"lora_rank": 64,
|
|
117
|
+
"epochs": 2,
|
|
118
|
+
"max_seq_length": 1024,
|
|
119
|
+
"learning_rate": 2e-4,
|
|
120
|
+
},
|
|
121
|
+
"general": {
|
|
122
|
+
"description": "General knowledge specialist — targets MMLU improvement",
|
|
123
|
+
"datasets": [
|
|
124
|
+
{
|
|
125
|
+
"name": "Open-Orca/SlimOrca",
|
|
126
|
+
"split": "train",
|
|
127
|
+
"samples": 25000,
|
|
128
|
+
"format": "sharegpt", # conversations list
|
|
129
|
+
},
|
|
130
|
+
{
|
|
131
|
+
"name": "yahma/alpaca-cleaned",
|
|
132
|
+
"split": "train",
|
|
133
|
+
"samples": 25000,
|
|
134
|
+
"format": "alpaca",
|
|
135
|
+
},
|
|
136
|
+
],
|
|
137
|
+
"system_prompt": "You are a knowledgeable assistant. Give accurate, well-structured answers. If you're unsure, say so.",
|
|
138
|
+
"lora_rank": 64,
|
|
139
|
+
"epochs": 2,
|
|
140
|
+
"max_seq_length": 2048,
|
|
141
|
+
"learning_rate": 1.5e-4,
|
|
142
|
+
},
|
|
143
|
+
"coordinator": {
|
|
144
|
+
"description": "Swarm coordinator — routes queries to the right specialist",
|
|
145
|
+
"datasets": [], # Generated synthetically below
|
|
146
|
+
"system_prompt": "You are a routing coordinator. Analyze each query and decide which specialist should handle it. Output JSON with your routing decision.",
|
|
147
|
+
"lora_rank": 32,
|
|
148
|
+
"epochs": 3,
|
|
149
|
+
"max_seq_length": 512,
|
|
150
|
+
"learning_rate": 2e-4,
|
|
151
|
+
},
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ============================================================
|
|
156
|
+
# Data Formatting — convert each dataset to chat template
|
|
157
|
+
# ============================================================
|
|
158
|
+
|
|
159
|
+
def format_chat(system: str, user: str, assistant: str) -> str:
|
|
160
|
+
"""Format a conversation in Qwen2.5 chat template."""
|
|
161
|
+
parts = []
|
|
162
|
+
if system:
|
|
163
|
+
parts.append(f"<|im_start|>system\n{system}<|im_end|>")
|
|
164
|
+
parts.append(f"<|im_start|>user\n{user}<|im_end|>")
|
|
165
|
+
parts.append(f"<|im_start|>assistant\n{assistant}<|im_end|>")
|
|
166
|
+
return "\n".join(parts)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def format_dataset_item(item: dict, fmt: str, system_prompt: str) -> Optional[str]:
|
|
170
|
+
"""Convert a dataset item to chat-template formatted text."""
|
|
171
|
+
try:
|
|
172
|
+
if fmt == "alpaca":
|
|
173
|
+
instruction = item.get("instruction", "") or item.get("prompt", "")
|
|
174
|
+
inp = item.get("input", "")
|
|
175
|
+
output = item.get("output", "")
|
|
176
|
+
if not instruction or not output:
|
|
177
|
+
return None
|
|
178
|
+
user = f"{instruction}\n{inp}".strip() if inp else instruction
|
|
179
|
+
return format_chat(system_prompt, user, output)
|
|
180
|
+
|
|
181
|
+
elif fmt == "evol":
|
|
182
|
+
instruction = item.get("instruction", "")
|
|
183
|
+
output = item.get("output", "")
|
|
184
|
+
if not instruction or not output:
|
|
185
|
+
return None
|
|
186
|
+
return format_chat(system_prompt, instruction, output)
|
|
187
|
+
|
|
188
|
+
elif fmt == "metamath":
|
|
189
|
+
query = item.get("query", "")
|
|
190
|
+
response = item.get("response", "")
|
|
191
|
+
if not query or not response:
|
|
192
|
+
return None
|
|
193
|
+
return format_chat(system_prompt, query, response)
|
|
194
|
+
|
|
195
|
+
elif fmt == "orca_math":
|
|
196
|
+
question = item.get("question", "")
|
|
197
|
+
answer = item.get("answer", "")
|
|
198
|
+
if not question or not answer:
|
|
199
|
+
return None
|
|
200
|
+
return format_chat(system_prompt, question, answer)
|
|
201
|
+
|
|
202
|
+
elif fmt == "sharegpt":
|
|
203
|
+
# SlimOrca uses conversations format
|
|
204
|
+
convos = item.get("conversations", [])
|
|
205
|
+
if not convos or len(convos) < 2:
|
|
206
|
+
return None
|
|
207
|
+
# Find system, human, gpt messages
|
|
208
|
+
system = ""
|
|
209
|
+
user = ""
|
|
210
|
+
assistant = ""
|
|
211
|
+
for msg in convos:
|
|
212
|
+
role = msg.get("from", "")
|
|
213
|
+
value = msg.get("value", "")
|
|
214
|
+
if role == "system":
|
|
215
|
+
system = value
|
|
216
|
+
elif role == "human":
|
|
217
|
+
user = value
|
|
218
|
+
elif role == "gpt":
|
|
219
|
+
assistant = value
|
|
220
|
+
if not user or not assistant:
|
|
221
|
+
return None
|
|
222
|
+
return format_chat(system or system_prompt, user, assistant)
|
|
223
|
+
|
|
224
|
+
elif fmt == "routing":
|
|
225
|
+
# Pre-formatted routing examples
|
|
226
|
+
return item.get("text", None)
|
|
227
|
+
|
|
228
|
+
return None
|
|
229
|
+
except Exception:
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# ============================================================
|
|
234
|
+
# Coordinator Routing Data Generation
|
|
235
|
+
# ============================================================
|
|
236
|
+
|
|
237
|
+
def generate_routing_data(count: int = 5000) -> list:
|
|
238
|
+
"""Generate synthetic routing training data for the coordinator.
|
|
239
|
+
|
|
240
|
+
The coordinator needs to learn:
|
|
241
|
+
1. Which specialist handles which query
|
|
242
|
+
2. When to use swarm mode (multi-specialist)
|
|
243
|
+
3. Confidence scoring
|
|
244
|
+
"""
|
|
245
|
+
import random
|
|
246
|
+
|
|
247
|
+
templates = {
|
|
248
|
+
"python_expert": [
|
|
249
|
+
"Write a Python function to {task}",
|
|
250
|
+
"Debug this Python code: {code}",
|
|
251
|
+
"How do I {task} in Python?",
|
|
252
|
+
"Create a Python class that {task}",
|
|
253
|
+
"What's the best way to {task} with Python?",
|
|
254
|
+
"Optimize this Python code for performance",
|
|
255
|
+
"Write unit tests for this Python function",
|
|
256
|
+
"Explain this Python error: {error}",
|
|
257
|
+
"Convert this code to async Python",
|
|
258
|
+
"Build a FastAPI endpoint that {task}",
|
|
259
|
+
],
|
|
260
|
+
"sql_expert": [
|
|
261
|
+
"Write a SQL query to {task}",
|
|
262
|
+
"Optimize this SQL query for performance",
|
|
263
|
+
"How do I join these tables: {tables}",
|
|
264
|
+
"Create a database schema for {domain}",
|
|
265
|
+
"Explain this SQL execution plan",
|
|
266
|
+
"Write a stored procedure that {task}",
|
|
267
|
+
"Migrate this schema from MySQL to Postgres",
|
|
268
|
+
"Add an index to improve this query",
|
|
269
|
+
"Write a complex GROUP BY query for {task}",
|
|
270
|
+
"Design a normalized database for {domain}",
|
|
271
|
+
],
|
|
272
|
+
"math_expert": [
|
|
273
|
+
"Solve this equation: {equation}",
|
|
274
|
+
"Calculate the probability of {event}",
|
|
275
|
+
"Prove that {theorem}",
|
|
276
|
+
"What is the derivative of {function}?",
|
|
277
|
+
"How many ways can you {task}?",
|
|
278
|
+
"Solve this word problem: {problem}",
|
|
279
|
+
"Find the integral of {function}",
|
|
280
|
+
"What is the expected value of {event}?",
|
|
281
|
+
"Simplify this expression: {expression}",
|
|
282
|
+
"Solve this system of equations",
|
|
283
|
+
],
|
|
284
|
+
"general": [
|
|
285
|
+
"Explain {topic} in simple terms",
|
|
286
|
+
"What is {concept}?",
|
|
287
|
+
"Compare {a} and {b}",
|
|
288
|
+
"Summarize the key points of {topic}",
|
|
289
|
+
"What are the pros and cons of {topic}?",
|
|
290
|
+
"How does {concept} work?",
|
|
291
|
+
"Give me an overview of {topic}",
|
|
292
|
+
"What should I know about {topic}?",
|
|
293
|
+
"Tell me about {topic}",
|
|
294
|
+
"What's the difference between {a} and {b}?",
|
|
295
|
+
],
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
# Fillers for templates
|
|
299
|
+
python_tasks = [
|
|
300
|
+
"sort a list of dictionaries", "parse JSON", "handle file uploads",
|
|
301
|
+
"connect to a database", "build a web scraper", "implement a binary tree",
|
|
302
|
+
"create a REST API", "process CSV files", "implement caching",
|
|
303
|
+
"handle concurrent requests", "validate email addresses", "parse dates",
|
|
304
|
+
"implement rate limiting", "build a CLI tool", "create a decorator",
|
|
305
|
+
"manage environment variables", "implement retry logic", "stream large files",
|
|
306
|
+
"build a state machine", "implement pagination",
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
sql_tasks = [
|
|
310
|
+
"find duplicate records", "aggregate sales by month", "rank customers",
|
|
311
|
+
"calculate running totals", "find gaps in sequences", "pivot data",
|
|
312
|
+
"merge overlapping ranges", "find nth highest salary",
|
|
313
|
+
"recursive CTE for hierarchies", "window functions for analytics",
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
math_problems = [
|
|
317
|
+
"a train traveling at 60mph", "probability of rolling dice",
|
|
318
|
+
"compound interest calculation", "geometric sequence sum",
|
|
319
|
+
"optimization of area", "combinatorics problem",
|
|
320
|
+
"linear regression coefficients", "matrix multiplication",
|
|
321
|
+
]
|
|
322
|
+
|
|
323
|
+
topics = [
|
|
324
|
+
"machine learning", "quantum computing", "blockchain", "neural networks",
|
|
325
|
+
"cloud computing", "cybersecurity", "distributed systems", "microservices",
|
|
326
|
+
"DevOps practices", "agile methodology", "functional programming",
|
|
327
|
+
"graph databases", "event-driven architecture", "container orchestration",
|
|
328
|
+
"data warehousing", "API design", "load balancing", "caching strategies",
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
# Swarm (multi-specialist) examples
|
|
332
|
+
swarm_queries = [
|
|
333
|
+
"Build a REST API with a database, authentication, and tests",
|
|
334
|
+
"Create a data pipeline that processes CSV, stores in PostgreSQL, and generates reports",
|
|
335
|
+
"Refactor this legacy codebase and add comprehensive tests",
|
|
336
|
+
"Design a microservice architecture with database schemas and API contracts",
|
|
337
|
+
"Build a machine learning pipeline with data preprocessing, model training, and evaluation",
|
|
338
|
+
"Create a web application with user auth, database, API, and deployment scripts",
|
|
339
|
+
"Analyze this dataset statistically and create visualizations with Python",
|
|
340
|
+
"Optimize both the SQL queries and the Python code calling them",
|
|
341
|
+
"Build a recommendation system with a database backend and API layer",
|
|
342
|
+
"Create a monitoring dashboard with alerting, database queries, and Python scripts",
|
|
343
|
+
]
|
|
344
|
+
|
|
345
|
+
data = []
|
|
346
|
+
specialist_names = list(templates.keys())
|
|
347
|
+
|
|
348
|
+
for i in range(count):
|
|
349
|
+
if random.random() < 0.15: # 15% swarm examples
|
|
350
|
+
query = random.choice(swarm_queries)
|
|
351
|
+
specialists = random.sample(specialist_names, k=random.randint(2, 3))
|
|
352
|
+
subtasks = []
|
|
353
|
+
for s in specialists:
|
|
354
|
+
subtasks.append({"specialist": s, "task": f"Handle the {s.replace('_expert', '')} aspect"})
|
|
355
|
+
response = json.dumps({
|
|
356
|
+
"mode": "swarm",
|
|
357
|
+
"confidence": round(random.uniform(0.80, 0.95), 2),
|
|
358
|
+
"subtasks": subtasks,
|
|
359
|
+
})
|
|
360
|
+
else: # Single specialist
|
|
361
|
+
specialist = random.choice(specialist_names)
|
|
362
|
+
template = random.choice(templates[specialist])
|
|
363
|
+
|
|
364
|
+
# Fill in template
|
|
365
|
+
if specialist == "python_expert":
|
|
366
|
+
query = template.format(
|
|
367
|
+
task=random.choice(python_tasks),
|
|
368
|
+
code="...",
|
|
369
|
+
error="TypeError: unsupported operand type",
|
|
370
|
+
)
|
|
371
|
+
elif specialist == "sql_expert":
|
|
372
|
+
query = template.format(
|
|
373
|
+
task=random.choice(sql_tasks),
|
|
374
|
+
tables="users, orders, products",
|
|
375
|
+
domain=random.choice(["e-commerce", "healthcare", "finance", "social media"]),
|
|
376
|
+
)
|
|
377
|
+
elif specialist == "math_expert":
|
|
378
|
+
query = template.format(
|
|
379
|
+
task=random.choice(math_problems),
|
|
380
|
+
equation="2x² + 3x - 5 = 0",
|
|
381
|
+
theorem="the sum of angles in a triangle is 180°",
|
|
382
|
+
function="x³ + 2x",
|
|
383
|
+
event="getting at least one head in 3 coin flips",
|
|
384
|
+
problem=random.choice(math_problems),
|
|
385
|
+
expression="(3x² + 2x) / x",
|
|
386
|
+
)
|
|
387
|
+
else: # general
|
|
388
|
+
t1, t2 = random.sample(topics, 2)
|
|
389
|
+
query = template.format(
|
|
390
|
+
topic=t1, concept=t1, a=t1, b=t2,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
response = json.dumps({
|
|
394
|
+
"mode": "single",
|
|
395
|
+
"specialist": specialist,
|
|
396
|
+
"confidence": round(random.uniform(0.82, 0.98), 2),
|
|
397
|
+
"reasoning": f"Query matches {specialist} domain",
|
|
398
|
+
})
|
|
399
|
+
|
|
400
|
+
text = format_chat(
|
|
401
|
+
"You are a routing coordinator for a specialist AI swarm. Analyze the query and decide which specialist(s) should handle it. Respond with JSON.",
|
|
402
|
+
f"Route this query: {query}",
|
|
403
|
+
response,
|
|
404
|
+
)
|
|
405
|
+
data.append({"text": text})
|
|
406
|
+
|
|
407
|
+
return data
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
# ============================================================
|
|
411
|
+
# Training Core
|
|
412
|
+
# ============================================================
|
|
413
|
+
|
|
414
|
+
def _get_attn_impl() -> str:
|
|
415
|
+
"""Pick the best available attention implementation."""
|
|
416
|
+
try:
|
|
417
|
+
import flash_attn # noqa: F401
|
|
418
|
+
logger.info("Using FlashAttention2")
|
|
419
|
+
return "flash_attention_2"
|
|
420
|
+
except ImportError:
|
|
421
|
+
logger.info("FlashAttention2 not available, using SDPA (still fast)")
|
|
422
|
+
return "sdpa"
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def load_and_prepare_data(specialist: str, max_total: int = 50000) -> list:
|
|
426
|
+
"""Download and format training data for a specialist."""
|
|
427
|
+
config = SPECIALIST_DATASETS[specialist]
|
|
428
|
+
all_texts = []
|
|
429
|
+
|
|
430
|
+
if specialist == "coordinator":
|
|
431
|
+
logger.info("Generating synthetic routing data...")
|
|
432
|
+
data = generate_routing_data(count=5000)
|
|
433
|
+
all_texts = [d["text"] for d in data]
|
|
434
|
+
logger.info(f"Generated {len(all_texts)} routing examples")
|
|
435
|
+
return all_texts
|
|
436
|
+
|
|
437
|
+
system_prompt = config["system_prompt"]
|
|
438
|
+
|
|
439
|
+
for ds_info in config["datasets"]:
|
|
440
|
+
name = ds_info["name"]
|
|
441
|
+
split = ds_info["split"]
|
|
442
|
+
samples = min(ds_info["samples"], max_total - len(all_texts))
|
|
443
|
+
fmt = ds_info["format"]
|
|
444
|
+
|
|
445
|
+
if samples <= 0:
|
|
446
|
+
break
|
|
447
|
+
|
|
448
|
+
logger.info(f"Loading {name} ({samples} samples)...")
|
|
449
|
+
try:
|
|
450
|
+
# Use streaming for large datasets to avoid downloading everything
|
|
451
|
+
if samples < 50000:
|
|
452
|
+
dataset = load_dataset(
|
|
453
|
+
name,
|
|
454
|
+
split=f"{split}[:{samples}]",
|
|
455
|
+
cache_dir=str(CACHE_DIR),
|
|
456
|
+
)
|
|
457
|
+
else:
|
|
458
|
+
dataset = load_dataset(
|
|
459
|
+
name,
|
|
460
|
+
split=split,
|
|
461
|
+
streaming=True,
|
|
462
|
+
cache_dir=str(CACHE_DIR),
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
count = 0
|
|
466
|
+
for item in dataset:
|
|
467
|
+
if count >= samples:
|
|
468
|
+
break
|
|
469
|
+
text = format_dataset_item(item, fmt, system_prompt)
|
|
470
|
+
if text and len(text) > 50:
|
|
471
|
+
all_texts.append(text)
|
|
472
|
+
count += 1
|
|
473
|
+
|
|
474
|
+
if count % 5000 == 0 and count > 0:
|
|
475
|
+
logger.info(f" Processed {count}/{samples} from {name}")
|
|
476
|
+
|
|
477
|
+
logger.info(f" Got {count} valid samples from {name}")
|
|
478
|
+
|
|
479
|
+
except Exception as e:
|
|
480
|
+
logger.error(f" Failed to load {name}: {e}")
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
logger.info(f"Total training samples for {specialist}: {len(all_texts)}")
|
|
484
|
+
return all_texts
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def train_specialist(
|
|
488
|
+
specialist: str,
|
|
489
|
+
base_model: str = BASE_MODEL,
|
|
490
|
+
resume_from: Optional[str] = None,
|
|
491
|
+
) -> dict:
|
|
492
|
+
"""Train a QLoRA adapter for a specialist."""
|
|
493
|
+
config = SPECIALIST_DATASETS[specialist]
|
|
494
|
+
logger.info("=" * 70)
|
|
495
|
+
logger.info(f"TRAINING SPECIALIST: {specialist}")
|
|
496
|
+
logger.info(f" {config['description']}")
|
|
497
|
+
logger.info(f" Base model: {base_model}")
|
|
498
|
+
logger.info(f" LoRA rank: {config['lora_rank']}")
|
|
499
|
+
logger.info(f" Epochs: {config['epochs']}")
|
|
500
|
+
logger.info(f" Max seq length: {config['max_seq_length']}")
|
|
501
|
+
logger.info(f" Learning rate: {config['learning_rate']}")
|
|
502
|
+
logger.info("=" * 70)
|
|
503
|
+
|
|
504
|
+
start_time = time.time()
|
|
505
|
+
|
|
506
|
+
# Step 1: Load and prepare data
|
|
507
|
+
logger.info("\n[1/4] Loading training data...")
|
|
508
|
+
texts = load_and_prepare_data(specialist)
|
|
509
|
+
|
|
510
|
+
if not texts:
|
|
511
|
+
logger.error("No training data available!")
|
|
512
|
+
return {"error": "no_data", "specialist": specialist}
|
|
513
|
+
|
|
514
|
+
# Save training data for reproducibility
|
|
515
|
+
data_file = TRAINING_DIR / f"{specialist}_train.jsonl"
|
|
516
|
+
with open(data_file, "w") as f:
|
|
517
|
+
for text in texts:
|
|
518
|
+
f.write(json.dumps({"text": text}) + "\n")
|
|
519
|
+
logger.info(f" Saved {len(texts)} samples to {data_file}")
|
|
520
|
+
|
|
521
|
+
dataset = Dataset.from_dict({"text": texts})
|
|
522
|
+
|
|
523
|
+
# Step 2: Load base model with QLoRA
|
|
524
|
+
logger.info("\n[2/4] Loading base model with 4-bit quantization...")
|
|
525
|
+
bnb_config = BitsAndBytesConfig(
|
|
526
|
+
load_in_4bit=True,
|
|
527
|
+
bnb_4bit_quant_type="nf4",
|
|
528
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
529
|
+
bnb_4bit_use_double_quant=True,
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
533
|
+
base_model,
|
|
534
|
+
cache_dir=str(CACHE_DIR),
|
|
535
|
+
trust_remote_code=True,
|
|
536
|
+
)
|
|
537
|
+
if tokenizer.pad_token is None:
|
|
538
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
539
|
+
|
|
540
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
541
|
+
base_model,
|
|
542
|
+
quantization_config=bnb_config,
|
|
543
|
+
device_map="auto",
|
|
544
|
+
cache_dir=str(CACHE_DIR),
|
|
545
|
+
trust_remote_code=True,
|
|
546
|
+
attn_implementation=_get_attn_impl(),
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
model = prepare_model_for_kbit_training(model)
|
|
550
|
+
|
|
551
|
+
# Step 3: Configure LoRA
|
|
552
|
+
logger.info("\n[3/4] Configuring LoRA adapter...")
|
|
553
|
+
lora_config = LoraConfig(
|
|
554
|
+
r=config["lora_rank"],
|
|
555
|
+
lora_alpha=config["lora_rank"] * 2,
|
|
556
|
+
target_modules=[
|
|
557
|
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
558
|
+
"gate_proj", "up_proj", "down_proj",
|
|
559
|
+
],
|
|
560
|
+
lora_dropout=0.05,
|
|
561
|
+
bias="none",
|
|
562
|
+
task_type="CAUSAL_LM",
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
model = get_peft_model(model, lora_config)
|
|
566
|
+
|
|
567
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
568
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
569
|
+
logger.info(f" Trainable: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
|
|
570
|
+
|
|
571
|
+
# Print GPU memory
|
|
572
|
+
if torch.cuda.is_available():
|
|
573
|
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
|
574
|
+
reserved = torch.cuda.memory_reserved() / 1024**3
|
|
575
|
+
logger.info(f" GPU memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
|
|
576
|
+
|
|
577
|
+
# Step 4: Train
|
|
578
|
+
logger.info("\n[4/4] Starting training...")
|
|
579
|
+
output_dir = str(ADAPTERS_DIR / f"{specialist}_v1")
|
|
580
|
+
|
|
581
|
+
training_config = SFTConfig(
|
|
582
|
+
output_dir=output_dir,
|
|
583
|
+
num_train_epochs=config["epochs"],
|
|
584
|
+
per_device_train_batch_size=4,
|
|
585
|
+
gradient_accumulation_steps=4,
|
|
586
|
+
learning_rate=config["learning_rate"],
|
|
587
|
+
weight_decay=0.01,
|
|
588
|
+
warmup_ratio=0.03,
|
|
589
|
+
lr_scheduler_type="cosine",
|
|
590
|
+
logging_steps=25,
|
|
591
|
+
save_strategy="epoch",
|
|
592
|
+
save_total_limit=2,
|
|
593
|
+
bf16=True,
|
|
594
|
+
max_length=config["max_seq_length"],
|
|
595
|
+
dataset_text_field="text",
|
|
596
|
+
packing=True, # Pack short examples for efficiency
|
|
597
|
+
gradient_checkpointing=True, # Save VRAM
|
|
598
|
+
optim="paged_adamw_8bit", # 8-bit optimizer saves VRAM
|
|
599
|
+
report_to="none", # No wandb/tensorboard
|
|
600
|
+
dataloader_num_workers=4,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
trainer = SFTTrainer(
|
|
604
|
+
model=model,
|
|
605
|
+
args=training_config,
|
|
606
|
+
train_dataset=dataset,
|
|
607
|
+
processing_class=tokenizer,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
result = trainer.train(resume_from_checkpoint=resume_from)
|
|
611
|
+
duration = time.time() - start_time
|
|
612
|
+
|
|
613
|
+
# Save
|
|
614
|
+
trainer.save_model(output_dir)
|
|
615
|
+
tokenizer.save_pretrained(output_dir)
|
|
616
|
+
|
|
617
|
+
# Save training metadata
|
|
618
|
+
meta = {
|
|
619
|
+
"specialist": specialist,
|
|
620
|
+
"description": config["description"],
|
|
621
|
+
"base_model": base_model,
|
|
622
|
+
"lora_rank": config["lora_rank"],
|
|
623
|
+
"training_loss": result.training_loss,
|
|
624
|
+
"epochs": config["epochs"],
|
|
625
|
+
"samples": len(texts),
|
|
626
|
+
"duration_seconds": round(duration),
|
|
627
|
+
"duration_human": f"{duration/3600:.1f}h",
|
|
628
|
+
"trainable_params": trainable_params,
|
|
629
|
+
"total_params": total_params,
|
|
630
|
+
"max_seq_length": config["max_seq_length"],
|
|
631
|
+
"learning_rate": config["learning_rate"],
|
|
632
|
+
"timestamp": datetime.now().isoformat(),
|
|
633
|
+
"gpu": torch.cuda.get_device_name() if torch.cuda.is_available() else "cpu",
|
|
634
|
+
"created_by": "titan-synapse",
|
|
635
|
+
}
|
|
636
|
+
with open(Path(output_dir) / "training_meta.json", "w") as f:
|
|
637
|
+
json.dump(meta, f, indent=2)
|
|
638
|
+
|
|
639
|
+
logger.info("\n" + "=" * 70)
|
|
640
|
+
logger.info(f"TRAINING COMPLETE: {specialist}")
|
|
641
|
+
logger.info(f" Loss: {result.training_loss:.4f}")
|
|
642
|
+
logger.info(f" Duration: {duration/3600:.1f}h ({duration:.0f}s)")
|
|
643
|
+
logger.info(f" Samples: {len(texts)}")
|
|
644
|
+
logger.info(f" Adapter saved: {output_dir}")
|
|
645
|
+
logger.info("=" * 70)
|
|
646
|
+
|
|
647
|
+
# Cleanup GPU memory
|
|
648
|
+
del model, trainer
|
|
649
|
+
gc.collect()
|
|
650
|
+
torch.cuda.empty_cache()
|
|
651
|
+
|
|
652
|
+
return meta
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def merge_and_export(specialist: str, quantize: str = "Q4_K_M") -> dict:
|
|
656
|
+
"""Merge LoRA adapter into base model and export as GGUF."""
|
|
657
|
+
adapter_dir = ADAPTERS_DIR / f"{specialist}_v1"
|
|
658
|
+
|
|
659
|
+
if not adapter_dir.exists():
|
|
660
|
+
logger.error(f"Adapter not found: {adapter_dir}")
|
|
661
|
+
return {"error": "adapter_not_found"}
|
|
662
|
+
|
|
663
|
+
logger.info(f"Merging {specialist} adapter into base model...")
|
|
664
|
+
|
|
665
|
+
# Load base + adapter
|
|
666
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
667
|
+
BASE_MODEL,
|
|
668
|
+
device_map="auto",
|
|
669
|
+
torch_dtype=torch.bfloat16,
|
|
670
|
+
cache_dir=str(CACHE_DIR),
|
|
671
|
+
trust_remote_code=True,
|
|
672
|
+
)
|
|
673
|
+
model = PeftModel.from_pretrained(model, str(adapter_dir))
|
|
674
|
+
merged = model.merge_and_unload()
|
|
675
|
+
|
|
676
|
+
# Save merged model
|
|
677
|
+
merge_dir = TRAINING_DIR / f"{specialist}_merged"
|
|
678
|
+
merged.save_pretrained(str(merge_dir))
|
|
679
|
+
|
|
680
|
+
tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir))
|
|
681
|
+
tokenizer.save_pretrained(str(merge_dir))
|
|
682
|
+
|
|
683
|
+
logger.info(f"Merged model saved to {merge_dir}")
|
|
684
|
+
logger.info(f"To convert to GGUF, run:")
|
|
685
|
+
logger.info(f" python llama.cpp/convert_hf_to_gguf.py {merge_dir} --outtype {quantize.lower()}")
|
|
686
|
+
|
|
687
|
+
del model, merged
|
|
688
|
+
gc.collect()
|
|
689
|
+
torch.cuda.empty_cache()
|
|
690
|
+
|
|
691
|
+
return {
|
|
692
|
+
"specialist": specialist,
|
|
693
|
+
"merged_path": str(merge_dir),
|
|
694
|
+
"next_step": f"Convert to GGUF with llama.cpp",
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
# ============================================================
|
|
699
|
+
# Main Pipeline
|
|
700
|
+
# ============================================================
|
|
701
|
+
|
|
702
|
+
def train_all_specialists():
|
|
703
|
+
"""Train all specialist adapters sequentially."""
|
|
704
|
+
results = {}
|
|
705
|
+
specialists = ["math", "code", "general", "coordinator"]
|
|
706
|
+
|
|
707
|
+
logger.info("=" * 70)
|
|
708
|
+
logger.info("TITAN SYNAPSE — Full Specialist Training Pipeline")
|
|
709
|
+
logger.info(f"Training {len(specialists)} specialists: {', '.join(specialists)}")
|
|
710
|
+
logger.info(f"Base model: {BASE_MODEL}")
|
|
711
|
+
logger.info(f"GPU: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}")
|
|
712
|
+
if torch.cuda.is_available():
|
|
713
|
+
total_vram = torch.cuda.get_device_properties(0).total_mem / 1024**3
|
|
714
|
+
logger.info(f"VRAM: {total_vram:.0f}GB")
|
|
715
|
+
logger.info("=" * 70)
|
|
716
|
+
|
|
717
|
+
total_start = time.time()
|
|
718
|
+
|
|
719
|
+
for specialist in specialists:
|
|
720
|
+
logger.info(f"\n{'='*70}")
|
|
721
|
+
logger.info(f"Starting training for: {specialist}")
|
|
722
|
+
logger.info(f"{'='*70}\n")
|
|
723
|
+
|
|
724
|
+
try:
|
|
725
|
+
result = train_specialist(specialist)
|
|
726
|
+
results[specialist] = result
|
|
727
|
+
except Exception as e:
|
|
728
|
+
logger.error(f"Failed to train {specialist}: {e}", exc_info=True)
|
|
729
|
+
results[specialist] = {"error": str(e)}
|
|
730
|
+
|
|
731
|
+
total_duration = time.time() - total_start
|
|
732
|
+
|
|
733
|
+
# Summary
|
|
734
|
+
logger.info("\n" + "=" * 70)
|
|
735
|
+
logger.info("TRAINING PIPELINE COMPLETE")
|
|
736
|
+
logger.info(f"Total duration: {total_duration/3600:.1f}h")
|
|
737
|
+
logger.info("")
|
|
738
|
+
for specialist, result in results.items():
|
|
739
|
+
if "error" in result:
|
|
740
|
+
logger.info(f" ✗ {specialist}: FAILED — {result['error']}")
|
|
741
|
+
else:
|
|
742
|
+
logger.info(f" ✓ {specialist}: loss={result['training_loss']:.4f}, "
|
|
743
|
+
f"samples={result['samples']}, "
|
|
744
|
+
f"time={result['duration_human']}")
|
|
745
|
+
logger.info("=" * 70)
|
|
746
|
+
|
|
747
|
+
# Save pipeline results
|
|
748
|
+
with open(TRAINING_DIR / "pipeline_results.json", "w") as f:
|
|
749
|
+
json.dump({
|
|
750
|
+
"timestamp": datetime.now().isoformat(),
|
|
751
|
+
"total_duration": total_duration,
|
|
752
|
+
"results": results,
|
|
753
|
+
}, f, indent=2, default=str)
|
|
754
|
+
|
|
755
|
+
return results
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
if __name__ == "__main__":
|
|
759
|
+
parser = argparse.ArgumentParser(description="Train Synapse specialist adapters")
|
|
760
|
+
parser.add_argument(
|
|
761
|
+
"--specialist",
|
|
762
|
+
choices=["all", "code", "math", "general", "coordinator"],
|
|
763
|
+
default="all",
|
|
764
|
+
help="Which specialist to train (default: all)",
|
|
765
|
+
)
|
|
766
|
+
parser.add_argument(
|
|
767
|
+
"--base-model",
|
|
768
|
+
default=BASE_MODEL,
|
|
769
|
+
help=f"Base model (default: {BASE_MODEL})",
|
|
770
|
+
)
|
|
771
|
+
parser.add_argument(
|
|
772
|
+
"--export",
|
|
773
|
+
action="store_true",
|
|
774
|
+
help="Merge and export trained adapter after training",
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
args = parser.parse_args()
|
|
778
|
+
|
|
779
|
+
if args.base_model != BASE_MODEL:
|
|
780
|
+
BASE_MODEL = args.base_model
|
|
781
|
+
|
|
782
|
+
if args.specialist == "all":
|
|
783
|
+
results = train_all_specialists()
|
|
784
|
+
else:
|
|
785
|
+
result = train_specialist(args.specialist, base_model=BASE_MODEL)
|
|
786
|
+
if args.export and "error" not in result:
|
|
787
|
+
merge_and_export(args.specialist)
|