@fugood/llama.node 0.3.13 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +89 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/CMakeLists.txt +9 -1
  25. package/src/llama.cpp/cmake/common.cmake +2 -0
  26. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  27. package/src/llama.cpp/common/arg.cpp +132 -13
  28. package/src/llama.cpp/common/chat.cpp +960 -266
  29. package/src/llama.cpp/common/chat.h +135 -0
  30. package/src/llama.cpp/common/common.cpp +33 -174
  31. package/src/llama.cpp/common/common.h +27 -67
  32. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  33. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  34. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  35. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  36. package/src/llama.cpp/common/sampling.cpp +45 -7
  37. package/src/llama.cpp/common/speculative.cpp +10 -9
  38. package/src/llama.cpp/common/speculative.h +1 -1
  39. package/src/llama.cpp/docs/build.md +45 -7
  40. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  41. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
  42. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
  43. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  44. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  45. package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
  46. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  48. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  50. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  51. package/src/llama.cpp/examples/llava/clip.h +19 -3
  52. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  53. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  54. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  55. package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
  56. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  57. package/src/llama.cpp/examples/main/main.cpp +79 -34
  58. package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
  59. package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
  60. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  61. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  62. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  63. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +196 -108
  67. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  68. package/src/llama.cpp/examples/server/server.cpp +113 -101
  69. package/src/llama.cpp/examples/server/utils.hpp +94 -105
  70. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  74. package/src/llama.cpp/examples/tts/tts.cpp +263 -151
  75. package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
  76. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  77. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  79. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  80. package/src/llama.cpp/ggml/include/ggml.h +29 -1
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
  82. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  83. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  84. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  85. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  87. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
  88. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  89. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
  90. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  91. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  102. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  103. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  104. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  105. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  106. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  107. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
  108. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
  109. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  110. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
  111. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  112. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  113. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
  117. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  118. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  124. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
  125. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
  127. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  128. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
  129. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  130. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
  132. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  134. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  135. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
  139. package/src/llama.cpp/ggml/src/ggml.c +93 -5
  140. package/src/llama.cpp/include/llama.h +105 -27
  141. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  142. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  143. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  144. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  145. package/src/llama.cpp/requirements.txt +1 -0
  146. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  147. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  148. package/src/llama.cpp/src/llama-adapter.h +11 -9
  149. package/src/llama.cpp/src/llama-arch.cpp +123 -16
  150. package/src/llama.cpp/src/llama-arch.h +19 -0
  151. package/src/llama.cpp/src/llama-batch.h +2 -2
  152. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  153. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  154. package/src/llama.cpp/src/llama-context.h +214 -77
  155. package/src/llama.cpp/src/llama-cparams.h +1 -0
  156. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  157. package/src/llama.cpp/src/llama-grammar.h +12 -3
  158. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  159. package/src/llama.cpp/src/llama-graph.h +574 -0
  160. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  161. package/src/llama.cpp/src/llama-hparams.h +9 -0
  162. package/src/llama.cpp/src/llama-io.cpp +15 -0
  163. package/src/llama.cpp/src/llama-io.h +35 -0
  164. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  165. package/src/llama.cpp/src/llama-kv-cache.h +178 -109
  166. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  167. package/src/llama.cpp/src/llama-memory.h +21 -0
  168. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  169. package/src/llama.cpp/src/llama-model.cpp +8230 -122
  170. package/src/llama.cpp/src/llama-model.h +34 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  172. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  173. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  174. package/src/llama.cpp/src/llama.cpp +51 -9837
  175. package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
  176. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  177. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  178. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  179. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  180. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  181. package/src/llama.cpp/common/chat.hpp +0 -55
  182. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  183. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
  184. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -18,20 +18,22 @@
18
18
  #include <ggml.h>
19
19
  #include <ggml-alloc.h>
20
20
  #include <ggml-backend.h>
21
+ #include <ggml-cpp.h>
21
22
 
22
23
  #include <algorithm>
23
24
  #include <array>
24
25
  #include <cfloat>
26
+ #include <cinttypes>
25
27
  #include <cstdint>
28
+ #include <cstdio>
29
+ #include <cstdlib>
26
30
  #include <cstring>
27
- #include <cinttypes>
31
+ #include <future>
28
32
  #include <memory>
29
33
  #include <random>
30
- #include <stdio.h>
31
- #include <stdlib.h>
34
+ #include <regex>
32
35
  #include <string>
33
36
  #include <thread>
34
- #include <future>
35
37
  #include <vector>
36
38
 
37
39
  static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
@@ -257,6 +259,10 @@ static std::string var_to_str(ggml_type type) {
257
259
  return ggml_type_name(type);
258
260
  }
259
261
 
262
+ static std::string var_to_str(ggml_prec prec) {
263
+ return prec == GGML_PREC_F32 ? "f32" : "def";
264
+ }
265
+
260
266
  static std::string var_to_str(ggml_op_pool pool) {
261
267
  switch (pool) {
262
268
  case GGML_OP_POOL_AVG: return "avg";
@@ -467,6 +473,7 @@ struct test_case {
467
473
 
468
474
  // allocate
469
475
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
476
+
470
477
  if (buf == NULL) {
471
478
  printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
472
479
  ggml_free(ctx);
@@ -588,14 +595,13 @@ struct test_case {
588
595
  /* .mem_base = */ NULL,
589
596
  /* .no_alloc = */ true,
590
597
  };
591
- ggml_context * ctx = ggml_init(params);
598
+ ggml_context_ptr ctx(ggml_init(params)); // smart ptr
592
599
  GGML_ASSERT(ctx);
593
600
 
594
- ggml_tensor * out = build_graph(ctx);
601
+ ggml_tensor * out = build_graph(ctx.get());
595
602
 
596
603
  if (op_name != nullptr && op_desc(out) != op_name) {
597
604
  //printf(" %s: skipping\n", op_desc(out).c_str());
598
- ggml_free(ctx);
599
605
  return true;
600
606
  }
601
607
 
@@ -605,7 +611,6 @@ struct test_case {
605
611
  // check if backends support op
606
612
  if (!ggml_backend_supports_op(backend, out)) {
607
613
  printf("not supported\n");
608
- ggml_free(ctx);
609
614
  return true;
610
615
  }
611
616
 
@@ -618,22 +623,26 @@ struct test_case {
618
623
  printf("%*s", last - len, "");
619
624
 
620
625
  // allocate
621
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
626
+ ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
627
+
622
628
  if (buf == NULL) {
623
629
  printf("failed to allocate tensors\n");
624
- ggml_free(ctx);
625
630
  return false;
626
631
  }
627
632
 
628
633
  // randomize tensors
629
- initialize_tensors(ctx);
634
+ initialize_tensors(ctx.get());
630
635
 
631
636
  // build graph
632
- ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
637
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
633
638
  ggml_build_forward_expand(gf, out);
634
639
 
635
640
  // warmup run
636
- ggml_backend_graph_compute(backend, gf);
641
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
642
+ if (status != GGML_STATUS_SUCCESS) {
643
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
644
+ return false;
645
+ }
637
646
 
638
647
  // determine number of runs
639
648
  int n_runs;
@@ -684,7 +693,11 @@ struct test_case {
684
693
  int total_runs = 0;
685
694
  do {
686
695
  int64_t start_time = ggml_time_us();
687
- ggml_backend_graph_compute(backend, gf);
696
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
697
+ if (status != GGML_STATUS_SUCCESS) {
698
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
699
+ return false;
700
+ }
688
701
  int64_t end_time = ggml_time_us();
689
702
 
690
703
  total_time_us += end_time - start_time;
@@ -722,10 +735,6 @@ struct test_case {
722
735
  }
723
736
  printf("\n");
724
737
 
725
- ggml_backend_buffer_free(buf);
726
-
727
- ggml_free(ctx);
728
-
729
738
  return true;
730
739
  }
731
740
 
@@ -738,17 +747,16 @@ struct test_case {
738
747
  /* .mem_base = */ NULL,
739
748
  /* .no_alloc = */ true,
740
749
  };
741
- ggml_context * ctx = ggml_init(params);
750
+ ggml_context_ptr ctx(ggml_init(params)); // smart ptr
742
751
  GGML_ASSERT(ctx);
743
752
 
744
- gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
745
- gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
753
+ gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
754
+ gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
746
755
 
747
- ggml_tensor * out = build_graph(ctx);
756
+ ggml_tensor * out = build_graph(ctx.get());
748
757
 
749
758
  if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
750
759
  //printf(" %s: skipping\n", op_desc(out).c_str());
751
- ggml_free(ctx);
752
760
  return true;
753
761
  }
754
762
 
@@ -756,7 +764,6 @@ struct test_case {
756
764
  fflush(stdout);
757
765
 
758
766
  if (out->type != GGML_TYPE_F32) {
759
- ggml_free(ctx);
760
767
  printf("not supported [%s->type != FP32]\n", out->name);
761
768
  return true;
762
769
  }
@@ -764,7 +771,7 @@ struct test_case {
764
771
  // check if the backend supports the ops
765
772
  bool supported = true;
766
773
  bool any_params = false;
767
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
774
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
768
775
  if (!ggml_backend_supports_op(backend, t)) {
769
776
  printf("not supported [%s] ", ggml_backend_name(backend));
770
777
  supported = false;
@@ -785,40 +792,38 @@ struct test_case {
785
792
  }
786
793
  if (!supported) {
787
794
  printf("\n");
788
- ggml_free(ctx);
789
795
  return true;
790
796
  }
791
797
 
792
798
  int64_t ngrads = 0;
793
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
799
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
794
800
  if (t->flags & GGML_TENSOR_FLAG_PARAM) {
795
801
  ngrads += ggml_nelements(t);
796
802
  }
797
803
  }
798
804
  if (ngrads > grad_nmax()) {
799
805
  printf("skipping large tensors for speed \n");
800
- ggml_free(ctx);
801
806
  return true;
802
807
  }
803
808
 
804
809
 
805
810
  if (!ggml_is_scalar(out)) {
806
- out = ggml_sum(ctx, out);
811
+ out = ggml_sum(ctx.get(), out);
807
812
  ggml_set_name(out, "sum_of_out");
808
813
  }
809
814
  ggml_set_loss(out);
810
815
 
811
816
  ggml_build_forward_expand(gf, out);
812
817
  ggml_graph_cpy(gf, gb);
813
- ggml_build_backward_expand(ctx, ctx, gb, false);
818
+ ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
814
819
  if (expect.size() != 1 || expect[0] != 0.0f) {
815
820
  GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
816
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
821
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
817
822
  GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
818
823
  }
819
824
  }
820
825
 
821
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
826
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
822
827
  if (!ggml_backend_supports_op(backend, t)) {
823
828
  printf("not supported [%s] ", ggml_backend_name(backend));
824
829
  supported = false;
@@ -832,27 +837,32 @@ struct test_case {
832
837
  }
833
838
  if (!supported) {
834
839
  printf("\n");
835
- ggml_free(ctx);
836
840
  return true;
837
841
  }
838
842
 
839
843
  // allocate
840
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
844
+ ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
841
845
  if (buf == NULL) {
842
846
  printf("failed to allocate tensors [%s] ", ggml_backend_name(backend));
843
- ggml_free(ctx);
844
847
  return false;
845
848
  }
846
849
 
847
-
848
- initialize_tensors(ctx); // Randomizes all tensors (including gradients).
850
+ initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).
849
851
  ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise.
850
852
 
851
- ggml_backend_graph_compute(backend, gf);
852
- ggml_backend_graph_compute(backend, gb);
853
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
854
+ if (status != GGML_STATUS_SUCCESS) {
855
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
856
+ return false;
857
+ }
858
+ status = ggml_backend_graph_compute(backend, gb);
859
+ if (status != GGML_STATUS_SUCCESS) {
860
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
861
+ return false;
862
+ }
853
863
 
854
864
  bool ok = true;
855
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
865
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
856
866
  if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
857
867
  continue;
858
868
  }
@@ -897,20 +907,36 @@ struct test_case {
897
907
  float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
898
908
 
899
909
  ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
900
- ggml_backend_graph_compute(backend, gf);
910
+ status = ggml_backend_graph_compute(backend, gf);
911
+ if (status != GGML_STATUS_SUCCESS) {
912
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
913
+ return false;
914
+ }
901
915
  ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
902
916
 
903
917
  ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
904
- ggml_backend_graph_compute(backend, gf);
918
+ status = ggml_backend_graph_compute(backend, gf);
919
+ if (status != GGML_STATUS_SUCCESS) {
920
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
921
+ return false;
922
+ }
905
923
  ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
906
924
 
907
925
  if (grad_precise()) {
908
926
  ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
909
- ggml_backend_graph_compute(backend, gf);
927
+ status = ggml_backend_graph_compute(backend, gf);
928
+ if (status != GGML_STATUS_SUCCESS) {
929
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
930
+ return false;
931
+ }
910
932
  ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
911
933
 
912
934
  ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
913
- ggml_backend_graph_compute(backend, gf);
935
+ status = ggml_backend_graph_compute(backend, gf);
936
+ if (status != GGML_STATUS_SUCCESS) {
937
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
938
+ return false;
939
+ }
914
940
  ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
915
941
 
916
942
  gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
@@ -936,10 +962,6 @@ struct test_case {
936
962
  printf("compare failed ");
937
963
  }
938
964
 
939
- ggml_backend_buffer_free(buf);
940
-
941
- ggml_free(ctx);
942
-
943
965
  if (ok) {
944
966
  printf("\033[1;32mOK\033[0m\n");
945
967
  return true;
@@ -1898,6 +1920,40 @@ struct test_gla : public test_case {
1898
1920
  }
1899
1921
  };
1900
1922
 
1923
+ // GGML_OP_RWKV_WKV7
1924
+ struct test_rwkv_wkv7 : public test_case {
1925
+ const ggml_type type;
1926
+
1927
+ const int64_t head_count;
1928
+ const int64_t head_size;
1929
+ const int64_t n_seq_tokens;
1930
+ const int64_t n_seqs;
1931
+
1932
+ std::string vars() override {
1933
+ return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1934
+ }
1935
+
1936
+ test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
1937
+ int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1938
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1939
+
1940
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1941
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1942
+ ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1943
+ ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1944
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1945
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1946
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1947
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1948
+ // Outputs may become NaN with long seqlen without these normalization
1949
+ a = ggml_l2_norm(ctx, a, 1e-7F);
1950
+ b = ggml_l2_norm(ctx, b, 1e-7F);
1951
+ ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1952
+ ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
1953
+ return out;
1954
+ }
1955
+ };
1956
+
1901
1957
  // GGML_OP_MUL_MAT
1902
1958
  struct test_mul_mat : public test_case {
1903
1959
  const ggml_type type_a;
@@ -2954,6 +3010,32 @@ struct test_group_norm : public test_case {
2954
3010
  }
2955
3011
  };
2956
3012
 
3013
+ // GGML_OP_L2_NORM
3014
+ struct test_l2_norm : public test_case {
3015
+ const ggml_type type;
3016
+ const std::array<int64_t, 4> ne;
3017
+ const float eps;
3018
+
3019
+ std::string vars() override {
3020
+ return VARS_TO_STR2(type, ne);
3021
+ }
3022
+
3023
+ test_l2_norm(ggml_type type = GGML_TYPE_F32,
3024
+ std::array<int64_t, 4> ne = {64, 64, 320, 1},
3025
+ float eps = 1e-12f)
3026
+ : type(type), ne(ne), eps(eps) {}
3027
+
3028
+ ggml_tensor * build_graph(ggml_context * ctx) override {
3029
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3030
+ ggml_set_name(a, "a");
3031
+
3032
+ ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
3033
+ ggml_set_name(out, "out");
3034
+
3035
+ return out;
3036
+ }
3037
+ };
3038
+
2957
3039
  // GGML_OP_ACC
2958
3040
  struct test_acc : public test_case {
2959
3041
  const ggml_type type;
@@ -3119,6 +3201,7 @@ struct test_leaky_relu : public test_case {
3119
3201
  struct test_flash_attn_ext : public test_case {
3120
3202
  const int64_t hs; // head size
3121
3203
  const int64_t nh; // num heads
3204
+ const int64_t nr; // repeat in Q, tests for grouped-query attention
3122
3205
  const int64_t kv; // kv size
3123
3206
  const int64_t nb; // batch size
3124
3207
 
@@ -3127,11 +3210,12 @@ struct test_flash_attn_ext : public test_case {
3127
3210
  const float max_bias; // ALiBi
3128
3211
  const float logit_softcap; // Gemma 2
3129
3212
 
3213
+ const ggml_prec prec;
3130
3214
  const ggml_type type_KV;
3131
3215
  std::array<int32_t, 4> permute;
3132
3216
 
3133
3217
  std::string vars() override {
3134
- return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3218
+ return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3135
3219
  }
3136
3220
 
3137
3221
  double max_nmse_err() override {
@@ -3142,13 +3226,13 @@ struct test_flash_attn_ext : public test_case {
3142
3226
  GGML_UNUSED(t);
3143
3227
  // Just counting matmul costs:
3144
3228
  // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
3145
- return 2 * 2 * nh * nb * hs * kv;
3229
+ return 2 * 2 * nh*nr * nb * hs * kv;
3146
3230
  }
3147
3231
 
3148
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
3149
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
3150
- std::array<int32_t, 4> permute = {0, 1, 2, 3})
3151
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3232
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3233
+ bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
3234
+ ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
3235
+ : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3152
3236
 
3153
3237
  ggml_tensor * build_graph(ggml_context * ctx) override {
3154
3238
  const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3166,13 +3250,13 @@ struct test_flash_attn_ext : public test_case {
3166
3250
  return t;
3167
3251
  };
3168
3252
 
3169
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
3253
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
3170
3254
  ggml_set_name(q, "q");
3171
3255
 
3172
- ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3256
+ ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3173
3257
  ggml_set_name(k, "k");
3174
3258
 
3175
- ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3259
+ ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3176
3260
  ggml_set_name(v, "v");
3177
3261
 
3178
3262
  ggml_tensor * m = nullptr;
@@ -3182,6 +3266,7 @@ struct test_flash_attn_ext : public test_case {
3182
3266
  }
3183
3267
 
3184
3268
  ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
3269
+ ggml_flash_attn_ext_set_prec(out, prec);
3185
3270
  ggml_set_name(out, "out");
3186
3271
 
3187
3272
  return out;
@@ -3752,10 +3837,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3752
3837
  std::default_random_engine rng(0);
3753
3838
 
3754
3839
  // unary ops
3755
- for (int v : {0, 1}) {
3756
- for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
3757
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 2, 2, 2 }, v));
3758
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 5, 7, 11, 13 }, v));
3840
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3841
+ for (int v : {0, 1}) {
3842
+ for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
3843
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
3844
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
3845
+ }
3759
3846
  }
3760
3847
  }
3761
3848
 
@@ -3942,37 +4029,38 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3942
4029
  test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
3943
4030
  }
3944
4031
  };
3945
-
3946
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
3947
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
3948
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
3949
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1});
3950
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1});
3951
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1});
3952
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1});
3953
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1});
3954
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1});
3955
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2});
3956
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2});
3957
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2});
3958
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2});
3959
-
3960
- // stable diffusion
3961
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
3962
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
3963
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
3964
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
3965
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
3966
- add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
3967
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
3968
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
3969
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
3970
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
3971
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
3972
- add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
3973
- add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
3974
- //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3975
- //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
4032
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4033
+ add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
4034
+ add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
4035
+ add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
4036
+ add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
4037
+ add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
4038
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
4039
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
4040
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
4041
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
4042
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
4043
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
4044
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
4045
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
4046
+
4047
+ // stable diffusion
4048
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
4049
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
4050
+ add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
4051
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
4052
+ add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
4053
+ add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
4054
+ add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
4055
+ add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
4056
+ add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
4057
+ add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
4058
+ add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
4059
+ add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
4060
+ add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
4061
+ //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
4062
+ //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
4063
+ }
3976
4064
 
3977
4065
  test_cases.emplace_back(new test_add1());
3978
4066
  test_cases.emplace_back(new test_scale());
@@ -3984,8 +4072,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3984
4072
  test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
3985
4073
  }
3986
4074
  test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
4075
+ test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3987
4076
  }
3988
4077
 
4078
+ test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
4079
+
3989
4080
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
3990
4081
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
3991
4082
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
@@ -3997,6 +4088,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3997
4088
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
3998
4089
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
3999
4090
 
4091
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
4092
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
4093
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
4094
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
4095
+
4000
4096
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
4001
4097
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
4002
4098
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
@@ -4091,7 +4187,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4091
4187
  for (int n_mats : {4, 8}) {
4092
4188
  for (int n_used : {1, 2, 4}) {
4093
4189
  for (bool b : {false, true}) {
4094
- for (int n : {1, 32}) {
4190
+ for (int n : {1, 32, 129}) {
4095
4191
  int m = 512;
4096
4192
  int k = 256;
4097
4193
  test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
@@ -4136,12 +4232,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4136
4232
  }
4137
4233
  }
4138
4234
 
4139
- test_cases.emplace_back(new test_sqr());
4140
- test_cases.emplace_back(new test_sqrt());
4141
- test_cases.emplace_back(new test_log());
4142
- test_cases.emplace_back(new test_sin());
4143
- test_cases.emplace_back(new test_cos());
4144
- test_cases.emplace_back(new test_clamp());
4235
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4236
+ test_cases.emplace_back(new test_sqr(type));
4237
+ test_cases.emplace_back(new test_sqrt(type));
4238
+ test_cases.emplace_back(new test_log(type));
4239
+ test_cases.emplace_back(new test_sin(type));
4240
+ test_cases.emplace_back(new test_cos(type));
4241
+ test_cases.emplace_back(new test_clamp(type));
4242
+ }
4145
4243
 
4146
4244
  test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
4147
4245
  test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
@@ -4278,14 +4376,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4278
4376
  if (!mask && max_bias > 0.0f) continue;
4279
4377
  for (float logit_softcap : {0.0f, 10.0f}) {
4280
4378
  if (hs != 128 && logit_softcap != 0.0f) continue;
4281
- for (int nh : { 32, }) {
4282
- for (int kv : { 512, 1024, }) {
4283
- for (int nb : { 1, 3, 32, 35, }) {
4284
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4285
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
4286
- // run fewer test cases permuted
4287
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4288
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4379
+ for (int nh : { 4, }) {
4380
+ for (int nr : { 1, 4, 16 }) {
4381
+ if (nr == 16 && hs != 128) continue;
4382
+ for (int kv : { 512, 1024, }) {
4383
+ if (nr != 1 && kv != 512) continue;
4384
+ for (int nb : { 1, 3, 32, 35, }) {
4385
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4386
+ if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
4387
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4388
+ test_cases.emplace_back(new test_flash_attn_ext(
4389
+ hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4390
+ // run fewer test cases permuted
4391
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4392
+ test_cases.emplace_back(new test_flash_attn_ext(
4393
+ hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4394
+ }
4395
+ }
4289
4396
  }
4290
4397
  }
4291
4398
  }
@@ -4360,9 +4467,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4360
4467
  return test_cases;
4361
4468
  }
4362
4469
 
4363
- static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
4470
+ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
4471
+ auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
4472
+ if (params_filter == nullptr) {
4473
+ return;
4474
+ }
4475
+
4476
+ std::regex params_filter_regex(params_filter);
4477
+
4478
+ for (auto it = test_cases.begin(); it != test_cases.end();) {
4479
+ if (!std::regex_search((*it)->vars(), params_filter_regex)) {
4480
+ it = test_cases.erase(it);
4481
+ continue;
4482
+ }
4483
+
4484
+ it++;
4485
+ }
4486
+ };
4487
+
4364
4488
  if (mode == MODE_TEST) {
4365
4489
  auto test_cases = make_test_cases_eval();
4490
+ filter_test_cases(test_cases, params_filter);
4366
4491
  ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
4367
4492
  if (backend_cpu == NULL) {
4368
4493
  printf(" Failed to initialize CPU backend\n");
@@ -4384,6 +4509,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4384
4509
 
4385
4510
  if (mode == MODE_GRAD) {
4386
4511
  auto test_cases = make_test_cases_eval();
4512
+ filter_test_cases(test_cases, params_filter);
4387
4513
  size_t n_ok = 0;
4388
4514
  for (auto & test : test_cases) {
4389
4515
  if (test->eval_grad(backend, op_name)) {
@@ -4397,6 +4523,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4397
4523
 
4398
4524
  if (mode == MODE_PERF) {
4399
4525
  auto test_cases = make_test_cases_perf();
4526
+ filter_test_cases(test_cases, params_filter);
4400
4527
  for (auto & test : test_cases) {
4401
4528
  test->eval_perf(backend, op_name);
4402
4529
  }
@@ -4407,7 +4534,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4407
4534
  }
4408
4535
 
4409
4536
  static void usage(char ** argv) {
4410
- printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
4537
+ printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>]\n", argv[0]);
4411
4538
  printf(" valid modes:\n");
4412
4539
  printf(" - test (default, compare with CPU backend for correctness)\n");
4413
4540
  printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
@@ -4417,8 +4544,9 @@ static void usage(char ** argv) {
4417
4544
 
4418
4545
  int main(int argc, char ** argv) {
4419
4546
  test_mode mode = MODE_TEST;
4420
- const char * op_name_filter = NULL;
4421
- const char * backend_filter = NULL;
4547
+ const char * op_name_filter = nullptr;
4548
+ const char * backend_filter = nullptr;
4549
+ const char * params_filter = nullptr;
4422
4550
 
4423
4551
  for (int i = 1; i < argc; i++) {
4424
4552
  if (strcmp(argv[i], "test") == 0) {
@@ -4441,6 +4569,13 @@ int main(int argc, char ** argv) {
4441
4569
  usage(argv);
4442
4570
  return 1;
4443
4571
  }
4572
+ } else if (strcmp(argv[i], "-p") == 0) {
4573
+ if (i + 1 < argc) {
4574
+ params_filter = argv[++i];
4575
+ } else {
4576
+ usage(argv);
4577
+ return 1;
4578
+ }
4444
4579
  } else {
4445
4580
  usage(argv);
4446
4581
  return 1;
@@ -4487,7 +4622,7 @@ int main(int argc, char ** argv) {
4487
4622
  printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
4488
4623
  printf("\n");
4489
4624
 
4490
- bool ok = test_backend(backend, mode, op_name_filter);
4625
+ bool ok = test_backend(backend, mode, op_name_filter, params_filter);
4491
4626
 
4492
4627
  printf(" Backend %s: ", ggml_backend_name(backend));
4493
4628
  if (ok) {