@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -187,16 +187,6 @@
187
187
  # define GGML_API
188
188
  #endif
189
189
 
190
- #ifdef GGML_MULTIPLATFORM
191
- # if defined(_WIN32)
192
- # define GGML_CALL
193
- # else
194
- # define GGML_CALL __attribute__((__ms_abi__))
195
- # endif
196
- #else
197
- # define GGML_CALL
198
- #endif
199
-
200
190
  // TODO: support for clang
201
191
  #ifdef __GNUC__
202
192
  # define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
@@ -220,7 +210,7 @@
220
210
  #include <stdio.h>
221
211
 
222
212
  #define GGML_FILE_MAGIC 0x67676d6c // "ggml"
223
- #define GGML_FILE_VERSION 1
213
+ #define GGML_FILE_VERSION 2
224
214
 
225
215
  #define GGML_QNT_VERSION 2 // bump this on quantization format changes
226
216
  #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
@@ -229,12 +219,16 @@
229
219
  #define GGML_MAX_PARAMS 2048
230
220
  #define GGML_MAX_CONTEXTS 64
231
221
  #define GGML_MAX_SRC 10
222
+ #define GGML_MAX_N_THREADS 512
223
+ #define GGML_MAX_OP_PARAMS 64
224
+
232
225
  #ifndef GGML_MAX_NAME
233
- #define GGML_MAX_NAME 64
226
+ # define GGML_MAX_NAME 64
234
227
  #endif
235
- #define GGML_MAX_OP_PARAMS 64
228
+
236
229
  #define GGML_DEFAULT_N_THREADS 4
237
230
  #define GGML_DEFAULT_GRAPH_SIZE 2048
231
+
238
232
  #if UINTPTR_MAX == 0xFFFFFFFF
239
233
  #define GGML_MEM_ALIGN 4
240
234
  #else
@@ -244,6 +238,8 @@
244
238
  #define GGML_EXIT_SUCCESS 0
245
239
  #define GGML_EXIT_ABORTED 1
246
240
 
241
+ #define GGML_ROPE_TYPE_NEOX 2
242
+
247
243
  #define GGUF_MAGIC "GGUF"
248
244
 
249
245
  #define GGUF_VERSION 3
@@ -255,21 +251,21 @@
255
251
  #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
256
252
 
257
253
  #ifndef NDEBUG
258
- #define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
254
+ # define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
259
255
  #elif defined(__GNUC__)
260
- #define GGML_UNREACHABLE() __builtin_unreachable()
256
+ # define GGML_UNREACHABLE() __builtin_unreachable()
261
257
  #elif defined(_MSC_VER)
262
- #define GGML_UNREACHABLE() __assume(0)
258
+ # define GGML_UNREACHABLE() __assume(0)
263
259
  #else
264
- #define GGML_UNREACHABLE() ((void) 0)
260
+ # define GGML_UNREACHABLE() ((void) 0)
265
261
  #endif
266
262
 
267
263
  #ifdef __cplusplus
268
- #define GGML_NORETURN [[noreturn]]
264
+ # define GGML_NORETURN [[noreturn]]
269
265
  #elif defined(_MSC_VER)
270
- #define GGML_NORETURN __declspec(noreturn)
266
+ # define GGML_NORETURN __declspec(noreturn)
271
267
  #else
272
- #define GGML_NORETURN _Noreturn
268
+ # define GGML_NORETURN _Noreturn
273
269
  #endif
274
270
 
275
271
  #define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
@@ -334,7 +330,7 @@ extern "C" {
334
330
  };
335
331
 
336
332
  // get ggml_status name string
337
- GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
333
+ GGML_API const char * ggml_status_to_string(enum ggml_status status);
338
334
 
339
335
  // ieee 754-2008 half-precision float16
340
336
  // todo: make this not an integral type
@@ -349,10 +345,12 @@ extern "C" {
349
345
  GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
350
346
  GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
351
347
  GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
348
+ GGML_API void ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t);
352
349
  GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
353
350
 
354
351
  struct ggml_object;
355
352
  struct ggml_context;
353
+ struct ggml_cgraph;
356
354
 
357
355
  // NOTE: always add types at the end of the enum to keep backward compatibility
358
356
  enum ggml_type {
@@ -390,6 +388,8 @@ extern "C" {
390
388
  GGML_TYPE_Q4_0_4_4 = 31,
391
389
  GGML_TYPE_Q4_0_4_8 = 32,
392
390
  GGML_TYPE_Q4_0_8_8 = 33,
391
+ GGML_TYPE_TQ1_0 = 34,
392
+ GGML_TYPE_TQ2_0 = 35,
393
393
  GGML_TYPE_COUNT,
394
394
  };
395
395
 
@@ -450,10 +450,13 @@ extern "C" {
450
450
  GGML_OP_SQR,
451
451
  GGML_OP_SQRT,
452
452
  GGML_OP_LOG,
453
+ GGML_OP_SIN,
454
+ GGML_OP_COS,
453
455
  GGML_OP_SUM,
454
456
  GGML_OP_SUM_ROWS,
455
457
  GGML_OP_MEAN,
456
458
  GGML_OP_ARGMAX,
459
+ GGML_OP_COUNT_EQUAL,
457
460
  GGML_OP_REPEAT,
458
461
  GGML_OP_REPEAT_BACK,
459
462
  GGML_OP_CONCAT,
@@ -487,9 +490,11 @@ extern "C" {
487
490
  GGML_OP_CLAMP,
488
491
  GGML_OP_CONV_TRANSPOSE_1D,
489
492
  GGML_OP_IM2COL,
493
+ GGML_OP_IM2COL_BACK,
490
494
  GGML_OP_CONV_TRANSPOSE_2D,
491
495
  GGML_OP_POOL_1D,
492
496
  GGML_OP_POOL_2D,
497
+ GGML_OP_POOL_2D_BACK,
493
498
  GGML_OP_UPSCALE, // nearest interpolate
494
499
  GGML_OP_PAD,
495
500
  GGML_OP_ARANGE,
@@ -505,6 +510,7 @@ extern "C" {
505
510
  GGML_OP_WIN_UNPART,
506
511
  GGML_OP_GET_REL_POS,
507
512
  GGML_OP_ADD_REL_POS,
513
+ GGML_OP_RWKV_WKV,
508
514
 
509
515
  GGML_OP_UNARY,
510
516
 
@@ -521,6 +527,7 @@ extern "C" {
521
527
 
522
528
  GGML_OP_CROSS_ENTROPY_LOSS,
523
529
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
530
+ GGML_OP_OPT_STEP_ADAMW,
524
531
 
525
532
  GGML_OP_COUNT,
526
533
  };
@@ -539,6 +546,7 @@ extern "C" {
539
546
  GGML_UNARY_OP_SILU,
540
547
  GGML_UNARY_OP_HARDSWISH,
541
548
  GGML_UNARY_OP_HARDSIGMOID,
549
+ GGML_UNARY_OP_EXP,
542
550
 
543
551
  GGML_UNARY_OP_COUNT,
544
552
  };
@@ -550,35 +558,25 @@ extern "C" {
550
558
  };
551
559
 
552
560
  enum ggml_log_level {
553
- GGML_LOG_LEVEL_ERROR = 2,
554
- GGML_LOG_LEVEL_WARN = 3,
555
- GGML_LOG_LEVEL_INFO = 4,
556
- GGML_LOG_LEVEL_DEBUG = 5
561
+ GGML_LOG_LEVEL_NONE = 0,
562
+ GGML_LOG_LEVEL_INFO = 1,
563
+ GGML_LOG_LEVEL_WARN = 2,
564
+ GGML_LOG_LEVEL_ERROR = 3,
565
+ GGML_LOG_LEVEL_DEBUG = 4,
566
+ GGML_LOG_LEVEL_CONT = 5, // continue previous log
557
567
  };
558
568
 
569
+ // this tensor...
559
570
  enum ggml_tensor_flag {
560
- GGML_TENSOR_FLAG_INPUT = 1,
561
- GGML_TENSOR_FLAG_OUTPUT = 2,
562
- GGML_TENSOR_FLAG_PARAM = 4,
563
- };
564
-
565
- // ggml object
566
- struct ggml_object {
567
- size_t offs;
568
- size_t size;
569
-
570
- struct ggml_object * next;
571
-
572
- enum ggml_object_type type;
573
-
574
- char padding[4];
571
+ GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
572
+ GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
573
+ GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
574
+ GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
575
575
  };
576
576
 
577
- static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
578
-
579
577
  // n-dimensional tensor
580
578
  struct ggml_tensor {
581
- enum ggml_type type;
579
+ enum ggml_type type;
582
580
 
583
581
  GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
584
582
 
@@ -621,6 +619,29 @@ extern "C" {
621
619
  // If it returns true, the computation is aborted
622
620
  typedef bool (*ggml_abort_callback)(void * data);
623
621
 
622
+ // Scheduling priorities
623
+ enum ggml_sched_priority {
624
+ GGML_SCHED_PRIO_NORMAL,
625
+ GGML_SCHED_PRIO_MEDIUM,
626
+ GGML_SCHED_PRIO_HIGH,
627
+ GGML_SCHED_PRIO_REALTIME
628
+ };
629
+
630
+ // Threadpool params
631
+ // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults
632
+ struct ggml_threadpool_params {
633
+ bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
634
+ int n_threads; // number of threads
635
+ enum ggml_sched_priority prio; // thread priority
636
+ uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
637
+ bool strict_cpu; // strict cpu placement
638
+ bool paused; // start in paused state
639
+ };
640
+
641
+ struct ggml_threadpool; // forward declaration, see ggml.c
642
+
643
+ typedef struct ggml_threadpool * ggml_threadpool_t;
644
+
624
645
  // the compute plan that needs to be prepared for ggml_graph_compute()
625
646
  // since https://github.com/ggerganov/ggml/issues/287
626
647
  struct ggml_cplan {
@@ -628,41 +649,13 @@ extern "C" {
628
649
  uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
629
650
 
630
651
  int n_threads;
652
+ struct ggml_threadpool * threadpool;
631
653
 
632
654
  // abort ggml_graph_compute when true
633
655
  ggml_abort_callback abort_callback;
634
656
  void * abort_callback_data;
635
657
  };
636
658
 
637
- enum ggml_cgraph_eval_order {
638
- GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
639
- GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
640
- GGML_CGRAPH_EVAL_ORDER_COUNT
641
- };
642
-
643
- typedef uint32_t ggml_bitset_t;
644
-
645
- struct ggml_hash_set {
646
- size_t size;
647
- ggml_bitset_t * used;
648
- struct ggml_tensor ** keys;
649
- };
650
-
651
- // computation graph
652
- struct ggml_cgraph {
653
- int size;
654
- int n_nodes;
655
- int n_leafs;
656
-
657
- struct ggml_tensor ** nodes;
658
- struct ggml_tensor ** grads;
659
- struct ggml_tensor ** leafs;
660
-
661
- struct ggml_hash_set visited_hash_set;
662
-
663
- enum ggml_cgraph_eval_order order;
664
- };
665
-
666
659
  // scratch buffer
667
660
  struct ggml_scratch {
668
661
  size_t offs;
@@ -714,46 +707,46 @@ extern "C" {
714
707
  GGML_API void ggml_print_object (const struct ggml_object * obj);
715
708
  GGML_API void ggml_print_objects(const struct ggml_context * ctx);
716
709
 
717
- GGML_API GGML_CALL int64_t ggml_nelements (const struct ggml_tensor * tensor);
718
- GGML_API GGML_CALL int64_t ggml_nrows (const struct ggml_tensor * tensor);
719
- GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
720
- GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
710
+ GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
711
+ GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
712
+ GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
713
+ GGML_API size_t ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
721
714
 
722
- GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
723
- GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
724
- GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
715
+ GGML_API int64_t ggml_blck_size(enum ggml_type type);
716
+ GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
717
+ GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
725
718
 
726
719
  GGML_DEPRECATED(
727
720
  GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
728
721
  "use ggml_row_size() instead");
729
722
 
730
- GGML_API GGML_CALL const char * ggml_type_name(enum ggml_type type);
731
- GGML_API GGML_CALL const char * ggml_op_name (enum ggml_op op);
732
- GGML_API const char * ggml_op_symbol(enum ggml_op op);
723
+ GGML_API const char * ggml_type_name(enum ggml_type type);
724
+ GGML_API const char * ggml_op_name (enum ggml_op op);
725
+ GGML_API const char * ggml_op_symbol(enum ggml_op op);
733
726
 
734
- GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
735
- GGML_API GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
727
+ GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
728
+ GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
736
729
 
737
- GGML_API GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor);
730
+ GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
738
731
 
739
- GGML_API GGML_CALL bool ggml_is_quantized(enum ggml_type type);
732
+ GGML_API bool ggml_is_quantized(enum ggml_type type);
740
733
 
741
734
  // TODO: temporary until model loading of ggml examples is refactored
742
735
  GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
743
736
 
744
- GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
745
- GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
746
- GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
747
- GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
748
- GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
749
- GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
750
- GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
751
- GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
737
+ GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
738
+ GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
739
+ GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
740
+ GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
741
+ GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
742
+ GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
743
+ GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
744
+ GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
752
745
 
753
- GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor);
754
- GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
755
- GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
756
- GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
746
+ GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor);
747
+ GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
748
+ GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
749
+ GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
757
750
 
758
751
  GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
759
752
  GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
@@ -845,7 +838,7 @@ extern "C" {
845
838
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
846
839
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
847
840
 
848
- GGML_API GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
841
+ GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
849
842
 
850
843
  GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
851
844
  GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
@@ -966,6 +959,22 @@ extern "C" {
966
959
  struct ggml_context * ctx,
967
960
  struct ggml_tensor * a);
968
961
 
962
+ GGML_API struct ggml_tensor * ggml_sin(
963
+ struct ggml_context * ctx,
964
+ struct ggml_tensor * a);
965
+
966
+ GGML_API struct ggml_tensor * ggml_sin_inplace(
967
+ struct ggml_context * ctx,
968
+ struct ggml_tensor * a);
969
+
970
+ GGML_API struct ggml_tensor * ggml_cos(
971
+ struct ggml_context * ctx,
972
+ struct ggml_tensor * a);
973
+
974
+ GGML_API struct ggml_tensor * ggml_cos_inplace(
975
+ struct ggml_context * ctx,
976
+ struct ggml_tensor * a);
977
+
969
978
  // return scalar
970
979
  GGML_API struct ggml_tensor * ggml_sum(
971
980
  struct ggml_context * ctx,
@@ -986,6 +995,12 @@ extern "C" {
986
995
  struct ggml_context * ctx,
987
996
  struct ggml_tensor * a);
988
997
 
998
+ // count number of equal elements in a and b
999
+ GGML_API struct ggml_tensor * ggml_count_equal(
1000
+ struct ggml_context * ctx,
1001
+ struct ggml_tensor * a,
1002
+ struct ggml_tensor * b);
1003
+
989
1004
  // if a is the same shape as b, and a is not parameter, return a
990
1005
  // otherwise, return a new tensor: repeat(a) to fit in b
991
1006
  GGML_API struct ggml_tensor * ggml_repeat(
@@ -1116,6 +1131,14 @@ extern "C" {
1116
1131
  struct ggml_context * ctx,
1117
1132
  struct ggml_tensor * a);
1118
1133
 
1134
+ GGML_API struct ggml_tensor * ggml_exp(
1135
+ struct ggml_context * ctx,
1136
+ struct ggml_tensor * a);
1137
+
1138
+ GGML_API struct ggml_tensor * ggml_exp_inplace(
1139
+ struct ggml_context * ctx,
1140
+ struct ggml_tensor * a);
1141
+
1119
1142
  // normalize along rows
1120
1143
  GGML_API struct ggml_tensor * ggml_norm(
1121
1144
  struct ggml_context * ctx,
@@ -1139,16 +1162,17 @@ extern "C" {
1139
1162
 
1140
1163
  // group normalize along ne0*ne1*n_groups
1141
1164
  // used in stable-diffusion
1142
- // TODO: eps is hardcoded to 1e-6 for now
1143
1165
  GGML_API struct ggml_tensor * ggml_group_norm(
1144
1166
  struct ggml_context * ctx,
1145
1167
  struct ggml_tensor * a,
1146
- int n_groups);
1168
+ int n_groups,
1169
+ float eps);
1147
1170
 
1148
1171
  GGML_API struct ggml_tensor * ggml_group_norm_inplace(
1149
1172
  struct ggml_context * ctx,
1150
1173
  struct ggml_tensor * a,
1151
- int n_groups);
1174
+ int n_groups,
1175
+ float eps);
1152
1176
 
1153
1177
  // a - x
1154
1178
  // b - dy
@@ -1210,7 +1234,7 @@ extern "C" {
1210
1234
  size_t nb1,
1211
1235
  size_t nb2,
1212
1236
  size_t nb3,
1213
- size_t offset);
1237
+ size_t offset); // in bytes
1214
1238
 
1215
1239
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1216
1240
  GGML_API struct ggml_tensor * ggml_set_inplace(
@@ -1220,19 +1244,19 @@ extern "C" {
1220
1244
  size_t nb1,
1221
1245
  size_t nb2,
1222
1246
  size_t nb3,
1223
- size_t offset);
1247
+ size_t offset); // in bytes
1224
1248
 
1225
1249
  GGML_API struct ggml_tensor * ggml_set_1d(
1226
1250
  struct ggml_context * ctx,
1227
1251
  struct ggml_tensor * a,
1228
1252
  struct ggml_tensor * b,
1229
- size_t offset);
1253
+ size_t offset); // in bytes
1230
1254
 
1231
1255
  GGML_API struct ggml_tensor * ggml_set_1d_inplace(
1232
1256
  struct ggml_context * ctx,
1233
1257
  struct ggml_tensor * a,
1234
1258
  struct ggml_tensor * b,
1235
- size_t offset);
1259
+ size_t offset); // in bytes
1236
1260
 
1237
1261
  // b -> view(a,offset,nb1,nb2,3), return modified a
1238
1262
  GGML_API struct ggml_tensor * ggml_set_2d(
@@ -1240,7 +1264,7 @@ extern "C" {
1240
1264
  struct ggml_tensor * a,
1241
1265
  struct ggml_tensor * b,
1242
1266
  size_t nb1,
1243
- size_t offset);
1267
+ size_t offset); // in bytes
1244
1268
 
1245
1269
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1246
1270
  GGML_API struct ggml_tensor * ggml_set_2d_inplace(
@@ -1248,7 +1272,7 @@ extern "C" {
1248
1272
  struct ggml_tensor * a,
1249
1273
  struct ggml_tensor * b,
1250
1274
  size_t nb1,
1251
- size_t offset);
1275
+ size_t offset); // in bytes
1252
1276
 
1253
1277
  // a -> b, return view(b)
1254
1278
  GGML_API struct ggml_tensor * ggml_cpy(
@@ -1383,14 +1407,14 @@ extern "C" {
1383
1407
  // supports 3D: a->ne[2] == b->ne[1]
1384
1408
  GGML_API struct ggml_tensor * ggml_get_rows(
1385
1409
  struct ggml_context * ctx,
1386
- struct ggml_tensor * a,
1387
- struct ggml_tensor * b);
1410
+ struct ggml_tensor * a, // data
1411
+ struct ggml_tensor * b); // row indices
1388
1412
 
1389
1413
  GGML_API struct ggml_tensor * ggml_get_rows_back(
1390
1414
  struct ggml_context * ctx,
1391
- struct ggml_tensor * a,
1392
- struct ggml_tensor * b,
1393
- struct ggml_tensor * c);
1415
+ struct ggml_tensor * a, // gradients of ggml_get_rows result
1416
+ struct ggml_tensor * b, // row indices
1417
+ struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
1394
1418
 
1395
1419
  GGML_API struct ggml_tensor * ggml_diag(
1396
1420
  struct ggml_context * ctx,
@@ -1451,11 +1475,10 @@ extern "C" {
1451
1475
  struct ggml_tensor * b);
1452
1476
 
1453
1477
  // rotary position embedding
1454
- // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
1455
- // if mode & 2 == 1, GPT-NeoX style
1478
+ // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1479
+ // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1456
1480
  //
1457
1481
  // b is an int32 vector with size a->ne[2], it contains the positions
1458
- // c is freq factors (e.g. phi3-128k), (optional)
1459
1482
  GGML_API struct ggml_tensor * ggml_rope(
1460
1483
  struct ggml_context * ctx,
1461
1484
  struct ggml_tensor * a,
@@ -1472,6 +1495,7 @@ extern "C" {
1472
1495
  int mode);
1473
1496
 
1474
1497
  // custom RoPE
1498
+ // c is freq factors (e.g. phi3-128k), (optional)
1475
1499
  GGML_API struct ggml_tensor * ggml_rope_ext(
1476
1500
  struct ggml_context * ctx,
1477
1501
  struct ggml_tensor * a,
@@ -1534,16 +1558,16 @@ extern "C" {
1534
1558
  "use ggml_rope_ext_inplace instead");
1535
1559
 
1536
1560
  // compute correction dims for YaRN RoPE scaling
1537
- GGML_CALL void ggml_rope_yarn_corr_dims(
1561
+ void ggml_rope_yarn_corr_dims(
1538
1562
  int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1539
1563
 
1540
1564
  // rotary position embedding backward, i.e compute dx from dy
1541
1565
  // a - dy
1542
1566
  GGML_API struct ggml_tensor * ggml_rope_back(
1543
1567
  struct ggml_context * ctx,
1544
- struct ggml_tensor * a,
1545
- struct ggml_tensor * b,
1546
- struct ggml_tensor * c,
1568
+ struct ggml_tensor * a, // gradients of ggml_rope result
1569
+ struct ggml_tensor * b, // positions
1570
+ struct ggml_tensor * c, // freq factors
1547
1571
  int n_dims,
1548
1572
  int mode,
1549
1573
  int n_ctx_orig,
@@ -1562,34 +1586,49 @@ extern "C" {
1562
1586
  float min,
1563
1587
  float max);
1564
1588
 
1589
+ // im2col
1590
+ // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1565
1591
  GGML_API struct ggml_tensor * ggml_im2col(
1566
1592
  struct ggml_context * ctx,
1567
- struct ggml_tensor * a,
1568
- struct ggml_tensor * b,
1569
- int s0,
1570
- int s1,
1571
- int p0,
1572
- int p1,
1573
- int d0,
1574
- int d1,
1575
- bool is_2D,
1576
- enum ggml_type dst_type);
1593
+ struct ggml_tensor * a, // convolution kernel
1594
+ struct ggml_tensor * b, // data
1595
+ int s0, // stride dimension 0
1596
+ int s1, // stride dimension 1
1597
+ int p0, // padding dimension 0
1598
+ int p1, // padding dimension 1
1599
+ int d0, // dilation dimension 0
1600
+ int d1, // dilation dimension 1
1601
+ bool is_2D,
1602
+ enum ggml_type dst_type);
1603
+
1604
+ GGML_API struct ggml_tensor * ggml_im2col_back(
1605
+ struct ggml_context * ctx,
1606
+ struct ggml_tensor * a, // convolution kernel
1607
+ struct ggml_tensor * b, // gradient of im2col output
1608
+ int64_t * ne, // shape of im2col input
1609
+ int s0, // stride dimension 0
1610
+ int s1, // stride dimension 1
1611
+ int p0, // padding dimension 0
1612
+ int p1, // padding dimension 1
1613
+ int d0, // dilation dimension 0
1614
+ int d1, // dilation dimension 1
1615
+ bool is_2D);
1577
1616
 
1578
1617
  GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
1579
1618
  struct ggml_context * ctx,
1580
- struct ggml_tensor * a,
1581
- struct ggml_tensor * b,
1582
- int s0,
1583
- int s1,
1584
- int p0,
1585
- int p1,
1586
- int d0,
1587
- int d1);
1619
+ struct ggml_tensor * a, // convolution kernel
1620
+ struct ggml_tensor * b, // data
1621
+ int s0, // stride dimension 0
1622
+ int s1, // stride dimension 1
1623
+ int p0, // padding dimension 0
1624
+ int p1, // padding dimension 1
1625
+ int d0, // dilation dimension 0
1626
+ int d1); // dilation dimension 1
1588
1627
 
1589
1628
  GGML_API struct ggml_tensor * ggml_conv_1d(
1590
1629
  struct ggml_context * ctx,
1591
- struct ggml_tensor * a,
1592
- struct ggml_tensor * b,
1630
+ struct ggml_tensor * a, // convolution kernel
1631
+ struct ggml_tensor * b, // data
1593
1632
  int s0, // stride
1594
1633
  int p0, // padding
1595
1634
  int d0); // dilation
@@ -1598,29 +1637,29 @@ extern "C" {
1598
1637
  // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1599
1638
  GGML_API struct ggml_tensor* ggml_conv_1d_ph(
1600
1639
  struct ggml_context * ctx,
1601
- struct ggml_tensor * a,
1602
- struct ggml_tensor * b,
1603
- int s,
1604
- int d);
1640
+ struct ggml_tensor * a, // convolution kernel
1641
+ struct ggml_tensor * b, // data
1642
+ int s, // stride
1643
+ int d); // dilation
1605
1644
 
1606
1645
  GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
1607
1646
  struct ggml_context * ctx,
1608
- struct ggml_tensor * a,
1609
- struct ggml_tensor * b,
1610
- int s0,
1611
- int p0,
1612
- int d0);
1647
+ struct ggml_tensor * a, // convolution kernel
1648
+ struct ggml_tensor * b, // data
1649
+ int s0, // stride
1650
+ int p0, // padding
1651
+ int d0); // dilation
1613
1652
 
1614
1653
  GGML_API struct ggml_tensor * ggml_conv_2d(
1615
1654
  struct ggml_context * ctx,
1616
- struct ggml_tensor * a,
1617
- struct ggml_tensor * b,
1618
- int s0,
1619
- int s1,
1620
- int p0,
1621
- int p1,
1622
- int d0,
1623
- int d1);
1655
+ struct ggml_tensor * a, // convolution kernel
1656
+ struct ggml_tensor * b, // data
1657
+ int s0, // stride dimension 0
1658
+ int s1, // stride dimension 1
1659
+ int p0, // padding dimension 0
1660
+ int p1, // padding dimension 1
1661
+ int d0, // dilation dimension 0
1662
+ int d1); // dilation dimension 1
1624
1663
 
1625
1664
 
1626
1665
  // kernel size is a->ne[0] x a->ne[1]
@@ -1682,6 +1721,18 @@ extern "C" {
1682
1721
  float p0,
1683
1722
  float p1);
1684
1723
 
1724
+ GGML_API struct ggml_tensor * ggml_pool_2d_back(
1725
+ struct ggml_context * ctx,
1726
+ struct ggml_tensor * a,
1727
+ struct ggml_tensor * af, // "a"/input used in forward pass
1728
+ enum ggml_op_pool op,
1729
+ int k0,
1730
+ int k1,
1731
+ int s0,
1732
+ int s1,
1733
+ float p0,
1734
+ float p1);
1735
+
1685
1736
  // nearest interpolate
1686
1737
  // multiplies ne0 and ne1 by scale factor
1687
1738
  // used in stable-diffusion
@@ -1756,7 +1807,8 @@ extern "C" {
1756
1807
  struct ggml_tensor * v,
1757
1808
  struct ggml_tensor * mask,
1758
1809
  float scale,
1759
- float max_bias);
1810
+ float max_bias,
1811
+ float logit_softcap);
1760
1812
 
1761
1813
  GGML_API void ggml_flash_attn_ext_set_prec(
1762
1814
  struct ggml_tensor * a,
@@ -1773,10 +1825,8 @@ extern "C" {
1773
1825
 
1774
1826
  GGML_API struct ggml_tensor * ggml_ssm_conv(
1775
1827
  struct ggml_context * ctx,
1776
- struct ggml_tensor * s,
1777
- struct ggml_tensor * x,
1778
- struct ggml_tensor * c,
1779
- struct ggml_tensor * sq);
1828
+ struct ggml_tensor * sx,
1829
+ struct ggml_tensor * c);
1780
1830
 
1781
1831
  GGML_API struct ggml_tensor * ggml_ssm_scan(
1782
1832
  struct ggml_context * ctx,
@@ -1785,8 +1835,7 @@ extern "C" {
1785
1835
  struct ggml_tensor * dt,
1786
1836
  struct ggml_tensor * A,
1787
1837
  struct ggml_tensor * B,
1788
- struct ggml_tensor * C,
1789
- struct ggml_tensor * sq);
1838
+ struct ggml_tensor * C);
1790
1839
 
1791
1840
  // partition into non-overlapping windows with padding if needed
1792
1841
  // example:
@@ -1838,6 +1887,15 @@ extern "C" {
1838
1887
  struct ggml_tensor * pw,
1839
1888
  struct ggml_tensor * ph);
1840
1889
 
1890
+ GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1891
+ struct ggml_context * ctx,
1892
+ struct ggml_tensor * k,
1893
+ struct ggml_tensor * v,
1894
+ struct ggml_tensor * r,
1895
+ struct ggml_tensor * tf,
1896
+ struct ggml_tensor * td,
1897
+ struct ggml_tensor * state);
1898
+
1841
1899
  // custom operators
1842
1900
 
1843
1901
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1921,7 +1979,8 @@ extern "C" {
1921
1979
  typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
1922
1980
  typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
1923
1981
 
1924
- #define GGML_N_TASKS_MAX -1
1982
+ #define GGML_N_TASKS_MAX (-1)
1983
+ // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
1925
1984
 
1926
1985
  GGML_API struct ggml_tensor * ggml_map_custom1(
1927
1986
  struct ggml_context * ctx,
@@ -1974,44 +2033,84 @@ extern "C" {
1974
2033
  // loss function
1975
2034
 
1976
2035
  GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
1977
- struct ggml_context * ctx,
1978
- struct ggml_tensor * a,
1979
- struct ggml_tensor * b);
2036
+ struct ggml_context * ctx,
2037
+ struct ggml_tensor * a, // logits
2038
+ struct ggml_tensor * b); // labels
1980
2039
 
1981
2040
  GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
1982
- struct ggml_context * ctx,
1983
- struct ggml_tensor * a,
1984
- struct ggml_tensor * b,
1985
- struct ggml_tensor * c);
2041
+ struct ggml_context * ctx,
2042
+ struct ggml_tensor * a, // logits
2043
+ struct ggml_tensor * b, // labels
2044
+ struct ggml_tensor * c); // gradients of cross_entropy_loss result
2045
+
2046
+ // AdamW optimizer step
2047
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2048
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2049
+ GGML_API struct ggml_tensor * ggml_opt_step_adamw(
2050
+ struct ggml_context * ctx,
2051
+ struct ggml_tensor * a,
2052
+ struct ggml_tensor * grad,
2053
+ float alpha,
2054
+ float beta1,
2055
+ float beta2,
2056
+ float eps,
2057
+ float wd); // weight decay
1986
2058
 
1987
2059
  //
1988
2060
  // automatic differentiation
1989
2061
  //
1990
2062
 
1991
- GGML_API void ggml_set_param(
1992
- struct ggml_context * ctx,
1993
- struct ggml_tensor * tensor);
1994
-
2063
+ GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
2064
+ GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
1995
2065
 
1996
2066
  GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1997
- GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
2067
+ GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
2068
+
2069
+ GGML_API void ggml_build_opt_adamw(
2070
+ struct ggml_context * ctx,
2071
+ struct ggml_cgraph * gf,
2072
+ struct ggml_cgraph * gb,
2073
+ float alpha,
2074
+ float beta1,
2075
+ float beta2,
2076
+ float eps,
2077
+ float wd); // weight decay
1998
2078
 
1999
2079
  // graph allocation in a context
2000
- GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
2001
- GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
2002
- GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2003
- GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
2004
- GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2005
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
2006
- GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
2080
+ GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
2081
+ GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
2082
+ GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2083
+ GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2084
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2085
+ GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
2086
+
2087
+ GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph);
2088
+ GGML_API struct ggml_tensor * ggml_graph_node (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2089
+ GGML_API struct ggml_tensor ** ggml_graph_nodes (struct ggml_cgraph * cgraph);
2090
+ GGML_API int ggml_graph_n_nodes(struct ggml_cgraph * cgraph);
2091
+
2092
+ GGML_API void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2007
2093
 
2008
2094
  GGML_API size_t ggml_graph_overhead(void);
2009
2095
  GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
2010
2096
 
2097
+ GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
2098
+ GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
2099
+ GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
2100
+ GGML_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
2101
+ GGML_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
2102
+ GGML_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool);
2103
+ GGML_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
2104
+ GGML_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
2105
+
2011
2106
  // ggml_graph_plan() has to be called before ggml_graph_compute()
2012
2107
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
2013
- GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
2014
- GGML_API enum ggml_status ggml_graph_compute( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
2108
+ GGML_API struct ggml_cplan ggml_graph_plan(
2109
+ const struct ggml_cgraph * cgraph,
2110
+ int n_threads, /* = GGML_DEFAULT_N_THREADS */
2111
+ struct ggml_threadpool * threadpool /* = NULL */ );
2112
+ GGML_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
2113
+
2015
2114
  // same as ggml_graph_compute() but the work data is allocated as a part of the context
2016
2115
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
2017
2116
  GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
@@ -2075,6 +2174,10 @@ extern "C" {
2075
2174
  typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
2076
2175
  typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
2077
2176
 
2177
+ // Set callback for all future logging events.
2178
+ // If this is not called, or NULL is supplied, everything is output on stderr.
2179
+ GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
2180
+
2078
2181
  // optimization parameters
2079
2182
  //
2080
2183
  // see ggml.c (ggml_opt_default_params) for default values
@@ -2400,6 +2503,7 @@ extern "C" {
2400
2503
  GGML_API int ggml_cpu_has_gpublas (void);
2401
2504
  GGML_API int ggml_cpu_has_sse3 (void);
2402
2505
  GGML_API int ggml_cpu_has_ssse3 (void);
2506
+ GGML_API int ggml_cpu_has_riscv_v (void);
2403
2507
  GGML_API int ggml_cpu_has_sycl (void);
2404
2508
  GGML_API int ggml_cpu_has_rpc (void);
2405
2509
  GGML_API int ggml_cpu_has_vsx (void);
@@ -2407,6 +2511,9 @@ extern "C" {
2407
2511
  GGML_API int ggml_cpu_has_cann (void);
2408
2512
  GGML_API int ggml_cpu_has_llamafile (void);
2409
2513
 
2514
+ // get the sve vector length in bytes
2515
+ GGML_API int ggml_cpu_get_sve_cnt(void);
2516
+
2410
2517
  //
2411
2518
  // Internal types and functions exposed for tests and benchmarks
2412
2519
  //