@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -1,10 +1,14 @@
1
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
2
  #include "ggml.h"
3
3
 
4
+ #include <cfloat>
4
5
  #include <cmath>
6
+ #include <cstdint>
5
7
  #include <cstdio>
6
8
  #include <cstdlib>
7
9
  #include <cassert>
10
+ #include <initializer_list>
11
+ #include <vector>
8
12
 
9
13
  #if defined(_MSC_VER)
10
14
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -217,7 +221,8 @@ static bool check_gradient(
217
221
  int nargs,
218
222
  float eps,
219
223
  float max_error_abs,
220
- float max_error_rel) {
224
+ float max_error_rel,
225
+ std::vector<double> expected_vals) {
221
226
 
222
227
  static int n_threads = -1;
223
228
  if (n_threads < 0) {
@@ -239,8 +244,10 @@ static bool check_gradient(
239
244
 
240
245
  ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
241
246
 
242
- ggml_graph_reset (gf);
243
- ggml_set_f32 (f->grad, 1.0f);
247
+ ggml_graph_reset(gb);
248
+ if (f->grad) {
249
+ ggml_set_f32(f->grad, 1.0f);
250
+ }
244
251
 
245
252
  ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
246
253
 
@@ -248,9 +255,10 @@ static bool check_gradient(
248
255
  // ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot");
249
256
 
250
257
  for (int i = 0; i < nargs; ++i) {
258
+ bool all_g0_bad = true;
251
259
  const int nelements = ggml_nelements(x[i]);
252
260
  for (int k = 0; k < nelements; ++k) {
253
- // compute gradient using finite differences
261
+ // Calculate gradient numerically:
254
262
  const float x0 = ggml_get_f32_1d(x[i], k);
255
263
  const float xm = x0 - eps;
256
264
  const float xp = x0 + eps;
@@ -267,18 +275,42 @@ static bool check_gradient(
267
275
  const double f1 = ggml_get_f32_1d(f, 0);
268
276
  const double g0 = (f0 - f1)/(2.0*(double) eps);
269
277
 
278
+ // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU).
279
+ // In such cases, provide a vector of expected values and skip the comparison for failed calculations.
280
+ if (!expected_vals.empty()) {
281
+ bool matches_any = false;
282
+ for (const double & ev : expected_vals) {
283
+ const double error_abs = std::fabs(g0 - ev);
284
+ if (error_abs > max_error_abs) {
285
+ continue;
286
+ }
287
+ const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0;
288
+ if (error_rel > max_error_rel) {
289
+ continue;
290
+ }
291
+ matches_any = true;
292
+ break;
293
+ }
294
+ if (!matches_any) {
295
+ continue;
296
+ }
297
+ }
298
+ all_g0_bad = false;
299
+
270
300
  ggml_set_f32_1d(x[i], k, x0);
271
301
 
272
302
  // compute gradient using backward graph
273
- ggml_graph_reset (gf);
274
- ggml_set_f32 (f->grad, 1.0f);
303
+ ggml_graph_reset(gb);
304
+ if (f->grad) {
305
+ ggml_set_f32(f->grad, 1.0f);
306
+ }
275
307
 
276
308
  ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
277
309
 
278
310
  const double g1 = ggml_get_f32_1d(x[i]->grad, k);
279
311
 
280
312
  const double error_abs = fabs(g0 - g1);
281
- const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0;
313
+ const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0;
282
314
 
283
315
  if (error_abs > max_error_abs || error_rel > max_error_rel) {
284
316
  printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
@@ -287,6 +319,10 @@ static bool check_gradient(
287
319
  return false;
288
320
  }
289
321
  }
322
+ if (all_g0_bad) {
323
+ printf("%s: numerical calculation of the gradient failed for all values\n", op_name);
324
+ return false;
325
+ }
290
326
  }
291
327
 
292
328
  return true;
@@ -404,7 +440,7 @@ int main(int argc, const char ** argv) {
404
440
  seed_iter = rand();
405
441
  unsigned seed = rand();
406
442
 
407
- printf("test-grad0: iter:%d/%d\n", iter, niter);
443
+ printf("test-grad0: iter:%d/%d\n", (iter+1), niter);
408
444
  struct ggml_context * ctx0 = ggml_init(params);
409
445
 
410
446
  get_random_dims(ne, 4);
@@ -424,7 +460,7 @@ int main(int argc, const char ** argv) {
424
460
 
425
461
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
426
462
 
427
- check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
463
+ check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {});
428
464
  }
429
465
  }
430
466
 
@@ -441,7 +477,7 @@ int main(int argc, const char ** argv) {
441
477
 
442
478
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
443
479
 
444
- check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
480
+ check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {});
445
481
  }
446
482
  }
447
483
 
@@ -458,7 +494,7 @@ int main(int argc, const char ** argv) {
458
494
 
459
495
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
460
496
 
461
- check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
497
+ check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
462
498
  }
463
499
  }
464
500
 
@@ -475,7 +511,7 @@ int main(int argc, const char ** argv) {
475
511
 
476
512
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
477
513
 
478
- check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
514
+ check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
479
515
  }
480
516
  }
481
517
 
@@ -492,7 +528,7 @@ int main(int argc, const char ** argv) {
492
528
 
493
529
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
494
530
 
495
- check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
531
+ check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {});
496
532
  }
497
533
  }
498
534
 
@@ -509,7 +545,7 @@ int main(int argc, const char ** argv) {
509
545
 
510
546
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
511
547
 
512
- check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
548
+ check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
513
549
  }
514
550
  }
515
551
 
@@ -526,7 +562,7 @@ int main(int argc, const char ** argv) {
526
562
 
527
563
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
528
564
 
529
- check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f);
565
+ check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {});
530
566
  }
531
567
  }
532
568
 
@@ -543,7 +579,7 @@ int main(int argc, const char ** argv) {
543
579
 
544
580
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
545
581
 
546
- check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
582
+ check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {});
547
583
  }
548
584
  }
549
585
 
@@ -560,7 +596,7 @@ int main(int argc, const char ** argv) {
560
596
 
561
597
  struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
562
598
 
563
- check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
599
+ check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
564
600
  }
565
601
  }
566
602
 
@@ -578,7 +614,7 @@ int main(int argc, const char ** argv) {
578
614
 
579
615
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
580
616
 
581
- check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
617
+ check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
582
618
  }
583
619
  }
584
620
 
@@ -596,7 +632,7 @@ int main(int argc, const char ** argv) {
596
632
 
597
633
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
598
634
 
599
- check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
635
+ check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
600
636
  }
601
637
  }
602
638
 
@@ -614,7 +650,7 @@ int main(int argc, const char ** argv) {
614
650
 
615
651
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
616
652
 
617
- check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
653
+ check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
618
654
  }
619
655
  }
620
656
 
@@ -637,7 +673,7 @@ int main(int argc, const char ** argv) {
637
673
 
638
674
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
639
675
 
640
- check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
676
+ check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
641
677
  }
642
678
  }
643
679
 
@@ -660,25 +696,25 @@ int main(int argc, const char ** argv) {
660
696
 
661
697
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
662
698
 
663
- check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
699
+ check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
664
700
  }
665
701
  }
666
702
 
667
- // abs (finite differences do not work)
668
- //{
669
- // const int nargs = 1;
703
+ // abs
704
+ {
705
+ const int nargs = 1;
670
706
 
671
- // for (int ndims = 1; ndims <= 2; ++ndims) {
672
- // for (int i = 0; i < nargs; ++i) {
673
- // x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
674
- // ggml_set_param(ctx0, x[i]);
675
- // }
707
+ for (int ndims = 1; ndims <= 4; ++ndims) {
708
+ for (int i = 0; i < nargs; ++i) {
709
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
710
+ ggml_set_param(ctx0, x[i]);
711
+ }
676
712
 
677
- // struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
713
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
678
714
 
679
- // check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
680
- // }
681
- //}
715
+ check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0});
716
+ }
717
+ }
682
718
 
683
719
  // sgn
684
720
  {
@@ -693,7 +729,7 @@ int main(int argc, const char ** argv) {
693
729
 
694
730
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
695
731
 
696
- check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
732
+ check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
697
733
  }
698
734
  }
699
735
 
@@ -710,7 +746,7 @@ int main(int argc, const char ** argv) {
710
746
 
711
747
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
712
748
 
713
- check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
749
+ check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
714
750
  }
715
751
  }
716
752
 
@@ -727,7 +763,7 @@ int main(int argc, const char ** argv) {
727
763
 
728
764
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
729
765
 
730
- check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
766
+ check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
731
767
  }
732
768
  }
733
769
 
@@ -745,7 +781,7 @@ int main(int argc, const char ** argv) {
745
781
 
746
782
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
747
783
 
748
- check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
784
+ check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
749
785
  }
750
786
  }
751
787
 
@@ -776,7 +812,7 @@ int main(int argc, const char ** argv) {
776
812
 
777
813
  GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
778
814
 
779
- check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
815
+ check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
780
816
  if (ndims == 2) {
781
817
  // check_mat_mul does not support ndims > 2
782
818
  check_mat_mul(m, x[1], x[0]);
@@ -800,7 +836,7 @@ int main(int argc, const char ** argv) {
800
836
 
801
837
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
802
838
 
803
- check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
839
+ check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
804
840
  }
805
841
  }
806
842
 
@@ -817,7 +853,7 @@ int main(int argc, const char ** argv) {
817
853
 
818
854
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
819
855
 
820
- check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
856
+ check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0});
821
857
  }
822
858
  }
823
859
 
@@ -835,7 +871,7 @@ int main(int argc, const char ** argv) {
835
871
 
836
872
  struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
837
873
 
838
- check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
874
+ check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
839
875
  }
840
876
  }
841
877
 
@@ -854,9 +890,9 @@ int main(int argc, const char ** argv) {
854
890
 
855
891
  #ifdef GGML_SILU_FP16
856
892
  // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
857
- check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
893
+ check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {});
858
894
  #else
859
- check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
895
+ check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
860
896
  #endif
861
897
  }
862
898
  }
@@ -874,7 +910,7 @@ int main(int argc, const char ** argv) {
874
910
 
875
911
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
876
912
 
877
- check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
913
+ check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {});
878
914
  }
879
915
  }
880
916
 
@@ -892,7 +928,7 @@ int main(int argc, const char ** argv) {
892
928
 
893
929
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s));
894
930
 
895
- check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
931
+ check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
896
932
  }
897
933
  }
898
934
 
@@ -910,7 +946,7 @@ int main(int argc, const char ** argv) {
910
946
 
911
947
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
912
948
 
913
- check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
949
+ check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
914
950
  }
915
951
  }
916
952
 
@@ -928,7 +964,7 @@ int main(int argc, const char ** argv) {
928
964
 
929
965
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
930
966
 
931
- check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
967
+ check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
932
968
  }
933
969
  }
934
970
 
@@ -952,7 +988,7 @@ int main(int argc, const char ** argv) {
952
988
 
953
989
 
954
990
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
955
- check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
991
+ check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
956
992
  }
957
993
  }
958
994
 
@@ -976,7 +1012,7 @@ int main(int argc, const char ** argv) {
976
1012
 
977
1013
 
978
1014
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
979
- check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1015
+ check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
980
1016
  }
981
1017
  }
982
1018
 
@@ -1004,7 +1040,7 @@ int main(int argc, const char ** argv) {
1004
1040
 
1005
1041
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
1006
1042
 
1007
- check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1043
+ check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1008
1044
  }
1009
1045
  }
1010
1046
 
@@ -1037,7 +1073,7 @@ int main(int argc, const char ** argv) {
1037
1073
 
1038
1074
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
1039
1075
 
1040
- check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1076
+ check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1041
1077
  }
1042
1078
  }
1043
1079
 
@@ -1072,7 +1108,7 @@ int main(int argc, const char ** argv) {
1072
1108
 
1073
1109
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
1074
1110
 
1075
- check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1111
+ check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1076
1112
  }
1077
1113
  }
1078
1114
 
@@ -1109,7 +1145,7 @@ int main(int argc, const char ** argv) {
1109
1145
 
1110
1146
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
1111
1147
 
1112
- check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1148
+ check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1113
1149
  }
1114
1150
  }
1115
1151
 
@@ -1137,7 +1173,7 @@ int main(int argc, const char ** argv) {
1137
1173
 
1138
1174
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
1139
1175
 
1140
- check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1176
+ check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1141
1177
  }
1142
1178
  }
1143
1179
 
@@ -1170,7 +1206,7 @@ int main(int argc, const char ** argv) {
1170
1206
 
1171
1207
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
1172
1208
 
1173
- check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1209
+ check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1174
1210
  }
1175
1211
  }
1176
1212
 
@@ -1194,7 +1230,7 @@ int main(int argc, const char ** argv) {
1194
1230
 
1195
1231
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
1196
1232
 
1197
- check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1233
+ check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1198
1234
  }
1199
1235
  }
1200
1236
 
@@ -1225,7 +1261,7 @@ int main(int argc, const char ** argv) {
1225
1261
 
1226
1262
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
1227
1263
 
1228
- check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1264
+ check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1229
1265
  }
1230
1266
  }
1231
1267
 
@@ -1257,7 +1293,7 @@ int main(int argc, const char ** argv) {
1257
1293
 
1258
1294
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
1259
1295
 
1260
- check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1296
+ check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1261
1297
  }
1262
1298
  }
1263
1299
 
@@ -1291,7 +1327,7 @@ int main(int argc, const char ** argv) {
1291
1327
  // sum requires contiguous tensor rows
1292
1328
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
1293
1329
 
1294
- check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1330
+ check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1295
1331
  }
1296
1332
  }
1297
1333
 
@@ -1319,7 +1355,7 @@ int main(int argc, const char ** argv) {
1319
1355
  // sum requires contiguous tensor rows
1320
1356
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
1321
1357
 
1322
- check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1358
+ check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1323
1359
  }
1324
1360
  }
1325
1361
 
@@ -1337,7 +1373,7 @@ int main(int argc, const char ** argv) {
1337
1373
 
1338
1374
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
1339
1375
 
1340
- check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1376
+ check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1341
1377
  }
1342
1378
 
1343
1379
  // diag_mask_inf
@@ -1353,7 +1389,7 @@ int main(int argc, const char ** argv) {
1353
1389
 
1354
1390
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
1355
1391
 
1356
- check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1392
+ check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1357
1393
  }
1358
1394
 
1359
1395
  // diag_mask_zero
@@ -1369,7 +1405,7 @@ int main(int argc, const char ** argv) {
1369
1405
 
1370
1406
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
1371
1407
 
1372
- check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1408
+ check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1373
1409
  }
1374
1410
 
1375
1411
  // softmax
@@ -1395,7 +1431,7 @@ int main(int argc, const char ** argv) {
1395
1431
  1.0f - eps),
1396
1432
  ggml_new_f32(ctx0, eps))));
1397
1433
 
1398
- check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
1434
+ check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {});
1399
1435
  // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
1400
1436
  // this may result in different gradients too finite differences.
1401
1437
  // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
@@ -1412,7 +1448,7 @@ int main(int argc, const char ** argv) {
1412
1448
  get_random_dims(ne2, 4);
1413
1449
 
1414
1450
  for (int ndims = 1; ndims <= 4; ++ndims) {
1415
- x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f);
1451
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
1416
1452
  x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
1417
1453
  // the second argument to cross_entropy_loss must sum up to 1 for each row
1418
1454
  int nr = ggml_nrows(x[1]);
@@ -1430,7 +1466,7 @@ int main(int argc, const char ** argv) {
1430
1466
 
1431
1467
  struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
1432
1468
 
1433
- check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY);
1469
+ check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
1434
1470
  }
1435
1471
  }
1436
1472
 
@@ -1468,7 +1504,7 @@ int main(int argc, const char ** argv) {
1468
1504
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
1469
1505
 
1470
1506
  GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
1471
- check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
1507
+ check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
1472
1508
  }
1473
1509
  }
1474
1510
  }
@@ -1508,12 +1544,93 @@ int main(int argc, const char ** argv) {
1508
1544
  struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
1509
1545
 
1510
1546
  GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
1511
- check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
1547
+ check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
1512
1548
  }
1513
1549
  }
1514
1550
  }
1515
1551
  }
1516
1552
 
1553
+ // im2col f32
1554
+ {
1555
+ srand(seed);
1556
+ const int nargs = 1;
1557
+ const int ndims = 4;
1558
+
1559
+ for (const bool is_2D : {false, true}) {
1560
+ int64_t ne0[ndims];
1561
+ int64_t ne1[ndims];
1562
+ get_random_dims(ne0, ndims);
1563
+ get_random_dims(ne1, ndims);
1564
+
1565
+ // // Ensure that the output is not zero-sized:
1566
+ ne1[0] += 8;
1567
+ ne1[1] += 8;
1568
+
1569
+ if (is_2D) {
1570
+ ne1[2] = ne0[2];
1571
+ } else {
1572
+ ne1[1] = ne0[1];
1573
+ ne0[3] = 1;
1574
+ ne1[3] = 1;
1575
+ }
1576
+
1577
+ // The order of arguments is swapped because the first tensor is only used for its shape.
1578
+ x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f);
1579
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f);
1580
+
1581
+ ggml_set_param(ctx0, x[0]);
1582
+
1583
+ const int s0 = 1 + irand(2);
1584
+ const int s1 = is_2D ? 1 + irand(2) : 0;
1585
+ const int p0 = 0 + irand(2);
1586
+ const int p1 = is_2D ? 0 + irand(2) : 0;
1587
+ const int d0 = 1 + irand(2);
1588
+ const int d1 = is_2D ? 1 + irand(2) : 0;
1589
+
1590
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32));
1591
+
1592
+ GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1);
1593
+ check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
1594
+ }
1595
+ }
1596
+
1597
+ // pool_2d f32
1598
+ {
1599
+ srand(seed);
1600
+ const int nargs = 1;
1601
+ const int ndims = 4;
1602
+
1603
+ for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
1604
+ int64_t ne0[ndims];
1605
+ get_random_dims(ne0, ndims);
1606
+
1607
+ ne0[0] += 8;
1608
+ ne0[1] += 8;
1609
+
1610
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f);
1611
+
1612
+ ggml_set_param(ctx0, x[0]);
1613
+
1614
+ const int k0 = 2 + irand(2);
1615
+ const int k1 = 2 + irand(2);
1616
+ const int s0 = 2 + irand(2);
1617
+ const int s1 = 2 + irand(2);
1618
+ const int p0 = 0 + irand(2);
1619
+ const int p1 = 0 + irand(2);
1620
+
1621
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1));
1622
+
1623
+ GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n",
1624
+ op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1);
1625
+ std::vector<double> expected_vals;
1626
+ if (op == GGML_OP_POOL_MAX) {
1627
+ expected_vals.push_back(0.0);
1628
+ expected_vals.push_back(1.0);
1629
+ }
1630
+ check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals);
1631
+ }
1632
+ }
1633
+
1517
1634
  // flash_attn f32
1518
1635
  // TODO: adapt to ggml_flash_attn_ext() changes
1519
1636
  //{
@@ -1553,7 +1670,7 @@ int main(int argc, const char ** argv) {
1553
1670
 
1554
1671
  // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
1555
1672
 
1556
- // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
1673
+ // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {});
1557
1674
  // }
1558
1675
  // }
1559
1676
  // }