@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -5,15 +5,35 @@ find_package (Threads REQUIRED)
5
5
 
6
6
  if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
7
7
  add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
8
+ message(STATUS "Enabling coopmat glslc support")
8
9
  endif()
9
10
  if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
10
11
  add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
12
+ message(STATUS "Enabling coopmat2 glslc support")
11
13
  endif()
12
14
  if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
13
15
  add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
16
+ message(STATUS "Enabling dot glslc support")
14
17
  endif()
18
+ if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19
+ add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20
+ message(STATUS "Enabling bfloat16 glslc support")
21
+ endif()
22
+
15
23
  set(TARGET vulkan-shaders-gen)
16
24
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
17
25
  install(TARGETS ${TARGET} RUNTIME)
18
26
  target_compile_features(${TARGET} PRIVATE cxx_std_17)
19
27
  target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
28
+
29
+ # Configure output directories for MSVC builds
30
+ if(MSVC)
31
+ # Get the main project's runtime output directory if possible
32
+ if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
33
+ foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
34
+ string(TOUPPER ${CONFIG} CONFIG)
35
+ set_target_properties(${TARGET} PROPERTIES
36
+ RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
37
+ endforeach()
38
+ endif()
39
+ endif()
@@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
63
63
  "iq3_xxs",
64
64
  "iq3_s",
65
65
  "iq4_xs",
66
- "iq4_nl"
66
+ "iq4_nl",
67
+ "bf16",
67
68
  };
68
69
 
69
70
  namespace {
@@ -214,7 +215,7 @@ static std::mutex compile_count_mutex;
214
215
  static std::condition_variable compile_count_cond;
215
216
 
216
217
  void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
217
- std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
218
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
218
219
  std::string out_fname = join_paths(output_dir, name + ".spv");
219
220
  std::string in_path = join_paths(input_dir, in_fname);
220
221
 
@@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
296
297
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
297
298
 
298
299
  std::map<std::string, std::string> base_dict = {
299
- {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
300
300
  {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
301
301
  };
302
302
  std::string shader_name = "matmul";
@@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
318
318
 
319
319
  const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
320
320
 
321
+ auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
322
+ if (t == "bf16") {
323
+ // scalar path promotes to float
324
+ if (!coopmat && !coopmat2) {
325
+ return "float";
326
+ }
327
+ return "bfloat16_t";
328
+ }
329
+ if (coopmat2 || fp16) {
330
+ return "float16_t";
331
+ }
332
+ return "float";
333
+ };
334
+
321
335
  // Shaders with f16 B_TYPE
322
- string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
323
- string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
336
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
337
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
324
338
 
325
- string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
326
- string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
339
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
340
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
341
+
342
+ // bf16
343
+ {
344
+ std::string load_vec_a_unaligned = "1";
345
+ // For aligned matmul loads
346
+ std::string load_vec_a = coopmat2 ? "1" : "4";
347
+
348
+ // scalar path promotes to float
349
+ std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
350
+
351
+ // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
352
+ #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
353
+ if (!(coopmat || coopmat2))
354
+ #endif
355
+ {
356
+ string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
357
+ string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
358
+ }
359
+ }
327
360
 
328
361
  for (const auto& tname : type_names) {
329
362
  std::string load_vec_quant = "2";
@@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
332
365
  else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
333
366
  load_vec_quant = "4";
334
367
 
368
+ if (tname == "bf16") {
369
+ continue;
370
+ }
371
+
335
372
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
336
373
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
337
- std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
374
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
338
375
  // For aligned matmul loads
339
- std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
376
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
340
377
 
341
378
  // don't generate f32 variants for coopmat2
342
379
  if (!coopmat2) {
343
- string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
344
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
380
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
381
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
345
382
  }
346
383
 
347
384
  if (tname != "f16" && tname != "f32") {
348
- string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
349
- string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
385
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
386
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
350
387
  }
351
388
 
352
389
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
353
390
  if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
354
- string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
391
+ string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
355
392
  }
356
393
  #endif
357
394
  }
@@ -384,16 +421,18 @@ void process_shaders() {
384
421
  #endif
385
422
  }
386
423
 
387
- #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
388
424
  // flash attention
389
425
  for (const auto& f16acc : {false, true}) {
390
426
  std::string acctype = f16acc ? "float16_t" : "float";
427
+ std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
391
428
 
392
429
  for (const auto& tname : type_names) {
393
430
  if (tname == "f32") {
394
431
  continue;
395
432
  }
433
+ if (tname == "bf16") continue;
396
434
 
435
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
397
436
  if (tname == "f16") {
398
437
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
399
438
  merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
@@ -402,9 +441,27 @@ void process_shaders() {
402
441
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
403
442
  merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
404
443
  }
444
+ #endif
445
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
446
+ if (tname == "f16") {
447
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
448
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
449
+ } else if (tname == "q4_0" || tname == "q8_0") {
450
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
451
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
452
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
453
+ }
454
+ #endif
455
+ if (tname == "f16") {
456
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
457
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
458
+ } else if (tname == "q4_0" || tname == "q8_0") {
459
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
460
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
461
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
462
+ }
405
463
  }
406
464
  }
407
- #endif
408
465
 
409
466
  for (const auto& tname : type_names) {
410
467
  // mul mat vec
@@ -417,12 +474,12 @@ void process_shaders() {
417
474
  string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
418
475
 
419
476
  // Dequant shaders
420
- if (tname != "f16") {
477
+ if (tname != "f16" && tname != "bf16") {
421
478
  string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
422
479
  }
423
480
 
424
481
  if (!string_ends_with(tname, "_k")) {
425
- shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
482
+ shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
426
483
 
427
484
  if (tname == "f16") {
428
485
  string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
@@ -447,9 +504,13 @@ void process_shaders() {
447
504
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
448
505
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
449
506
  string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
507
+ string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
508
+ string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
450
509
  string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
510
  string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
452
511
  string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
512
+ string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
513
+ string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
453
514
 
454
515
  for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
455
516
  string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
@@ -457,8 +518,26 @@ void process_shaders() {
457
518
  string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
458
519
  }
459
520
 
460
- string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
461
- string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
521
+ auto get_type_str = [](bool f16) {
522
+ return f16 ? "float16_t" : "float";
523
+ };
524
+ auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
525
+ std::string s;
526
+ s += std::string(src0_f16 ? "_f16" : "_f32");
527
+ s += std::string(src1_f16 ? "_f16" : "_f32");
528
+ s += std::string(dst_f16 ? "_f16" : "_f32");
529
+ return s;
530
+ };
531
+ for (std::string op : {"add", "sub", "mul", "div"}) {
532
+ for (auto src0_f16 : {false, true}) {
533
+ for (auto src1_f16 : {false, true}) {
534
+ for (auto dst_f16 : {false, true}) {
535
+ auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
536
+ string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
537
+ }
538
+ }
539
+ }
540
+ }
462
541
 
463
542
  string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
464
543
 
@@ -493,14 +572,21 @@ void process_shaders() {
493
572
 
494
573
  string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
495
574
 
496
- string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
497
- string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
498
- string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
499
- string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
500
- string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
501
- string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
502
- string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
503
- string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
575
+ string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
576
+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
577
+ string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
578
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
579
+ string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
580
+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
581
+ string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
582
+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
583
+ string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
584
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
585
+ string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
586
+ string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
587
+
588
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
589
+ string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
504
590
 
505
591
  string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
506
592
 
@@ -544,6 +630,9 @@ void process_shaders() {
544
630
 
545
631
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
546
632
 
633
+ string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
634
+ string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
635
+
547
636
  for (auto &c : compiles) {
548
637
  c.wait();
549
638
  }
@@ -598,7 +687,12 @@ void write_output_files() {
598
687
  std::remove(path.c_str());
599
688
  }
600
689
  }
601
-
690
+ for (const char *op : {"add", "sub", "mul", "div"}) {
691
+ fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
692
+ fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
693
+ fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
694
+ fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
695
+ }
602
696
  fclose(hdr);
603
697
  fclose(src);
604
698
  }
@@ -1299,6 +1299,10 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
1299
1299
  return ggml_is_contiguous_n(tensor, 2);
1300
1300
  }
1301
1301
 
1302
+ bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
1303
+ return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
1304
+ }
1305
+
1302
1306
  bool ggml_is_permuted(const struct ggml_tensor * tensor) {
1303
1307
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1304
1308
 
@@ -2728,11 +2732,11 @@ void ggml_mul_mat_set_prec(
2728
2732
  c = ggml_mul_mat_id(ctx, as, b, ids);
2729
2733
 
2730
2734
  as -> [cols, rows, n_expert]
2731
- ids -> [n_experts_used, n_tokens] (i32)
2732
2735
  b -> [cols, n_expert_used, n_tokens]
2736
+ ids -> [n_expert_used, n_tokens] (i32)
2733
2737
  c -> [rows, n_expert_used, n_tokens]
2734
2738
 
2735
- in b, n_experts_used can be broadcasted to match the n_expert_used of ids
2739
+ in b, n_expert_used can be broadcasted to match the n_expert_used of ids
2736
2740
 
2737
2741
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2738
2742
  */
@@ -5495,7 +5499,7 @@ static void ggml_compute_backward(
5495
5499
  // tensor = src0 * 1 + src1 * 0
5496
5500
  if (src0_needs_grads) {
5497
5501
  // dsrc0 = dtensor * 1
5498
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
5502
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
5499
5503
  }
5500
5504
  if (src1_needs_grads) {
5501
5505
  // dsrc1 = dtensor * 0 -> noop
@@ -5776,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
5776
5780
  }
5777
5781
 
5778
5782
  void ggml_build_backward_expand(
5779
- struct ggml_context * ctx_static,
5780
- struct ggml_context * ctx_compute,
5781
- struct ggml_cgraph * cgraph,
5782
- bool accumulate) {
5783
+ struct ggml_context * ctx,
5784
+ struct ggml_cgraph * cgraph,
5785
+ struct ggml_tensor ** grad_accs) {
5783
5786
  GGML_ASSERT(cgraph->n_nodes > 0);
5784
5787
  GGML_ASSERT(cgraph->grads);
5785
5788
  GGML_ASSERT(cgraph->grad_accs);
@@ -5852,21 +5855,24 @@ void ggml_build_backward_expand(
5852
5855
  GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
5853
5856
  node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
5854
5857
 
5855
- const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5856
- GGML_ASSERT(igrad != GGML_HASHSET_FULL);
5857
- GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5858
- if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
5859
- cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
5860
- cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5861
- ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
5858
+ const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
5859
+ GGML_ASSERT(ihash != GGML_HASHSET_FULL);
5860
+ GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
5861
+ if (grad_accs && grad_accs[i]) {
5862
+ cgraph->grad_accs[ihash] = grad_accs[i];
5863
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5864
+ } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
5865
+ // loss tensors always need a gradient accumulator
5866
+ cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
5867
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5862
5868
  }
5863
- grads_needed[igrad] = true;
5869
+ grads_needed[ihash] = true;
5864
5870
  }
5865
5871
 
5866
5872
  for (int i = n_nodes_f - 1; i >= 0; --i) {
5867
5873
  // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
5868
5874
  // use allocator to automatically make inplace operations
5869
- ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
5875
+ ggml_compute_backward(ctx, cgraph, i, grads_needed);
5870
5876
  }
5871
5877
 
5872
5878
  free(grads_needed);
@@ -6012,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6012
6018
  }
6013
6019
  }
6014
6020
 
6015
- struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
6016
- struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
6021
+ struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
6022
+ struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
6017
6023
  ggml_graph_cpy(cgraph, result);
6018
6024
  return result;
6019
6025
  }
@@ -6032,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
6032
6038
  }
6033
6039
 
6034
6040
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
6041
+ if (!cgraph) {
6042
+ return;
6043
+ }
6035
6044
  GGML_ASSERT(cgraph->grads != NULL);
6036
6045
 
6037
6046
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6341,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
6341
6350
  tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
6342
6351
  }
6343
6352
 
6344
- void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
6345
- GGML_UNUSED(ctx); // TODO: remove this parameter
6353
+ void ggml_set_param(struct ggml_tensor * tensor) {
6354
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
6346
6355
  tensor->flags |= GGML_TENSOR_FLAG_PARAM;
6347
6356
  }
6348
6357