whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -3,6 +3,7 @@
|
|
3
3
|
#include "ggml-cpu.h"
|
4
4
|
#include "ggml-impl.h"
|
5
5
|
#include "binary-ops.h"
|
6
|
+
#include "ggml.h"
|
6
7
|
#include "unary-ops.h"
|
7
8
|
#include "vec.h"
|
8
9
|
|
@@ -108,7 +109,7 @@ static void ggml_compute_forward_dup_f16(
|
|
108
109
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
109
110
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
110
111
|
for (int i00 = 0; i00 < ne00; i00++) {
|
111
|
-
dst_ptr[id] =
|
112
|
+
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
112
113
|
id++;
|
113
114
|
}
|
114
115
|
}
|
@@ -130,7 +131,7 @@ static void ggml_compute_forward_dup_f16(
|
|
130
131
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
131
132
|
|
132
133
|
for (int i00 = 0; i00 < ne00; i00++) {
|
133
|
-
src0_f32[i00] =
|
134
|
+
src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
134
135
|
}
|
135
136
|
|
136
137
|
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
@@ -156,7 +157,7 @@ static void ggml_compute_forward_dup_f16(
|
|
156
157
|
for (int i00 = 0; i00 < ne00; i00++) {
|
157
158
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
158
159
|
|
159
|
-
dst_ptr[id] =
|
160
|
+
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
160
161
|
id++;
|
161
162
|
}
|
162
163
|
}
|
@@ -267,7 +268,7 @@ static void ggml_compute_forward_dup_f16(
|
|
267
268
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
268
269
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
269
270
|
|
270
|
-
*(float *) dst_ptr =
|
271
|
+
*(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
271
272
|
|
272
273
|
if (++i10 == ne0) {
|
273
274
|
i10 = 0;
|
@@ -372,7 +373,7 @@ static void ggml_compute_forward_dup_bf16(
|
|
372
373
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
373
374
|
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
374
375
|
for (int i00 = 0; i00 < ne00; i00++) {
|
375
|
-
dst_ptr[id] =
|
376
|
+
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
|
376
377
|
id++;
|
377
378
|
}
|
378
379
|
}
|
@@ -473,7 +474,7 @@ static void ggml_compute_forward_dup_bf16(
|
|
473
474
|
for (int i00 = 0; i00 < ne00; i00++) {
|
474
475
|
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
475
476
|
|
476
|
-
dst_ptr[id] =
|
477
|
+
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
|
477
478
|
id++;
|
478
479
|
}
|
479
480
|
}
|
@@ -566,7 +567,7 @@ static void ggml_compute_forward_dup_bf16(
|
|
566
567
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
567
568
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
568
569
|
|
569
|
-
*(ggml_fp16_t *) dst_ptr =
|
570
|
+
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
|
570
571
|
|
571
572
|
if (++i10 == ne0) {
|
572
573
|
i10 = 0;
|
@@ -696,24 +697,8 @@ static void ggml_compute_forward_dup_f32(
|
|
696
697
|
if (ggml_is_contiguous(dst)) {
|
697
698
|
// TODO: simplify
|
698
699
|
if (nb00 == sizeof(float)) {
|
699
|
-
if (dst->type
|
700
|
-
|
701
|
-
const size_t rs = ne00 * nb00;
|
702
|
-
char * dst_ptr = (char *) dst->data;
|
703
|
-
|
704
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
705
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
706
|
-
id += rs * ir0;
|
707
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
708
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
709
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
710
|
-
id += rs;
|
711
|
-
}
|
712
|
-
id += rs * (ne01 - ir1);
|
713
|
-
}
|
714
|
-
}
|
715
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
716
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
700
|
+
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
701
|
+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
717
702
|
|
718
703
|
size_t id = 0;
|
719
704
|
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
@@ -724,7 +709,7 @@ static void ggml_compute_forward_dup_f32(
|
|
724
709
|
id += rs * ir0;
|
725
710
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
726
711
|
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
727
|
-
|
712
|
+
from_float(src0_ptr, dst_ptr + id, ne00);
|
728
713
|
id += rs;
|
729
714
|
}
|
730
715
|
id += rs * (ne01 - ir1);
|
@@ -765,7 +750,7 @@ static void ggml_compute_forward_dup_f32(
|
|
765
750
|
for (int i00 = 0; i00 < ne00; i00++) {
|
766
751
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
767
752
|
|
768
|
-
dst_ptr[id] =
|
753
|
+
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
769
754
|
id++;
|
770
755
|
}
|
771
756
|
}
|
@@ -878,7 +863,7 @@ static void ggml_compute_forward_dup_f32(
|
|
878
863
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
879
864
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
880
865
|
|
881
|
-
*(ggml_fp16_t *) dst_ptr =
|
866
|
+
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
|
882
867
|
|
883
868
|
if (++i10 == ne0) {
|
884
869
|
i10 = 0;
|
@@ -1419,7 +1404,7 @@ static void ggml_compute_forward_add1_f16_f32(
|
|
1419
1404
|
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
1420
1405
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
1421
1406
|
for (int i = 0; i < ne0; i++) {
|
1422
|
-
dst_ptr[i] =
|
1407
|
+
dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
|
1423
1408
|
}
|
1424
1409
|
}
|
1425
1410
|
}
|
@@ -1435,7 +1420,7 @@ static void ggml_compute_forward_add1_f16_f16(
|
|
1435
1420
|
GGML_ASSERT(ggml_is_scalar(src1));
|
1436
1421
|
|
1437
1422
|
// scalar to add
|
1438
|
-
const float v =
|
1423
|
+
const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
|
1439
1424
|
|
1440
1425
|
const int ith = params->ith;
|
1441
1426
|
const int nth = params->nth;
|
@@ -1467,7 +1452,7 @@ static void ggml_compute_forward_add1_f16_f16(
|
|
1467
1452
|
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
1468
1453
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
1469
1454
|
for (int i = 0; i < ne0; i++) {
|
1470
|
-
dst_ptr[i] =
|
1455
|
+
dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
|
1471
1456
|
}
|
1472
1457
|
}
|
1473
1458
|
}
|
@@ -1889,7 +1874,7 @@ static void ggml_compute_forward_sum_f16(
|
|
1889
1874
|
}
|
1890
1875
|
}
|
1891
1876
|
}
|
1892
|
-
((ggml_fp16_t *) dst->data)[0] =
|
1877
|
+
((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
|
1893
1878
|
}
|
1894
1879
|
|
1895
1880
|
static void ggml_compute_forward_sum_bf16(
|
@@ -2300,6 +2285,12 @@ void ggml_compute_forward_repeat(
|
|
2300
2285
|
{
|
2301
2286
|
ggml_compute_forward_repeat_f32(params, dst);
|
2302
2287
|
} break;
|
2288
|
+
// TODO: templateify the implemenation and support for I64
|
2289
|
+
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
|
2290
|
+
//case GGML_TYPE_I64:
|
2291
|
+
// {
|
2292
|
+
// ggml_compute_forward_repeat_i64(params, dst);
|
2293
|
+
// } break;
|
2303
2294
|
default:
|
2304
2295
|
{
|
2305
2296
|
GGML_ABORT("fatal error");
|
@@ -2660,7 +2651,7 @@ static void ggml_compute_forward_gelu_f16(
|
|
2660
2651
|
#ifndef NDEBUG
|
2661
2652
|
for (int k = 0; k < nc; k++) {
|
2662
2653
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2663
|
-
const float v =
|
2654
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
2664
2655
|
GGML_UNUSED(v);
|
2665
2656
|
assert(!isnan(v));
|
2666
2657
|
assert(!isinf(v));
|
@@ -2763,7 +2754,7 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
2763
2754
|
#ifndef NDEBUG
|
2764
2755
|
for (int k = 0; k < nc; k++) {
|
2765
2756
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2766
|
-
const float v =
|
2757
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
2767
2758
|
GGML_UNUSED(v);
|
2768
2759
|
assert(!isnan(v));
|
2769
2760
|
assert(!isinf(v));
|
@@ -2866,7 +2857,7 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
2866
2857
|
#ifndef NDEBUG
|
2867
2858
|
for (int k = 0; k < nc; k++) {
|
2868
2859
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2869
|
-
const float v =
|
2860
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
2870
2861
|
GGML_UNUSED(v);
|
2871
2862
|
assert(!isnan(v));
|
2872
2863
|
assert(!isinf(v));
|
@@ -2969,7 +2960,7 @@ static void ggml_compute_forward_silu_f16(
|
|
2969
2960
|
#ifndef NDEBUG
|
2970
2961
|
for (int k = 0; k < nc; k++) {
|
2971
2962
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
2972
|
-
const float v =
|
2963
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
2973
2964
|
GGML_UNUSED(v);
|
2974
2965
|
assert(!isnan(v));
|
2975
2966
|
assert(!isinf(v));
|
@@ -3163,7 +3154,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
3163
3154
|
#ifndef NDEBUG
|
3164
3155
|
for (int k = 0; k < nc; k++) {
|
3165
3156
|
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3166
|
-
const float v =
|
3157
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
3167
3158
|
GGML_UNUSED(v);
|
3168
3159
|
assert(!isnan(v));
|
3169
3160
|
assert(!isinf(v));
|
@@ -3194,6 +3185,435 @@ void ggml_compute_forward_silu_back(
|
|
3194
3185
|
}
|
3195
3186
|
}
|
3196
3187
|
|
3188
|
+
// ggml_compute_forward_reglu
|
3189
|
+
|
3190
|
+
static void ggml_compute_forward_reglu_f32(
|
3191
|
+
const ggml_compute_params * params,
|
3192
|
+
ggml_tensor * dst) {
|
3193
|
+
|
3194
|
+
const ggml_tensor * src0 = dst->src[0];
|
3195
|
+
const ggml_tensor * src1 = dst->src[1];
|
3196
|
+
char * src0_d = (char *) src0->data;
|
3197
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3198
|
+
const size_t src0_o = src0->nb[1];
|
3199
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3200
|
+
|
3201
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3202
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3203
|
+
|
3204
|
+
if (src1) {
|
3205
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3206
|
+
GGML_ASSERT(src0->type == src1->type);
|
3207
|
+
}
|
3208
|
+
|
3209
|
+
const int ith = params->ith;
|
3210
|
+
const int nth = params->nth;
|
3211
|
+
|
3212
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3213
|
+
const int nr = ggml_nrows(src0);
|
3214
|
+
|
3215
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3216
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3217
|
+
|
3218
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3219
|
+
|
3220
|
+
// rows per thread
|
3221
|
+
const int dr = (nr + nth - 1)/nth;
|
3222
|
+
|
3223
|
+
// row range for this thread
|
3224
|
+
const int ir0 = dr*ith;
|
3225
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3226
|
+
|
3227
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3228
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
3229
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
3230
|
+
|
3231
|
+
if (!src1) {
|
3232
|
+
src0_p += swapped ? nc : 0;
|
3233
|
+
src1_p += swapped ? 0 : nc;
|
3234
|
+
}
|
3235
|
+
|
3236
|
+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3237
|
+
|
3238
|
+
#ifndef NDEBUG
|
3239
|
+
for (int k = 0; k < nc; k++) {
|
3240
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3241
|
+
GGML_UNUSED(x);
|
3242
|
+
assert(!isnan(x));
|
3243
|
+
assert(!isinf(x));
|
3244
|
+
}
|
3245
|
+
#endif
|
3246
|
+
}
|
3247
|
+
}
|
3248
|
+
|
3249
|
+
static void ggml_compute_forward_reglu_f16(
|
3250
|
+
const ggml_compute_params * params,
|
3251
|
+
ggml_tensor * dst) {
|
3252
|
+
|
3253
|
+
const ggml_tensor * src0 = dst->src[0];
|
3254
|
+
const ggml_tensor * src1 = dst->src[1];
|
3255
|
+
char * src0_d = (char *) src0->data;
|
3256
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3257
|
+
const size_t src0_o = src0->nb[1];
|
3258
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3259
|
+
|
3260
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3261
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3262
|
+
|
3263
|
+
if (src1) {
|
3264
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3265
|
+
GGML_ASSERT(src0->type == src1->type);
|
3266
|
+
}
|
3267
|
+
|
3268
|
+
const int ith = params->ith;
|
3269
|
+
const int nth = params->nth;
|
3270
|
+
|
3271
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3272
|
+
const int nr = ggml_nrows(src0);
|
3273
|
+
|
3274
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3275
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3276
|
+
|
3277
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3278
|
+
|
3279
|
+
// rows per thread
|
3280
|
+
const int dr = (nr + nth - 1)/nth;
|
3281
|
+
|
3282
|
+
// row range for this thread
|
3283
|
+
const int ir0 = dr*ith;
|
3284
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3285
|
+
|
3286
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3287
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
3288
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
3289
|
+
|
3290
|
+
if (!src1) {
|
3291
|
+
src0_p += swapped ? nc : 0;
|
3292
|
+
src1_p += swapped ? 0 : nc;
|
3293
|
+
}
|
3294
|
+
|
3295
|
+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3296
|
+
|
3297
|
+
#ifndef NDEBUG
|
3298
|
+
for (int k = 0; k < nc; k++) {
|
3299
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3300
|
+
const float v = GGML_FP16_TO_FP32(x);
|
3301
|
+
GGML_UNUSED(v);
|
3302
|
+
assert(!isnan(v));
|
3303
|
+
assert(!isinf(v));
|
3304
|
+
}
|
3305
|
+
#endif
|
3306
|
+
}
|
3307
|
+
}
|
3308
|
+
|
3309
|
+
static void ggml_compute_forward_reglu(
|
3310
|
+
const ggml_compute_params * params,
|
3311
|
+
ggml_tensor * dst) {
|
3312
|
+
|
3313
|
+
const ggml_tensor * src0 = dst->src[0];
|
3314
|
+
|
3315
|
+
switch (src0->type) {
|
3316
|
+
case GGML_TYPE_F32:
|
3317
|
+
{
|
3318
|
+
ggml_compute_forward_reglu_f32(params, dst);
|
3319
|
+
} break;
|
3320
|
+
case GGML_TYPE_F16:
|
3321
|
+
{
|
3322
|
+
ggml_compute_forward_reglu_f16(params, dst);
|
3323
|
+
} break;
|
3324
|
+
default:
|
3325
|
+
{
|
3326
|
+
GGML_ABORT("fatal error");
|
3327
|
+
}
|
3328
|
+
}
|
3329
|
+
}
|
3330
|
+
|
3331
|
+
// ggml_compute_forward_geglu
|
3332
|
+
|
3333
|
+
static void ggml_compute_forward_geglu_f32(
|
3334
|
+
const ggml_compute_params * params,
|
3335
|
+
ggml_tensor * dst) {
|
3336
|
+
|
3337
|
+
const ggml_tensor * src0 = dst->src[0];
|
3338
|
+
const ggml_tensor * src1 = dst->src[1];
|
3339
|
+
char * src0_d = (char *) src0->data;
|
3340
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3341
|
+
const size_t src0_o = src0->nb[1];
|
3342
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3343
|
+
|
3344
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3345
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3346
|
+
|
3347
|
+
if (src1) {
|
3348
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3349
|
+
GGML_ASSERT(src0->type == src1->type);
|
3350
|
+
}
|
3351
|
+
|
3352
|
+
const int ith = params->ith;
|
3353
|
+
const int nth = params->nth;
|
3354
|
+
|
3355
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3356
|
+
const int nr = ggml_nrows(src0);
|
3357
|
+
|
3358
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3359
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3360
|
+
|
3361
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3362
|
+
|
3363
|
+
// rows per thread
|
3364
|
+
const int dr = (nr + nth - 1)/nth;
|
3365
|
+
|
3366
|
+
// row range for this thread
|
3367
|
+
const int ir0 = dr*ith;
|
3368
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3369
|
+
|
3370
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3371
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
3372
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
3373
|
+
|
3374
|
+
if (!src1) {
|
3375
|
+
src0_p += swapped ? nc : 0;
|
3376
|
+
src1_p += swapped ? 0 : nc;
|
3377
|
+
}
|
3378
|
+
|
3379
|
+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3380
|
+
|
3381
|
+
#ifndef NDEBUG
|
3382
|
+
for (int k = 0; k < nc; k++) {
|
3383
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3384
|
+
GGML_UNUSED(x);
|
3385
|
+
assert(!isnan(x));
|
3386
|
+
assert(!isinf(x));
|
3387
|
+
}
|
3388
|
+
#endif
|
3389
|
+
}
|
3390
|
+
}
|
3391
|
+
|
3392
|
+
static void ggml_compute_forward_geglu_f16(
|
3393
|
+
const ggml_compute_params * params,
|
3394
|
+
ggml_tensor * dst) {
|
3395
|
+
|
3396
|
+
const ggml_tensor * src0 = dst->src[0];
|
3397
|
+
const ggml_tensor * src1 = dst->src[1];
|
3398
|
+
char * src0_d = (char *) src0->data;
|
3399
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3400
|
+
const size_t src0_o = src0->nb[1];
|
3401
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3402
|
+
|
3403
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3404
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3405
|
+
|
3406
|
+
if (src1) {
|
3407
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3408
|
+
GGML_ASSERT(src0->type == src1->type);
|
3409
|
+
}
|
3410
|
+
|
3411
|
+
const int ith = params->ith;
|
3412
|
+
const int nth = params->nth;
|
3413
|
+
|
3414
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3415
|
+
const int nr = ggml_nrows(src0);
|
3416
|
+
|
3417
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3418
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3419
|
+
|
3420
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3421
|
+
|
3422
|
+
// rows per thread
|
3423
|
+
const int dr = (nr + nth - 1)/nth;
|
3424
|
+
|
3425
|
+
// row range for this thread
|
3426
|
+
const int ir0 = dr*ith;
|
3427
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3428
|
+
|
3429
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3430
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
3431
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
3432
|
+
|
3433
|
+
if (!src1) {
|
3434
|
+
src0_p += swapped ? nc : 0;
|
3435
|
+
src1_p += swapped ? 0 : nc;
|
3436
|
+
}
|
3437
|
+
|
3438
|
+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3439
|
+
|
3440
|
+
#ifndef NDEBUG
|
3441
|
+
for (int k = 0; k < nc; k++) {
|
3442
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3443
|
+
const float v = GGML_FP16_TO_FP32(x);
|
3444
|
+
GGML_UNUSED(v);
|
3445
|
+
assert(!isnan(v));
|
3446
|
+
assert(!isinf(v));
|
3447
|
+
}
|
3448
|
+
#endif
|
3449
|
+
}
|
3450
|
+
}
|
3451
|
+
|
3452
|
+
static void ggml_compute_forward_geglu(
|
3453
|
+
const ggml_compute_params * params,
|
3454
|
+
ggml_tensor * dst) {
|
3455
|
+
|
3456
|
+
const ggml_tensor * src0 = dst->src[0];
|
3457
|
+
|
3458
|
+
switch (src0->type) {
|
3459
|
+
case GGML_TYPE_F32:
|
3460
|
+
{
|
3461
|
+
ggml_compute_forward_geglu_f32(params, dst);
|
3462
|
+
} break;
|
3463
|
+
case GGML_TYPE_F16:
|
3464
|
+
{
|
3465
|
+
ggml_compute_forward_geglu_f16(params, dst);
|
3466
|
+
} break;
|
3467
|
+
default:
|
3468
|
+
{
|
3469
|
+
GGML_ABORT("fatal error");
|
3470
|
+
}
|
3471
|
+
}
|
3472
|
+
}
|
3473
|
+
|
3474
|
+
// ggml_compute_forward_swiglu
|
3475
|
+
|
3476
|
+
static void ggml_compute_forward_swiglu_f32(
|
3477
|
+
const ggml_compute_params * params,
|
3478
|
+
ggml_tensor * dst) {
|
3479
|
+
|
3480
|
+
const ggml_tensor * src0 = dst->src[0];
|
3481
|
+
const ggml_tensor * src1 = dst->src[1];
|
3482
|
+
char * src0_d = (char *) src0->data;
|
3483
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3484
|
+
const size_t src0_o = src0->nb[1];
|
3485
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3486
|
+
|
3487
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3488
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3489
|
+
|
3490
|
+
if (src1) {
|
3491
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3492
|
+
GGML_ASSERT(src0->type == src1->type);
|
3493
|
+
}
|
3494
|
+
|
3495
|
+
const int ith = params->ith;
|
3496
|
+
const int nth = params->nth;
|
3497
|
+
|
3498
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3499
|
+
const int nr = ggml_nrows(src0);
|
3500
|
+
|
3501
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3502
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3503
|
+
|
3504
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3505
|
+
|
3506
|
+
// rows per thread
|
3507
|
+
const int dr = (nr + nth - 1)/nth;
|
3508
|
+
|
3509
|
+
// row range for this thread
|
3510
|
+
const int ir0 = dr*ith;
|
3511
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3512
|
+
|
3513
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3514
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
3515
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
3516
|
+
|
3517
|
+
if (!src1) {
|
3518
|
+
src0_p += swapped ? nc : 0;
|
3519
|
+
src1_p += swapped ? 0 : nc;
|
3520
|
+
}
|
3521
|
+
|
3522
|
+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3523
|
+
|
3524
|
+
#ifndef NDEBUG
|
3525
|
+
for (int k = 0; k < nc; k++) {
|
3526
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3527
|
+
GGML_UNUSED(x);
|
3528
|
+
assert(!isnan(x));
|
3529
|
+
assert(!isinf(x));
|
3530
|
+
}
|
3531
|
+
#endif
|
3532
|
+
}
|
3533
|
+
}
|
3534
|
+
|
3535
|
+
static void ggml_compute_forward_swiglu_f16(
|
3536
|
+
const ggml_compute_params * params,
|
3537
|
+
ggml_tensor * dst) {
|
3538
|
+
|
3539
|
+
const ggml_tensor * src0 = dst->src[0];
|
3540
|
+
const ggml_tensor * src1 = dst->src[1];
|
3541
|
+
char * src0_d = (char *) src0->data;
|
3542
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3543
|
+
const size_t src0_o = src0->nb[1];
|
3544
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3545
|
+
|
3546
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3547
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3548
|
+
|
3549
|
+
if (src1) {
|
3550
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3551
|
+
GGML_ASSERT(src0->type == src1->type);
|
3552
|
+
}
|
3553
|
+
|
3554
|
+
const int ith = params->ith;
|
3555
|
+
const int nth = params->nth;
|
3556
|
+
|
3557
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3558
|
+
const int nr = ggml_nrows(src0);
|
3559
|
+
|
3560
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3561
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3562
|
+
|
3563
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3564
|
+
|
3565
|
+
// rows per thread
|
3566
|
+
const int dr = (nr + nth - 1)/nth;
|
3567
|
+
|
3568
|
+
// row range for this thread
|
3569
|
+
const int ir0 = dr*ith;
|
3570
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3571
|
+
|
3572
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3573
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
3574
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
3575
|
+
|
3576
|
+
if (!src1) {
|
3577
|
+
src0_p += swapped ? nc : 0;
|
3578
|
+
src1_p += swapped ? 0 : nc;
|
3579
|
+
}
|
3580
|
+
|
3581
|
+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3582
|
+
|
3583
|
+
#ifndef NDEBUG
|
3584
|
+
for (int k = 0; k < nc; k++) {
|
3585
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3586
|
+
const float v = GGML_FP16_TO_FP32(x);
|
3587
|
+
GGML_UNUSED(v);
|
3588
|
+
assert(!isnan(v));
|
3589
|
+
assert(!isinf(v));
|
3590
|
+
}
|
3591
|
+
#endif
|
3592
|
+
}
|
3593
|
+
}
|
3594
|
+
|
3595
|
+
static void ggml_compute_forward_swiglu(
|
3596
|
+
const ggml_compute_params * params,
|
3597
|
+
ggml_tensor * dst) {
|
3598
|
+
|
3599
|
+
const ggml_tensor * src0 = dst->src[0];
|
3600
|
+
|
3601
|
+
switch (src0->type) {
|
3602
|
+
case GGML_TYPE_F32:
|
3603
|
+
{
|
3604
|
+
ggml_compute_forward_swiglu_f32(params, dst);
|
3605
|
+
} break;
|
3606
|
+
case GGML_TYPE_F16:
|
3607
|
+
{
|
3608
|
+
ggml_compute_forward_swiglu_f16(params, dst);
|
3609
|
+
} break;
|
3610
|
+
default:
|
3611
|
+
{
|
3612
|
+
GGML_ABORT("fatal error");
|
3613
|
+
}
|
3614
|
+
}
|
3615
|
+
}
|
3616
|
+
|
3197
3617
|
// ggml_compute_forward_norm
|
3198
3618
|
|
3199
3619
|
static void ggml_compute_forward_norm_f32(
|
@@ -4470,6 +4890,74 @@ void ggml_compute_forward_get_rows(
|
|
4470
4890
|
//}
|
4471
4891
|
}
|
4472
4892
|
|
4893
|
+
static void ggml_compute_forward_set_rows_f32(
|
4894
|
+
const ggml_compute_params * params,
|
4895
|
+
ggml_tensor * dst) {
|
4896
|
+
|
4897
|
+
const ggml_tensor * src0 = dst->src[0];
|
4898
|
+
const ggml_tensor * src1 = dst->src[1];
|
4899
|
+
|
4900
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
4901
|
+
|
4902
|
+
const int64_t nc = ne00;
|
4903
|
+
const int64_t nr = ne01;
|
4904
|
+
|
4905
|
+
assert(ne0 == nc);
|
4906
|
+
assert(ne2 == ne02);
|
4907
|
+
assert(ne3 == ne03);
|
4908
|
+
assert(src0->type == GGML_TYPE_F32);
|
4909
|
+
assert(ne02 % ne11 == 0);
|
4910
|
+
assert(ne03 % ne12 == 0);
|
4911
|
+
|
4912
|
+
const int ith = params->ith;
|
4913
|
+
const int nth = params->nth;
|
4914
|
+
|
4915
|
+
// rows per thread
|
4916
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
4917
|
+
|
4918
|
+
// row range for this thread
|
4919
|
+
const int64_t ir0 = dr*ith;
|
4920
|
+
const int64_t ir1 = std::min(ir0 + dr, nr);
|
4921
|
+
|
4922
|
+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
4923
|
+
|
4924
|
+
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
4925
|
+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
4926
|
+
for (int64_t i = ir0; i < ir1; ++i) {
|
4927
|
+
const int64_t i12 = i03%ne12;
|
4928
|
+
const int64_t i11 = i02%ne11;
|
4929
|
+
const int64_t i10 = i;
|
4930
|
+
|
4931
|
+
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
4932
|
+
|
4933
|
+
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
4934
|
+
|
4935
|
+
from_float(
|
4936
|
+
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
4937
|
+
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
|
4938
|
+
}
|
4939
|
+
}
|
4940
|
+
}
|
4941
|
+
}
|
4942
|
+
|
4943
|
+
void ggml_compute_forward_set_rows(
|
4944
|
+
const ggml_compute_params * params,
|
4945
|
+
ggml_tensor * dst) {
|
4946
|
+
|
4947
|
+
const ggml_tensor * src0 = dst->src[0];
|
4948
|
+
|
4949
|
+
switch (src0->type) {
|
4950
|
+
case GGML_TYPE_F32:
|
4951
|
+
{
|
4952
|
+
ggml_compute_forward_set_rows_f32(params, dst);
|
4953
|
+
} break;
|
4954
|
+
default:
|
4955
|
+
{
|
4956
|
+
GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
|
4957
|
+
}
|
4958
|
+
}
|
4959
|
+
}
|
4960
|
+
|
4473
4961
|
// ggml_compute_forward_get_rows_back
|
4474
4962
|
|
4475
4963
|
static void ggml_compute_forward_get_rows_back_f32_f16(
|
@@ -4500,7 +4988,7 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
4500
4988
|
|
4501
4989
|
for (int j = 0; j < nc; ++j) {
|
4502
4990
|
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
|
4503
|
-
((float *) ((char *) dst->data + r*dst->nb[1]))[j] +=
|
4991
|
+
((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
|
4504
4992
|
}
|
4505
4993
|
}
|
4506
4994
|
}
|
@@ -4792,7 +5280,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
4792
5280
|
if (mp_f32) {
|
4793
5281
|
if (use_f16) {
|
4794
5282
|
for (int i = 0; i < nc; ++i) {
|
4795
|
-
wp[i] += slope*
|
5283
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
4796
5284
|
}
|
4797
5285
|
} else {
|
4798
5286
|
for (int i = 0; i < nc; ++i) {
|
@@ -5018,8 +5506,8 @@ static void ggml_compute_forward_clamp_f16(
|
|
5018
5506
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
5019
5507
|
|
5020
5508
|
for (int i = 0; i < nc; i++) {
|
5021
|
-
float v =
|
5022
|
-
dst_ptr[i] =
|
5509
|
+
float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
|
5510
|
+
dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
|
5023
5511
|
}
|
5024
5512
|
}
|
5025
5513
|
}
|
@@ -5476,11 +5964,11 @@ static void ggml_compute_forward_rope_f16(
|
|
5476
5964
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
5477
5965
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
5478
5966
|
|
5479
|
-
const float x0 =
|
5480
|
-
const float x1 =
|
5967
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
5968
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
5481
5969
|
|
5482
|
-
dst_data[0] =
|
5483
|
-
dst_data[n_dims] =
|
5970
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
5971
|
+
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5484
5972
|
}
|
5485
5973
|
} else {
|
5486
5974
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
@@ -5492,11 +5980,11 @@ static void ggml_compute_forward_rope_f16(
|
|
5492
5980
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
5493
5981
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
5494
5982
|
|
5495
|
-
const float x0 =
|
5496
|
-
const float x1 =
|
5983
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
5984
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
5497
5985
|
|
5498
|
-
dst_data[0] =
|
5499
|
-
dst_data[n_dims/2] =
|
5986
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
5987
|
+
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5500
5988
|
}
|
5501
5989
|
}
|
5502
5990
|
} else {
|
@@ -5507,11 +5995,11 @@ static void ggml_compute_forward_rope_f16(
|
|
5507
5995
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
5508
5996
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
5509
5997
|
|
5510
|
-
const float x0 =
|
5511
|
-
const float x1 =
|
5998
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
5999
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
5512
6000
|
|
5513
|
-
dst_data[0] =
|
5514
|
-
dst_data[1] =
|
6001
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
6002
|
+
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5515
6003
|
}
|
5516
6004
|
}
|
5517
6005
|
|
@@ -5525,11 +6013,11 @@ static void ggml_compute_forward_rope_f16(
|
|
5525
6013
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
5526
6014
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
5527
6015
|
|
5528
|
-
const float x0 =
|
5529
|
-
const float x1 =
|
6016
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
6017
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
5530
6018
|
|
5531
|
-
dst_data[0] =
|
5532
|
-
dst_data[n_dims] =
|
6019
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
6020
|
+
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5533
6021
|
}
|
5534
6022
|
} else {
|
5535
6023
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
@@ -5640,7 +6128,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
5640
6128
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
5641
6129
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
5642
6130
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
5643
|
-
dst_data[i10*ne11 + i11] =
|
6131
|
+
dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
5644
6132
|
}
|
5645
6133
|
}
|
5646
6134
|
}
|
@@ -5933,7 +6421,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
5933
6421
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
5934
6422
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
5935
6423
|
} else {
|
5936
|
-
dst_data[iic*(KH*KW) + ikh*KW + ikw] =
|
6424
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
5937
6425
|
}
|
5938
6426
|
}
|
5939
6427
|
}
|
@@ -6058,6 +6546,186 @@ void ggml_compute_forward_im2col_back_f32(
|
|
6058
6546
|
}
|
6059
6547
|
}
|
6060
6548
|
|
6549
|
+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
6550
|
+
void * a, void * b, float * c) {
|
6551
|
+
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
6552
|
+
struct ggml_tensor src1 = {};
|
6553
|
+
src1.type = type;
|
6554
|
+
src1.ne[0] = k;
|
6555
|
+
src1.ne[1] = m;
|
6556
|
+
src1.ne[2] = 1;
|
6557
|
+
src1.ne[3] = 1;
|
6558
|
+
src1.nb[0] = traits->type_size;
|
6559
|
+
src1.nb[1] = k * traits->type_size;
|
6560
|
+
src1.nb[2] = src1.nb[1];
|
6561
|
+
src1.nb[3] = src1.nb[2];
|
6562
|
+
src1.data = a;
|
6563
|
+
|
6564
|
+
struct ggml_tensor src0 = {};
|
6565
|
+
src0.type = type;
|
6566
|
+
src0.ne[0] = k;
|
6567
|
+
src0.ne[1] = n;
|
6568
|
+
src0.ne[2] = 1;
|
6569
|
+
src0.ne[3] = 1;
|
6570
|
+
src0.nb[0] = traits->type_size;
|
6571
|
+
src0.nb[1] = k * traits->type_size;
|
6572
|
+
src0.nb[2] = src0.nb[1];
|
6573
|
+
src0.nb[3] = src0.nb[2];
|
6574
|
+
src0.data = b;
|
6575
|
+
|
6576
|
+
struct ggml_tensor dst = {};
|
6577
|
+
dst.ne[0] = n;
|
6578
|
+
dst.ne[1] = m;
|
6579
|
+
dst.ne[2] = 1;
|
6580
|
+
dst.ne[3] = 1;
|
6581
|
+
dst.nb[0] = sizeof(float);
|
6582
|
+
dst.nb[1] = n * sizeof(float);
|
6583
|
+
dst.nb[2] = dst.nb[1];
|
6584
|
+
dst.nb[3] = dst.nb[2];
|
6585
|
+
dst.data = c;
|
6586
|
+
dst.src[0] = &src0;
|
6587
|
+
dst.src[1] = &src1;
|
6588
|
+
|
6589
|
+
ggml_compute_forward_mul_mat(params, &dst);
|
6590
|
+
}
|
6591
|
+
|
6592
|
+
// ggml_compute_forward_conv_2d
|
6593
|
+
|
6594
|
+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
6595
|
+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
6596
|
+
const ggml_tensor * src, // [W, H, C, N]
|
6597
|
+
ggml_tensor * dst, // [OW, OH, OC, N]
|
6598
|
+
ggml_type kernel_type) {
|
6599
|
+
|
6600
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
6601
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
6602
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
6603
|
+
|
6604
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
6605
|
+
|
6606
|
+
const int32_t stride_x = dst->op_params[0];
|
6607
|
+
const int32_t stride_y = dst->op_params[1];
|
6608
|
+
const int32_t pad_x = dst->op_params[2];
|
6609
|
+
const int32_t pad_y = dst->op_params[3];
|
6610
|
+
const int32_t dilation_x = dst->op_params[4];
|
6611
|
+
const int32_t dilation_y = dst->op_params[5];
|
6612
|
+
|
6613
|
+
const int64_t c_in = src->ne[2];
|
6614
|
+
const int64_t c_out = kernel->ne[3];
|
6615
|
+
GGML_ASSERT(c_in == kernel->ne[2]);
|
6616
|
+
|
6617
|
+
const int64_t src_w = src->ne[0];
|
6618
|
+
const int64_t src_h = src->ne[1];
|
6619
|
+
const int64_t knl_w = kernel->ne[0];
|
6620
|
+
const int64_t knl_h = kernel->ne[1];
|
6621
|
+
const int64_t dst_w = dst->ne[0];
|
6622
|
+
const int64_t dst_h = dst->ne[1];
|
6623
|
+
|
6624
|
+
const float * src_data = (float *) src->data;
|
6625
|
+
void * knl_data = kernel->data;
|
6626
|
+
float * dst_data = (float *) dst->data;
|
6627
|
+
|
6628
|
+
const int64_t knl_n = knl_w * knl_h * c_in;
|
6629
|
+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
|
6630
|
+
|
6631
|
+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
|
6632
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
6633
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
6634
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
6635
|
+
|
6636
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
6637
|
+
|
6638
|
+
void * tmp = params->wdata;
|
6639
|
+
|
6640
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
6641
|
+
|
6642
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
6643
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
|
6644
|
+
patch_total);
|
6645
|
+
const int64_t patch_n = patch_end_batch - patch_start_batch;
|
6646
|
+
|
6647
|
+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
|
6648
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
6649
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
6650
|
+
|
6651
|
+
//im2col for a patch
|
6652
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
6653
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
6654
|
+
const int64_t src_x = (p / dst_w) % dst_h;
|
6655
|
+
const int64_t src_y = p % dst_w;
|
6656
|
+
|
6657
|
+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
|
6658
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
|
6659
|
+
|
6660
|
+
for (int64_t ic = 0; ic < c_in; ++ic) {
|
6661
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
6662
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
6663
|
+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
|
6664
|
+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
|
6665
|
+
|
6666
|
+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
|
6667
|
+
|
6668
|
+
float src_val;
|
6669
|
+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
6670
|
+
src_val = 0.0f;
|
6671
|
+
} else {
|
6672
|
+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
|
6673
|
+
src_val = *src_ptr;
|
6674
|
+
}
|
6675
|
+
|
6676
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
6677
|
+
if (kernel_type == GGML_TYPE_F32) {
|
6678
|
+
*(float *) element_ptr = src_val;
|
6679
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
6680
|
+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
6681
|
+
}
|
6682
|
+
}
|
6683
|
+
}
|
6684
|
+
}
|
6685
|
+
} // patches handled by this thread
|
6686
|
+
|
6687
|
+
ggml_barrier(params->threadpool);
|
6688
|
+
|
6689
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
|
6690
|
+
|
6691
|
+
GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
|
6692
|
+
|
6693
|
+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
|
6694
|
+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
|
6695
|
+
|
6696
|
+
ggml_barrier(params->threadpool);
|
6697
|
+
|
6698
|
+
|
6699
|
+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
|
6700
|
+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
|
6701
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
6702
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
|
6703
|
+
|
6704
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
6705
|
+
const int64_t p = patch_start_batch + i;
|
6706
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
6707
|
+
const int64_t dst_y = (p / dst_w) % dst_h;
|
6708
|
+
const int64_t dst_x = p % dst_w;
|
6709
|
+
|
6710
|
+
for (int64_t oc = 0; oc < c_out; ++oc) {
|
6711
|
+
const float value = gemm_output[i * c_out + oc];
|
6712
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
|
6713
|
+
*dst_ptr = value;
|
6714
|
+
}
|
6715
|
+
}
|
6716
|
+
}
|
6717
|
+
}
|
6718
|
+
|
6719
|
+
void ggml_compute_forward_conv_2d(
|
6720
|
+
const ggml_compute_params * params,
|
6721
|
+
ggml_tensor * dst) {
|
6722
|
+
|
6723
|
+
const ggml_tensor * src0 = dst->src[0];
|
6724
|
+
const ggml_tensor * src1 = dst->src[1];
|
6725
|
+
|
6726
|
+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
6727
|
+
}
|
6728
|
+
|
6061
6729
|
// ggml_compute_forward_conv_transpose_2d
|
6062
6730
|
|
6063
6731
|
void ggml_compute_forward_conv_transpose_2d(
|
@@ -6109,7 +6777,7 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
6109
6777
|
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
|
6110
6778
|
ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
|
6111
6779
|
for (int i10 = 0; i10 < ne10; i10++) {
|
6112
|
-
dst_data[i10*ne12 + i12] =
|
6780
|
+
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
6113
6781
|
}
|
6114
6782
|
}
|
6115
6783
|
}
|
@@ -6358,7 +7026,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|
6358
7026
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
6359
7027
|
}
|
6360
7028
|
for (int ki = 0; ki < k; ++ki) {
|
6361
|
-
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] :
|
7029
|
+
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
6362
7030
|
switch (op) {
|
6363
7031
|
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
|
6364
7032
|
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
|
@@ -6450,7 +7118,7 @@ void ggml_compute_forward_pool_2d(
|
|
6450
7118
|
for (int kx = 0; kx < k0; ++kx) {
|
6451
7119
|
int j = ix + kx;
|
6452
7120
|
if (j < 0 || j >= src->ne[0]) continue;
|
6453
|
-
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] :
|
7121
|
+
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
6454
7122
|
switch (op) {
|
6455
7123
|
case GGML_OP_POOL_AVG: *out += srow_j; break;
|
6456
7124
|
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
|
@@ -6538,7 +7206,7 @@ void ggml_compute_forward_pool_2d_back(
|
|
6538
7206
|
}
|
6539
7207
|
|
6540
7208
|
const float val = dst->type == GGML_TYPE_F32 ?
|
6541
|
-
((const float *) drowf)[j] :
|
7209
|
+
((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
|
6542
7210
|
if (val <= maxval) {
|
6543
7211
|
continue;
|
6544
7212
|
}
|
@@ -6558,7 +7226,7 @@ void ggml_compute_forward_pool_2d_back(
|
|
6558
7226
|
if (dst->type == GGML_TYPE_F32) {
|
6559
7227
|
((float *) drow)[j] += grad0;
|
6560
7228
|
} else {
|
6561
|
-
((ggml_fp16_t *) drow)[j] =
|
7229
|
+
((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
|
6562
7230
|
}
|
6563
7231
|
} else if (op == GGML_OP_POOL_AVG) {
|
6564
7232
|
const float grad = grad0 / ka;
|
@@ -6577,7 +7245,7 @@ void ggml_compute_forward_pool_2d_back(
|
|
6577
7245
|
if (dst->type == GGML_TYPE_F32) {
|
6578
7246
|
((float *) drow)[j] += grad;
|
6579
7247
|
} else {
|
6580
|
-
((ggml_fp16_t *) drow)[j] +=
|
7248
|
+
((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
|
6581
7249
|
}
|
6582
7250
|
}
|
6583
7251
|
}
|
@@ -6608,12 +7276,13 @@ static void ggml_compute_forward_upscale_f32(
|
|
6608
7276
|
|
6609
7277
|
GGML_TENSOR_UNARY_OP_LOCALS
|
6610
7278
|
|
6611
|
-
|
6612
|
-
|
6613
|
-
|
6614
|
-
|
7279
|
+
float sf0 = (float)ne0/src0->ne[0];
|
7280
|
+
float sf1 = (float)ne1/src0->ne[1];
|
7281
|
+
float sf2 = (float)ne2/src0->ne[2];
|
7282
|
+
float sf3 = (float)ne3/src0->ne[3];
|
6615
7283
|
|
6616
|
-
const
|
7284
|
+
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
7285
|
+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
6617
7286
|
|
6618
7287
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
6619
7288
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
@@ -6634,8 +7303,12 @@ static void ggml_compute_forward_upscale_f32(
|
|
6634
7303
|
}
|
6635
7304
|
}
|
6636
7305
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
6637
|
-
|
6638
|
-
|
7306
|
+
float pixel_offset = 0.5f;
|
7307
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
7308
|
+
pixel_offset = 0.0f;
|
7309
|
+
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
7310
|
+
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
7311
|
+
}
|
6639
7312
|
|
6640
7313
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
6641
7314
|
const int64_t i03 = i3 / sf3;
|
@@ -6793,6 +7466,73 @@ void ggml_compute_forward_pad_reflect_1d(
|
|
6793
7466
|
}
|
6794
7467
|
}
|
6795
7468
|
|
7469
|
+
// ggml_compute_forward_roll
|
7470
|
+
|
7471
|
+
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
|
7472
|
+
if (i < 0) {
|
7473
|
+
return i + ne;
|
7474
|
+
} else if (i >= ne) {
|
7475
|
+
return i - ne;
|
7476
|
+
}
|
7477
|
+
return i;
|
7478
|
+
}
|
7479
|
+
|
7480
|
+
static void ggml_compute_forward_roll_f32(
|
7481
|
+
const ggml_compute_params * params,
|
7482
|
+
ggml_tensor * dst) {
|
7483
|
+
|
7484
|
+
const ggml_tensor * src0 = dst->src[0];
|
7485
|
+
const float * src_data = (const float *) src0->data;
|
7486
|
+
float * dst_data = (float *) dst->data;
|
7487
|
+
|
7488
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
7489
|
+
|
7490
|
+
const int s0 = ggml_get_op_params_i32(dst, 0);
|
7491
|
+
const int s1 = ggml_get_op_params_i32(dst, 1);
|
7492
|
+
const int s2 = ggml_get_op_params_i32(dst, 2);
|
7493
|
+
const int s3 = ggml_get_op_params_i32(dst, 3);
|
7494
|
+
|
7495
|
+
const int64_t total = ne1 * ne2 * ne3;
|
7496
|
+
const int64_t per_thread = (total + params->nth) / params->nth;
|
7497
|
+
const int64_t start = params->ith * per_thread;
|
7498
|
+
const int64_t end = std::min(start + per_thread, total);
|
7499
|
+
|
7500
|
+
for (int64_t i = start; i < end; ++i) {
|
7501
|
+
const int64_t i1 = i % ne1;
|
7502
|
+
const int64_t i2 = (i / ne1) % ne2;
|
7503
|
+
const int64_t i3 = i / (ne2 * ne1);
|
7504
|
+
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
|
7505
|
+
|
7506
|
+
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
|
7507
|
+
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
|
7508
|
+
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
|
7509
|
+
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
|
7510
|
+
|
7511
|
+
const int64_t s = ggml_wrap_index(-s0, ne00);
|
7512
|
+
const int64_t n = ne00 - s;
|
7513
|
+
ggml_vec_cpy_f32(n, dst_row, src_row + s);
|
7514
|
+
ggml_vec_cpy_f32(s, dst_row + n, src_row);
|
7515
|
+
}
|
7516
|
+
}
|
7517
|
+
|
7518
|
+
void ggml_compute_forward_roll(
|
7519
|
+
const ggml_compute_params * params,
|
7520
|
+
ggml_tensor * dst) {
|
7521
|
+
|
7522
|
+
const ggml_tensor * src0 = dst->src[0];
|
7523
|
+
|
7524
|
+
switch (src0->type) {
|
7525
|
+
case GGML_TYPE_F32:
|
7526
|
+
{
|
7527
|
+
ggml_compute_forward_roll_f32(params, dst);
|
7528
|
+
} break;
|
7529
|
+
default:
|
7530
|
+
{
|
7531
|
+
GGML_ABORT("fatal error");
|
7532
|
+
}
|
7533
|
+
}
|
7534
|
+
}
|
7535
|
+
|
6796
7536
|
// ggml_compute_forward_arange
|
6797
7537
|
|
6798
7538
|
static void ggml_compute_forward_arange_f32(
|
@@ -7075,7 +7815,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7075
7815
|
// loop over n_kv and n_head_kv
|
7076
7816
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
7077
7817
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
7078
|
-
const float mv = mp ? slope*
|
7818
|
+
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
7079
7819
|
if (mv == -INFINITY) {
|
7080
7820
|
continue;
|
7081
7821
|
}
|
@@ -7143,7 +7883,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7143
7883
|
|
7144
7884
|
if (v->type == GGML_TYPE_F16) {
|
7145
7885
|
for (int64_t d = 0; d < DV; ++d) {
|
7146
|
-
VKQ32[d] =
|
7886
|
+
VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
|
7147
7887
|
}
|
7148
7888
|
}
|
7149
7889
|
|
@@ -7633,39 +8373,83 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
7633
8373
|
const int ir1 = MIN(ir0 + dr, nr);
|
7634
8374
|
const int ir = ir1 - ir0;
|
7635
8375
|
|
7636
|
-
|
7637
|
-
for (int
|
7638
|
-
|
7639
|
-
|
7640
|
-
|
7641
|
-
|
7642
|
-
|
7643
|
-
|
7644
|
-
|
7645
|
-
|
7646
|
-
|
7647
|
-
|
7648
|
-
|
7649
|
-
|
7650
|
-
|
7651
|
-
|
7652
|
-
|
7653
|
-
|
7654
|
-
|
7655
|
-
|
7656
|
-
|
7657
|
-
|
7658
|
-
|
7659
|
-
|
7660
|
-
|
7661
|
-
|
7662
|
-
|
7663
|
-
|
8376
|
+
#ifdef __ARM_FEATURE_SVE
|
8377
|
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
8378
|
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
8379
|
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
8380
|
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
8381
|
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
8382
|
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
8383
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
8384
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
8385
|
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
8386
|
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
8387
|
+
|
8388
|
+
// use the output as the source for the next token-wise iterations
|
8389
|
+
if (i2 > 0) { s0 = s; }
|
8390
|
+
|
8391
|
+
// d_inner
|
8392
|
+
for (int i1 = 0; i1 < ir; ++i1) {
|
8393
|
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
8394
|
+
float x_dt = x[i1] * dt_soft_plus;
|
8395
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
8396
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
8397
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
8398
|
+
|
8399
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
8400
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
8401
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
8402
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
8403
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
8404
|
+
|
8405
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
8406
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
8407
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
8408
|
+
|
8409
|
+
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
8410
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
8411
|
+
|
8412
|
+
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
8413
|
+
}
|
8414
|
+
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
7664
8415
|
}
|
7665
|
-
y[i1] = sumf;
|
7666
8416
|
}
|
7667
8417
|
}
|
7668
|
-
|
8418
|
+
#else
|
8419
|
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
8420
|
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
8421
|
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
8422
|
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
8423
|
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
8424
|
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
8425
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
8426
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
8427
|
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
8428
|
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
8429
|
+
|
8430
|
+
// use the output as the source for the next token-wise iterations
|
8431
|
+
if (i2 > 0) { s0 = s; }
|
8432
|
+
|
8433
|
+
// d_inner
|
8434
|
+
for (int i1 = 0; i1 < ir; ++i1) {
|
8435
|
+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
8436
|
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
8437
|
+
float x_dt = x[i1] * dt_soft_plus;
|
8438
|
+
float sumf = 0.0f;
|
8439
|
+
// d_state
|
8440
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
8441
|
+
int i = i0 + i1*nc;
|
8442
|
+
// state = prev_state * dA + dB * x
|
8443
|
+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
8444
|
+
// y = rowwise_dotprod(state, C)
|
8445
|
+
sumf += state * C[i0];
|
8446
|
+
s[i] = state;
|
8447
|
+
}
|
8448
|
+
y[i1] = sumf;
|
8449
|
+
}
|
8450
|
+
}
|
8451
|
+
}
|
8452
|
+
#endif
|
7669
8453
|
}
|
7670
8454
|
|
7671
8455
|
void ggml_compute_forward_ssm_scan(
|
@@ -7883,6 +8667,34 @@ void ggml_compute_forward_unary(
|
|
7883
8667
|
}
|
7884
8668
|
}
|
7885
8669
|
|
8670
|
+
//ggml_compute_forward_glu
|
8671
|
+
|
8672
|
+
void ggml_compute_forward_glu(
|
8673
|
+
const ggml_compute_params * params,
|
8674
|
+
ggml_tensor * dst) {
|
8675
|
+
|
8676
|
+
const ggml_glu_op op = ggml_get_glu_op(dst);
|
8677
|
+
|
8678
|
+
switch (op) {
|
8679
|
+
case GGML_GLU_OP_REGLU:
|
8680
|
+
{
|
8681
|
+
ggml_compute_forward_reglu(params, dst);
|
8682
|
+
} break;
|
8683
|
+
case GGML_GLU_OP_GEGLU:
|
8684
|
+
{
|
8685
|
+
ggml_compute_forward_geglu(params, dst);
|
8686
|
+
} break;
|
8687
|
+
case GGML_GLU_OP_SWIGLU:
|
8688
|
+
{
|
8689
|
+
ggml_compute_forward_swiglu(params, dst);
|
8690
|
+
} break;
|
8691
|
+
default:
|
8692
|
+
{
|
8693
|
+
GGML_ABORT("fatal error");
|
8694
|
+
}
|
8695
|
+
}
|
8696
|
+
}
|
8697
|
+
|
7886
8698
|
// ggml_compute_forward_get_rel_pos
|
7887
8699
|
|
7888
8700
|
static void ggml_compute_forward_get_rel_pos_f16(
|
@@ -8070,6 +8882,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
8070
8882
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
8071
8883
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
8072
8884
|
#define WKV_VECTOR_SIZE 16
|
8885
|
+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
8886
|
+
#define GGML_F32X GGML_F32xt
|
8887
|
+
#define GGML_F32X_SET1 GGML_F32xt_SET1
|
8888
|
+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
|
8889
|
+
#define GGML_F32X_STORE GGML_F32xt_STORE
|
8890
|
+
#define GGML_F32X_MUL GGML_F32xt_MUL
|
8891
|
+
#define GGML_F32X_FMA GGML_F32xt_FMA
|
8892
|
+
#define WKV_VECTOR_SIZE 8
|
8073
8893
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
8074
8894
|
#define GGML_F32X GGML_F32x4
|
8075
8895
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
@@ -8081,7 +8901,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
8081
8901
|
#endif
|
8082
8902
|
|
8083
8903
|
#ifdef WKV_VECTOR_SIZE
|
8084
|
-
|
8904
|
+
int wkv_vector_size;
|
8905
|
+
#if defined(__ARM_FEATURE_SVE)
|
8906
|
+
wkv_vector_size = svcntw();
|
8907
|
+
#else
|
8908
|
+
wkv_vector_size = WKV_VECTOR_SIZE;
|
8909
|
+
#endif
|
8910
|
+
const int64_t vec_count = head_size / wkv_vector_size;
|
8085
8911
|
|
8086
8912
|
for (int64_t t = 0; t < T; t++) {
|
8087
8913
|
size_t t_offset = t * t_stride;
|
@@ -8111,7 +8937,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
8111
8937
|
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
8112
8938
|
|
8113
8939
|
for (int64_t j = 0; j < vec_count; j++) {
|
8114
|
-
size_t base_j = j *
|
8940
|
+
size_t base_j = j * wkv_vector_size;
|
8115
8941
|
size_t t_h_j_offset = t_h_offset + base_j;
|
8116
8942
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
8117
8943
|
|
@@ -8136,7 +8962,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
8136
8962
|
}
|
8137
8963
|
|
8138
8964
|
// Handle remaining elements, this will not be used.
|
8139
|
-
for (int64_t j = vec_count *
|
8965
|
+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
|
8140
8966
|
size_t t_h_j_offset = t_h_offset + j;
|
8141
8967
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8142
8968
|
float v_val = v[t_h_j_offset];
|
@@ -8272,6 +9098,14 @@ static void ggml_compute_forward_gla_f32(
|
|
8272
9098
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
8273
9099
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
8274
9100
|
#define GLA_VECTOR_SIZE 16
|
9101
|
+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
9102
|
+
#define GGML_F32X GGML_F32xt
|
9103
|
+
#define GGML_F32X_SET1 GGML_F32xt_SET1
|
9104
|
+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
|
9105
|
+
#define GGML_F32X_STORE GGML_F32xt_STORE
|
9106
|
+
#define GGML_F32X_MUL GGML_F32xt_MUL
|
9107
|
+
#define GGML_F32X_FMA GGML_F32xt_FMA
|
9108
|
+
#define GLA_VECTOR_SIZE 8
|
8275
9109
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
8276
9110
|
#define GGML_F32X GGML_F32x4
|
8277
9111
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
@@ -8283,7 +9117,13 @@ static void ggml_compute_forward_gla_f32(
|
|
8283
9117
|
#endif
|
8284
9118
|
|
8285
9119
|
#ifdef GLA_VECTOR_SIZE
|
8286
|
-
|
9120
|
+
int gla_vector_size;
|
9121
|
+
#if defined(__ARM_FEATURE_SVE)
|
9122
|
+
gla_vector_size = svcntw();
|
9123
|
+
#else
|
9124
|
+
gla_vector_size = GLA_VECTOR_SIZE;
|
9125
|
+
#endif
|
9126
|
+
const int64_t vec_count = head_size / gla_vector_size;
|
8287
9127
|
|
8288
9128
|
for (int64_t t = 0; t < T; t++) {
|
8289
9129
|
size_t t_offset = t * t_stride;
|
@@ -8310,7 +9150,7 @@ static void ggml_compute_forward_gla_f32(
|
|
8310
9150
|
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
|
8311
9151
|
|
8312
9152
|
for (int64_t j = 0; j < vec_count; j++) {
|
8313
|
-
size_t base_j = j *
|
9153
|
+
size_t base_j = j * gla_vector_size;
|
8314
9154
|
size_t t_h_j_offset = t_h_offset + base_j;
|
8315
9155
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
8316
9156
|
|
@@ -8334,7 +9174,7 @@ static void ggml_compute_forward_gla_f32(
|
|
8334
9174
|
}
|
8335
9175
|
|
8336
9176
|
// Handle remaining elements, this will not be used.
|
8337
|
-
for (int64_t j = vec_count *
|
9177
|
+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
|
8338
9178
|
size_t t_h_j_offset = t_h_offset + j;
|
8339
9179
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8340
9180
|
float v_val = v[t_h_j_offset];
|
@@ -8443,83 +9283,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
8443
9283
|
int64_t h_stride_2d = head_size * head_size;
|
8444
9284
|
|
8445
9285
|
#if defined(GGML_SIMD)
|
8446
|
-
|
8447
|
-
|
8448
|
-
int64_t
|
8449
|
-
|
8450
|
-
|
8451
|
-
|
8452
|
-
|
8453
|
-
|
8454
|
-
int64_t
|
8455
|
-
|
8456
|
-
|
8457
|
-
|
8458
|
-
|
8459
|
-
int64_t
|
8460
|
-
|
8461
|
-
|
9286
|
+
#if defined(__ARM_FEATURE_SVE)
|
9287
|
+
// scalar Route to scalar implementation //TODO: Write SVE code
|
9288
|
+
for (int64_t t = 0; t < T; t++) {
|
9289
|
+
int64_t t_offset = t * t_stride;
|
9290
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
9291
|
+
float * state_cur = state + state_offset;
|
9292
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
9293
|
+
|
9294
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
9295
|
+
int64_t h_offset = h * h_stride;
|
9296
|
+
int64_t t_h_offset = t_offset + h_offset;
|
9297
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
9298
|
+
|
9299
|
+
for (int64_t i = 0; i < head_size; i++) {
|
9300
|
+
int64_t t_h_i_offset = t_h_offset + i;
|
9301
|
+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
9302
|
+
|
9303
|
+
float v_val = v[t_h_i_offset];
|
9304
|
+
|
9305
|
+
float sa = 0, result = 0;
|
9306
|
+
for (int64_t j = 0; j < head_size; j++) {
|
9307
|
+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
9308
|
+
}
|
8462
9309
|
|
8463
|
-
|
8464
|
-
|
8465
|
-
|
8466
|
-
|
8467
|
-
|
8468
|
-
|
8469
|
-
|
8470
|
-
|
8471
|
-
|
8472
|
-
|
8473
|
-
|
9310
|
+
for (int64_t j = 0; j < head_size; j++) {
|
9311
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
9312
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
9313
|
+
|
9314
|
+
float r_val = r[t_h_j_offset];
|
9315
|
+
float w_val = w[t_h_j_offset];
|
9316
|
+
float k_val = k[t_h_j_offset];
|
9317
|
+
float b_val = b[t_h_j_offset];
|
9318
|
+
float kv_val = v_val * k_val;
|
9319
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
9320
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
9321
|
+
result += state_cur[h_2d_i_j_offset] * r_val;
|
8474
9322
|
}
|
8475
|
-
|
9323
|
+
dst_data[t_h_i_offset] = result;
|
8476
9324
|
}
|
9325
|
+
}
|
9326
|
+
}
|
9327
|
+
#else
|
9328
|
+
for (int64_t t = 0; t < T; t++) {
|
9329
|
+
int64_t t_offset = t * t_stride;
|
9330
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
9331
|
+
float * state_cur = state + state_offset;
|
9332
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
9333
|
+
|
9334
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
9335
|
+
int64_t h_offset = h * h_stride;
|
9336
|
+
int64_t t_h_offset = t_offset + h_offset;
|
9337
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
9338
|
+
|
9339
|
+
for (int64_t ii = 0; ii < head_size; ii++) {
|
9340
|
+
int64_t t_h_i_offset = t_h_offset + ii;
|
9341
|
+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
9342
|
+
|
9343
|
+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
9344
|
+
|
9345
|
+
float sa = 0;
|
9346
|
+
{
|
9347
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
9348
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
9349
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
9350
|
+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
9351
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
9352
|
+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
9353
|
+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
9354
|
+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
9355
|
+
}
|
9356
|
+
}
|
9357
|
+
GGML_F32_VEC_REDUCE(sa, sum);
|
9358
|
+
}
|
8477
9359
|
|
8478
|
-
|
9360
|
+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
8479
9361
|
|
8480
|
-
|
8481
|
-
|
8482
|
-
|
8483
|
-
|
8484
|
-
|
8485
|
-
|
9362
|
+
int64_t j = 0;
|
9363
|
+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
9364
|
+
for (; j < head_size; j += GGML_F32_STEP) {
|
9365
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
9366
|
+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
9367
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
8486
9368
|
|
8487
|
-
|
8488
|
-
|
8489
|
-
|
8490
|
-
|
9369
|
+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
9370
|
+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
9371
|
+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
9372
|
+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
8491
9373
|
|
8492
|
-
|
9374
|
+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
8493
9375
|
|
8494
|
-
|
8495
|
-
|
8496
|
-
|
8497
|
-
|
8498
|
-
|
9376
|
+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
9377
|
+
// kv + s * decay + sa * b
|
9378
|
+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
9379
|
+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
9380
|
+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
8499
9381
|
|
8500
|
-
|
9382
|
+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
9383
|
+
}
|
9384
|
+
}
|
9385
|
+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
9386
|
+
|
9387
|
+
// There shouldn't be left-overs though.
|
9388
|
+
for (; j < head_size; j++) {
|
9389
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
9390
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
9391
|
+
|
9392
|
+
float r_val = r[t_h_j_offset];
|
9393
|
+
float w_val = w[t_h_j_offset];
|
9394
|
+
float k_val = k[t_h_j_offset];
|
9395
|
+
float b_val = b[t_h_j_offset];
|
9396
|
+
float kv_val = v[t_h_i_offset] * k_val;
|
9397
|
+
|
9398
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
9399
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
9400
|
+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
8501
9401
|
}
|
8502
|
-
}
|
8503
|
-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
8504
|
-
|
8505
|
-
// There shouldn't be left-overs though.
|
8506
|
-
for (; j < head_size; j++) {
|
8507
|
-
int64_t t_h_j_offset = t_h_offset + j;
|
8508
|
-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8509
|
-
|
8510
|
-
float r_val = r[t_h_j_offset];
|
8511
|
-
float w_val = w[t_h_j_offset];
|
8512
|
-
float k_val = k[t_h_j_offset];
|
8513
|
-
float b_val = b[t_h_j_offset];
|
8514
|
-
float kv_val = v[t_h_i_offset] * k_val;
|
8515
|
-
|
8516
|
-
float prev_state_val = state_prev[h_2d_i_j_offset];
|
8517
|
-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
8518
|
-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
8519
9402
|
}
|
8520
9403
|
}
|
8521
9404
|
}
|
8522
|
-
|
9405
|
+
#endif
|
8523
9406
|
#else
|
8524
9407
|
for (int64_t t = 0; t < T; t++) {
|
8525
9408
|
int64_t t_offset = t * t_stride;
|