node-llama-cpp 3.0.0-beta.11 → 3.0.0-beta.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -4
- package/dist/bindings/Llama.d.ts +1 -0
- package/dist/bindings/Llama.js +7 -1
- package/dist/bindings/Llama.js.map +1 -1
- package/dist/bindings/getLlama.d.ts +7 -1
- package/dist/bindings/getLlama.js +6 -3
- package/dist/bindings/getLlama.js.map +1 -1
- package/dist/bindings/types.d.ts +1 -0
- package/dist/bindings/types.js.map +1 -1
- package/dist/bindings/utils/compileLLamaCpp.js +2 -0
- package/dist/bindings/utils/compileLLamaCpp.js.map +1 -1
- package/dist/bindings/utils/getBuildFolderNameForBuildOptions.js +2 -0
- package/dist/bindings/utils/getBuildFolderNameForBuildOptions.js.map +1 -1
- package/dist/bindings/utils/resolveCustomCmakeOptions.js +2 -0
- package/dist/bindings/utils/resolveCustomCmakeOptions.js.map +1 -1
- package/dist/cli/commands/BuildCommand.d.ts +2 -1
- package/dist/cli/commands/BuildCommand.js +11 -9
- package/dist/cli/commands/BuildCommand.js.map +1 -1
- package/dist/cli/commands/DebugCommand.js +16 -13
- package/dist/cli/commands/DebugCommand.js.map +1 -1
- package/dist/cli/commands/DownloadCommand.d.ts +2 -1
- package/dist/cli/commands/DownloadCommand.js +11 -9
- package/dist/cli/commands/DownloadCommand.js.map +1 -1
- package/dist/cli/utils/logEnabledComputeLayers.d.ts +8 -0
- package/dist/cli/utils/logEnabledComputeLayers.js +11 -0
- package/dist/cli/utils/logEnabledComputeLayers.js.map +1 -0
- package/dist/config.d.ts +1 -0
- package/dist/config.js +5 -2
- package/dist/config.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfArray.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfBoolean.d.ts +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfBoolean.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfBooleanValue.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfGrammar.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfNull.d.ts +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfNull.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfNumber.d.ts +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfNumber.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfNumberValue.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfObjectMap.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfOr.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfString.d.ts +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfString.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfStringValue.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfWhitespace.d.ts +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfWhitespace.js.map +1 -1
- package/dist/utils/getBuildDefaults.d.ts +1 -0
- package/dist/utils/getBuildDefaults.js +3 -2
- package/dist/utils/getBuildDefaults.js.map +1 -1
- package/llama/CMakeLists.txt +20 -0
- package/llama/addon.cpp +34 -3
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llama/gpuInfo/cuda-gpu-info.cu +5 -5
- package/llama/gpuInfo/cuda-gpu-info.h +2 -2
- package/llama/gpuInfo/vulkan-gpu-info.cpp +65 -0
- package/llama/gpuInfo/vulkan-gpu-info.h +7 -0
- package/llama/llama.cpp.info.json +1 -1
- package/llamaBins/linux-arm64/.buildMetadata.json +1 -1
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/.buildMetadata.json +1 -1
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/.buildMetadata.json +1 -1
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/linux-x64-cuda/.buildMetadata.json +1 -1
- package/llamaBins/linux-x64-cuda/llama-addon.node +0 -0
- package/llamaBins/linux-x64-vulkan/.buildMetadata.json +1 -0
- package/llamaBins/linux-x64-vulkan/llama-addon.node +0 -0
- package/llamaBins/mac-arm64-metal/.buildMetadata.json +1 -1
- package/llamaBins/mac-arm64-metal/ggml-metal.metal +540 -9
- package/llamaBins/mac-arm64-metal/llama-addon.node +0 -0
- package/llamaBins/mac-x64/.buildMetadata.json +1 -1
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/.buildMetadata.json +1 -1
- package/llamaBins/win-x64/llama-addon.exp +0 -0
- package/llamaBins/win-x64/llama-addon.lib +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64-cuda/.buildMetadata.json +1 -1
- package/llamaBins/win-x64-cuda/llama-addon.exp +0 -0
- package/llamaBins/win-x64-cuda/llama-addon.lib +0 -0
- package/llamaBins/win-x64-cuda/llama-addon.node +0 -0
- package/llamaBins/win-x64-vulkan/.buildMetadata.json +1 -0
- package/llamaBins/win-x64-vulkan/llama-addon.exp +0 -0
- package/llamaBins/win-x64-vulkan/llama-addon.lib +0 -0
- package/llamaBins/win-x64-vulkan/llama-addon.node +0 -0
- package/package.json +2 -1
|
@@ -392,7 +392,7 @@ kernel void kernel_soft_max(
|
|
|
392
392
|
float lmax = -INFINITY;
|
|
393
393
|
|
|
394
394
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
395
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
395
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
|
396
396
|
}
|
|
397
397
|
|
|
398
398
|
// find the max value in the block
|
|
@@ -417,7 +417,7 @@ kernel void kernel_soft_max(
|
|
|
417
417
|
// parallel sum
|
|
418
418
|
float lsum = 0.0f;
|
|
419
419
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
420
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
420
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
|
421
421
|
lsum += exp_psrc0;
|
|
422
422
|
pdst[i00] = exp_psrc0;
|
|
423
423
|
}
|
|
@@ -495,7 +495,7 @@ kernel void kernel_soft_max_4(
|
|
|
495
495
|
float4 lmax4 = -INFINITY;
|
|
496
496
|
|
|
497
497
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
498
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
498
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
|
499
499
|
}
|
|
500
500
|
|
|
501
501
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -521,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
|
521
521
|
// parallel sum
|
|
522
522
|
float4 lsum4 = 0.0f;
|
|
523
523
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
524
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
524
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
|
525
525
|
lsum4 += exp_psrc4;
|
|
526
526
|
pdst4[i00] = exp_psrc4;
|
|
527
527
|
}
|
|
@@ -2525,12 +2525,32 @@ typedef struct {
|
|
|
2525
2525
|
} block_iq3_xxs;
|
|
2526
2526
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
|
2527
2527
|
|
|
2528
|
+
// 3.4375 bpw
|
|
2529
|
+
#if QK_K == 64
|
|
2530
|
+
#define IQ3S_N_SCALE 2
|
|
2531
|
+
#else
|
|
2532
|
+
#define IQ3S_N_SCALE QK_K/64
|
|
2533
|
+
#endif
|
|
2534
|
+
typedef struct {
|
|
2535
|
+
half d;
|
|
2536
|
+
uint8_t qs[QK_K/4];
|
|
2537
|
+
uint8_t qh[QK_K/32];
|
|
2538
|
+
uint8_t signs[QK_K/8];
|
|
2539
|
+
uint8_t scales[IQ3S_N_SCALE];
|
|
2540
|
+
} block_iq3_s;
|
|
2541
|
+
|
|
2528
2542
|
typedef struct {
|
|
2529
2543
|
half d;
|
|
2530
2544
|
uint8_t qs[QK_K/8];
|
|
2531
2545
|
uint8_t scales[QK_K/16];
|
|
2532
2546
|
} block_iq1_s;
|
|
2533
2547
|
|
|
2548
|
+
// Non-linear quants
|
|
2549
|
+
#define QK4_NL 32
|
|
2550
|
+
typedef struct {
|
|
2551
|
+
half d;
|
|
2552
|
+
uint8_t qs[QK4_NL/2];
|
|
2553
|
+
} block_iq4_nl;
|
|
2534
2554
|
|
|
2535
2555
|
//====================================== dot products =========================
|
|
2536
2556
|
|
|
@@ -3789,6 +3809,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
|
3789
3809
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
|
3790
3810
|
};
|
|
3791
3811
|
|
|
3812
|
+
constexpr constant static uint32_t iq3xs_grid[512] = {
|
|
3813
|
+
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
|
|
3814
|
+
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
|
|
3815
|
+
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
|
|
3816
|
+
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
|
|
3817
|
+
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
|
|
3818
|
+
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
|
|
3819
|
+
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
|
|
3820
|
+
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
|
|
3821
|
+
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
|
|
3822
|
+
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
|
|
3823
|
+
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
|
|
3824
|
+
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
|
|
3825
|
+
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
|
|
3826
|
+
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
|
|
3827
|
+
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
|
|
3828
|
+
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
|
|
3829
|
+
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
|
|
3830
|
+
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
|
|
3831
|
+
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
|
|
3832
|
+
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
|
|
3833
|
+
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
|
|
3834
|
+
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
|
|
3835
|
+
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
|
|
3836
|
+
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
|
|
3837
|
+
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
|
|
3838
|
+
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
|
|
3839
|
+
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
|
|
3840
|
+
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
|
|
3841
|
+
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
|
|
3842
|
+
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
|
|
3843
|
+
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
|
|
3844
|
+
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
|
|
3845
|
+
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
|
|
3846
|
+
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
|
|
3847
|
+
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
|
|
3848
|
+
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
|
|
3849
|
+
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
|
|
3850
|
+
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
|
|
3851
|
+
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
|
|
3852
|
+
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
|
|
3853
|
+
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
|
|
3854
|
+
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
|
|
3855
|
+
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
|
|
3856
|
+
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
|
|
3857
|
+
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
|
|
3858
|
+
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
|
|
3859
|
+
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
|
|
3860
|
+
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
|
|
3861
|
+
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
|
|
3862
|
+
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
|
|
3863
|
+
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
|
|
3864
|
+
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
|
|
3865
|
+
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
|
|
3866
|
+
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
|
|
3867
|
+
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
|
|
3868
|
+
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
|
|
3869
|
+
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
|
|
3870
|
+
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
|
|
3871
|
+
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
|
|
3872
|
+
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
|
|
3873
|
+
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
|
|
3874
|
+
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
|
|
3875
|
+
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
|
|
3876
|
+
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
|
|
3877
|
+
};
|
|
3878
|
+
|
|
3792
3879
|
#define NGRID_IQ1S 512
|
|
3793
3880
|
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
|
3794
3881
|
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
@@ -4027,7 +4114,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
4027
4114
|
y4 += 32 * 32;
|
|
4028
4115
|
}
|
|
4029
4116
|
#else
|
|
4030
|
-
|
|
4117
|
+
(void) x;
|
|
4118
|
+
(void) y;
|
|
4119
|
+
(void) yl;
|
|
4120
|
+
(void) nb32;
|
|
4031
4121
|
#endif
|
|
4032
4122
|
|
|
4033
4123
|
for (int row = 0; row < N_DST; ++row) {
|
|
@@ -4170,7 +4260,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
4170
4260
|
y4 += 32 * 32;
|
|
4171
4261
|
}
|
|
4172
4262
|
#else
|
|
4173
|
-
|
|
4263
|
+
(void) x;
|
|
4264
|
+
(void) y;
|
|
4265
|
+
(void) yl;
|
|
4266
|
+
(void) nb32;
|
|
4174
4267
|
#endif
|
|
4175
4268
|
|
|
4176
4269
|
for (int row = 0; row < N_DST; ++row) {
|
|
@@ -4306,7 +4399,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
4306
4399
|
y4 += 32 * 32;
|
|
4307
4400
|
}
|
|
4308
4401
|
#else
|
|
4309
|
-
|
|
4402
|
+
(void) x;
|
|
4403
|
+
(void) y;
|
|
4404
|
+
(void) yl;
|
|
4405
|
+
(void) nb32;
|
|
4310
4406
|
#endif
|
|
4311
4407
|
|
|
4312
4408
|
for (int row = 0; row < N_DST; ++row) {
|
|
@@ -4346,7 +4442,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
|
4346
4442
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4347
4443
|
}
|
|
4348
4444
|
|
|
4349
|
-
void
|
|
4445
|
+
void kernel_mul_mv_iq3_s_f32_impl(
|
|
4350
4446
|
device const void * src0,
|
|
4351
4447
|
device const float * src1,
|
|
4352
4448
|
device float * dst,
|
|
@@ -4359,6 +4455,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
4359
4455
|
constant int64_t & ne1,
|
|
4360
4456
|
constant uint & r2,
|
|
4361
4457
|
constant uint & r3,
|
|
4458
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4362
4459
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4363
4460
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
4364
4461
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -4376,6 +4473,134 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
4376
4473
|
|
|
4377
4474
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4378
4475
|
|
|
4476
|
+
device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
|
|
4477
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4478
|
+
|
|
4479
|
+
float yl[32];
|
|
4480
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4481
|
+
|
|
4482
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4483
|
+
|
|
4484
|
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
|
4485
|
+
{
|
|
4486
|
+
int nval = 8;
|
|
4487
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
4488
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
|
|
4489
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4490
|
+
}
|
|
4491
|
+
|
|
4492
|
+
const int ix = tiisg;
|
|
4493
|
+
|
|
4494
|
+
device const float * y4 = y + 32 * ix;
|
|
4495
|
+
|
|
4496
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
4497
|
+
|
|
4498
|
+
for (int i = 0; i < 32; ++i) {
|
|
4499
|
+
yl[i] = y4[i];
|
|
4500
|
+
}
|
|
4501
|
+
|
|
4502
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4503
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4504
|
+
|
|
4505
|
+
device const block_iq3_s * xr = x + ibl;
|
|
4506
|
+
device const uint8_t * qs = xr->qs + 8 * ib;
|
|
4507
|
+
device const uint8_t * qh = xr->qh + ib;
|
|
4508
|
+
device const uint8_t * sc = xr->scales + (ib/2);
|
|
4509
|
+
device const uint8_t * signs = xr->signs + 4 * ib;
|
|
4510
|
+
device const half * dh = &xr->d;
|
|
4511
|
+
|
|
4512
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4513
|
+
|
|
4514
|
+
const float db = dh[0];
|
|
4515
|
+
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
|
|
4516
|
+
|
|
4517
|
+
float2 sum = {0};
|
|
4518
|
+
for (int l = 0; l < 4; ++l) {
|
|
4519
|
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
|
4520
|
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
|
4521
|
+
for (int j = 0; j < 4; ++j) {
|
|
4522
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
|
4523
|
+
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
|
4524
|
+
}
|
|
4525
|
+
}
|
|
4526
|
+
sumf[row] += d * (sum[0] + sum[1]);
|
|
4527
|
+
|
|
4528
|
+
dh += nb*sizeof(block_iq3_s)/2;
|
|
4529
|
+
qs += nb*sizeof(block_iq3_s);
|
|
4530
|
+
qh += nb*sizeof(block_iq3_s);
|
|
4531
|
+
sc += nb*sizeof(block_iq3_s);
|
|
4532
|
+
signs += nb*sizeof(block_iq3_s);
|
|
4533
|
+
}
|
|
4534
|
+
|
|
4535
|
+
y4 += 32 * 32;
|
|
4536
|
+
}
|
|
4537
|
+
|
|
4538
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4539
|
+
all_sum = simd_sum(sumf[row]);
|
|
4540
|
+
if (tiisg == 0) {
|
|
4541
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
|
|
4542
|
+
}
|
|
4543
|
+
}
|
|
4544
|
+
}
|
|
4545
|
+
|
|
4546
|
+
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
|
4547
|
+
kernel void kernel_mul_mv_iq3_s_f32(
|
|
4548
|
+
device const void * src0,
|
|
4549
|
+
device const float * src1,
|
|
4550
|
+
device float * dst,
|
|
4551
|
+
constant int64_t & ne00,
|
|
4552
|
+
constant int64_t & ne01,
|
|
4553
|
+
constant int64_t & ne02,
|
|
4554
|
+
constant uint64_t & nb00,
|
|
4555
|
+
constant uint64_t & nb01,
|
|
4556
|
+
constant uint64_t & nb02,
|
|
4557
|
+
constant int64_t & ne10,
|
|
4558
|
+
constant int64_t & ne11,
|
|
4559
|
+
constant int64_t & ne12,
|
|
4560
|
+
constant uint64_t & nb10,
|
|
4561
|
+
constant uint64_t & nb11,
|
|
4562
|
+
constant uint64_t & nb12,
|
|
4563
|
+
constant int64_t & ne0,
|
|
4564
|
+
constant int64_t & ne1,
|
|
4565
|
+
constant uint & r2,
|
|
4566
|
+
constant uint & r3,
|
|
4567
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4568
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4569
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4570
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4571
|
+
|
|
4572
|
+
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4573
|
+
}
|
|
4574
|
+
|
|
4575
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
|
4576
|
+
device const void * src0,
|
|
4577
|
+
device const float * src1,
|
|
4578
|
+
device float * dst,
|
|
4579
|
+
constant int64_t & ne00,
|
|
4580
|
+
constant int64_t & ne01,
|
|
4581
|
+
constant int64_t & ne02,
|
|
4582
|
+
constant int64_t & ne10,
|
|
4583
|
+
constant int64_t & ne12,
|
|
4584
|
+
constant int64_t & ne0,
|
|
4585
|
+
constant int64_t & ne1,
|
|
4586
|
+
constant uint & r2,
|
|
4587
|
+
constant uint & r3,
|
|
4588
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4589
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4590
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4591
|
+
|
|
4592
|
+
const int nb = ne00/QK_K;
|
|
4593
|
+
const int r0 = tgpig.x;
|
|
4594
|
+
const int r1 = tgpig.y;
|
|
4595
|
+
const int im = tgpig.z;
|
|
4596
|
+
|
|
4597
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4598
|
+
const int ib_row = first_row * nb;
|
|
4599
|
+
|
|
4600
|
+
const uint i12 = im%ne12;
|
|
4601
|
+
const uint i13 = im/ne12;
|
|
4602
|
+
|
|
4603
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4379
4604
|
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
|
4380
4605
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4381
4606
|
|
|
@@ -4424,7 +4649,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
4424
4649
|
y4 += 16 * 32;
|
|
4425
4650
|
}
|
|
4426
4651
|
#else
|
|
4427
|
-
|
|
4652
|
+
(void) x;
|
|
4653
|
+
(void) y;
|
|
4654
|
+
(void) yl;
|
|
4655
|
+
(void) nb32;
|
|
4428
4656
|
#endif
|
|
4429
4657
|
|
|
4430
4658
|
for (int row = 0; row < N_DST; ++row) {
|
|
@@ -4435,6 +4663,103 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
4435
4663
|
}
|
|
4436
4664
|
}
|
|
4437
4665
|
|
|
4666
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
4667
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
4668
|
+
};
|
|
4669
|
+
|
|
4670
|
+
void kernel_mul_mv_iq4_nl_f32_impl(
|
|
4671
|
+
device const void * src0,
|
|
4672
|
+
device const float * src1,
|
|
4673
|
+
device float * dst,
|
|
4674
|
+
constant int64_t & ne00,
|
|
4675
|
+
constant int64_t & ne01,
|
|
4676
|
+
constant int64_t & ne02,
|
|
4677
|
+
constant int64_t & ne10,
|
|
4678
|
+
constant int64_t & ne12,
|
|
4679
|
+
constant int64_t & ne0,
|
|
4680
|
+
constant int64_t & ne1,
|
|
4681
|
+
constant uint & r2,
|
|
4682
|
+
constant uint & r3,
|
|
4683
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
4684
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4685
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4687
|
+
|
|
4688
|
+
const int nb = ne00/QK4_NL;
|
|
4689
|
+
const int r0 = tgpig.x;
|
|
4690
|
+
const int r1 = tgpig.y;
|
|
4691
|
+
const int im = tgpig.z;
|
|
4692
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
|
4693
|
+
const int ib_row = first_row * nb;
|
|
4694
|
+
|
|
4695
|
+
const uint i12 = im%ne12;
|
|
4696
|
+
const uint i13 = im/ne12;
|
|
4697
|
+
|
|
4698
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4699
|
+
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
|
4700
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4701
|
+
|
|
4702
|
+
const int ix = tiisg/2; // 0...15
|
|
4703
|
+
const int it = tiisg%2; // 0 or 1
|
|
4704
|
+
|
|
4705
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
|
4706
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4707
|
+
|
|
4708
|
+
float4 yl[4];
|
|
4709
|
+
float sumf[2]={0.f}, all_sum;
|
|
4710
|
+
|
|
4711
|
+
device const float * yb = y + ix * QK4_NL + it * 8;
|
|
4712
|
+
|
|
4713
|
+
uint32_t aux32[2];
|
|
4714
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
|
4715
|
+
|
|
4716
|
+
float4 qf1, qf2;
|
|
4717
|
+
|
|
4718
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
|
4719
|
+
|
|
4720
|
+
device const float4 * y4 = (device const float4 *)yb;
|
|
4721
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
|
4722
|
+
|
|
4723
|
+
for (int row = 0; row < 2; ++row) {
|
|
4724
|
+
|
|
4725
|
+
device const block_iq4_nl & xb = x[row*nb + ib];
|
|
4726
|
+
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
|
4727
|
+
|
|
4728
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
|
4729
|
+
|
|
4730
|
+
aux32[0] = q4[0] | (q4[1] << 16);
|
|
4731
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
|
4732
|
+
aux32[0] &= 0x0f0f0f0f;
|
|
4733
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
4734
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
4735
|
+
acc1 += yl[0] * qf1;
|
|
4736
|
+
acc2 += yl[1] * qf2;
|
|
4737
|
+
|
|
4738
|
+
aux32[0] = q4[2] | (q4[3] << 16);
|
|
4739
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
|
4740
|
+
aux32[0] &= 0x0f0f0f0f;
|
|
4741
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
4742
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
4743
|
+
acc1 += yl[2] * qf1;
|
|
4744
|
+
acc2 += yl[3] * qf2;
|
|
4745
|
+
|
|
4746
|
+
acc1 += acc2;
|
|
4747
|
+
|
|
4748
|
+
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
|
4749
|
+
|
|
4750
|
+
}
|
|
4751
|
+
|
|
4752
|
+
yb += 16 * QK4_NL;
|
|
4753
|
+
}
|
|
4754
|
+
|
|
4755
|
+
for (int row = 0; row < 2; ++row) {
|
|
4756
|
+
all_sum = simd_sum(sumf[row]);
|
|
4757
|
+
if (tiisg == 0) {
|
|
4758
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
4759
|
+
}
|
|
4760
|
+
}
|
|
4761
|
+
}
|
|
4762
|
+
|
|
4438
4763
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
|
4439
4764
|
kernel void kernel_mul_mv_iq1_s_f32(
|
|
4440
4765
|
device const void * src0,
|
|
@@ -4463,6 +4788,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
|
4463
4788
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
4464
4789
|
}
|
|
4465
4790
|
|
|
4791
|
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
4792
|
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
|
4793
|
+
device const void * src0,
|
|
4794
|
+
device const float * src1,
|
|
4795
|
+
device float * dst,
|
|
4796
|
+
constant int64_t & ne00,
|
|
4797
|
+
constant int64_t & ne01,
|
|
4798
|
+
constant int64_t & ne02,
|
|
4799
|
+
constant uint64_t & nb00,
|
|
4800
|
+
constant uint64_t & nb01,
|
|
4801
|
+
constant uint64_t & nb02,
|
|
4802
|
+
constant int64_t & ne10,
|
|
4803
|
+
constant int64_t & ne11,
|
|
4804
|
+
constant int64_t & ne12,
|
|
4805
|
+
constant uint64_t & nb10,
|
|
4806
|
+
constant uint64_t & nb11,
|
|
4807
|
+
constant uint64_t & nb12,
|
|
4808
|
+
constant int64_t & ne0,
|
|
4809
|
+
constant int64_t & ne1,
|
|
4810
|
+
constant uint & r2,
|
|
4811
|
+
constant uint & r3,
|
|
4812
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
4813
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4814
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4815
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4816
|
+
|
|
4817
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4818
|
+
}
|
|
4466
4819
|
|
|
4467
4820
|
//============================= templates and their specializations =============================
|
|
4468
4821
|
|
|
@@ -4659,6 +5012,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
|
4659
5012
|
const float dl = d * sc[0];
|
|
4660
5013
|
const float ml = min * sc[1];
|
|
4661
5014
|
#else
|
|
5015
|
+
(void) get_scale_min_k4_just2;
|
|
5016
|
+
|
|
4662
5017
|
q = q + 16 * (il&1);
|
|
4663
5018
|
device const uint8_t * s = xb->scales;
|
|
4664
5019
|
device const half2 * dh = (device const half2 *)xb->d;
|
|
@@ -4808,6 +5163,31 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
|
4808
5163
|
}
|
|
4809
5164
|
}
|
|
4810
5165
|
|
|
5166
|
+
template <typename type4x4>
|
|
5167
|
+
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
|
5168
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5169
|
+
const float d = xb->d;
|
|
5170
|
+
const int ib32 = il/2;
|
|
5171
|
+
il = il%2;
|
|
5172
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5173
|
+
device const uint8_t * qs = xb->qs + 8*ib32;
|
|
5174
|
+
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
|
5175
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
|
5176
|
+
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
|
|
5177
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
|
5178
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
|
5179
|
+
for (int i = 0; i < 4; ++i) {
|
|
5180
|
+
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
|
5181
|
+
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
|
5182
|
+
}
|
|
5183
|
+
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
|
5184
|
+
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
|
5185
|
+
for (int i = 0; i < 4; ++i) {
|
|
5186
|
+
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
|
5187
|
+
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
|
5188
|
+
}
|
|
5189
|
+
}
|
|
5190
|
+
|
|
4811
5191
|
template <typename type4x4>
|
|
4812
5192
|
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
|
4813
5193
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
@@ -4824,6 +5204,21 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
|
4824
5204
|
}
|
|
4825
5205
|
}
|
|
4826
5206
|
|
|
5207
|
+
template <typename type4x4>
|
|
5208
|
+
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
|
5209
|
+
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
|
5210
|
+
const float d = xb->d;
|
|
5211
|
+
uint32_t aux32;
|
|
5212
|
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
|
5213
|
+
for (int i = 0; i < 4; ++i) {
|
|
5214
|
+
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
|
|
5215
|
+
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
|
5216
|
+
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
|
5217
|
+
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
|
5218
|
+
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
|
5219
|
+
}
|
|
5220
|
+
}
|
|
5221
|
+
|
|
4827
5222
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
4828
5223
|
kernel void kernel_get_rows(
|
|
4829
5224
|
device const void * src0,
|
|
@@ -5366,7 +5761,9 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
|
5366
5761
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5367
5762
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5368
5763
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5764
|
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
5369
5765
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5766
|
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
5370
5767
|
|
|
5371
5768
|
//
|
|
5372
5769
|
// matrix-matrix multiplication
|
|
@@ -5406,7 +5803,9 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
5406
5803
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5407
5804
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5408
5805
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5806
|
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
5409
5807
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5808
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
5410
5809
|
|
|
5411
5810
|
//
|
|
5412
5811
|
// indirect matrix-matrix multiplication
|
|
@@ -5458,7 +5857,9 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
5458
5857
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5459
5858
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5460
5859
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5860
|
+
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
5461
5861
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5862
|
+
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
5462
5863
|
|
|
5463
5864
|
//
|
|
5464
5865
|
// matrix-vector multiplication
|
|
@@ -6427,6 +6828,71 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
|
6427
6828
|
sgitg);
|
|
6428
6829
|
}
|
|
6429
6830
|
|
|
6831
|
+
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
|
6832
|
+
kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6833
|
+
device const char * ids,
|
|
6834
|
+
device const char * src1,
|
|
6835
|
+
device float * dst,
|
|
6836
|
+
constant uint64_t & nbi1,
|
|
6837
|
+
constant int64_t & ne00,
|
|
6838
|
+
constant int64_t & ne01,
|
|
6839
|
+
constant int64_t & ne02,
|
|
6840
|
+
constant uint64_t & nb00,
|
|
6841
|
+
constant uint64_t & nb01,
|
|
6842
|
+
constant uint64_t & nb02,
|
|
6843
|
+
constant int64_t & ne10,
|
|
6844
|
+
constant int64_t & ne11,
|
|
6845
|
+
constant int64_t & ne12,
|
|
6846
|
+
constant int64_t & ne13,
|
|
6847
|
+
constant uint64_t & nb10,
|
|
6848
|
+
constant uint64_t & nb11,
|
|
6849
|
+
constant uint64_t & nb12,
|
|
6850
|
+
constant int64_t & ne0,
|
|
6851
|
+
constant int64_t & ne1,
|
|
6852
|
+
constant uint64_t & nb1,
|
|
6853
|
+
constant uint & r2,
|
|
6854
|
+
constant uint & r3,
|
|
6855
|
+
constant int & idx,
|
|
6856
|
+
device const char * src00,
|
|
6857
|
+
device const char * src01,
|
|
6858
|
+
device const char * src02,
|
|
6859
|
+
device const char * src03,
|
|
6860
|
+
device const char * src04,
|
|
6861
|
+
device const char * src05,
|
|
6862
|
+
device const char * src06,
|
|
6863
|
+
device const char * src07,
|
|
6864
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
6865
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6866
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6867
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6868
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6869
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6870
|
+
|
|
6871
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
6872
|
+
|
|
6873
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
6874
|
+
|
|
6875
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6876
|
+
|
|
6877
|
+
kernel_mul_mv_iq3_s_f32_impl(
|
|
6878
|
+
src0[id],
|
|
6879
|
+
(device const float *) (src1 + bid*nb11),
|
|
6880
|
+
dst + bid*ne0,
|
|
6881
|
+
ne00,
|
|
6882
|
+
ne01,
|
|
6883
|
+
ne02,
|
|
6884
|
+
ne10,
|
|
6885
|
+
ne12,
|
|
6886
|
+
ne0,
|
|
6887
|
+
ne1,
|
|
6888
|
+
r2,
|
|
6889
|
+
r3,
|
|
6890
|
+
shared_values,
|
|
6891
|
+
tgpig,
|
|
6892
|
+
tiisg,
|
|
6893
|
+
sgitg);
|
|
6894
|
+
}
|
|
6895
|
+
|
|
6430
6896
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
|
6431
6897
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6432
6898
|
device const char * ids,
|
|
@@ -6489,3 +6955,68 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
|
6489
6955
|
tiisg,
|
|
6490
6956
|
sgitg);
|
|
6491
6957
|
}
|
|
6958
|
+
|
|
6959
|
+
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
|
6960
|
+
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6961
|
+
device const char * ids,
|
|
6962
|
+
device const char * src1,
|
|
6963
|
+
device float * dst,
|
|
6964
|
+
constant uint64_t & nbi1,
|
|
6965
|
+
constant int64_t & ne00,
|
|
6966
|
+
constant int64_t & ne01,
|
|
6967
|
+
constant int64_t & ne02,
|
|
6968
|
+
constant uint64_t & nb00,
|
|
6969
|
+
constant uint64_t & nb01,
|
|
6970
|
+
constant uint64_t & nb02,
|
|
6971
|
+
constant int64_t & ne10,
|
|
6972
|
+
constant int64_t & ne11,
|
|
6973
|
+
constant int64_t & ne12,
|
|
6974
|
+
constant int64_t & ne13,
|
|
6975
|
+
constant uint64_t & nb10,
|
|
6976
|
+
constant uint64_t & nb11,
|
|
6977
|
+
constant uint64_t & nb12,
|
|
6978
|
+
constant int64_t & ne0,
|
|
6979
|
+
constant int64_t & ne1,
|
|
6980
|
+
constant uint64_t & nb1,
|
|
6981
|
+
constant uint & r2,
|
|
6982
|
+
constant uint & r3,
|
|
6983
|
+
constant int & idx,
|
|
6984
|
+
device const char * src00,
|
|
6985
|
+
device const char * src01,
|
|
6986
|
+
device const char * src02,
|
|
6987
|
+
device const char * src03,
|
|
6988
|
+
device const char * src04,
|
|
6989
|
+
device const char * src05,
|
|
6990
|
+
device const char * src06,
|
|
6991
|
+
device const char * src07,
|
|
6992
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
6993
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6994
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6995
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6996
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6997
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6998
|
+
|
|
6999
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
7000
|
+
|
|
7001
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
7002
|
+
|
|
7003
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
7004
|
+
|
|
7005
|
+
kernel_mul_mv_iq4_nl_f32_impl(
|
|
7006
|
+
src0[id],
|
|
7007
|
+
(device const float *) (src1 + bid*nb11),
|
|
7008
|
+
dst + bid*ne0,
|
|
7009
|
+
ne00,
|
|
7010
|
+
ne01,
|
|
7011
|
+
ne02,
|
|
7012
|
+
ne10,
|
|
7013
|
+
ne12,
|
|
7014
|
+
ne0,
|
|
7015
|
+
ne1,
|
|
7016
|
+
r2,
|
|
7017
|
+
r3,
|
|
7018
|
+
shared_values,
|
|
7019
|
+
tgpig,
|
|
7020
|
+
tiisg,
|
|
7021
|
+
sgitg);
|
|
7022
|
+
}
|