@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -19,12 +19,6 @@
|
|
|
19
19
|
#define GROUP_MAX_EPS_IQ1_M 1e-7f
|
|
20
20
|
#define GROUP_MAX_EPS_IQ1_S 1e-12f
|
|
21
21
|
|
|
22
|
-
#if defined(_MSC_VER)
|
|
23
|
-
// disable "possible loss of data" to avoid warnings for hundreds of casts
|
|
24
|
-
// we should just be careful :)
|
|
25
|
-
#pragma warning(disable: 4244 4267)
|
|
26
|
-
#endif
|
|
27
|
-
|
|
28
22
|
#define UNUSED GGML_UNUSED
|
|
29
23
|
|
|
30
24
|
// reference implementation for deterministic creation of model files
|
|
@@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req {
|
|
|
151
151
|
uint8_t value;
|
|
152
152
|
};
|
|
153
153
|
|
|
154
|
+
struct rpc_msg_set_tensor_hash_req {
|
|
155
|
+
rpc_tensor tensor;
|
|
156
|
+
uint64_t offset;
|
|
157
|
+
uint64_t hash;
|
|
158
|
+
};
|
|
159
|
+
|
|
154
160
|
struct rpc_msg_set_tensor_hash_rsp {
|
|
155
161
|
uint8_t result;
|
|
156
162
|
};
|
|
@@ -518,6 +524,11 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
|
518
524
|
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
|
519
525
|
result.view_offs = tensor->view_offs;
|
|
520
526
|
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
|
527
|
+
|
|
528
|
+
// Avoid sending uninitialized data over the wire
|
|
529
|
+
memset(result.name, 0, sizeof(result.name));
|
|
530
|
+
memset(result.padding, 0, sizeof(result.padding));
|
|
531
|
+
|
|
521
532
|
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
|
|
522
533
|
return result;
|
|
523
534
|
}
|
|
@@ -543,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|
|
543
554
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
544
555
|
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
545
556
|
if (size > HASH_THRESHOLD) {
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
551
|
-
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
552
|
-
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
|
|
557
|
+
rpc_msg_set_tensor_hash_req request;
|
|
558
|
+
request.tensor = rpc_tensor;
|
|
559
|
+
request.offset = offset;
|
|
560
|
+
request.hash = fnv_hash((const uint8_t*)data, size);
|
|
553
561
|
rpc_msg_set_tensor_hash_rsp response;
|
|
554
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH,
|
|
562
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
|
555
563
|
GGML_ASSERT(status);
|
|
556
564
|
if (response.result) {
|
|
557
565
|
// the server has the same data, no need to send it
|
|
@@ -859,7 +867,7 @@ public:
|
|
|
859
867
|
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
860
868
|
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
861
869
|
bool set_tensor(const std::vector<uint8_t> & input);
|
|
862
|
-
bool set_tensor_hash(const
|
|
870
|
+
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
|
863
871
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
864
872
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
865
873
|
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
|
@@ -1096,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|
|
1096
1104
|
return true;
|
|
1097
1105
|
}
|
|
1098
1106
|
|
|
1099
|
-
bool rpc_server::set_tensor_hash(const
|
|
1107
|
+
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
|
|
1100
1108
|
{
|
|
1101
|
-
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
|
|
1102
|
-
if (input.size() != sizeof(rpc_tensor) + 16) {
|
|
1103
|
-
return false;
|
|
1104
|
-
}
|
|
1105
|
-
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
|
1106
|
-
uint64_t offset;
|
|
1107
|
-
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
|
1108
|
-
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
|
|
1109
1109
|
std::vector<uint8_t> cached_file;
|
|
1110
|
-
if (!get_cached_file(
|
|
1110
|
+
if (!get_cached_file(request.hash, cached_file)) {
|
|
1111
1111
|
response.result = 0;
|
|
1112
1112
|
return true;
|
|
1113
1113
|
}
|
|
@@ -1120,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
|
|
|
1120
1120
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1121
1121
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1122
1122
|
ggml_context * ctx = ctx_ptr.get();
|
|
1123
|
-
ggml_tensor * tensor = deserialize_tensor(ctx,
|
|
1123
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
1124
1124
|
if (tensor == nullptr) {
|
|
1125
1125
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1126
1126
|
return false;
|
|
1127
1127
|
}
|
|
1128
|
-
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
|
1128
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
|
1129
|
+
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
|
|
1129
1130
|
|
|
1130
1131
|
// sanitize tensor->data
|
|
1131
1132
|
{
|
|
1132
1133
|
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
|
1133
1134
|
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
1134
1135
|
|
|
1135
|
-
if (
|
|
1136
|
+
if (request.tensor.data + request.offset < p0
|
|
1137
|
+
|| request.tensor.data + request.offset >= p1
|
|
1138
|
+
|| size > (p1 - request.tensor.data - request.offset)) {
|
|
1136
1139
|
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1137
|
-
__func__,
|
|
1140
|
+
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
|
|
1138
1141
|
return false;
|
|
1139
1142
|
}
|
|
1140
1143
|
}
|
|
1141
|
-
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
|
|
1144
|
+
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
|
|
1142
1145
|
response.result = 1;
|
|
1143
1146
|
return true;
|
|
1144
1147
|
}
|
|
@@ -1498,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1498
1501
|
break;
|
|
1499
1502
|
}
|
|
1500
1503
|
case RPC_CMD_SET_TENSOR_HASH: {
|
|
1501
|
-
|
|
1502
|
-
if (!recv_msg(sockfd,
|
|
1504
|
+
rpc_msg_set_tensor_hash_req request;
|
|
1505
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1503
1506
|
return;
|
|
1504
1507
|
}
|
|
1505
1508
|
rpc_msg_set_tensor_hash_rsp response;
|
|
1506
|
-
if (!server.set_tensor_hash(
|
|
1509
|
+
if (!server.set_tensor_hash(request, response)) {
|
|
1507
1510
|
return;
|
|
1508
1511
|
}
|
|
1509
1512
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
@@ -1589,6 +1592,14 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1589
1592
|
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
|
1590
1593
|
const char * cache_dir,
|
|
1591
1594
|
size_t free_mem, size_t total_mem) {
|
|
1595
|
+
printf("Starting RPC server v%d.%d.%d\n",
|
|
1596
|
+
RPC_PROTO_MAJOR_VERSION,
|
|
1597
|
+
RPC_PROTO_MINOR_VERSION,
|
|
1598
|
+
RPC_PROTO_PATCH_VERSION);
|
|
1599
|
+
printf(" endpoint : %s\n", endpoint);
|
|
1600
|
+
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
|
1601
|
+
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
|
|
1602
|
+
|
|
1592
1603
|
std::string host;
|
|
1593
1604
|
int port;
|
|
1594
1605
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
@@ -1748,6 +1759,9 @@ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const ch
|
|
|
1748
1759
|
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
|
1749
1760
|
return (void *)ggml_backend_rpc_add_device;
|
|
1750
1761
|
}
|
|
1762
|
+
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
|
|
1763
|
+
return (void *)ggml_backend_rpc_start_server;
|
|
1764
|
+
}
|
|
1751
1765
|
return NULL;
|
|
1752
1766
|
|
|
1753
1767
|
GGML_UNUSED(reg);
|
|
@@ -49,35 +49,38 @@ endif()
|
|
|
49
49
|
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
|
50
50
|
|
|
51
51
|
# Link against oneDNN
|
|
52
|
-
find_package(DNNL)
|
|
53
52
|
set(GGML_SYCL_DNNL 0)
|
|
54
|
-
if(
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
53
|
+
if(GGML_SYCL_DNN)
|
|
54
|
+
find_package(DNNL)
|
|
55
|
+
if(DNNL_FOUND)
|
|
56
|
+
if (NOT DEFINED DNNL_GPU_VENDOR)
|
|
57
|
+
# default to intel target
|
|
58
|
+
set(DNNL_GPU_VENDOR "INTEL")
|
|
59
|
+
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
|
60
|
+
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
|
61
|
+
endif()
|
|
61
62
|
endif()
|
|
62
|
-
endif()
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
64
|
+
# Verify oneDNN was compiled for the same target as llama
|
|
65
|
+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
|
66
|
+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
|
67
|
+
set(GGML_SYCL_DNNL 1)
|
|
68
|
+
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
|
69
|
+
foreach(CONFIG ${CONFIGS})
|
|
70
|
+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
|
71
|
+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
|
72
|
+
endforeach()
|
|
73
|
+
else()
|
|
74
|
+
message(WARNING
|
|
75
|
+
"oneDNN must be compiled for the same target as llama.cpp.
|
|
76
|
+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
|
77
|
+
Disabling oneDNN support.")
|
|
78
|
+
endif()
|
|
73
79
|
else()
|
|
74
|
-
message(
|
|
75
|
-
"oneDNN must be compiled for the same target as llama.cpp.
|
|
76
|
-
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
|
77
|
-
Disabling oneDNN support.")
|
|
80
|
+
message(STATUS "oneDNN not found, disabling oneDNN support")
|
|
78
81
|
endif()
|
|
79
82
|
else()
|
|
80
|
-
message(STATUS "oneDNN
|
|
83
|
+
message(STATUS "oneDNN support disabled by the user")
|
|
81
84
|
endif()
|
|
82
85
|
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
|
83
86
|
|
|
@@ -108,6 +111,9 @@ endif()
|
|
|
108
111
|
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
|
109
112
|
# Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
|
|
110
113
|
# See https://github.com/uxlfoundation/oneMath/issues/654
|
|
114
|
+
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|
115
|
+
set(SYCL_COMPILER ON)
|
|
116
|
+
endif()
|
|
111
117
|
find_package(MKL REQUIRED)
|
|
112
118
|
target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
|
|
113
119
|
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
|
|
@@ -14,23 +14,24 @@
|
|
|
14
14
|
#define GGML_SYCL_BACKEND_HPP
|
|
15
15
|
|
|
16
16
|
#include "binbcast.hpp"
|
|
17
|
-
#include "concat.hpp"
|
|
18
17
|
#include "common.hpp"
|
|
18
|
+
#include "concat.hpp"
|
|
19
19
|
#include "conv.hpp"
|
|
20
20
|
#include "convert.hpp"
|
|
21
|
+
#include "cpy.hpp"
|
|
21
22
|
#include "dequantize.hpp"
|
|
22
23
|
#include "dmmv.hpp"
|
|
24
|
+
#include "element_wise.hpp"
|
|
25
|
+
#include "gla.hpp"
|
|
26
|
+
#include "im2col.hpp"
|
|
23
27
|
#include "mmq.hpp"
|
|
24
28
|
#include "mmvq.hpp"
|
|
25
|
-
#include "rope.hpp"
|
|
26
29
|
#include "norm.hpp"
|
|
30
|
+
#include "outprod.hpp"
|
|
31
|
+
#include "quants.hpp"
|
|
32
|
+
#include "rope.hpp"
|
|
27
33
|
#include "softmax.hpp"
|
|
28
34
|
#include "tsembd.hpp"
|
|
29
|
-
#include "im2col.hpp"
|
|
30
35
|
#include "wkv.hpp"
|
|
31
|
-
#include "outprod.hpp"
|
|
32
|
-
#include "element_wise.hpp"
|
|
33
|
-
#include "cpy.hpp"
|
|
34
|
-
#include "gla.hpp"
|
|
35
36
|
|
|
36
|
-
#endif
|
|
37
|
+
#endif // GGML_SYCL_BACKEND_HPP
|
|
@@ -1,93 +1,74 @@
|
|
|
1
1
|
#include "binbcast.hpp"
|
|
2
2
|
|
|
3
|
+
#include <array>
|
|
3
4
|
#include <cstddef>
|
|
4
5
|
#include <cstdint>
|
|
5
6
|
#include <sycl/sycl.hpp>
|
|
6
7
|
|
|
8
|
+
#include "dpct/helper.hpp"
|
|
7
9
|
#include "ggml.h"
|
|
8
10
|
|
|
9
|
-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
10
|
-
static void
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
item_ct1.get_local_id(1));
|
|
21
|
-
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
|
22
|
-
item_ct1.get_local_id(0)) /
|
|
23
|
-
ne3;
|
|
24
|
-
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
|
25
|
-
item_ct1.get_local_id(0)) %
|
|
26
|
-
ne3;
|
|
27
|
-
|
|
28
|
-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
|
29
|
-
return;
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
const int i11 = i1 % ne11;
|
|
33
|
-
const int i12 = i2 % ne12;
|
|
34
|
-
const int i13 = i3 % ne13;
|
|
35
|
-
|
|
36
|
-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
|
37
|
-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
|
38
|
-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
|
39
|
-
|
|
40
|
-
const src0_t * src0_row = src0 + i_src0;
|
|
41
|
-
const src1_t * src1_row = src1 + i_src1;
|
|
42
|
-
dst_t * dst_row = dst + i_dst;
|
|
43
|
-
|
|
44
|
-
for (int i0 = i0s; i0 < ne0;
|
|
45
|
-
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
|
46
|
-
const int i10 = i0 % ne10;
|
|
47
|
-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
|
11
|
+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
12
|
+
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
|
|
13
|
+
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
|
|
14
|
+
auto element_id = it.get_global_id(0);
|
|
15
|
+
auto global_range = it.get_global_range(0);
|
|
16
|
+
for (; element_id < num_elements; element_id += global_range) {
|
|
17
|
+
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
|
18
|
+
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
|
19
|
+
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
|
20
|
+
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
|
21
|
+
dst[element_id] = val_to_store;
|
|
48
22
|
}
|
|
49
23
|
}
|
|
50
24
|
|
|
51
|
-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
52
|
-
static void
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
const
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
25
|
+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
26
|
+
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
|
|
27
|
+
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
|
|
28
|
+
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
|
|
29
|
+
int s11, int s12, int s13, std::size_t num_dst_elements,
|
|
30
|
+
const sycl::nd_item<1> & item_ct1) {
|
|
31
|
+
auto calculate_logical_index =
|
|
32
|
+
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
|
|
33
|
+
std::array<int, 4> logical_index;
|
|
34
|
+
#pragma unroll(4)
|
|
35
|
+
for (int i = 3; i >= 0; i--) {
|
|
36
|
+
logical_index[i] = element_id % dims[i];
|
|
37
|
+
element_id /= dims[i];
|
|
38
|
+
}
|
|
39
|
+
return logical_index;
|
|
40
|
+
};
|
|
41
|
+
|
|
42
|
+
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
|
|
43
|
+
const std::array<int, 4> & indices) __attribute__((always_inline))
|
|
44
|
+
->std::size_t {
|
|
45
|
+
std::size_t index = 0;
|
|
46
|
+
#pragma unroll(4)
|
|
47
|
+
for (int i = 0; i < 4; i++) {
|
|
48
|
+
auto index_i = indices[i];
|
|
49
|
+
if (indices[i] >= dims[i]) {
|
|
50
|
+
index_i = indices[i] % dims[i];
|
|
51
|
+
}
|
|
52
|
+
index += strides[i] * index_i;
|
|
53
|
+
}
|
|
54
|
+
return index;
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
auto element_id = item_ct1.get_global_id(0);
|
|
58
|
+
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
|
|
59
|
+
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
|
|
60
|
+
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
|
|
61
|
+
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
|
|
62
|
+
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
|
|
63
|
+
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
|
|
64
|
+
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
|
|
65
|
+
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
|
66
|
+
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
|
67
|
+
dst[dst_index] = val_to_store;
|
|
70
68
|
}
|
|
71
|
-
|
|
72
|
-
const int i11 = i1 % ne11;
|
|
73
|
-
const int i12 = i2 % ne12;
|
|
74
|
-
const int i13 = i3 % ne13;
|
|
75
|
-
|
|
76
|
-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
|
77
|
-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
|
78
|
-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
|
79
|
-
|
|
80
|
-
const src0_t * src0_row = src0 + i_src0;
|
|
81
|
-
const src1_t * src1_row = src1 + i_src1;
|
|
82
|
-
dst_t * dst_row = dst + i_dst;
|
|
83
|
-
|
|
84
|
-
const int i10 = i0 % ne10;
|
|
85
|
-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
|
86
69
|
}
|
|
87
70
|
|
|
88
|
-
|
|
89
|
-
template<float (*bin_op)(const float, const float)>
|
|
90
|
-
struct bin_bcast_sycl {
|
|
71
|
+
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
|
|
91
72
|
template <typename src0_t, typename src1_t, typename dst_t>
|
|
92
73
|
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
|
|
93
74
|
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
|
|
@@ -96,165 +77,73 @@ struct bin_bcast_sycl {
|
|
|
96
77
|
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
|
|
97
78
|
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
|
|
98
79
|
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
int nr2 = ne12/ne2;
|
|
102
|
-
int nr3 = ne13/ne3;
|
|
103
|
-
|
|
104
|
-
int nr[4] = { nr0, nr1, nr2, nr3 };
|
|
105
|
-
|
|
106
|
-
// collapse dimensions until first broadcast dimension
|
|
107
|
-
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
|
108
|
-
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
|
109
|
-
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
|
110
|
-
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
|
111
|
-
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
|
112
|
-
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
|
113
|
-
auto collapse = [](int64_t cne[]) {
|
|
114
|
-
cne[0] *= cne[1];
|
|
115
|
-
cne[1] = cne[2];
|
|
116
|
-
cne[2] = cne[3];
|
|
117
|
-
cne[3] = 1;
|
|
118
|
-
};
|
|
119
|
-
|
|
120
|
-
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
|
121
|
-
cnb[1] *= cne[1];
|
|
122
|
-
cnb[2] *= cne[2];
|
|
123
|
-
cnb[3] *= cne[3];
|
|
124
|
-
};
|
|
125
|
-
|
|
126
|
-
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
|
|
80
|
+
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
|
|
81
|
+
const std::array<int64_t, 4> & dst_dims) -> bool {
|
|
127
82
|
for (int i = 0; i < 4; i++) {
|
|
128
|
-
if (
|
|
129
|
-
|
|
130
|
-
}
|
|
131
|
-
if (i > 0) {
|
|
132
|
-
collapse_nb(cnb, cne);
|
|
133
|
-
collapse_nb(cnb0, cne0);
|
|
134
|
-
collapse_nb(cnb1, cne1);
|
|
135
|
-
collapse(cne);
|
|
136
|
-
collapse(cne0);
|
|
137
|
-
collapse(cne1);
|
|
83
|
+
if (dst_dims[i] > src_dims[i]) {
|
|
84
|
+
return true;
|
|
138
85
|
}
|
|
139
86
|
}
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
GGML_ASSERT(s10 == 1);
|
|
201
|
-
|
|
202
|
-
const int block_size = 128;
|
|
203
|
-
|
|
204
|
-
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
|
205
|
-
|
|
206
|
-
sycl::range<3> block_dims(1, 1, 1);
|
|
207
|
-
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
|
208
|
-
block_dims[1] = std::min<unsigned int>(
|
|
209
|
-
ne1, block_size / (unsigned int)block_dims[2]);
|
|
210
|
-
block_dims[0] = std::min(
|
|
211
|
-
std::min<unsigned int>(
|
|
212
|
-
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
|
213
|
-
(unsigned int)block_dims[1]),
|
|
214
|
-
64U);
|
|
215
|
-
|
|
216
|
-
sycl::range<3> block_nums(
|
|
217
|
-
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
|
218
|
-
(ne1 + block_dims[1] - 1) / block_dims[1],
|
|
219
|
-
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
|
220
|
-
|
|
221
|
-
if (block_nums[0] > 65535) {
|
|
222
|
-
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
|
223
|
-
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
|
224
|
-
{
|
|
225
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
|
226
|
-
{sycl::aspect::fp16});
|
|
227
|
-
|
|
228
|
-
stream->parallel_for(
|
|
229
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
|
230
|
-
sycl::range<3>(1, 1, block_size),
|
|
231
|
-
sycl::range<3>(1, 1, block_size)),
|
|
232
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
233
|
-
k_bin_bcast_unravel<bin_op>(
|
|
234
|
-
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
|
235
|
-
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
|
|
236
|
-
s03, s11, s12, s13, item_ct1);
|
|
237
|
-
});
|
|
238
|
-
}
|
|
239
|
-
} else {
|
|
240
|
-
/*
|
|
241
|
-
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
|
242
|
-
exceed the limit. To get the device limit, query
|
|
243
|
-
info::device::max_work_group_size. Adjust the work-group size if
|
|
244
|
-
needed.
|
|
245
|
-
*/
|
|
246
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
|
247
|
-
{sycl::aspect::fp16});
|
|
248
|
-
|
|
249
|
-
stream->parallel_for(
|
|
250
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
251
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
252
|
-
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
|
253
|
-
ne2, ne3, ne10, ne11, ne12, ne13,
|
|
254
|
-
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
|
255
|
-
item_ct1);
|
|
256
|
-
});
|
|
257
|
-
}
|
|
87
|
+
return false;
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
91
|
+
|
|
92
|
+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
|
93
|
+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
|
94
|
+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
|
95
|
+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
|
96
|
+
|
|
97
|
+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
|
98
|
+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
|
99
|
+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
|
100
|
+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
|
101
|
+
|
|
102
|
+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
|
103
|
+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
|
104
|
+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
|
105
|
+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
|
106
|
+
|
|
107
|
+
// dst strides in number of elements
|
|
108
|
+
size_t s0 = nb0 / sizeof(dst_t);
|
|
109
|
+
size_t s1 = nb1 / sizeof(dst_t);
|
|
110
|
+
size_t s2 = nb2 / sizeof(dst_t);
|
|
111
|
+
size_t s3 = nb3 / sizeof(dst_t);
|
|
112
|
+
|
|
113
|
+
// src1 strides in number of elements
|
|
114
|
+
size_t s10 = nb10 / sizeof(src0_t);
|
|
115
|
+
size_t s11 = nb11 / sizeof(src1_t);
|
|
116
|
+
size_t s12 = nb12 / sizeof(src1_t);
|
|
117
|
+
size_t s13 = nb13 / sizeof(src1_t);
|
|
118
|
+
|
|
119
|
+
// src0 strides in number of elements
|
|
120
|
+
size_t s00 = nb00 / sizeof(src0_t);
|
|
121
|
+
size_t s01 = nb01 / sizeof(src0_t);
|
|
122
|
+
size_t s02 = nb02 / sizeof(src0_t);
|
|
123
|
+
size_t s03 = nb03 / sizeof(src0_t);
|
|
124
|
+
|
|
125
|
+
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
|
|
126
|
+
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
|
|
127
|
+
std::size_t local_range = 256;
|
|
128
|
+
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
|
|
129
|
+
|
|
130
|
+
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
|
|
131
|
+
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
|
|
132
|
+
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
|
|
133
|
+
|
|
134
|
+
if (! needs_broadcasting && all_contiguous) {
|
|
135
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
136
|
+
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
|
137
|
+
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
} else {
|
|
141
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
142
|
+
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
|
143
|
+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
|
|
144
|
+
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
|
|
145
|
+
});
|
|
146
|
+
});
|
|
258
147
|
}
|
|
259
148
|
}
|
|
260
149
|
};
|