@fugood/llama.node 0.3.13 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +60 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  25. package/src/llama.cpp/common/arg.cpp +112 -11
  26. package/src/llama.cpp/common/chat.cpp +960 -266
  27. package/src/llama.cpp/common/chat.h +135 -0
  28. package/src/llama.cpp/common/common.cpp +27 -171
  29. package/src/llama.cpp/common/common.h +27 -67
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  31. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  32. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  33. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  34. package/src/llama.cpp/common/sampling.cpp +45 -7
  35. package/src/llama.cpp/common/speculative.cpp +6 -5
  36. package/src/llama.cpp/common/speculative.h +1 -1
  37. package/src/llama.cpp/docs/build.md +45 -7
  38. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  39. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  40. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
  42. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  43. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  44. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  45. package/src/llama.cpp/examples/llava/clip.h +19 -3
  46. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  47. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  48. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  49. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  50. package/src/llama.cpp/examples/main/main.cpp +73 -28
  51. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  52. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  53. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  54. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  55. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  56. package/src/llama.cpp/examples/run/run.cpp +110 -67
  57. package/src/llama.cpp/examples/server/server.cpp +82 -87
  58. package/src/llama.cpp/examples/server/utils.hpp +94 -107
  59. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  60. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  61. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  62. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  63. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  64. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  65. package/src/llama.cpp/ggml/include/ggml.h +5 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  68. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  69. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  70. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  71. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  72. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  73. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  74. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  75. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  76. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  77. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  78. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
  79. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
  80. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  81. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  82. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  83. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  84. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  85. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  86. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  87. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  89. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  90. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  91. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  92. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
  93. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  94. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  95. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  96. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  97. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  98. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  99. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  100. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  101. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  102. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  103. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  104. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  105. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  106. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  107. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
  108. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  109. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  111. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  112. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
  113. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  114. package/src/llama.cpp/ggml/src/ggml.c +8 -3
  115. package/src/llama.cpp/include/llama.h +19 -5
  116. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  117. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  118. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  119. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  120. package/src/llama.cpp/requirements.txt +1 -0
  121. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  122. package/src/llama.cpp/src/llama-arch.h +1 -0
  123. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  124. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  125. package/src/llama.cpp/src/llama-grammar.h +12 -3
  126. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  127. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  128. package/src/llama.cpp/src/llama-model.cpp +69 -5
  129. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  130. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  131. package/src/llama.cpp/src/llama.cpp +147 -0
  132. package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
  133. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  134. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  135. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  136. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  137. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  138. package/src/llama.cpp/common/chat.hpp +0 -55
  139. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -18,20 +18,22 @@
18
18
  #include <ggml.h>
19
19
  #include <ggml-alloc.h>
20
20
  #include <ggml-backend.h>
21
+ #include <ggml-cpp.h>
21
22
 
22
23
  #include <algorithm>
23
24
  #include <array>
24
25
  #include <cfloat>
26
+ #include <cinttypes>
25
27
  #include <cstdint>
28
+ #include <cstdio>
29
+ #include <cstdlib>
26
30
  #include <cstring>
27
- #include <cinttypes>
31
+ #include <future>
28
32
  #include <memory>
29
33
  #include <random>
30
- #include <stdio.h>
31
- #include <stdlib.h>
34
+ #include <regex>
32
35
  #include <string>
33
36
  #include <thread>
34
- #include <future>
35
37
  #include <vector>
36
38
 
37
39
  static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
@@ -467,6 +469,7 @@ struct test_case {
467
469
 
468
470
  // allocate
469
471
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
472
+
470
473
  if (buf == NULL) {
471
474
  printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
472
475
  ggml_free(ctx);
@@ -588,14 +591,13 @@ struct test_case {
588
591
  /* .mem_base = */ NULL,
589
592
  /* .no_alloc = */ true,
590
593
  };
591
- ggml_context * ctx = ggml_init(params);
594
+ ggml_context_ptr ctx(ggml_init(params)); // smart ptr
592
595
  GGML_ASSERT(ctx);
593
596
 
594
- ggml_tensor * out = build_graph(ctx);
597
+ ggml_tensor * out = build_graph(ctx.get());
595
598
 
596
599
  if (op_name != nullptr && op_desc(out) != op_name) {
597
600
  //printf(" %s: skipping\n", op_desc(out).c_str());
598
- ggml_free(ctx);
599
601
  return true;
600
602
  }
601
603
 
@@ -605,7 +607,6 @@ struct test_case {
605
607
  // check if backends support op
606
608
  if (!ggml_backend_supports_op(backend, out)) {
607
609
  printf("not supported\n");
608
- ggml_free(ctx);
609
610
  return true;
610
611
  }
611
612
 
@@ -618,22 +619,26 @@ struct test_case {
618
619
  printf("%*s", last - len, "");
619
620
 
620
621
  // allocate
621
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
622
+ ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
623
+
622
624
  if (buf == NULL) {
623
625
  printf("failed to allocate tensors\n");
624
- ggml_free(ctx);
625
626
  return false;
626
627
  }
627
628
 
628
629
  // randomize tensors
629
- initialize_tensors(ctx);
630
+ initialize_tensors(ctx.get());
630
631
 
631
632
  // build graph
632
- ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
633
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
633
634
  ggml_build_forward_expand(gf, out);
634
635
 
635
636
  // warmup run
636
- ggml_backend_graph_compute(backend, gf);
637
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
638
+ if (status != GGML_STATUS_SUCCESS) {
639
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
640
+ return false;
641
+ }
637
642
 
638
643
  // determine number of runs
639
644
  int n_runs;
@@ -684,7 +689,11 @@ struct test_case {
684
689
  int total_runs = 0;
685
690
  do {
686
691
  int64_t start_time = ggml_time_us();
687
- ggml_backend_graph_compute(backend, gf);
692
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
693
+ if (status != GGML_STATUS_SUCCESS) {
694
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
695
+ return false;
696
+ }
688
697
  int64_t end_time = ggml_time_us();
689
698
 
690
699
  total_time_us += end_time - start_time;
@@ -722,10 +731,6 @@ struct test_case {
722
731
  }
723
732
  printf("\n");
724
733
 
725
- ggml_backend_buffer_free(buf);
726
-
727
- ggml_free(ctx);
728
-
729
734
  return true;
730
735
  }
731
736
 
@@ -738,17 +743,16 @@ struct test_case {
738
743
  /* .mem_base = */ NULL,
739
744
  /* .no_alloc = */ true,
740
745
  };
741
- ggml_context * ctx = ggml_init(params);
746
+ ggml_context_ptr ctx(ggml_init(params)); // smart ptr
742
747
  GGML_ASSERT(ctx);
743
748
 
744
- gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
745
- gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
749
+ gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
750
+ gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
746
751
 
747
- ggml_tensor * out = build_graph(ctx);
752
+ ggml_tensor * out = build_graph(ctx.get());
748
753
 
749
754
  if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
750
755
  //printf(" %s: skipping\n", op_desc(out).c_str());
751
- ggml_free(ctx);
752
756
  return true;
753
757
  }
754
758
 
@@ -756,7 +760,6 @@ struct test_case {
756
760
  fflush(stdout);
757
761
 
758
762
  if (out->type != GGML_TYPE_F32) {
759
- ggml_free(ctx);
760
763
  printf("not supported [%s->type != FP32]\n", out->name);
761
764
  return true;
762
765
  }
@@ -764,7 +767,7 @@ struct test_case {
764
767
  // check if the backend supports the ops
765
768
  bool supported = true;
766
769
  bool any_params = false;
767
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
770
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
768
771
  if (!ggml_backend_supports_op(backend, t)) {
769
772
  printf("not supported [%s] ", ggml_backend_name(backend));
770
773
  supported = false;
@@ -785,40 +788,38 @@ struct test_case {
785
788
  }
786
789
  if (!supported) {
787
790
  printf("\n");
788
- ggml_free(ctx);
789
791
  return true;
790
792
  }
791
793
 
792
794
  int64_t ngrads = 0;
793
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
795
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
794
796
  if (t->flags & GGML_TENSOR_FLAG_PARAM) {
795
797
  ngrads += ggml_nelements(t);
796
798
  }
797
799
  }
798
800
  if (ngrads > grad_nmax()) {
799
801
  printf("skipping large tensors for speed \n");
800
- ggml_free(ctx);
801
802
  return true;
802
803
  }
803
804
 
804
805
 
805
806
  if (!ggml_is_scalar(out)) {
806
- out = ggml_sum(ctx, out);
807
+ out = ggml_sum(ctx.get(), out);
807
808
  ggml_set_name(out, "sum_of_out");
808
809
  }
809
810
  ggml_set_loss(out);
810
811
 
811
812
  ggml_build_forward_expand(gf, out);
812
813
  ggml_graph_cpy(gf, gb);
813
- ggml_build_backward_expand(ctx, ctx, gb, false);
814
+ ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
814
815
  if (expect.size() != 1 || expect[0] != 0.0f) {
815
816
  GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
816
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
817
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
817
818
  GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
818
819
  }
819
820
  }
820
821
 
821
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
822
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
822
823
  if (!ggml_backend_supports_op(backend, t)) {
823
824
  printf("not supported [%s] ", ggml_backend_name(backend));
824
825
  supported = false;
@@ -832,27 +833,32 @@ struct test_case {
832
833
  }
833
834
  if (!supported) {
834
835
  printf("\n");
835
- ggml_free(ctx);
836
836
  return true;
837
837
  }
838
838
 
839
839
  // allocate
840
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
840
+ ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
841
841
  if (buf == NULL) {
842
842
  printf("failed to allocate tensors [%s] ", ggml_backend_name(backend));
843
- ggml_free(ctx);
844
843
  return false;
845
844
  }
846
845
 
847
-
848
- initialize_tensors(ctx); // Randomizes all tensors (including gradients).
846
+ initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).
849
847
  ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise.
850
848
 
851
- ggml_backend_graph_compute(backend, gf);
852
- ggml_backend_graph_compute(backend, gb);
849
+ ggml_status status = ggml_backend_graph_compute(backend, gf);
850
+ if (status != GGML_STATUS_SUCCESS) {
851
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
852
+ return false;
853
+ }
854
+ status = ggml_backend_graph_compute(backend, gb);
855
+ if (status != GGML_STATUS_SUCCESS) {
856
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
857
+ return false;
858
+ }
853
859
 
854
860
  bool ok = true;
855
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
861
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
856
862
  if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
857
863
  continue;
858
864
  }
@@ -897,20 +903,36 @@ struct test_case {
897
903
  float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
898
904
 
899
905
  ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
900
- ggml_backend_graph_compute(backend, gf);
906
+ status = ggml_backend_graph_compute(backend, gf);
907
+ if (status != GGML_STATUS_SUCCESS) {
908
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
909
+ return false;
910
+ }
901
911
  ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
902
912
 
903
913
  ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
904
- ggml_backend_graph_compute(backend, gf);
914
+ status = ggml_backend_graph_compute(backend, gf);
915
+ if (status != GGML_STATUS_SUCCESS) {
916
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
917
+ return false;
918
+ }
905
919
  ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
906
920
 
907
921
  if (grad_precise()) {
908
922
  ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
909
- ggml_backend_graph_compute(backend, gf);
923
+ status = ggml_backend_graph_compute(backend, gf);
924
+ if (status != GGML_STATUS_SUCCESS) {
925
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
926
+ return false;
927
+ }
910
928
  ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
911
929
 
912
930
  ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
913
- ggml_backend_graph_compute(backend, gf);
931
+ status = ggml_backend_graph_compute(backend, gf);
932
+ if (status != GGML_STATUS_SUCCESS) {
933
+ fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
934
+ return false;
935
+ }
914
936
  ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
915
937
 
916
938
  gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
@@ -936,10 +958,6 @@ struct test_case {
936
958
  printf("compare failed ");
937
959
  }
938
960
 
939
- ggml_backend_buffer_free(buf);
940
-
941
- ggml_free(ctx);
942
-
943
961
  if (ok) {
944
962
  printf("\033[1;32mOK\033[0m\n");
945
963
  return true;
@@ -3119,6 +3137,7 @@ struct test_leaky_relu : public test_case {
3119
3137
  struct test_flash_attn_ext : public test_case {
3120
3138
  const int64_t hs; // head size
3121
3139
  const int64_t nh; // num heads
3140
+ const int64_t nr; // repeat in Q, tests for grouped-query attention
3122
3141
  const int64_t kv; // kv size
3123
3142
  const int64_t nb; // batch size
3124
3143
 
@@ -3131,7 +3150,7 @@ struct test_flash_attn_ext : public test_case {
3131
3150
  std::array<int32_t, 4> permute;
3132
3151
 
3133
3152
  std::string vars() override {
3134
- return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3153
+ return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3135
3154
  }
3136
3155
 
3137
3156
  double max_nmse_err() override {
@@ -3142,13 +3161,13 @@ struct test_flash_attn_ext : public test_case {
3142
3161
  GGML_UNUSED(t);
3143
3162
  // Just counting matmul costs:
3144
3163
  // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
3145
- return 2 * 2 * nh * nb * hs * kv;
3164
+ return 2 * 2 * nh*nr * nb * hs * kv;
3146
3165
  }
3147
3166
 
3148
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
3167
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3149
3168
  bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
3150
3169
  std::array<int32_t, 4> permute = {0, 1, 2, 3})
3151
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3170
+ : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3152
3171
 
3153
3172
  ggml_tensor * build_graph(ggml_context * ctx) override {
3154
3173
  const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3166,13 +3185,13 @@ struct test_flash_attn_ext : public test_case {
3166
3185
  return t;
3167
3186
  };
3168
3187
 
3169
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
3188
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
3170
3189
  ggml_set_name(q, "q");
3171
3190
 
3172
- ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3191
+ ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3173
3192
  ggml_set_name(k, "k");
3174
3193
 
3175
- ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3194
+ ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3176
3195
  ggml_set_name(v, "v");
3177
3196
 
3178
3197
  ggml_tensor * m = nullptr;
@@ -3752,10 +3771,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3752
3771
  std::default_random_engine rng(0);
3753
3772
 
3754
3773
  // unary ops
3755
- for (int v : {0, 1}) {
3756
- for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
3757
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 2, 2, 2 }, v));
3758
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 5, 7, 11, 13 }, v));
3774
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3775
+ for (int v : {0, 1}) {
3776
+ for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
3777
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
3778
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
3779
+ }
3759
3780
  }
3760
3781
  }
3761
3782
 
@@ -3942,37 +3963,38 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3942
3963
  test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
3943
3964
  }
3944
3965
  };
3945
-
3946
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
3947
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
3948
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
3949
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1});
3950
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1});
3951
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1});
3952
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1});
3953
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1});
3954
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1});
3955
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2});
3956
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2});
3957
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2});
3958
- add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2});
3959
-
3960
- // stable diffusion
3961
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
3962
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
3963
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
3964
- add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
3965
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
3966
- add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
3967
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
3968
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
3969
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
3970
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
3971
- add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
3972
- add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
3973
- add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
3974
- //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3975
- //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
3966
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3967
+ add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
3968
+ add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
3969
+ add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
3970
+ add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
3971
+ add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
3972
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
3973
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
3974
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
3975
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
3976
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
3977
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
3978
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
3979
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
3980
+
3981
+ // stable diffusion
3982
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
3983
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
3984
+ add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
3985
+ add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
3986
+ add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
3987
+ add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
3988
+ add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
3989
+ add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
3990
+ add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
3991
+ add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
3992
+ add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
3993
+ add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
3994
+ add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
3995
+ //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3996
+ //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
3997
+ }
3976
3998
 
3977
3999
  test_cases.emplace_back(new test_add1());
3978
4000
  test_cases.emplace_back(new test_scale());
@@ -4091,7 +4113,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4091
4113
  for (int n_mats : {4, 8}) {
4092
4114
  for (int n_used : {1, 2, 4}) {
4093
4115
  for (bool b : {false, true}) {
4094
- for (int n : {1, 32}) {
4116
+ for (int n : {1, 32, 129}) {
4095
4117
  int m = 512;
4096
4118
  int k = 256;
4097
4119
  test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
@@ -4136,12 +4158,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4136
4158
  }
4137
4159
  }
4138
4160
 
4139
- test_cases.emplace_back(new test_sqr());
4140
- test_cases.emplace_back(new test_sqrt());
4141
- test_cases.emplace_back(new test_log());
4142
- test_cases.emplace_back(new test_sin());
4143
- test_cases.emplace_back(new test_cos());
4144
- test_cases.emplace_back(new test_clamp());
4161
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4162
+ test_cases.emplace_back(new test_sqr(type));
4163
+ test_cases.emplace_back(new test_sqrt(type));
4164
+ test_cases.emplace_back(new test_log(type));
4165
+ test_cases.emplace_back(new test_sin(type));
4166
+ test_cases.emplace_back(new test_cos(type));
4167
+ test_cases.emplace_back(new test_clamp(type));
4168
+ }
4145
4169
 
4146
4170
  test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
4147
4171
  test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
@@ -4278,14 +4302,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4278
4302
  if (!mask && max_bias > 0.0f) continue;
4279
4303
  for (float logit_softcap : {0.0f, 10.0f}) {
4280
4304
  if (hs != 128 && logit_softcap != 0.0f) continue;
4281
- for (int nh : { 32, }) {
4282
- for (int kv : { 512, 1024, }) {
4283
- for (int nb : { 1, 3, 32, 35, }) {
4284
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4285
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
4286
- // run fewer test cases permuted
4287
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4288
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4305
+ for (int nh : { 4, }) {
4306
+ for (int nr : { 1, 4, 16 }) {
4307
+ if (nr == 16 && hs != 128) continue;
4308
+ for (int kv : { 512, 1024, }) {
4309
+ if (nr != 1 && kv != 512) continue;
4310
+ for (int nb : { 1, 3, 32, 35, }) {
4311
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4312
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
4313
+ // run fewer test cases permuted
4314
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4315
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4316
+ }
4289
4317
  }
4290
4318
  }
4291
4319
  }
@@ -4360,9 +4388,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4360
4388
  return test_cases;
4361
4389
  }
4362
4390
 
4363
- static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
4391
+ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
4392
+ auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
4393
+ if (params_filter == nullptr) {
4394
+ return;
4395
+ }
4396
+
4397
+ std::regex params_filter_regex(params_filter);
4398
+
4399
+ for (auto it = test_cases.begin(); it != test_cases.end();) {
4400
+ if (!std::regex_search((*it)->vars(), params_filter_regex)) {
4401
+ it = test_cases.erase(it);
4402
+ continue;
4403
+ }
4404
+
4405
+ it++;
4406
+ }
4407
+ };
4408
+
4364
4409
  if (mode == MODE_TEST) {
4365
4410
  auto test_cases = make_test_cases_eval();
4411
+ filter_test_cases(test_cases, params_filter);
4366
4412
  ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
4367
4413
  if (backend_cpu == NULL) {
4368
4414
  printf(" Failed to initialize CPU backend\n");
@@ -4384,6 +4430,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4384
4430
 
4385
4431
  if (mode == MODE_GRAD) {
4386
4432
  auto test_cases = make_test_cases_eval();
4433
+ filter_test_cases(test_cases, params_filter);
4387
4434
  size_t n_ok = 0;
4388
4435
  for (auto & test : test_cases) {
4389
4436
  if (test->eval_grad(backend, op_name)) {
@@ -4397,6 +4444,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4397
4444
 
4398
4445
  if (mode == MODE_PERF) {
4399
4446
  auto test_cases = make_test_cases_perf();
4447
+ filter_test_cases(test_cases, params_filter);
4400
4448
  for (auto & test : test_cases) {
4401
4449
  test->eval_perf(backend, op_name);
4402
4450
  }
@@ -4407,7 +4455,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4407
4455
  }
4408
4456
 
4409
4457
  static void usage(char ** argv) {
4410
- printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
4458
+ printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>]\n", argv[0]);
4411
4459
  printf(" valid modes:\n");
4412
4460
  printf(" - test (default, compare with CPU backend for correctness)\n");
4413
4461
  printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
@@ -4417,8 +4465,9 @@ static void usage(char ** argv) {
4417
4465
 
4418
4466
  int main(int argc, char ** argv) {
4419
4467
  test_mode mode = MODE_TEST;
4420
- const char * op_name_filter = NULL;
4421
- const char * backend_filter = NULL;
4468
+ const char * op_name_filter = nullptr;
4469
+ const char * backend_filter = nullptr;
4470
+ const char * params_filter = nullptr;
4422
4471
 
4423
4472
  for (int i = 1; i < argc; i++) {
4424
4473
  if (strcmp(argv[i], "test") == 0) {
@@ -4441,6 +4490,13 @@ int main(int argc, char ** argv) {
4441
4490
  usage(argv);
4442
4491
  return 1;
4443
4492
  }
4493
+ } else if (strcmp(argv[i], "-p") == 0) {
4494
+ if (i + 1 < argc) {
4495
+ params_filter = argv[++i];
4496
+ } else {
4497
+ usage(argv);
4498
+ return 1;
4499
+ }
4444
4500
  } else {
4445
4501
  usage(argv);
4446
4502
  return 1;
@@ -4487,7 +4543,7 @@ int main(int argc, char ** argv) {
4487
4543
  printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
4488
4544
  printf("\n");
4489
4545
 
4490
- bool ok = test_backend(backend, mode, op_name_filter);
4546
+ bool ok = test_backend(backend, mode, op_name_filter, params_filter);
4491
4547
 
4492
4548
  printf(" Backend %s: ", ggml_backend_name(backend));
4493
4549
  if (ok) {