@fugood/llama.node 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +3 -1
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -780,7 +780,7 @@ struct test_case {
780
780
  }
781
781
  }
782
782
  if (!any_params) {
783
- printf("not supported [%s] \n", op_name);
783
+ printf("not supported [%s] \n", op_desc(out).c_str());
784
784
  supported = false;
785
785
  }
786
786
  if (!supported) {
@@ -1130,6 +1130,59 @@ struct test_get_rows : public test_case {
1130
1130
  }
1131
1131
  };
1132
1132
 
1133
+ // GGML_OP_GET_ROWS_BACK
1134
+ struct test_get_rows_back : public test_case {
1135
+ const ggml_type type;
1136
+ const int n; // cols
1137
+ const int m; // rows
1138
+ const int r; // rows to get
1139
+ const int b; // batch size
1140
+ const bool v; // view (non-contiguous src1)
1141
+
1142
+ std::string vars() override {
1143
+ return VARS_TO_STR6(type, n, m, r, b, v);
1144
+ }
1145
+
1146
+ test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
1147
+ : type(type), n(n), m(m), r(r), b(b), v(v) {}
1148
+
1149
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1150
+ ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);
1151
+ ggml_set_name(in_forward, "in_forward");
1152
+
1153
+ ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
1154
+ ggml_set_name(rows, "rows");
1155
+ if (v) {
1156
+ rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
1157
+ ggml_set_name(rows, "view_of_rows");
1158
+ }
1159
+
1160
+ ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);
1161
+ ggml_set_name(grad, "grad");
1162
+
1163
+ ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);
1164
+ ggml_set_name(out, "out");
1165
+
1166
+ return out;
1167
+ }
1168
+
1169
+ void initialize_tensors(ggml_context * ctx) override {
1170
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1171
+ if (t->type == GGML_TYPE_I32) {
1172
+ if (ggml_is_view_op(t->op)) { continue; }
1173
+ // rows
1174
+ std::vector<int> data(r*b);
1175
+ for (int i = 0; i < r*b; i++) {
1176
+ data[i] = rand() % m;
1177
+ }
1178
+ ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
1179
+ } else {
1180
+ init_tensor_uniform(t);
1181
+ }
1182
+ }
1183
+ }
1184
+ };
1185
+
1133
1186
  // GGML_OP_ARGMAX
1134
1187
  struct test_argmax : public test_case {
1135
1188
  const ggml_type type;
@@ -1531,6 +1584,39 @@ struct test_scale : public test_case {
1531
1584
  }
1532
1585
  };
1533
1586
 
1587
+ // GGML_OP_SILU_BACK
1588
+ struct test_silu_back : public test_case {
1589
+ const ggml_type type;
1590
+ const std::array<int64_t, 4> ne;
1591
+ float eps;
1592
+
1593
+ std::string vars() override {
1594
+ return VARS_TO_STR3(type, ne, eps);
1595
+ }
1596
+
1597
+ test_silu_back(ggml_type type = GGML_TYPE_F32,
1598
+ std::array<int64_t, 4> ne = {64, 5, 4, 3},
1599
+ float eps = 1e-6f)
1600
+ : type(type), ne(ne), eps(eps) {}
1601
+
1602
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1603
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1604
+ ggml_set_name(a, "a");
1605
+
1606
+ ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
1607
+ ggml_set_name(grad, "grad");
1608
+
1609
+ ggml_tensor * out = ggml_silu_back(ctx, a, grad);
1610
+ ggml_set_name(out, "out");
1611
+
1612
+ return out;
1613
+ }
1614
+
1615
+ bool grad_precise() override {
1616
+ return true;
1617
+ }
1618
+ };
1619
+
1534
1620
  // GGML_OP_NORM
1535
1621
  struct test_norm : public test_case {
1536
1622
  const ggml_type type;
@@ -1583,11 +1669,56 @@ struct test_rms_norm : public test_case {
1583
1669
  return out;
1584
1670
  }
1585
1671
 
1672
+ void initialize_tensors(ggml_context * ctx) override {
1673
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1674
+ init_tensor_uniform(t, -10.f, 10.f);
1675
+ }
1676
+ }
1677
+
1678
+ float grad_eps() override {
1679
+ return 1.0f;
1680
+ }
1681
+
1586
1682
  bool grad_precise() override {
1587
1683
  return true;
1588
1684
  }
1589
1685
  };
1590
1686
 
1687
+ // GGML_OP_RMS_NORM_BACK
1688
+ struct test_rms_norm_back : public test_case {
1689
+ const ggml_type type;
1690
+ const std::array<int64_t, 4> ne;
1691
+ float eps;
1692
+
1693
+ std::string vars() override {
1694
+ return VARS_TO_STR3(type, ne, eps);
1695
+ }
1696
+
1697
+ test_rms_norm_back(ggml_type type = GGML_TYPE_F32,
1698
+ std::array<int64_t, 4> ne = {64, 5, 4, 3},
1699
+ float eps = 1e-6f)
1700
+ : type(type), ne(ne), eps(eps) {}
1701
+
1702
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1703
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1704
+ ggml_set_name(a, "a");
1705
+
1706
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
1707
+ ggml_set_name(b, "b");
1708
+
1709
+ ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);
1710
+ ggml_set_name(out, "out");
1711
+
1712
+ return out;
1713
+ }
1714
+
1715
+ void initialize_tensors(ggml_context * ctx) override {
1716
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1717
+ init_tensor_uniform(t, -10.f, 10.f);
1718
+ }
1719
+ }
1720
+ };
1721
+
1591
1722
  // GGML_OP_SSM_CONV
1592
1723
  struct test_ssm_conv : public test_case {
1593
1724
  const ggml_type type;
@@ -1659,17 +1790,46 @@ struct test_rwkv_wkv6 : public test_case {
1659
1790
 
1660
1791
  ggml_tensor * build_graph(ggml_context * ctx) override {
1661
1792
  const int64_t n_tokens = n_seq_tokens * n_seqs;
1662
- ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1663
- ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
1664
- ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1793
+ ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1794
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1795
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1665
1796
  ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
1666
- ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1797
+ ggml_tensor * td = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1667
1798
  ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1668
1799
  ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
1669
1800
  return out;
1670
1801
  }
1671
1802
  };
1672
1803
 
1804
+ // GGML_OP_GATED_LINEAR_ATTN
1805
+ struct test_gla : public test_case {
1806
+ const ggml_type type;
1807
+
1808
+ const int64_t head_count;
1809
+ const int64_t head_size;
1810
+ const int64_t n_seq_tokens;
1811
+ const int64_t n_seqs;
1812
+
1813
+ std::string vars() override {
1814
+ return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1815
+ }
1816
+
1817
+ test_gla(ggml_type type = GGML_TYPE_F32,
1818
+ int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1819
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1820
+
1821
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1822
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1823
+ ggml_tensor * q = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1824
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1825
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1826
+ ggml_tensor * g = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1827
+ ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1828
+ ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
1829
+ return out;
1830
+ }
1831
+ };
1832
+
1673
1833
  // GGML_OP_MUL_MAT
1674
1834
  struct test_mul_mat : public test_case {
1675
1835
  const ggml_type type_a;
@@ -1826,10 +1986,11 @@ struct test_out_prod : public test_case {
1826
1986
  const int64_t n;
1827
1987
  const int64_t k;
1828
1988
  const std::array<int64_t, 2> bs; // dims 3 and 4
1989
+ const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1829
1990
  const bool trans_b;
1830
1991
 
1831
1992
  std::string vars() override {
1832
- return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b);
1993
+ return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);
1833
1994
  }
1834
1995
 
1835
1996
  double max_nmse_err() override {
@@ -1839,8 +2000,9 @@ struct test_out_prod : public test_case {
1839
2000
  test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
1840
2001
  int64_t m = 32, int64_t n = 32, int64_t k = 32,
1841
2002
  std::array<int64_t, 2> bs = {10, 10},
2003
+ std::array<int64_t, 2> nr = {2, 2},
1842
2004
  bool trans_b = false)
1843
- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {}
2005
+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}
1844
2006
 
1845
2007
  ggml_tensor * build_graph(ggml_context * ctx) override {
1846
2008
  ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
@@ -1848,10 +2010,10 @@ struct test_out_prod : public test_case {
1848
2010
 
1849
2011
  ggml_tensor * b;
1850
2012
  if (trans_b) {
1851
- b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]);
2013
+ b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1852
2014
  b = ggml_transpose(ctx, b);
1853
2015
  } else {
1854
- b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]);
2016
+ b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);
1855
2017
  }
1856
2018
  ggml_set_name(b, "b");
1857
2019
 
@@ -2162,8 +2324,38 @@ struct test_soft_max : public test_case {
2162
2324
  }
2163
2325
  };
2164
2326
 
2327
+ // GGML_OP_SOFT_MAX_BACK
2328
+ struct test_soft_max_back : public test_case {
2329
+ const ggml_type type;
2330
+ const std::array<int64_t, 4> ne;
2331
+ const float scale;
2332
+ const float max_bias;
2333
+
2334
+ std::string vars() override {
2335
+ return VARS_TO_STR4(type, ne, scale, max_bias);
2336
+ }
2337
+
2338
+ test_soft_max_back(ggml_type type = GGML_TYPE_F32,
2339
+ std::array<int64_t, 4> ne = {10, 5, 4, 3},
2340
+ float scale = 1.0f,
2341
+ float max_bias = 0.0f)
2342
+ : type(type), ne(ne), scale(scale), max_bias(max_bias) {}
2165
2343
 
2166
- // GGML_OP_ROPE
2344
+ ggml_tensor * build_graph(ggml_context * ctx) override {
2345
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2346
+ ggml_set_name(a, "a");
2347
+
2348
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
2349
+ ggml_set_name(a, "a");
2350
+
2351
+ ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);
2352
+ ggml_set_name(out, "out");
2353
+
2354
+ return out;
2355
+ }
2356
+ };
2357
+
2358
+ // GGML_OP_ROPE + GGML_OP_ROPE_BACK
2167
2359
  struct test_rope : public test_case {
2168
2360
  const ggml_type type;
2169
2361
  const std::array<int64_t, 4> ne_a;
@@ -2175,29 +2367,36 @@ struct test_rope : public test_case {
2175
2367
  float af; // attn_factor
2176
2368
  bool ff;
2177
2369
  int v; // view (1 : non-contiguous a)
2370
+ bool forward;
2178
2371
 
2179
2372
  std::string vars() override {
2373
+ // forward can be inferred from the op, does not need to be printed
2180
2374
  return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
2181
2375
  }
2182
2376
 
2183
2377
  test_rope(ggml_type type = GGML_TYPE_F32,
2184
2378
  std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
2185
- int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
2186
- : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
2379
+ int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
2380
+ float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
2381
+ : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
2187
2382
 
2188
2383
  ggml_tensor * build_graph(ggml_context * ctx) override {
2189
2384
  ggml_tensor * a;
2190
2385
  if (v & 1) {
2191
2386
  auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
2192
2387
  a = ggml_new_tensor(ctx, type, 4, ne.data());
2193
- ggml_set_param(ctx, a);
2388
+ if (forward) {
2389
+ ggml_set_param(ctx, a);
2390
+ }
2194
2391
  ggml_set_name(a, "a");
2195
2392
 
2196
2393
  a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
2197
2394
  ggml_set_name(a, "view_of_a");
2198
2395
  } else {
2199
2396
  a = ggml_new_tensor(ctx, type, 4, ne_a.data());
2200
- ggml_set_param(ctx, a);
2397
+ if (forward) {
2398
+ ggml_set_param(ctx, a);
2399
+ }
2201
2400
  ggml_set_name(a, "a");
2202
2401
  }
2203
2402
 
@@ -2223,14 +2422,26 @@ struct test_rope : public test_case {
2223
2422
  if (is_vision) {
2224
2423
  GGML_ASSERT(n_dims/4 > 0);
2225
2424
  int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
2226
- out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2425
+ if (forward) {
2426
+ out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2427
+ } else {
2428
+ out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2429
+ }
2227
2430
  } else {
2228
2431
  GGML_ASSERT(n_dims/3 > 0);
2229
2432
  int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
2230
- out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2433
+ if (forward) {
2434
+ out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2435
+ } else {
2436
+ out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2437
+ }
2231
2438
  }
2232
2439
  } else {
2233
- out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2440
+ if (forward) {
2441
+ out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2442
+ } else {
2443
+ out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2444
+ }
2234
2445
  }
2235
2446
  ggml_set_name(out, "out");
2236
2447
 
@@ -2835,9 +3046,10 @@ struct test_flash_attn_ext : public test_case {
2835
3046
  const float logit_softcap; // Gemma 2
2836
3047
 
2837
3048
  const ggml_type type_KV;
3049
+ std::array<int32_t, 4> permute;
2838
3050
 
2839
3051
  std::string vars() override {
2840
- return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
3052
+ return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
2841
3053
  }
2842
3054
 
2843
3055
  double max_nmse_err() override {
@@ -2852,19 +3064,33 @@ struct test_flash_attn_ext : public test_case {
2852
3064
  }
2853
3065
 
2854
3066
  test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
2855
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
2856
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
3067
+ bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
3068
+ std::array<int32_t, 4> permute = {0, 1, 2, 3})
3069
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
2857
3070
 
2858
3071
  ggml_tensor * build_graph(ggml_context * ctx) override {
2859
3072
  const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
2860
3073
 
2861
- ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
3074
+ auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
3075
+ int64_t ne[4] = {ne0, ne1, ne2, ne3};
3076
+ int64_t ne_perm[4];
3077
+ for (int i = 0; i < 4; ++i) {
3078
+ ne_perm[permute[i]] = ne[i];
3079
+ }
3080
+ ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
3081
+ if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
3082
+ t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
3083
+ }
3084
+ return t;
3085
+ };
3086
+
3087
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
2862
3088
  ggml_set_name(q, "q");
2863
3089
 
2864
- ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
3090
+ ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
2865
3091
  ggml_set_name(k, "k");
2866
3092
 
2867
- ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
3093
+ ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
2868
3094
  ggml_set_name(v, "v");
2869
3095
 
2870
3096
  ggml_tensor * m = nullptr;
@@ -2932,6 +3158,40 @@ struct test_cross_entropy_loss : public test_case {
2932
3158
  }
2933
3159
  };
2934
3160
 
3161
+ // GGML_OP_CROSS_ENTROPY_LOSS_BACK
3162
+ struct test_cross_entropy_loss_back : public test_case {
3163
+ const ggml_type type;
3164
+ const std::array<int64_t, 4> ne;
3165
+
3166
+ std::string vars() override {
3167
+ return VARS_TO_STR2(type, ne);
3168
+ }
3169
+
3170
+ test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,
3171
+ std::array<int64_t, 4> ne = {10, 5, 4, 3})
3172
+ : type(type), ne(ne) {}
3173
+
3174
+ ggml_tensor * build_graph(ggml_context * ctx) override {
3175
+ ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
3176
+ ggml_set_name(grad, "grad");
3177
+
3178
+ ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
3179
+ ggml_set_name(logits, "logits");
3180
+
3181
+ ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
3182
+ ggml_set_name(labels, "labels");
3183
+
3184
+ // Ensure labels add up to 1:
3185
+ labels = ggml_soft_max(ctx, labels);
3186
+ ggml_set_name(labels, "labels_normalized");
3187
+
3188
+ ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);
3189
+ ggml_set_name(out, "out");
3190
+
3191
+ return out;
3192
+ }
3193
+ };
3194
+
2935
3195
  // GGML_OP_OPT_STEP_ADAMW
2936
3196
  struct test_opt_step_adamw : public test_case {
2937
3197
  const ggml_type type;
@@ -3431,6 +3691,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3431
3691
  }
3432
3692
  }
3433
3693
 
3694
+ test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
3695
+ for (ggml_type type : all_types) {
3696
+ for (bool v : {false, true}) {
3697
+ test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
3698
+ }
3699
+ }
3700
+ for (bool v : {false, true}) {
3701
+ test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
3702
+ }
3703
+
3434
3704
  for (ggml_type type_input : {GGML_TYPE_F32}) {
3435
3705
  for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
3436
3706
  for (int k0 : {1, 3}) {
@@ -3553,6 +3823,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3553
3823
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3554
3824
  }
3555
3825
  }
3826
+ for (ggml_type type_dst : {GGML_TYPE_F32}) {
3827
+ for (ggml_type type_src : all_types) {
3828
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3829
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3830
+ }
3831
+ }
3556
3832
  for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3557
3833
  for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3558
3834
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
@@ -3609,10 +3885,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3609
3885
 
3610
3886
  test_cases.emplace_back(new test_add1());
3611
3887
  test_cases.emplace_back(new test_scale());
3888
+ test_cases.emplace_back(new test_silu_back());
3612
3889
 
3613
- for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
3614
- test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3615
- test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3890
+ for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
3891
+ test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3892
+ test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3893
+ test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3616
3894
  }
3617
3895
 
3618
3896
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
@@ -3626,6 +3904,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3626
3904
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
3627
3905
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
3628
3906
 
3907
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
3908
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
3909
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
3910
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
3911
+
3629
3912
  for (int i = 1; i < 9; ++i) {
3630
3913
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3631
3914
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
@@ -3747,22 +4030,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3747
4030
 
3748
4031
  for (ggml_type type_a : base_types) {
3749
4032
  for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3750
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1, 1}));
3751
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1}));
3752
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1}));
3753
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
3754
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
3755
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
3756
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
3757
-
3758
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}));
3759
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}, true));
3760
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1}));
3761
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1}));
3762
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
3763
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
3764
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
3765
- test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
4033
+ for (int n : {1, 16}) {
4034
+ for (int k : {1, 16}) {
4035
+ for (int bs2 : {1, 3}) {
4036
+ for (int bs3 : {1, 3}) {
4037
+ for (int nr2 : {1, 2}) {
4038
+ for (int nr3 : {1, 2}) {
4039
+ test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));
4040
+ }
4041
+ }
4042
+ }
4043
+ }
4044
+ }
4045
+ }
3766
4046
  }
3767
4047
  }
3768
4048
 
@@ -3805,12 +4085,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3805
4085
  }
3806
4086
  }
3807
4087
  }
3808
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
4088
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
3809
4089
  test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
3810
4090
  test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
3811
4091
  test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
3812
4092
 
3813
- {
4093
+ for (float max_bias : {0.0f, 8.0f}) {
4094
+ for (float scale : {1.0f, 0.1f}) {
4095
+ for (int64_t ne0 : {16, 1024}) {
4096
+ for (int64_t ne1 : {16, 1024}) {
4097
+ test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 1, 1}, scale, max_bias));
4098
+ test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
4099
+ }
4100
+ }
4101
+ }
4102
+ }
4103
+
4104
+ for (bool fw : {true, false}) { // fw == forward
3814
4105
  bool all = true;
3815
4106
 
3816
4107
  for (float v : { 0, 1 }) {
@@ -3819,29 +4110,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3819
4110
  for (float af : { 1.0f, 1.4245f }) {
3820
4111
  for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3821
4112
  for (bool ff : {false, true}) { // freq_factors
3822
- test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
4113
+ test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
3823
4114
 
3824
4115
  if (all) {
3825
- test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
3826
- test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
3827
- test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
4116
+ test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
4117
+ test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
4118
+ test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
3828
4119
  }
3829
4120
 
3830
4121
  if (all) {
3831
- test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
3832
- test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
3833
- test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
3834
- test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
3835
- test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
4122
+ test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
4123
+ test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
4124
+ test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
4125
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
4126
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
3836
4127
  }
3837
4128
 
3838
4129
  if (all) {
3839
- test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
3840
- test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
3841
- test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
4130
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
4131
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
4132
+ test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
3842
4133
  }
3843
4134
 
3844
- test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
4135
+ test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
3845
4136
  }
3846
4137
  }
3847
4138
 
@@ -3891,6 +4182,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3891
4182
  for (int nb : { 1, 3, 32, 35, }) {
3892
4183
  for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
3893
4184
  test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
4185
+ // run fewer test cases permuted
4186
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4187
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4188
+ }
3894
4189
  }
3895
4190
  }
3896
4191
  }
@@ -3900,7 +4195,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3900
4195
  }
3901
4196
  }
3902
4197
 
3903
- test_cases.emplace_back(new test_cross_entropy_loss());
4198
+ test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
4199
+ test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
4200
+ test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3}));
4201
+ test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
4202
+
3904
4203
  test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
3905
4204
 
3906
4205
  // these tests are disabled to save execution time, but they can be handy for debugging
@@ -3937,7 +4236,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3937
4236
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3938
4237
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
3939
4238
 
3940
- for (int bs : {1, 512}) {
4239
+ for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
3941
4240
  for (ggml_type type_a : all_types) {
3942
4241
  for (ggml_type type_b : {GGML_TYPE_F32}) {
3943
4242
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1, 1}, {1, 1}));
@@ -3945,6 +4244,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3945
4244
  }
3946
4245
  }
3947
4246
 
4247
+ for (int K : {3, 5}) {
4248
+ for (int IC : {256, 2560}) {
4249
+ for (int IW_IH : {32, 64, 256}) {
4250
+ if (IC == 2560 && IW_IH == 256) {
4251
+ // too big
4252
+ continue;
4253
+ }
4254
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {IW_IH, IW_IH, IC, 1}, {K, K, IC, 1}, 1, 1, 1, 1, 1, 1, true));
4255
+ }
4256
+ }
4257
+ }
4258
+
3948
4259
  return test_cases;
3949
4260
  }
3950
4261