@fugood/llama.node 0.3.12 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +2 -1
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +110 -79
  21. package/src/LlamaContext.h +1 -1
  22. package/src/common.hpp +1 -2
  23. package/src/llama.cpp/.github/workflows/build.yml +95 -13
  24. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  25. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  27. package/src/llama.cpp/common/CMakeLists.txt +23 -6
  28. package/src/llama.cpp/common/arg.cpp +292 -14
  29. package/src/llama.cpp/common/chat.cpp +1128 -315
  30. package/src/llama.cpp/common/chat.h +135 -0
  31. package/src/llama.cpp/common/common.cpp +27 -171
  32. package/src/llama.cpp/common/common.h +41 -73
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  34. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  35. package/src/llama.cpp/common/llguidance.cpp +3 -3
  36. package/src/llama.cpp/common/log.cpp +1 -0
  37. package/src/llama.cpp/common/log.h +2 -1
  38. package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
  39. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
  40. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  41. package/src/llama.cpp/common/sampling.cpp +93 -49
  42. package/src/llama.cpp/common/speculative.cpp +6 -5
  43. package/src/llama.cpp/common/speculative.h +1 -1
  44. package/src/llama.cpp/docs/build.md +47 -9
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  47. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  48. package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
  49. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  50. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  52. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  54. package/src/llama.cpp/examples/llava/clip.h +19 -3
  55. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  56. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  57. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  58. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  59. package/src/llama.cpp/examples/main/main.cpp +73 -28
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +115 -79
  67. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/server/httplib.h +381 -292
  69. package/src/llama.cpp/examples/server/server.cpp +134 -128
  70. package/src/llama.cpp/examples/server/utils.hpp +95 -106
  71. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  72. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  73. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  74. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  75. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  76. package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
  77. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  79. package/src/llama.cpp/ggml/include/ggml.h +6 -2
  80. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  81. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  82. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  83. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  84. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  85. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  86. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  87. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  88. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  89. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  90. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
  96. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
  102. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  103. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  104. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  105. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  106. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  107. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
  109. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  110. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  111. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  112. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  115. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  116. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  117. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  121. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
  124. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  125. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  128. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
  129. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
  130. package/src/llama.cpp/ggml/src/ggml.c +9 -4
  131. package/src/llama.cpp/include/llama.h +32 -14
  132. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  133. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  134. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  135. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  136. package/src/llama.cpp/requirements.txt +1 -0
  137. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  138. package/src/llama.cpp/src/llama-arch.h +1 -0
  139. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  140. package/src/llama.cpp/src/llama-grammar.cpp +183 -183
  141. package/src/llama.cpp/src/llama-grammar.h +13 -4
  142. package/src/llama.cpp/src/llama-impl.h +6 -6
  143. package/src/llama.cpp/src/llama-kv-cache.h +2 -1
  144. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  145. package/src/llama.cpp/src/llama-mmap.h +1 -0
  146. package/src/llama.cpp/src/llama-model.cpp +70 -6
  147. package/src/llama.cpp/src/llama-sampling.cpp +174 -67
  148. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  149. package/src/llama.cpp/src/llama.cpp +154 -5
  150. package/src/llama.cpp/src/unicode.cpp +9 -2
  151. package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
  152. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  153. package/src/llama.cpp/tests/test-chat.cpp +691 -325
  154. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  155. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  156. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  157. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
  158. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  159. package/src/llama.cpp/common/chat.hpp +0 -52
@@ -18,20 +18,22 @@
18
18
  #include <ggml.h>
19
19
  #include <ggml-alloc.h>
20
20
  #include <ggml-backend.h>
21
+ #include <ggml-cpp.h>
21
22
 
22
23
  #include <algorithm>
23
24
  #include <array>
24
25
  #include <cfloat>
26
+ #include <cinttypes>
25
27
  #include <cstdint>
28
+ #include <cstdio>
29
+ #include <cstdlib>
26
30
  #include <cstring>
27
- #include <cinttypes>
31
+ #include <future>
28
32
  #include <memory>
29
33
  #include <random>
30
- #include <stdio.h>
31
- #include <stdlib.h>
34
+ #include <regex>
32
35
  #include <string>
33
36
  #include <thread>
34
- #include <future>
35
37
  #include <vector>
36
38
 
37
39
  static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
@@ -467,6 +469,7 @@ struct test_case {
467
469
 
468
470
  // allocate
469
471
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
472
+
470
473
  if (buf == NULL) {
471
474
  printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
472
475
  ggml_free(ctx);
@@ -588,14 +591,13 @@ struct test_case {
588
591
  /* .mem_base = */ NULL,
589
592
  /* .no_alloc = */ true,
590
593
  };
591
- ggml_context * ctx = ggml_init(params);
594
+ ggml_context_ptr ctx(ggml_init(params)); // smart ptr
592
595
  GGML_ASSERT(ctx);
593
596
 
594
- ggml_tensor * out = build_graph(ctx);
597
+ ggml_tensor * out = build_graph(ctx.get());
595
598
 
596
599
  if (op_name != nullptr && op_desc(out) != op_name) {
597
600
  //printf(" %s: skipping\n", op_desc(out).c_str());
598
- ggml_free(ctx);
599
601
  return true;
600
602
  }
601
603
 
@@ -605,7 +607,6 @@ struct test_case {
605
607
  // check if backends support op
606
608
  if (!ggml_backend_supports_op(backend, out)) {
607
609
  printf("not supported\n");
608
- ggml_free(ctx);
609
610
  return true;
610
611
  }
611
612
 
@@ -618,22 +619,26 @@ struct test_case {
618
619
  printf("%*s", last - len, "");
619
620
 
620
621
  // allocate
621
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
622
+ ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
623
+
622
624
  if (buf == NULL) {
623
625
  printf("failed to allocate tensors\n");
624
- ggml_free(ctx);
625
626
  return false;
626
627
  }
627
628
 
628
629
  // randomize tensors
629
- initialize_tensors(ctx);
630
+ initialize_tensors(ctx.get());
630
631
 
631
632
  // build graph
632
- ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
633
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
633
634
  ggml_build_forward_expand(gf, out);
634
635
 
635
636
  // warmup run
636
- ggml_backend_graph_compute(backend, gf);
637
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
638
+ if (status != GGML_STATUS_SUCCESS) {
639
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
640
+ return false;
641
+ }
637
642
 
638
643
  // determine number of runs
639
644
  int n_runs;
@@ -684,7 +689,11 @@ struct test_case {
684
689
  int total_runs = 0;
685
690
  do {
686
691
  int64_t start_time = ggml_time_us();
687
- ggml_backend_graph_compute(backend, gf);
692
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
693
+ if (status != GGML_STATUS_SUCCESS) {
694
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
695
+ return false;
696
+ }
688
697
  int64_t end_time = ggml_time_us();
689
698
 
690
699
  total_time_us += end_time - start_time;
@@ -722,10 +731,6 @@ struct test_case {
722
731
  }
723
732
  printf("\n");
724
733
 
725
- ggml_backend_buffer_free(buf);
726
-
727
- ggml_free(ctx);
728
-
729
734
  return true;
730
735
  }
731
736
 
@@ -738,17 +743,16 @@ struct test_case {
738
743
  /* .mem_base = */ NULL,
739
744
  /* .no_alloc = */ true,
740
745
  };
741
- ggml_context * ctx = ggml_init(params);
746
+ ggml_context_ptr ctx(ggml_init(params)); // smart ptr
742
747
  GGML_ASSERT(ctx);
743
748
 
744
- gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
745
- gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
749
+ gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
750
+ gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
746
751
 
747
- ggml_tensor * out = build_graph(ctx);
752
+ ggml_tensor * out = build_graph(ctx.get());
748
753
 
749
754
  if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
750
755
  //printf(" %s: skipping\n", op_desc(out).c_str());
751
- ggml_free(ctx);
752
756
  return true;
753
757
  }
754
758
 
@@ -756,7 +760,6 @@ struct test_case {
756
760
  fflush(stdout);
757
761
 
758
762
  if (out->type != GGML_TYPE_F32) {
759
- ggml_free(ctx);
760
763
  printf("not supported [%s->type != FP32]\n", out->name);
761
764
  return true;
762
765
  }
@@ -764,7 +767,7 @@ struct test_case {
764
767
  // check if the backend supports the ops
765
768
  bool supported = true;
766
769
  bool any_params = false;
767
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
770
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
768
771
  if (!ggml_backend_supports_op(backend, t)) {
769
772
  printf("not supported [%s] ", ggml_backend_name(backend));
770
773
  supported = false;
@@ -785,40 +788,38 @@ struct test_case {
785
788
  }
786
789
  if (!supported) {
787
790
  printf("\n");
788
- ggml_free(ctx);
789
791
  return true;
790
792
  }
791
793
 
792
794
  int64_t ngrads = 0;
793
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
795
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
794
796
  if (t->flags & GGML_TENSOR_FLAG_PARAM) {
795
797
  ngrads += ggml_nelements(t);
796
798
  }
797
799
  }
798
800
  if (ngrads > grad_nmax()) {
799
801
  printf("skipping large tensors for speed \n");
800
- ggml_free(ctx);
801
802
  return true;
802
803
  }
803
804
 
804
805
 
805
806
  if (!ggml_is_scalar(out)) {
806
- out = ggml_sum(ctx, out);
807
+ out = ggml_sum(ctx.get(), out);
807
808
  ggml_set_name(out, "sum_of_out");
808
809
  }
809
810
  ggml_set_loss(out);
810
811
 
811
812
  ggml_build_forward_expand(gf, out);
812
813
  ggml_graph_cpy(gf, gb);
813
- ggml_build_backward_expand(ctx, ctx, gb, false);
814
+ ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
814
815
  if (expect.size() != 1 || expect[0] != 0.0f) {
815
816
  GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
816
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
817
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
817
818
  GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
818
819
  }
819
820
  }
820
821
 
821
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
822
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
822
823
  if (!ggml_backend_supports_op(backend, t)) {
823
824
  printf("not supported [%s] ", ggml_backend_name(backend));
824
825
  supported = false;
@@ -832,27 +833,32 @@ struct test_case {
832
833
  }
833
834
  if (!supported) {
834
835
  printf("\n");
835
- ggml_free(ctx);
836
836
  return true;
837
837
  }
838
838
 
839
839
  // allocate
840
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
840
+ ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
841
841
  if (buf == NULL) {
842
842
  printf("failed to allocate tensors [%s] ", ggml_backend_name(backend));
843
- ggml_free(ctx);
844
843
  return false;
845
844
  }
846
845
 
847
-
848
- initialize_tensors(ctx); // Randomizes all tensors (including gradients).
846
+ initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).
849
847
  ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise.
850
848
 
851
- ggml_backend_graph_compute(backend, gf);
852
- ggml_backend_graph_compute(backend, gb);
849
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
850
+ if (status != GGML_STATUS_SUCCESS) {
851
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
852
+ return false;
853
+ }
854
+ status = ggml_backend_graph_compute(backend, gb);
855
+ if (status != GGML_STATUS_SUCCESS) {
856
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
857
+ return false;
858
+ }
853
859
 
854
860
  bool ok = true;
855
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
861
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
856
862
  if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
857
863
  continue;
858
864
  }
@@ -897,20 +903,36 @@ struct test_case {
897
903
  float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
898
904
 
899
905
  ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
900
- ggml_backend_graph_compute(backend, gf);
906
+ status = ggml_backend_graph_compute(backend, gf);
907
+ if (status != GGML_STATUS_SUCCESS) {
908
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
909
+ return false;
910
+ }
901
911
  ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
902
912
 
903
913
  ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
904
- ggml_backend_graph_compute(backend, gf);
914
+ status = ggml_backend_graph_compute(backend, gf);
915
+ if (status != GGML_STATUS_SUCCESS) {
916
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
917
+ return false;
918
+ }
905
919
  ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
906
920
 
907
921
  if (grad_precise()) {
908
922
  ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
909
- ggml_backend_graph_compute(backend, gf);
923
+ status = ggml_backend_graph_compute(backend, gf);
924
+ if (status != GGML_STATUS_SUCCESS) {
925
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
926
+ return false;
927
+ }
910
928
  ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
911
929
 
912
930
  ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
913
- ggml_backend_graph_compute(backend, gf);
931
+ status = ggml_backend_graph_compute(backend, gf);
932
+ if (status != GGML_STATUS_SUCCESS) {
933
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
934
+ return false;
935
+ }
914
936
  ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
915
937
 
916
938
  gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
@@ -936,10 +958,6 @@ struct test_case {
936
958
  printf("compare failed ");
937
959
  }
938
960
 
939
- ggml_backend_buffer_free(buf);
940
-
941
- ggml_free(ctx);
942
-
943
961
  if (ok) {
944
962
  printf("\033[1;32mOK\033[0m\n");
945
963
  return true;
@@ -1254,7 +1272,7 @@ struct test_count_equal : public test_case {
1254
1272
  ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
1255
1273
  ggml_set_name(b, "b");
1256
1274
 
1257
- ggml_tensor * b_argmax = ggml_argmax(ctx, a);
1275
+ ggml_tensor * b_argmax = ggml_argmax(ctx, b);
1258
1276
  ggml_set_name(b_argmax, "b_argmax");
1259
1277
 
1260
1278
  ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
@@ -1511,6 +1529,7 @@ struct test_cont : public test_case {
1511
1529
  };
1512
1530
 
1513
1531
  // GGML_OP_ADD
1532
+ // GGML_OP_SUB
1514
1533
  // GGML_OP_MUL
1515
1534
  // GGML_OP_DIV
1516
1535
  struct test_bin_bcast : public test_case {
@@ -3118,6 +3137,7 @@ struct test_leaky_relu : public test_case {
3118
3137
  struct test_flash_attn_ext : public test_case {
3119
3138
  const int64_t hs; // head size
3120
3139
  const int64_t nh; // num heads
3140
+ const int64_t nr; // repeat in Q, tests for grouped-query attention
3121
3141
  const int64_t kv; // kv size
3122
3142
  const int64_t nb; // batch size
3123
3143
 
@@ -3130,7 +3150,7 @@ struct test_flash_attn_ext : public test_case {
3130
3150
  std::array<int32_t, 4> permute;
3131
3151
 
3132
3152
  std::string vars() override {
3133
- return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3153
+ return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3134
3154
  }
3135
3155
 
3136
3156
  double max_nmse_err() override {
@@ -3141,13 +3161,13 @@ struct test_flash_attn_ext : public test_case {
3141
3161
  GGML_UNUSED(t);
3142
3162
  // Just counting matmul costs:
3143
3163
  // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
3144
- return 2 * 2 * nh * nb * hs * kv;
3164
+ return 2 * 2 * nh*nr * nb * hs * kv;
3145
3165
  }
3146
3166
 
3147
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
3167
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3148
3168
  bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
3149
3169
  std::array<int32_t, 4> permute = {0, 1, 2, 3})
3150
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3170
+ : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3151
3171
 
3152
3172
  ggml_tensor * build_graph(ggml_context * ctx) override {
3153
3173
  const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3165,13 +3185,13 @@ struct test_flash_attn_ext : public test_case {
3165
3185
  return t;
3166
3186
  };
3167
3187
 
3168
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
3188
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
3169
3189
  ggml_set_name(q, "q");
3170
3190
 
3171
- ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3191
+ ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3172
3192
  ggml_set_name(k, "k");
3173
3193
 
3174
- ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3194
+ ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3175
3195
  ggml_set_name(v, "v");
3176
3196
 
3177
3197
  ggml_tensor * m = nullptr;
@@ -3751,10 +3771,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3751
3771
  std::default_random_engine rng(0);
3752
3772
 
3753
3773
  // unary ops
3754
- for (int v : {0, 1}) {
3755
- for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
3756
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 2, 2, 2 }, v));
3757
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 5, 7, 11, 13 }, v));
3774
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3775
+ for (int v : {0, 1}) {
3776
+ for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
3777
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
3778
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
3779
+ }
3758
3780
  }
3759
3781
  }
3760
3782
 
@@ -3860,7 +3882,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3860
3882
  test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
3861
3883
  test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
3862
3884
 
3863
- test_cases.emplace_back(new test_count_equal());
3885
+ test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
3886
+ test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
3864
3887
 
3865
3888
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3866
3889
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
@@ -3885,8 +3908,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3885
3908
  test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
3886
3909
  test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
3887
3910
  test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
3888
- test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
3889
- test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
3890
3911
  }
3891
3912
 
3892
3913
  test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
@@ -3938,41 +3959,42 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3938
3959
  test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
3939
3960
 
3940
3961
  auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
3941
- for (auto op : {ggml_add, ggml_mul, ggml_div}) {
3962
+ for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
3942
3963
  test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
3943
3964
  }
3944
3965
  };
3945
-
3946
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
3947
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
3948
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
3949
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1});
3950
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1});
3951
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1});
3952
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1});
3953
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1});
3954
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1});
3955
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2});
3956
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2});
3957
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2});
3958
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2});
3959
-
3960
- // stable diffusion
3961
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
3962
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
3963
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
3964
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
3965
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
3966
- add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
3967
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
3968
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
3969
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
3970
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
3971
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
3972
- add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
3973
- add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
3974
- //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3975
- //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
3966
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3967
+ add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
3968
+ add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
3969
+ add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
3970
+ add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
3971
+ add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
3972
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
3973
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
3974
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
3975
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
3976
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
3977
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
3978
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
3979
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
3980
+
3981
+ // stable diffusion
3982
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
3983
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
3984
+ add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
3985
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
3986
+ add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
3987
+ add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
3988
+ add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
3989
+ add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
3990
+ add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
3991
+ add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
3992
+ add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
3993
+ add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
3994
+ add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
3995
+ //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3996
+ //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
3997
+ }
3976
3998
 
3977
3999
  test_cases.emplace_back(new test_add1());
3978
4000
  test_cases.emplace_back(new test_scale());
@@ -4091,7 +4113,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4091
4113
  for (int n_mats : {4, 8}) {
4092
4114
  for (int n_used : {1, 2, 4}) {
4093
4115
  for (bool b : {false, true}) {
4094
- for (int n : {1, 32}) {
4116
+ for (int n : {1, 32, 129}) {
4095
4117
  int m = 512;
4096
4118
  int k = 256;
4097
4119
  test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
@@ -4136,12 +4158,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4136
4158
  }
4137
4159
  }
4138
4160
 
4139
- test_cases.emplace_back(new test_sqr());
4140
- test_cases.emplace_back(new test_sqrt());
4141
- test_cases.emplace_back(new test_log());
4142
- test_cases.emplace_back(new test_sin());
4143
- test_cases.emplace_back(new test_cos());
4144
- test_cases.emplace_back(new test_clamp());
4161
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4162
+ test_cases.emplace_back(new test_sqr(type));
4163
+ test_cases.emplace_back(new test_sqrt(type));
4164
+ test_cases.emplace_back(new test_log(type));
4165
+ test_cases.emplace_back(new test_sin(type));
4166
+ test_cases.emplace_back(new test_cos(type));
4167
+ test_cases.emplace_back(new test_clamp(type));
4168
+ }
4145
4169
 
4146
4170
  test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
4147
4171
  test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
@@ -4278,14 +4302,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4278
4302
  if (!mask && max_bias > 0.0f) continue;
4279
4303
  for (float logit_softcap : {0.0f, 10.0f}) {
4280
4304
  if (hs != 128 && logit_softcap != 0.0f) continue;
4281
- for (int nh : { 32, }) {
4282
- for (int kv : { 512, 1024, }) {
4283
- for (int nb : { 1, 3, 32, 35, }) {
4284
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4285
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
4286
- // run fewer test cases permuted
4287
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4288
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4305
+ for (int nh : { 4, }) {
4306
+ for (int nr : { 1, 4, 16 }) {
4307
+ if (nr == 16 && hs != 128) continue;
4308
+ for (int kv : { 512, 1024, }) {
4309
+ if (nr != 1 && kv != 512) continue;
4310
+ for (int nb : { 1, 3, 32, 35, }) {
4311
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4312
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
4313
+ // run fewer test cases permuted
4314
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4315
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4316
+ }
4289
4317
  }
4290
4318
  }
4291
4319
  }
@@ -4360,9 +4388,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4360
4388
  return test_cases;
4361
4389
  }
4362
4390
 
4363
- static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
4391
+ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
4392
+ auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
4393
+ if (params_filter == nullptr) {
4394
+ return;
4395
+ }
4396
+
4397
+ std::regex params_filter_regex(params_filter);
4398
+
4399
+ for (auto it = test_cases.begin(); it != test_cases.end();) {
4400
+ if (!std::regex_search((*it)->vars(), params_filter_regex)) {
4401
+ it = test_cases.erase(it);
4402
+ continue;
4403
+ }
4404
+
4405
+ it++;
4406
+ }
4407
+ };
4408
+
4364
4409
  if (mode == MODE_TEST) {
4365
4410
  auto test_cases = make_test_cases_eval();
4411
+ filter_test_cases(test_cases, params_filter);
4366
4412
  ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
4367
4413
  if (backend_cpu == NULL) {
4368
4414
  printf(" Failed to initialize CPU backend\n");
@@ -4384,6 +4430,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4384
4430
 
4385
4431
  if (mode == MODE_GRAD) {
4386
4432
  auto test_cases = make_test_cases_eval();
4433
+ filter_test_cases(test_cases, params_filter);
4387
4434
  size_t n_ok = 0;
4388
4435
  for (auto & test : test_cases) {
4389
4436
  if (test->eval_grad(backend, op_name)) {
@@ -4397,6 +4444,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4397
4444
 
4398
4445
  if (mode == MODE_PERF) {
4399
4446
  auto test_cases = make_test_cases_perf();
4447
+ filter_test_cases(test_cases, params_filter);
4400
4448
  for (auto & test : test_cases) {
4401
4449
  test->eval_perf(backend, op_name);
4402
4450
  }
@@ -4407,7 +4455,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4407
4455
  }
4408
4456
 
4409
4457
  static void usage(char ** argv) {
4410
- printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
4458
+ printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>]\n", argv[0]);
4411
4459
  printf(" valid modes:\n");
4412
4460
  printf(" - test (default, compare with CPU backend for correctness)\n");
4413
4461
  printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
@@ -4417,8 +4465,9 @@ static void usage(char ** argv) {
4417
4465
 
4418
4466
  int main(int argc, char ** argv) {
4419
4467
  test_mode mode = MODE_TEST;
4420
- const char * op_name_filter = NULL;
4421
- const char * backend_filter = NULL;
4468
+ const char * op_name_filter = nullptr;
4469
+ const char * backend_filter = nullptr;
4470
+ const char * params_filter = nullptr;
4422
4471
 
4423
4472
  for (int i = 1; i < argc; i++) {
4424
4473
  if (strcmp(argv[i], "test") == 0) {
@@ -4441,6 +4490,13 @@ int main(int argc, char ** argv) {
4441
4490
  usage(argv);
4442
4491
  return 1;
4443
4492
  }
4493
+ } else if (strcmp(argv[i], "-p") == 0) {
4494
+ if (i + 1 < argc) {
4495
+ params_filter = argv[++i];
4496
+ } else {
4497
+ usage(argv);
4498
+ return 1;
4499
+ }
4444
4500
  } else {
4445
4501
  usage(argv);
4446
4502
  return 1;
@@ -4487,7 +4543,7 @@ int main(int argc, char ** argv) {
4487
4543
  printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
4488
4544
  printf("\n");
4489
4545
 
4490
- bool ok = test_backend(backend, mode, op_name_filter);
4546
+ bool ok = test_backend(backend, mode, op_name_filter, params_filter);
4491
4547
 
4492
4548
  printf(" Backend %s: ", ggml_backend_name(backend));
4493
4549
  if (ok) {