open-agents-ai 0.185.28 → 0.185.30
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +242 -22
- package/package.json +1 -1
- package/voices/personaplex/quantize-weights.py +167 -0
package/dist/index.js
CHANGED
|
@@ -26415,29 +26415,18 @@ If you're stuck, try a completely different approach. Do NOT repeat what failed
|
|
|
26415
26415
|
}
|
|
26416
26416
|
let imageRecovered = false;
|
|
26417
26417
|
if (/invalid image|image.*invalid|image_url.*unsupported|does not support.*image|image.*not supported/i.test(errMsg)) {
|
|
26418
|
-
this.
|
|
26419
|
-
|
|
26420
|
-
|
|
26421
|
-
|
|
26422
|
-
|
|
26423
|
-
|
|
26424
|
-
|
|
26425
|
-
|
|
26426
|
-
|
|
26427
|
-
|
|
26428
|
-
} else {
|
|
26429
|
-
msg2.content = "[Image was here but backend doesn't support images]";
|
|
26430
|
-
}
|
|
26418
|
+
imageRecovered = await this._recoverFromImageError(messages, chatRequest, turn);
|
|
26419
|
+
if (imageRecovered) {
|
|
26420
|
+
try {
|
|
26421
|
+
const imgRetry = this.options.streamEnabled && this.hasStreamingSupport() ? await this.streamingRequest(chatRequest, turn) : await this.backend.chatCompletion(chatRequest);
|
|
26422
|
+
response = imgRetry;
|
|
26423
|
+
} catch (imgRetryErr) {
|
|
26424
|
+
const msg2 = imgRetryErr instanceof Error ? imgRetryErr.message : String(imgRetryErr);
|
|
26425
|
+
this.emit({ type: "error", content: `Retry after image fallback also failed: ${msg2}`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
26426
|
+
imageRecovered = false;
|
|
26427
|
+
break;
|
|
26431
26428
|
}
|
|
26432
|
-
}
|
|
26433
|
-
chatRequest.messages = messages;
|
|
26434
|
-
try {
|
|
26435
|
-
const imgRetry = this.options.streamEnabled && this.hasStreamingSupport() ? await this.streamingRequest(chatRequest, turn) : await this.backend.chatCompletion(chatRequest);
|
|
26436
|
-
response = imgRetry;
|
|
26437
|
-
imageRecovered = true;
|
|
26438
|
-
} catch (imgRetryErr) {
|
|
26439
|
-
const msg2 = imgRetryErr instanceof Error ? imgRetryErr.message : String(imgRetryErr);
|
|
26440
|
-
this.emit({ type: "error", content: `Retry without images also failed: ${msg2}`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
26429
|
+
} else {
|
|
26441
26430
|
break;
|
|
26442
26431
|
}
|
|
26443
26432
|
}
|
|
@@ -28475,6 +28464,172 @@ ${transcript}`
|
|
|
28475
28464
|
return true;
|
|
28476
28465
|
return false;
|
|
28477
28466
|
}
|
|
28467
|
+
/**
|
|
28468
|
+
* Graceful image error recovery chain:
|
|
28469
|
+
* 1. Downconvert images (resize to ≤512px, JPEG compress) and retry inline
|
|
28470
|
+
* 2. Describe images via moondream/Ollama vision → replace image_url with text description
|
|
28471
|
+
* 3. Last resort: strip images, keep text context
|
|
28472
|
+
*
|
|
28473
|
+
* Mutates messages in-place. Returns true if messages were successfully transformed.
|
|
28474
|
+
*/
|
|
28475
|
+
async _recoverFromImageError(messages, chatRequest, turn) {
|
|
28476
|
+
const imageEntries = [];
|
|
28477
|
+
for (let mi = 0; mi < messages.length; mi++) {
|
|
28478
|
+
const msg = messages[mi];
|
|
28479
|
+
if (!Array.isArray(msg.content))
|
|
28480
|
+
continue;
|
|
28481
|
+
for (let pi = 0; pi < msg.content.length; pi++) {
|
|
28482
|
+
const part = msg.content[pi];
|
|
28483
|
+
if (part.type === "image_url" && part.image_url?.url) {
|
|
28484
|
+
imageEntries.push({ msgIdx: mi, partIdx: pi, dataUrl: part.image_url.url });
|
|
28485
|
+
}
|
|
28486
|
+
}
|
|
28487
|
+
}
|
|
28488
|
+
if (imageEntries.length === 0)
|
|
28489
|
+
return false;
|
|
28490
|
+
this.emit({ type: "status", content: `Image rejected \u2014 trying downconversion (${imageEntries.length} image(s))...`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28491
|
+
let downconverted = false;
|
|
28492
|
+
try {
|
|
28493
|
+
for (const entry of imageEntries) {
|
|
28494
|
+
const { dataUrl } = entry;
|
|
28495
|
+
if (!dataUrl.startsWith("data:"))
|
|
28496
|
+
continue;
|
|
28497
|
+
const commaIdx = dataUrl.indexOf(",");
|
|
28498
|
+
if (commaIdx < 0)
|
|
28499
|
+
continue;
|
|
28500
|
+
const rawBase64 = dataUrl.slice(commaIdx + 1);
|
|
28501
|
+
const buffer = Buffer.from(rawBase64, "base64");
|
|
28502
|
+
let resizedBase64 = null;
|
|
28503
|
+
try {
|
|
28504
|
+
const { execSync: execSync35 } = await import("node:child_process");
|
|
28505
|
+
const { writeFileSync: writeFileSync30, readFileSync: readFileSync45, unlinkSync: unlinkSync13 } = await import("node:fs");
|
|
28506
|
+
const { join: join77 } = await import("node:path");
|
|
28507
|
+
const { tmpdir: tmpdir11 } = await import("node:os");
|
|
28508
|
+
const tmpIn = join77(tmpdir11(), `oa_img_in_${Date.now()}.png`);
|
|
28509
|
+
const tmpOut = join77(tmpdir11(), `oa_img_out_${Date.now()}.jpg`);
|
|
28510
|
+
writeFileSync30(tmpIn, buffer);
|
|
28511
|
+
execSync35(`python3 -c "
|
|
28512
|
+
from PIL import Image
|
|
28513
|
+
img = Image.open('${tmpIn}')
|
|
28514
|
+
img.thumbnail((512, 512), Image.LANCZOS)
|
|
28515
|
+
img = img.convert('RGB')
|
|
28516
|
+
img.save('${tmpOut}', 'JPEG', quality=75)
|
|
28517
|
+
"`, { timeout: 1e4, stdio: "pipe" });
|
|
28518
|
+
const resizedBuf = readFileSync45(tmpOut);
|
|
28519
|
+
resizedBase64 = `data:image/jpeg;base64,${resizedBuf.toString("base64")}`;
|
|
28520
|
+
try {
|
|
28521
|
+
unlinkSync13(tmpIn);
|
|
28522
|
+
} catch {
|
|
28523
|
+
}
|
|
28524
|
+
try {
|
|
28525
|
+
unlinkSync13(tmpOut);
|
|
28526
|
+
} catch {
|
|
28527
|
+
}
|
|
28528
|
+
} catch {
|
|
28529
|
+
}
|
|
28530
|
+
if (resizedBase64) {
|
|
28531
|
+
const msg = messages[entry.msgIdx];
|
|
28532
|
+
const parts = msg.content;
|
|
28533
|
+
parts[entry.partIdx] = { type: "image_url", image_url: { url: resizedBase64 } };
|
|
28534
|
+
downconverted = true;
|
|
28535
|
+
}
|
|
28536
|
+
}
|
|
28537
|
+
} catch {
|
|
28538
|
+
}
|
|
28539
|
+
if (downconverted) {
|
|
28540
|
+
chatRequest.messages = messages;
|
|
28541
|
+
this.emit({ type: "status", content: `Downconverted images to 512px JPEG \u2014 retrying`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28542
|
+
return true;
|
|
28543
|
+
}
|
|
28544
|
+
this.emit({ type: "status", content: `Downconversion unavailable \u2014 describing images via vision model...`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28545
|
+
const ollamaHost = process.env["OLLAMA_HOST"] || "http://127.0.0.1:11434";
|
|
28546
|
+
let described = false;
|
|
28547
|
+
for (const entry of imageEntries) {
|
|
28548
|
+
const { dataUrl } = entry;
|
|
28549
|
+
if (!dataUrl.startsWith("data:"))
|
|
28550
|
+
continue;
|
|
28551
|
+
const commaIdx = dataUrl.indexOf(",");
|
|
28552
|
+
if (commaIdx < 0)
|
|
28553
|
+
continue;
|
|
28554
|
+
const rawBase64 = dataUrl.slice(commaIdx + 1);
|
|
28555
|
+
try {
|
|
28556
|
+
const model = process.env["OLLAMA_VISION_MODEL"] || "moondream";
|
|
28557
|
+
let res = await fetch(`${ollamaHost}/api/generate`, {
|
|
28558
|
+
method: "POST",
|
|
28559
|
+
headers: { "Content-Type": "application/json" },
|
|
28560
|
+
body: JSON.stringify({
|
|
28561
|
+
model,
|
|
28562
|
+
prompt: "Describe this image in detail. Include text content, UI elements, code, errors, and any relevant visual information.",
|
|
28563
|
+
images: [rawBase64],
|
|
28564
|
+
stream: false
|
|
28565
|
+
}),
|
|
28566
|
+
signal: AbortSignal.timeout(6e4)
|
|
28567
|
+
});
|
|
28568
|
+
if (!res.ok && model === "moondream" && res.status === 404) {
|
|
28569
|
+
this.emit({ type: "status", content: `Pulling moondream vision model...`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28570
|
+
try {
|
|
28571
|
+
const { execSync: execSync35 } = await import("node:child_process");
|
|
28572
|
+
execSync35("ollama pull moondream", { timeout: 3e5, stdio: "pipe" });
|
|
28573
|
+
res = await fetch(`${ollamaHost}/api/generate`, {
|
|
28574
|
+
method: "POST",
|
|
28575
|
+
headers: { "Content-Type": "application/json" },
|
|
28576
|
+
body: JSON.stringify({
|
|
28577
|
+
model,
|
|
28578
|
+
prompt: "Describe this image in detail. Include text content, UI elements, code, errors, and any relevant visual information.",
|
|
28579
|
+
images: [rawBase64],
|
|
28580
|
+
stream: false
|
|
28581
|
+
}),
|
|
28582
|
+
signal: AbortSignal.timeout(6e4)
|
|
28583
|
+
});
|
|
28584
|
+
} catch {
|
|
28585
|
+
}
|
|
28586
|
+
}
|
|
28587
|
+
if (res.ok) {
|
|
28588
|
+
const data = await res.json();
|
|
28589
|
+
const description = data.response?.trim();
|
|
28590
|
+
if (description && description.length > 20) {
|
|
28591
|
+
const msg = messages[entry.msgIdx];
|
|
28592
|
+
const parts = msg.content;
|
|
28593
|
+
parts[entry.partIdx] = {
|
|
28594
|
+
type: "text",
|
|
28595
|
+
text: `[Image description from vision model]:
|
|
28596
|
+
${description}`
|
|
28597
|
+
};
|
|
28598
|
+
described = true;
|
|
28599
|
+
this.emit({ type: "status", content: `Image described (${description.length} chars) \u2014 replacing inline`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28600
|
+
}
|
|
28601
|
+
}
|
|
28602
|
+
} catch {
|
|
28603
|
+
}
|
|
28604
|
+
}
|
|
28605
|
+
if (described) {
|
|
28606
|
+
for (const msg of messages) {
|
|
28607
|
+
if (Array.isArray(msg.content)) {
|
|
28608
|
+
const parts = msg.content;
|
|
28609
|
+
const allText = parts.every((p) => p.type === "text");
|
|
28610
|
+
if (allText && parts.length === 1 && parts[0].text) {
|
|
28611
|
+
msg.content = parts[0].text;
|
|
28612
|
+
}
|
|
28613
|
+
}
|
|
28614
|
+
}
|
|
28615
|
+
chatRequest.messages = messages;
|
|
28616
|
+
this.emit({ type: "status", content: `Images replaced with descriptions \u2014 retrying`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28617
|
+
return true;
|
|
28618
|
+
}
|
|
28619
|
+
this.emit({ type: "status", content: `No vision model available \u2014 stripping images (text context preserved)`, timestamp: (/* @__PURE__ */ new Date()).toISOString() });
|
|
28620
|
+
for (const msg of messages) {
|
|
28621
|
+
if (Array.isArray(msg.content)) {
|
|
28622
|
+
const textParts = msg.content.filter((p) => p.type !== "image_url");
|
|
28623
|
+
if (textParts.length > 0) {
|
|
28624
|
+
msg.content = textParts.length === 1 && textParts[0].text ? textParts[0].text : textParts;
|
|
28625
|
+
} else {
|
|
28626
|
+
msg.content = "[Image was provided but could not be processed \u2014 no vision model available]";
|
|
28627
|
+
}
|
|
28628
|
+
}
|
|
28629
|
+
}
|
|
28630
|
+
chatRequest.messages = messages;
|
|
28631
|
+
return true;
|
|
28632
|
+
}
|
|
28478
28633
|
/**
|
|
28479
28634
|
* Retry a failed model request up to 3 times with exponential backoff.
|
|
28480
28635
|
* Returns the response on success, or null if all retries failed.
|
|
@@ -41142,6 +41297,18 @@ function detectPersonaPlexCapability() {
|
|
|
41142
41297
|
const [gpuName, vramMB] = nvsmi.split("\n")[0].split(", ");
|
|
41143
41298
|
const vramGB = parseInt(vramMB ?? "0", 10) / 1024;
|
|
41144
41299
|
if (vramGB < 16) {
|
|
41300
|
+
const isJetson = /orin|tegra|jetson/i.test(gpuName ?? "");
|
|
41301
|
+
if (isJetson) {
|
|
41302
|
+
try {
|
|
41303
|
+
const memInfo = execSync27("grep MemTotal /proc/meminfo", { encoding: "utf8", timeout: 3e3, stdio: "pipe" });
|
|
41304
|
+
const memKB = parseInt(memInfo.match(/(\d+)/)?.[1] ?? "0", 10);
|
|
41305
|
+
const totalGB = memKB / 1024 / 1024;
|
|
41306
|
+
if (totalGB >= 32) {
|
|
41307
|
+
return { supported: true, reason: `Jetson unified memory (${totalGB.toFixed(0)}GB total)`, gpuName: gpuName ?? "", vramGB: totalGB };
|
|
41308
|
+
}
|
|
41309
|
+
} catch {
|
|
41310
|
+
}
|
|
41311
|
+
}
|
|
41145
41312
|
return { supported: false, reason: `GPU has ${vramGB.toFixed(1)}GB VRAM (need \u226516GB)`, gpuName: gpuName ?? "", vramGB };
|
|
41146
41313
|
}
|
|
41147
41314
|
try {
|
|
@@ -41197,6 +41364,14 @@ async function installPersonaPlex(onInfo) {
|
|
|
41197
41364
|
}
|
|
41198
41365
|
const pip = process.platform === "win32" ? join54(venvDir, "Scripts", "pip.exe") : join54(venvDir, "bin", "pip");
|
|
41199
41366
|
const python = process.platform === "win32" ? join54(venvDir, "Scripts", "python.exe") : join54(venvDir, "bin", "python3");
|
|
41367
|
+
let arch2 = "";
|
|
41368
|
+
try {
|
|
41369
|
+
arch2 = execSync27("uname -m", { encoding: "utf8", timeout: 3e3, stdio: "pipe" }).trim();
|
|
41370
|
+
} catch {
|
|
41371
|
+
}
|
|
41372
|
+
const isAarch64 = arch2 === "aarch64" || arch2 === "arm64";
|
|
41373
|
+
if (isAarch64)
|
|
41374
|
+
log(`Detected ARM64 platform (${arch2}) \u2014 Jetson/ARM install path`);
|
|
41200
41375
|
log("Checking system dependencies (libopus)...");
|
|
41201
41376
|
try {
|
|
41202
41377
|
if (process.platform === "linux") {
|
|
@@ -41206,12 +41381,43 @@ async function installPersonaPlex(onInfo) {
|
|
|
41206
41381
|
}
|
|
41207
41382
|
} catch {
|
|
41208
41383
|
}
|
|
41384
|
+
if (isAarch64) {
|
|
41385
|
+
log("ARM64: Checking Rust toolchain for sphn build...");
|
|
41386
|
+
try {
|
|
41387
|
+
execSync27("rustc --version", { timeout: 5e3, stdio: "pipe" });
|
|
41388
|
+
} catch {
|
|
41389
|
+
log("ARM64: Installing Rust toolchain (needed for sphn audio codec)...");
|
|
41390
|
+
try {
|
|
41391
|
+
execSync27("curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y", { timeout: 12e4, stdio: "pipe" });
|
|
41392
|
+
} catch (e) {
|
|
41393
|
+
log(`Rust install failed: ${e instanceof Error ? e.message : String(e)}`);
|
|
41394
|
+
log("Install Rust manually: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh");
|
|
41395
|
+
return false;
|
|
41396
|
+
}
|
|
41397
|
+
}
|
|
41398
|
+
try {
|
|
41399
|
+
execSync27(`"${pip}" install --quiet maturin`, { timeout: 6e4, stdio: "pipe" });
|
|
41400
|
+
} catch {
|
|
41401
|
+
}
|
|
41402
|
+
}
|
|
41209
41403
|
log("Installing PersonaPlex (moshi package)...");
|
|
41210
41404
|
const repoDir = join54(PERSONAPLEX_DIR, "personaplex-repo");
|
|
41211
41405
|
try {
|
|
41212
41406
|
if (!existsSync37(repoDir)) {
|
|
41213
41407
|
execSync27(`git clone https://github.com/NVIDIA/personaplex.git "${repoDir}"`, { timeout: 12e4, stdio: "pipe" });
|
|
41214
41408
|
}
|
|
41409
|
+
if (isAarch64) {
|
|
41410
|
+
log("ARM64: Building sphn from source (Opus codec bindings)...");
|
|
41411
|
+
try {
|
|
41412
|
+
const rustEnv = `export PATH="$HOME/.cargo/bin:$PATH" &&`;
|
|
41413
|
+
execSync27(`${rustEnv} "${pip}" install --quiet --no-binary sphn sphn`, { timeout: 3e5, stdio: "pipe", shell: "/bin/bash" });
|
|
41414
|
+
log("ARM64: sphn built successfully");
|
|
41415
|
+
} catch (e) {
|
|
41416
|
+
log(`ARM64: sphn build failed \u2014 ${e instanceof Error ? e.message : String(e)}`);
|
|
41417
|
+
log("Ensure Rust, libopus-dev, and cmake are installed.");
|
|
41418
|
+
return false;
|
|
41419
|
+
}
|
|
41420
|
+
}
|
|
41215
41421
|
execSync27(`"${pip}" install --quiet "${join54(repoDir, "moshi")}/."`, { timeout: 3e5, stdio: "pipe" });
|
|
41216
41422
|
} catch (err) {
|
|
41217
41423
|
log(`Moshi install failed: ${err instanceof Error ? err.message : String(err)}`);
|
|
@@ -41239,7 +41445,21 @@ async function installPersonaPlex(onInfo) {
|
|
|
41239
41445
|
}
|
|
41240
41446
|
} catch {
|
|
41241
41447
|
}
|
|
41448
|
+
if (isAarch64) {
|
|
41449
|
+
log("ARM64: Installing bitsandbytes for INT4 inference...");
|
|
41450
|
+
try {
|
|
41451
|
+
execSync27(`"${pip}" install --quiet bitsandbytes`, { timeout: 12e4, stdio: "pipe" });
|
|
41452
|
+
} catch {
|
|
41453
|
+
}
|
|
41454
|
+
}
|
|
41455
|
+
try {
|
|
41456
|
+
execSync27(`"${pip}" install --quiet pyloudnorm noisereduce torchaudio`, { timeout: 12e4, stdio: "pipe" });
|
|
41457
|
+
} catch {
|
|
41458
|
+
}
|
|
41242
41459
|
log("PersonaPlex installed. Model will download on first launch (~14GB).");
|
|
41460
|
+
if (isAarch64) {
|
|
41461
|
+
log("ARM64: On first run, weights will load in INT4 mode for real-time performance.");
|
|
41462
|
+
}
|
|
41243
41463
|
writeFileSync16(join54(PERSONAPLEX_DIR, "model_ready"), (/* @__PURE__ */ new Date()).toISOString());
|
|
41244
41464
|
log("PersonaPlex installed successfully.");
|
|
41245
41465
|
return true;
|
package/package.json
CHANGED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
quantize-weights.py — Quantize PersonaPlex 7B weights to INT4 (NF4) for edge devices.
|
|
4
|
+
|
|
5
|
+
Creates a ~3.5GB quantized checkpoint from the ~14GB bf16 weights.
|
|
6
|
+
The quantized model runs 3-4x faster on memory-bandwidth-limited devices
|
|
7
|
+
like Jetson AGX Orin while maintaining voice quality.
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
python quantize-weights.py [--device cuda] [--output personaplex-7b-nf4.safetensors]
|
|
11
|
+
|
|
12
|
+
Requirements:
|
|
13
|
+
pip install bitsandbytes safetensors torch
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
22
|
+
log = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def quantize_model(device: str = "cuda", output_path: str = None):
|
|
26
|
+
"""Quantize PersonaPlex 7B to NF4 (4-bit Normal Float)"""
|
|
27
|
+
import torch
|
|
28
|
+
from huggingface_hub import hf_hub_download
|
|
29
|
+
from safetensors.torch import load_file, save_file
|
|
30
|
+
|
|
31
|
+
hf_repo = "nvidia/personaplex-7b-v1"
|
|
32
|
+
|
|
33
|
+
# 1) Download original weights
|
|
34
|
+
log.info("Downloading PersonaPlex 7B weights...")
|
|
35
|
+
weight_path = hf_hub_download(hf_repo, "model.safetensors")
|
|
36
|
+
log.info(f" Weights: {weight_path}")
|
|
37
|
+
log.info(f" Size: {os.path.getsize(weight_path) / 1024**3:.1f} GB")
|
|
38
|
+
|
|
39
|
+
# 2) Load state dict
|
|
40
|
+
log.info("Loading state dict...")
|
|
41
|
+
state_dict = load_file(weight_path, device="cpu")
|
|
42
|
+
log.info(f" Loaded {len(state_dict)} tensors")
|
|
43
|
+
|
|
44
|
+
# 3) Quantize each weight tensor to INT4 using block-wise NF4
|
|
45
|
+
try:
|
|
46
|
+
import bitsandbytes as bnb
|
|
47
|
+
from bitsandbytes.functional import quantize_nf4, dequantize_nf4
|
|
48
|
+
HAS_BNB = True
|
|
49
|
+
except ImportError:
|
|
50
|
+
HAS_BNB = False
|
|
51
|
+
log.info(" bitsandbytes not available — using manual INT4 quantization")
|
|
52
|
+
|
|
53
|
+
quantized_state = {}
|
|
54
|
+
quant_meta = {} # Store quantization parameters for dequantization
|
|
55
|
+
total_original = 0
|
|
56
|
+
total_quantized = 0
|
|
57
|
+
skipped = 0
|
|
58
|
+
|
|
59
|
+
for name, tensor in state_dict.items():
|
|
60
|
+
original_bytes = tensor.numel() * tensor.element_size()
|
|
61
|
+
total_original += original_bytes
|
|
62
|
+
|
|
63
|
+
# Only quantize large weight matrices (≥1024 elements, 2D)
|
|
64
|
+
# Skip biases, norms, embeddings, small tensors
|
|
65
|
+
should_quantize = (
|
|
66
|
+
tensor.ndim >= 2
|
|
67
|
+
and tensor.numel() >= 1024
|
|
68
|
+
and not any(skip in name for skip in [
|
|
69
|
+
"norm", "bias", "embed", "positional", "rope",
|
|
70
|
+
"depformer_emb", "depformer_in",
|
|
71
|
+
])
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if not should_quantize:
|
|
75
|
+
quantized_state[name] = tensor.to(torch.float16).contiguous()
|
|
76
|
+
total_quantized += tensor.numel() * 2 # fp16
|
|
77
|
+
skipped += 1
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# Reshape to 2D for quantization
|
|
81
|
+
orig_shape = tensor.shape
|
|
82
|
+
flat = tensor.reshape(-1).float()
|
|
83
|
+
|
|
84
|
+
if HAS_BNB:
|
|
85
|
+
# Use bitsandbytes NF4 quantization
|
|
86
|
+
quant_tensor, quant_state = bnb.functional.quantize_4bit(
|
|
87
|
+
flat, quant_type="nf4", compress_statistics=True,
|
|
88
|
+
)
|
|
89
|
+
# Store the quantized bytes + metadata for reconstruction
|
|
90
|
+
quantized_state[name] = quant_tensor.contiguous()
|
|
91
|
+
quant_meta[f"{name}.__quant_state__"] = torch.tensor(
|
|
92
|
+
list(orig_shape) + [0] * (4 - len(orig_shape)),
|
|
93
|
+
dtype=torch.int64,
|
|
94
|
+
)
|
|
95
|
+
# Store absmax for dequantization
|
|
96
|
+
if hasattr(quant_state, 'absmax'):
|
|
97
|
+
quantized_state[f"{name}.__absmax__"] = quant_state.absmax.contiguous()
|
|
98
|
+
if hasattr(quant_state, 'quant_map'):
|
|
99
|
+
quantized_state[f"{name}.__quant_map__"] = quant_state.quant_map.contiguous()
|
|
100
|
+
total_quantized += quant_tensor.numel()
|
|
101
|
+
else:
|
|
102
|
+
# Manual symmetric INT4 quantization (no bitsandbytes)
|
|
103
|
+
# Block size 64 for good accuracy
|
|
104
|
+
block_size = 64
|
|
105
|
+
n_blocks = (flat.numel() + block_size - 1) // block_size
|
|
106
|
+
padded = torch.zeros(n_blocks * block_size)
|
|
107
|
+
padded[:flat.numel()] = flat
|
|
108
|
+
|
|
109
|
+
blocks = padded.reshape(n_blocks, block_size)
|
|
110
|
+
scales = blocks.abs().max(dim=1).values / 7.0 # INT4 range: -8 to 7
|
|
111
|
+
scales = scales.clamp(min=1e-8)
|
|
112
|
+
|
|
113
|
+
# Quantize to INT4 (stored as INT8 pairs)
|
|
114
|
+
quantized_blocks = torch.round(blocks / scales.unsqueeze(1)).clamp(-8, 7).to(torch.int8)
|
|
115
|
+
|
|
116
|
+
# Pack two INT4 values into one INT8
|
|
117
|
+
packed = torch.zeros(n_blocks, block_size // 2, dtype=torch.uint8)
|
|
118
|
+
for i in range(block_size // 2):
|
|
119
|
+
low = (quantized_blocks[:, 2 * i] + 8).to(torch.uint8)
|
|
120
|
+
high = (quantized_blocks[:, 2 * i + 1] + 8).to(torch.uint8)
|
|
121
|
+
packed[:, i] = low | (high << 4)
|
|
122
|
+
|
|
123
|
+
quantized_state[name] = packed.reshape(-1).contiguous()
|
|
124
|
+
quantized_state[f"{name}.__scales__"] = scales.to(torch.float16).contiguous()
|
|
125
|
+
quant_meta[f"{name}.__quant_state__"] = torch.tensor(
|
|
126
|
+
list(orig_shape) + [0] * (4 - len(orig_shape)) + [block_size, flat.numel()],
|
|
127
|
+
dtype=torch.int64,
|
|
128
|
+
)
|
|
129
|
+
total_quantized += packed.numel() + scales.numel() * 2
|
|
130
|
+
|
|
131
|
+
# Add metadata tensors
|
|
132
|
+
quantized_state.update(quant_meta)
|
|
133
|
+
|
|
134
|
+
# 4) Save quantized weights
|
|
135
|
+
if output_path is None:
|
|
136
|
+
output_path = os.path.join(os.path.dirname(weight_path), "model-nf4.safetensors")
|
|
137
|
+
|
|
138
|
+
log.info(f"\nSaving quantized weights to: {output_path}")
|
|
139
|
+
save_file(quantized_state, output_path)
|
|
140
|
+
|
|
141
|
+
final_size = os.path.getsize(output_path)
|
|
142
|
+
compression = total_original / max(final_size, 1)
|
|
143
|
+
|
|
144
|
+
log.info(f"\nQuantization complete!")
|
|
145
|
+
log.info(f" Original: {total_original / 1024**3:.1f} GB (bf16)")
|
|
146
|
+
log.info(f" Quantized: {final_size / 1024**3:.1f} GB (NF4)")
|
|
147
|
+
log.info(f" Compression: {compression:.1f}x")
|
|
148
|
+
log.info(f" Tensors quantized: {len(state_dict) - skipped}/{len(state_dict)}")
|
|
149
|
+
log.info(f" Tensors kept fp16: {skipped} (norms, biases, embeddings)")
|
|
150
|
+
log.info(f"\nUse --quantized flag with PersonaPlex server for INT4 inference")
|
|
151
|
+
|
|
152
|
+
return output_path
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def main():
|
|
156
|
+
parser = argparse.ArgumentParser(description="Quantize PersonaPlex 7B to INT4 NF4")
|
|
157
|
+
parser.add_argument("--device", default="cuda", help="Device for quantization")
|
|
158
|
+
parser.add_argument("--output", "-o", default=None, help="Output path for quantized weights")
|
|
159
|
+
args = parser.parse_args()
|
|
160
|
+
|
|
161
|
+
import torch
|
|
162
|
+
with torch.no_grad():
|
|
163
|
+
quantize_model(device=args.device, output_path=args.output)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
if __name__ == "__main__":
|
|
167
|
+
main()
|