cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
package/cpp/ggml.c CHANGED
@@ -4,6 +4,7 @@
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-threading.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml.h"
8
9
 
9
10
  // FIXME: required here for quantization functions
@@ -382,58 +383,16 @@ void lm_ggml_fp16_to_fp32_row(const lm_ggml_fp16_t * x, float * y, int64_t n) {
382
383
  }
383
384
  }
384
385
 
385
- // FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
386
- // currently, the lm_ggml_cpu_has_* functions are entirely compile-time
387
386
  void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) {
388
- int64_t i = 0;
389
- #if defined(__F16C__)
390
- //if (lm_ggml_cpu_has_f16c()) {
391
- for (; i + 7 < n; i += 8) {
392
- __m256 x_vec = _mm256_loadu_ps(x + i);
393
- __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
394
- _mm_storeu_si128((__m128i *)(y + i), y_vec);
395
- }
396
- for(; i + 3 < n; i += 4) {
397
- __m128 x_vec = _mm_loadu_ps(x + i);
398
- __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
399
- _mm_storel_epi64((__m128i *)(y + i), y_vec);
400
- }
401
- //}
402
- #endif
403
- for (; i < n; i++) {
387
+ int i = 0;
388
+ for (; i < n; ++i) {
404
389
  y[i] = LM_GGML_FP32_TO_FP16(x[i]);
405
390
  }
406
391
  }
407
392
 
408
393
  void lm_ggml_bf16_to_fp32_row(const lm_ggml_bf16_t * x, float * y, int64_t n) {
409
- int64_t i = 0;
410
- #if defined(__AVX512F__)
411
- //if (lm_ggml_cpu_has_avx512()) {
412
- for (; i + 16 <= n; i += 16) {
413
- _mm512_storeu_ps(y + i,
414
- _mm512_castsi512_ps(
415
- _mm512_slli_epi32(
416
- _mm512_cvtepu16_epi32(
417
- _mm256_loadu_si256(
418
- (const __m256i *)(x + i))),
419
- 16)));
420
- }
421
- //}
422
- #endif
423
- #if defined(__AVX2__)
424
- //if (lm_ggml_cpu_has_avx2()) {
425
- for (; i + 8 <= n; i += 8) {
426
- _mm256_storeu_ps(y + i,
427
- _mm256_castsi256_ps(
428
- _mm256_slli_epi32(
429
- _mm256_cvtepu16_epi32(
430
- _mm_loadu_si128(
431
- (const __m128i *)(x + i))),
432
- 16)));
433
- }
434
- //}
435
- #endif
436
- for (; i < n; i++) {
394
+ int i = 0;
395
+ for (; i < n; ++i) {
437
396
  y[i] = LM_GGML_BF16_TO_FP32(x[i]);
438
397
  }
439
398
  }
@@ -969,6 +928,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
969
928
  "CONV_TRANSPOSE_1D",
970
929
  "IM2COL",
971
930
  "IM2COL_BACK",
931
+ "CONV_2D_DW",
972
932
  "CONV_TRANSPOSE_2D",
973
933
  "POOL_1D",
974
934
  "POOL_2D",
@@ -995,23 +955,18 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
995
955
 
996
956
  "UNARY",
997
957
 
998
- "MAP_UNARY",
999
- "MAP_BINARY",
1000
-
1001
- "MAP_CUSTOM1_F32",
1002
- "MAP_CUSTOM2_F32",
1003
- "MAP_CUSTOM3_F32",
1004
-
1005
958
  "MAP_CUSTOM1",
1006
959
  "MAP_CUSTOM2",
1007
960
  "MAP_CUSTOM3",
1008
961
 
962
+ "CUSTOM",
963
+
1009
964
  "CROSS_ENTROPY_LOSS",
1010
965
  "CROSS_ENTROPY_LOSS_BACK",
1011
966
  "OPT_STEP_ADAMW",
1012
967
  };
1013
968
 
1014
- static_assert(LM_GGML_OP_COUNT == 85, "LM_GGML_OP_COUNT != 85");
969
+ static_assert(LM_GGML_OP_COUNT == 82, "LM_GGML_OP_COUNT != 82");
1015
970
 
1016
971
  static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1017
972
  "none",
@@ -1068,6 +1023,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1068
1023
  "conv_transpose_1d(x)",
1069
1024
  "im2col(x)",
1070
1025
  "im2col_back(x)",
1026
+ "conv_2d_dw(x)",
1071
1027
  "conv_transpose_2d(x)",
1072
1028
  "pool_1d(x)",
1073
1029
  "pool_2d(x)",
@@ -1094,23 +1050,18 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1094
1050
 
1095
1051
  "unary(x)",
1096
1052
 
1097
- "f(x)",
1098
- "f(x,y)",
1099
-
1100
- "custom_f32(x)",
1101
- "custom_f32(x,y)",
1102
- "custom_f32(x,y,z)",
1053
+ "map_custom(x)",
1054
+ "map_custom(x,y)",
1055
+ "map_custom(x,y,z)",
1103
1056
 
1104
1057
  "custom(x)",
1105
- "custom(x,y)",
1106
- "custom(x,y,z)",
1107
1058
 
1108
1059
  "cross_entropy_loss(x,y)",
1109
1060
  "cross_entropy_loss_back(x,y)",
1110
1061
  "adamw(x)",
1111
1062
  };
1112
1063
 
1113
- static_assert(LM_GGML_OP_COUNT == 85, "LM_GGML_OP_COUNT != 85");
1064
+ static_assert(LM_GGML_OP_COUNT == 82, "LM_GGML_OP_COUNT != 82");
1114
1065
 
1115
1066
  static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
1116
1067
 
@@ -1367,6 +1318,13 @@ bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) {
1367
1318
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1368
1319
  }
1369
1320
 
1321
+ bool lm_ggml_is_contiguous_channels(const struct lm_ggml_tensor * tensor) {
1322
+ return
1323
+ tensor->nb[0] > tensor->nb[2] &&
1324
+ tensor->nb[1] > tensor->nb[0] &&
1325
+ tensor->nb[2] == lm_ggml_type_size(tensor->type);
1326
+ }
1327
+
1370
1328
  static inline bool lm_ggml_is_padded_1d(const struct lm_ggml_tensor * tensor) {
1371
1329
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
1372
1330
 
@@ -4073,6 +4031,46 @@ struct lm_ggml_tensor * lm_ggml_conv_2d_dw(
4073
4031
  return result;
4074
4032
  }
4075
4033
 
4034
+ // lm_ggml_conv_2d_dw_direct
4035
+
4036
+ struct lm_ggml_tensor * lm_ggml_conv_2d_dw_direct(
4037
+ struct lm_ggml_context * ctx,
4038
+ struct lm_ggml_tensor * a,
4039
+ struct lm_ggml_tensor * b,
4040
+ int stride0,
4041
+ int stride1,
4042
+ int pad0,
4043
+ int pad1,
4044
+ int dilation0,
4045
+ int dilation1) {
4046
+ LM_GGML_ASSERT(a->ne[2] == 1);
4047
+ LM_GGML_ASSERT(a->ne[3] == b->ne[2]);
4048
+ int64_t ne[4];
4049
+ ne[0] = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4050
+ ne[1] = lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4051
+ ne[2] = b->ne[2];
4052
+ ne[3] = b->ne[3];
4053
+
4054
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, b->type, 4, ne);
4055
+
4056
+ if (lm_ggml_is_contiguous_channels(b)) {
4057
+ // Result will be permuted the same way as input (CWHN order)
4058
+ const int64_t type_size = lm_ggml_type_size(result->type);
4059
+ LM_GGML_ASSERT(lm_ggml_blck_size(result->type) == 1);
4060
+ result->nb[0] = result->ne[2] * type_size;
4061
+ result->nb[1] = result->ne[0] * result->nb[0];
4062
+ result->nb[2] = type_size;
4063
+ }
4064
+
4065
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4066
+ lm_ggml_set_op_params(result, params, sizeof(params));
4067
+
4068
+ result->op = LM_GGML_OP_CONV_2D_DW;
4069
+ result->src[0] = a;
4070
+ result->src[1] = b;
4071
+ return result;
4072
+ }
4073
+
4076
4074
  // lm_ggml_conv_transpose_2d_p0
4077
4075
 
4078
4076
  static int64_t lm_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4197,7 +4195,8 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl(
4197
4195
  int ne0,
4198
4196
  int ne1,
4199
4197
  int ne2,
4200
- int ne3) {
4198
+ int ne3,
4199
+ enum lm_ggml_scale_mode mode) {
4201
4200
  LM_GGML_ASSERT(a->ne[0] <= ne0);
4202
4201
  LM_GGML_ASSERT(a->ne[1] <= ne1);
4203
4202
  LM_GGML_ASSERT(a->ne[2] <= ne2);
@@ -4205,6 +4204,8 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl(
4205
4204
 
4206
4205
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4207
4206
 
4207
+ lm_ggml_set_op_params_i32(result, 0, mode);
4208
+
4208
4209
  result->op = LM_GGML_OP_UPSCALE;
4209
4210
  result->src[0] = a;
4210
4211
 
@@ -4214,8 +4215,9 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl(
4214
4215
  struct lm_ggml_tensor * lm_ggml_upscale(
4215
4216
  struct lm_ggml_context * ctx,
4216
4217
  struct lm_ggml_tensor * a,
4217
- int scale_factor) {
4218
- return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4218
+ int scale_factor,
4219
+ enum lm_ggml_scale_mode mode) {
4220
+ return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4219
4221
  }
4220
4222
 
4221
4223
  struct lm_ggml_tensor * lm_ggml_upscale_ext(
@@ -4224,8 +4226,9 @@ struct lm_ggml_tensor * lm_ggml_upscale_ext(
4224
4226
  int ne0,
4225
4227
  int ne1,
4226
4228
  int ne2,
4227
- int ne3) {
4228
- return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4229
+ int ne3,
4230
+ enum lm_ggml_scale_mode mode) {
4231
+ return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4229
4232
  }
4230
4233
 
4231
4234
  // lm_ggml_pad
@@ -4855,179 +4858,6 @@ struct lm_ggml_tensor * lm_ggml_unary_inplace(
4855
4858
  return lm_ggml_unary_impl(ctx, a, op, true);
4856
4859
  }
4857
4860
 
4858
- // lm_ggml_map_unary
4859
-
4860
- static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32(
4861
- struct lm_ggml_context * ctx,
4862
- struct lm_ggml_tensor * a,
4863
- const lm_ggml_unary_op_f32_t fun,
4864
- bool inplace) {
4865
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4866
-
4867
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4868
-
4869
- result->op = LM_GGML_OP_MAP_UNARY;
4870
- result->src[0] = a;
4871
-
4872
- return result;
4873
- }
4874
-
4875
- struct lm_ggml_tensor * lm_ggml_map_unary_f32(
4876
- struct lm_ggml_context * ctx,
4877
- struct lm_ggml_tensor * a,
4878
- const lm_ggml_unary_op_f32_t fun) {
4879
- return lm_ggml_map_unary_impl_f32(ctx, a, fun, false);
4880
- }
4881
-
4882
- struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32(
4883
- struct lm_ggml_context * ctx,
4884
- struct lm_ggml_tensor * a,
4885
- const lm_ggml_unary_op_f32_t fun) {
4886
- return lm_ggml_map_unary_impl_f32(ctx, a, fun, true);
4887
- }
4888
-
4889
- // lm_ggml_map_binary
4890
-
4891
- static struct lm_ggml_tensor * lm_ggml_map_binary_impl_f32(
4892
- struct lm_ggml_context * ctx,
4893
- struct lm_ggml_tensor * a,
4894
- struct lm_ggml_tensor * b,
4895
- const lm_ggml_binary_op_f32_t fun,
4896
- bool inplace) {
4897
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4898
-
4899
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4900
-
4901
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4902
-
4903
- result->op = LM_GGML_OP_MAP_BINARY;
4904
- result->src[0] = a;
4905
- result->src[1] = b;
4906
-
4907
- return result;
4908
- }
4909
-
4910
- struct lm_ggml_tensor * lm_ggml_map_binary_f32(
4911
- struct lm_ggml_context * ctx,
4912
- struct lm_ggml_tensor * a,
4913
- struct lm_ggml_tensor * b,
4914
- const lm_ggml_binary_op_f32_t fun) {
4915
- return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4916
- }
4917
-
4918
- struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32(
4919
- struct lm_ggml_context * ctx,
4920
- struct lm_ggml_tensor * a,
4921
- struct lm_ggml_tensor * b,
4922
- const lm_ggml_binary_op_f32_t fun) {
4923
- return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4924
- }
4925
-
4926
- // lm_ggml_map_custom1_f32
4927
-
4928
- static struct lm_ggml_tensor * lm_ggml_map_custom1_impl_f32(
4929
- struct lm_ggml_context * ctx,
4930
- struct lm_ggml_tensor * a,
4931
- const lm_ggml_custom1_op_f32_t fun,
4932
- bool inplace) {
4933
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4934
-
4935
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4936
-
4937
- result->op = LM_GGML_OP_MAP_CUSTOM1_F32;
4938
- result->src[0] = a;
4939
-
4940
- return result;
4941
- }
4942
-
4943
- struct lm_ggml_tensor * lm_ggml_map_custom1_f32(
4944
- struct lm_ggml_context * ctx,
4945
- struct lm_ggml_tensor * a,
4946
- const lm_ggml_custom1_op_f32_t fun) {
4947
- return lm_ggml_map_custom1_impl_f32(ctx, a, fun, false);
4948
- }
4949
-
4950
- struct lm_ggml_tensor * lm_ggml_map_custom1_inplace_f32(
4951
- struct lm_ggml_context * ctx,
4952
- struct lm_ggml_tensor * a,
4953
- const lm_ggml_custom1_op_f32_t fun) {
4954
- return lm_ggml_map_custom1_impl_f32(ctx, a, fun, true);
4955
- }
4956
-
4957
- // lm_ggml_map_custom2_f32
4958
-
4959
- static struct lm_ggml_tensor * lm_ggml_map_custom2_impl_f32(
4960
- struct lm_ggml_context * ctx,
4961
- struct lm_ggml_tensor * a,
4962
- struct lm_ggml_tensor * b,
4963
- const lm_ggml_custom2_op_f32_t fun,
4964
- bool inplace) {
4965
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4966
-
4967
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4968
-
4969
- result->op = LM_GGML_OP_MAP_CUSTOM2_F32;
4970
- result->src[0] = a;
4971
- result->src[1] = b;
4972
-
4973
- return result;
4974
- }
4975
-
4976
- struct lm_ggml_tensor * lm_ggml_map_custom2_f32(
4977
- struct lm_ggml_context * ctx,
4978
- struct lm_ggml_tensor * a,
4979
- struct lm_ggml_tensor * b,
4980
- const lm_ggml_custom2_op_f32_t fun) {
4981
- return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
4982
- }
4983
-
4984
- struct lm_ggml_tensor * lm_ggml_map_custom2_inplace_f32(
4985
- struct lm_ggml_context * ctx,
4986
- struct lm_ggml_tensor * a,
4987
- struct lm_ggml_tensor * b,
4988
- const lm_ggml_custom2_op_f32_t fun) {
4989
- return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
4990
- }
4991
-
4992
- // lm_ggml_map_custom3_f32
4993
-
4994
- static struct lm_ggml_tensor * lm_ggml_map_custom3_impl_f32(
4995
- struct lm_ggml_context * ctx,
4996
- struct lm_ggml_tensor * a,
4997
- struct lm_ggml_tensor * b,
4998
- struct lm_ggml_tensor * c,
4999
- const lm_ggml_custom3_op_f32_t fun,
5000
- bool inplace) {
5001
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5002
-
5003
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
5004
-
5005
- result->op = LM_GGML_OP_MAP_CUSTOM3_F32;
5006
- result->src[0] = a;
5007
- result->src[1] = b;
5008
- result->src[2] = c;
5009
-
5010
- return result;
5011
- }
5012
-
5013
- struct lm_ggml_tensor * lm_ggml_map_custom3_f32(
5014
- struct lm_ggml_context * ctx,
5015
- struct lm_ggml_tensor * a,
5016
- struct lm_ggml_tensor * b,
5017
- struct lm_ggml_tensor * c,
5018
- const lm_ggml_custom3_op_f32_t fun) {
5019
- return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
5020
- }
5021
-
5022
- struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32(
5023
- struct lm_ggml_context * ctx,
5024
- struct lm_ggml_tensor * a,
5025
- struct lm_ggml_tensor * b,
5026
- struct lm_ggml_tensor * c,
5027
- const lm_ggml_custom3_op_f32_t fun) {
5028
- return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
5029
- }
5030
-
5031
4861
  // lm_ggml_map_custom1
5032
4862
 
5033
4863
  static struct lm_ggml_tensor * lm_ggml_map_custom1_impl(
@@ -5046,7 +4876,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom1_impl(
5046
4876
  /*.n_tasks =*/ n_tasks,
5047
4877
  /*.userdata =*/ userdata
5048
4878
  };
5049
- lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
4879
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5050
4880
 
5051
4881
  result->op = LM_GGML_OP_MAP_CUSTOM1;
5052
4882
  result->src[0] = a;
@@ -5091,7 +4921,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl(
5091
4921
  /*.n_tasks =*/ n_tasks,
5092
4922
  /*.userdata =*/ userdata
5093
4923
  };
5094
- lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
4924
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5095
4925
 
5096
4926
  result->op = LM_GGML_OP_MAP_CUSTOM2;
5097
4927
  result->src[0] = a;
@@ -5140,7 +4970,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl(
5140
4970
  /*.n_tasks =*/ n_tasks,
5141
4971
  /*.userdata =*/ userdata
5142
4972
  };
5143
- lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
4973
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5144
4974
 
5145
4975
  result->op = LM_GGML_OP_MAP_CUSTOM3;
5146
4976
  result->src[0] = a;
@@ -5172,6 +5002,66 @@ struct lm_ggml_tensor * lm_ggml_map_custom3_inplace(
5172
5002
  return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
5173
5003
  }
5174
5004
 
5005
+ struct lm_ggml_tensor * lm_ggml_custom_4d(
5006
+ struct lm_ggml_context * ctx,
5007
+ enum lm_ggml_type type,
5008
+ int64_t ne0,
5009
+ int64_t ne1,
5010
+ int64_t ne2,
5011
+ int64_t ne3,
5012
+ struct lm_ggml_tensor ** args,
5013
+ int n_args,
5014
+ lm_ggml_custom_op_t fun,
5015
+ int n_tasks,
5016
+ void * userdata) {
5017
+
5018
+ LM_GGML_ASSERT(n_args < LM_GGML_MAX_SRC);
5019
+
5020
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
5021
+
5022
+ struct lm_ggml_custom_op_params params = {
5023
+ /*.fun =*/ fun,
5024
+ /*.n_tasks =*/ n_tasks,
5025
+ /*.userdata =*/ userdata
5026
+ };
5027
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5028
+
5029
+ result->op = LM_GGML_OP_CUSTOM;
5030
+ for (int i = 0; i < n_args; i++) {
5031
+ result->src[i] = args[i];
5032
+ }
5033
+
5034
+ return result;
5035
+ }
5036
+
5037
+ struct lm_ggml_tensor * lm_ggml_custom_inplace(
5038
+ struct lm_ggml_context * ctx,
5039
+ struct lm_ggml_tensor * a,
5040
+ struct lm_ggml_tensor ** args,
5041
+ int n_args,
5042
+ lm_ggml_custom_op_t fun,
5043
+ int n_tasks,
5044
+ void * userdata) {
5045
+
5046
+ LM_GGML_ASSERT(n_args < LM_GGML_MAX_SRC - 1);
5047
+
5048
+ struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
5049
+
5050
+ struct lm_ggml_custom_op_params params = {
5051
+ /*.fun =*/ fun,
5052
+ /*.n_tasks =*/ n_tasks,
5053
+ /*.userdata =*/ userdata
5054
+ };
5055
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5056
+
5057
+ result->op = LM_GGML_OP_CUSTOM;
5058
+ result->src[0] = a;
5059
+ for (int i = 0; i < n_args; i++) {
5060
+ result->src[i + 1] = args[i];
5061
+ }
5062
+
5063
+ return result;
5064
+ }
5175
5065
  // lm_ggml_cross_entropy_loss
5176
5066
 
5177
5067
  struct lm_ggml_tensor * lm_ggml_cross_entropy_loss(