hegelion 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hegelion/__init__.py +45 -0
- hegelion/core/__init__.py +29 -0
- hegelion/core/agent.py +166 -0
- hegelion/core/autocoding_state.py +293 -0
- hegelion/core/backends.py +442 -0
- hegelion/core/cache.py +92 -0
- hegelion/core/config.py +276 -0
- hegelion/core/core.py +649 -0
- hegelion/core/engine.py +865 -0
- hegelion/core/logging_utils.py +67 -0
- hegelion/core/models.py +293 -0
- hegelion/core/parsing.py +271 -0
- hegelion/core/personas.py +81 -0
- hegelion/core/prompt_autocoding.py +353 -0
- hegelion/core/prompt_dialectic.py +414 -0
- hegelion/core/prompts.py +127 -0
- hegelion/core/schema.py +67 -0
- hegelion/core/validation.py +68 -0
- hegelion/council.py +254 -0
- hegelion/examples_data/__init__.py +6 -0
- hegelion/examples_data/glm4_6_examples.jsonl +2 -0
- hegelion/judge.py +230 -0
- hegelion/mcp/__init__.py +3 -0
- hegelion/mcp/server.py +918 -0
- hegelion/scripts/hegelion_agent_cli.py +90 -0
- hegelion/scripts/hegelion_bench.py +117 -0
- hegelion/scripts/hegelion_cli.py +497 -0
- hegelion/scripts/hegelion_dataset.py +99 -0
- hegelion/scripts/hegelion_eval.py +137 -0
- hegelion/scripts/mcp_setup.py +150 -0
- hegelion/search_providers.py +151 -0
- hegelion/training/__init__.py +7 -0
- hegelion/training/datasets.py +123 -0
- hegelion/training/generator.py +232 -0
- hegelion/training/mlx_scu_trainer.py +379 -0
- hegelion/training/mlx_trainer.py +181 -0
- hegelion/training/unsloth_trainer.py +136 -0
- hegelion-0.4.0.dist-info/METADATA +295 -0
- hegelion-0.4.0.dist-info/RECORD +43 -0
- hegelion-0.4.0.dist-info/WHEEL +5 -0
- hegelion-0.4.0.dist-info/entry_points.txt +8 -0
- hegelion-0.4.0.dist-info/licenses/LICENSE +21 -0
- hegelion-0.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from hegelion.core.core import run_dialectic
|
|
7
|
+
from hegelion.core.models import HegelionResult
|
|
8
|
+
from hegelion.core.config import get_config, set_config_value
|
|
9
|
+
from hegelion.training.wrappers.kimi_cli import get_kimi_cli
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from datasets import load_dataset
|
|
13
|
+
except ImportError:
|
|
14
|
+
load_dataset = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# The "Teacher" System Prompt
|
|
18
|
+
# Forces the model to explicate its dialectical reasoning process
|
|
19
|
+
TEACHER_SYSTEM_PROMPT = """You are a dialectical reasoning engine. For every user query, you MUST follow this strict thought process:
|
|
20
|
+
|
|
21
|
+
1. **THESIS**: Propose the strongest initial argument or solution.
|
|
22
|
+
2. **ANTITHESIS**: Critically attack the thesis. Find flaws, edge cases, or opposing evidence.
|
|
23
|
+
3. **SYNTHESIS**: Resolve the conflict. Create a new, stronger solution that incorporates the valid points of both.
|
|
24
|
+
|
|
25
|
+
Format your response exactly like this:
|
|
26
|
+
<thought>
|
|
27
|
+
[The full dialectical trace goes here: Thesis -> Antithesis -> Synthesis]
|
|
28
|
+
</thought>
|
|
29
|
+
[The final answer goes here]"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def generate_dataset(
|
|
33
|
+
dataset_name: str,
|
|
34
|
+
output_file: str,
|
|
35
|
+
split: str = "train",
|
|
36
|
+
column: str = "text",
|
|
37
|
+
limit: int = 100,
|
|
38
|
+
resume: bool = True,
|
|
39
|
+
max_tokens: int = 4000,
|
|
40
|
+
model: str = "kimi", # Default to our teacher
|
|
41
|
+
prompt_file: Optional[str] = None,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Generate dialectical traces for a HuggingFace dataset using a Teacher model (Kimi).
|
|
45
|
+
"""
|
|
46
|
+
if load_dataset is None and prompt_file is None:
|
|
47
|
+
raise ImportError("Please install 'datasets' to use this feature: pip install datasets")
|
|
48
|
+
|
|
49
|
+
# Configure Teacher
|
|
50
|
+
use_cli = False
|
|
51
|
+
if model == "kimi-cli":
|
|
52
|
+
use_cli = True
|
|
53
|
+
print(" configured to use Moonshot AI (Kimi) via CLI wrapper.")
|
|
54
|
+
elif model == "kimi":
|
|
55
|
+
# Ensure Kimi is configured
|
|
56
|
+
config = get_config()
|
|
57
|
+
if not config.moonshot_key:
|
|
58
|
+
raise ValueError("MOONSHOT_API_KEY not found. Set it in your .env or environment.")
|
|
59
|
+
set_config_value("provider", "moonshot")
|
|
60
|
+
set_config_value("model", "kimi-k2-thinking") # Use the reasoning model
|
|
61
|
+
print(" configured to use Moonshot AI (Kimi) as Teacher.")
|
|
62
|
+
elif model:
|
|
63
|
+
set_config_value("model", model)
|
|
64
|
+
set_config_value("provider", "auto")
|
|
65
|
+
|
|
66
|
+
prompts = None
|
|
67
|
+
if prompt_file:
|
|
68
|
+
prompts = [
|
|
69
|
+
line.strip() for line in Path(prompt_file).read_text().splitlines() if line.strip()
|
|
70
|
+
]
|
|
71
|
+
print(f"Loaded {len(prompts)} prompts from {prompt_file}")
|
|
72
|
+
ds = [{"prompt": p} for p in prompts]
|
|
73
|
+
else:
|
|
74
|
+
print(f"Loading dataset {dataset_name} ({split})...")
|
|
75
|
+
ds = load_dataset(dataset_name, split=split, streaming=True)
|
|
76
|
+
|
|
77
|
+
output_path = Path(output_file)
|
|
78
|
+
processed_count = 0
|
|
79
|
+
|
|
80
|
+
# Resume logic
|
|
81
|
+
if resume and output_path.exists():
|
|
82
|
+
with open(output_path, "r") as f:
|
|
83
|
+
processed_count = sum(1 for _ in f)
|
|
84
|
+
print(f"Resuming from {processed_count} examples...")
|
|
85
|
+
|
|
86
|
+
# Iterate and generate
|
|
87
|
+
current_idx = 0
|
|
88
|
+
buffer_size = 1 # Write every N examples
|
|
89
|
+
buffer = []
|
|
90
|
+
|
|
91
|
+
for item in ds:
|
|
92
|
+
if current_idx < processed_count:
|
|
93
|
+
current_idx += 1
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
if current_idx >= processed_count + limit:
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
# Extract prompt (handle different dataset formats)
|
|
100
|
+
if prompt_file:
|
|
101
|
+
query = item.get("prompt", "")
|
|
102
|
+
elif column in item:
|
|
103
|
+
query = item[column]
|
|
104
|
+
elif "prompt" in item:
|
|
105
|
+
query = item["prompt"]
|
|
106
|
+
elif "instruction" in item:
|
|
107
|
+
query = item["instruction"]
|
|
108
|
+
else:
|
|
109
|
+
print(f"Skipping item {current_idx}: No suitable text column found.")
|
|
110
|
+
current_idx += 1
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
# Truncate extremely long inputs to keep teacher focus
|
|
114
|
+
query = query[:2000]
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
print(f"[{current_idx}] Generating dialectic for: {query[:50]}...")
|
|
118
|
+
|
|
119
|
+
# We run the dialectic. Ideally, we want the backend to use our SYSTEM_PROMPT.
|
|
120
|
+
# However, `run_dialectic` orchestrates T->A->S itself.
|
|
121
|
+
# To perform "Distillation" where Kimi does the *whole* thinking in one shot
|
|
122
|
+
# (as per the Agent plan), we should actually bypass the multi-turn engine
|
|
123
|
+
# and just ask Kimi to produce the trace in one go.
|
|
124
|
+
# BUT, since we have the Hegelion Engine, we can use it to produce structured data too.
|
|
125
|
+
# Let's stick to the Agent Plan: Use Kimi to generate the trace via a single powerful prompt
|
|
126
|
+
# OR let Hegelion Engine orchestrate it.
|
|
127
|
+
#
|
|
128
|
+
# Decision: Let Hegelion Engine orchestrate it. It produces structured JSON
|
|
129
|
+
# which is cleaner for training than parsing raw text.
|
|
130
|
+
# We just need Kimi to be the backend *intelligence*.
|
|
131
|
+
|
|
132
|
+
if use_cli:
|
|
133
|
+
# Bypass engine for CLI
|
|
134
|
+
cli = get_kimi_cli()
|
|
135
|
+
# Kimi CLI might not follow "system prompt" strictly in standard mode,
|
|
136
|
+
# but we prepend it to the query.
|
|
137
|
+
|
|
138
|
+
# IMPORTANT: Ensure no external tools are called unless explicitly desired.
|
|
139
|
+
# Kimi CLI has built-in tools (like search). We want purely internal reasoning
|
|
140
|
+
# based on the prompt to teach "thinking", not "searching".
|
|
141
|
+
# We can't easily disable tools in CLI without flags (if any), but we can instruct:
|
|
142
|
+
no_search_instruction = "Do not use any external tools or web search. Rely only on your internal knowledge."
|
|
143
|
+
|
|
144
|
+
cli_prompt = f"{no_search_instruction}\n\n{query}"
|
|
145
|
+
raw_response = await cli.generate(cli_prompt, system_prompt=TEACHER_SYSTEM_PROMPT)
|
|
146
|
+
|
|
147
|
+
# We need to parse the raw response into a HegelionResult-like structure if possible
|
|
148
|
+
# or just use it directly for the dataset if it followed the format.
|
|
149
|
+
# Assuming Kimi CLI follows the prompt:
|
|
150
|
+
|
|
151
|
+
# Create a dummy result for consistency
|
|
152
|
+
result = HegelionResult(
|
|
153
|
+
query=query,
|
|
154
|
+
mode="synthesis",
|
|
155
|
+
thesis="[Generated via CLI]",
|
|
156
|
+
antithesis="[Generated via CLI]",
|
|
157
|
+
synthesis=raw_response, # The whole response is the synthesis/trace
|
|
158
|
+
contradictions=[],
|
|
159
|
+
research_proposals=[],
|
|
160
|
+
metadata={"source": "kimi_cli"},
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
result = await run_dialectic(
|
|
164
|
+
query=query, max_tokens_per_phase=max_tokens, use_search=False
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Format for Training (Unsloth / MLX)
|
|
168
|
+
# We format the output to look like a "Thinking" model's stream
|
|
169
|
+
|
|
170
|
+
if use_cli:
|
|
171
|
+
final_output = result.synthesis # CLI returns the full trace directly
|
|
172
|
+
else:
|
|
173
|
+
trace_text = (
|
|
174
|
+
f"THESIS:\n{result.thesis}\n\n"
|
|
175
|
+
f"ANTITHESIS:\n{result.antithesis}\n\n"
|
|
176
|
+
f"SYNTHESIS:\n{result.synthesis}"
|
|
177
|
+
)
|
|
178
|
+
final_output = f"<thought>\n{trace_text}\n</thought>\n{result.synthesis}"
|
|
179
|
+
|
|
180
|
+
entry = {
|
|
181
|
+
"instruction": query,
|
|
182
|
+
"input": "",
|
|
183
|
+
"output": final_output,
|
|
184
|
+
"system": "You are a dialectical reasoning engine.",
|
|
185
|
+
"hegelion_trace": result.to_dict(),
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
buffer.append(json.dumps(entry, ensure_ascii=False))
|
|
189
|
+
|
|
190
|
+
if len(buffer) >= buffer_size:
|
|
191
|
+
with open(output_path, "a", encoding="utf-8") as f:
|
|
192
|
+
f.write("\n".join(buffer) + "\n")
|
|
193
|
+
buffer = []
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
print(f"Error processing item {current_idx}: {e}")
|
|
197
|
+
|
|
198
|
+
current_idx += 1
|
|
199
|
+
|
|
200
|
+
# Flush remaining
|
|
201
|
+
if buffer:
|
|
202
|
+
with open(output_path, "a", encoding="utf-8") as f:
|
|
203
|
+
f.write("\n".join(buffer) + "\n")
|
|
204
|
+
|
|
205
|
+
print(f"Done. Saved to {output_file}")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
if __name__ == "__main__":
|
|
209
|
+
import argparse
|
|
210
|
+
|
|
211
|
+
parser = argparse.ArgumentParser()
|
|
212
|
+
parser.add_argument("--dataset", default="HuggingFaceH4/ultrafeedback_binarized")
|
|
213
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling")
|
|
214
|
+
parser.add_argument("--output", default="hegelion_kimi_data.jsonl")
|
|
215
|
+
parser.add_argument("--limit", type=int, default=10)
|
|
216
|
+
parser.add_argument("--split", default="train", help="Dataset split to use")
|
|
217
|
+
parser.add_argument("--model", help="Teacher model", default="kimi")
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
"--prompt-file", help="Optional newline-delimited prompts to bypass HF datasets"
|
|
220
|
+
)
|
|
221
|
+
args = parser.parse_args()
|
|
222
|
+
|
|
223
|
+
asyncio.run(
|
|
224
|
+
generate_dataset(
|
|
225
|
+
dataset_name=args.dataset,
|
|
226
|
+
output_file=args.output,
|
|
227
|
+
split=args.split,
|
|
228
|
+
limit=args.limit,
|
|
229
|
+
model=args.model,
|
|
230
|
+
prompt_file=args.prompt_file,
|
|
231
|
+
)
|
|
232
|
+
)
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hegelion MLX SCU Trainer
|
|
3
|
+
Combines Apple Silicon optimization (MLX) with Shannon Control Unit (SCU) adaptive regularization.
|
|
4
|
+
|
|
5
|
+
Logic:
|
|
6
|
+
1. Train Loop in MLX
|
|
7
|
+
2. Calculate DataBPT (Loss)
|
|
8
|
+
3. Calculate ParamBPT (LoRA weights complexity)
|
|
9
|
+
4. Update Lambda (Regularization strength) via PID Control
|
|
10
|
+
5. Optimize: Loss = CE + lambda * L2
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
import json
|
|
15
|
+
import argparse
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
# Optional imports
|
|
20
|
+
import psutil
|
|
21
|
+
|
|
22
|
+
# MLX imports
|
|
23
|
+
import mlx.core as mx
|
|
24
|
+
import mlx.nn as nn
|
|
25
|
+
import mlx.optimizers as optim
|
|
26
|
+
from mlx.utils import tree_flatten
|
|
27
|
+
from mlx_lm import load
|
|
28
|
+
from mlx_lm.tuner.utils import linear_to_lora_layers
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# SCU Control Logic (Re-implemented for independence)
|
|
32
|
+
def calculate_data_bpt(loss_nats):
|
|
33
|
+
return loss_nats / math.log(2)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def calculate_param_bpt(model, sigma=0.01, tokens_per_epoch=1000000):
|
|
37
|
+
"""Calculate ParamBPT - optimized to batch operations."""
|
|
38
|
+
param_squares = []
|
|
39
|
+
total_params = 0
|
|
40
|
+
|
|
41
|
+
# Collect all parameter squares first (stays on Neural Engine)
|
|
42
|
+
for name, weight in tree_flatten(model.trainable_parameters()):
|
|
43
|
+
param_squares.append(mx.sum(weight**2))
|
|
44
|
+
total_params += weight.size
|
|
45
|
+
|
|
46
|
+
if total_params == 0:
|
|
47
|
+
return 1e-9, 0.0
|
|
48
|
+
|
|
49
|
+
# Sum all at once, then evaluate once
|
|
50
|
+
param_sum = mx.sum(mx.stack(param_squares)).item() # ✅ GOOD: Single .item() call
|
|
51
|
+
|
|
52
|
+
# Convert to bits
|
|
53
|
+
param_bpt = param_sum / (2 * sigma**2 * tokens_per_epoch * math.log(2))
|
|
54
|
+
return param_bpt, param_sum
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def update_lambda(
|
|
58
|
+
lmbda,
|
|
59
|
+
S_meas,
|
|
60
|
+
S_target,
|
|
61
|
+
integral_term,
|
|
62
|
+
Kp=0.8,
|
|
63
|
+
Ki=0.15,
|
|
64
|
+
deadband=0.002,
|
|
65
|
+
lmin=1e-4,
|
|
66
|
+
lmax=10.0,
|
|
67
|
+
i_min=-0.2,
|
|
68
|
+
i_max=0.2,
|
|
69
|
+
):
|
|
70
|
+
"""PID Controller for Lambda."""
|
|
71
|
+
error = S_meas - S_target
|
|
72
|
+
|
|
73
|
+
if abs(error) <= deadband:
|
|
74
|
+
return lmbda, integral_term * 0.995 # Leak
|
|
75
|
+
|
|
76
|
+
integral_term = integral_term * 0.995
|
|
77
|
+
integral_term = max(i_min, min(i_max, integral_term + Ki * error))
|
|
78
|
+
|
|
79
|
+
control_effort = Kp * error + integral_term
|
|
80
|
+
lmbda_new = lmbda * math.exp(control_effort)
|
|
81
|
+
lmbda_new = max(lmin, min(lmax, lmbda_new))
|
|
82
|
+
|
|
83
|
+
return lmbda_new, integral_term
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def check_system_resources():
|
|
87
|
+
"""Check CPU and memory usage, warn if high."""
|
|
88
|
+
try:
|
|
89
|
+
cpu_percent = psutil.cpu_percent(interval=0.1)
|
|
90
|
+
memory = psutil.virtual_memory()
|
|
91
|
+
|
|
92
|
+
if cpu_percent > 90:
|
|
93
|
+
print(f"⚠️ WARNING: High CPU usage: {cpu_percent:.1f}%")
|
|
94
|
+
|
|
95
|
+
if memory.percent > 85:
|
|
96
|
+
print(f"⚠️ WARNING: High memory usage: {memory.percent:.1f}%")
|
|
97
|
+
|
|
98
|
+
return cpu_percent, memory.percent
|
|
99
|
+
except Exception:
|
|
100
|
+
# If psutil fails, return None values silently
|
|
101
|
+
return None, None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def loss_fn(model, inputs, targets, lmbda, sigma, tokens_per_epoch, lengths):
|
|
105
|
+
# Forward pass
|
|
106
|
+
logits = model(inputs)
|
|
107
|
+
logits = logits.astype(mx.float32)
|
|
108
|
+
|
|
109
|
+
# Cross Entropy Loss
|
|
110
|
+
# Masking padding tokens: we assume inputs are padded with 0 or similar and ignore them?
|
|
111
|
+
# Actually mlx_lm trainer usually handles masking if provided.
|
|
112
|
+
# Here we'll do a standard CE.
|
|
113
|
+
|
|
114
|
+
ce_loss = nn.losses.cross_entropy(logits, targets, reduction="none")
|
|
115
|
+
|
|
116
|
+
# Mask out padding (assuming lengths provided or implicit)
|
|
117
|
+
# For simplicity, average over all non-masked
|
|
118
|
+
mask = targets != -100 # Standard ignore index
|
|
119
|
+
ce_loss = mx.sum(ce_loss * mask) / mx.sum(mask)
|
|
120
|
+
|
|
121
|
+
# SCU Regularization Term
|
|
122
|
+
# L2 = sum(w^2) / (2*sigma^2)
|
|
123
|
+
# We calculate this purely for gradients.
|
|
124
|
+
# ParamBPT calculation is separate but related.
|
|
125
|
+
|
|
126
|
+
# More efficient: collect all squares, then sum
|
|
127
|
+
param_squares = [mx.sum(weight**2) for _, weight in tree_flatten(model.trainable_parameters())]
|
|
128
|
+
l2_sum = mx.sum(mx.stack(param_squares)) if param_squares else mx.array(0.0)
|
|
129
|
+
|
|
130
|
+
reg_term = l2_sum / (2 * sigma**2)
|
|
131
|
+
|
|
132
|
+
# Total Loss = CE + lambda * (Reg_term normalized per token??)
|
|
133
|
+
# In SCU derivation: Loss = CE + lambda * Reg_loss_per_token
|
|
134
|
+
# Reg_loss_per_token = ParamBPT * ln(2) = sum(w^2) / (2*sigma^2 * N)
|
|
135
|
+
# So we divide by tokens_per_epoch (N)
|
|
136
|
+
|
|
137
|
+
total_reg = (lmbda * reg_term) / tokens_per_epoch
|
|
138
|
+
|
|
139
|
+
return ce_loss + total_reg, ce_loss
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def train_scu(args):
|
|
143
|
+
np.random.seed(args.seed)
|
|
144
|
+
mx.random.seed(args.seed)
|
|
145
|
+
|
|
146
|
+
print(f"Loading model: {args.model}")
|
|
147
|
+
model, tokenizer = load(args.model)
|
|
148
|
+
|
|
149
|
+
# Freeze base model before applying LoRA so quantized weights stay constant
|
|
150
|
+
model.freeze()
|
|
151
|
+
|
|
152
|
+
# Apply LoRA
|
|
153
|
+
print("Applying LoRA adapters...")
|
|
154
|
+
# We use the standard config usually passed to mlx_lm
|
|
155
|
+
lora_config = {
|
|
156
|
+
"rank": args.lora_rank,
|
|
157
|
+
"scale": float(args.lora_alpha),
|
|
158
|
+
"dropout": args.lora_dropout,
|
|
159
|
+
# "keys": ["q_proj", "v_proj", "k_proj", "o_proj"] # Let auto-detect work
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# Apply to all layers
|
|
163
|
+
# Auto-detect number of layers or use reasonable default
|
|
164
|
+
try:
|
|
165
|
+
# Try to detect from model structure
|
|
166
|
+
if hasattr(model, "layers"):
|
|
167
|
+
num_layers = len(model.layers)
|
|
168
|
+
elif hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
169
|
+
num_layers = len(model.model.layers)
|
|
170
|
+
else:
|
|
171
|
+
# Reasonable default for 1.5B models
|
|
172
|
+
num_layers = 24
|
|
173
|
+
except Exception:
|
|
174
|
+
num_layers = 24 # Safe default
|
|
175
|
+
|
|
176
|
+
print(f"Applying LoRA to {num_layers} layers")
|
|
177
|
+
linear_to_lora_layers(model, num_layers=num_layers, config=lora_config)
|
|
178
|
+
|
|
179
|
+
# Ensure only LoRA adapters are trainable to avoid gradients through quantized weights
|
|
180
|
+
model.freeze() # freeze newly created modules (base weights stay frozen)
|
|
181
|
+
|
|
182
|
+
def _unfreeze_lora_params(_, module):
|
|
183
|
+
if hasattr(module, "lora_a") and hasattr(module, "lora_b"):
|
|
184
|
+
module.unfreeze(keys=["lora_a", "lora_b"], recurse=False)
|
|
185
|
+
|
|
186
|
+
model.apply_to_modules(_unfreeze_lora_params)
|
|
187
|
+
|
|
188
|
+
n_trainable = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
|
|
189
|
+
print(f"Trainable parameters: {n_trainable}")
|
|
190
|
+
if n_trainable == 0:
|
|
191
|
+
raise ValueError("No trainable parameters! LoRA failed to apply.")
|
|
192
|
+
|
|
193
|
+
# Optimizer
|
|
194
|
+
optimizer = optim.AdamW(learning_rate=args.lr)
|
|
195
|
+
|
|
196
|
+
# SCU State
|
|
197
|
+
lmbda = args.lambda_init
|
|
198
|
+
integral_term = 0.0
|
|
199
|
+
sigma = args.prior_sigma
|
|
200
|
+
|
|
201
|
+
# Data Loading
|
|
202
|
+
print(f"Loading data from {args.data}")
|
|
203
|
+
|
|
204
|
+
# Simple dataset loader
|
|
205
|
+
def load_dataset(path):
|
|
206
|
+
data = []
|
|
207
|
+
with open(path, "r") as f:
|
|
208
|
+
for line in f:
|
|
209
|
+
if not line.strip():
|
|
210
|
+
continue
|
|
211
|
+
obj = json.loads(line)
|
|
212
|
+
text = obj.get("text", "")
|
|
213
|
+
if not text:
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
# Tokenize and append EOS
|
|
217
|
+
ids = tokenizer.encode(text) + [tokenizer.eos_token_id]
|
|
218
|
+
data.append(np.array(ids))
|
|
219
|
+
return data
|
|
220
|
+
|
|
221
|
+
dataset = load_dataset(args.data)
|
|
222
|
+
total_tokens = sum(len(x) for x in dataset)
|
|
223
|
+
print(f"Loaded {len(dataset)} examples | {total_tokens} tokens (pre-truncation)")
|
|
224
|
+
|
|
225
|
+
tokens_per_epoch = args.tokens_per_epoch
|
|
226
|
+
if tokens_per_epoch <= 0:
|
|
227
|
+
tokens_per_epoch = max(1, total_tokens)
|
|
228
|
+
print(f"Auto-setting tokens_per_epoch to {tokens_per_epoch}")
|
|
229
|
+
|
|
230
|
+
# Training Loop
|
|
231
|
+
steps = 0
|
|
232
|
+
max_steps = args.iters
|
|
233
|
+
batch_size = args.batch_size
|
|
234
|
+
|
|
235
|
+
# Prepare function for grad
|
|
236
|
+
# We use nn.value_and_grad which handles trainable parameters automatically
|
|
237
|
+
|
|
238
|
+
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
|
239
|
+
|
|
240
|
+
def step(model, inputs, targets, lmbda):
|
|
241
|
+
(loss, ce_loss), grads = loss_value_and_grad(
|
|
242
|
+
model, inputs, targets, lmbda, sigma, tokens_per_epoch, None
|
|
243
|
+
)
|
|
244
|
+
return loss, ce_loss, grads
|
|
245
|
+
|
|
246
|
+
print("Starting SCU Training...")
|
|
247
|
+
|
|
248
|
+
# Create batches manually for control
|
|
249
|
+
# We simply shuffle and slice
|
|
250
|
+
|
|
251
|
+
while steps < max_steps:
|
|
252
|
+
# Shuffle
|
|
253
|
+
indices = np.random.permutation(len(dataset))
|
|
254
|
+
|
|
255
|
+
for i in range(0, len(dataset), batch_size):
|
|
256
|
+
if steps >= max_steps:
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
batch_idx = indices[i : i + batch_size]
|
|
260
|
+
batch_data = [dataset[k] for k in batch_idx]
|
|
261
|
+
|
|
262
|
+
# Pad batch
|
|
263
|
+
max_len = max(len(x) for x in batch_data)
|
|
264
|
+
# Truncate if too long?
|
|
265
|
+
if max_len > args.max_seq_length:
|
|
266
|
+
max_len = args.max_seq_length
|
|
267
|
+
|
|
268
|
+
inputs_np = np.zeros((len(batch_data), max_len), dtype=np.int32)
|
|
269
|
+
targets_np = np.full(
|
|
270
|
+
(len(batch_data), max_len), -100, dtype=np.int32
|
|
271
|
+
) # -100 for ignore
|
|
272
|
+
|
|
273
|
+
for j, seq in enumerate(batch_data):
|
|
274
|
+
L = min(len(seq), max_len)
|
|
275
|
+
# Causal LM: input is seq[:-1], target is seq[1:]
|
|
276
|
+
# But usually we just feed seq and shift inside or feed (seq[:-1], seq[1:])
|
|
277
|
+
# Let's do (seq[:-1], seq[1:])
|
|
278
|
+
if L < 2:
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
# Inputs
|
|
282
|
+
inputs_np[j, : L - 1] = seq[: L - 1]
|
|
283
|
+
# Targets
|
|
284
|
+
targets_np[j, : L - 1] = seq[1:L]
|
|
285
|
+
|
|
286
|
+
inputs = mx.array(inputs_np)
|
|
287
|
+
targets = mx.array(targets_np)
|
|
288
|
+
|
|
289
|
+
# Step
|
|
290
|
+
(total_loss, ce_loss_val, grads) = step(model, inputs, targets, lmbda)
|
|
291
|
+
|
|
292
|
+
optimizer.update(model, grads)
|
|
293
|
+
mx.eval(model.parameters(), optimizer.state)
|
|
294
|
+
|
|
295
|
+
# SCU Updates (Post-step measurement)
|
|
296
|
+
if steps % args.scu_update_freq == 0:
|
|
297
|
+
# Calculate BPTs
|
|
298
|
+
data_bpt = calculate_data_bpt(ce_loss_val.item())
|
|
299
|
+
param_bpt, _ = calculate_param_bpt(model, sigma, tokens_per_epoch)
|
|
300
|
+
|
|
301
|
+
S_meas = param_bpt / (data_bpt + param_bpt + 1e-9)
|
|
302
|
+
|
|
303
|
+
# Update Lambda
|
|
304
|
+
lmbda_old = lmbda
|
|
305
|
+
lmbda, integral_term = update_lambda(
|
|
306
|
+
lmbda, S_meas, args.target_s, integral_term, Kp=args.kp, Ki=args.ki
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
# Keep lambda constant between updates
|
|
310
|
+
lmbda_old = lmbda
|
|
311
|
+
|
|
312
|
+
# System monitoring
|
|
313
|
+
if steps % 50 == 0:
|
|
314
|
+
cpu, mem = check_system_resources()
|
|
315
|
+
if cpu is not None and cpu > 95:
|
|
316
|
+
print("⚠️ CRITICAL: CPU usage > 95%, consider reducing --scu_update_freq")
|
|
317
|
+
|
|
318
|
+
# Logging
|
|
319
|
+
if steps % 10 == 0:
|
|
320
|
+
print(
|
|
321
|
+
f"Step {steps}: Loss={ce_loss_val.item():.3f}, DataBPT={data_bpt:.3f}, "
|
|
322
|
+
f"ParamBPT={param_bpt:.5f}, S={S_meas:.2%}, λ={lmbda_old:.3f} -> {lmbda:.3f}"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
steps += 1
|
|
326
|
+
|
|
327
|
+
# Save adapter occasionally
|
|
328
|
+
if steps % 100 == 0:
|
|
329
|
+
print(f"Saving adapter to {args.adapter_path}")
|
|
330
|
+
Path(args.adapter_path).mkdir(parents=True, exist_ok=True)
|
|
331
|
+
model.save_weights(str(Path(args.adapter_path) / "weights.safetensors"))
|
|
332
|
+
with open(Path(args.adapter_path) / "adapter_config.json", "w") as f:
|
|
333
|
+
json.dump(lora_config, f, indent=2)
|
|
334
|
+
|
|
335
|
+
# Final Save
|
|
336
|
+
Path(args.adapter_path).mkdir(parents=True, exist_ok=True)
|
|
337
|
+
model.save_weights(str(Path(args.adapter_path) / "weights.safetensors"))
|
|
338
|
+
with open(Path(args.adapter_path) / "adapter_config.json", "w") as f:
|
|
339
|
+
json.dump(lora_config, f, indent=2)
|
|
340
|
+
print("Training complete.")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
if __name__ == "__main__":
|
|
344
|
+
parser = argparse.ArgumentParser()
|
|
345
|
+
parser.add_argument("--model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
|
|
346
|
+
parser.add_argument("--data", required=True)
|
|
347
|
+
parser.add_argument("--adapter_path", default="artifacts/adapters/hegelion_mlx_scu")
|
|
348
|
+
parser.add_argument("--iters", type=int, default=1000)
|
|
349
|
+
parser.add_argument("--batch_size", type=int, default=1)
|
|
350
|
+
parser.add_argument("--lr", type=float, default=1e-5)
|
|
351
|
+
|
|
352
|
+
# LoRA Args
|
|
353
|
+
parser.add_argument("--lora_rank", type=int, default=16)
|
|
354
|
+
parser.add_argument("--lora_alpha", type=int, default=32)
|
|
355
|
+
parser.add_argument("--lora_dropout", type=float, default=0.05)
|
|
356
|
+
|
|
357
|
+
# SCU Args
|
|
358
|
+
parser.add_argument("--target_s", type=float, default=0.01)
|
|
359
|
+
parser.add_argument("--kp", type=float, default=0.8)
|
|
360
|
+
parser.add_argument("--ki", type=float, default=0.15)
|
|
361
|
+
parser.add_argument("--lambda_init", type=float, default=1.0)
|
|
362
|
+
parser.add_argument("--prior_sigma", type=float, default=0.01)
|
|
363
|
+
parser.add_argument(
|
|
364
|
+
"--tokens_per_epoch",
|
|
365
|
+
type=float,
|
|
366
|
+
default=-1,
|
|
367
|
+
help="If <=0, auto-compute from dataset token count",
|
|
368
|
+
)
|
|
369
|
+
parser.add_argument(
|
|
370
|
+
"--scu_update_freq",
|
|
371
|
+
type=int,
|
|
372
|
+
default=10,
|
|
373
|
+
help="Update SCU lambda every N steps (default: 10)",
|
|
374
|
+
)
|
|
375
|
+
parser.add_argument("--max_seq_length", type=int, default=4096)
|
|
376
|
+
parser.add_argument("--seed", type=int, default=42)
|
|
377
|
+
|
|
378
|
+
args = parser.parse_args()
|
|
379
|
+
train_scu(args)
|