@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -0,0 +1,125 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "im2col.hpp"
14
+
15
+ template <typename T>
16
+ static void im2col_kernel(
17
+ const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
18
+ int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
19
+ int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
20
+ const sycl::nd_item<3> &item_ct1) {
21
+ const int64_t work_group_size = item_ct1.get_local_range(2);
22
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
23
+
24
+ // make each work-item deal with more elements since sycl global range can not exceed max int
25
+ for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
26
+
27
+ const int64_t ksize = OW * (KH > 1 ? KW : 1);
28
+ const int64_t kx = i / ksize;
29
+ const int64_t kd = kx * ksize;
30
+ const int64_t ky = (i - kd) / OW;
31
+ const int64_t ix = i % OW;
32
+
33
+ const int64_t oh = item_ct1.get_group(1);
34
+ const int64_t batch = item_ct1.get_group(0) / IC;
35
+ const int64_t ic = item_ct1.get_group(0) % IC;
36
+
37
+ const int64_t iiw = ix * s0 + kx * d0 - p0;
38
+ const int64_t iih = oh * s1 + ky * d1 - p1;
39
+
40
+ const int64_t offset_dst =
41
+ ((batch * OH + oh) * OW + ix) * CHW +
42
+ (ic * (KW * KH) + ky * KW + kx);
43
+
44
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
45
+ dst[offset_dst] =
46
+ sycl::vec<float, 1>(0.0f)
47
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
48
+ } else {
49
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
50
+ dst[offset_dst] =
51
+ sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
52
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
53
+ }
54
+ }
55
+ }
56
+
57
+ template <typename T>
58
+ static void im2col_sycl(
59
+ const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
60
+ int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
61
+ int s0, int s1, int p0, int p1, int d0, int d1,
62
+ queue_ptr stream) {
63
+ const int64_t parallel_elements = OW * KW * KH;
64
+ const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
65
+
66
+ // decrease global range when it exceeds the max int
67
+ int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
68
+ sycl::range<3> block_nums(batch * IC, OH, num_blocks);
69
+ sycl::range<3> local_range(1, 1, local_size);
70
+
71
+ {
72
+ dpct::has_capability_or_fail(stream->get_device(),
73
+ {sycl::aspect::fp16});
74
+
75
+ stream->parallel_for(
76
+ sycl::nd_range<3>(block_nums * local_range, local_range),
77
+ [=](sycl::nd_item<3> item_ct1) {
78
+ im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
79
+ parallel_elements, (IC * KH * KW), s0, s1, p0,
80
+ p1, d0, d1, item_ct1);
81
+ });
82
+ }
83
+ }
84
+
85
+ void ggml_sycl_op_im2col(
86
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
87
+ ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
88
+ const queue_ptr &main_stream) {
89
+
90
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
91
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
92
+ GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
93
+
94
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
95
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
96
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
97
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
98
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
99
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
100
+
101
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
102
+
103
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
104
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
105
+ const int64_t IW = src1->ne[0];
106
+
107
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
108
+ const int64_t KW = src0->ne[0];
109
+
110
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
111
+ const int64_t OW = dst->ne[1];
112
+
113
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
114
+ const int64_t batch = src1->ne[3];
115
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
116
+
117
+ if (dst->type == GGML_TYPE_F16) {
118
+ im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
119
+ } else {
120
+ im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
121
+ }
122
+
123
+ (void) src0;
124
+ (void) src0_dd;
125
+ }
@@ -0,0 +1,23 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_IM2COL_HPP
14
+ #define GGML_SYCL_IM2COL_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_im2col(
19
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
20
+ ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
21
+ const queue_ptr &main_stream);
22
+
23
+ #endif // GGML_SYCL_IM2COL_HPP
@@ -902,7 +902,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
902
902
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
903
903
  [=](sycl::nd_item<3> item_ct1)
904
904
  [[intel::reqd_sub_group_size(WARP_SIZE)]] {
905
- mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
905
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
906
906
  vx, vy, dst, ncols, nrows, item_ct1);
907
907
  });
908
908
  });
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
225
225
  }
226
226
 
227
227
  static void group_norm_f32_sycl(const float* x, float* dst,
228
- const int num_groups, const int group_size,
228
+ const int num_groups, const float eps, const int group_size,
229
229
  const int ne_elements, queue_ptr stream, int device) {
230
- static const float eps = 1e-6f;
231
230
  if (group_size < 1024) {
232
231
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233
232
  stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
343
342
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
344
343
 
345
344
  int num_groups = dst->op_params[0];
345
+
346
+ float eps;
347
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
348
+
346
349
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
347
- group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
350
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
348
351
 
349
352
  (void)src1;
350
353
  (void)dst;
@@ -41,6 +41,8 @@
41
41
  #define SYCL_ACC_BLOCK_SIZE 256
42
42
  #define SYCL_IM2COL_BLOCK_SIZE 256
43
43
  #define SYCL_POOL2D_BLOCK_SIZE 256
44
+ #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
45
+ #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
44
46
 
45
47
  // dmmv = dequantize_mul_mat_vec
46
48
  #ifndef GGML_SYCL_DMMV_X
@@ -226,7 +226,7 @@ void ggml_sycl_op_rope(
226
226
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
227
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
228
228
 
229
- const bool is_neox = mode & 2;
229
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
230
230
 
231
231
  const int32_t * pos = (const int32_t *) src1_dd;
232
232
 
@@ -0,0 +1,71 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "tsembd.hpp"
14
+
15
+ static void timestep_embedding_f32(
16
+ const float * timesteps, float * dst, const int nb1,
17
+ const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
18
+ // item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
19
+ // item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
20
+ int i = item_ct1.get_group(1);
21
+ int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
22
+ float * embed_data = (float *)((char *)dst + i*nb1);
23
+
24
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
25
+ embed_data[dim] = 0.f;
26
+ }
27
+
28
+ int half = dim / 2;
29
+ if (j >= half) {
30
+ return;
31
+ }
32
+
33
+ float timestep = timesteps[i];
34
+ float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
35
+ float arg = timestep * freq;
36
+ embed_data[j] = sycl::cos(arg);
37
+ embed_data[j + half] = sycl::sin(arg);
38
+ }
39
+
40
+ static void timestep_embedding_f32_sycl(
41
+ const float * x, float * dst, const int ne00, const int nb1,
42
+ const int dim, const int max_period, const queue_ptr& stream) {
43
+ // As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
44
+ int half_ceil = dim / 2;
45
+ int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
46
+ sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
47
+ sycl::range<3> gridDim(1, ne00, num_blocks);
48
+ stream->parallel_for(
49
+ sycl::nd_range<3>(
50
+ gridDim * block_dims, block_dims),
51
+ [=](sycl::nd_item<3> item_ct1) {
52
+ timestep_embedding_f32(
53
+ x, dst, nb1, dim, max_period, item_ct1
54
+ );
55
+ });
56
+ }
57
+
58
+ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
59
+ const ggml_tensor *src1, ggml_tensor * dst) {
60
+ const float * src0_d = (const float *)src0->data;
61
+ float * dst_d = (float *)dst->data;
62
+ dpct::queue_ptr stream = ctx.stream();
63
+
64
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
65
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
66
+
67
+ const int dim = dst->op_params[0];
68
+ const int max_period = dst->op_params[1];
69
+
70
+ timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
71
+ }
@@ -0,0 +1,21 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_TSEMBD_HPP
14
+ #define GGML_SYCL_TSEMBD_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19
+ const ggml_tensor *src1, ggml_tensor * dst);
20
+
21
+ #endif // GGML_SYCL_TSEMBD_HPP