@novastera-oss/llamarn 0.2.7 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/common/arg.cpp +7 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +1 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
- package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -3
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
- package/cpp/llama.cpp/src/llama-batch.h +98 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
- package/cpp/llama.cpp/src/llama-graph.h +44 -32
- package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-hparams.h +8 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
- package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.h +18 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
- package/cpp/llama.cpp/src/llama-model.h +22 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/common.h +1 -0
- package/ios/include/llama.h +8 -3
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
|
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
|
35
35
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
36
36
|
};
|
|
37
37
|
|
|
38
|
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
|
39
|
+
if (x <= val[0]) return 0;
|
|
40
|
+
if (x >= val[n-1]) return n-1;
|
|
41
|
+
int ml = 0, mu = n-1;
|
|
42
|
+
while (mu-ml > 1) {
|
|
43
|
+
int mav = (ml+mu)/2;
|
|
44
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
|
45
|
+
}
|
|
46
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
47
|
+
}
|
|
48
|
+
|
|
38
49
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
39
50
|
template <typename type4x4>
|
|
40
51
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
97
108
|
}
|
|
98
109
|
}
|
|
99
110
|
|
|
111
|
+
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
112
|
+
float amax = 0.0f; // absolute max
|
|
113
|
+
float max = 0.0f;
|
|
114
|
+
|
|
115
|
+
for (int j = 0; j < QK4_0; j++) {
|
|
116
|
+
const float v = src[j];
|
|
117
|
+
if (amax < fabs(v)) {
|
|
118
|
+
amax = fabs(v);
|
|
119
|
+
max = v;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
const float d = max / -8;
|
|
124
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
125
|
+
|
|
126
|
+
dst.d = d;
|
|
127
|
+
|
|
128
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
129
|
+
const float x0 = src[0 + j]*id;
|
|
130
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
|
131
|
+
|
|
132
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
133
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
134
|
+
|
|
135
|
+
dst.qs[j] = xi0;
|
|
136
|
+
dst.qs[j] |= xi1 << 4;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
141
|
+
float min = FLT_MAX;
|
|
142
|
+
float max = -FLT_MAX;
|
|
143
|
+
|
|
144
|
+
for (int j = 0; j < QK4_1; j++) {
|
|
145
|
+
const float v = src[j];
|
|
146
|
+
if (min > v) min = v;
|
|
147
|
+
if (max < v) max = v;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
|
151
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
152
|
+
|
|
153
|
+
dst.d = d;
|
|
154
|
+
dst.m = min;
|
|
155
|
+
|
|
156
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
157
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
158
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
159
|
+
|
|
160
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
161
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
162
|
+
|
|
163
|
+
dst.qs[j] = xi0;
|
|
164
|
+
dst.qs[j] |= xi1 << 4;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
169
|
+
float amax = 0.0f; // absolute max
|
|
170
|
+
float max = 0.0f;
|
|
171
|
+
|
|
172
|
+
for (int j = 0; j < QK5_0; j++) {
|
|
173
|
+
const float v = src[j];
|
|
174
|
+
if (amax < fabs(v)) {
|
|
175
|
+
amax = fabs(v);
|
|
176
|
+
max = v;
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
const float d = max / -16;
|
|
181
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
182
|
+
|
|
183
|
+
dst.d = d;
|
|
184
|
+
|
|
185
|
+
uint32_t qh = 0;
|
|
186
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
|
187
|
+
const float x0 = src[0 + j]*id;
|
|
188
|
+
const float x1 = src[QK5_0/2 + j]*id;
|
|
189
|
+
|
|
190
|
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
|
191
|
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
|
192
|
+
|
|
193
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
194
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
195
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
199
|
+
|
|
200
|
+
for (int j = 0; j < 4; ++j) {
|
|
201
|
+
dst.qh[j] = qh8[j];
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
206
|
+
float max = src[0];
|
|
207
|
+
float min = src[0];
|
|
208
|
+
|
|
209
|
+
for (int j = 1; j < QK5_1; j++) {
|
|
210
|
+
const float v = src[j];
|
|
211
|
+
min = v < min ? v : min;
|
|
212
|
+
max = v > max ? v : max;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
const float d = (max - min) / 31;
|
|
216
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
217
|
+
|
|
218
|
+
dst.d = d;
|
|
219
|
+
dst.m = min;
|
|
220
|
+
|
|
221
|
+
uint32_t qh = 0;
|
|
222
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
|
223
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
224
|
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
|
225
|
+
|
|
226
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
227
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
228
|
+
|
|
229
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
230
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
231
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
235
|
+
|
|
236
|
+
for (int j = 0; j < 4; ++j) {
|
|
237
|
+
dst.qh[j] = qh8[j];
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
|
242
|
+
float amax = 0.0f; // absolute max
|
|
243
|
+
float max = 0.0f;
|
|
244
|
+
|
|
245
|
+
for (int j = 0; j < QK4_NL; j++) {
|
|
246
|
+
const float v = src[j];
|
|
247
|
+
if (amax < fabs(v)) {
|
|
248
|
+
amax = fabs(v);
|
|
249
|
+
max = v;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
const float d = max / kvalues_iq4nl_f[0];
|
|
254
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
255
|
+
|
|
256
|
+
float sumqx = 0, sumq2 = 0;
|
|
257
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
258
|
+
const float x0 = src[0 + j]*id;
|
|
259
|
+
const float x1 = src[QK4_NL/2 + j]*id;
|
|
260
|
+
|
|
261
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
|
262
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
|
263
|
+
|
|
264
|
+
dst.qs[j] = xi0 | (xi1 << 4);
|
|
265
|
+
|
|
266
|
+
const float v0 = kvalues_iq4nl_f[xi0];
|
|
267
|
+
const float v1 = kvalues_iq4nl_f[xi1];
|
|
268
|
+
const float w0 = src[0 + j]*src[0 + j];
|
|
269
|
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
|
270
|
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
|
271
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
272
|
+
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
276
|
+
}
|
|
277
|
+
|
|
100
278
|
template <typename type4x4>
|
|
101
279
|
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
|
102
280
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
|
279
457
|
}
|
|
280
458
|
}
|
|
281
459
|
|
|
460
|
+
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
|
461
|
+
float amax = 0.0f; // absolute max
|
|
462
|
+
|
|
463
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
464
|
+
const float v = src[j];
|
|
465
|
+
amax = MAX(amax, fabs(v));
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
const float d = amax / ((1 << 7) - 1);
|
|
469
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
470
|
+
|
|
471
|
+
dst.d = d;
|
|
472
|
+
|
|
473
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
474
|
+
const float x0 = src[j]*id;
|
|
475
|
+
|
|
476
|
+
dst.qs[j] = round(x0);
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
282
480
|
template <typename type4x4>
|
|
283
481
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
284
482
|
const float d = xb->d;
|
|
@@ -2532,6 +2730,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
|
|
|
2532
2730
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
|
2533
2731
|
#endif
|
|
2534
2732
|
|
|
2733
|
+
template<typename T04, typename T14, typename args_t>
|
|
2734
|
+
void kernel_mul_mv_c4_impl(
|
|
2735
|
+
args_t args,
|
|
2736
|
+
device const char * src0,
|
|
2737
|
+
device const char * src1,
|
|
2738
|
+
device char * dst,
|
|
2739
|
+
uint3 tgpig,
|
|
2740
|
+
ushort tiisg) {
|
|
2741
|
+
const int r0 = tgpig.x*32 + tiisg;
|
|
2742
|
+
const int rb = tgpig.y*N_MV_T_T;
|
|
2743
|
+
const int im = tgpig.z;
|
|
2744
|
+
|
|
2745
|
+
if (r0 >= args.ne01) {
|
|
2746
|
+
return;
|
|
2747
|
+
}
|
|
2748
|
+
|
|
2749
|
+
const uint i12 = im%args.ne12;
|
|
2750
|
+
const uint i13 = im/args.ne12;
|
|
2751
|
+
|
|
2752
|
+
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
2753
|
+
|
|
2754
|
+
device const T04 * x = (device const T04 *) (src0 + offset0);
|
|
2755
|
+
|
|
2756
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
|
2757
|
+
|
|
2758
|
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
|
2759
|
+
int r1 = rb + row;
|
|
2760
|
+
if (r1 >= args.ne11) {
|
|
2761
|
+
break;
|
|
2762
|
+
}
|
|
2763
|
+
|
|
2764
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
2765
|
+
|
|
2766
|
+
device const T14 * y = (device const T14 *) (src1 + offset1);
|
|
2767
|
+
|
|
2768
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
|
2769
|
+
}
|
|
2770
|
+
}
|
|
2771
|
+
|
|
2772
|
+
template<typename T04, typename T14>
|
|
2773
|
+
kernel void kernel_mul_mv_c4(
|
|
2774
|
+
constant ggml_metal_kargs_mul_mv & args,
|
|
2775
|
+
device const char * src0,
|
|
2776
|
+
device const char * src1,
|
|
2777
|
+
device char * dst,
|
|
2778
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2779
|
+
ushort tiisg[[thread_index_in_simdgroup]]) {
|
|
2780
|
+
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
|
2781
|
+
args,
|
|
2782
|
+
src0,
|
|
2783
|
+
src1,
|
|
2784
|
+
dst,
|
|
2785
|
+
tgpig,
|
|
2786
|
+
tiisg);
|
|
2787
|
+
}
|
|
2788
|
+
|
|
2789
|
+
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
|
2790
|
+
|
|
2791
|
+
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
|
2792
|
+
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
|
2793
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
2794
|
+
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
|
2795
|
+
#endif
|
|
2796
|
+
|
|
2535
2797
|
template<typename T, typename T4>
|
|
2536
2798
|
kernel void kernel_mul_mv_1row(
|
|
2537
2799
|
constant ggml_metal_kargs_mul_mv & args,
|
|
@@ -4306,11 +4568,16 @@ kernel void kernel_cpy(
|
|
|
4306
4568
|
device const char * src0,
|
|
4307
4569
|
device char * dst,
|
|
4308
4570
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4571
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4309
4572
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4310
|
-
ushort3
|
|
4573
|
+
ushort3 tptg[[threads_per_threadgroup]]) {
|
|
4311
4574
|
const int i03 = tgpig[2];
|
|
4312
4575
|
const int i02 = tgpig[1];
|
|
4313
|
-
const int i01 = tgpig[0];
|
|
4576
|
+
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
|
|
4577
|
+
|
|
4578
|
+
if (i01 >= args.ne01) {
|
|
4579
|
+
return;
|
|
4580
|
+
}
|
|
4314
4581
|
|
|
4315
4582
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
4316
4583
|
|
|
@@ -4321,7 +4588,7 @@ kernel void kernel_cpy(
|
|
|
4321
4588
|
|
|
4322
4589
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
4323
4590
|
|
|
4324
|
-
for (int64_t i00 =
|
|
4591
|
+
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
4325
4592
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4326
4593
|
dst_data[i00] = (T1) src[0];
|
|
4327
4594
|
}
|
|
@@ -4341,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
|
|
|
4341
4608
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
|
4342
4609
|
#endif
|
|
4343
4610
|
|
|
4611
|
+
// TODO: templetify these kernels
|
|
4344
4612
|
kernel void kernel_cpy_f32_q8_0(
|
|
4345
4613
|
constant ggml_metal_kargs_cpy & args,
|
|
4346
4614
|
device const char * src0,
|
|
@@ -4364,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
|
|
|
4364
4632
|
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
|
4365
4633
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4366
4634
|
|
|
4367
|
-
|
|
4368
|
-
|
|
4369
|
-
for (int j = 0; j < QK8_0; j++) {
|
|
4370
|
-
const float v = src[j];
|
|
4371
|
-
amax = MAX(amax, fabs(v));
|
|
4372
|
-
}
|
|
4373
|
-
|
|
4374
|
-
const float d = amax / ((1 << 7) - 1);
|
|
4375
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4376
|
-
|
|
4377
|
-
dst_data[i00/QK8_0].d = d;
|
|
4378
|
-
|
|
4379
|
-
for (int j = 0; j < QK8_0; ++j) {
|
|
4380
|
-
const float x0 = src[j]*id;
|
|
4381
|
-
|
|
4382
|
-
dst_data[i00/QK8_0].qs[j] = round(x0);
|
|
4383
|
-
}
|
|
4635
|
+
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
|
4384
4636
|
}
|
|
4385
4637
|
}
|
|
4386
4638
|
|
|
@@ -4407,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
|
|
|
4407
4659
|
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
|
4408
4660
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4409
4661
|
|
|
4410
|
-
|
|
4411
|
-
float max = 0.0f;
|
|
4412
|
-
|
|
4413
|
-
for (int j = 0; j < QK4_0; j++) {
|
|
4414
|
-
const float v = src[j];
|
|
4415
|
-
if (amax < fabs(v)) {
|
|
4416
|
-
amax = fabs(v);
|
|
4417
|
-
max = v;
|
|
4418
|
-
}
|
|
4419
|
-
}
|
|
4420
|
-
|
|
4421
|
-
const float d = max / -8;
|
|
4422
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4423
|
-
|
|
4424
|
-
dst_data[i00/QK4_0].d = d;
|
|
4425
|
-
|
|
4426
|
-
for (int j = 0; j < QK4_0/2; ++j) {
|
|
4427
|
-
const float x0 = src[0 + j]*id;
|
|
4428
|
-
const float x1 = src[QK4_0/2 + j]*id;
|
|
4429
|
-
|
|
4430
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
4431
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
4432
|
-
|
|
4433
|
-
dst_data[i00/QK4_0].qs[j] = xi0;
|
|
4434
|
-
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
|
4435
|
-
}
|
|
4662
|
+
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
|
4436
4663
|
}
|
|
4437
4664
|
}
|
|
4438
4665
|
|
|
@@ -4459,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
|
|
|
4459
4686
|
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
|
4460
4687
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4461
4688
|
|
|
4462
|
-
|
|
4463
|
-
float max = -FLT_MAX;
|
|
4464
|
-
|
|
4465
|
-
for (int j = 0; j < QK4_1; j++) {
|
|
4466
|
-
const float v = src[j];
|
|
4467
|
-
if (min > v) min = v;
|
|
4468
|
-
if (max < v) max = v;
|
|
4469
|
-
}
|
|
4470
|
-
|
|
4471
|
-
const float d = (max - min) / ((1 << 4) - 1);
|
|
4472
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4473
|
-
|
|
4474
|
-
dst_data[i00/QK4_1].d = d;
|
|
4475
|
-
dst_data[i00/QK4_1].m = min;
|
|
4476
|
-
|
|
4477
|
-
for (int j = 0; j < QK4_1/2; ++j) {
|
|
4478
|
-
const float x0 = (src[0 + j] - min)*id;
|
|
4479
|
-
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
4480
|
-
|
|
4481
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
4482
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
4483
|
-
|
|
4484
|
-
dst_data[i00/QK4_1].qs[j] = xi0;
|
|
4485
|
-
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
|
4486
|
-
}
|
|
4689
|
+
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
|
4487
4690
|
}
|
|
4488
4691
|
}
|
|
4489
4692
|
|
|
@@ -4510,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
|
|
|
4510
4713
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
|
4511
4714
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4512
4715
|
|
|
4513
|
-
|
|
4514
|
-
float max = 0.0f;
|
|
4515
|
-
|
|
4516
|
-
for (int j = 0; j < QK5_0; j++) {
|
|
4517
|
-
const float v = src[j];
|
|
4518
|
-
if (amax < fabs(v)) {
|
|
4519
|
-
amax = fabs(v);
|
|
4520
|
-
max = v;
|
|
4521
|
-
}
|
|
4522
|
-
}
|
|
4523
|
-
|
|
4524
|
-
const float d = max / -16;
|
|
4525
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4526
|
-
|
|
4527
|
-
dst_data[i00/QK5_0].d = d;
|
|
4528
|
-
|
|
4529
|
-
uint32_t qh = 0;
|
|
4530
|
-
for (int j = 0; j < QK5_0/2; ++j) {
|
|
4531
|
-
const float x0 = src[0 + j]*id;
|
|
4532
|
-
const float x1 = src[QK5_0/2 + j]*id;
|
|
4533
|
-
|
|
4534
|
-
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
|
4535
|
-
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
|
4536
|
-
|
|
4537
|
-
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
4538
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
4539
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
4540
|
-
}
|
|
4541
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
4542
|
-
for (int j = 0; j < 4; ++j) {
|
|
4543
|
-
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
|
4544
|
-
}
|
|
4716
|
+
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
|
4545
4717
|
}
|
|
4546
4718
|
}
|
|
4547
4719
|
|
|
@@ -4568,49 +4740,8 @@ kernel void kernel_cpy_f32_q5_1(
|
|
|
4568
4740
|
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
|
4569
4741
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4570
4742
|
|
|
4571
|
-
|
|
4572
|
-
float min = src[0];
|
|
4573
|
-
|
|
4574
|
-
for (int j = 1; j < QK5_1; j++) {
|
|
4575
|
-
const float v = src[j];
|
|
4576
|
-
min = v < min ? v : min;
|
|
4577
|
-
max = v > max ? v : max;
|
|
4578
|
-
}
|
|
4579
|
-
|
|
4580
|
-
const float d = (max - min) / 31;
|
|
4581
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4582
|
-
|
|
4583
|
-
dst_data[i00/QK5_1].d = d;
|
|
4584
|
-
dst_data[i00/QK5_1].m = min;
|
|
4585
|
-
|
|
4586
|
-
uint32_t qh = 0;
|
|
4587
|
-
for (int j = 0; j < QK5_1/2; ++j) {
|
|
4588
|
-
const float x0 = (src[0 + j] - min)*id;
|
|
4589
|
-
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
|
4590
|
-
|
|
4591
|
-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
4592
|
-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
4593
|
-
|
|
4594
|
-
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
4595
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
4596
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
4597
|
-
}
|
|
4598
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
4599
|
-
for (int j = 0; j < 4; ++j) {
|
|
4600
|
-
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
|
4601
|
-
}
|
|
4602
|
-
}
|
|
4603
|
-
}
|
|
4604
|
-
|
|
4605
|
-
static inline int best_index_int8(int n, constant float * val, float x) {
|
|
4606
|
-
if (x <= val[0]) return 0;
|
|
4607
|
-
if (x >= val[n-1]) return n-1;
|
|
4608
|
-
int ml = 0, mu = n-1;
|
|
4609
|
-
while (mu-ml > 1) {
|
|
4610
|
-
int mav = (ml+mu)/2;
|
|
4611
|
-
if (x < val[mav]) mu = mav; else ml = mav;
|
|
4743
|
+
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
|
4612
4744
|
}
|
|
4613
|
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
4614
4745
|
}
|
|
4615
4746
|
|
|
4616
4747
|
kernel void kernel_cpy_f32_iq4_nl(
|
|
@@ -4636,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
|
4636
4767
|
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
|
4637
4768
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4638
4769
|
|
|
4639
|
-
|
|
4640
|
-
float max = 0.0f;
|
|
4641
|
-
|
|
4642
|
-
for (int j = 0; j < QK4_NL; j++) {
|
|
4643
|
-
const float v = src[j];
|
|
4644
|
-
if (amax < fabs(v)) {
|
|
4645
|
-
amax = fabs(v);
|
|
4646
|
-
max = v;
|
|
4647
|
-
}
|
|
4648
|
-
}
|
|
4649
|
-
|
|
4650
|
-
const float d = max / kvalues_iq4nl_f[0];
|
|
4651
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4652
|
-
|
|
4653
|
-
float sumqx = 0, sumq2 = 0;
|
|
4654
|
-
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
4655
|
-
const float x0 = src[0 + j]*id;
|
|
4656
|
-
const float x1 = src[QK4_NL/2 + j]*id;
|
|
4657
|
-
|
|
4658
|
-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
|
4659
|
-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
|
4660
|
-
|
|
4661
|
-
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
|
4662
|
-
|
|
4663
|
-
const float v0 = kvalues_iq4nl_f[xi0];
|
|
4664
|
-
const float v1 = kvalues_iq4nl_f[xi1];
|
|
4665
|
-
const float w0 = src[0 + j]*src[0 + j];
|
|
4666
|
-
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
|
4667
|
-
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
|
4668
|
-
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
4669
|
-
|
|
4670
|
-
}
|
|
4671
|
-
|
|
4672
|
-
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
4770
|
+
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
|
4673
4771
|
}
|
|
4674
4772
|
}
|
|
4675
4773
|
|
|
@@ -6350,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
|
6350
6448
|
|
|
6351
6449
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
6352
6450
|
kernel void kernel_get_rows_q(
|
|
6451
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6353
6452
|
device const void * src0,
|
|
6354
6453
|
device const void * src1,
|
|
6355
6454
|
device float * dst,
|
|
6356
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6357
6455
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6358
6456
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6359
6457
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6373,10 +6471,10 @@ kernel void kernel_get_rows_q(
|
|
|
6373
6471
|
|
|
6374
6472
|
template<typename T>
|
|
6375
6473
|
kernel void kernel_get_rows_f(
|
|
6474
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6376
6475
|
device const void * src0,
|
|
6377
6476
|
device const void * src1,
|
|
6378
6477
|
device float * dst,
|
|
6379
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6380
6478
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6381
6479
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6382
6480
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6394,10 +6492,10 @@ kernel void kernel_get_rows_f(
|
|
|
6394
6492
|
}
|
|
6395
6493
|
|
|
6396
6494
|
kernel void kernel_get_rows_i32(
|
|
6495
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6397
6496
|
device const void * src0,
|
|
6398
6497
|
device const void * src1,
|
|
6399
6498
|
device int32_t * dst,
|
|
6400
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6401
6499
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6402
6500
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6403
6501
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6414,6 +6512,67 @@ kernel void kernel_get_rows_i32(
|
|
|
6414
6512
|
}
|
|
6415
6513
|
}
|
|
6416
6514
|
|
|
6515
|
+
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
|
6516
|
+
kernel void kernel_set_rows_q32(
|
|
6517
|
+
constant ggml_metal_kargs_set_rows & args,
|
|
6518
|
+
device const void * src0,
|
|
6519
|
+
device const void * src1,
|
|
6520
|
+
device float * dst,
|
|
6521
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6522
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6523
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
6524
|
+
const int32_t i03 = tgpig.z;
|
|
6525
|
+
const int32_t i02 = tgpig.y;
|
|
6526
|
+
|
|
6527
|
+
const int32_t i12 = i03%args.ne12;
|
|
6528
|
+
const int32_t i11 = i02%args.ne11;
|
|
6529
|
+
|
|
6530
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
|
6531
|
+
if (i01 >= args.ne01) {
|
|
6532
|
+
return;
|
|
6533
|
+
}
|
|
6534
|
+
|
|
6535
|
+
const int32_t i10 = i01;
|
|
6536
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
6537
|
+
|
|
6538
|
+
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
6539
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
6540
|
+
|
|
6541
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
|
6542
|
+
quantize_func(src_row + 32*ind, dst_row[ind]);
|
|
6543
|
+
}
|
|
6544
|
+
}
|
|
6545
|
+
|
|
6546
|
+
template<typename T>
|
|
6547
|
+
kernel void kernel_set_rows_f(
|
|
6548
|
+
constant ggml_metal_kargs_set_rows & args,
|
|
6549
|
+
device const void * src0,
|
|
6550
|
+
device const void * src1,
|
|
6551
|
+
device float * dst,
|
|
6552
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6553
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6554
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
6555
|
+
const int32_t i03 = tgpig.z;
|
|
6556
|
+
const int32_t i02 = tgpig.y;
|
|
6557
|
+
|
|
6558
|
+
const int32_t i12 = i03%args.ne12;
|
|
6559
|
+
const int32_t i11 = i02%args.ne11;
|
|
6560
|
+
|
|
6561
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
|
6562
|
+
if (i01 >= args.ne01) {
|
|
6563
|
+
return;
|
|
6564
|
+
}
|
|
6565
|
+
|
|
6566
|
+
const int32_t i10 = i01;
|
|
6567
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
6568
|
+
|
|
6569
|
+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
6570
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
6571
|
+
|
|
6572
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
|
6573
|
+
dst_row[ind] = (T) src_row[ind];
|
|
6574
|
+
}
|
|
6575
|
+
}
|
|
6417
6576
|
|
|
6418
6577
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
6419
6578
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
@@ -6837,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
|
|
|
6837
6996
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6838
6997
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6839
6998
|
|
|
6999
|
+
//
|
|
7000
|
+
// set rows
|
|
7001
|
+
//
|
|
7002
|
+
|
|
7003
|
+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
|
7004
|
+
|
|
7005
|
+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
|
7006
|
+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
|
7007
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
7008
|
+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
|
7009
|
+
#endif
|
|
7010
|
+
|
|
7011
|
+
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
|
7012
|
+
|
|
7013
|
+
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
|
7014
|
+
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
|
7015
|
+
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
|
7016
|
+
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
|
7017
|
+
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
|
7018
|
+
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
|
7019
|
+
|
|
6840
7020
|
//
|
|
6841
7021
|
// matrix-matrix multiplication
|
|
6842
7022
|
//
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
|
-
#include "
|
|
4
|
-
#include "
|
|
3
|
+
#include "ggml-cuda/common.cuh"
|
|
4
|
+
#include "ggml.h"
|
|
5
5
|
|
|
6
6
|
// Asynchronously copies data from src tensor to dst tensor using the provided context.
|
|
7
7
|
// Returns a musaError_t indicating success or failure.
|