hegelion 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. hegelion/__init__.py +45 -0
  2. hegelion/core/__init__.py +29 -0
  3. hegelion/core/agent.py +166 -0
  4. hegelion/core/autocoding_state.py +293 -0
  5. hegelion/core/backends.py +442 -0
  6. hegelion/core/cache.py +92 -0
  7. hegelion/core/config.py +276 -0
  8. hegelion/core/core.py +649 -0
  9. hegelion/core/engine.py +865 -0
  10. hegelion/core/logging_utils.py +67 -0
  11. hegelion/core/models.py +293 -0
  12. hegelion/core/parsing.py +271 -0
  13. hegelion/core/personas.py +81 -0
  14. hegelion/core/prompt_autocoding.py +353 -0
  15. hegelion/core/prompt_dialectic.py +414 -0
  16. hegelion/core/prompts.py +127 -0
  17. hegelion/core/schema.py +67 -0
  18. hegelion/core/validation.py +68 -0
  19. hegelion/council.py +254 -0
  20. hegelion/examples_data/__init__.py +6 -0
  21. hegelion/examples_data/glm4_6_examples.jsonl +2 -0
  22. hegelion/judge.py +230 -0
  23. hegelion/mcp/__init__.py +3 -0
  24. hegelion/mcp/server.py +918 -0
  25. hegelion/scripts/hegelion_agent_cli.py +90 -0
  26. hegelion/scripts/hegelion_bench.py +117 -0
  27. hegelion/scripts/hegelion_cli.py +497 -0
  28. hegelion/scripts/hegelion_dataset.py +99 -0
  29. hegelion/scripts/hegelion_eval.py +137 -0
  30. hegelion/scripts/mcp_setup.py +150 -0
  31. hegelion/search_providers.py +151 -0
  32. hegelion/training/__init__.py +7 -0
  33. hegelion/training/datasets.py +123 -0
  34. hegelion/training/generator.py +232 -0
  35. hegelion/training/mlx_scu_trainer.py +379 -0
  36. hegelion/training/mlx_trainer.py +181 -0
  37. hegelion/training/unsloth_trainer.py +136 -0
  38. hegelion-0.4.0.dist-info/METADATA +295 -0
  39. hegelion-0.4.0.dist-info/RECORD +43 -0
  40. hegelion-0.4.0.dist-info/WHEEL +5 -0
  41. hegelion-0.4.0.dist-info/entry_points.txt +8 -0
  42. hegelion-0.4.0.dist-info/licenses/LICENSE +21 -0
  43. hegelion-0.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,232 @@
1
+ import asyncio
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ from hegelion.core.core import run_dialectic
7
+ from hegelion.core.models import HegelionResult
8
+ from hegelion.core.config import get_config, set_config_value
9
+ from hegelion.training.wrappers.kimi_cli import get_kimi_cli
10
+
11
+ try:
12
+ from datasets import load_dataset
13
+ except ImportError:
14
+ load_dataset = None
15
+
16
+
17
+ # The "Teacher" System Prompt
18
+ # Forces the model to explicate its dialectical reasoning process
19
+ TEACHER_SYSTEM_PROMPT = """You are a dialectical reasoning engine. For every user query, you MUST follow this strict thought process:
20
+
21
+ 1. **THESIS**: Propose the strongest initial argument or solution.
22
+ 2. **ANTITHESIS**: Critically attack the thesis. Find flaws, edge cases, or opposing evidence.
23
+ 3. **SYNTHESIS**: Resolve the conflict. Create a new, stronger solution that incorporates the valid points of both.
24
+
25
+ Format your response exactly like this:
26
+ <thought>
27
+ [The full dialectical trace goes here: Thesis -> Antithesis -> Synthesis]
28
+ </thought>
29
+ [The final answer goes here]"""
30
+
31
+
32
+ async def generate_dataset(
33
+ dataset_name: str,
34
+ output_file: str,
35
+ split: str = "train",
36
+ column: str = "text",
37
+ limit: int = 100,
38
+ resume: bool = True,
39
+ max_tokens: int = 4000,
40
+ model: str = "kimi", # Default to our teacher
41
+ prompt_file: Optional[str] = None,
42
+ ):
43
+ """
44
+ Generate dialectical traces for a HuggingFace dataset using a Teacher model (Kimi).
45
+ """
46
+ if load_dataset is None and prompt_file is None:
47
+ raise ImportError("Please install 'datasets' to use this feature: pip install datasets")
48
+
49
+ # Configure Teacher
50
+ use_cli = False
51
+ if model == "kimi-cli":
52
+ use_cli = True
53
+ print(" configured to use Moonshot AI (Kimi) via CLI wrapper.")
54
+ elif model == "kimi":
55
+ # Ensure Kimi is configured
56
+ config = get_config()
57
+ if not config.moonshot_key:
58
+ raise ValueError("MOONSHOT_API_KEY not found. Set it in your .env or environment.")
59
+ set_config_value("provider", "moonshot")
60
+ set_config_value("model", "kimi-k2-thinking") # Use the reasoning model
61
+ print(" configured to use Moonshot AI (Kimi) as Teacher.")
62
+ elif model:
63
+ set_config_value("model", model)
64
+ set_config_value("provider", "auto")
65
+
66
+ prompts = None
67
+ if prompt_file:
68
+ prompts = [
69
+ line.strip() for line in Path(prompt_file).read_text().splitlines() if line.strip()
70
+ ]
71
+ print(f"Loaded {len(prompts)} prompts from {prompt_file}")
72
+ ds = [{"prompt": p} for p in prompts]
73
+ else:
74
+ print(f"Loading dataset {dataset_name} ({split})...")
75
+ ds = load_dataset(dataset_name, split=split, streaming=True)
76
+
77
+ output_path = Path(output_file)
78
+ processed_count = 0
79
+
80
+ # Resume logic
81
+ if resume and output_path.exists():
82
+ with open(output_path, "r") as f:
83
+ processed_count = sum(1 for _ in f)
84
+ print(f"Resuming from {processed_count} examples...")
85
+
86
+ # Iterate and generate
87
+ current_idx = 0
88
+ buffer_size = 1 # Write every N examples
89
+ buffer = []
90
+
91
+ for item in ds:
92
+ if current_idx < processed_count:
93
+ current_idx += 1
94
+ continue
95
+
96
+ if current_idx >= processed_count + limit:
97
+ break
98
+
99
+ # Extract prompt (handle different dataset formats)
100
+ if prompt_file:
101
+ query = item.get("prompt", "")
102
+ elif column in item:
103
+ query = item[column]
104
+ elif "prompt" in item:
105
+ query = item["prompt"]
106
+ elif "instruction" in item:
107
+ query = item["instruction"]
108
+ else:
109
+ print(f"Skipping item {current_idx}: No suitable text column found.")
110
+ current_idx += 1
111
+ continue
112
+
113
+ # Truncate extremely long inputs to keep teacher focus
114
+ query = query[:2000]
115
+
116
+ try:
117
+ print(f"[{current_idx}] Generating dialectic for: {query[:50]}...")
118
+
119
+ # We run the dialectic. Ideally, we want the backend to use our SYSTEM_PROMPT.
120
+ # However, `run_dialectic` orchestrates T->A->S itself.
121
+ # To perform "Distillation" where Kimi does the *whole* thinking in one shot
122
+ # (as per the Agent plan), we should actually bypass the multi-turn engine
123
+ # and just ask Kimi to produce the trace in one go.
124
+ # BUT, since we have the Hegelion Engine, we can use it to produce structured data too.
125
+ # Let's stick to the Agent Plan: Use Kimi to generate the trace via a single powerful prompt
126
+ # OR let Hegelion Engine orchestrate it.
127
+ #
128
+ # Decision: Let Hegelion Engine orchestrate it. It produces structured JSON
129
+ # which is cleaner for training than parsing raw text.
130
+ # We just need Kimi to be the backend *intelligence*.
131
+
132
+ if use_cli:
133
+ # Bypass engine for CLI
134
+ cli = get_kimi_cli()
135
+ # Kimi CLI might not follow "system prompt" strictly in standard mode,
136
+ # but we prepend it to the query.
137
+
138
+ # IMPORTANT: Ensure no external tools are called unless explicitly desired.
139
+ # Kimi CLI has built-in tools (like search). We want purely internal reasoning
140
+ # based on the prompt to teach "thinking", not "searching".
141
+ # We can't easily disable tools in CLI without flags (if any), but we can instruct:
142
+ no_search_instruction = "Do not use any external tools or web search. Rely only on your internal knowledge."
143
+
144
+ cli_prompt = f"{no_search_instruction}\n\n{query}"
145
+ raw_response = await cli.generate(cli_prompt, system_prompt=TEACHER_SYSTEM_PROMPT)
146
+
147
+ # We need to parse the raw response into a HegelionResult-like structure if possible
148
+ # or just use it directly for the dataset if it followed the format.
149
+ # Assuming Kimi CLI follows the prompt:
150
+
151
+ # Create a dummy result for consistency
152
+ result = HegelionResult(
153
+ query=query,
154
+ mode="synthesis",
155
+ thesis="[Generated via CLI]",
156
+ antithesis="[Generated via CLI]",
157
+ synthesis=raw_response, # The whole response is the synthesis/trace
158
+ contradictions=[],
159
+ research_proposals=[],
160
+ metadata={"source": "kimi_cli"},
161
+ )
162
+ else:
163
+ result = await run_dialectic(
164
+ query=query, max_tokens_per_phase=max_tokens, use_search=False
165
+ )
166
+
167
+ # Format for Training (Unsloth / MLX)
168
+ # We format the output to look like a "Thinking" model's stream
169
+
170
+ if use_cli:
171
+ final_output = result.synthesis # CLI returns the full trace directly
172
+ else:
173
+ trace_text = (
174
+ f"THESIS:\n{result.thesis}\n\n"
175
+ f"ANTITHESIS:\n{result.antithesis}\n\n"
176
+ f"SYNTHESIS:\n{result.synthesis}"
177
+ )
178
+ final_output = f"<thought>\n{trace_text}\n</thought>\n{result.synthesis}"
179
+
180
+ entry = {
181
+ "instruction": query,
182
+ "input": "",
183
+ "output": final_output,
184
+ "system": "You are a dialectical reasoning engine.",
185
+ "hegelion_trace": result.to_dict(),
186
+ }
187
+
188
+ buffer.append(json.dumps(entry, ensure_ascii=False))
189
+
190
+ if len(buffer) >= buffer_size:
191
+ with open(output_path, "a", encoding="utf-8") as f:
192
+ f.write("\n".join(buffer) + "\n")
193
+ buffer = []
194
+
195
+ except Exception as e:
196
+ print(f"Error processing item {current_idx}: {e}")
197
+
198
+ current_idx += 1
199
+
200
+ # Flush remaining
201
+ if buffer:
202
+ with open(output_path, "a", encoding="utf-8") as f:
203
+ f.write("\n".join(buffer) + "\n")
204
+
205
+ print(f"Done. Saved to {output_file}")
206
+
207
+
208
+ if __name__ == "__main__":
209
+ import argparse
210
+
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument("--dataset", default="HuggingFaceH4/ultrafeedback_binarized")
213
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling")
214
+ parser.add_argument("--output", default="hegelion_kimi_data.jsonl")
215
+ parser.add_argument("--limit", type=int, default=10)
216
+ parser.add_argument("--split", default="train", help="Dataset split to use")
217
+ parser.add_argument("--model", help="Teacher model", default="kimi")
218
+ parser.add_argument(
219
+ "--prompt-file", help="Optional newline-delimited prompts to bypass HF datasets"
220
+ )
221
+ args = parser.parse_args()
222
+
223
+ asyncio.run(
224
+ generate_dataset(
225
+ dataset_name=args.dataset,
226
+ output_file=args.output,
227
+ split=args.split,
228
+ limit=args.limit,
229
+ model=args.model,
230
+ prompt_file=args.prompt_file,
231
+ )
232
+ )
@@ -0,0 +1,379 @@
1
+ """
2
+ Hegelion MLX SCU Trainer
3
+ Combines Apple Silicon optimization (MLX) with Shannon Control Unit (SCU) adaptive regularization.
4
+
5
+ Logic:
6
+ 1. Train Loop in MLX
7
+ 2. Calculate DataBPT (Loss)
8
+ 3. Calculate ParamBPT (LoRA weights complexity)
9
+ 4. Update Lambda (Regularization strength) via PID Control
10
+ 5. Optimize: Loss = CE + lambda * L2
11
+ """
12
+
13
+ import math
14
+ import json
15
+ import argparse
16
+ from pathlib import Path
17
+ import numpy as np
18
+
19
+ # Optional imports
20
+ import psutil
21
+
22
+ # MLX imports
23
+ import mlx.core as mx
24
+ import mlx.nn as nn
25
+ import mlx.optimizers as optim
26
+ from mlx.utils import tree_flatten
27
+ from mlx_lm import load
28
+ from mlx_lm.tuner.utils import linear_to_lora_layers
29
+
30
+
31
+ # SCU Control Logic (Re-implemented for independence)
32
+ def calculate_data_bpt(loss_nats):
33
+ return loss_nats / math.log(2)
34
+
35
+
36
+ def calculate_param_bpt(model, sigma=0.01, tokens_per_epoch=1000000):
37
+ """Calculate ParamBPT - optimized to batch operations."""
38
+ param_squares = []
39
+ total_params = 0
40
+
41
+ # Collect all parameter squares first (stays on Neural Engine)
42
+ for name, weight in tree_flatten(model.trainable_parameters()):
43
+ param_squares.append(mx.sum(weight**2))
44
+ total_params += weight.size
45
+
46
+ if total_params == 0:
47
+ return 1e-9, 0.0
48
+
49
+ # Sum all at once, then evaluate once
50
+ param_sum = mx.sum(mx.stack(param_squares)).item() # ✅ GOOD: Single .item() call
51
+
52
+ # Convert to bits
53
+ param_bpt = param_sum / (2 * sigma**2 * tokens_per_epoch * math.log(2))
54
+ return param_bpt, param_sum
55
+
56
+
57
+ def update_lambda(
58
+ lmbda,
59
+ S_meas,
60
+ S_target,
61
+ integral_term,
62
+ Kp=0.8,
63
+ Ki=0.15,
64
+ deadband=0.002,
65
+ lmin=1e-4,
66
+ lmax=10.0,
67
+ i_min=-0.2,
68
+ i_max=0.2,
69
+ ):
70
+ """PID Controller for Lambda."""
71
+ error = S_meas - S_target
72
+
73
+ if abs(error) <= deadband:
74
+ return lmbda, integral_term * 0.995 # Leak
75
+
76
+ integral_term = integral_term * 0.995
77
+ integral_term = max(i_min, min(i_max, integral_term + Ki * error))
78
+
79
+ control_effort = Kp * error + integral_term
80
+ lmbda_new = lmbda * math.exp(control_effort)
81
+ lmbda_new = max(lmin, min(lmax, lmbda_new))
82
+
83
+ return lmbda_new, integral_term
84
+
85
+
86
+ def check_system_resources():
87
+ """Check CPU and memory usage, warn if high."""
88
+ try:
89
+ cpu_percent = psutil.cpu_percent(interval=0.1)
90
+ memory = psutil.virtual_memory()
91
+
92
+ if cpu_percent > 90:
93
+ print(f"⚠️ WARNING: High CPU usage: {cpu_percent:.1f}%")
94
+
95
+ if memory.percent > 85:
96
+ print(f"⚠️ WARNING: High memory usage: {memory.percent:.1f}%")
97
+
98
+ return cpu_percent, memory.percent
99
+ except Exception:
100
+ # If psutil fails, return None values silently
101
+ return None, None
102
+
103
+
104
+ def loss_fn(model, inputs, targets, lmbda, sigma, tokens_per_epoch, lengths):
105
+ # Forward pass
106
+ logits = model(inputs)
107
+ logits = logits.astype(mx.float32)
108
+
109
+ # Cross Entropy Loss
110
+ # Masking padding tokens: we assume inputs are padded with 0 or similar and ignore them?
111
+ # Actually mlx_lm trainer usually handles masking if provided.
112
+ # Here we'll do a standard CE.
113
+
114
+ ce_loss = nn.losses.cross_entropy(logits, targets, reduction="none")
115
+
116
+ # Mask out padding (assuming lengths provided or implicit)
117
+ # For simplicity, average over all non-masked
118
+ mask = targets != -100 # Standard ignore index
119
+ ce_loss = mx.sum(ce_loss * mask) / mx.sum(mask)
120
+
121
+ # SCU Regularization Term
122
+ # L2 = sum(w^2) / (2*sigma^2)
123
+ # We calculate this purely for gradients.
124
+ # ParamBPT calculation is separate but related.
125
+
126
+ # More efficient: collect all squares, then sum
127
+ param_squares = [mx.sum(weight**2) for _, weight in tree_flatten(model.trainable_parameters())]
128
+ l2_sum = mx.sum(mx.stack(param_squares)) if param_squares else mx.array(0.0)
129
+
130
+ reg_term = l2_sum / (2 * sigma**2)
131
+
132
+ # Total Loss = CE + lambda * (Reg_term normalized per token??)
133
+ # In SCU derivation: Loss = CE + lambda * Reg_loss_per_token
134
+ # Reg_loss_per_token = ParamBPT * ln(2) = sum(w^2) / (2*sigma^2 * N)
135
+ # So we divide by tokens_per_epoch (N)
136
+
137
+ total_reg = (lmbda * reg_term) / tokens_per_epoch
138
+
139
+ return ce_loss + total_reg, ce_loss
140
+
141
+
142
+ def train_scu(args):
143
+ np.random.seed(args.seed)
144
+ mx.random.seed(args.seed)
145
+
146
+ print(f"Loading model: {args.model}")
147
+ model, tokenizer = load(args.model)
148
+
149
+ # Freeze base model before applying LoRA so quantized weights stay constant
150
+ model.freeze()
151
+
152
+ # Apply LoRA
153
+ print("Applying LoRA adapters...")
154
+ # We use the standard config usually passed to mlx_lm
155
+ lora_config = {
156
+ "rank": args.lora_rank,
157
+ "scale": float(args.lora_alpha),
158
+ "dropout": args.lora_dropout,
159
+ # "keys": ["q_proj", "v_proj", "k_proj", "o_proj"] # Let auto-detect work
160
+ }
161
+
162
+ # Apply to all layers
163
+ # Auto-detect number of layers or use reasonable default
164
+ try:
165
+ # Try to detect from model structure
166
+ if hasattr(model, "layers"):
167
+ num_layers = len(model.layers)
168
+ elif hasattr(model, "model") and hasattr(model.model, "layers"):
169
+ num_layers = len(model.model.layers)
170
+ else:
171
+ # Reasonable default for 1.5B models
172
+ num_layers = 24
173
+ except Exception:
174
+ num_layers = 24 # Safe default
175
+
176
+ print(f"Applying LoRA to {num_layers} layers")
177
+ linear_to_lora_layers(model, num_layers=num_layers, config=lora_config)
178
+
179
+ # Ensure only LoRA adapters are trainable to avoid gradients through quantized weights
180
+ model.freeze() # freeze newly created modules (base weights stay frozen)
181
+
182
+ def _unfreeze_lora_params(_, module):
183
+ if hasattr(module, "lora_a") and hasattr(module, "lora_b"):
184
+ module.unfreeze(keys=["lora_a", "lora_b"], recurse=False)
185
+
186
+ model.apply_to_modules(_unfreeze_lora_params)
187
+
188
+ n_trainable = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
189
+ print(f"Trainable parameters: {n_trainable}")
190
+ if n_trainable == 0:
191
+ raise ValueError("No trainable parameters! LoRA failed to apply.")
192
+
193
+ # Optimizer
194
+ optimizer = optim.AdamW(learning_rate=args.lr)
195
+
196
+ # SCU State
197
+ lmbda = args.lambda_init
198
+ integral_term = 0.0
199
+ sigma = args.prior_sigma
200
+
201
+ # Data Loading
202
+ print(f"Loading data from {args.data}")
203
+
204
+ # Simple dataset loader
205
+ def load_dataset(path):
206
+ data = []
207
+ with open(path, "r") as f:
208
+ for line in f:
209
+ if not line.strip():
210
+ continue
211
+ obj = json.loads(line)
212
+ text = obj.get("text", "")
213
+ if not text:
214
+ continue
215
+
216
+ # Tokenize and append EOS
217
+ ids = tokenizer.encode(text) + [tokenizer.eos_token_id]
218
+ data.append(np.array(ids))
219
+ return data
220
+
221
+ dataset = load_dataset(args.data)
222
+ total_tokens = sum(len(x) for x in dataset)
223
+ print(f"Loaded {len(dataset)} examples | {total_tokens} tokens (pre-truncation)")
224
+
225
+ tokens_per_epoch = args.tokens_per_epoch
226
+ if tokens_per_epoch <= 0:
227
+ tokens_per_epoch = max(1, total_tokens)
228
+ print(f"Auto-setting tokens_per_epoch to {tokens_per_epoch}")
229
+
230
+ # Training Loop
231
+ steps = 0
232
+ max_steps = args.iters
233
+ batch_size = args.batch_size
234
+
235
+ # Prepare function for grad
236
+ # We use nn.value_and_grad which handles trainable parameters automatically
237
+
238
+ loss_value_and_grad = nn.value_and_grad(model, loss_fn)
239
+
240
+ def step(model, inputs, targets, lmbda):
241
+ (loss, ce_loss), grads = loss_value_and_grad(
242
+ model, inputs, targets, lmbda, sigma, tokens_per_epoch, None
243
+ )
244
+ return loss, ce_loss, grads
245
+
246
+ print("Starting SCU Training...")
247
+
248
+ # Create batches manually for control
249
+ # We simply shuffle and slice
250
+
251
+ while steps < max_steps:
252
+ # Shuffle
253
+ indices = np.random.permutation(len(dataset))
254
+
255
+ for i in range(0, len(dataset), batch_size):
256
+ if steps >= max_steps:
257
+ break
258
+
259
+ batch_idx = indices[i : i + batch_size]
260
+ batch_data = [dataset[k] for k in batch_idx]
261
+
262
+ # Pad batch
263
+ max_len = max(len(x) for x in batch_data)
264
+ # Truncate if too long?
265
+ if max_len > args.max_seq_length:
266
+ max_len = args.max_seq_length
267
+
268
+ inputs_np = np.zeros((len(batch_data), max_len), dtype=np.int32)
269
+ targets_np = np.full(
270
+ (len(batch_data), max_len), -100, dtype=np.int32
271
+ ) # -100 for ignore
272
+
273
+ for j, seq in enumerate(batch_data):
274
+ L = min(len(seq), max_len)
275
+ # Causal LM: input is seq[:-1], target is seq[1:]
276
+ # But usually we just feed seq and shift inside or feed (seq[:-1], seq[1:])
277
+ # Let's do (seq[:-1], seq[1:])
278
+ if L < 2:
279
+ continue
280
+
281
+ # Inputs
282
+ inputs_np[j, : L - 1] = seq[: L - 1]
283
+ # Targets
284
+ targets_np[j, : L - 1] = seq[1:L]
285
+
286
+ inputs = mx.array(inputs_np)
287
+ targets = mx.array(targets_np)
288
+
289
+ # Step
290
+ (total_loss, ce_loss_val, grads) = step(model, inputs, targets, lmbda)
291
+
292
+ optimizer.update(model, grads)
293
+ mx.eval(model.parameters(), optimizer.state)
294
+
295
+ # SCU Updates (Post-step measurement)
296
+ if steps % args.scu_update_freq == 0:
297
+ # Calculate BPTs
298
+ data_bpt = calculate_data_bpt(ce_loss_val.item())
299
+ param_bpt, _ = calculate_param_bpt(model, sigma, tokens_per_epoch)
300
+
301
+ S_meas = param_bpt / (data_bpt + param_bpt + 1e-9)
302
+
303
+ # Update Lambda
304
+ lmbda_old = lmbda
305
+ lmbda, integral_term = update_lambda(
306
+ lmbda, S_meas, args.target_s, integral_term, Kp=args.kp, Ki=args.ki
307
+ )
308
+ else:
309
+ # Keep lambda constant between updates
310
+ lmbda_old = lmbda
311
+
312
+ # System monitoring
313
+ if steps % 50 == 0:
314
+ cpu, mem = check_system_resources()
315
+ if cpu is not None and cpu > 95:
316
+ print("⚠️ CRITICAL: CPU usage > 95%, consider reducing --scu_update_freq")
317
+
318
+ # Logging
319
+ if steps % 10 == 0:
320
+ print(
321
+ f"Step {steps}: Loss={ce_loss_val.item():.3f}, DataBPT={data_bpt:.3f}, "
322
+ f"ParamBPT={param_bpt:.5f}, S={S_meas:.2%}, λ={lmbda_old:.3f} -> {lmbda:.3f}"
323
+ )
324
+
325
+ steps += 1
326
+
327
+ # Save adapter occasionally
328
+ if steps % 100 == 0:
329
+ print(f"Saving adapter to {args.adapter_path}")
330
+ Path(args.adapter_path).mkdir(parents=True, exist_ok=True)
331
+ model.save_weights(str(Path(args.adapter_path) / "weights.safetensors"))
332
+ with open(Path(args.adapter_path) / "adapter_config.json", "w") as f:
333
+ json.dump(lora_config, f, indent=2)
334
+
335
+ # Final Save
336
+ Path(args.adapter_path).mkdir(parents=True, exist_ok=True)
337
+ model.save_weights(str(Path(args.adapter_path) / "weights.safetensors"))
338
+ with open(Path(args.adapter_path) / "adapter_config.json", "w") as f:
339
+ json.dump(lora_config, f, indent=2)
340
+ print("Training complete.")
341
+
342
+
343
+ if __name__ == "__main__":
344
+ parser = argparse.ArgumentParser()
345
+ parser.add_argument("--model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
346
+ parser.add_argument("--data", required=True)
347
+ parser.add_argument("--adapter_path", default="artifacts/adapters/hegelion_mlx_scu")
348
+ parser.add_argument("--iters", type=int, default=1000)
349
+ parser.add_argument("--batch_size", type=int, default=1)
350
+ parser.add_argument("--lr", type=float, default=1e-5)
351
+
352
+ # LoRA Args
353
+ parser.add_argument("--lora_rank", type=int, default=16)
354
+ parser.add_argument("--lora_alpha", type=int, default=32)
355
+ parser.add_argument("--lora_dropout", type=float, default=0.05)
356
+
357
+ # SCU Args
358
+ parser.add_argument("--target_s", type=float, default=0.01)
359
+ parser.add_argument("--kp", type=float, default=0.8)
360
+ parser.add_argument("--ki", type=float, default=0.15)
361
+ parser.add_argument("--lambda_init", type=float, default=1.0)
362
+ parser.add_argument("--prior_sigma", type=float, default=0.01)
363
+ parser.add_argument(
364
+ "--tokens_per_epoch",
365
+ type=float,
366
+ default=-1,
367
+ help="If <=0, auto-compute from dataset token count",
368
+ )
369
+ parser.add_argument(
370
+ "--scu_update_freq",
371
+ type=int,
372
+ default=10,
373
+ help="Update SCU lambda every N steps (default: 10)",
374
+ )
375
+ parser.add_argument("--max_seq_length", type=int, default=4096)
376
+ parser.add_argument("--seed", type=int, default=42)
377
+
378
+ args = parser.parse_args()
379
+ train_scu(args)