@fugood/llama.node 0.2.0 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. package/CMakeLists.txt +9 -0
  2. package/README.md +1 -1
  3. package/bin/darwin/arm64/default.metallib +0 -0
  4. package/bin/darwin/arm64/llama-node.node +0 -0
  5. package/bin/darwin/x64/default.metallib +0 -0
  6. package/bin/darwin/x64/llama-node.node +0 -0
  7. package/bin/linux/arm64/llama-node.node +0 -0
  8. package/bin/linux/x64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  10. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  11. package/bin/win32/arm64/llama-node.node +0 -0
  12. package/bin/win32/arm64/node.lib +0 -0
  13. package/bin/win32/x64/llama-node.node +0 -0
  14. package/bin/win32/x64/node.lib +0 -0
  15. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/arm64/node.lib +0 -0
  17. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  18. package/bin/win32-vulkan/x64/node.lib +0 -0
  19. package/lib/binding.ts +1 -1
  20. package/package.json +2 -1
  21. package/patches/llama.patch +22 -0
  22. package/src/LlamaContext.cpp +2 -2
  23. package/src/TokenizeWorker.cpp +1 -1
  24. package/src/llama.cpp/CMakeLists.txt +82 -54
  25. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +16 -0
  26. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +6 -0
  27. package/src/llama.cpp/common/common.cpp +748 -754
  28. package/src/llama.cpp/common/common.h +49 -41
  29. package/src/llama.cpp/common/grammar-parser.cpp +10 -1
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +6 -6
  31. package/src/llama.cpp/common/log.h +5 -5
  32. package/src/llama.cpp/common/sampling.cpp +92 -10
  33. package/src/llama.cpp/common/sampling.h +6 -1
  34. package/src/llama.cpp/common/train.cpp +2 -2
  35. package/src/llama.cpp/examples/CMakeLists.txt +3 -0
  36. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  37. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  38. package/src/llama.cpp/examples/embedding/embedding.cpp +13 -4
  39. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +2 -2
  40. package/src/llama.cpp/examples/finetune/finetune.cpp +4 -3
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -2
  42. package/src/llama.cpp/examples/infill/infill.cpp +8 -8
  43. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +57 -8
  44. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +55 -0
  45. package/src/llama.cpp/examples/llama.android/{app → llama}/src/main/cpp/CMakeLists.txt +7 -8
  46. package/src/llama.cpp/examples/llama.android/{app → llama}/src/main/cpp/llama-android.cpp +14 -14
  47. package/src/llama.cpp/examples/llava/clip.h +1 -1
  48. package/src/llama.cpp/examples/llava/llava-cli.cpp +27 -7
  49. package/src/llama.cpp/examples/llava/llava.cpp +0 -15
  50. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  51. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  52. package/src/llama.cpp/examples/main/main.cpp +29 -17
  53. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  54. package/src/llama.cpp/examples/perplexity/perplexity.cpp +9 -9
  55. package/src/llama.cpp/examples/quantize/quantize.cpp +2 -2
  56. package/src/llama.cpp/examples/retrieval/retrieval.cpp +2 -2
  57. package/src/llama.cpp/examples/rpc/CMakeLists.txt +2 -0
  58. package/src/llama.cpp/examples/rpc/rpc-server.cpp +134 -0
  59. package/src/llama.cpp/examples/server/server.cpp +33 -25
  60. package/src/llama.cpp/examples/server/utils.hpp +1 -1
  61. package/src/llama.cpp/examples/tokenize/tokenize.cpp +359 -9
  62. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +4 -3
  63. package/src/llama.cpp/ggml-backend.c +2 -3
  64. package/src/llama.cpp/ggml-common.h +0 -54
  65. package/src/llama.cpp/ggml-cuda.h +1 -0
  66. package/src/llama.cpp/ggml-impl.h +51 -0
  67. package/src/llama.cpp/ggml-kompute.cpp +13 -3
  68. package/src/llama.cpp/ggml-opencl.cpp +4 -1
  69. package/src/llama.cpp/ggml-quants.c +3715 -2050
  70. package/src/llama.cpp/ggml-rpc.cpp +1155 -0
  71. package/src/llama.cpp/ggml-rpc.h +24 -0
  72. package/src/llama.cpp/ggml-sycl.cpp +119 -673
  73. package/src/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
  74. package/src/llama.cpp/ggml-vulkan.cpp +203 -224
  75. package/src/llama.cpp/ggml.c +1208 -1483
  76. package/src/llama.cpp/ggml.h +71 -46
  77. package/src/llama.cpp/llama.cpp +1374 -938
  78. package/src/llama.cpp/llama.h +22 -6
  79. package/src/llama.cpp/requirements.txt +0 -2
  80. package/src/llama.cpp/tests/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/tests/test-backend-ops.cpp +120 -57
  82. package/src/llama.cpp/tests/test-chat-template.cpp +16 -4
  83. package/src/llama.cpp/tests/test-grad0.cpp +43 -83
  84. package/src/llama.cpp/tests/test-grammar-integration.cpp +46 -0
  85. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +27 -3
  86. package/src/llama.cpp/unicode-data.cpp +6969 -2169
  87. package/src/llama.cpp/unicode-data.h +15 -12
  88. package/src/llama.cpp/unicode.cpp +89 -111
  89. package/src/llama.cpp/unicode.h +44 -12
  90. package/src/llama.cpp/build.zig +0 -172
  91. package/src/llama.cpp/ggml-mpi.c +0 -216
  92. package/src/llama.cpp/ggml-mpi.h +0 -39
  93. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +0 -2
  94. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +0 -2
@@ -81,9 +81,11 @@ extern "C" {
81
81
  LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
82
82
  LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
83
83
  LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
84
- LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
85
- LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
86
- LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
84
+ LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
85
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
86
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
87
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
88
+ LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
87
89
  };
88
90
 
89
91
  // note: these values should be synchronized with ggml_rope
@@ -242,6 +244,9 @@ extern "C" {
242
244
  // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
243
245
  const float * tensor_split;
244
246
 
247
+ // comma separated list of RPC servers to use for offloading
248
+ const char * rpc_servers;
249
+
245
250
  // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
246
251
  // If the provided progress_callback returns true, model loading continues.
247
252
  // If it returns false, model loading is immediately aborted.
@@ -260,6 +265,8 @@ extern "C" {
260
265
  bool check_tensors; // validate model tensor data
261
266
  };
262
267
 
268
+ // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
269
+ // https://github.com/ggerganov/llama.cpp/pull/7544
263
270
  struct llama_context_params {
264
271
  uint32_t seed; // RNG seed, -1 for random
265
272
  uint32_t n_ctx; // text context, 0 = from model
@@ -286,14 +293,14 @@ extern "C" {
286
293
  ggml_backend_sched_eval_callback cb_eval;
287
294
  void * cb_eval_user_data;
288
295
 
289
- enum ggml_type type_k; // data type for K cache
290
- enum ggml_type type_v; // data type for V cache
296
+ enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
297
+ enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
291
298
 
292
299
  // Keep the booleans together to avoid misalignment during copy-by-value.
293
300
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
294
301
  bool embeddings; // if true, extract embeddings (together with logits)
295
302
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
296
- bool flash_attn; // whether to use flash attention
303
+ bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
297
304
 
298
305
  // Abort callback
299
306
  // if it returns true, execution of llama_decode() will be aborted
@@ -755,6 +762,12 @@ extern "C" {
755
762
  // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
756
763
  LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
757
764
 
765
+ // Get the number of threads used for generation of a single token.
766
+ LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
767
+
768
+ // Get the number of threads used for prompt and batch processing (multiple token).
769
+ LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
770
+
758
771
  // Set whether to use causal attention or not
759
772
  // If set to true, the model will only attend to the past tokens
760
773
  LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@@ -813,6 +826,9 @@ extern "C" {
813
826
  // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
814
827
  LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
815
828
 
829
+ // Identify if Token Id is a control token or a render-able token
830
+ LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
831
+
816
832
  // Special tokens
817
833
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
818
834
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@@ -9,5 +9,3 @@
9
9
  -r ./requirements/requirements-convert-hf-to-gguf.txt
10
10
  -r ./requirements/requirements-convert-hf-to-gguf-update.txt
11
11
  -r ./requirements/requirements-convert-llama-ggml-to-gguf.txt
12
- -r ./requirements/requirements-convert-lora-to-ggml.txt
13
- -r ./requirements/requirements-convert-persimmon-to-gguf.txt
@@ -92,7 +92,7 @@ target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
92
92
  install(TARGETS test-tokenizer-1-bpe RUNTIME)
93
93
 
94
94
  # TODO: disabled due to slowness
95
- #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
95
+ #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
96
96
  #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
97
97
  #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
98
98
  #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
@@ -2,6 +2,7 @@
2
2
  #include <ggml-alloc.h>
3
3
  #include <ggml-backend.h>
4
4
  #include <ggml-backend-impl.h>
5
+
5
6
  #include <algorithm>
6
7
  #include <array>
7
8
  #include <cfloat>
@@ -15,6 +16,7 @@
15
16
  #include <thread>
16
17
  #include <vector>
17
18
 
19
+
18
20
  static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
19
21
  // static RNG initialization (revisit if n_threads stops being constant)
20
22
  static const size_t n_threads = std::thread::hardware_concurrency();
@@ -48,6 +50,22 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
48
50
  t.join();
49
51
  }
50
52
 
53
+ #if 0
54
+ const char * val_str = getenv("GGML_TEST_EPS");
55
+ float val = 1e-9f;
56
+ if (val_str != nullptr) {
57
+ val = std::stof(val_str);
58
+ printf("GGML_TEST_EPS=%e\n", val);
59
+ }
60
+
61
+ // test quantization with very small values that may result in nan scales due to division by zero
62
+ if (ggml_is_quantized(tensor->type)) {
63
+ for (int i = 0; i < 256; i++) {
64
+ data[i] = val;
65
+ }
66
+ }
67
+ #endif
68
+
51
69
  if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
52
70
  ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
53
71
  } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
@@ -63,6 +81,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
63
81
  }
64
82
  }
65
83
  ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
84
+ GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
66
85
  ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
67
86
  } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
68
87
  // This is going to create some weird integers though.
@@ -1111,11 +1130,7 @@ struct test_soft_max : public test_case {
1111
1130
  if (this->mask) {
1112
1131
  mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
1113
1132
  }
1114
- ggml_tensor * pos = nullptr;
1115
- if (max_bias > 0.0f) {
1116
- pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
1117
- }
1118
- ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
1133
+ ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
1119
1134
  return out;
1120
1135
  }
1121
1136
  };
@@ -1127,20 +1142,22 @@ struct test_rope : public test_case {
1127
1142
  int n_dims;
1128
1143
  int mode;
1129
1144
  int n_ctx;
1145
+ bool ff;
1130
1146
 
1131
1147
  std::string vars() override {
1132
- return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
1148
+ return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
1133
1149
  }
1134
1150
 
1135
1151
  test_rope(ggml_type type = GGML_TYPE_F32,
1136
1152
  std::array<int64_t, 4> ne = {10, 10, 10, 1},
1137
- int n_dims = 10, int mode = 0, int n_ctx = 512)
1138
- : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
1153
+ int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
1154
+ : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
1139
1155
 
1140
1156
  ggml_tensor * build_graph(ggml_context * ctx) override {
1141
1157
  ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1142
1158
  ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
1143
- ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
1159
+ ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
1160
+ ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
1144
1161
  return out;
1145
1162
  }
1146
1163
 
@@ -1154,7 +1171,12 @@ struct test_rope : public test_case {
1154
1171
  }
1155
1172
  ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
1156
1173
  } else {
1157
- init_tensor_uniform(t);
1174
+ if (t->ne[0] == n_dims/2) {
1175
+ // frequency factors in the range [0.9f, 1.1f]
1176
+ init_tensor_uniform(t, 0.9f, 1.1f);
1177
+ } else {
1178
+ init_tensor_uniform(t);
1179
+ }
1158
1180
  }
1159
1181
  }
1160
1182
  }
@@ -1237,22 +1259,26 @@ struct test_im2col : public test_case {
1237
1259
  // GGML_OP_CONCAT
1238
1260
  struct test_concat : public test_case {
1239
1261
  const ggml_type type;
1240
- const std::array<int64_t, 4> ne;
1241
- const int64_t b_ne2;
1262
+ const std::array<int64_t, 4> ne_a;
1263
+ const int64_t ne_b_d;
1264
+ const int dim;
1242
1265
 
1243
1266
  std::string vars() override {
1244
- return VARS_TO_STR3(type, ne, b_ne2);
1267
+ return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
1245
1268
  }
1246
1269
 
1247
1270
  test_concat(ggml_type type = GGML_TYPE_F32,
1248
- std::array<int64_t, 4> ne = {10, 10, 10, 10},
1249
- int64_t b_ne2 = 10)
1250
- : type(type), ne(ne), b_ne2(b_ne2) {}
1271
+ std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
1272
+ int64_t ne_b_d = 10,
1273
+ int dim = 2)
1274
+ : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
1251
1275
 
1252
1276
  ggml_tensor * build_graph(ggml_context * ctx) override {
1253
- ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1254
- ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]);
1255
- ggml_tensor * out = ggml_concat(ctx, a, b);
1277
+ auto ne_b = ne_a;
1278
+ ne_b[dim] = ne_b_d;
1279
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1280
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
1281
+ ggml_tensor * out = ggml_concat(ctx, a, b, dim);
1256
1282
  return out;
1257
1283
  }
1258
1284
  };
@@ -1332,23 +1358,47 @@ struct test_upscale : public test_case {
1332
1358
  const ggml_type type;
1333
1359
  const std::array<int64_t, 4> ne;
1334
1360
  const int32_t scale_factor;
1361
+ const bool transpose;
1335
1362
 
1336
1363
  std::string vars() override {
1337
- return VARS_TO_STR3(type, ne, scale_factor);
1364
+ return VARS_TO_STR4(type, ne, scale_factor, transpose);
1338
1365
  }
1339
1366
 
1340
1367
  test_upscale(ggml_type type = GGML_TYPE_F32,
1341
1368
  std::array<int64_t, 4> ne = {512, 512, 3, 1},
1342
- int32_t scale_factor = 2)
1343
- : type(type), ne(ne), scale_factor(scale_factor) {}
1369
+ int32_t scale_factor = 2, bool transpose = false)
1370
+ : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
1344
1371
 
1345
1372
  ggml_tensor * build_graph(ggml_context * ctx) override {
1346
1373
  ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1374
+ if (transpose) a = ggml_transpose(ctx, a);
1347
1375
  ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
1348
1376
  return out;
1349
1377
  }
1350
1378
  };
1351
1379
 
1380
+ // GGML_OP_UPSCALE (ext)
1381
+ struct test_upscale_ext : public test_case {
1382
+ const ggml_type type;
1383
+ const std::array<int64_t, 4> ne;
1384
+ const std::array<int64_t, 4> ne_tgt;
1385
+
1386
+ std::string vars() override {
1387
+ return VARS_TO_STR3(type, ne, ne_tgt);
1388
+ }
1389
+
1390
+ test_upscale_ext(ggml_type type = GGML_TYPE_F32,
1391
+ std::array<int64_t, 4> ne = {2, 5, 7, 11},
1392
+ std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
1393
+ : type(type), ne(ne), ne_tgt(ne_tgt) {}
1394
+
1395
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1396
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1397
+ ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
1398
+ return out;
1399
+ }
1400
+ };
1401
+
1352
1402
  // GGML_OP_GROUP_NORM
1353
1403
  struct test_group_norm : public test_case {
1354
1404
  const ggml_type type;
@@ -1490,23 +1540,27 @@ struct test_flash_attn_ext : public test_case {
1490
1540
  const int64_t kv; // kv size
1491
1541
  const int64_t nb; // batch size
1492
1542
 
1543
+ const bool mask; // use mask
1544
+
1545
+ const float max_bias; // ALiBi
1546
+
1493
1547
  std::string vars() override {
1494
- return VARS_TO_STR4(hs, nh, kv, nb);
1548
+ return VARS_TO_STR6(hs, nh, kv, nb, mask, max_bias);
1495
1549
  }
1496
1550
 
1497
1551
  double max_nmse_err() override {
1498
1552
  return 5e-4;
1499
1553
  }
1500
1554
 
1501
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
1502
- : hs(hs), nh(nh), kv(kv), nb(nb) {}
1555
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f)
1556
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {}
1503
1557
 
1504
1558
  ggml_tensor * build_graph(ggml_context * ctx) override {
1505
1559
  ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
1506
1560
  ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
1507
1561
  ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
1508
- ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
1509
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
1562
+ ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
1563
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
1510
1564
  return out;
1511
1565
  }
1512
1566
  };
@@ -1611,7 +1665,7 @@ public:
1611
1665
 
1612
1666
  struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
1613
1667
 
1614
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
1668
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
1615
1669
 
1616
1670
  // split cached v into n_head heads
1617
1671
  struct ggml_tensor * v =
@@ -1720,14 +1774,14 @@ struct test_llama : public test_llm {
1720
1774
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
1721
1775
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
1722
1776
 
1723
- Qcur = ggml_rope_custom(
1724
- ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
1777
+ Qcur = ggml_rope_ext(
1778
+ ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
1725
1779
  hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
1726
1780
  ext_factor, attn_factor, beta_fast, beta_slow
1727
1781
  );
1728
1782
 
1729
- Kcur = ggml_rope_custom(
1730
- ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
1783
+ Kcur = ggml_rope_ext(
1784
+ ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
1731
1785
  hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
1732
1786
  ext_factor, attn_factor, beta_fast, beta_slow
1733
1787
  );
@@ -1846,13 +1900,13 @@ struct test_falcon : public test_llm {
1846
1900
  Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
1847
1901
 
1848
1902
  // using mode = 2 for neox mode
1849
- Qcur = ggml_rope_custom(
1850
- ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
1903
+ Qcur = ggml_rope_ext(
1904
+ ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
1851
1905
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
1852
1906
  );
1853
1907
 
1854
- Kcur = ggml_rope_custom(
1855
- ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
1908
+ Kcur = ggml_rope_ext(
1909
+ ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
1856
1910
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
1857
1911
  );
1858
1912
 
@@ -2128,6 +2182,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2128
2182
  #endif
2129
2183
  for (bool mask : {false, true}) {
2130
2184
  for (float max_bias : {0.0f, 8.0f}) {
2185
+ if (!mask && max_bias > 0.0f) continue;
2131
2186
  for (float scale : {1.0f, 0.1f}) {
2132
2187
  for (int64_t ne0 : {16, 1024}) {
2133
2188
  for (int64_t ne1 : {16, 1024}) {
@@ -2141,24 +2196,29 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2141
2196
 
2142
2197
  test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
2143
2198
  test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
2144
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
2145
2199
  test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
2146
2200
 
2147
2201
  for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2148
- test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
2149
- test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
2150
- test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B
2151
- test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B
2152
- test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
2153
- test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
2154
- test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
2155
- test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
2156
- test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
2157
- test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
2202
+ // TODO: ff not supported yet for !neox
2203
+ test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
2204
+ test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
2205
+ test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
2206
+ test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
2207
+
2208
+ for (bool ff : {false, true}) { // freq_factors
2209
+ test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
2210
+ test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
2211
+ test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
2212
+ test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
2213
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
2214
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
2215
+ }
2158
2216
  }
2159
2217
 
2160
- test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
2161
- test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
2218
+ for (int dim : { 0, 1, 2, 3, }) {
2219
+ test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
2220
+ test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
2221
+ }
2162
2222
 
2163
2223
  for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
2164
2224
  test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
@@ -2168,6 +2228,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2168
2228
 
2169
2229
  test_cases.emplace_back(new test_sum_rows());
2170
2230
  test_cases.emplace_back(new test_upscale());
2231
+ test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
2232
+ test_cases.emplace_back(new test_upscale_ext());
2171
2233
  test_cases.emplace_back(new test_group_norm());
2172
2234
  test_cases.emplace_back(new test_acc());
2173
2235
  test_cases.emplace_back(new test_pad());
@@ -2175,15 +2237,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2175
2237
  test_cases.emplace_back(new test_timestep_embedding());
2176
2238
  test_cases.emplace_back(new test_leaky_relu());
2177
2239
 
2178
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2179
- for (int hs : { 64, 128, }) { // other head sizes not implemented
2180
- #else
2181
2240
  for (int hs : { 64, 80, 128, 256, }) {
2182
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2183
- for (int nh : { 32, }) {
2184
- for (int kv : { 512, 1024, }) {
2185
- for (int nb : { 1, 2, 4, 8, }) {
2186
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
2241
+ for (bool mask : { true, false } ) {
2242
+ for (float max_bias : { 0.0f, 8.0f }) {
2243
+ if (!mask && max_bias > 0.0f) continue;
2244
+ for (int nh : { 32, }) {
2245
+ for (int kv : { 512, 1024, }) {
2246
+ for (int nb : { 1, 2, 4, 8, }) {
2247
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias));
2248
+ }
2249
+ }
2187
2250
  }
2188
2251
  }
2189
2252
  }
@@ -49,8 +49,14 @@ int main(void) {
49
49
  "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
50
50
  // Llama-3
51
51
  "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
52
- // Phi-3
53
- "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + ' ' + message['content'] + '<|end|> ' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> ' }}{% else %}{{ eos_token }}{% endif %}"
52
+ //Phi-3-mini
53
+ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
54
+ //Phi-3-small
55
+ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
56
+ //Phi-3-medium
57
+ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
58
+ //Phi-3-vision
59
+ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
54
60
  };
55
61
  std::vector<std::string> expected_output = {
56
62
  // teknium/OpenHermes-2.5-Mistral-7B
@@ -79,8 +85,14 @@ int main(void) {
79
85
  "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
80
86
  // Llama 3
81
87
  "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
82
- // Phi 3
83
- "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\nI am an assistant<|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
88
+ //Phi-3-mini
89
+ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
90
+ //Phi-3-small
91
+ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
92
+ //Phi-3-medium
93
+ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
94
+ //Phi-3-vision
95
+ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
84
96
  };
85
97
  std::vector<char> formatted_chat(1024);
86
98
  int32_t res;
@@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) {
1515
1515
  }
1516
1516
 
1517
1517
  // flash_attn f32
1518
- {
1519
- srand(seed);
1520
- const int nargs = 3;
1521
-
1522
- int64_t ne2[4];
1523
-
1524
- get_random_dims(ne2, 4);
1525
- int64_t D = ne2[0];
1526
- int64_t N = ne2[1];
1527
- int64_t M = ne2[2] + N;
1528
- int64_t B = ne2[3];
1529
-
1530
- for (int masked = 0; masked <= 1; ++masked) {
1531
- for (int ndims = 2; ndims <= 4; ++ndims) {
1532
- int max_nrep = (ndims >= 3) ? 2 : 1;
1533
- for (int nrep = 1; nrep < max_nrep; ++nrep) {
1534
- int64_t neq[4] = { D, N, B*nrep, ne[3] };
1535
- int64_t nek[4] = { D, M, B, ne[3] };
1536
- int64_t nev[4] = { M, D, B, ne[3] };
1537
- if (ndims == 2) {
1538
- neq[2] = 1; neq[3] = 1;
1539
- nek[2] = 1; nek[3] = 1;
1540
- nev[2] = 1; nev[3] = 1;
1541
- } else if (ndims == 3) {
1542
- neq[3] = 1;
1543
- nek[3] = 1;
1544
- nev[3] = 1;
1545
- }
1546
- x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
1547
- x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
1548
- x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
1549
- ggml_set_param(ctx0, x[0]);
1550
- ggml_set_param(ctx0, x[1]);
1551
- ggml_set_param(ctx0, x[2]);
1552
-
1553
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
1554
-
1555
- check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
1556
- }
1557
- }
1558
- }
1559
- }
1560
-
1561
- // flash_attn f16, not yet fully implemented
1562
- if(0)
1563
- {
1564
- srand(seed);
1565
- const int nargs = 3;
1566
-
1567
- int64_t ne2[4];
1568
-
1569
- get_random_dims(ne2, 4);
1570
- int64_t D = ne2[0];
1571
- int64_t N = ne2[1];
1572
- int64_t M = ne2[2] + N;
1573
- int64_t B = ne2[3];
1574
-
1575
- for (int masked = 0; masked <= 1; ++masked) {
1576
- for (int ndims = 2; ndims <= 4; ++ndims) {
1577
- int64_t neq[4] = { D, N, B, ne[3] };
1578
- int64_t nek[4] = { D, M, B, ne[3] };
1579
- int64_t nev[4] = { M, D, B, ne[3] };
1580
- if (ndims == 2) {
1581
- neq[2] = 1; neq[3] = 1;
1582
- nek[2] = 1; nek[3] = 1;
1583
- nev[2] = 1; nev[3] = 1;
1584
- } else if (ndims == 3) {
1585
- neq[3] = 1;
1586
- nek[3] = 1;
1587
- nev[3] = 1;
1588
- }
1589
- x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
1590
- x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
1591
- x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
1592
- ggml_set_param(ctx0, x[0]);
1593
- ggml_set_param(ctx0, x[1]);
1594
- ggml_set_param(ctx0, x[2]);
1595
-
1596
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
1518
+ // TODO: adapt to ggml_flash_attn_ext() changes
1519
+ //{
1520
+ // srand(seed);
1521
+ // const int nargs = 3;
1522
+
1523
+ // int64_t ne2[4];
1524
+
1525
+ // get_random_dims(ne2, 4);
1526
+ // int64_t D = ne2[0];
1527
+ // int64_t N = ne2[1];
1528
+ // int64_t M = ne2[2] + N;
1529
+ // int64_t B = ne2[3];
1530
+
1531
+ // for (int masked = 0; masked <= 1; ++masked) {
1532
+ // for (int ndims = 2; ndims <= 4; ++ndims) {
1533
+ // int max_nrep = (ndims >= 3) ? 2 : 1;
1534
+ // for (int nrep = 1; nrep < max_nrep; ++nrep) {
1535
+ // int64_t neq[4] = { D, N, B*nrep, ne[3] };
1536
+ // int64_t nek[4] = { D, M, B, ne[3] };
1537
+ // int64_t nev[4] = { M, D, B, ne[3] };
1538
+ // if (ndims == 2) {
1539
+ // neq[2] = 1; neq[3] = 1;
1540
+ // nek[2] = 1; nek[3] = 1;
1541
+ // nev[2] = 1; nev[3] = 1;
1542
+ // } else if (ndims == 3) {
1543
+ // neq[3] = 1;
1544
+ // nek[3] = 1;
1545
+ // nev[3] = 1;
1546
+ // }
1547
+ // x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
1548
+ // x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
1549
+ // x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
1550
+ // ggml_set_param(ctx0, x[0]);
1551
+ // ggml_set_param(ctx0, x[1]);
1552
+ // ggml_set_param(ctx0, x[2]);
1553
+
1554
+ // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
1555
+
1556
+ // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
1557
+ // }
1558
+ // }
1559
+ // }
1560
+ //}
1597
1561
 
1598
- check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
1599
- }
1600
- }
1601
- }
1602
1562
  ggml_free(ctx0);
1603
1563
  }
1604
1564