@fugood/llama.node 0.3.16 → 0.3.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +5 -0
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +8 -0
- package/src/LlamaCompletionWorker.h +1 -0
- package/src/LlamaContext.cpp +3 -2
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
- package/src/llama.cpp/.github/workflows/build.yml +70 -27
- package/src/llama.cpp/.github/workflows/docker.yml +6 -6
- package/src/llama.cpp/.github/workflows/server.yml +7 -11
- package/src/llama.cpp/CMakeLists.txt +23 -1
- package/src/llama.cpp/common/CMakeLists.txt +6 -3
- package/src/llama.cpp/common/arg.cpp +809 -105
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +1 -1
- package/src/llama.cpp/common/common.cpp +31 -521
- package/src/llama.cpp/common/common.h +17 -36
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +30 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
- package/src/llama.cpp/common/minja/minja.hpp +119 -93
- package/src/llama.cpp/common/sampling.cpp +3 -0
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +0 -9
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
- package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
- package/src/llama.cpp/examples/llava/clip.h +39 -22
- package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/examples/llava/llava.cpp +64 -52
- package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
- package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
- package/src/llama.cpp/examples/llava/mtmd.h +168 -0
- package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
- package/src/llama.cpp/examples/main/main.cpp +16 -5
- package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
- package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
- package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
- package/src/llama.cpp/examples/run/run.cpp +14 -28
- package/src/llama.cpp/examples/server/httplib.h +313 -247
- package/src/llama.cpp/examples/server/server.cpp +238 -139
- package/src/llama.cpp/examples/server/utils.hpp +51 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +6 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +66 -99
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
- package/src/llama.cpp/ggml/src/ggml.c +141 -245
- package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
- package/src/llama.cpp/include/llama.h +30 -11
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +2 -0
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/src/CMakeLists.txt +3 -2
- package/src/llama.cpp/src/llama-adapter.cpp +37 -1
- package/src/llama.cpp/src/llama-arch.cpp +160 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-chat.cpp +82 -17
- package/src/llama.cpp/src/llama-chat.h +6 -2
- package/src/llama.cpp/src/llama-context.cpp +108 -92
- package/src/llama.cpp/src/llama-context.h +1 -2
- package/src/llama.cpp/src/llama-graph.cpp +189 -119
- package/src/llama.cpp/src/llama-graph.h +26 -6
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
- package/src/llama.cpp/src/llama-kv-cache.h +41 -115
- package/src/llama.cpp/src/llama-memory.h +1 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model.cpp +1760 -534
- package/src/llama.cpp/src/llama-model.h +13 -1
- package/src/llama.cpp/src/llama-quant.cpp +29 -8
- package/src/llama.cpp/src/llama-sampling.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.cpp +44 -6
- package/src/llama.cpp/src/llama.cpp +1 -1
- package/src/llama.cpp/tests/CMakeLists.txt +43 -30
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
- package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
- package/src/llama.cpp/tests/test-chat.cpp +12 -2
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
|
@@ -12,115 +12,125 @@
|
|
|
12
12
|
|
|
13
13
|
#include "im2col.hpp"
|
|
14
14
|
|
|
15
|
+
#include <sycl/sycl.hpp>
|
|
16
|
+
#include <type_traits> // For std::is_same_v
|
|
17
|
+
|
|
18
|
+
#include "ggml.h"
|
|
19
|
+
|
|
15
20
|
template <typename T>
|
|
16
|
-
static void im2col_kernel(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
|
|
20
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
21
|
+
static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW,
|
|
22
|
+
int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
|
|
23
|
+
int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) {
|
|
21
24
|
const int64_t work_group_size = item_ct1.get_local_range(2);
|
|
22
|
-
const int64_t global_id
|
|
25
|
+
const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2));
|
|
23
26
|
|
|
24
27
|
// make each work-item deal with more elements since sycl global range can not exceed max int
|
|
25
|
-
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
|
|
26
|
-
|
|
28
|
+
for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {
|
|
27
29
|
const int64_t ksize = OW * (KH > 1 ? KW : 1);
|
|
28
|
-
const int64_t kx
|
|
29
|
-
const int64_t kd
|
|
30
|
-
const int64_t ky
|
|
31
|
-
const int64_t ix
|
|
32
|
-
|
|
33
|
-
const int64_t
|
|
34
|
-
const int64_t
|
|
35
|
-
const int64_t
|
|
36
|
-
|
|
37
|
-
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
|
38
|
-
const int64_t iih = oh * s1 + ky * d1 - p1;
|
|
39
|
-
|
|
40
|
-
const int64_t offset_dst =
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
|
30
|
+
const int64_t kx = i / ksize;
|
|
31
|
+
const int64_t kd = kx * ksize;
|
|
32
|
+
const int64_t ky = (i - kd) / OW;
|
|
33
|
+
const int64_t ix = i % OW;
|
|
34
|
+
|
|
35
|
+
const int64_t oh = item_ct1.get_group(1);
|
|
36
|
+
const int64_t batch = item_ct1.get_group(0) / IC;
|
|
37
|
+
const int64_t ic = item_ct1.get_group(0) % IC;
|
|
38
|
+
|
|
39
|
+
const int64_t iiw = (ix * s0) + (kx * d0) - p0;
|
|
40
|
+
const int64_t iih = (oh * s1) + (ky * d1) - p1;
|
|
41
|
+
|
|
42
|
+
const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx);
|
|
43
|
+
|
|
44
|
+
const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset);
|
|
45
|
+
const int64_t offset_src = offset_src_base + (iih * IW) + iiw;
|
|
46
|
+
|
|
47
|
+
const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW);
|
|
48
|
+
const float src_val = out_of_bounds ? 0.0f : x[offset_src];
|
|
49
|
+
|
|
50
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
51
|
+
dst[offset_dst] = sycl::half(src_val);
|
|
52
|
+
} else if constexpr (std::is_same_v<T, float>) {
|
|
53
|
+
dst[offset_dst] = src_val;
|
|
53
54
|
}
|
|
54
55
|
}
|
|
55
56
|
}
|
|
56
57
|
|
|
57
58
|
template <typename T>
|
|
58
|
-
static void
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
int s0, int s1, int p0, int p1, int d0, int d1,
|
|
62
|
-
queue_ptr stream) {
|
|
59
|
+
static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
|
|
60
|
+
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
|
|
61
|
+
int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
|
|
63
62
|
const int64_t parallel_elements = OW * KW * KH;
|
|
64
|
-
const int64_t num_blocks
|
|
63
|
+
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
|
|
65
64
|
|
|
66
65
|
// decrease global range when it exceeds the max int
|
|
67
66
|
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
|
|
67
|
+
|
|
68
68
|
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
|
|
69
69
|
sycl::range<3> local_range(1, 1, local_size);
|
|
70
70
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
71
|
+
const int64_t CHW = IC * KH * KW;
|
|
72
|
+
|
|
73
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
|
|
74
|
+
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
|
|
75
|
+
p0, p1, d0, d1, item_ct1);
|
|
76
|
+
});
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH,
|
|
80
|
+
int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset,
|
|
81
|
+
int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
|
|
82
|
+
if (!stream->get_device().has(sycl::aspect::fp16)) {
|
|
83
|
+
throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported),
|
|
84
|
+
"Device does not support half precision (fp16) operations!");
|
|
82
85
|
}
|
|
86
|
+
im2col_sycl_internal<sycl::half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0,
|
|
87
|
+
p1, d0, d1, stream);
|
|
83
88
|
}
|
|
84
89
|
|
|
85
|
-
void
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
90
|
+
static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
|
|
91
|
+
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0,
|
|
92
|
+
int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
|
|
93
|
+
im2col_sycl_internal<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1,
|
|
94
|
+
d0, d1, stream);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
98
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
99
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
89
100
|
|
|
90
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
91
101
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
92
102
|
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
93
103
|
|
|
94
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
95
|
-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
|
96
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
|
97
|
-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
|
98
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
|
99
|
-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
|
104
|
+
const int32_t s0 = ((const int32_t *) (dst->op_params))[0];
|
|
105
|
+
const int32_t s1 = ((const int32_t *) (dst->op_params))[1];
|
|
106
|
+
const int32_t p0 = ((const int32_t *) (dst->op_params))[2];
|
|
107
|
+
const int32_t p1 = ((const int32_t *) (dst->op_params))[3];
|
|
108
|
+
const int32_t d0 = ((const int32_t *) (dst->op_params))[4];
|
|
109
|
+
const int32_t d1 = ((const int32_t *) (dst->op_params))[5];
|
|
100
110
|
|
|
101
|
-
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
|
111
|
+
const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1;
|
|
102
112
|
|
|
103
113
|
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
|
104
114
|
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
|
105
|
-
const int64_t IW =
|
|
115
|
+
const int64_t IW = src1->ne[0];
|
|
106
116
|
|
|
107
117
|
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
|
108
|
-
const int64_t KW =
|
|
118
|
+
const int64_t KW = src0->ne[0];
|
|
109
119
|
|
|
110
120
|
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
|
111
|
-
const int64_t OW =
|
|
121
|
+
const int64_t OW = dst->ne[1];
|
|
112
122
|
|
|
113
|
-
const size_t
|
|
114
|
-
const int64_t batch
|
|
115
|
-
const size_t
|
|
123
|
+
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float);
|
|
124
|
+
const int64_t batch = src1->ne[is_2D ? 3 : 2];
|
|
125
|
+
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float);
|
|
126
|
+
|
|
127
|
+
queue_ptr stream = ctx.stream();
|
|
116
128
|
|
|
117
129
|
if (dst->type == GGML_TYPE_F16) {
|
|
118
|
-
|
|
130
|
+
im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,
|
|
131
|
+
batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
|
119
132
|
} else {
|
|
120
|
-
|
|
133
|
+
im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,
|
|
134
|
+
batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
|
121
135
|
}
|
|
122
|
-
|
|
123
|
-
GGML_UNUSED(src0);
|
|
124
|
-
GGML_UNUSED(src0_dd);
|
|
125
|
-
GGML_UNUSED(ctx);
|
|
126
136
|
}
|
|
@@ -16,8 +16,6 @@
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
18
|
void ggml_sycl_op_im2col(
|
|
19
|
-
ggml_backend_sycl_context & ctx,
|
|
20
|
-
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
|
|
21
|
-
const queue_ptr &main_stream);
|
|
19
|
+
ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
|
22
20
|
|
|
23
21
|
#endif // GGML_SYCL_IM2COL_HPP
|
|
@@ -367,7 +367,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
367
367
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
368
368
|
block_dims),
|
|
369
369
|
[=](sycl::nd_item<3> item_ct1)
|
|
370
|
-
[[
|
|
370
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
371
371
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
372
372
|
nullptr, WARP_SIZE);
|
|
373
373
|
});
|
|
@@ -389,7 +389,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
389
389
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
390
390
|
block_dims),
|
|
391
391
|
[=](sycl::nd_item<3> item_ct1)
|
|
392
|
-
[[
|
|
392
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
393
393
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
394
394
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
395
395
|
});
|
|
@@ -397,90 +397,78 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
397
397
|
}
|
|
398
398
|
}
|
|
399
399
|
|
|
400
|
-
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx,
|
|
401
|
-
ggml_tensor* dst, const float* src0_dd,
|
|
402
|
-
const float* src1_dd, float* dst_dd,
|
|
403
|
-
const queue_ptr& main_stream) {
|
|
400
|
+
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
404
401
|
|
|
405
|
-
GGML_ASSERT(
|
|
402
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
406
403
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
407
404
|
|
|
408
|
-
const int64_t ne00 =
|
|
409
|
-
const int64_t nrows = ggml_nrows(
|
|
405
|
+
const int64_t ne00 = dst->src[0]->ne[0];
|
|
406
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
407
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
408
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
409
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
410
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
410
411
|
|
|
411
412
|
float eps;
|
|
412
413
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
413
414
|
|
|
414
415
|
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
|
415
|
-
|
|
416
|
-
(void)src1;
|
|
417
|
-
(void)dst;
|
|
418
|
-
(void)src1_dd;
|
|
419
416
|
}
|
|
420
417
|
|
|
421
|
-
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx,
|
|
422
|
-
const ggml_tensor* src1, ggml_tensor* dst,
|
|
423
|
-
const float* src0_dd, const float* src1_dd,
|
|
424
|
-
float* dst_dd,
|
|
425
|
-
const queue_ptr& main_stream) {
|
|
418
|
+
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
426
419
|
|
|
427
|
-
GGML_ASSERT(
|
|
420
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
428
421
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
429
422
|
|
|
430
423
|
int num_groups = dst->op_params[0];
|
|
424
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
425
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
426
|
+
|
|
427
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
428
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
431
429
|
|
|
432
430
|
float eps;
|
|
433
431
|
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
|
434
432
|
|
|
435
|
-
int group_size =
|
|
436
|
-
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size,
|
|
437
|
-
|
|
438
|
-
(void)src1;
|
|
439
|
-
(void)dst;
|
|
440
|
-
(void)src1_dd;
|
|
441
|
-
GGML_UNUSED(ctx);
|
|
433
|
+
int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
|
|
434
|
+
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
|
|
442
435
|
}
|
|
443
436
|
|
|
444
|
-
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx,
|
|
445
|
-
const ggml_tensor* src1, ggml_tensor* dst,
|
|
446
|
-
const float* src0_dd, const float* src1_dd,
|
|
447
|
-
float* dst_dd,
|
|
448
|
-
const queue_ptr& main_stream) {
|
|
437
|
+
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
449
438
|
|
|
450
|
-
GGML_ASSERT(
|
|
439
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
451
440
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
452
441
|
|
|
453
|
-
const int64_t ne00 =
|
|
454
|
-
const int64_t nrows = ggml_nrows(
|
|
442
|
+
const int64_t ne00 = dst->src[0]->ne[0];
|
|
443
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
444
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
445
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
446
|
+
|
|
447
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
448
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
455
449
|
|
|
456
450
|
float eps;
|
|
457
451
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
458
452
|
|
|
459
453
|
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
|
460
|
-
|
|
461
|
-
(void)src1;
|
|
462
|
-
(void)dst;
|
|
463
|
-
(void)src1_dd;
|
|
464
454
|
}
|
|
465
455
|
|
|
466
|
-
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx,
|
|
467
|
-
const ggml_tensor* src1, ggml_tensor* dst,
|
|
468
|
-
const float* src0_dd, const float* src1_dd,
|
|
469
|
-
float* dst_dd,
|
|
470
|
-
const queue_ptr& main_stream) {
|
|
456
|
+
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
471
457
|
|
|
472
|
-
GGML_ASSERT(
|
|
458
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
473
459
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
474
460
|
|
|
475
|
-
|
|
476
|
-
|
|
461
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
462
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
463
|
+
|
|
464
|
+
const int64_t ne00 = dst->src[0]->ne[0];
|
|
465
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
466
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
467
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
477
468
|
|
|
478
469
|
float eps;
|
|
479
470
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
480
471
|
|
|
481
472
|
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
|
482
473
|
|
|
483
|
-
(void)src1;
|
|
484
|
-
(void)dst;
|
|
485
|
-
(void)src1_dd;
|
|
486
474
|
}
|
|
@@ -15,27 +15,12 @@
|
|
|
15
15
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
|
-
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
const float* src0_dd, const float* src1_dd,
|
|
26
|
-
float* dst_dd,
|
|
27
|
-
const queue_ptr& main_stream);
|
|
28
|
-
|
|
29
|
-
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
30
|
-
const ggml_tensor* src1, ggml_tensor* dst,
|
|
31
|
-
const float* src0_dd, const float* src1_dd,
|
|
32
|
-
float* dst_dd,
|
|
33
|
-
const queue_ptr& main_stream);
|
|
34
|
-
|
|
35
|
-
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
36
|
-
const ggml_tensor* src1, ggml_tensor* dst,
|
|
37
|
-
const float* src0_dd, const float* src1_dd,
|
|
38
|
-
float* dst_dd,
|
|
39
|
-
const queue_ptr& main_stream);
|
|
18
|
+
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
19
|
+
|
|
20
|
+
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
21
|
+
|
|
22
|
+
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
23
|
+
|
|
24
|
+
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
40
25
|
|
|
41
26
|
#endif // GGML_SYCL_NORM_HPP
|
|
@@ -1,8 +1,5 @@
|
|
|
1
|
-
#include <sycl/sycl.hpp>
|
|
2
|
-
#include <oneapi/mkl.hpp>
|
|
3
1
|
#include "outprod.hpp"
|
|
4
2
|
|
|
5
|
-
|
|
6
3
|
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
7
4
|
const ggml_tensor *src0 = dst->src[0];
|
|
8
5
|
const ggml_tensor *src1 = dst->src[1];
|
|
@@ -34,20 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
34
31
|
|
|
35
32
|
// Handle transposition of src1
|
|
36
33
|
const bool src1_T = ggml_is_transposed(src1);
|
|
37
|
-
const oneapi::
|
|
38
|
-
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
|
34
|
+
const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
|
|
39
35
|
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
|
40
36
|
|
|
41
37
|
try {
|
|
42
|
-
// Perform matrix multiplication using
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
|
46
|
-
ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
47
|
-
#else
|
|
48
|
-
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
|
49
|
-
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
50
|
-
#endif
|
|
38
|
+
// Perform matrix multiplication using oneMath GEMM
|
|
39
|
+
oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
|
|
40
|
+
ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
51
41
|
}
|
|
52
42
|
catch (sycl::exception const& exc) {
|
|
53
43
|
std::cerr << exc.what() << std::endl;
|