whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
35
35
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
36
36
|
};
|
37
37
|
|
38
|
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
39
|
+
if (x <= val[0]) return 0;
|
40
|
+
if (x >= val[n-1]) return n-1;
|
41
|
+
int ml = 0, mu = n-1;
|
42
|
+
while (mu-ml > 1) {
|
43
|
+
int mav = (ml+mu)/2;
|
44
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
45
|
+
}
|
46
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
47
|
+
}
|
48
|
+
|
38
49
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
39
50
|
template <typename type4x4>
|
40
51
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
@@ -97,6 +108,176 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
97
108
|
}
|
98
109
|
}
|
99
110
|
|
111
|
+
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
112
|
+
float amax = 0.0f; // absolute max
|
113
|
+
float max = 0.0f;
|
114
|
+
|
115
|
+
for (int j = 0; j < QK4_0; j++) {
|
116
|
+
const float v = src[j];
|
117
|
+
if (amax < fabs(v)) {
|
118
|
+
amax = fabs(v);
|
119
|
+
max = v;
|
120
|
+
}
|
121
|
+
}
|
122
|
+
|
123
|
+
const float d = max / -8;
|
124
|
+
const float id = d ? 1.0f/d : 0.0f;
|
125
|
+
|
126
|
+
dst.d = d;
|
127
|
+
|
128
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
129
|
+
const float x0 = src[0 + j]*id;
|
130
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
131
|
+
|
132
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
133
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
134
|
+
|
135
|
+
dst.qs[j] = xi0;
|
136
|
+
dst.qs[j] |= xi1 << 4;
|
137
|
+
}
|
138
|
+
}
|
139
|
+
|
140
|
+
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
141
|
+
#pragma METAL fp math_mode(safe)
|
142
|
+
float min = FLT_MAX;
|
143
|
+
float max = -FLT_MAX;
|
144
|
+
|
145
|
+
for (int j = 0; j < QK4_1; j++) {
|
146
|
+
const float v = src[j];
|
147
|
+
if (min > v) min = v;
|
148
|
+
if (max < v) max = v;
|
149
|
+
}
|
150
|
+
|
151
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
152
|
+
const float id = d ? 1.0f/d : 0.0f;
|
153
|
+
|
154
|
+
dst.d = d;
|
155
|
+
dst.m = min;
|
156
|
+
|
157
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
158
|
+
const float x0 = (src[0 + j] - min)*id;
|
159
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
160
|
+
|
161
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
162
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
163
|
+
|
164
|
+
dst.qs[j] = xi0;
|
165
|
+
dst.qs[j] |= xi1 << 4;
|
166
|
+
}
|
167
|
+
}
|
168
|
+
|
169
|
+
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
170
|
+
float amax = 0.0f; // absolute max
|
171
|
+
float max = 0.0f;
|
172
|
+
|
173
|
+
for (int j = 0; j < QK5_0; j++) {
|
174
|
+
const float v = src[j];
|
175
|
+
if (amax < fabs(v)) {
|
176
|
+
amax = fabs(v);
|
177
|
+
max = v;
|
178
|
+
}
|
179
|
+
}
|
180
|
+
|
181
|
+
const float d = max / -16;
|
182
|
+
const float id = d ? 1.0f/d : 0.0f;
|
183
|
+
|
184
|
+
dst.d = d;
|
185
|
+
|
186
|
+
uint32_t qh = 0;
|
187
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
188
|
+
const float x0 = src[0 + j]*id;
|
189
|
+
const float x1 = src[QK5_0/2 + j]*id;
|
190
|
+
|
191
|
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
192
|
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
193
|
+
|
194
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
195
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
196
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
197
|
+
}
|
198
|
+
|
199
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
200
|
+
|
201
|
+
for (int j = 0; j < 4; ++j) {
|
202
|
+
dst.qh[j] = qh8[j];
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
207
|
+
#pragma METAL fp math_mode(safe)
|
208
|
+
float max = src[0];
|
209
|
+
float min = src[0];
|
210
|
+
|
211
|
+
for (int j = 1; j < QK5_1; j++) {
|
212
|
+
const float v = src[j];
|
213
|
+
min = v < min ? v : min;
|
214
|
+
max = v > max ? v : max;
|
215
|
+
}
|
216
|
+
|
217
|
+
const float d = (max - min) / 31;
|
218
|
+
const float id = d ? 1.0f/d : 0.0f;
|
219
|
+
|
220
|
+
dst.d = d;
|
221
|
+
dst.m = min;
|
222
|
+
|
223
|
+
uint32_t qh = 0;
|
224
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
225
|
+
const float x0 = (src[0 + j] - min)*id;
|
226
|
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
227
|
+
|
228
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
229
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
230
|
+
|
231
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
232
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
233
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
234
|
+
}
|
235
|
+
|
236
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
237
|
+
|
238
|
+
for (int j = 0; j < 4; ++j) {
|
239
|
+
dst.qh[j] = qh8[j];
|
240
|
+
}
|
241
|
+
}
|
242
|
+
|
243
|
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
244
|
+
#pragma METAL fp math_mode(safe)
|
245
|
+
float amax = 0.0f; // absolute max
|
246
|
+
float max = 0.0f;
|
247
|
+
|
248
|
+
for (int j = 0; j < QK4_NL; j++) {
|
249
|
+
const float v = src[j];
|
250
|
+
if (amax < fabs(v)) {
|
251
|
+
amax = fabs(v);
|
252
|
+
max = v;
|
253
|
+
}
|
254
|
+
}
|
255
|
+
|
256
|
+
const float d = max / kvalues_iq4nl_f[0];
|
257
|
+
const float id = d ? 1.0f/d : 0.0f;
|
258
|
+
|
259
|
+
float sumqx = 0, sumq2 = 0;
|
260
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
261
|
+
const float x0 = src[0 + j]*id;
|
262
|
+
const float x1 = src[QK4_NL/2 + j]*id;
|
263
|
+
|
264
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
265
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
266
|
+
|
267
|
+
dst.qs[j] = xi0 | (xi1 << 4);
|
268
|
+
|
269
|
+
const float v0 = kvalues_iq4nl_f[xi0];
|
270
|
+
const float v1 = kvalues_iq4nl_f[xi1];
|
271
|
+
const float w0 = src[0 + j]*src[0 + j];
|
272
|
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
273
|
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
274
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
275
|
+
|
276
|
+
}
|
277
|
+
|
278
|
+
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
279
|
+
}
|
280
|
+
|
100
281
|
template <typename type4x4>
|
101
282
|
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
102
283
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
@@ -279,6 +460,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
279
460
|
}
|
280
461
|
}
|
281
462
|
|
463
|
+
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
464
|
+
float amax = 0.0f; // absolute max
|
465
|
+
|
466
|
+
for (int j = 0; j < QK8_0; j++) {
|
467
|
+
const float v = src[j];
|
468
|
+
amax = MAX(amax, fabs(v));
|
469
|
+
}
|
470
|
+
|
471
|
+
const float d = amax / ((1 << 7) - 1);
|
472
|
+
const float id = d ? 1.0f/d : 0.0f;
|
473
|
+
|
474
|
+
dst.d = d;
|
475
|
+
|
476
|
+
for (int j = 0; j < QK8_0; ++j) {
|
477
|
+
const float x0 = src[j]*id;
|
478
|
+
|
479
|
+
dst.qs[j] = round(x0);
|
480
|
+
}
|
481
|
+
}
|
482
|
+
|
282
483
|
template <typename type4x4>
|
283
484
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
284
485
|
const float d = xb->d;
|
@@ -993,31 +1194,125 @@ kernel void kernel_neg(
|
|
993
1194
|
dst[tpig] = -src0[tpig];
|
994
1195
|
}
|
995
1196
|
|
1197
|
+
kernel void kernel_reglu(
|
1198
|
+
device const char * src0,
|
1199
|
+
device const char * src1,
|
1200
|
+
device char * dst,
|
1201
|
+
constant ggml_metal_kargs_glu & args,
|
1202
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
1203
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
1204
|
+
uint ntg[[threads_per_threadgroup]]) {
|
1205
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
1206
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
1207
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
1208
|
+
|
1209
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
1210
|
+
const float x0 = src0_row[i0];
|
1211
|
+
const float x1 = src1_row[i0];
|
1212
|
+
|
1213
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
1214
|
+
}
|
1215
|
+
}
|
1216
|
+
|
1217
|
+
kernel void kernel_geglu(
|
1218
|
+
device const char * src0,
|
1219
|
+
device const char * src1,
|
1220
|
+
device char * dst,
|
1221
|
+
constant ggml_metal_kargs_glu & args,
|
1222
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
1223
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
1224
|
+
uint ntg[[threads_per_threadgroup]]) {
|
1225
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
1226
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
1227
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
1228
|
+
|
1229
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
1230
|
+
const float x0 = src0_row[i0];
|
1231
|
+
const float x1 = src1_row[i0];
|
1232
|
+
|
1233
|
+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
1234
|
+
|
1235
|
+
dst_row[i0] = gelu*x1;
|
1236
|
+
}
|
1237
|
+
}
|
1238
|
+
|
1239
|
+
kernel void kernel_swiglu(
|
1240
|
+
device const char * src0,
|
1241
|
+
device const char * src1,
|
1242
|
+
device char * dst,
|
1243
|
+
constant ggml_metal_kargs_glu & args,
|
1244
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
1245
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
1246
|
+
uint ntg[[threads_per_threadgroup]]) {
|
1247
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
1248
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
1249
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
1250
|
+
|
1251
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
1252
|
+
const float x0 = src0_row[i0];
|
1253
|
+
const float x1 = src1_row[i0];
|
1254
|
+
|
1255
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
1256
|
+
|
1257
|
+
dst_row[i0] = silu*x1;
|
1258
|
+
}
|
1259
|
+
}
|
1260
|
+
|
1261
|
+
template <bool norm>
|
996
1262
|
kernel void kernel_sum_rows(
|
1263
|
+
constant ggml_metal_kargs_sum_rows & args,
|
997
1264
|
device const float * src0,
|
998
1265
|
device float * dst,
|
999
|
-
|
1000
|
-
uint3
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1266
|
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
1267
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1268
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
1269
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
1270
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
1271
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
1272
|
+
int64_t i3 = tgpig.z;
|
1273
|
+
int64_t i2 = tgpig.y;
|
1274
|
+
int64_t i1 = tgpig.x;
|
1004
1275
|
|
1005
1276
|
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
1006
1277
|
return;
|
1007
1278
|
}
|
1008
1279
|
|
1280
|
+
if (sgitg == 0) {
|
1281
|
+
shmem_f32[tiisg] = 0.0f;
|
1282
|
+
}
|
1283
|
+
|
1009
1284
|
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
1010
1285
|
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
1011
1286
|
|
1012
|
-
float
|
1287
|
+
float sumf = 0;
|
1013
1288
|
|
1014
|
-
for (int64_t i0 =
|
1015
|
-
|
1289
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
1290
|
+
sumf += src_row[i0];
|
1016
1291
|
}
|
1017
1292
|
|
1018
|
-
|
1293
|
+
sumf = simd_sum(sumf);
|
1294
|
+
|
1295
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1296
|
+
|
1297
|
+
if (tiisg == 0) {
|
1298
|
+
shmem_f32[sgitg] = sumf;
|
1299
|
+
}
|
1300
|
+
|
1301
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1302
|
+
|
1303
|
+
sumf = shmem_f32[tiisg];
|
1304
|
+
sumf = simd_sum(sumf);
|
1305
|
+
|
1306
|
+
if (tpitg.x == 0) {
|
1307
|
+
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
1308
|
+
}
|
1019
1309
|
}
|
1020
1310
|
|
1311
|
+
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
1312
|
+
|
1313
|
+
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
1314
|
+
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
1315
|
+
|
1021
1316
|
template<typename T>
|
1022
1317
|
kernel void kernel_soft_max(
|
1023
1318
|
device const char * src0,
|
@@ -2502,6 +2797,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
|
|
2502
2797
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
2503
2798
|
#endif
|
2504
2799
|
|
2800
|
+
template<typename T04, typename T14, typename args_t>
|
2801
|
+
void kernel_mul_mv_c4_impl(
|
2802
|
+
args_t args,
|
2803
|
+
device const char * src0,
|
2804
|
+
device const char * src1,
|
2805
|
+
device char * dst,
|
2806
|
+
uint3 tgpig,
|
2807
|
+
ushort tiisg) {
|
2808
|
+
const int r0 = tgpig.x*32 + tiisg;
|
2809
|
+
const int rb = tgpig.y*N_MV_T_T;
|
2810
|
+
const int im = tgpig.z;
|
2811
|
+
|
2812
|
+
if (r0 >= args.ne01) {
|
2813
|
+
return;
|
2814
|
+
}
|
2815
|
+
|
2816
|
+
const uint i12 = im%args.ne12;
|
2817
|
+
const uint i13 = im/args.ne12;
|
2818
|
+
|
2819
|
+
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
2820
|
+
|
2821
|
+
device const T04 * x = (device const T04 *) (src0 + offset0);
|
2822
|
+
|
2823
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
2824
|
+
|
2825
|
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
2826
|
+
int r1 = rb + row;
|
2827
|
+
if (r1 >= args.ne11) {
|
2828
|
+
break;
|
2829
|
+
}
|
2830
|
+
|
2831
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
2832
|
+
|
2833
|
+
device const T14 * y = (device const T14 *) (src1 + offset1);
|
2834
|
+
|
2835
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
2836
|
+
}
|
2837
|
+
}
|
2838
|
+
|
2839
|
+
template<typename T04, typename T14>
|
2840
|
+
kernel void kernel_mul_mv_c4(
|
2841
|
+
constant ggml_metal_kargs_mul_mv & args,
|
2842
|
+
device const char * src0,
|
2843
|
+
device const char * src1,
|
2844
|
+
device char * dst,
|
2845
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2846
|
+
ushort tiisg[[thread_index_in_simdgroup]]) {
|
2847
|
+
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
2848
|
+
args,
|
2849
|
+
src0,
|
2850
|
+
src1,
|
2851
|
+
dst,
|
2852
|
+
tgpig,
|
2853
|
+
tiisg);
|
2854
|
+
}
|
2855
|
+
|
2856
|
+
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
2857
|
+
|
2858
|
+
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
2859
|
+
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
2860
|
+
#if defined(GGML_METAL_USE_BF16)
|
2861
|
+
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
2862
|
+
#endif
|
2863
|
+
|
2505
2864
|
template<typename T, typename T4>
|
2506
2865
|
kernel void kernel_mul_mv_1row(
|
2507
2866
|
constant ggml_metal_kargs_mul_mv & args,
|
@@ -3328,14 +3687,12 @@ kernel void kernel_flash_attn_ext(
|
|
3328
3687
|
constexpr short NW = N_SIMDWIDTH;
|
3329
3688
|
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
3330
3689
|
|
3331
|
-
const short TS = nsg*SH;
|
3332
|
-
const short T = DK + 2*TS; // shared memory size per query in (half)
|
3690
|
+
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
3691
|
+
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
|
3333
3692
|
|
3334
|
-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
3335
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
3336
|
-
threadgroup
|
3337
|
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
3338
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
3693
|
+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
3694
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
3695
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
3339
3696
|
|
3340
3697
|
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
3341
3698
|
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
@@ -3354,7 +3711,7 @@ kernel void kernel_flash_attn_ext(
|
|
3354
3711
|
if (iq1 + j < args.ne01) {
|
3355
3712
|
sq4[j*DK4 + i] = (q4_t) q4[i];
|
3356
3713
|
} else {
|
3357
|
-
sq4[j*DK4 + i] =
|
3714
|
+
sq4[j*DK4 + i] = 0;
|
3358
3715
|
}
|
3359
3716
|
}
|
3360
3717
|
}
|
@@ -3548,20 +3905,20 @@ kernel void kernel_flash_attn_ext(
|
|
3548
3905
|
|
3549
3906
|
// O = diag(ms)*O
|
3550
3907
|
{
|
3551
|
-
s8x8_t
|
3552
|
-
simdgroup_load(
|
3908
|
+
s8x8_t ms;
|
3909
|
+
simdgroup_load(ms, ss + 2*C, TS, 0, false);
|
3553
3910
|
|
3554
3911
|
#pragma unroll(DV8)
|
3555
3912
|
for (short i = 0; i < DV8; ++i) {
|
3556
|
-
simdgroup_multiply(lo[i],
|
3913
|
+
simdgroup_multiply(lo[i], ms, lo[i]);
|
3557
3914
|
}
|
3558
3915
|
}
|
3559
3916
|
|
3560
3917
|
// O = O + (Q*K^T)*V
|
3561
3918
|
{
|
3562
3919
|
for (short cc = 0; cc < C/8; ++cc) {
|
3563
|
-
s8x8_t
|
3564
|
-
simdgroup_load(
|
3920
|
+
s8x8_t vs;
|
3921
|
+
simdgroup_load(vs, ss + 8*cc, TS, 0, false);
|
3565
3922
|
|
3566
3923
|
if (is_same<vd4x4_t, v4x4_t>::value) {
|
3567
3924
|
// we can read directly from global memory
|
@@ -3572,7 +3929,7 @@ kernel void kernel_flash_attn_ext(
|
|
3572
3929
|
v8x8_t mv;
|
3573
3930
|
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
3574
3931
|
|
3575
|
-
simdgroup_multiply_accumulate(lo[i],
|
3932
|
+
simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
|
3576
3933
|
}
|
3577
3934
|
} else {
|
3578
3935
|
for (short ii = 0; ii < DV16; ii += 4) {
|
@@ -3593,10 +3950,10 @@ kernel void kernel_flash_attn_ext(
|
|
3593
3950
|
v8x8_t mv;
|
3594
3951
|
|
3595
3952
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
3596
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0],
|
3953
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
|
3597
3954
|
|
3598
3955
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
3599
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1],
|
3956
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
|
3600
3957
|
}
|
3601
3958
|
} else {
|
3602
3959
|
if (ii + tx < DV16) {
|
@@ -3611,10 +3968,10 @@ kernel void kernel_flash_attn_ext(
|
|
3611
3968
|
v8x8_t mv;
|
3612
3969
|
|
3613
3970
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
3614
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0],
|
3971
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
|
3615
3972
|
|
3616
3973
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
3617
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1],
|
3974
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
|
3618
3975
|
}
|
3619
3976
|
}
|
3620
3977
|
}
|
@@ -3624,93 +3981,89 @@ kernel void kernel_flash_attn_ext(
|
|
3624
3981
|
}
|
3625
3982
|
|
3626
3983
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
3627
|
-
for (short j =
|
3628
|
-
|
3629
|
-
|
3630
|
-
ss[j*TS + 1] = M[j];
|
3631
|
-
}
|
3984
|
+
for (short j = tiisg; j < Q; j += NW) {
|
3985
|
+
ss[j*TS + 0] = S[j];
|
3986
|
+
ss[j*TS + 1] = M[j];
|
3632
3987
|
}
|
3633
3988
|
}
|
3634
3989
|
|
3635
|
-
|
3636
|
-
for (ushort sg = 1; sg < nsg; ++sg) {
|
3637
|
-
float S = { 0.0f };
|
3638
|
-
float M = { -__FLT_MAX__/2 };
|
3990
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3639
3991
|
|
3640
|
-
|
3992
|
+
threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
3993
|
+
threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
|
3641
3994
|
|
3642
|
-
|
3643
|
-
|
3644
|
-
|
3645
|
-
|
3646
|
-
|
3995
|
+
// store result to shared memory in F32
|
3996
|
+
if (sgitg == 0) {
|
3997
|
+
for (short i = 0; i < DV8; ++i) {
|
3998
|
+
//simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
3999
|
+
simdgroup_float8x8 t(1.0f);
|
4000
|
+
simdgroup_multiply(t, lo[i], t);
|
4001
|
+
simdgroup_store(t, so + i*8, DV, 0, false);
|
3647
4002
|
}
|
4003
|
+
}
|
3648
4004
|
|
3649
|
-
|
4005
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3650
4006
|
|
3651
|
-
|
3652
|
-
|
3653
|
-
|
3654
|
-
|
3655
|
-
const float
|
4007
|
+
// reduce the warps sequentially
|
4008
|
+
for (ushort sg = 1; sg < nsg; ++sg) {
|
4009
|
+
if (sgitg == sg) {
|
4010
|
+
for (short j = tiisg; j < Q; j += NW) {
|
4011
|
+
const float S0 = ss[j*TS - 1*SH + 0];
|
4012
|
+
const float S1 = ss[j*TS + 0];
|
3656
4013
|
|
3657
|
-
const float M0 = ss[j*TS +
|
3658
|
-
const float M1 = ss[j*TS
|
4014
|
+
const float M0 = ss[j*TS - 1*SH + 1];
|
4015
|
+
const float M1 = ss[j*TS + 1];
|
3659
4016
|
|
3660
|
-
M = max(M0, M1);
|
4017
|
+
const float M = max(M0, M1);
|
3661
4018
|
|
3662
|
-
|
3663
|
-
|
4019
|
+
float ms0 = exp(M0 - M);
|
4020
|
+
float ms1 = exp(M1 - M);
|
3664
4021
|
|
3665
|
-
S = S0*ms0 + S1*ms1;
|
4022
|
+
const float S = S0*ms0 + S1*ms1;
|
3666
4023
|
|
3667
|
-
|
3668
|
-
|
3669
|
-
ss[j*TS + 1] = M;
|
4024
|
+
ss[j*TS + 0] = S;
|
4025
|
+
ss[j*TS + 1] = M;
|
3670
4026
|
|
3671
|
-
|
3672
|
-
|
3673
|
-
}
|
4027
|
+
ss[j*TS + 2*C + j - 1*SH] = ms0;
|
4028
|
+
ss[j*TS + 2*C + j ] = ms1;
|
3674
4029
|
}
|
3675
4030
|
|
4031
|
+
//simdgroup_barrier(mem_flags::mem_threadgroup);
|
4032
|
+
|
3676
4033
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
3677
4034
|
{
|
3678
4035
|
s8x8_t ms0;
|
3679
4036
|
s8x8_t ms1;
|
3680
4037
|
|
3681
|
-
simdgroup_load(ms0, ss + 2*C,
|
3682
|
-
simdgroup_load(ms1, ss + 2*C
|
4038
|
+
simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
|
4039
|
+
simdgroup_load(ms1, ss + 2*C, TS, 0, false);
|
3683
4040
|
|
3684
4041
|
#pragma unroll(DV8)
|
3685
4042
|
for (short i = 0; i < DV8; ++i) {
|
3686
|
-
|
4043
|
+
simdgroup_float8x8 t;
|
3687
4044
|
|
3688
4045
|
simdgroup_load (t, so + i*8, DV, 0, false);
|
3689
|
-
simdgroup_multiply(t,
|
4046
|
+
simdgroup_multiply(t, ms0, t);
|
3690
4047
|
|
3691
|
-
simdgroup_multiply_accumulate(
|
4048
|
+
simdgroup_multiply_accumulate(t, ms1, lo[i], t);
|
4049
|
+
simdgroup_store(t, so + i*8, DV, 0, false);
|
3692
4050
|
}
|
3693
4051
|
}
|
3694
4052
|
}
|
3695
|
-
}
|
3696
4053
|
|
3697
|
-
|
3698
|
-
if (sgitg == 0) {
|
3699
|
-
for (short i = 0; i < DV8; ++i) {
|
3700
|
-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
3701
|
-
}
|
4054
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3702
4055
|
}
|
3703
4056
|
|
3704
|
-
|
4057
|
+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
|
3705
4058
|
|
3706
4059
|
// final rescale with 1/S and store to global memory
|
3707
|
-
|
3708
|
-
|
3709
|
-
const float S = ss[j*TS + 0];
|
4060
|
+
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
|
4061
|
+
const float S = 1.0f/sf[j*TS + 0];
|
3710
4062
|
|
3711
|
-
|
3712
|
-
|
3713
|
-
|
4063
|
+
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
4064
|
+
|
4065
|
+
for (short i = tiisg; i < DV4; i += NW) {
|
4066
|
+
dst4[i] = (float4) so4[j*DV4 + i]*S;
|
3714
4067
|
}
|
3715
4068
|
}
|
3716
4069
|
}
|
@@ -3719,12 +4072,22 @@ kernel void kernel_flash_attn_ext(
|
|
3719
4072
|
// template to be able to explore different combinations
|
3720
4073
|
//
|
3721
4074
|
#define FA_TYPES \
|
3722
|
-
|
3723
|
-
half,
|
3724
|
-
half,
|
3725
|
-
float,
|
3726
|
-
float,
|
3727
|
-
half,
|
4075
|
+
float, float4, simdgroup_float8x8, \
|
4076
|
+
half, half4x4, simdgroup_half8x8, \
|
4077
|
+
half, half4x4, simdgroup_half8x8, \
|
4078
|
+
float, simdgroup_float8x8, \
|
4079
|
+
float, simdgroup_float8x8, \
|
4080
|
+
half, half4, simdgroup_half8x8
|
4081
|
+
//float, float4, simdgroup_float8x8
|
4082
|
+
|
4083
|
+
#define FA_TYPES_BF \
|
4084
|
+
bfloat, bfloat4, simdgroup_bfloat8x8, \
|
4085
|
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
4086
|
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
4087
|
+
float, simdgroup_float8x8, \
|
4088
|
+
float, simdgroup_float8x8, \
|
4089
|
+
half, half4, simdgroup_half8x8
|
4090
|
+
//float, float4, simdgroup_float8x8
|
3728
4091
|
|
3729
4092
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
3730
4093
|
|
@@ -3739,15 +4102,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
|
3739
4102
|
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
3740
4103
|
|
3741
4104
|
#if defined(GGML_METAL_USE_BF16)
|
3742
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3743
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3744
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3745
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3746
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3747
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3748
|
-
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3749
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
3750
|
-
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
4105
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
4106
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
4107
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
4108
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
4109
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
4110
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
4111
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
4112
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
4113
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
3751
4114
|
#endif
|
3752
4115
|
|
3753
4116
|
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
@@ -3801,6 +4164,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
|
|
3801
4164
|
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
3802
4165
|
|
3803
4166
|
#undef FA_TYPES
|
4167
|
+
#undef FA_TYPES_BF
|
3804
4168
|
|
3805
4169
|
template<
|
3806
4170
|
typename q4_t, // query types in shared memory
|
@@ -3847,12 +4211,12 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3847
4211
|
|
3848
4212
|
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
3849
4213
|
|
3850
|
-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
3851
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
3852
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 +
|
3853
|
-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 +
|
3854
|
-
threadgroup float * sm = (threadgroup float *) (shmem_f16 +
|
3855
|
-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
4214
|
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
4215
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
4216
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
4217
|
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
4218
|
+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
4219
|
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
|
3856
4220
|
|
3857
4221
|
// store the result for all queries in local memory (the O matrix from the paper)
|
3858
4222
|
o4_t lo[DV4/NL];
|
@@ -4157,7 +4521,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
4157
4521
|
half4, \
|
4158
4522
|
float, \
|
4159
4523
|
float, float4, \
|
4160
|
-
|
4524
|
+
float4
|
4161
4525
|
|
4162
4526
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
4163
4527
|
|
@@ -4271,11 +4635,16 @@ kernel void kernel_cpy(
|
|
4271
4635
|
device const char * src0,
|
4272
4636
|
device char * dst,
|
4273
4637
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4638
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4274
4639
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
4275
|
-
ushort3
|
4640
|
+
ushort3 tptg[[threads_per_threadgroup]]) {
|
4276
4641
|
const int i03 = tgpig[2];
|
4277
4642
|
const int i02 = tgpig[1];
|
4278
|
-
const int i01 = tgpig[0];
|
4643
|
+
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
|
4644
|
+
|
4645
|
+
if (i01 >= args.ne01) {
|
4646
|
+
return;
|
4647
|
+
}
|
4279
4648
|
|
4280
4649
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
4281
4650
|
|
@@ -4286,7 +4655,7 @@ kernel void kernel_cpy(
|
|
4286
4655
|
|
4287
4656
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
4288
4657
|
|
4289
|
-
for (int64_t i00 =
|
4658
|
+
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
|
4290
4659
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4291
4660
|
dst_data[i00] = (T1) src[0];
|
4292
4661
|
}
|
@@ -4306,6 +4675,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
|
|
4306
4675
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
4307
4676
|
#endif
|
4308
4677
|
|
4678
|
+
// TODO: templetify these kernels
|
4309
4679
|
kernel void kernel_cpy_f32_q8_0(
|
4310
4680
|
constant ggml_metal_kargs_cpy & args,
|
4311
4681
|
device const char * src0,
|
@@ -4329,23 +4699,7 @@ kernel void kernel_cpy_f32_q8_0(
|
|
4329
4699
|
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
4330
4700
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4331
4701
|
|
4332
|
-
|
4333
|
-
|
4334
|
-
for (int j = 0; j < QK8_0; j++) {
|
4335
|
-
const float v = src[j];
|
4336
|
-
amax = MAX(amax, fabs(v));
|
4337
|
-
}
|
4338
|
-
|
4339
|
-
const float d = amax / ((1 << 7) - 1);
|
4340
|
-
const float id = d ? 1.0f/d : 0.0f;
|
4341
|
-
|
4342
|
-
dst_data[i00/QK8_0].d = d;
|
4343
|
-
|
4344
|
-
for (int j = 0; j < QK8_0; ++j) {
|
4345
|
-
const float x0 = src[j]*id;
|
4346
|
-
|
4347
|
-
dst_data[i00/QK8_0].qs[j] = round(x0);
|
4348
|
-
}
|
4702
|
+
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
4349
4703
|
}
|
4350
4704
|
}
|
4351
4705
|
|
@@ -4372,32 +4726,7 @@ kernel void kernel_cpy_f32_q4_0(
|
|
4372
4726
|
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
4373
4727
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4374
4728
|
|
4375
|
-
|
4376
|
-
float max = 0.0f;
|
4377
|
-
|
4378
|
-
for (int j = 0; j < QK4_0; j++) {
|
4379
|
-
const float v = src[j];
|
4380
|
-
if (amax < fabs(v)) {
|
4381
|
-
amax = fabs(v);
|
4382
|
-
max = v;
|
4383
|
-
}
|
4384
|
-
}
|
4385
|
-
|
4386
|
-
const float d = max / -8;
|
4387
|
-
const float id = d ? 1.0f/d : 0.0f;
|
4388
|
-
|
4389
|
-
dst_data[i00/QK4_0].d = d;
|
4390
|
-
|
4391
|
-
for (int j = 0; j < QK4_0/2; ++j) {
|
4392
|
-
const float x0 = src[0 + j]*id;
|
4393
|
-
const float x1 = src[QK4_0/2 + j]*id;
|
4394
|
-
|
4395
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
4396
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
4397
|
-
|
4398
|
-
dst_data[i00/QK4_0].qs[j] = xi0;
|
4399
|
-
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
4400
|
-
}
|
4729
|
+
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
4401
4730
|
}
|
4402
4731
|
}
|
4403
4732
|
|
@@ -4424,31 +4753,7 @@ kernel void kernel_cpy_f32_q4_1(
|
|
4424
4753
|
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
4425
4754
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4426
4755
|
|
4427
|
-
|
4428
|
-
float max = -FLT_MAX;
|
4429
|
-
|
4430
|
-
for (int j = 0; j < QK4_1; j++) {
|
4431
|
-
const float v = src[j];
|
4432
|
-
if (min > v) min = v;
|
4433
|
-
if (max < v) max = v;
|
4434
|
-
}
|
4435
|
-
|
4436
|
-
const float d = (max - min) / ((1 << 4) - 1);
|
4437
|
-
const float id = d ? 1.0f/d : 0.0f;
|
4438
|
-
|
4439
|
-
dst_data[i00/QK4_1].d = d;
|
4440
|
-
dst_data[i00/QK4_1].m = min;
|
4441
|
-
|
4442
|
-
for (int j = 0; j < QK4_1/2; ++j) {
|
4443
|
-
const float x0 = (src[0 + j] - min)*id;
|
4444
|
-
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
4445
|
-
|
4446
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
4447
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
4448
|
-
|
4449
|
-
dst_data[i00/QK4_1].qs[j] = xi0;
|
4450
|
-
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
4451
|
-
}
|
4756
|
+
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
4452
4757
|
}
|
4453
4758
|
}
|
4454
4759
|
|
@@ -4475,38 +4780,7 @@ kernel void kernel_cpy_f32_q5_0(
|
|
4475
4780
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
4476
4781
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4477
4782
|
|
4478
|
-
|
4479
|
-
float max = 0.0f;
|
4480
|
-
|
4481
|
-
for (int j = 0; j < QK5_0; j++) {
|
4482
|
-
const float v = src[j];
|
4483
|
-
if (amax < fabs(v)) {
|
4484
|
-
amax = fabs(v);
|
4485
|
-
max = v;
|
4486
|
-
}
|
4487
|
-
}
|
4488
|
-
|
4489
|
-
const float d = max / -16;
|
4490
|
-
const float id = d ? 1.0f/d : 0.0f;
|
4491
|
-
|
4492
|
-
dst_data[i00/QK5_0].d = d;
|
4493
|
-
|
4494
|
-
uint32_t qh = 0;
|
4495
|
-
for (int j = 0; j < QK5_0/2; ++j) {
|
4496
|
-
const float x0 = src[0 + j]*id;
|
4497
|
-
const float x1 = src[QK5_0/2 + j]*id;
|
4498
|
-
|
4499
|
-
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
4500
|
-
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
4501
|
-
|
4502
|
-
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
4503
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
4504
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
4505
|
-
}
|
4506
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
4507
|
-
for (int j = 0; j < 4; ++j) {
|
4508
|
-
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
4509
|
-
}
|
4783
|
+
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
4510
4784
|
}
|
4511
4785
|
}
|
4512
4786
|
|
@@ -4533,49 +4807,8 @@ kernel void kernel_cpy_f32_q5_1(
|
|
4533
4807
|
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
4534
4808
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4535
4809
|
|
4536
|
-
|
4537
|
-
float min = src[0];
|
4538
|
-
|
4539
|
-
for (int j = 1; j < QK5_1; j++) {
|
4540
|
-
const float v = src[j];
|
4541
|
-
min = v < min ? v : min;
|
4542
|
-
max = v > max ? v : max;
|
4543
|
-
}
|
4544
|
-
|
4545
|
-
const float d = (max - min) / 31;
|
4546
|
-
const float id = d ? 1.0f/d : 0.0f;
|
4547
|
-
|
4548
|
-
dst_data[i00/QK5_1].d = d;
|
4549
|
-
dst_data[i00/QK5_1].m = min;
|
4550
|
-
|
4551
|
-
uint32_t qh = 0;
|
4552
|
-
for (int j = 0; j < QK5_1/2; ++j) {
|
4553
|
-
const float x0 = (src[0 + j] - min)*id;
|
4554
|
-
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
4555
|
-
|
4556
|
-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
4557
|
-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
4558
|
-
|
4559
|
-
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
4560
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
4561
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
4562
|
-
}
|
4563
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
4564
|
-
for (int j = 0; j < 4; ++j) {
|
4565
|
-
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
4566
|
-
}
|
4567
|
-
}
|
4568
|
-
}
|
4569
|
-
|
4570
|
-
static inline int best_index_int8(int n, constant float * val, float x) {
|
4571
|
-
if (x <= val[0]) return 0;
|
4572
|
-
if (x >= val[n-1]) return n-1;
|
4573
|
-
int ml = 0, mu = n-1;
|
4574
|
-
while (mu-ml > 1) {
|
4575
|
-
int mav = (ml+mu)/2;
|
4576
|
-
if (x < val[mav]) mu = mav; else ml = mav;
|
4810
|
+
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
4577
4811
|
}
|
4578
|
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
4579
4812
|
}
|
4580
4813
|
|
4581
4814
|
kernel void kernel_cpy_f32_iq4_nl(
|
@@ -4601,40 +4834,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
4601
4834
|
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
4602
4835
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
4603
4836
|
|
4604
|
-
|
4605
|
-
float max = 0.0f;
|
4606
|
-
|
4607
|
-
for (int j = 0; j < QK4_NL; j++) {
|
4608
|
-
const float v = src[j];
|
4609
|
-
if (amax < fabs(v)) {
|
4610
|
-
amax = fabs(v);
|
4611
|
-
max = v;
|
4612
|
-
}
|
4613
|
-
}
|
4614
|
-
|
4615
|
-
const float d = max / kvalues_iq4nl_f[0];
|
4616
|
-
const float id = d ? 1.0f/d : 0.0f;
|
4617
|
-
|
4618
|
-
float sumqx = 0, sumq2 = 0;
|
4619
|
-
for (int j = 0; j < QK4_NL/2; ++j) {
|
4620
|
-
const float x0 = src[0 + j]*id;
|
4621
|
-
const float x1 = src[QK4_NL/2 + j]*id;
|
4622
|
-
|
4623
|
-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
4624
|
-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
4625
|
-
|
4626
|
-
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
4627
|
-
|
4628
|
-
const float v0 = kvalues_iq4nl_f[xi0];
|
4629
|
-
const float v1 = kvalues_iq4nl_f[xi1];
|
4630
|
-
const float w0 = src[0 + j]*src[0 + j];
|
4631
|
-
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
4632
|
-
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
4633
|
-
sumq2 += w0*v0*v0 + w1*v1*v1;
|
4634
|
-
|
4635
|
-
}
|
4636
|
-
|
4637
|
-
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
4837
|
+
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
4638
4838
|
}
|
4639
4839
|
}
|
4640
4840
|
|
@@ -6315,10 +6515,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
6315
6515
|
|
6316
6516
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
6317
6517
|
kernel void kernel_get_rows_q(
|
6518
|
+
constant ggml_metal_kargs_get_rows & args,
|
6318
6519
|
device const void * src0,
|
6319
6520
|
device const void * src1,
|
6320
6521
|
device float * dst,
|
6321
|
-
constant ggml_metal_kargs_get_rows & args,
|
6322
6522
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6323
6523
|
uint tiitg[[thread_index_in_threadgroup]],
|
6324
6524
|
uint3 tptg [[threads_per_threadgroup]]) {
|
@@ -6338,10 +6538,10 @@ kernel void kernel_get_rows_q(
|
|
6338
6538
|
|
6339
6539
|
template<typename T>
|
6340
6540
|
kernel void kernel_get_rows_f(
|
6541
|
+
constant ggml_metal_kargs_get_rows & args,
|
6341
6542
|
device const void * src0,
|
6342
6543
|
device const void * src1,
|
6343
6544
|
device float * dst,
|
6344
|
-
constant ggml_metal_kargs_get_rows & args,
|
6345
6545
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6346
6546
|
uint tiitg[[thread_index_in_threadgroup]],
|
6347
6547
|
uint3 tptg [[threads_per_threadgroup]]) {
|
@@ -6359,10 +6559,10 @@ kernel void kernel_get_rows_f(
|
|
6359
6559
|
}
|
6360
6560
|
|
6361
6561
|
kernel void kernel_get_rows_i32(
|
6562
|
+
constant ggml_metal_kargs_get_rows & args,
|
6362
6563
|
device const void * src0,
|
6363
6564
|
device const void * src1,
|
6364
6565
|
device int32_t * dst,
|
6365
|
-
constant ggml_metal_kargs_get_rows & args,
|
6366
6566
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6367
6567
|
uint tiitg[[thread_index_in_threadgroup]],
|
6368
6568
|
uint3 tptg [[threads_per_threadgroup]]) {
|
@@ -6379,6 +6579,67 @@ kernel void kernel_get_rows_i32(
|
|
6379
6579
|
}
|
6380
6580
|
}
|
6381
6581
|
|
6582
|
+
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
6583
|
+
kernel void kernel_set_rows_q32(
|
6584
|
+
constant ggml_metal_kargs_set_rows & args,
|
6585
|
+
device const void * src0,
|
6586
|
+
device const void * src1,
|
6587
|
+
device float * dst,
|
6588
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6589
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
6590
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
6591
|
+
const int32_t i03 = tgpig.z;
|
6592
|
+
const int32_t i02 = tgpig.y;
|
6593
|
+
|
6594
|
+
const int32_t i12 = i03%args.ne12;
|
6595
|
+
const int32_t i11 = i02%args.ne11;
|
6596
|
+
|
6597
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
6598
|
+
if (i01 >= args.ne01) {
|
6599
|
+
return;
|
6600
|
+
}
|
6601
|
+
|
6602
|
+
const int32_t i10 = i01;
|
6603
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
6604
|
+
|
6605
|
+
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
6606
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
6607
|
+
|
6608
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
6609
|
+
quantize_func(src_row + 32*ind, dst_row[ind]);
|
6610
|
+
}
|
6611
|
+
}
|
6612
|
+
|
6613
|
+
template<typename T>
|
6614
|
+
kernel void kernel_set_rows_f(
|
6615
|
+
constant ggml_metal_kargs_set_rows & args,
|
6616
|
+
device const void * src0,
|
6617
|
+
device const void * src1,
|
6618
|
+
device float * dst,
|
6619
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6620
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
6621
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
6622
|
+
const int32_t i03 = tgpig.z;
|
6623
|
+
const int32_t i02 = tgpig.y;
|
6624
|
+
|
6625
|
+
const int32_t i12 = i03%args.ne12;
|
6626
|
+
const int32_t i11 = i02%args.ne11;
|
6627
|
+
|
6628
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
6629
|
+
if (i01 >= args.ne01) {
|
6630
|
+
return;
|
6631
|
+
}
|
6632
|
+
|
6633
|
+
const int32_t i10 = i01;
|
6634
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
6635
|
+
|
6636
|
+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
6637
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
6638
|
+
|
6639
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
6640
|
+
dst_row[ind] = (T) src_row[ind];
|
6641
|
+
}
|
6642
|
+
}
|
6382
6643
|
|
6383
6644
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
6384
6645
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
@@ -6802,6 +7063,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
|
|
6802
7063
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6803
7064
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6804
7065
|
|
7066
|
+
//
|
7067
|
+
// set rows
|
7068
|
+
//
|
7069
|
+
|
7070
|
+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
7071
|
+
|
7072
|
+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
7073
|
+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
7074
|
+
#if defined(GGML_METAL_USE_BF16)
|
7075
|
+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
7076
|
+
#endif
|
7077
|
+
|
7078
|
+
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
7079
|
+
|
7080
|
+
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
7081
|
+
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
7082
|
+
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
7083
|
+
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
7084
|
+
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
7085
|
+
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
7086
|
+
|
6805
7087
|
//
|
6806
7088
|
// matrix-matrix multiplication
|
6807
7089
|
//
|