@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -0,0 +1,162 @@
1
+
2
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
3
+ find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
4
+
5
+ if (NOT glslc_executable)
6
+ message(FATAL_ERROR "glslc not found")
7
+ endif()
8
+
9
+ add_library(ggml-kompute
10
+ ggml-kompute.cpp
11
+ ../../include/ggml-kompute.h
12
+ )
13
+
14
+ target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
15
+ target_include_directories(ggml-kompute PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
16
+
17
+ add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
18
+
19
+ function(compile_shader)
20
+ set(options)
21
+ set(oneValueArgs)
22
+ set(multiValueArgs SOURCES)
23
+ cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
24
+ foreach(source ${compile_shader_SOURCES})
25
+ get_filename_component(filename ${source} NAME)
26
+ set(spv_file ${filename}.spv)
27
+ add_custom_command(
28
+ OUTPUT ${spv_file}
29
+ DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
30
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
31
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
32
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
33
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
34
+ COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
35
+ COMMENT "Compiling ${source} to ${spv_file}"
36
+ )
37
+
38
+ get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
39
+ set(FILE_NAME "shader${RAW_FILE_NAME}")
40
+ string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
41
+ string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
42
+ string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
43
+ set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
44
+ message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
45
+ if(CMAKE_GENERATOR MATCHES "Visual Studio")
46
+ add_custom_command(
47
+ OUTPUT ${OUTPUT_HEADER_FILE}
48
+ COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
49
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
50
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
51
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
52
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
53
+ COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
54
+ COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
55
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
56
+ DEPENDS ${spv_file} xxd
57
+ COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
58
+ )
59
+ else()
60
+ add_custom_command(
61
+ OUTPUT ${OUTPUT_HEADER_FILE}
62
+ COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
63
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
64
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
65
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
66
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
67
+ COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
68
+ COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
69
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
70
+ DEPENDS ${spv_file} xxd
71
+ COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
72
+ )
73
+ endif()
74
+ endforeach()
75
+ endfunction()
76
+
77
+ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
78
+ message(STATUS "Kompute found")
79
+ set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
80
+ add_subdirectory(kompute)
81
+
82
+ # Compile our shaders
83
+ compile_shader(SOURCES
84
+ kompute-shaders/op_scale.comp
85
+ kompute-shaders/op_scale_8.comp
86
+ kompute-shaders/op_add.comp
87
+ kompute-shaders/op_addrow.comp
88
+ kompute-shaders/op_mul.comp
89
+ kompute-shaders/op_silu.comp
90
+ kompute-shaders/op_relu.comp
91
+ kompute-shaders/op_gelu.comp
92
+ kompute-shaders/op_softmax.comp
93
+ kompute-shaders/op_norm.comp
94
+ kompute-shaders/op_rmsnorm.comp
95
+ kompute-shaders/op_diagmask.comp
96
+ kompute-shaders/op_mul_mat_mat_f32.comp
97
+ kompute-shaders/op_mul_mat_f16.comp
98
+ kompute-shaders/op_mul_mat_q8_0.comp
99
+ kompute-shaders/op_mul_mat_q4_0.comp
100
+ kompute-shaders/op_mul_mat_q4_1.comp
101
+ kompute-shaders/op_mul_mat_q4_k.comp
102
+ kompute-shaders/op_mul_mat_q6_k.comp
103
+ kompute-shaders/op_getrows_f32.comp
104
+ kompute-shaders/op_getrows_f16.comp
105
+ kompute-shaders/op_getrows_q4_0.comp
106
+ kompute-shaders/op_getrows_q4_1.comp
107
+ kompute-shaders/op_getrows_q6_k.comp
108
+ kompute-shaders/op_rope_f16.comp
109
+ kompute-shaders/op_rope_f32.comp
110
+ kompute-shaders/op_cpy_f16_f16.comp
111
+ kompute-shaders/op_cpy_f16_f32.comp
112
+ kompute-shaders/op_cpy_f32_f16.comp
113
+ kompute-shaders/op_cpy_f32_f32.comp
114
+ )
115
+
116
+ # Create a custom target for our generated shaders
117
+ add_custom_target(generated_shaders DEPENDS
118
+ shaderop_scale.h
119
+ shaderop_scale_8.h
120
+ shaderop_add.h
121
+ shaderop_addrow.h
122
+ shaderop_mul.h
123
+ shaderop_silu.h
124
+ shaderop_relu.h
125
+ shaderop_gelu.h
126
+ shaderop_softmax.h
127
+ shaderop_norm.h
128
+ shaderop_rmsnorm.h
129
+ shaderop_diagmask.h
130
+ shaderop_mul_mat_mat_f32.h
131
+ shaderop_mul_mat_f16.h
132
+ shaderop_mul_mat_q8_0.h
133
+ shaderop_mul_mat_q4_0.h
134
+ shaderop_mul_mat_q4_1.h
135
+ shaderop_mul_mat_q4_k.h
136
+ shaderop_mul_mat_q6_k.h
137
+ shaderop_getrows_f32.h
138
+ shaderop_getrows_f16.h
139
+ shaderop_getrows_q4_0.h
140
+ shaderop_getrows_q4_1.h
141
+ shaderop_getrows_q6_k.h
142
+ shaderop_rope_f16.h
143
+ shaderop_rope_f32.h
144
+ shaderop_cpy_f16_f16.h
145
+ shaderop_cpy_f16_f32.h
146
+ shaderop_cpy_f32_f16.h
147
+ shaderop_cpy_f32_f32.h
148
+ )
149
+
150
+ # Create a custom command that depends on the generated_shaders
151
+ add_custom_command(
152
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
153
+ COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
154
+ DEPENDS generated_shaders
155
+ COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
156
+ )
157
+
158
+ # Add the stamp to the main sources to ensure dependency tracking
159
+ target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
160
+ else()
161
+ message(WARNING "Kompute not found")
162
+ endif()
@@ -20,6 +20,7 @@
20
20
  #include "shaderop_mul_mat_q8_0.h"
21
21
  #include "shaderop_mul_mat_q4_0.h"
22
22
  #include "shaderop_mul_mat_q4_1.h"
23
+ #include "shaderop_mul_mat_q4_k.h"
23
24
  #include "shaderop_mul_mat_q6_k.h"
24
25
  #include "shaderop_mul_mat_mat_f32.h"
25
26
  #include "shaderop_getrows_f32.h"
@@ -42,6 +43,7 @@
42
43
  #include <cstring>
43
44
  #include <iostream>
44
45
  #include <memory>
46
+ #include <mutex>
45
47
  #include <stdexcept>
46
48
  #include <string>
47
49
  #include <unordered_map>
@@ -273,18 +275,9 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
273
275
  return results;
274
276
  }
275
277
 
276
- // public API returns a C-style array
277
- ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
278
- auto devices = ggml_vk_available_devices_internal(memoryRequired);
279
- *count = devices.size();
280
- if (devices.empty()) {
281
- return nullptr;
282
- }
283
-
284
- size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
285
- auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
286
- memcpy(arr, devices.data(), nbytes);
287
- return arr;
278
+ static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
279
+ static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
280
+ return devices;
288
281
  }
289
282
 
290
283
  static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
@@ -341,7 +334,7 @@ ggml_vk_device ggml_vk_current_device() {
341
334
  if (!komputeManager()->hasDevice())
342
335
  return ggml_vk_device();
343
336
 
344
- auto devices = ggml_vk_available_devices_internal(0);
337
+ auto devices = ggml_vk_available_devices();
345
338
  ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
346
339
  GGML_ASSERT(!devices.empty());
347
340
  return devices.front();
@@ -1075,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
1075
1068
  ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1076
1069
  }
1077
1070
 
1071
+ static void ggml_vk_mul_mat_q4_k(
1072
+ kp::Sequence& seq,
1073
+ const std::shared_ptr<kp::Tensor>& inA,
1074
+ const std::shared_ptr<kp::Tensor>& inB,
1075
+ const std::shared_ptr<kp::Tensor>& out,
1076
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1077
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1078
+ int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1079
+ int32_t ne1, int32_t r2, int32_t r3
1080
+ ) {
1081
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1082
+ kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1083
+
1084
+ struct PushConstants {
1085
+ uint32_t inAOff, inBOff, outOff;
1086
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1087
+ } pushConsts {
1088
+ 0, 0, 0,
1089
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1090
+ };
1091
+
1092
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1093
+ if (!komputeManager()->hasAlgorithm(__func__)) {
1094
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
1095
+ } else {
1096
+ s_algo = komputeManager()->getAlgorithm(__func__);
1097
+ s_algo->setTensors({inA, inB, out});
1098
+ s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
1099
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1100
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1101
+ }
1102
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1103
+ }
1104
+
1078
1105
  static void ggml_vk_mul_mat_q6_k(
1079
1106
  kp::Sequence& seq,
1080
1107
  const std::shared_ptr<kp::Tensor>& inA,
@@ -1323,17 +1350,7 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
1323
1350
  ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
1324
1351
  }
1325
1352
 
1326
- static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1327
- switch (op->type) {
1328
- case GGML_TYPE_F16:
1329
- case GGML_TYPE_F32:
1330
- case GGML_TYPE_Q4_0:
1331
- case GGML_TYPE_Q4_1:
1332
- break;
1333
- default:
1334
- return false;
1335
- }
1336
-
1353
+ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1337
1354
  switch (op->op) {
1338
1355
  case GGML_OP_UNARY:
1339
1356
  switch (ggml_get_unary_op(op)) {
@@ -1402,6 +1419,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1402
1419
  case GGML_TYPE_Q8_0:
1403
1420
  case GGML_TYPE_Q4_0:
1404
1421
  case GGML_TYPE_Q4_1:
1422
+ case GGML_TYPE_Q4_K:
1405
1423
  return true;
1406
1424
  default:
1407
1425
  ;
@@ -1410,6 +1428,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1410
1428
  ;
1411
1429
  }
1412
1430
  return false;
1431
+
1432
+ GGML_UNUSED(dev);
1413
1433
  }
1414
1434
 
1415
1435
  static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
@@ -1458,11 +1478,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1458
1478
 
1459
1479
  any_commands_recorded = true;
1460
1480
 
1461
- if (!ggml_vk_supports_op(dst)) {
1462
- fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1463
- GGML_ABORT("unsupported op");
1464
- }
1465
-
1466
1481
  const int32_t ne00 = src0 ? src0->ne[0] : 0;
1467
1482
  const int32_t ne01 = src0 ? src0->ne[1] : 0;
1468
1483
  const int32_t ne02 = src0 ? src0->ne[2] : 0;
@@ -1656,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1656
1671
  ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1657
1672
  );
1658
1673
  break;
1674
+ case GGML_TYPE_Q4_K:
1675
+ ggml_vk_mul_mat_q4_k(
1676
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1677
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1678
+ );
1679
+ break;
1659
1680
  case GGML_TYPE_Q6_K:
1660
1681
  ggml_vk_mul_mat_q6_k(
1661
1682
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
@@ -1820,11 +1841,6 @@ static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
1820
1841
  }
1821
1842
  }
1822
1843
 
1823
- static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
1824
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
1825
- return ctx->name.c_str();
1826
- }
1827
-
1828
1844
  static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1829
1845
  auto * memory = (ggml_vk_memory *)buffer->context;
1830
1846
  if (ggml_vk_has_device()) {
@@ -1868,7 +1884,6 @@ static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint
1868
1884
  }
1869
1885
 
1870
1886
  static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1871
- /* .get_name = */ ggml_backend_kompute_buffer_get_name,
1872
1887
  /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1873
1888
  /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1874
1889
  /* .init_tensor = */ NULL,
@@ -1913,25 +1928,31 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1913
1928
  };
1914
1929
 
1915
1930
  ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1916
- static std::vector<ggml_backend_buffer_type> bufts = []() {
1917
- std::vector<ggml_backend_buffer_type> vec;
1918
- auto devices = ggml_vk_available_devices_internal(0);
1919
- vec.reserve(devices.size());
1920
-
1921
- for (const auto & dev : devices) {
1922
- vec.push_back({
1923
- /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1924
- /* .device = */ nullptr,
1925
- /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
1926
- });
1931
+ static std::mutex mutex;
1932
+ std::lock_guard<std::mutex> lock(mutex);
1933
+
1934
+ auto devices = ggml_vk_available_devices();
1935
+ int32_t device_count = (int32_t) devices.size();
1936
+ GGML_ASSERT(device < device_count);
1937
+ GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
1938
+
1939
+ static ggml_backend_buffer_type
1940
+ ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
1941
+
1942
+ static bool ggml_backend_kompute_buffer_type_initialized = false;
1943
+
1944
+ if (!ggml_backend_kompute_buffer_type_initialized) {
1945
+ for (int32_t i = 0; i < device_count; i++) {
1946
+ ggml_backend_kompute_buffer_types[i] = {
1947
+ /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1948
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
1949
+ /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
1950
+ };
1927
1951
  }
1928
- return vec;
1929
- }();
1952
+ ggml_backend_kompute_buffer_type_initialized = true;
1953
+ }
1930
1954
 
1931
- auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
1932
- return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
1933
- });
1934
- return it < bufts.end() ? &*it : nullptr;
1955
+ return &ggml_backend_kompute_buffer_types[device];
1935
1956
  }
1936
1957
 
1937
1958
  // backend
@@ -1953,31 +1974,15 @@ static void ggml_backend_kompute_free(ggml_backend_t backend) {
1953
1974
  delete backend;
1954
1975
  }
1955
1976
 
1956
- static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
1957
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1958
- return ggml_backend_kompute_buffer_type(ctx->device);
1959
- }
1960
-
1961
1977
  static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1962
1978
  auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1963
1979
  ggml_vk_graph_compute(ctx, cgraph);
1964
1980
  return GGML_STATUS_SUCCESS;
1965
1981
  }
1966
1982
 
1967
- static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1968
- GGML_UNUSED(backend);
1969
- return ggml_vk_supports_op(op);
1970
- }
1971
-
1972
- static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1973
- GGML_UNUSED(backend);
1974
- return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
1975
- }
1976
-
1977
1983
  static struct ggml_backend_i kompute_backend_i = {
1978
1984
  /* .get_name = */ ggml_backend_kompute_name,
1979
1985
  /* .free = */ ggml_backend_kompute_free,
1980
- /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
1981
1986
  /* .set_tensor_async = */ NULL,
1982
1987
  /* .get_tensor_async = */ NULL,
1983
1988
  /* .cpy_tensor_async = */ NULL,
@@ -1987,9 +1992,6 @@ static struct ggml_backend_i kompute_backend_i = {
1987
1992
  /* .graph_plan_update = */ NULL,
1988
1993
  /* .graph_plan_compute = */ NULL,
1989
1994
  /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1990
- /* .supports_op = */ ggml_backend_kompute_supports_op,
1991
- /* .supports_buft = */ ggml_backend_kompute_supports_buft,
1992
- /* .offload_op = */ NULL,
1993
1995
  /* .event_record = */ NULL,
1994
1996
  /* .event_wait = */ NULL,
1995
1997
  };
@@ -2006,7 +2008,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
2006
2008
  ggml_backend_t kompute_backend = new ggml_backend {
2007
2009
  /* .guid = */ ggml_backend_kompute_guid(),
2008
2010
  /* .interface = */ kompute_backend_i,
2009
- /* .device = */ nullptr,
2011
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
2010
2012
  /* .context = */ s_kompute_context,
2011
2013
  };
2012
2014
 
@@ -2016,3 +2018,167 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
2016
2018
  bool ggml_backend_is_kompute(ggml_backend_t backend) {
2017
2019
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
2018
2020
  }
2021
+
2022
+ static size_t ggml_backend_kompute_get_device_count() {
2023
+ auto devices = ggml_vk_available_devices();
2024
+ return devices.size();
2025
+ }
2026
+
2027
+ static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
2028
+ auto devices = ggml_vk_available_devices();
2029
+ GGML_ASSERT((size_t) device < devices.size());
2030
+ snprintf(description, description_size, "%s", devices[device].name);
2031
+ }
2032
+
2033
+ static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
2034
+ auto devices = ggml_vk_available_devices();
2035
+ GGML_ASSERT((size_t) device < devices.size());
2036
+ *total = devices[device].heapSize;
2037
+ *free = devices[device].heapSize;
2038
+ }
2039
+
2040
+ //////////////////////////
2041
+
2042
+ struct ggml_backend_kompute_device_context {
2043
+ int device;
2044
+ std::string name;
2045
+ std::string description;
2046
+ };
2047
+
2048
+ static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
2049
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2050
+ return ctx->name.c_str();
2051
+ }
2052
+
2053
+ static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
2054
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2055
+ return ctx->description.c_str();
2056
+ }
2057
+
2058
+ static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2059
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2060
+ ggml_backend_kompute_get_device_memory(ctx->device, free, total);
2061
+ }
2062
+
2063
+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
2064
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2065
+ return ggml_backend_kompute_buffer_type(ctx->device);
2066
+ }
2067
+
2068
+ static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2069
+ if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
2070
+ return false;
2071
+ }
2072
+
2073
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2074
+ ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
2075
+
2076
+ return buft_ctx->device == ctx->device;
2077
+ }
2078
+
2079
+ static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
2080
+ GGML_UNUSED(dev);
2081
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
2082
+ }
2083
+
2084
+ static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2085
+ props->name = ggml_backend_kompute_device_get_name(dev);
2086
+ props->description = ggml_backend_kompute_device_get_description(dev);
2087
+ props->type = ggml_backend_kompute_device_get_type(dev);
2088
+ ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
2089
+ props->caps = {
2090
+ /* async = */ false,
2091
+ /* host_buffer = */ false,
2092
+ /* .buffer_from_host_ptr = */ false,
2093
+ /* events = */ false,
2094
+ };
2095
+ }
2096
+
2097
+ static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
2098
+ GGML_UNUSED(params);
2099
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2100
+ return ggml_backend_kompute_init(ctx->device);
2101
+ }
2102
+
2103
+ static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2104
+ const int min_batch_size = 32;
2105
+
2106
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2107
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2108
+
2109
+ GGML_UNUSED(dev);
2110
+ }
2111
+
2112
+ static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2113
+ /* .get_name = */ ggml_backend_kompute_device_get_name,
2114
+ /* .get_description = */ ggml_backend_kompute_device_get_description,
2115
+ /* .get_memory = */ ggml_backend_kompute_device_get_memory,
2116
+ /* .get_type = */ ggml_backend_kompute_device_get_type,
2117
+ /* .get_props = */ ggml_backend_kompute_device_get_props,
2118
+ /* .init_backend = */ ggml_backend_kompute_device_init,
2119
+ /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2120
+ /* .get_host_buffer_type = */ NULL,
2121
+ /* .buffer_from_host_ptr = */ NULL,
2122
+ /* .supports_op = */ ggml_backend_kompute_device_supports_op,
2123
+ /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2124
+ /* .offload_op = */ ggml_backend_kompute_device_offload_op,
2125
+ /* .event_new = */ NULL,
2126
+ /* .event_free = */ NULL,
2127
+ /* .event_synchronize = */ NULL,
2128
+ };
2129
+
2130
+ static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
2131
+ GGML_UNUSED(reg);
2132
+ return "Kompute";
2133
+ }
2134
+
2135
+ static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
2136
+ GGML_UNUSED(reg);
2137
+ return ggml_backend_kompute_get_device_count();
2138
+ }
2139
+
2140
+ static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
2141
+ static std::vector<ggml_backend_dev_t> devices;
2142
+
2143
+ static bool initialized = false;
2144
+
2145
+ {
2146
+ static std::mutex mutex;
2147
+ std::lock_guard<std::mutex> lock(mutex);
2148
+ if (!initialized) {
2149
+ for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
2150
+ ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2151
+ char desc[256];
2152
+ ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
2153
+ ctx->device = i;
2154
+ ctx->name = "Kompute" + std::to_string(i);
2155
+ ctx->description = desc;
2156
+ devices.push_back(new ggml_backend_device {
2157
+ /* .iface = */ ggml_backend_kompute_device_i,
2158
+ /* .reg = */ reg,
2159
+ /* .context = */ ctx,
2160
+ });
2161
+ }
2162
+ initialized = true;
2163
+ }
2164
+ }
2165
+
2166
+ GGML_ASSERT(device < devices.size());
2167
+ return devices[device];
2168
+ }
2169
+
2170
+ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2171
+ /* .get_name = */ ggml_backend_kompute_reg_get_name,
2172
+ /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2173
+ /* .get_device = */ ggml_backend_kompute_reg_get_device,
2174
+ /* .get_proc_address = */ NULL,
2175
+ };
2176
+
2177
+ ggml_backend_reg_t ggml_backend_kompute_reg() {
2178
+ static ggml_backend_reg reg = {
2179
+ /* .iface = */ ggml_backend_kompute_reg_i,
2180
+ /* .context = */ nullptr,
2181
+ };
2182
+
2183
+ return &reg;
2184
+ }