@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -1,5 +1,17 @@
1
+ cmake_minimum_required(VERSION 3.19)
2
+ project("vulkan-shaders-gen" C CXX)
3
+
1
4
  find_package (Threads REQUIRED)
2
5
 
6
+ if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
7
+ add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
8
+ endif()
9
+ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
10
+ add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
11
+ endif()
12
+ if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
13
+ add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
14
+ endif()
3
15
  set(TARGET vulkan-shaders-gen)
4
16
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
5
17
  install(TARGETS ${TARGET} RUNTIME)
@@ -295,7 +295,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
295
295
  std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
296
296
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
297
297
 
298
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
298
+ std::map<std::string, std::string> base_dict = {
299
+ {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
300
+ {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
301
+ };
299
302
  std::string shader_name = "matmul";
300
303
 
301
304
  if (matmul_id) {
@@ -313,9 +316,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
313
316
  base_dict["COOPMAT"] = "1";
314
317
  }
315
318
 
316
- base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
317
-
318
- std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
319
+ const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
319
320
 
320
321
  // Shaders with f16 B_TYPE
321
322
  string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
@@ -339,14 +340,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
339
340
 
340
341
  // don't generate f32 variants for coopmat2
341
342
  if (!coopmat2) {
342
- string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
343
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
343
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
344
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
344
345
  }
345
346
 
346
347
  if (tname != "f16" && tname != "f32") {
347
- string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
348
- string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
348
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
349
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
349
350
  }
351
+
352
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
353
+ if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
354
+ string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
355
+ }
356
+ #endif
350
357
  }
351
358
  }
352
359
 
@@ -426,8 +433,9 @@ void process_shaders() {
426
433
  }
427
434
  }
428
435
 
429
- string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
430
- string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
436
+ string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
437
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
438
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
431
439
 
432
440
  // Norms
433
441
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
@@ -445,6 +453,7 @@ void process_shaders() {
445
453
 
446
454
  for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
447
455
  string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
456
+ string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
448
457
  string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
449
458
  }
450
459
 
@@ -456,6 +465,8 @@ void process_shaders() {
456
465
  string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
457
466
 
458
467
  string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
468
+ string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
469
+ string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
459
470
 
460
471
  string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
461
472
 
@@ -4,6 +4,7 @@
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-threading.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml.h"
8
9
 
9
10
  // FIXME: required here for quantization functions
@@ -382,58 +383,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
382
383
  }
383
384
  }
384
385
 
385
- // FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
386
- // currently, the ggml_cpu_has_* functions are entirely compile-time
387
386
  void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
388
- int64_t i = 0;
389
- #if defined(__F16C__)
390
- //if (ggml_cpu_has_f16c()) {
391
- for (; i + 7 < n; i += 8) {
392
- __m256 x_vec = _mm256_loadu_ps(x + i);
393
- __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
394
- _mm_storeu_si128((__m128i *)(y + i), y_vec);
395
- }
396
- for(; i + 3 < n; i += 4) {
397
- __m128 x_vec = _mm_loadu_ps(x + i);
398
- __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
399
- _mm_storel_epi64((__m128i *)(y + i), y_vec);
400
- }
401
- //}
402
- #endif
403
- for (; i < n; i++) {
387
+ int i = 0;
388
+ for (; i < n; ++i) {
404
389
  y[i] = GGML_FP32_TO_FP16(x[i]);
405
390
  }
406
391
  }
407
392
 
408
393
  void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
409
- int64_t i = 0;
410
- #if defined(__AVX512F__)
411
- //if (ggml_cpu_has_avx512()) {
412
- for (; i + 16 <= n; i += 16) {
413
- _mm512_storeu_ps(y + i,
414
- _mm512_castsi512_ps(
415
- _mm512_slli_epi32(
416
- _mm512_cvtepu16_epi32(
417
- _mm256_loadu_si256(
418
- (const __m256i *)(x + i))),
419
- 16)));
420
- }
421
- //}
422
- #endif
423
- #if defined(__AVX2__)
424
- //if (ggml_cpu_has_avx2()) {
425
- for (; i + 8 <= n; i += 8) {
426
- _mm256_storeu_ps(y + i,
427
- _mm256_castsi256_ps(
428
- _mm256_slli_epi32(
429
- _mm256_cvtepu16_epi32(
430
- _mm_loadu_si128(
431
- (const __m128i *)(x + i))),
432
- 16)));
433
- }
434
- //}
435
- #endif
436
- for (; i < n; i++) {
394
+ int i = 0;
395
+ for (; i < n; ++i) {
437
396
  y[i] = GGML_BF16_TO_FP32(x[i]);
438
397
  }
439
398
  }
@@ -956,6 +915,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
956
915
  "CONV_TRANSPOSE_1D",
957
916
  "IM2COL",
958
917
  "IM2COL_BACK",
918
+ "CONV_2D_DW",
959
919
  "CONV_TRANSPOSE_2D",
960
920
  "POOL_1D",
961
921
  "POOL_2D",
@@ -982,23 +942,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
982
942
 
983
943
  "UNARY",
984
944
 
985
- "MAP_UNARY",
986
- "MAP_BINARY",
987
-
988
- "MAP_CUSTOM1_F32",
989
- "MAP_CUSTOM2_F32",
990
- "MAP_CUSTOM3_F32",
991
-
992
945
  "MAP_CUSTOM1",
993
946
  "MAP_CUSTOM2",
994
947
  "MAP_CUSTOM3",
995
948
 
949
+ "CUSTOM",
950
+
996
951
  "CROSS_ENTROPY_LOSS",
997
952
  "CROSS_ENTROPY_LOSS_BACK",
998
953
  "OPT_STEP_ADAMW",
999
954
  };
1000
955
 
1001
- static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
956
+ static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1002
957
 
1003
958
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1004
959
  "none",
@@ -1055,6 +1010,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1055
1010
  "conv_transpose_1d(x)",
1056
1011
  "im2col(x)",
1057
1012
  "im2col_back(x)",
1013
+ "conv_2d_dw(x)",
1058
1014
  "conv_transpose_2d(x)",
1059
1015
  "pool_1d(x)",
1060
1016
  "pool_2d(x)",
@@ -1081,23 +1037,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1081
1037
 
1082
1038
  "unary(x)",
1083
1039
 
1084
- "f(x)",
1085
- "f(x,y)",
1086
-
1087
- "custom_f32(x)",
1088
- "custom_f32(x,y)",
1089
- "custom_f32(x,y,z)",
1040
+ "map_custom(x)",
1041
+ "map_custom(x,y)",
1042
+ "map_custom(x,y,z)",
1090
1043
 
1091
1044
  "custom(x)",
1092
- "custom(x,y)",
1093
- "custom(x,y,z)",
1094
1045
 
1095
1046
  "cross_entropy_loss(x,y)",
1096
1047
  "cross_entropy_loss_back(x,y)",
1097
1048
  "adamw(x)",
1098
1049
  };
1099
1050
 
1100
- static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1051
+ static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1101
1052
 
1102
1053
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1103
1054
 
@@ -1159,6 +1110,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
1159
1110
  }
1160
1111
 
1161
1112
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
1113
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1114
+ if (tensor->ne[i] <= 0) {
1115
+ return 0;
1116
+ }
1117
+ }
1118
+
1162
1119
  size_t nbytes;
1163
1120
  const size_t blck_size = ggml_blck_size(tensor->type);
1164
1121
  if (blck_size == 1) {
@@ -1348,6 +1305,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
1348
1305
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1349
1306
  }
1350
1307
 
1308
+ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1309
+ return
1310
+ tensor->nb[0] > tensor->nb[2] &&
1311
+ tensor->nb[1] > tensor->nb[0] &&
1312
+ tensor->nb[2] == ggml_type_size(tensor->type);
1313
+ }
1314
+
1351
1315
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1352
1316
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1353
1317
 
@@ -4054,6 +4018,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
4054
4018
  return result;
4055
4019
  }
4056
4020
 
4021
+ // ggml_conv_2d_dw_direct
4022
+
4023
+ struct ggml_tensor * ggml_conv_2d_dw_direct(
4024
+ struct ggml_context * ctx,
4025
+ struct ggml_tensor * a,
4026
+ struct ggml_tensor * b,
4027
+ int stride0,
4028
+ int stride1,
4029
+ int pad0,
4030
+ int pad1,
4031
+ int dilation0,
4032
+ int dilation1) {
4033
+ GGML_ASSERT(a->ne[2] == 1);
4034
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
4035
+ int64_t ne[4];
4036
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4037
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4038
+ ne[2] = b->ne[2];
4039
+ ne[3] = b->ne[3];
4040
+
4041
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4042
+
4043
+ if (ggml_is_contiguous_channels(b)) {
4044
+ // Result will be permuted the same way as input (CWHN order)
4045
+ const int64_t type_size = ggml_type_size(result->type);
4046
+ GGML_ASSERT(ggml_blck_size(result->type) == 1);
4047
+ result->nb[0] = result->ne[2] * type_size;
4048
+ result->nb[1] = result->ne[0] * result->nb[0];
4049
+ result->nb[2] = type_size;
4050
+ }
4051
+
4052
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4053
+ ggml_set_op_params(result, params, sizeof(params));
4054
+
4055
+ result->op = GGML_OP_CONV_2D_DW;
4056
+ result->src[0] = a;
4057
+ result->src[1] = b;
4058
+ return result;
4059
+ }
4060
+
4057
4061
  // ggml_conv_transpose_2d_p0
4058
4062
 
4059
4063
  static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4178,7 +4182,8 @@ static struct ggml_tensor * ggml_upscale_impl(
4178
4182
  int ne0,
4179
4183
  int ne1,
4180
4184
  int ne2,
4181
- int ne3) {
4185
+ int ne3,
4186
+ enum ggml_scale_mode mode) {
4182
4187
  GGML_ASSERT(a->ne[0] <= ne0);
4183
4188
  GGML_ASSERT(a->ne[1] <= ne1);
4184
4189
  GGML_ASSERT(a->ne[2] <= ne2);
@@ -4186,6 +4191,8 @@ static struct ggml_tensor * ggml_upscale_impl(
4186
4191
 
4187
4192
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4188
4193
 
4194
+ ggml_set_op_params_i32(result, 0, mode);
4195
+
4189
4196
  result->op = GGML_OP_UPSCALE;
4190
4197
  result->src[0] = a;
4191
4198
 
@@ -4195,8 +4202,9 @@ static struct ggml_tensor * ggml_upscale_impl(
4195
4202
  struct ggml_tensor * ggml_upscale(
4196
4203
  struct ggml_context * ctx,
4197
4204
  struct ggml_tensor * a,
4198
- int scale_factor) {
4199
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4205
+ int scale_factor,
4206
+ enum ggml_scale_mode mode) {
4207
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4200
4208
  }
4201
4209
 
4202
4210
  struct ggml_tensor * ggml_upscale_ext(
@@ -4205,8 +4213,9 @@ struct ggml_tensor * ggml_upscale_ext(
4205
4213
  int ne0,
4206
4214
  int ne1,
4207
4215
  int ne2,
4208
- int ne3) {
4209
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4216
+ int ne3,
4217
+ enum ggml_scale_mode mode) {
4218
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4210
4219
  }
4211
4220
 
4212
4221
  // ggml_pad
@@ -4369,7 +4378,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
4369
4378
  }
4370
4379
 
4371
4380
  // permute(0, 2, 1, 3)
4372
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4381
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4373
4382
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4374
4383
 
4375
4384
  float params[] = { scale, max_bias, logit_softcap };
@@ -4836,179 +4845,6 @@ struct ggml_tensor * ggml_unary_inplace(
4836
4845
  return ggml_unary_impl(ctx, a, op, true);
4837
4846
  }
4838
4847
 
4839
- // ggml_map_unary
4840
-
4841
- static struct ggml_tensor * ggml_map_unary_impl_f32(
4842
- struct ggml_context * ctx,
4843
- struct ggml_tensor * a,
4844
- const ggml_unary_op_f32_t fun,
4845
- bool inplace) {
4846
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4847
-
4848
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4849
-
4850
- result->op = GGML_OP_MAP_UNARY;
4851
- result->src[0] = a;
4852
-
4853
- return result;
4854
- }
4855
-
4856
- struct ggml_tensor * ggml_map_unary_f32(
4857
- struct ggml_context * ctx,
4858
- struct ggml_tensor * a,
4859
- const ggml_unary_op_f32_t fun) {
4860
- return ggml_map_unary_impl_f32(ctx, a, fun, false);
4861
- }
4862
-
4863
- struct ggml_tensor * ggml_map_unary_inplace_f32(
4864
- struct ggml_context * ctx,
4865
- struct ggml_tensor * a,
4866
- const ggml_unary_op_f32_t fun) {
4867
- return ggml_map_unary_impl_f32(ctx, a, fun, true);
4868
- }
4869
-
4870
- // ggml_map_binary
4871
-
4872
- static struct ggml_tensor * ggml_map_binary_impl_f32(
4873
- struct ggml_context * ctx,
4874
- struct ggml_tensor * a,
4875
- struct ggml_tensor * b,
4876
- const ggml_binary_op_f32_t fun,
4877
- bool inplace) {
4878
- GGML_ASSERT(ggml_are_same_shape(a, b));
4879
-
4880
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4881
-
4882
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4883
-
4884
- result->op = GGML_OP_MAP_BINARY;
4885
- result->src[0] = a;
4886
- result->src[1] = b;
4887
-
4888
- return result;
4889
- }
4890
-
4891
- struct ggml_tensor * ggml_map_binary_f32(
4892
- struct ggml_context * ctx,
4893
- struct ggml_tensor * a,
4894
- struct ggml_tensor * b,
4895
- const ggml_binary_op_f32_t fun) {
4896
- return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4897
- }
4898
-
4899
- struct ggml_tensor * ggml_map_binary_inplace_f32(
4900
- struct ggml_context * ctx,
4901
- struct ggml_tensor * a,
4902
- struct ggml_tensor * b,
4903
- const ggml_binary_op_f32_t fun) {
4904
- return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4905
- }
4906
-
4907
- // ggml_map_custom1_f32
4908
-
4909
- static struct ggml_tensor * ggml_map_custom1_impl_f32(
4910
- struct ggml_context * ctx,
4911
- struct ggml_tensor * a,
4912
- const ggml_custom1_op_f32_t fun,
4913
- bool inplace) {
4914
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4915
-
4916
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4917
-
4918
- result->op = GGML_OP_MAP_CUSTOM1_F32;
4919
- result->src[0] = a;
4920
-
4921
- return result;
4922
- }
4923
-
4924
- struct ggml_tensor * ggml_map_custom1_f32(
4925
- struct ggml_context * ctx,
4926
- struct ggml_tensor * a,
4927
- const ggml_custom1_op_f32_t fun) {
4928
- return ggml_map_custom1_impl_f32(ctx, a, fun, false);
4929
- }
4930
-
4931
- struct ggml_tensor * ggml_map_custom1_inplace_f32(
4932
- struct ggml_context * ctx,
4933
- struct ggml_tensor * a,
4934
- const ggml_custom1_op_f32_t fun) {
4935
- return ggml_map_custom1_impl_f32(ctx, a, fun, true);
4936
- }
4937
-
4938
- // ggml_map_custom2_f32
4939
-
4940
- static struct ggml_tensor * ggml_map_custom2_impl_f32(
4941
- struct ggml_context * ctx,
4942
- struct ggml_tensor * a,
4943
- struct ggml_tensor * b,
4944
- const ggml_custom2_op_f32_t fun,
4945
- bool inplace) {
4946
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4947
-
4948
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4949
-
4950
- result->op = GGML_OP_MAP_CUSTOM2_F32;
4951
- result->src[0] = a;
4952
- result->src[1] = b;
4953
-
4954
- return result;
4955
- }
4956
-
4957
- struct ggml_tensor * ggml_map_custom2_f32(
4958
- struct ggml_context * ctx,
4959
- struct ggml_tensor * a,
4960
- struct ggml_tensor * b,
4961
- const ggml_custom2_op_f32_t fun) {
4962
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
4963
- }
4964
-
4965
- struct ggml_tensor * ggml_map_custom2_inplace_f32(
4966
- struct ggml_context * ctx,
4967
- struct ggml_tensor * a,
4968
- struct ggml_tensor * b,
4969
- const ggml_custom2_op_f32_t fun) {
4970
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
4971
- }
4972
-
4973
- // ggml_map_custom3_f32
4974
-
4975
- static struct ggml_tensor * ggml_map_custom3_impl_f32(
4976
- struct ggml_context * ctx,
4977
- struct ggml_tensor * a,
4978
- struct ggml_tensor * b,
4979
- struct ggml_tensor * c,
4980
- const ggml_custom3_op_f32_t fun,
4981
- bool inplace) {
4982
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4983
-
4984
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4985
-
4986
- result->op = GGML_OP_MAP_CUSTOM3_F32;
4987
- result->src[0] = a;
4988
- result->src[1] = b;
4989
- result->src[2] = c;
4990
-
4991
- return result;
4992
- }
4993
-
4994
- struct ggml_tensor * ggml_map_custom3_f32(
4995
- struct ggml_context * ctx,
4996
- struct ggml_tensor * a,
4997
- struct ggml_tensor * b,
4998
- struct ggml_tensor * c,
4999
- const ggml_custom3_op_f32_t fun) {
5000
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
5001
- }
5002
-
5003
- struct ggml_tensor * ggml_map_custom3_inplace_f32(
5004
- struct ggml_context * ctx,
5005
- struct ggml_tensor * a,
5006
- struct ggml_tensor * b,
5007
- struct ggml_tensor * c,
5008
- const ggml_custom3_op_f32_t fun) {
5009
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
5010
- }
5011
-
5012
4848
  // ggml_map_custom1
5013
4849
 
5014
4850
  static struct ggml_tensor * ggml_map_custom1_impl(
@@ -5027,7 +4863,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
5027
4863
  /*.n_tasks =*/ n_tasks,
5028
4864
  /*.userdata =*/ userdata
5029
4865
  };
5030
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
4866
+ ggml_set_op_params(result, &params, sizeof(params));
5031
4867
 
5032
4868
  result->op = GGML_OP_MAP_CUSTOM1;
5033
4869
  result->src[0] = a;
@@ -5072,7 +4908,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
5072
4908
  /*.n_tasks =*/ n_tasks,
5073
4909
  /*.userdata =*/ userdata
5074
4910
  };
5075
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
4911
+ ggml_set_op_params(result, &params, sizeof(params));
5076
4912
 
5077
4913
  result->op = GGML_OP_MAP_CUSTOM2;
5078
4914
  result->src[0] = a;
@@ -5121,7 +4957,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
5121
4957
  /*.n_tasks =*/ n_tasks,
5122
4958
  /*.userdata =*/ userdata
5123
4959
  };
5124
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
4960
+ ggml_set_op_params(result, &params, sizeof(params));
5125
4961
 
5126
4962
  result->op = GGML_OP_MAP_CUSTOM3;
5127
4963
  result->src[0] = a;
@@ -5153,6 +4989,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
5153
4989
  return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
5154
4990
  }
5155
4991
 
4992
+ struct ggml_tensor * ggml_custom_4d(
4993
+ struct ggml_context * ctx,
4994
+ enum ggml_type type,
4995
+ int64_t ne0,
4996
+ int64_t ne1,
4997
+ int64_t ne2,
4998
+ int64_t ne3,
4999
+ struct ggml_tensor ** args,
5000
+ int n_args,
5001
+ ggml_custom_op_t fun,
5002
+ int n_tasks,
5003
+ void * userdata) {
5004
+
5005
+ GGML_ASSERT(n_args < GGML_MAX_SRC);
5006
+
5007
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
5008
+
5009
+ struct ggml_custom_op_params params = {
5010
+ /*.fun =*/ fun,
5011
+ /*.n_tasks =*/ n_tasks,
5012
+ /*.userdata =*/ userdata
5013
+ };
5014
+ ggml_set_op_params(result, &params, sizeof(params));
5015
+
5016
+ result->op = GGML_OP_CUSTOM;
5017
+ for (int i = 0; i < n_args; i++) {
5018
+ result->src[i] = args[i];
5019
+ }
5020
+
5021
+ return result;
5022
+ }
5023
+
5024
+ struct ggml_tensor * ggml_custom_inplace(
5025
+ struct ggml_context * ctx,
5026
+ struct ggml_tensor * a,
5027
+ struct ggml_tensor ** args,
5028
+ int n_args,
5029
+ ggml_custom_op_t fun,
5030
+ int n_tasks,
5031
+ void * userdata) {
5032
+
5033
+ GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
5034
+
5035
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5036
+
5037
+ struct ggml_custom_op_params params = {
5038
+ /*.fun =*/ fun,
5039
+ /*.n_tasks =*/ n_tasks,
5040
+ /*.userdata =*/ userdata
5041
+ };
5042
+ ggml_set_op_params(result, &params, sizeof(params));
5043
+
5044
+ result->op = GGML_OP_CUSTOM;
5045
+ result->src[0] = a;
5046
+ for (int i = 0; i < n_args; i++) {
5047
+ result->src[i + 1] = args[i];
5048
+ }
5049
+
5050
+ return result;
5051
+ }
5156
5052
  // ggml_cross_entropy_loss
5157
5053
 
5158
5054
  struct ggml_tensor * ggml_cross_entropy_loss(
@@ -932,6 +932,7 @@ static void gguf_check_reserved_keys(const std::string & key, const T val) {
932
932
  if constexpr (std::is_same<T, uint32_t>::value) {
933
933
  GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2");
934
934
  } else {
935
+ GGML_UNUSED(val);
935
936
  GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32");
936
937
  }
937
938
  }