whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,8 +1,12 @@
|
|
1
1
|
#include "cpy.hpp"
|
2
2
|
|
3
3
|
#include <float.h>
|
4
|
+
#include <string>
|
4
5
|
|
5
6
|
#include "dequantize.hpp"
|
7
|
+
#include "ggml-sycl/common.hpp"
|
8
|
+
#include "ggml-sycl/presets.hpp"
|
9
|
+
#include "ggml.h"
|
6
10
|
|
7
11
|
static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
|
8
12
|
if (x <= val[0]) {
|
@@ -116,6 +120,15 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|
116
120
|
}
|
117
121
|
}
|
118
122
|
|
123
|
+
/* quantized type same copy */
|
124
|
+
template<typename T>
|
125
|
+
static void cpy_blck_q_q(const char * cxi, char * cdsti) {
|
126
|
+
const T * xi = (const T *) cxi;
|
127
|
+
T * dsti = (T *) cdsti;
|
128
|
+
*dsti = *xi;
|
129
|
+
}
|
130
|
+
|
131
|
+
|
119
132
|
static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
120
133
|
float * cdstf = (float *) (cdsti);
|
121
134
|
|
@@ -311,6 +324,34 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
|
|
311
324
|
}
|
312
325
|
}
|
313
326
|
|
327
|
+
|
328
|
+
template <typename T, int qk>
|
329
|
+
static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
330
|
+
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
331
|
+
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
332
|
+
const sycl::nd_item<3> & item_ct1) {
|
333
|
+
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
|
334
|
+
|
335
|
+
if (i >= ne) {
|
336
|
+
return;
|
337
|
+
}
|
338
|
+
|
339
|
+
const int i03 = i / (ne00 * ne01 * ne02);
|
340
|
+
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
341
|
+
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
342
|
+
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
|
343
|
+
const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
|
344
|
+
|
345
|
+
|
346
|
+
const int i13 = i / (ne10 * ne11 * ne12);
|
347
|
+
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
|
348
|
+
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
|
349
|
+
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
350
|
+
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
351
|
+
|
352
|
+
cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
|
353
|
+
}
|
354
|
+
|
314
355
|
template <cpy_kernel_t cpy_blck, int qk>
|
315
356
|
static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
316
357
|
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
@@ -322,6 +363,7 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00
|
|
322
363
|
return;
|
323
364
|
}
|
324
365
|
|
366
|
+
|
325
367
|
const int i03 = i / (ne00 * ne01 * ne02);
|
326
368
|
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
327
369
|
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
@@ -371,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
|
|
371
413
|
{
|
372
414
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
373
415
|
|
374
|
-
|
416
|
+
sycl_parallel_for(
|
417
|
+
stream,
|
375
418
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
376
419
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
377
420
|
[=](sycl::nd_item<3> item_ct1) {
|
@@ -389,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
|
|
389
432
|
{
|
390
433
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
391
434
|
|
392
|
-
|
435
|
+
sycl_parallel_for(
|
436
|
+
stream,
|
393
437
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
394
438
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
395
439
|
[=](sycl::nd_item<3> item_ct1) {
|
@@ -407,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
|
|
407
451
|
{
|
408
452
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
409
453
|
|
410
|
-
|
454
|
+
sycl_parallel_for(
|
455
|
+
stream,
|
411
456
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
412
457
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
413
458
|
[=](sycl::nd_item<3> item_ct1) {
|
@@ -423,11 +468,11 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
|
|
423
468
|
const int nb12, const int nb13, queue_ptr stream) {
|
424
469
|
GGML_ASSERT(ne % QK8_0 == 0);
|
425
470
|
const int num_blocks = ne / QK8_0;
|
426
|
-
stream
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
471
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
472
|
+
[=](sycl::nd_item<3> item_ct1) {
|
473
|
+
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
474
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
475
|
+
});
|
431
476
|
}
|
432
477
|
|
433
478
|
static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -435,11 +480,11 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
|
435
480
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
436
481
|
const int nb12, const int nb13, queue_ptr stream) {
|
437
482
|
const int num_blocks = ne;
|
438
|
-
stream
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
483
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
484
|
+
[=](sycl::nd_item<3> item_ct1) {
|
485
|
+
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
486
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
487
|
+
});
|
443
488
|
}
|
444
489
|
|
445
490
|
static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -448,11 +493,11 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
|
|
448
493
|
const int nb12, const int nb13, queue_ptr stream) {
|
449
494
|
GGML_ASSERT(ne % QK4_0 == 0);
|
450
495
|
const int num_blocks = ne / QK4_0;
|
451
|
-
stream
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
496
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
497
|
+
[=](sycl::nd_item<3> item_ct1) {
|
498
|
+
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
499
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
500
|
+
});
|
456
501
|
}
|
457
502
|
|
458
503
|
static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -460,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
|
460
505
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
461
506
|
const int nb12, const int nb13, queue_ptr stream) {
|
462
507
|
const int num_blocks = ne;
|
463
|
-
|
464
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
508
|
+
sycl_parallel_for(
|
509
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
510
|
+
[=](sycl::nd_item<3> item_ct1) {
|
465
511
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
466
512
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
467
513
|
item_ct1);
|
@@ -474,11 +520,11 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
|
|
474
520
|
const int nb12, const int nb13, queue_ptr stream) {
|
475
521
|
GGML_ASSERT(ne % QK4_1 == 0);
|
476
522
|
const int num_blocks = ne / QK4_1;
|
477
|
-
stream
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
523
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
524
|
+
[=](sycl::nd_item<3> item_ct1) {
|
525
|
+
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
526
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
527
|
+
});
|
482
528
|
}
|
483
529
|
|
484
530
|
static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -486,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
|
|
486
532
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
487
533
|
const int nb12, const int nb13, queue_ptr stream) {
|
488
534
|
const int num_blocks = ne;
|
489
|
-
|
490
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
535
|
+
sycl_parallel_for(
|
536
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
537
|
+
[=](sycl::nd_item<3> item_ct1) {
|
491
538
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
492
539
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
493
540
|
item_ct1);
|
@@ -500,11 +547,11 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
|
|
500
547
|
const int nb12, const int nb13, queue_ptr stream) {
|
501
548
|
GGML_ASSERT(ne % QK5_0 == 0);
|
502
549
|
const int num_blocks = ne / QK5_0;
|
503
|
-
stream
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
550
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
551
|
+
[=](sycl::nd_item<3> item_ct1) {
|
552
|
+
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
553
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
554
|
+
});
|
508
555
|
}
|
509
556
|
|
510
557
|
static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -512,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
|
512
559
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
513
560
|
const int nb12, const int nb13, queue_ptr stream) {
|
514
561
|
const int num_blocks = ne;
|
515
|
-
|
516
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
562
|
+
sycl_parallel_for(
|
563
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
564
|
+
[=](sycl::nd_item<3> item_ct1) {
|
517
565
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
518
566
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
519
567
|
item_ct1);
|
@@ -526,11 +574,11 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
|
|
526
574
|
const int nb12, const int nb13, queue_ptr stream) {
|
527
575
|
GGML_ASSERT(ne % QK5_1 == 0);
|
528
576
|
const int num_blocks = ne / QK5_1;
|
529
|
-
stream
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
577
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
578
|
+
[=](sycl::nd_item<3> item_ct1) {
|
579
|
+
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
580
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
581
|
+
});
|
534
582
|
}
|
535
583
|
|
536
584
|
static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -538,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
|
|
538
586
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
539
587
|
const int nb12, const int nb13, queue_ptr stream) {
|
540
588
|
const int num_blocks = ne;
|
541
|
-
|
542
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
589
|
+
sycl_parallel_for(
|
590
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
591
|
+
[=](sycl::nd_item<3> item_ct1) {
|
543
592
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
544
593
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
545
594
|
item_ct1);
|
@@ -552,11 +601,11 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
|
|
552
601
|
const int nb12, const int nb13, queue_ptr stream) {
|
553
602
|
GGML_ASSERT(ne % QK4_NL == 0);
|
554
603
|
const int num_blocks = ne / QK4_NL;
|
555
|
-
stream
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
604
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
605
|
+
[=](sycl::nd_item<3> item_ct1) {
|
606
|
+
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
607
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
608
|
+
});
|
560
609
|
}
|
561
610
|
|
562
611
|
static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
@@ -567,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
|
|
567
616
|
{
|
568
617
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
569
618
|
|
570
|
-
|
619
|
+
sycl_parallel_for(
|
620
|
+
stream,
|
571
621
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
572
622
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
573
623
|
[=](sycl::nd_item<3> item_ct1) {
|
@@ -586,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
|
|
586
636
|
// dpct::has_capability_or_fail(stream->get_device(),
|
587
637
|
// {sycl::aspect::fp16});
|
588
638
|
|
589
|
-
|
639
|
+
sycl_parallel_for(
|
640
|
+
stream,
|
590
641
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
591
642
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
592
643
|
[=](sycl::nd_item<3> item_ct1) {
|
@@ -605,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
|
|
605
656
|
// dpct::has_capability_or_fail(stream->get_device(),
|
606
657
|
// {sycl::aspect::fp16});
|
607
658
|
|
608
|
-
|
659
|
+
sycl_parallel_for(
|
660
|
+
stream,
|
609
661
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
610
662
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
611
663
|
[=](sycl::nd_item<3> item_ct1) {
|
@@ -615,10 +667,85 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
|
|
615
667
|
}
|
616
668
|
}
|
617
669
|
|
670
|
+
static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
671
|
+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
672
|
+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
673
|
+
const int nb12, const int nb13, queue_ptr stream) {
|
674
|
+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
675
|
+
sycl_parallel_for(stream,
|
676
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
677
|
+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
678
|
+
[=](sycl::nd_item<3> item_ct1) {
|
679
|
+
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
680
|
+
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
681
|
+
});
|
682
|
+
}
|
683
|
+
|
684
|
+
|
685
|
+
static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
686
|
+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
687
|
+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
688
|
+
const int nb12, const int nb13, queue_ptr stream) {
|
689
|
+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
690
|
+
sycl_parallel_for(stream,
|
691
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
692
|
+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
693
|
+
[=](sycl::nd_item<3> item_ct1) {
|
694
|
+
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
695
|
+
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
696
|
+
});
|
697
|
+
}
|
698
|
+
|
699
|
+
|
700
|
+
static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
701
|
+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
702
|
+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
703
|
+
const int nb12, const int nb13, queue_ptr stream) {
|
704
|
+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
705
|
+
|
706
|
+
sycl_parallel_for(stream,
|
707
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
708
|
+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
709
|
+
[=](sycl::nd_item<3> item_ct1) {
|
710
|
+
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
711
|
+
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
712
|
+
});
|
713
|
+
}
|
714
|
+
|
715
|
+
|
716
|
+
static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
717
|
+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
718
|
+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
719
|
+
const int nb12, const int nb13, queue_ptr stream) {
|
720
|
+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
721
|
+
sycl_parallel_for(stream,
|
722
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
723
|
+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
724
|
+
[=](sycl::nd_item<3> item_ct1) {
|
725
|
+
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
726
|
+
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
727
|
+
});
|
728
|
+
}
|
729
|
+
|
730
|
+
|
731
|
+
static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
732
|
+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
733
|
+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
734
|
+
const int nb12, const int nb13, queue_ptr stream) {
|
735
|
+
|
736
|
+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
737
|
+
sycl_parallel_for(stream,
|
738
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
739
|
+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
740
|
+
[=](sycl::nd_item<3> item_ct1) {
|
741
|
+
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
742
|
+
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
743
|
+
});
|
744
|
+
}
|
745
|
+
|
618
746
|
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
|
619
747
|
// Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
|
620
|
-
scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0,
|
621
|
-
std::string(" src0 type=") + ggml_type_name(src0->type));
|
748
|
+
scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str("\tsrc0", src0));
|
622
749
|
const int64_t ne = ggml_nelements(src0);
|
623
750
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
624
751
|
|
@@ -632,8 +759,10 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
|
|
632
759
|
|
633
760
|
char * src0_ddc = (char *) src0->data;
|
634
761
|
char * src1_ddc = (char *) src1->data;
|
635
|
-
|
636
|
-
|
762
|
+
if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {
|
763
|
+
GGML_SYCL_DEBUG("%s: memcpy path\n", __func__);
|
764
|
+
main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));
|
765
|
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
637
766
|
ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
638
767
|
nb11, nb12, nb13, main_stream);
|
639
768
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
@@ -684,6 +813,16 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
|
|
684
813
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
685
814
|
ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
686
815
|
nb10, nb11, nb12, nb13, main_stream);
|
816
|
+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
817
|
+
ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
818
|
+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {
|
819
|
+
ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
820
|
+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {
|
821
|
+
ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
822
|
+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {
|
823
|
+
ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
824
|
+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {
|
825
|
+
ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
687
826
|
} else {
|
688
827
|
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
|
689
828
|
ggml_type_name(src1->type));
|
@@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|
538
538
|
#endif
|
539
539
|
}
|
540
540
|
|
541
|
+
template <typename dst_t>
|
542
|
+
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
543
|
+
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
|
544
|
+
const int64_t ib = item_ct1.get_group(2);
|
545
|
+
|
546
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
547
|
+
const int64_t ip = tid / 32; // ip is 0 or 1
|
548
|
+
const int64_t il = tid - 32 * ip; // 0...32
|
549
|
+
const int64_t is = 8 * ip + il / 16;
|
550
|
+
|
551
|
+
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
|
552
|
+
const auto ql_offset = ib * (QK_K / 2);
|
553
|
+
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
|
554
|
+
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
|
555
|
+
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
|
556
|
+
const uint8_t * ql_ptr = base_ptr + ql_offset;
|
557
|
+
const uint8_t * qh_ptr = base_ptr + qh_offset;
|
558
|
+
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
|
559
|
+
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
|
560
|
+
|
561
|
+
dst_t * y = yy + ib * QK_K + 128 * ip + il;
|
562
|
+
|
563
|
+
const uint8_t * ql = ql_ptr + 64 * ip + il;
|
564
|
+
const uint8_t qh = *(qh_ptr + 32 * ip + il);
|
565
|
+
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
|
566
|
+
|
567
|
+
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
568
|
+
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
569
|
+
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
570
|
+
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
571
|
+
}
|
572
|
+
|
541
573
|
template<typename dst_t>
|
542
574
|
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
543
575
|
const sycl::nd_item<3> &item_ct1,
|