whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,121 @@
|
|
1
|
+
kernel void kernel_upscale(
|
2
|
+
global const void * p_src0,
|
3
|
+
ulong off_src0,
|
4
|
+
global void * p_dst,
|
5
|
+
ulong off_dst,
|
6
|
+
ulong nb00,
|
7
|
+
ulong nb01,
|
8
|
+
ulong nb02,
|
9
|
+
ulong nb03,
|
10
|
+
int ne10,
|
11
|
+
int ne11,
|
12
|
+
int ne12,
|
13
|
+
int ne13,
|
14
|
+
float sf0,
|
15
|
+
float sf1,
|
16
|
+
float sf2,
|
17
|
+
float sf3
|
18
|
+
) {
|
19
|
+
global const char * src_base = (global const char *)p_src0 + off_src0;
|
20
|
+
global float * dst_base = (global float *)((global char *)p_dst + off_dst);
|
21
|
+
|
22
|
+
int index = get_global_id(0);
|
23
|
+
int dst_total_elements = ne10 * ne11 * ne12 * ne13;
|
24
|
+
|
25
|
+
if (index >= dst_total_elements) {
|
26
|
+
return;
|
27
|
+
}
|
28
|
+
|
29
|
+
int i10 = index % ne10;
|
30
|
+
int i11 = (index / ne10) % ne11;
|
31
|
+
int i12 = (index / (ne10 * ne11)) % ne12;
|
32
|
+
int i13 = index / (ne10 * ne11 * ne12);
|
33
|
+
|
34
|
+
int i00 = (int)(i10 / sf0);
|
35
|
+
int i01 = (int)(i11 / sf1);
|
36
|
+
int i02 = (int)(i12 / sf2);
|
37
|
+
int i03 = (int)(i13 / sf3);
|
38
|
+
|
39
|
+
ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;
|
40
|
+
global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);
|
41
|
+
|
42
|
+
dst_base[index] = *src_element_ptr;
|
43
|
+
}
|
44
|
+
|
45
|
+
kernel void kernel_upscale_bilinear(
|
46
|
+
global const void * p_src0,
|
47
|
+
ulong off_src0,
|
48
|
+
global void * p_dst,
|
49
|
+
ulong off_dst,
|
50
|
+
ulong nb00,
|
51
|
+
ulong nb01,
|
52
|
+
ulong nb02,
|
53
|
+
ulong nb03,
|
54
|
+
int ne00_src,
|
55
|
+
int ne01_src,
|
56
|
+
int ne10_dst,
|
57
|
+
int ne11_dst,
|
58
|
+
int ne12_dst,
|
59
|
+
int ne13_dst,
|
60
|
+
float sf0,
|
61
|
+
float sf1,
|
62
|
+
float sf2,
|
63
|
+
float sf3
|
64
|
+
) {
|
65
|
+
global const char * src_base = (global const char *)p_src0 + off_src0;
|
66
|
+
global float * dst_base = (global float *)((global char *)p_dst + off_dst);
|
67
|
+
|
68
|
+
int index = get_global_id(0);
|
69
|
+
int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
70
|
+
|
71
|
+
if (index >= dst_total_elements) {
|
72
|
+
return;
|
73
|
+
}
|
74
|
+
|
75
|
+
int i10_dst = index % ne10_dst;
|
76
|
+
int i11_dst = (index / ne10_dst) % ne11_dst;
|
77
|
+
int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
78
|
+
int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
79
|
+
|
80
|
+
int i02_src = (int)(i12_dst / sf2);
|
81
|
+
int i03_src = (int)(i13_dst / sf3);
|
82
|
+
|
83
|
+
const float pixel_offset = 0.5f;
|
84
|
+
|
85
|
+
float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
86
|
+
long y0_src = (long)floor(y_src_f);
|
87
|
+
long y1_src = y0_src + 1;
|
88
|
+
|
89
|
+
y0_src = max(0L, min(y0_src, (long)ne01_src - 1));
|
90
|
+
y1_src = max(0L, min(y1_src, (long)ne01_src - 1));
|
91
|
+
|
92
|
+
float dy = y_src_f - (float)y0_src;
|
93
|
+
dy = max(0.0f, min(dy, 1.0f));
|
94
|
+
|
95
|
+
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
96
|
+
long x0_src = (long)floor(x_src_f);
|
97
|
+
long x1_src = x0_src + 1;
|
98
|
+
|
99
|
+
x0_src = max(0L, min(x0_src, (long)ne00_src - 1));
|
100
|
+
x1_src = max(0L, min(x1_src, (long)ne00_src - 1));
|
101
|
+
|
102
|
+
float dx = x_src_f - (float)x0_src;
|
103
|
+
dx = max(0.0f, min(dx, 1.0f));
|
104
|
+
|
105
|
+
global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
106
|
+
global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
107
|
+
global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
108
|
+
global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
109
|
+
|
110
|
+
const float val_a = *p_a;
|
111
|
+
const float val_b = *p_b;
|
112
|
+
const float val_c = *p_c;
|
113
|
+
const float val_d = *p_d;
|
114
|
+
|
115
|
+
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
116
|
+
val_b * dx * (1.0f - dy) +
|
117
|
+
val_c * (1.0f - dx) * dy +
|
118
|
+
val_d * dx * dy;
|
119
|
+
|
120
|
+
dst_base[index] = result;
|
121
|
+
}
|
@@ -568,14 +568,14 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
|
|
568
568
|
}
|
569
569
|
float iscale = nmax/(max - min);
|
570
570
|
float scale = 1/iscale;
|
571
|
-
float
|
571
|
+
float best_error = 0;
|
572
572
|
for (int i = 0; i < n; ++i) {
|
573
573
|
int l = nearest_int(iscale*(x[i] - min));
|
574
574
|
L[i] = MAX(0, MIN(nmax, l));
|
575
575
|
float diff = scale * L[i] + min - x[i];
|
576
576
|
diff = use_mad ? fabsf(diff) : diff * diff;
|
577
577
|
float w = weights[i];
|
578
|
-
|
578
|
+
best_error += w * diff;
|
579
579
|
}
|
580
580
|
if (nstep < 1) {
|
581
581
|
*the_min = -min;
|
@@ -601,18 +601,18 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
|
|
601
601
|
this_min = 0;
|
602
602
|
this_scale = sum_xl / sum_l2;
|
603
603
|
}
|
604
|
-
float
|
604
|
+
float cur_error = 0;
|
605
605
|
for (int i = 0; i < n; ++i) {
|
606
606
|
float diff = this_scale * Laux[i] + this_min - x[i];
|
607
607
|
diff = use_mad ? fabsf(diff) : diff * diff;
|
608
608
|
float w = weights[i];
|
609
|
-
|
609
|
+
cur_error += w * diff;
|
610
610
|
}
|
611
|
-
if (
|
611
|
+
if (cur_error < best_error) {
|
612
612
|
for (int i = 0; i < n; ++i) {
|
613
613
|
L[i] = Laux[i];
|
614
614
|
}
|
615
|
-
|
615
|
+
best_error = cur_error;
|
616
616
|
scale = this_scale;
|
617
617
|
min = this_min;
|
618
618
|
}
|
@@ -2425,8 +2425,6 @@ void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_REST
|
|
2425
2425
|
}
|
2426
2426
|
}
|
2427
2427
|
|
2428
|
-
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
2429
|
-
|
2430
2428
|
void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
2431
2429
|
assert(k % QK4_NL == 0);
|
2432
2430
|
const int64_t nb = k / QK4_NL;
|
@@ -53,6 +53,9 @@ struct socket_t {
|
|
53
53
|
}
|
54
54
|
};
|
55
55
|
|
56
|
+
// macro for nicer error messages on server crash
|
57
|
+
#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
|
58
|
+
|
56
59
|
// all RPC structures must be packed
|
57
60
|
#pragma pack(push, 1)
|
58
61
|
// ggml_tensor is serialized into rpc_tensor
|
@@ -425,7 +428,7 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
425
428
|
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
426
429
|
rpc_msg_hello_rsp response;
|
427
430
|
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
428
|
-
|
431
|
+
RPC_STATUS_ASSERT(status);
|
429
432
|
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
430
433
|
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
431
434
|
return false;
|
@@ -481,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
481
484
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
482
485
|
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
483
486
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
484
|
-
|
487
|
+
RPC_STATUS_ASSERT(status);
|
485
488
|
delete ctx;
|
486
489
|
}
|
487
490
|
|
@@ -493,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
493
496
|
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
494
497
|
rpc_msg_buffer_get_base_rsp response;
|
495
498
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
496
|
-
|
499
|
+
RPC_STATUS_ASSERT(status);
|
497
500
|
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
498
501
|
return ctx->base_ptr;
|
499
502
|
}
|
@@ -545,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
|
|
545
548
|
request.tensor = serialize_tensor(tensor);
|
546
549
|
|
547
550
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
|
548
|
-
|
551
|
+
RPC_STATUS_ASSERT(status);
|
549
552
|
}
|
550
553
|
return GGML_STATUS_SUCCESS;
|
551
554
|
}
|
@@ -560,7 +563,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|
560
563
|
request.hash = fnv_hash((const uint8_t*)data, size);
|
561
564
|
rpc_msg_set_tensor_hash_rsp response;
|
562
565
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
563
|
-
|
566
|
+
RPC_STATUS_ASSERT(status);
|
564
567
|
if (response.result) {
|
565
568
|
// the server has the same data, no need to send it
|
566
569
|
return;
|
@@ -573,7 +576,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|
573
576
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
574
577
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
575
578
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
|
576
|
-
|
579
|
+
RPC_STATUS_ASSERT(status);
|
577
580
|
}
|
578
581
|
|
579
582
|
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
@@ -583,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
|
583
586
|
request.offset = offset;
|
584
587
|
request.size = size;
|
585
588
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
586
|
-
|
589
|
+
RPC_STATUS_ASSERT(status);
|
587
590
|
}
|
588
591
|
|
589
592
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
@@ -601,7 +604,7 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
|
601
604
|
request.dst = serialize_tensor(dst);
|
602
605
|
rpc_msg_copy_tensor_rsp response;
|
603
606
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
604
|
-
|
607
|
+
RPC_STATUS_ASSERT(status);
|
605
608
|
return response.result;
|
606
609
|
}
|
607
610
|
|
@@ -609,7 +612,7 @@ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
|
609
612
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
610
613
|
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
611
614
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
612
|
-
|
615
|
+
RPC_STATUS_ASSERT(status);
|
613
616
|
}
|
614
617
|
|
615
618
|
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
@@ -635,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
635
638
|
rpc_msg_alloc_buffer_rsp response;
|
636
639
|
auto sock = get_socket(buft_ctx->endpoint);
|
637
640
|
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
638
|
-
|
641
|
+
RPC_STATUS_ASSERT(status);
|
639
642
|
if (response.remote_ptr != 0) {
|
640
643
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
641
644
|
ggml_backend_rpc_buffer_interface,
|
@@ -650,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
650
653
|
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
651
654
|
rpc_msg_get_alignment_rsp response;
|
652
655
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
653
|
-
|
656
|
+
RPC_STATUS_ASSERT(status);
|
654
657
|
return response.alignment;
|
655
658
|
}
|
656
659
|
|
@@ -662,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|
662
665
|
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
663
666
|
rpc_msg_get_max_size_rsp response;
|
664
667
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
665
|
-
|
668
|
+
RPC_STATUS_ASSERT(status);
|
666
669
|
return response.max_size;
|
667
670
|
}
|
668
671
|
|
@@ -683,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
|
|
683
686
|
|
684
687
|
rpc_msg_get_alloc_size_rsp response;
|
685
688
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
686
|
-
|
689
|
+
RPC_STATUS_ASSERT(status);
|
687
690
|
|
688
691
|
return response.alloc_size;
|
689
692
|
} else {
|
@@ -761,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|
761
764
|
rpc_msg_graph_compute_rsp response;
|
762
765
|
auto sock = get_socket(rpc_ctx->endpoint);
|
763
766
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
764
|
-
|
767
|
+
RPC_STATUS_ASSERT(status);
|
765
768
|
return (enum ggml_status)response.result;
|
766
769
|
}
|
767
770
|
|
@@ -835,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
835
838
|
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
836
839
|
rpc_msg_get_device_memory_rsp response;
|
837
840
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
838
|
-
|
841
|
+
RPC_STATUS_ASSERT(status);
|
839
842
|
*free = response.free_mem;
|
840
843
|
*total = response.total_mem;
|
841
844
|
}
|
@@ -13,7 +13,7 @@ elseif(SUPPORTS_SYCL)
|
|
13
13
|
If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
|
14
14
|
source /opt/intel/oneapi/setvars.sh")
|
15
15
|
else()
|
16
|
-
message(FATAL_ERROR
|
16
|
+
message(FATAL_ERROR "C++ compiler lacks SYCL support.")
|
17
17
|
endif()
|
18
18
|
message(STATUS "SYCL found")
|
19
19
|
#todo: AOT
|
@@ -142,7 +142,7 @@ else()
|
|
142
142
|
FetchContent_Declare(
|
143
143
|
ONEMATH
|
144
144
|
GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git
|
145
|
-
GIT_TAG
|
145
|
+
GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142
|
146
146
|
)
|
147
147
|
FetchContent_MakeAvailable(ONEMATH)
|
148
148
|
# Create alias to match with find_package targets name
|
@@ -170,7 +170,7 @@ else()
|
|
170
170
|
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA)
|
171
171
|
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
172
172
|
if (NOT GGML_SYCL_DEVICE_ARCH)
|
173
|
-
message(
|
173
|
+
message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
|
174
174
|
endif()
|
175
175
|
target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas)
|
176
176
|
target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
|
@@ -225,9 +225,9 @@ struct bin_bcast_sycl {
|
|
225
225
|
dpct::has_capability_or_fail(stream->get_device(),
|
226
226
|
{sycl::aspect::fp16});
|
227
227
|
|
228
|
-
|
229
|
-
|
230
|
-
|
228
|
+
sycl_parallel_for(
|
229
|
+
stream,
|
230
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
|
231
231
|
sycl::range<3>(1, 1, block_size)),
|
232
232
|
[=](sycl::nd_item<3> item_ct1) {
|
233
233
|
k_bin_bcast_unravel<bin_op>(
|
@@ -246,9 +246,8 @@ struct bin_bcast_sycl {
|
|
246
246
|
dpct::has_capability_or_fail(stream->get_device(),
|
247
247
|
{sycl::aspect::fp16});
|
248
248
|
|
249
|
-
|
250
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
251
|
-
[=](sycl::nd_item<3> item_ct1) {
|
249
|
+
sycl_parallel_for(
|
250
|
+
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
252
251
|
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
253
252
|
ne2, ne3, ne10, ne11, ne12, ne13,
|
254
253
|
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
@@ -149,8 +149,6 @@ typedef sycl::float2 dfloat2;
|
|
149
149
|
|
150
150
|
#define MMVQ_MAX_BATCH_SIZE 8
|
151
151
|
|
152
|
-
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
153
|
-
|
154
152
|
static int g_all_sycl_device_count = -1;
|
155
153
|
static bool g_ggml_backend_sycl_buffer_type_initialized = false;
|
156
154
|
|
@@ -201,7 +199,7 @@ struct sycl_device_info {
|
|
201
199
|
// size_t smpb; // max. shared memory per block
|
202
200
|
bool vmm; // virtual memory support
|
203
201
|
size_t total_vram;
|
204
|
-
sycl_hw_info hw_info;
|
202
|
+
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
|
205
203
|
optimize_feature opt_feature;
|
206
204
|
};
|
207
205
|
|
@@ -288,29 +286,6 @@ struct ggml_tensor_extra_gpu {
|
|
288
286
|
|
289
287
|
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
|
290
288
|
|
291
|
-
inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
|
292
|
-
optimize_feature opt;
|
293
|
-
|
294
|
-
opt.reorder =
|
295
|
-
(arch == syclex::architecture::intel_gpu_dg1 ||
|
296
|
-
arch == syclex::architecture::intel_gpu_acm_g10 ||
|
297
|
-
arch == syclex::architecture::intel_gpu_acm_g11 ||
|
298
|
-
arch == syclex::architecture::intel_gpu_acm_g12 ||
|
299
|
-
arch == syclex::architecture::intel_gpu_pvc ||
|
300
|
-
arch == syclex::architecture::intel_gpu_pvc_vg ||
|
301
|
-
arch == syclex::architecture::intel_gpu_mtl_u ||
|
302
|
-
arch == syclex::architecture::intel_gpu_mtl_s ||
|
303
|
-
arch == syclex::architecture::intel_gpu_mtl_h ||
|
304
|
-
arch == syclex::architecture::intel_gpu_arl_u ||
|
305
|
-
arch == syclex::architecture::intel_gpu_arl_s ||
|
306
|
-
arch == syclex::architecture::intel_gpu_arl_h ||
|
307
|
-
arch == syclex::architecture::intel_gpu_bmg_g21 ||
|
308
|
-
arch == syclex::architecture::intel_gpu_lnl_m
|
309
|
-
);
|
310
|
-
|
311
|
-
return opt;
|
312
|
-
}
|
313
|
-
|
314
289
|
namespace sycl_ex = sycl::ext::oneapi::experimental;
|
315
290
|
struct ggml_backend_sycl_context {
|
316
291
|
int device;
|
@@ -515,9 +490,9 @@ constexpr size_t ceil_div(const size_t m, const size_t n) {
|
|
515
490
|
|
516
491
|
bool gpu_has_xmx(sycl::device &dev);
|
517
492
|
|
518
|
-
template <int N, class T>
|
493
|
+
template <int N, class T> std::string debug_get_array_str(const std::string & prefix, const T array[N]) {
|
519
494
|
if (LIKELY(!g_ggml_sycl_debug)) {
|
520
|
-
return;
|
495
|
+
return "";
|
521
496
|
}
|
522
497
|
std::stringstream ss;
|
523
498
|
ss << prefix << "=[";
|
@@ -528,29 +503,26 @@ template <int N, class T> void debug_print_array(const std::string & prefix, con
|
|
528
503
|
ss << array[N - 1];
|
529
504
|
}
|
530
505
|
ss << "]";
|
531
|
-
|
506
|
+
return ss.str();
|
532
507
|
}
|
533
508
|
|
534
|
-
inline
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
GGML_SYCL_DEBUG("%s=", prefix.c_str());
|
509
|
+
inline std::string debug_get_tensor_str(const std::string &prefix,
|
510
|
+
const ggml_tensor *tensor, const std::string &suffix = "") {
|
511
|
+
std::stringstream ss;
|
512
|
+
if (LIKELY(!g_ggml_sycl_debug)) { return ss.str(); }
|
513
|
+
ss << prefix.c_str() << "=";
|
540
514
|
if (tensor) {
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
}
|
547
|
-
if (ggml_is_permuted(tensor)) {
|
548
|
-
GGML_SYCL_DEBUG(";permuted");
|
549
|
-
}
|
515
|
+
ss << "'" << tensor->name << "':type=" << ggml_type_name(tensor->type);
|
516
|
+
ss << debug_get_array_str<GGML_MAX_DIMS>(";ne", tensor->ne);
|
517
|
+
ss << debug_get_array_str<GGML_MAX_DIMS>(";nb", tensor->nb);
|
518
|
+
|
519
|
+
if (!ggml_is_contiguous(tensor)) { ss << ";strided"; }
|
520
|
+
if (ggml_is_permuted(tensor)) { ss << ";permuted"; }
|
550
521
|
} else {
|
551
|
-
|
522
|
+
ss << "nullptr";
|
552
523
|
}
|
553
|
-
|
524
|
+
ss << suffix;
|
525
|
+
return ss.str();
|
554
526
|
}
|
555
527
|
|
556
528
|
// Use scope_op_debug_print to log operations coming from running a model
|
@@ -566,10 +538,10 @@ struct scope_op_debug_print {
|
|
566
538
|
return;
|
567
539
|
}
|
568
540
|
GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data());
|
569
|
-
|
541
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" dst", dst).c_str());
|
570
542
|
if (dst) {
|
571
543
|
for (std::size_t i = 0; i < num_src; ++i) {
|
572
|
-
|
544
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str("\tsrc" + std::to_string(i), dst->src[i]).c_str());
|
573
545
|
}
|
574
546
|
}
|
575
547
|
GGML_SYCL_DEBUG("%s\n", suffix.data());
|
@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
|
89
89
|
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
90
90
|
switch (dim) {
|
91
91
|
case 0:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
|
98
|
-
});
|
99
|
-
break;
|
92
|
+
sycl_parallel_for(stream,
|
93
|
+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
94
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
95
|
+
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
|
96
|
+
break;
|
100
97
|
case 1:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
107
|
-
});
|
108
|
-
break;
|
98
|
+
sycl_parallel_for(stream,
|
99
|
+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
100
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
101
|
+
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
|
102
|
+
break;
|
109
103
|
// dim >=2 will be dispatched to the default path
|
110
104
|
default:
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
|
117
|
-
});
|
118
|
-
break;
|
105
|
+
sycl_parallel_for(stream,
|
106
|
+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
107
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
108
|
+
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
|
109
|
+
break;
|
119
110
|
}
|
120
111
|
}
|
121
112
|
|
@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
|
|
129
120
|
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
|
130
121
|
uint64_t nb3, int32_t dim) {
|
131
122
|
sycl::range<3> gridDim(ne3, ne2, ne1);
|
132
|
-
stream
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
int64_t i2 = item_ct1.get_group(1);
|
137
|
-
int64_t i1 = item_ct1.get_group(2);
|
123
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
124
|
+
int64_t i3 = item_ct1.get_group(0);
|
125
|
+
int64_t i2 = item_ct1.get_group(1);
|
126
|
+
int64_t i1 = item_ct1.get_group(2);
|
138
127
|
|
139
|
-
|
140
|
-
|
128
|
+
int64_t o[4] = { 0, 0, 0, 0 };
|
129
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
141
130
|
|
142
|
-
|
131
|
+
const float * x;
|
143
132
|
|
144
|
-
|
145
|
-
i0 += item_ct1.get_local_range(2)) {
|
133
|
+
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
|
146
134
|
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
147
|
-
|
148
|
-
(i0)*nb00);
|
135
|
+
x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
|
149
136
|
} else {
|
150
|
-
|
151
|
-
|
137
|
+
x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
|
138
|
+
(i0 - o[0]) * nb10);
|
152
139
|
}
|
153
140
|
|
154
141
|
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
|
155
142
|
|
156
143
|
*y = *x;
|
157
|
-
|
158
|
-
|
144
|
+
}
|
145
|
+
});
|
159
146
|
}
|
160
147
|
|
161
148
|
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
@@ -59,16 +59,10 @@ static void conv_transpose_1d_f32_f32_sycl(
|
|
59
59
|
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
60
60
|
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
61
61
|
const sycl::range<3> block_nums(1, 1, num_blocks);
|
62
|
-
stream
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
conv_transpose_1d_kernel(
|
67
|
-
s0, output_size,
|
68
|
-
src0_ne0, src0_ne1, src0_ne2,
|
69
|
-
src1_ne0, dst_ne0,
|
70
|
-
src0, src1, dst, item_ct1);
|
71
|
-
});
|
62
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
63
|
+
conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
|
64
|
+
item_ct1);
|
65
|
+
});
|
72
66
|
}
|
73
67
|
|
74
68
|
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|