omnius 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/README.md +4959 -0
  2. package/dist/index.d.ts +6 -0
  3. package/dist/index.js +630665 -0
  4. package/dist/launcher.cjs +78 -0
  5. package/dist/postinstall-daemon.cjs +776 -0
  6. package/dist/preinstall.cjs +92 -0
  7. package/dist/scripts/autoresearch-prepare.py +459 -0
  8. package/dist/scripts/autoresearch-train.py +661 -0
  9. package/dist/scripts/crawlee-scraper.py +358 -0
  10. package/dist/scripts/live-nemotron.py +478 -0
  11. package/dist/scripts/live-whisper.py +242 -0
  12. package/dist/scripts/ocr-advanced.py +571 -0
  13. package/dist/scripts/start-moondream.py +112 -0
  14. package/dist/scripts/tor/UPSTREAM-README.md +148 -0
  15. package/dist/scripts/tor/destroy_tor.sh +29 -0
  16. package/dist/scripts/tor/tor_setup.sh +163 -0
  17. package/dist/scripts/transcribe-file.py +63 -0
  18. package/dist/scripts/web_scrape.py +1295 -0
  19. package/npm-shrinkwrap.json +7412 -0
  20. package/package.json +142 -0
  21. package/prompts/agentic/system-large.md +569 -0
  22. package/prompts/agentic/system-medium.md +211 -0
  23. package/prompts/agentic/system-small.md +114 -0
  24. package/prompts/compaction/context-compaction.md +44 -0
  25. package/prompts/personality/level-1-minimal.md +3 -0
  26. package/prompts/personality/level-2-concise.md +3 -0
  27. package/prompts/personality/level-4-explanatory.md +3 -0
  28. package/prompts/personality/level-5-thorough.md +3 -0
  29. package/prompts/personality/level-autist.md +3 -0
  30. package/prompts/personality/level-stark.md +3 -0
  31. package/prompts/runners/dispatcher.md +24 -0
  32. package/prompts/runners/editor.md +44 -0
  33. package/prompts/runners/evaluator.md +30 -0
  34. package/prompts/runners/merge-summary.md +9 -0
  35. package/prompts/runners/normalizer.md +23 -0
  36. package/prompts/runners/planner.md +33 -0
  37. package/prompts/runners/scout.md +39 -0
  38. package/prompts/runners/verifier.md +36 -0
  39. package/prompts/skill-builder/seed-analysis.md +30 -0
  40. package/prompts/skill-builder/skill-expansion.md +76 -0
  41. package/prompts/skill-builder/skill-validation.md +31 -0
  42. package/prompts/templates/analysis.md +14 -0
  43. package/prompts/templates/code-review.md +16 -0
  44. package/prompts/templates/code.md +13 -0
  45. package/prompts/templates/document.md +13 -0
  46. package/prompts/templates/error-diagnosis.md +14 -0
  47. package/prompts/templates/general.md +9 -0
  48. package/prompts/templates/plan.md +15 -0
  49. package/prompts/templates/system.md +16 -0
  50. package/prompts/tui/dmn-gather.md +128 -0
  51. package/prompts/tui/dream-consolidate.md +48 -0
  52. package/prompts/tui/dream-lucid-eval.md +17 -0
  53. package/prompts/tui/dream-lucid-implement.md +14 -0
  54. package/prompts/tui/dream-stages.md +19 -0
  55. package/prompts/tui/emotion-behavioral.md +2 -0
  56. package/prompts/tui/emotion-center.md +12 -0
  57. package/voices/personaplex/OverBarn.pt +0 -0
  58. package/voices/personaplex/clone-voice.py +384 -0
  59. package/voices/personaplex/dequant-loader.py +174 -0
  60. package/voices/personaplex/quantize-weights.py +167 -0
@@ -0,0 +1,384 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ clone-voice.py — High-fidelity PersonaPlex voice cloning.
4
+
5
+ Applies LuxTTS-inspired preprocessing before embedding extraction:
6
+ 1. Resample to 24kHz mono
7
+ 2. Noise reduction (spectral gating)
8
+ 3. Silence trimming (energy-based VAD)
9
+ 4. LUFS normalization to -20 dBFS (tuned for PersonaPlex)
10
+ 5. Duration optimization (trim to 4-8s sweet spot)
11
+ 6. Multi-segment embedding averaging for long clips
12
+
13
+ Usage:
14
+ python clone-voice.py --input voice.wav --name MyVoice [--device cuda]
15
+ python clone-voice.py --input voice.wav --name MyVoice --segments 3 # multi-segment averaging
16
+ """
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ import logging
22
+ import numpy as np
23
+
24
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
25
+ log = logging.getLogger(__name__)
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Audio preprocessing (LuxTTS-inspired)
30
+ # ---------------------------------------------------------------------------
31
+
32
+ def preprocess_audio(input_path: str, target_sr: int = 24000,
33
+ target_lufs: float = -20.0,
34
+ min_duration: float = 2.0,
35
+ max_duration: float = 8.0,
36
+ denoise: bool = True) -> "torch.Tensor":
37
+ """
38
+ Full preprocessing pipeline:
39
+ 1. Load + resample to 24kHz mono
40
+ 2. Noise reduction via spectral gating
41
+ 3. Silence trimming (leading/trailing)
42
+ 4. LUFS normalization
43
+ 5. Duration clipping to sweet spot
44
+ Returns: [1, T] tensor at target_sr
45
+ """
46
+ import torch
47
+ import torchaudio
48
+
49
+ log.info(f" Loading: {input_path}")
50
+ wav, sr = torchaudio.load(input_path)
51
+
52
+ # Stereo → mono
53
+ if wav.shape[0] > 1:
54
+ wav = wav.mean(dim=0, keepdim=True)
55
+ log.info(f" Converted stereo → mono")
56
+
57
+ # Resample
58
+ if sr != target_sr:
59
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
60
+ sr = target_sr
61
+ log.info(f" Resampled to {target_sr}Hz")
62
+
63
+ audio_np = wav.squeeze().numpy()
64
+ orig_duration = len(audio_np) / sr
65
+
66
+ # Step 1: Noise reduction
67
+ if denoise:
68
+ try:
69
+ import noisereduce as nr
70
+ log.info(f" Denoising (spectral gating)...")
71
+ audio_np = nr.reduce_noise(
72
+ y=audio_np,
73
+ sr=sr,
74
+ prop_decrease=0.7, # moderate — preserve voice character
75
+ n_fft=2048,
76
+ hop_length=512,
77
+ )
78
+ log.info(f" Denoised")
79
+ except ImportError:
80
+ log.info(f" noisereduce not available, skipping denoise")
81
+
82
+ # Step 2: Silence trimming (energy-based)
83
+ frame_length = int(0.025 * sr) # 25ms frames
84
+ hop = int(0.010 * sr) # 10ms hop
85
+ energy = []
86
+ for i in range(0, len(audio_np) - frame_length, hop):
87
+ frame = audio_np[i:i + frame_length]
88
+ energy.append(np.sqrt(np.mean(frame ** 2)))
89
+ energy = np.array(energy)
90
+
91
+ if len(energy) > 0:
92
+ # Threshold: 5% of peak energy (aggressive silence removal)
93
+ threshold = max(np.percentile(energy, 10), np.max(energy) * 0.05)
94
+ voiced = np.where(energy > threshold)[0]
95
+ if len(voiced) > 0:
96
+ start_sample = max(0, voiced[0] * hop - int(0.1 * sr)) # 100ms margin
97
+ end_sample = min(len(audio_np), (voiced[-1] + 1) * hop + int(0.1 * sr))
98
+ audio_np = audio_np[start_sample:end_sample]
99
+ trimmed_dur = len(audio_np) / sr
100
+ if trimmed_dur < orig_duration - 0.2:
101
+ log.info(f" Trimmed silence: {orig_duration:.1f}s → {trimmed_dur:.1f}s")
102
+
103
+ # Step 3: LUFS normalization
104
+ try:
105
+ import pyloudnorm as pyln
106
+ meter = pyln.Meter(sr)
107
+ current_lufs = meter.integrated_loudness(audio_np)
108
+ if not np.isinf(current_lufs) and not np.isnan(current_lufs):
109
+ audio_np = pyln.normalize.loudness(audio_np, current_lufs, target_lufs)
110
+ log.info(f" Normalized: {current_lufs:.1f} → {target_lufs:.1f} LUFS")
111
+ except Exception as e:
112
+ # Fallback: simple RMS normalization
113
+ rms = np.sqrt(np.mean(audio_np ** 2))
114
+ if rms > 0:
115
+ target_rms = 10 ** (target_lufs / 20) # approximate
116
+ audio_np = audio_np * (target_rms / rms)
117
+ log.info(f" RMS-normalized (fallback)")
118
+
119
+ # Step 4: Duration clipping
120
+ duration = len(audio_np) / sr
121
+ if duration > max_duration:
122
+ # Take the most energetic segment, not just the start
123
+ segment_samples = int(max_duration * sr)
124
+ # Sliding window energy to find the best segment
125
+ window_energy = []
126
+ step = int(0.5 * sr) # 500ms steps
127
+ for i in range(0, len(audio_np) - segment_samples, step):
128
+ seg = audio_np[i:i + segment_samples]
129
+ window_energy.append((i, np.sqrt(np.mean(seg ** 2))))
130
+ if window_energy:
131
+ best_start = max(window_energy, key=lambda x: x[1])[0]
132
+ audio_np = audio_np[best_start:best_start + segment_samples]
133
+ log.info(f" Selected best {max_duration:.0f}s segment (from {duration:.1f}s)")
134
+ elif duration < min_duration:
135
+ log.warning(f" Audio too short ({duration:.1f}s < {min_duration:.0f}s minimum)")
136
+
137
+ final_duration = len(audio_np) / sr
138
+ log.info(f" Final: {final_duration:.1f}s, {sr}Hz mono")
139
+
140
+ wav_out = torch.from_numpy(audio_np).float().unsqueeze(0) # [1, T]
141
+ return wav_out
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Multi-segment embedding averaging
146
+ # ---------------------------------------------------------------------------
147
+
148
+ def clone_voice_multiseg(input_wav: str, output_name: str, device: str = "cuda",
149
+ hf_repo: str = "nvidia/personaplex-7b-v1",
150
+ cpu_offload: bool = False,
151
+ n_segments: int = 1,
152
+ target_lufs: float = -20.0,
153
+ max_seg_duration: float = 8.0):
154
+ """
155
+ Clone voice with optional multi-segment averaging.
156
+
157
+ For n_segments > 1 on long clips:
158
+ - Split preprocessed audio into overlapping segments
159
+ - Extract embeddings from each segment independently
160
+ - Average embeddings in latent space
161
+ - Use final segment's KV-cache (most voice information)
162
+ """
163
+ import torch
164
+ import torchaudio
165
+ from huggingface_hub import hf_hub_download
166
+ from moshi.models import loaders
167
+ from moshi.models.lm import LMGen
168
+
169
+ voices_dir = os.path.join(os.path.dirname(__file__), "custom_voices")
170
+ os.makedirs(voices_dir, exist_ok=True)
171
+ output_pt = os.path.join(voices_dir, f"{output_name}.pt")
172
+
173
+ if os.path.exists(output_pt):
174
+ log.info(f"Voice '{output_name}' already exists at {output_pt}")
175
+ log.info("Delete it first if you want to re-clone.")
176
+ return output_pt
177
+
178
+ # ── Preprocessing ────────────────────────────────────────────────────
179
+ log.info("Step 1: Preprocessing audio...")
180
+ # For multi-segment, preprocess with longer max duration
181
+ effective_max = max_seg_duration if n_segments == 1 else max_seg_duration * n_segments
182
+ preprocessed = preprocess_audio(
183
+ input_wav,
184
+ target_lufs=target_lufs,
185
+ max_duration=effective_max,
186
+ denoise=True,
187
+ )
188
+
189
+ # Save preprocessed audio for inspection
190
+ prep_path = os.path.join(voices_dir, f"{output_name}_preprocessed.wav")
191
+ torchaudio.save(prep_path, preprocessed, 24000)
192
+ log.info(f" Saved preprocessed audio: {prep_path}")
193
+
194
+ # ── Split into segments if requested ─────────────────────────────────
195
+ sr = 24000
196
+ total_samples = preprocessed.shape[1]
197
+ total_duration = total_samples / sr
198
+
199
+ if n_segments > 1 and total_duration > max_seg_duration:
200
+ seg_samples = int(max_seg_duration * sr)
201
+ overlap = int(1.0 * sr) # 1s overlap
202
+ stride = max(seg_samples - overlap, int(2.0 * sr))
203
+ segments = []
204
+ for i in range(0, total_samples - seg_samples + 1, stride):
205
+ seg = preprocessed[:, i:i + seg_samples]
206
+ segments.append(seg)
207
+ if len(segments) >= n_segments:
208
+ break
209
+ # If we didn't get enough, just use what we have
210
+ if len(segments) == 0:
211
+ segments = [preprocessed]
212
+ log.info(f" Split into {len(segments)} segments ({max_seg_duration:.0f}s each, 1s overlap)")
213
+ else:
214
+ segments = [preprocessed]
215
+
216
+ # ── Load models ──────────────────────────────────────────────────────
217
+ log.info("\nStep 2: Loading models...")
218
+ mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME)
219
+ mimi = loaders.get_mimi(mimi_weight, device)
220
+ mimi.streaming_forever(1)
221
+
222
+ moshi_weight = hf_hub_download(hf_repo, loaders.MOSHI_NAME)
223
+ lm = loaders.get_moshi_lm(moshi_weight, device=device, cpu_offload=cpu_offload)
224
+ lm.eval()
225
+
226
+ frame_size = int(mimi.sample_rate / mimi.frame_rate)
227
+ other_mimi = loaders.get_mimi(mimi_weight, device)
228
+ other_mimi.streaming_forever(1)
229
+
230
+ log.info(" Warming up...")
231
+ from moshi.offline import warmup
232
+ lm_gen = LMGen(
233
+ lm,
234
+ audio_silence_frame_cnt=int(0.5 * mimi.frame_rate),
235
+ sample_rate=mimi.sample_rate,
236
+ device=device,
237
+ frame_rate=mimi.frame_rate,
238
+ save_voice_prompt_embeddings=True,
239
+ use_sampling=False,
240
+ temp=0.8,
241
+ temp_text=0.7,
242
+ top_k=250,
243
+ top_k_text=25,
244
+ )
245
+ lm_gen.streaming_forever(1)
246
+ warmup(mimi, other_mimi, lm_gen, device, frame_size)
247
+
248
+ # ── Extract embeddings per segment ───────────────────────────────────
249
+ all_embeddings = []
250
+ final_cache = None
251
+
252
+ for seg_idx, seg_audio in enumerate(segments):
253
+ seg_dur = seg_audio.shape[1] / sr
254
+ log.info(f"\nStep 3.{seg_idx + 1}: Extracting embeddings from segment {seg_idx + 1}/{len(segments)} ({seg_dur:.1f}s)...")
255
+
256
+ # Write temp WAV for this segment
257
+ tmp_seg = os.path.join(voices_dir, f"_tmp_seg_{seg_idx}.wav")
258
+ torchaudio.save(tmp_seg, seg_audio, sr)
259
+
260
+ # Reset state for each segment
261
+ mimi.reset_streaming()
262
+ other_mimi.reset_streaming()
263
+ lm_gen.reset_streaming()
264
+
265
+ # Load and process
266
+ lm_gen.load_voice_prompt(tmp_seg)
267
+
268
+ # Trick save path so .pt goes where we want
269
+ lm_gen.voice_prompt = os.path.join(voices_dir, f"_tmp_seg_{seg_idx}.wav")
270
+ lm_gen._step_voice_prompt(mimi)
271
+
272
+ # Collect embeddings
273
+ auto_saved = os.path.join(voices_dir, f"_tmp_seg_{seg_idx}.pt")
274
+ if os.path.exists(auto_saved):
275
+ state = torch.load(auto_saved, map_location="cpu", weights_only=False)
276
+ all_embeddings.append(state["embeddings"])
277
+ final_cache = state["cache"]
278
+ os.remove(auto_saved)
279
+ log.info(f" Extracted {state['embeddings'].shape[0]} frames")
280
+ else:
281
+ log.warning(f" Segment {seg_idx + 1} failed to produce embeddings")
282
+
283
+ # Clean up temp
284
+ os.remove(tmp_seg)
285
+
286
+ if not all_embeddings:
287
+ log.error("No embeddings extracted!")
288
+ return None
289
+
290
+ # ── Average embeddings across segments ───────────────────────────────
291
+ if len(all_embeddings) > 1:
292
+ log.info(f"\nStep 4: Averaging {len(all_embeddings)} segment embeddings...")
293
+ # Pad/truncate to same length, then average
294
+ min_frames = min(e.shape[0] for e in all_embeddings)
295
+ truncated = [e[:min_frames] for e in all_embeddings]
296
+ stacked = torch.stack(truncated, dim=0) # [N_segs, min_frames, 1, 1, 4096]
297
+ averaged = stacked.mean(dim=0) # [min_frames, 1, 1, 4096]
298
+ log.info(f" Averaged: {averaged.shape[0]} frames from {len(all_embeddings)} segments")
299
+ else:
300
+ averaged = all_embeddings[0]
301
+
302
+ # ── Save final .pt ───────────────────────────────────────────────────
303
+ torch.save({
304
+ "embeddings": averaged.detach().cpu(),
305
+ "cache": final_cache.detach().cpu() if final_cache is not None else torch.zeros(1, 17, 4, dtype=torch.int64),
306
+ }, output_pt)
307
+
308
+ # Verify
309
+ state = torch.load(output_pt, map_location="cpu", weights_only=False)
310
+ emb_shape = state["embeddings"].shape
311
+ log.info(f"\nVoice cloned successfully!")
312
+ log.info(f" Output: {output_pt}")
313
+ log.info(f" Embeddings: {emb_shape} ({emb_shape[0]} frames, ~{emb_shape[0] / 12.5:.1f}s)")
314
+ log.info(f" Segments averaged: {len(all_embeddings)}")
315
+ log.info(f" Preprocessing: denoise + silence trim + {target_lufs:.0f} LUFS normalize")
316
+
317
+ return output_pt
318
+
319
+
320
+ def main():
321
+ parser = argparse.ArgumentParser(
322
+ description="High-fidelity PersonaPlex voice cloning"
323
+ )
324
+ parser.add_argument(
325
+ "--input", "-i", required=True,
326
+ help="Input audio file (WAV, MP3, FLAC — any format torchaudio supports)"
327
+ )
328
+ parser.add_argument(
329
+ "--name", "-n", required=True,
330
+ help="Name for the cloned voice (e.g. 'MyVoice')"
331
+ )
332
+ parser.add_argument(
333
+ "--device", "-d", default="cuda",
334
+ help="Device to run on (default: cuda)"
335
+ )
336
+ parser.add_argument(
337
+ "--cpu-offload", action="store_true",
338
+ help="Offload to CPU if GPU memory is insufficient"
339
+ )
340
+ parser.add_argument(
341
+ "--hf-repo", default="nvidia/personaplex-7b-v1",
342
+ help="HuggingFace model repo"
343
+ )
344
+ parser.add_argument(
345
+ "--segments", "-s", type=int, default=1,
346
+ help="Number of segments to average (1 = single pass, 3+ = multi-segment averaging for longer clips)"
347
+ )
348
+ parser.add_argument(
349
+ "--lufs", type=float, default=-20.0,
350
+ help="Target LUFS normalization (default: -20, PersonaPlex default is -24)"
351
+ )
352
+ parser.add_argument(
353
+ "--max-duration", type=float, default=8.0,
354
+ help="Max duration per segment in seconds (default: 8, built-in voices use ~4)"
355
+ )
356
+ parser.add_argument(
357
+ "--no-denoise", action="store_true",
358
+ help="Skip noise reduction (if reference is already clean)"
359
+ )
360
+
361
+ args = parser.parse_args()
362
+
363
+ if not os.path.exists(args.input):
364
+ print(f"Error: Input file not found: {args.input}")
365
+ sys.exit(1)
366
+
367
+ import torch
368
+ with torch.no_grad():
369
+ result = clone_voice_multiseg(
370
+ input_wav=args.input,
371
+ output_name=args.name,
372
+ device=args.device,
373
+ hf_repo=args.hf_repo,
374
+ cpu_offload=args.cpu_offload,
375
+ n_segments=args.segments,
376
+ target_lufs=args.lufs,
377
+ max_seg_duration=args.max_duration,
378
+ )
379
+
380
+ sys.exit(0 if result else 1)
381
+
382
+
383
+ if __name__ == "__main__":
384
+ main()
@@ -0,0 +1,174 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ dequant-loader.py — Pre-dequantize quantized PersonaPlex weights to bf16 cache.
4
+
5
+ For NF4 (INT4) or TurboQuant 2-bit weights, dequantizes to a temporary
6
+ bf16 safetensors file that moshi.server can load natively.
7
+
8
+ Usage:
9
+ python dequant-loader.py --input model-nf4.safetensors --output /tmp/model-bf16.safetensors
10
+ python dequant-loader.py --input model-turbo2bit.safetensors --output /tmp/model-bf16.safetensors
11
+
12
+ The output file can then be passed to moshi.server via --moshi-weight.
13
+ """
14
+
15
+ import os, sys, math, time
16
+ import torch
17
+ from safetensors.torch import load_file, save_file
18
+
19
+ NF2_CENTROIDS = torch.tensor([-1.5104, -0.4528, 0.4528, 1.5104])
20
+
21
+
22
+ def fast_wht(x):
23
+ """Vectorized Walsh-Hadamard Transform."""
24
+ n = x.shape[-1]
25
+ h = 1
26
+ while h < n:
27
+ x_view = x.view(*x.shape[:-1], -1, 2, h)
28
+ a = x_view[..., 0, :].clone()
29
+ b = x_view[..., 1, :].clone()
30
+ x_view[..., 0, :] = a + b
31
+ x_view[..., 1, :] = a - b
32
+ x = x_view.reshape(*x.shape)
33
+ h *= 2
34
+ return x / math.sqrt(n)
35
+
36
+
37
+ def detect_format(state):
38
+ """Detect if weights are NF4 (INT4), TurboQuant 2-bit, or plain."""
39
+ has_scales = any(k.endswith(".__scales__") for k in state)
40
+ has_packed = any(k.endswith(".packed") for k in state)
41
+ if has_packed:
42
+ return "turbo2bit"
43
+ if has_scales:
44
+ return "nf4"
45
+ return "plain"
46
+
47
+
48
+ def dequant_nf4(state):
49
+ """Dequantize INT4 NF4 weights."""
50
+ result = {}
51
+ processed = set()
52
+
53
+ for name in list(state.keys()):
54
+ if name.endswith(".__scales__") or name.endswith(".__shape__") or name.endswith(".__numel__"):
55
+ continue
56
+ if name in processed:
57
+ continue
58
+
59
+ scales_key = f"{name}.__scales__"
60
+ if scales_key in state:
61
+ packed = state[name]
62
+ scales = state[scales_key].float()
63
+ shape = state[f"{name}.__shape__"].tolist()
64
+ numel = state[f"{name}.__numel__"].item()
65
+ group_size = 64
66
+
67
+ lo = (packed & 0x0F).to(torch.int8) - 8
68
+ hi = ((packed >> 4) & 0x0F).to(torch.int8) - 8
69
+ unpacked = torch.zeros(packed.numel() * 2, dtype=torch.float32)
70
+ unpacked[0::2] = lo.float()
71
+ unpacked[1::2] = hi.float()
72
+
73
+ n_groups = scales.numel()
74
+ groups = unpacked[:n_groups * group_size].reshape(n_groups, group_size)
75
+ deq = (groups * scales.unsqueeze(1)).reshape(-1)[:numel]
76
+
77
+ orig_shape = [s for s in shape if s > 0]
78
+ result[name] = deq.reshape(orig_shape).to(torch.bfloat16)
79
+ processed.add(name)
80
+ else:
81
+ result[name] = state[name].to(torch.bfloat16)
82
+ processed.add(name)
83
+
84
+ return result
85
+
86
+
87
+ def dequant_turbo2bit(state):
88
+ """Dequantize TurboQuant 2-bit (NF2 + WHT) weights."""
89
+ result = {}
90
+ processed = set()
91
+
92
+ for name in list(state.keys()):
93
+ if any(name.endswith(f".{s}") for s in ["packed", "scales", "shape", "numel", "gs", "np2"]):
94
+ continue
95
+ if name in processed:
96
+ continue
97
+
98
+ packed_key = f"{name}.packed"
99
+ if packed_key in state:
100
+ gs = state[f"{name}.gs"].item()
101
+ gs_pow2 = state[f"{name}.np2"].item()
102
+ numel = state[f"{name}.numel"].item()
103
+ shape = [s for s in state[f"{name}.shape"].tolist() if s > 0]
104
+ scales = state[f"{name}.scales"].float()
105
+ packed = state[packed_key]
106
+ n_groups = scales.numel()
107
+
108
+ # Unpack 2-bit
109
+ p = packed.reshape(n_groups, gs // 4)
110
+ codes = torch.zeros(n_groups, gs, dtype=torch.long)
111
+ for i in range(4):
112
+ codes[:, i::4] = (p >> (2 * i)) & 0x03
113
+
114
+ dequant = NF2_CENTROIDS[codes]
115
+
116
+ # Inverse WHT
117
+ if gs_pow2 > gs:
118
+ dequant = torch.cat([dequant, torch.zeros(n_groups, gs_pow2 - gs)], dim=1)
119
+ dequant = fast_wht(dequant)
120
+ dequant = dequant[:, :gs]
121
+
122
+ dequant = dequant * scales.unsqueeze(1)
123
+ result[name] = dequant.reshape(-1)[:numel].reshape(shape).to(torch.bfloat16)
124
+ processed.add(name)
125
+ else:
126
+ result[name] = state[name].to(torch.bfloat16)
127
+ processed.add(name)
128
+
129
+ return result
130
+
131
+
132
+ def main():
133
+ import argparse
134
+ parser = argparse.ArgumentParser(description="Dequantize PersonaPlex weights to bf16")
135
+ parser.add_argument("--input", "-i", required=True, help="Quantized safetensors file")
136
+ parser.add_argument("--output", "-o", required=True, help="Output bf16 safetensors file")
137
+ parser.add_argument("--device", "-d", default="cpu", help="Device for dequantization")
138
+ args = parser.parse_args()
139
+
140
+ if not os.path.exists(args.input):
141
+ print(f"Error: {args.input} not found")
142
+ sys.exit(1)
143
+
144
+ # Skip if output already exists and is newer than input
145
+ if os.path.exists(args.output) and os.path.getmtime(args.output) > os.path.getmtime(args.input):
146
+ print(f"Cached: {args.output} is up to date")
147
+ sys.exit(0)
148
+
149
+ print(f"Loading {args.input}...")
150
+ t0 = time.time()
151
+ state = load_file(args.input, device=args.device)
152
+
153
+ fmt = detect_format(state)
154
+ print(f"Format: {fmt}")
155
+
156
+ if fmt == "nf4":
157
+ result = dequant_nf4(state)
158
+ elif fmt == "turbo2bit":
159
+ result = dequant_turbo2bit(state)
160
+ else:
161
+ print("Already plain bf16/fp16 — copying")
162
+ result = {k: v.to(torch.bfloat16) for k, v in state.items()}
163
+
164
+ t1 = time.time()
165
+ print(f"Dequantized {len(result)} tensors in {t1-t0:.1f}s")
166
+
167
+ print(f"Saving to {args.output}...")
168
+ save_file(result, args.output)
169
+ size_gb = os.path.getsize(args.output) / 1024**3
170
+ print(f"Done: {size_gb:.2f} GB")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()