@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -49,6 +49,8 @@ static bool g_sycl_loaded = false;
|
|
|
49
49
|
int g_ggml_sycl_debug = 0;
|
|
50
50
|
int g_ggml_sycl_disable_optimize = 0;
|
|
51
51
|
int g_ggml_sycl_disable_graph = 0;
|
|
52
|
+
int g_ggml_sycl_disable_dnn = 0;
|
|
53
|
+
int g_ggml_sycl_prioritize_dmmv = 0;
|
|
52
54
|
|
|
53
55
|
static ggml_sycl_device_info ggml_sycl_init() {
|
|
54
56
|
ggml_sycl_device_info info = {};
|
|
@@ -193,13 +195,25 @@ static void ggml_check_sycl() try {
|
|
|
193
195
|
|
|
194
196
|
if (!initialized) {
|
|
195
197
|
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
|
196
|
-
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT",
|
|
198
|
+
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
|
197
199
|
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
|
200
|
+
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
|
201
|
+
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
|
198
202
|
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
|
199
203
|
GGML_LOG_INFO("Running with Environment Variables:\n");
|
|
200
204
|
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
|
201
205
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
|
206
|
+
#ifdef GGML_SYCL_GRAPH
|
|
202
207
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
|
208
|
+
#else
|
|
209
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
|
|
210
|
+
#endif
|
|
211
|
+
#if GGML_SYCL_DNNL
|
|
212
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
|
213
|
+
#else
|
|
214
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
|
215
|
+
#endif
|
|
216
|
+
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
|
203
217
|
GGML_LOG_INFO("Build with Macros:\n");
|
|
204
218
|
#if defined(GGML_SYCL_FORCE_MMQ)
|
|
205
219
|
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
|
@@ -338,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|
|
338
352
|
assert(tensor->view_src->buffer->buft == buffer->buft);
|
|
339
353
|
return GGML_STATUS_SUCCESS;
|
|
340
354
|
}
|
|
341
|
-
if (tensor->type == GGML_TYPE_Q4_0) {
|
|
355
|
+
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
|
|
342
356
|
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
|
343
357
|
tensor->extra = extra;
|
|
344
358
|
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
|
@@ -1982,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
1982
1996
|
|
|
1983
1997
|
const int64_t ne00 = src0->ne[0];
|
|
1984
1998
|
const int64_t ne10 = src1->ne[0];
|
|
1985
|
-
|
|
1999
|
+
GGML_ASSERT(ne00 == ne10);
|
|
1986
2000
|
|
|
1987
2001
|
const int64_t row_diff = row_high - row_low;
|
|
1988
2002
|
|
|
1989
2003
|
int id;
|
|
1990
2004
|
SYCL_CHECK(
|
|
1991
2005
|
CHECK_TRY_ERROR(id = get_current_device_id()));
|
|
1992
|
-
|
|
1993
|
-
const int64_t ne0 = dst->ne[0];
|
|
2006
|
+
|
|
2007
|
+
const int64_t ne0 = dst->ne[0]; // used by MKL only
|
|
1994
2008
|
// the main device has a larger memory buffer to hold the results from all GPUs
|
|
1995
2009
|
// ldc == nrows of the matrix that cuBLAS writes into
|
|
1996
|
-
int ldc = id == ctx.device ? ne0 : row_diff;
|
|
1997
|
-
#endif
|
|
2010
|
+
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
|
|
1998
2011
|
|
|
1999
2012
|
#ifdef GGML_SYCL_F16
|
|
2000
2013
|
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
|
@@ -2030,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
2030
2043
|
: src1_as_f16.get();
|
|
2031
2044
|
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
|
2032
2045
|
|
|
2033
|
-
#if
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
dpct::library_data_t::real_half)));
|
|
2043
|
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
|
2044
|
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
|
2045
|
-
#else
|
|
2046
|
-
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
|
2047
|
-
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
|
2048
|
-
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
|
2049
|
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
|
2050
|
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
|
2046
|
+
#if GGML_SYCL_DNNL
|
|
2047
|
+
if (!g_ggml_sycl_disable_dnn) {
|
|
2048
|
+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
|
2049
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
|
2050
|
+
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
|
2051
|
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
|
2052
|
+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
|
2053
|
+
}
|
|
2054
|
+
else
|
|
2051
2055
|
#endif
|
|
2056
|
+
{
|
|
2057
|
+
const sycl::half alpha_f16 = 1.0f;
|
|
2058
|
+
const sycl::half beta_f16 = 0.0f;
|
|
2059
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
|
2060
|
+
*stream, oneapi::math::transpose::trans,
|
|
2061
|
+
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
|
2062
|
+
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
|
2063
|
+
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
|
2064
|
+
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
|
2065
|
+
dpct::library_data_t::real_half)));
|
|
2066
|
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
|
2067
|
+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
|
2068
|
+
}
|
|
2052
2069
|
}
|
|
2053
2070
|
else {
|
|
2054
2071
|
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
|
@@ -2069,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
2069
2086
|
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
|
2070
2087
|
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
|
2071
2088
|
|
|
2072
|
-
#if
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
#else
|
|
2080
|
-
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
|
2081
|
-
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
|
2082
|
-
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
|
2089
|
+
#if GGML_SYCL_DNNL
|
|
2090
|
+
if (!g_ggml_sycl_disable_dnn) {
|
|
2091
|
+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
|
2092
|
+
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
|
2093
|
+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
|
2094
|
+
}
|
|
2095
|
+
else
|
|
2083
2096
|
#endif
|
|
2097
|
+
{
|
|
2098
|
+
const float alpha = 1.0f;
|
|
2099
|
+
const float beta = 0.0f;
|
|
2100
|
+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
|
2101
|
+
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
|
2102
|
+
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
|
2103
|
+
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
|
2104
|
+
}
|
|
2084
2105
|
}
|
|
2085
2106
|
GGML_UNUSED(dst);
|
|
2086
2107
|
GGML_UNUSED(src1_ddq_i);
|
|
@@ -2694,139 +2715,180 @@ catch (sycl::exception const &exc) {
|
|
|
2694
2715
|
std::exit(1);
|
|
2695
2716
|
}
|
|
2696
2717
|
|
|
2697
|
-
static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
|
|
2698
|
-
const
|
|
2699
|
-
|
|
2700
|
-
int64_t
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
int64_t r2, int64_t r3,
|
|
2704
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
2705
|
-
int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
2706
|
-
item_ct1.get_local_id(2);
|
|
2707
|
-
int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
|
|
2708
|
-
item_ct1.get_local_id(1);
|
|
2718
|
+
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
|
|
2719
|
+
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
|
2720
|
+
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
|
2721
|
+
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
|
2722
|
+
const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
|
2723
|
+
const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
|
|
2709
2724
|
|
|
2710
2725
|
if (i13 >= ne13 || i12 >= ne12) {
|
|
2711
2726
|
return;
|
|
2712
2727
|
}
|
|
2713
2728
|
|
|
2714
|
-
int64_t i03 = i13 / r3;
|
|
2715
|
-
int64_t i02 = i12 / r2;
|
|
2729
|
+
const int64_t i03 = i13 / r3;
|
|
2730
|
+
const int64_t i02 = i12 / r2;
|
|
2731
|
+
|
|
2732
|
+
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
|
2733
|
+
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
|
2734
|
+
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
|
|
2716
2735
|
|
|
2717
|
-
ptrs_src[0*ne23 + i12 + i13*ne12] =
|
|
2718
|
-
ptrs_src[1*ne23 + i12 + i13*ne12] =
|
|
2719
|
-
ptrs_dst[0*ne23 + i12 + i13*ne12] =
|
|
2736
|
+
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
|
2737
|
+
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
|
2738
|
+
ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
|
|
2720
2739
|
}
|
|
2721
2740
|
|
|
2722
|
-
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
2723
|
-
|
|
2724
|
-
const ggml_tensor *src1,
|
|
2725
|
-
ggml_tensor *dst) try {
|
|
2741
|
+
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
|
|
2742
|
+
const ggml_tensor * src1, ggml_tensor * dst) try {
|
|
2726
2743
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
2727
2744
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
2728
2745
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
2729
2746
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
2747
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
2730
2748
|
|
|
2731
2749
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
2732
2750
|
|
|
2751
|
+
// TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
|
|
2752
|
+
// Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
|
|
2753
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
2733
2754
|
|
|
2734
2755
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
2735
|
-
queue_ptr
|
|
2756
|
+
queue_ptr queue = ctx.stream();
|
|
2736
2757
|
|
|
2737
|
-
|
|
2738
|
-
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
|
|
2739
|
-
float * src1_ddf = (float *) src1->data;
|
|
2740
|
-
float * dst_ddf = (float *) dst->data;
|
|
2758
|
+
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
|
2741
2759
|
|
|
2742
|
-
|
|
2760
|
+
const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
|
|
2761
|
+
float * dst_ddf = static_cast<float *>(dst->data);
|
|
2762
|
+
|
|
2763
|
+
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
|
2764
|
+
const size_t type_size_src1 = ggml_type_size(src1->type);
|
|
2765
|
+
GGML_ASSERT(nb10 == type_size_src1);
|
|
2766
|
+
|
|
2767
|
+
// SRC1 strides
|
|
2768
|
+
int64_t s11 = nb11 / type_size_src1;
|
|
2769
|
+
int64_t s12 = nb12 / type_size_src1;
|
|
2770
|
+
int64_t s13 = nb13 / type_size_src1;
|
|
2743
2771
|
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
|
2772
|
+
|
|
2773
|
+
// convert src1 to fp16
|
|
2744
2774
|
if (src1->type != GGML_TYPE_F16) {
|
|
2745
|
-
const
|
|
2775
|
+
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
|
2776
|
+
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
|
2746
2777
|
const int64_t ne_src1 = ggml_nelements(src1);
|
|
2747
2778
|
src1_f16_alloc.alloc(ne_src1);
|
|
2748
|
-
|
|
2749
|
-
|
|
2779
|
+
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
|
2780
|
+
|
|
2781
|
+
src1_f16 = src1_f16_alloc.get();
|
|
2782
|
+
s11 = ne10;
|
|
2783
|
+
s12 = ne11 * s11;
|
|
2784
|
+
s13 = ne12 * s12;
|
|
2750
2785
|
}
|
|
2751
|
-
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
|
|
2752
|
-
: src1_f16_alloc.get();
|
|
2753
2786
|
|
|
2754
|
-
|
|
2787
|
+
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
|
2755
2788
|
|
|
2756
|
-
dpct::library_data_t
|
|
2757
|
-
dpct::library_data_t
|
|
2789
|
+
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
|
2790
|
+
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
|
2758
2791
|
|
|
2759
2792
|
// dst strides
|
|
2760
2793
|
size_t nbd2 = dst->nb[2];
|
|
2761
2794
|
size_t nbd3 = dst->nb[3];
|
|
2762
2795
|
|
|
2763
2796
|
const float alpha_f32 = 1.0f;
|
|
2764
|
-
const float beta_f32
|
|
2797
|
+
const float beta_f32 = 0.0f;
|
|
2765
2798
|
|
|
2766
2799
|
const void * alpha = &alpha_f32;
|
|
2767
2800
|
const void * beta = &beta_f32;
|
|
2768
2801
|
|
|
2769
|
-
dst_t = (char *) dst_ddf;
|
|
2770
|
-
|
|
2771
2802
|
GGML_ASSERT(ne12 % ne02 == 0);
|
|
2772
2803
|
GGML_ASSERT(ne13 % ne03 == 0);
|
|
2804
|
+
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
|
|
2805
|
+
GGML_ASSERT(ne10 == ne00);
|
|
2773
2806
|
|
|
2774
2807
|
// broadcast factors
|
|
2775
|
-
const int64_t r2 = ne12/ne02;
|
|
2776
|
-
const int64_t r3 = ne13/ne03;
|
|
2777
|
-
|
|
2778
|
-
|
|
2779
|
-
|
|
2780
|
-
|
|
2781
|
-
*
|
|
2782
|
-
|
|
2783
|
-
(
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
|
|
2798
|
-
|
|
2799
|
-
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
2803
|
-
|
|
2804
|
-
|
|
2805
|
-
|
|
2806
|
-
|
|
2807
|
-
|
|
2808
|
-
|
|
2809
|
-
|
|
2810
|
-
|
|
2811
|
-
|
|
2812
|
-
|
|
2813
|
-
|
|
2814
|
-
|
|
2815
|
-
|
|
2808
|
+
const int64_t r2 = ne12 / ne02;
|
|
2809
|
+
const int64_t r3 = ne13 / ne03;
|
|
2810
|
+
|
|
2811
|
+
#if GGML_SYCL_DNNL
|
|
2812
|
+
if (!g_ggml_sycl_disable_dnn) {
|
|
2813
|
+
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
|
2814
|
+
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
|
2815
|
+
|
|
2816
|
+
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
|
2817
|
+
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
|
2818
|
+
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
|
2819
|
+
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
|
2820
|
+
};
|
|
2821
|
+
|
|
2822
|
+
if (r2 == 1 && r3 == 1) {
|
|
2823
|
+
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
|
2824
|
+
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
|
2825
|
+
}
|
|
2826
|
+
else {
|
|
2827
|
+
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
|
2828
|
+
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
|
2829
|
+
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
|
2830
|
+
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
|
2831
|
+
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
|
2832
|
+
}
|
|
2833
|
+
}
|
|
2834
|
+
} else {
|
|
2835
|
+
// iterate over batches from smaller set of matrices (matrix 0)
|
|
2836
|
+
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
|
2837
|
+
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
|
2838
|
+
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
|
2839
|
+
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
|
2840
|
+
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
|
2841
|
+
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
|
2842
|
+
}
|
|
2843
|
+
}
|
|
2844
|
+
}
|
|
2845
|
+
}
|
|
2846
|
+
else
|
|
2847
|
+
#endif
|
|
2848
|
+
{
|
|
2849
|
+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
|
2850
|
+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
2851
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
|
2852
|
+
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
2853
|
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
|
2854
|
+
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
|
|
2855
|
+
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
|
2856
|
+
} else {
|
|
2857
|
+
const int ne23 = ne12 * ne13;
|
|
2858
|
+
|
|
2859
|
+
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
|
2860
|
+
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
|
2861
|
+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
|
2862
|
+
|
|
2863
|
+
sycl::range<3> block_dims(1, ne12, ne13);
|
|
2864
|
+
queue->submit([&](sycl::handler & cgh) {
|
|
2865
|
+
const void ** ptrs_src_get = ptrs_src.get();
|
|
2866
|
+
void ** ptrs_dst_get = ptrs_dst.get();
|
|
2867
|
+
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
|
2868
|
+
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
|
2869
|
+
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
2870
|
+
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
|
2871
|
+
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
|
2872
|
+
});
|
|
2816
2873
|
});
|
|
2874
|
+
|
|
2875
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
|
2876
|
+
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
2877
|
+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
|
2878
|
+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
|
2879
|
+
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
|
2817
2880
|
}
|
|
2818
|
-
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
|
2819
|
-
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
2820
|
-
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
|
2821
|
-
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
|
2822
|
-
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
|
2823
2881
|
}
|
|
2882
|
+
} catch (const sycl::exception & exc) {
|
|
2883
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
2884
|
+
std::exit(1);
|
|
2824
2885
|
}
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
|
|
2828
|
-
|
|
2829
|
-
|
|
2886
|
+
|
|
2887
|
+
enum class mul_mat_algo {
|
|
2888
|
+
DMMV = 0,
|
|
2889
|
+
MMVQ = 1,
|
|
2890
|
+
MUL_MAT_SYCL = 2,
|
|
2891
|
+
};
|
|
2830
2892
|
|
|
2831
2893
|
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
2832
2894
|
// TODO: accuracy issues in MMQ
|
|
@@ -2834,6 +2896,36 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
|
2834
2896
|
return false;
|
|
2835
2897
|
}
|
|
2836
2898
|
|
|
2899
|
+
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
|
2900
|
+
switch (type) {
|
|
2901
|
+
case GGML_TYPE_Q4_0:
|
|
2902
|
+
return true;
|
|
2903
|
+
case GGML_TYPE_Q4_K:
|
|
2904
|
+
return !g_ggml_sycl_prioritize_dmmv;
|
|
2905
|
+
default:
|
|
2906
|
+
return false;
|
|
2907
|
+
}
|
|
2908
|
+
}
|
|
2909
|
+
|
|
2910
|
+
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
|
2911
|
+
switch (type) {
|
|
2912
|
+
case GGML_TYPE_Q4_0:
|
|
2913
|
+
return true;
|
|
2914
|
+
default:
|
|
2915
|
+
return false;
|
|
2916
|
+
}
|
|
2917
|
+
}
|
|
2918
|
+
|
|
2919
|
+
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
|
2920
|
+
switch (type) {
|
|
2921
|
+
case GGML_TYPE_Q4_0:
|
|
2922
|
+
case GGML_TYPE_Q4_K:
|
|
2923
|
+
return true;
|
|
2924
|
+
default:
|
|
2925
|
+
return false;
|
|
2926
|
+
}
|
|
2927
|
+
}
|
|
2928
|
+
|
|
2837
2929
|
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
|
2838
2930
|
switch (type) {
|
|
2839
2931
|
case GGML_TYPE_Q4_0:
|
|
@@ -2853,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
|
|
2853
2945
|
}
|
|
2854
2946
|
}
|
|
2855
2947
|
|
|
2856
|
-
static void
|
|
2857
|
-
|
|
2858
|
-
auto tmp_buf = sycl::malloc_shared<
|
|
2948
|
+
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
|
2949
|
+
dpct::queue_ptr stream) {
|
|
2950
|
+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
|
2859
2951
|
SYCL_CHECK(
|
|
2860
2952
|
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
|
|
2861
2953
|
.wait()));
|
|
2862
2954
|
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
|
2863
2955
|
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
|
2864
2956
|
int offset_blks = offset / sizeof(block_q4_0);
|
|
2865
|
-
auto qs_ptr
|
|
2957
|
+
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
|
2866
2958
|
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
|
2867
2959
|
|
|
2868
2960
|
stream->parallel_for(
|
|
@@ -2876,48 +2968,119 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
|
|
2876
2968
|
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
|
2877
2969
|
}
|
|
2878
2970
|
*(d_ptr + ib) = x[ib].d;
|
|
2879
|
-
});
|
|
2971
|
+
}).wait_and_throw();
|
|
2972
|
+
|
|
2973
|
+
sycl::free(tmp_buf, *stream);
|
|
2974
|
+
}
|
|
2975
|
+
|
|
2976
|
+
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
|
2977
|
+
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
|
|
2978
|
+
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
|
|
2979
|
+
|
|
2980
|
+
const int nblocks = size / sizeof(block_q4_K);
|
|
2981
|
+
|
|
2982
|
+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
|
2983
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
|
2984
|
+
|
|
2985
|
+
auto * qs_ptr = data_device;
|
|
2986
|
+
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
|
2987
|
+
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
|
2988
|
+
|
|
2989
|
+
stream->parallel_for(nblocks, [=](auto i) {
|
|
2990
|
+
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
|
2991
|
+
const int ib = i;
|
|
2992
|
+
|
|
2993
|
+
for (int j = 0; j < QK_K / 2; ++j) {
|
|
2994
|
+
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
|
|
2995
|
+
}
|
|
2996
|
+
|
|
2997
|
+
for (int j = 0; j < K_SCALE_SIZE; ++j) {
|
|
2998
|
+
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
|
|
2999
|
+
}
|
|
3000
|
+
|
|
3001
|
+
dm_ptr[ib] = x[ib].dm;
|
|
3002
|
+
}).wait_and_throw();
|
|
2880
3003
|
|
|
2881
3004
|
sycl::free(tmp_buf, *stream);
|
|
2882
3005
|
}
|
|
2883
3006
|
|
|
2884
3007
|
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
|
2885
|
-
|
|
3008
|
+
uint8_t * data_device = (uint8_t *) src0->data;
|
|
2886
3009
|
size_t ncols = src0->ne[0];
|
|
2887
3010
|
size_t nrows = src0->ne[1];
|
|
2888
3011
|
size_t size = ggml_nbytes(src0);
|
|
2889
3012
|
|
|
2890
|
-
|
|
3013
|
+
switch (src0->type) {
|
|
3014
|
+
case GGML_TYPE_Q4_0:
|
|
3015
|
+
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
|
|
3016
|
+
break;
|
|
3017
|
+
case GGML_TYPE_Q4_K:
|
|
3018
|
+
reorder_qw_q4_k(data_device, size, 0, stream);
|
|
3019
|
+
break;
|
|
3020
|
+
default:
|
|
3021
|
+
GGML_ABORT("reorder_qw() called with unsupported type");
|
|
3022
|
+
break;
|
|
3023
|
+
}
|
|
2891
3024
|
}
|
|
2892
3025
|
|
|
2893
|
-
|
|
2894
|
-
|
|
2895
|
-
|
|
2896
|
-
|
|
2897
|
-
|
|
2898
|
-
|
|
2899
|
-
ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
|
2900
|
-
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
|
2901
|
-
src0->type == GGML_TYPE_Q4_0 &&
|
|
2902
|
-
src1->ne[2]==1 && src1->ne[3]==1) {
|
|
3026
|
+
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
|
3027
|
+
return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
|
3028
|
+
ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
|
3029
|
+
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
|
3030
|
+
dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
|
3031
|
+
}
|
|
2903
3032
|
|
|
2904
|
-
|
|
2905
|
-
|
|
3033
|
+
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
|
|
3034
|
+
ggml_tensor * dst, mul_mat_algo mm_algorithm) {
|
|
3035
|
+
if (!should_reorder_tensor(*ctx, dst)) {
|
|
3036
|
+
return;
|
|
3037
|
+
}
|
|
2906
3038
|
|
|
2907
|
-
|
|
3039
|
+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
|
3040
|
+
if (!extra || extra->optimized_feature.reorder) {
|
|
3041
|
+
return; // Skip permutations and already reordered tensors
|
|
3042
|
+
}
|
|
2908
3043
|
|
|
2909
|
-
|
|
2910
|
-
|
|
3044
|
+
switch (mm_algorithm) {
|
|
3045
|
+
case mul_mat_algo::DMMV:
|
|
3046
|
+
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
|
|
3047
|
+
return;
|
|
3048
|
+
}
|
|
3049
|
+
break;
|
|
3050
|
+
case mul_mat_algo::MMVQ:
|
|
3051
|
+
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
|
|
3052
|
+
return;
|
|
3053
|
+
}
|
|
3054
|
+
break;
|
|
3055
|
+
case mul_mat_algo::MUL_MAT_SYCL:
|
|
3056
|
+
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
|
|
3057
|
+
return;
|
|
3058
|
+
}
|
|
3059
|
+
break;
|
|
2911
3060
|
}
|
|
3061
|
+
|
|
3062
|
+
reorder_qw(src0, ctx->stream());
|
|
3063
|
+
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
|
2912
3064
|
}
|
|
2913
3065
|
|
|
2914
|
-
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
2915
3066
|
|
|
3067
|
+
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
3068
|
+
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
|
3069
|
+
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
|
3070
|
+
}
|
|
3071
|
+
|
|
3072
|
+
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
3073
|
+
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
|
3074
|
+
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
|
3075
|
+
}
|
|
3076
|
+
|
|
3077
|
+
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
2916
3078
|
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
|
2917
3079
|
int64_t min_compute_capability = INT_MAX;
|
|
2918
3080
|
|
|
2919
3081
|
if (split) {
|
|
2920
|
-
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
|
3082
|
+
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
|
3083
|
+
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
|
2921
3084
|
auto & tensor_split = buft_ctx->tensor_split;
|
|
2922
3085
|
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
|
|
2923
3086
|
// skip devices that are not going to do any work:
|
|
@@ -2930,17 +3093,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
2930
3093
|
}
|
|
2931
3094
|
}
|
|
2932
3095
|
} else {
|
|
2933
|
-
min_compute_capability
|
|
3096
|
+
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
|
|
2934
3097
|
}
|
|
2935
3098
|
|
|
2936
3099
|
// check data types and tensor shapes for custom matrix multiplication kernels:
|
|
2937
|
-
bool use_dequantize_mul_mat_vec =
|
|
2938
|
-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
2939
|
-
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
|
3100
|
+
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
|
|
2940
3101
|
|
|
2941
|
-
bool use_mul_mat_vec_q =
|
|
2942
|
-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
2943
|
-
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
|
3102
|
+
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
|
|
2944
3103
|
|
|
2945
3104
|
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
|
2946
3105
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
@@ -2952,9 +3111,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
2952
3111
|
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
|
2953
3112
|
#endif // SYCL_USE_XMX
|
|
2954
3113
|
|
|
3114
|
+
|
|
2955
3115
|
// mmvq path is faster in the CUDA backend.
|
|
2956
|
-
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
|
3116
|
+
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
|
3117
|
+
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
|
3118
|
+
// is enabled takes precedence over DMMV, the current if-else implementation
|
|
3119
|
+
// requires disabling DMMV if both conditions are met
|
|
3120
|
+
|| (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
|
2957
3121
|
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
|
3122
|
+
}
|
|
2958
3123
|
|
|
2959
3124
|
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
|
2960
3125
|
// TODO: Refactor and cleanup of mul mat dispatching.
|
|
@@ -2966,24 +3131,30 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
2966
3131
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
|
2967
3132
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
2968
3133
|
}
|
|
2969
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
3134
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
2970
3135
|
// KQV single-batch
|
|
2971
3136
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
|
2972
3137
|
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
2973
3138
|
// KQ + KQV multi-batch
|
|
2974
3139
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
2975
3140
|
} else if (use_dequantize_mul_mat_vec) {
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
|
|
3141
|
+
constexpr bool convert_src1_to_q8_1 = false;
|
|
3142
|
+
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
|
3143
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
|
|
2979
3144
|
} else if (use_mul_mat_vec_q) {
|
|
2980
|
-
|
|
3145
|
+
constexpr bool convert_src1_to_q8_1 = true;
|
|
3146
|
+
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
|
3147
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
|
|
2981
3148
|
} else if (use_mul_mat_q) {
|
|
2982
|
-
|
|
3149
|
+
constexpr bool convert_src1_to_q8_1 = true;
|
|
3150
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
|
2983
3151
|
} else {
|
|
2984
|
-
|
|
2985
|
-
|
|
3152
|
+
constexpr bool convert_src1_to_q8_1 = false;
|
|
3153
|
+
// MUL_MAT_SYCL supports reorder
|
|
3154
|
+
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
|
|
3155
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
|
2986
3156
|
}
|
|
3157
|
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
|
2987
3158
|
}
|
|
2988
3159
|
|
|
2989
3160
|
|
|
@@ -3651,7 +3822,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
|
|
3651
3822
|
return GGML_STATUS_SUCCESS;
|
|
3652
3823
|
}
|
|
3653
3824
|
|
|
3654
|
-
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
|
3825
|
+
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
|
|
3826
|
+
|
|
3655
3827
|
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
|
3656
3828
|
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
|
3657
3829
|
model_sycl_graph.end_recording();
|