cui-llama.rn 1.6.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (285) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +22 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
  4. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  5. package/android/src/main/jni.cpp +173 -18
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  16. package/cpp/LICENSE +21 -0
  17. package/cpp/chat.cpp +129 -107
  18. package/cpp/chat.h +2 -0
  19. package/cpp/common.cpp +58 -78
  20. package/cpp/common.h +29 -21
  21. package/cpp/ggml-alloc.c +4 -1
  22. package/cpp/ggml-backend.cpp +9 -5
  23. package/cpp/ggml-backend.h +4 -4
  24. package/cpp/ggml-cpp.h +1 -1
  25. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  26. package/cpp/ggml-cpu/amx/amx.h +8 -0
  27. package/cpp/ggml-cpu/amx/common.h +91 -0
  28. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  29. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  31. package/cpp/ggml-cpu/common.h +72 -0
  32. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
  33. package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
  34. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
  35. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
  36. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
  37. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  38. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  39. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  40. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  41. package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
  42. package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
  43. package/cpp/ggml-cpu.h +5 -0
  44. package/cpp/ggml-impl.h +16 -9
  45. package/cpp/ggml-llama-sim.metallib +0 -0
  46. package/cpp/ggml-llama.metallib +0 -0
  47. package/cpp/ggml-metal-impl.h +36 -11
  48. package/cpp/ggml-metal.m +810 -176
  49. package/cpp/ggml-opt.cpp +373 -190
  50. package/cpp/ggml-opt.h +49 -28
  51. package/cpp/ggml-quants.c +0 -6
  52. package/cpp/ggml.c +227 -282
  53. package/cpp/ggml.h +82 -101
  54. package/cpp/gguf.cpp +33 -33
  55. package/cpp/json-schema-to-grammar.cpp +3 -0
  56. package/cpp/llama-adapter.cpp +6 -0
  57. package/cpp/llama-arch.cpp +49 -17
  58. package/cpp/llama-arch.h +9 -0
  59. package/cpp/llama-batch.cpp +8 -2
  60. package/cpp/llama-batch.h +2 -1
  61. package/cpp/llama-chat.cpp +39 -16
  62. package/cpp/llama-chat.h +4 -2
  63. package/cpp/llama-context.cpp +440 -611
  64. package/cpp/llama-context.h +44 -33
  65. package/cpp/llama-cparams.h +1 -0
  66. package/cpp/llama-graph.cpp +214 -291
  67. package/cpp/llama-graph.h +69 -21
  68. package/cpp/llama-hparams.cpp +17 -1
  69. package/cpp/llama-hparams.h +39 -5
  70. package/cpp/llama-kv-cache.cpp +2067 -620
  71. package/cpp/llama-kv-cache.h +410 -108
  72. package/cpp/llama-memory.h +12 -1
  73. package/cpp/llama-model-loader.cpp +24 -15
  74. package/cpp/llama-model-saver.cpp +281 -0
  75. package/cpp/llama-model-saver.h +37 -0
  76. package/cpp/llama-model.cpp +1089 -359
  77. package/cpp/llama-model.h +19 -3
  78. package/cpp/llama-sampling.cpp +20 -7
  79. package/cpp/llama-vocab.cpp +54 -9
  80. package/cpp/llama-vocab.h +6 -0
  81. package/cpp/llama.cpp +14 -0
  82. package/cpp/llama.h +86 -142
  83. package/cpp/minja/chat-template.hpp +9 -5
  84. package/cpp/minja/minja.hpp +69 -36
  85. package/cpp/rn-llama.cpp +602 -190
  86. package/cpp/rn-llama.h +34 -8
  87. package/cpp/sampling.cpp +57 -50
  88. package/cpp/tools/mtmd/clip-impl.h +462 -0
  89. package/cpp/tools/mtmd/clip.cpp +4024 -0
  90. package/cpp/tools/mtmd/clip.h +101 -0
  91. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  92. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  93. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  94. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  95. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  96. package/cpp/tools/mtmd/mtmd.h +362 -0
  97. package/cpp/tools/mtmd/stb_image.h +7988 -0
  98. package/ios/CMakeLists.txt +20 -10
  99. package/ios/RNLlama.h +6 -0
  100. package/ios/RNLlama.mm +82 -3
  101. package/ios/RNLlamaContext.h +5 -1
  102. package/ios/RNLlamaContext.mm +131 -38
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  131. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  132. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  133. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  134. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  160. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  161. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
  162. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  163. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  164. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  175. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  176. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  177. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  178. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  179. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  180. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  181. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
  182. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  183. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  184. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  185. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  188. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  189. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  190. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  191. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  192. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  193. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  194. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  195. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  196. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  197. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  198. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  199. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  200. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  201. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  202. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  203. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  204. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  205. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  206. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  207. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  208. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  209. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  210. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  211. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  212. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  213. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  214. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  215. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  216. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  217. package/jest/mock.js +33 -7
  218. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  219. package/lib/commonjs/index.js +153 -21
  220. package/lib/commonjs/index.js.map +1 -1
  221. package/lib/module/NativeRNLlama.js.map +1 -1
  222. package/lib/module/index.js +152 -20
  223. package/lib/module/index.js.map +1 -1
  224. package/lib/typescript/NativeRNLlama.d.ts +54 -4
  225. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  226. package/lib/typescript/index.d.ts +72 -6
  227. package/lib/typescript/index.d.ts.map +1 -1
  228. package/package.json +1 -1
  229. package/src/NativeRNLlama.ts +72 -4
  230. package/src/index.ts +212 -38
  231. package/cpp/binary-ops.h +0 -16
  232. package/cpp/ops.h +0 -128
  233. package/cpp/simd-mappings.h +0 -888
  234. package/cpp/unary-ops.h +0 -28
  235. package/cpp/vec.h +0 -802
  236. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  237. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  238. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  239. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  240. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  241. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  242. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  243. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  244. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  245. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  246. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  247. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  248. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  249. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  250. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  251. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  271. package/lib/commonjs/chat.js +0 -37
  272. package/lib/commonjs/chat.js.map +0 -1
  273. package/lib/module/chat.js +0 -33
  274. package/lib/module/chat.js.map +0 -1
  275. package/lib/typescript/chat.d.ts +0 -10
  276. package/lib/typescript/chat.d.ts.map +0 -1
  277. package/src/chat.ts +0 -44
  278. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  279. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  280. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  281. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  282. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  283. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  284. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  285. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
package/cpp/ggml.c CHANGED
@@ -4,6 +4,7 @@
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-threading.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml.h"
8
9
 
9
10
  // FIXME: required here for quantization functions
@@ -63,12 +64,17 @@
63
64
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
64
65
  float lm_ggml_table_f32_f16[1 << 16];
65
66
 
66
- #if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
67
- (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
67
+ #if defined(__linux__) || \
68
+ defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
69
+ (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
70
+
68
71
  #include <unistd.h>
69
72
  #include <sys/types.h>
70
73
  #include <sys/stat.h>
71
74
  #include <sys/wait.h>
75
+ #if defined(__linux__)
76
+ #include <sys/prctl.h>
77
+ #endif
72
78
 
73
79
  #if defined(__ANDROID__)
74
80
  #include <unwind.h>
@@ -132,10 +138,36 @@ static void lm_ggml_print_backtrace(void) {
132
138
  if (LM_GGML_NO_BACKTRACE) {
133
139
  return;
134
140
  }
135
- char attach[32];
136
- snprintf(attach, sizeof(attach), "attach %d", getpid());
137
- int pid = fork();
138
- if (pid == 0) {
141
+ #if defined(__linux__)
142
+ FILE * f = fopen("/proc/self/status", "r");
143
+ size_t size = 0;
144
+ char * line = NULL;
145
+ ssize_t length = 0;
146
+ while ((length = getline(&line, &size, f)) > 0) {
147
+ if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
148
+ (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
149
+ // Already being debugged, and the breakpoint is the later abort()
150
+ free(line);
151
+ fclose(f);
152
+ return;
153
+ }
154
+ }
155
+ free(line);
156
+ fclose(f);
157
+ int lock[2] = { -1, -1 };
158
+ (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
159
+ #endif
160
+ const int parent_pid = getpid();
161
+ const int child_pid = fork();
162
+ if (child_pid < 0) { // error
163
+ return;
164
+ } else if (child_pid == 0) { // child
165
+ char attach[32];
166
+ snprintf(attach, sizeof(attach), "attach %d", parent_pid);
167
+ #if defined(__linux__)
168
+ close(lock[1]);
169
+ (void) !read(lock[0], lock, 1);
170
+ #endif
139
171
  // try gdb
140
172
  execlp("gdb", "gdb", "--batch",
141
173
  "-ex", "set style enabled on",
@@ -148,18 +180,18 @@ static void lm_ggml_print_backtrace(void) {
148
180
  execlp("lldb", "lldb", "--batch",
149
181
  "-o", "bt",
150
182
  "-o", "quit",
151
- "-p", attach,
183
+ "-p", &attach[sizeof("attach ") - 1],
152
184
  (char *) NULL);
153
- exit(EXIT_FAILURE);
154
- } else {
155
- int wstatus;
156
- waitpid(pid, &wstatus, 0);
157
- if (WIFEXITED(wstatus)) {
158
- if (WEXITSTATUS(wstatus) == EXIT_FAILURE) {
159
- // gdb failed, fallback to backtrace_symbols
160
- lm_ggml_print_backtrace_symbols();
161
- }
162
- }
185
+ // gdb failed, fallback to backtrace_symbols
186
+ lm_ggml_print_backtrace_symbols();
187
+ _Exit(0);
188
+ } else { // parent
189
+ #if defined(__linux__)
190
+ prctl(PR_SET_PTRACER, child_pid);
191
+ close(lock[1]);
192
+ close(lock[0]);
193
+ #endif
194
+ waitpid(child_pid, NULL, 0);
163
195
  }
164
196
  }
165
197
  #else
@@ -382,58 +414,16 @@ void lm_ggml_fp16_to_fp32_row(const lm_ggml_fp16_t * x, float * y, int64_t n) {
382
414
  }
383
415
  }
384
416
 
385
- // FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
386
- // currently, the lm_ggml_cpu_has_* functions are entirely compile-time
387
417
  void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) {
388
- int64_t i = 0;
389
- #if defined(__F16C__)
390
- //if (lm_ggml_cpu_has_f16c()) {
391
- for (; i + 7 < n; i += 8) {
392
- __m256 x_vec = _mm256_loadu_ps(x + i);
393
- __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
394
- _mm_storeu_si128((__m128i *)(y + i), y_vec);
395
- }
396
- for(; i + 3 < n; i += 4) {
397
- __m128 x_vec = _mm_loadu_ps(x + i);
398
- __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
399
- _mm_storel_epi64((__m128i *)(y + i), y_vec);
400
- }
401
- //}
402
- #endif
403
- for (; i < n; i++) {
418
+ int i = 0;
419
+ for (; i < n; ++i) {
404
420
  y[i] = LM_GGML_FP32_TO_FP16(x[i]);
405
421
  }
406
422
  }
407
423
 
408
424
  void lm_ggml_bf16_to_fp32_row(const lm_ggml_bf16_t * x, float * y, int64_t n) {
409
- int64_t i = 0;
410
- #if defined(__AVX512F__)
411
- //if (lm_ggml_cpu_has_avx512()) {
412
- for (; i + 16 <= n; i += 16) {
413
- _mm512_storeu_ps(y + i,
414
- _mm512_castsi512_ps(
415
- _mm512_slli_epi32(
416
- _mm512_cvtepu16_epi32(
417
- _mm256_loadu_si256(
418
- (const __m256i *)(x + i))),
419
- 16)));
420
- }
421
- //}
422
- #endif
423
- #if defined(__AVX2__)
424
- //if (lm_ggml_cpu_has_avx2()) {
425
- for (; i + 8 <= n; i += 8) {
426
- _mm256_storeu_ps(y + i,
427
- _mm256_castsi256_ps(
428
- _mm256_slli_epi32(
429
- _mm256_cvtepu16_epi32(
430
- _mm_loadu_si128(
431
- (const __m128i *)(x + i))),
432
- 16)));
433
- }
434
- //}
435
- #endif
436
- for (; i < n; i++) {
425
+ int i = 0;
426
+ for (; i < n; ++i) {
437
427
  y[i] = LM_GGML_BF16_TO_FP32(x[i]);
438
428
  }
439
429
  }
@@ -969,6 +959,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
969
959
  "CONV_TRANSPOSE_1D",
970
960
  "IM2COL",
971
961
  "IM2COL_BACK",
962
+ "CONV_2D_DW",
972
963
  "CONV_TRANSPOSE_2D",
973
964
  "POOL_1D",
974
965
  "POOL_2D",
@@ -995,23 +986,18 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
995
986
 
996
987
  "UNARY",
997
988
 
998
- "MAP_UNARY",
999
- "MAP_BINARY",
1000
-
1001
- "MAP_CUSTOM1_F32",
1002
- "MAP_CUSTOM2_F32",
1003
- "MAP_CUSTOM3_F32",
1004
-
1005
989
  "MAP_CUSTOM1",
1006
990
  "MAP_CUSTOM2",
1007
991
  "MAP_CUSTOM3",
1008
992
 
993
+ "CUSTOM",
994
+
1009
995
  "CROSS_ENTROPY_LOSS",
1010
996
  "CROSS_ENTROPY_LOSS_BACK",
1011
997
  "OPT_STEP_ADAMW",
1012
998
  };
1013
999
 
1014
- static_assert(LM_GGML_OP_COUNT == 85, "LM_GGML_OP_COUNT != 85");
1000
+ static_assert(LM_GGML_OP_COUNT == 82, "LM_GGML_OP_COUNT != 82");
1015
1001
 
1016
1002
  static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1017
1003
  "none",
@@ -1068,6 +1054,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1068
1054
  "conv_transpose_1d(x)",
1069
1055
  "im2col(x)",
1070
1056
  "im2col_back(x)",
1057
+ "conv_2d_dw(x)",
1071
1058
  "conv_transpose_2d(x)",
1072
1059
  "pool_1d(x)",
1073
1060
  "pool_2d(x)",
@@ -1094,23 +1081,18 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1094
1081
 
1095
1082
  "unary(x)",
1096
1083
 
1097
- "f(x)",
1098
- "f(x,y)",
1099
-
1100
- "custom_f32(x)",
1101
- "custom_f32(x,y)",
1102
- "custom_f32(x,y,z)",
1084
+ "map_custom(x)",
1085
+ "map_custom(x,y)",
1086
+ "map_custom(x,y,z)",
1103
1087
 
1104
1088
  "custom(x)",
1105
- "custom(x,y)",
1106
- "custom(x,y,z)",
1107
1089
 
1108
1090
  "cross_entropy_loss(x,y)",
1109
1091
  "cross_entropy_loss_back(x,y)",
1110
1092
  "adamw(x)",
1111
1093
  };
1112
1094
 
1113
- static_assert(LM_GGML_OP_COUNT == 85, "LM_GGML_OP_COUNT != 85");
1095
+ static_assert(LM_GGML_OP_COUNT == 82, "LM_GGML_OP_COUNT != 82");
1114
1096
 
1115
1097
  static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
1116
1098
 
@@ -1130,9 +1112,10 @@ static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = {
1130
1112
  "HARDSWISH",
1131
1113
  "HARDSIGMOID",
1132
1114
  "EXP",
1115
+ "GELU_ERF",
1133
1116
  };
1134
1117
 
1135
- static_assert(LM_GGML_UNARY_OP_COUNT == 14, "LM_GGML_UNARY_OP_COUNT != 14");
1118
+ static_assert(LM_GGML_UNARY_OP_COUNT == 15, "LM_GGML_UNARY_OP_COUNT != 15");
1136
1119
 
1137
1120
 
1138
1121
  static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN");
@@ -1361,12 +1344,23 @@ bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) {
1361
1344
  return lm_ggml_is_contiguous_n(tensor, 2);
1362
1345
  }
1363
1346
 
1347
+ bool lm_ggml_is_contiguously_allocated(const struct lm_ggml_tensor * tensor) {
1348
+ return lm_ggml_nbytes(tensor) == lm_ggml_nelements(tensor) * lm_ggml_type_size(tensor->type)/lm_ggml_blck_size(tensor->type);
1349
+ }
1350
+
1364
1351
  bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) {
1365
1352
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
1366
1353
 
1367
1354
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1368
1355
  }
1369
1356
 
1357
+ bool lm_ggml_is_contiguous_channels(const struct lm_ggml_tensor * tensor) {
1358
+ return
1359
+ tensor->nb[0] > tensor->nb[2] &&
1360
+ tensor->nb[1] > tensor->nb[0] &&
1361
+ tensor->nb[2] == lm_ggml_type_size(tensor->type);
1362
+ }
1363
+
1370
1364
  static inline bool lm_ggml_is_padded_1d(const struct lm_ggml_tensor * tensor) {
1371
1365
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
1372
1366
 
@@ -2521,6 +2515,20 @@ struct lm_ggml_tensor * lm_ggml_gelu_inplace(
2521
2515
  return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU);
2522
2516
  }
2523
2517
 
2518
+ // lm_ggml_gelu_erf
2519
+
2520
+ struct lm_ggml_tensor * lm_ggml_gelu_erf(
2521
+ struct lm_ggml_context * ctx,
2522
+ struct lm_ggml_tensor * a) {
2523
+ return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_GELU_ERF);
2524
+ }
2525
+
2526
+ struct lm_ggml_tensor * lm_ggml_gelu_erf_inplace(
2527
+ struct lm_ggml_context * ctx,
2528
+ struct lm_ggml_tensor * a) {
2529
+ return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU_ERF);
2530
+ }
2531
+
2524
2532
  // lm_ggml_gelu_quick
2525
2533
 
2526
2534
  struct lm_ggml_tensor * lm_ggml_gelu_quick(
@@ -2783,11 +2791,11 @@ void lm_ggml_mul_mat_set_prec(
2783
2791
  c = lm_ggml_mul_mat_id(ctx, as, b, ids);
2784
2792
 
2785
2793
  as -> [cols, rows, n_expert]
2786
- ids -> [n_experts_used, n_tokens] (i32)
2787
2794
  b -> [cols, n_expert_used, n_tokens]
2795
+ ids -> [n_expert_used, n_tokens] (i32)
2788
2796
  c -> [rows, n_expert_used, n_tokens]
2789
2797
 
2790
- in b, n_experts_used can be broadcasted to match the n_expert_used of ids
2798
+ in b, n_expert_used can be broadcasted to match the n_expert_used of ids
2791
2799
 
2792
2800
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2793
2801
  */
@@ -4073,6 +4081,46 @@ struct lm_ggml_tensor * lm_ggml_conv_2d_dw(
4073
4081
  return result;
4074
4082
  }
4075
4083
 
4084
+ // lm_ggml_conv_2d_dw_direct
4085
+
4086
+ struct lm_ggml_tensor * lm_ggml_conv_2d_dw_direct(
4087
+ struct lm_ggml_context * ctx,
4088
+ struct lm_ggml_tensor * a,
4089
+ struct lm_ggml_tensor * b,
4090
+ int stride0,
4091
+ int stride1,
4092
+ int pad0,
4093
+ int pad1,
4094
+ int dilation0,
4095
+ int dilation1) {
4096
+ LM_GGML_ASSERT(a->ne[2] == 1);
4097
+ LM_GGML_ASSERT(a->ne[3] == b->ne[2]);
4098
+ int64_t ne[4];
4099
+ ne[0] = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4100
+ ne[1] = lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4101
+ ne[2] = b->ne[2];
4102
+ ne[3] = b->ne[3];
4103
+
4104
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, b->type, 4, ne);
4105
+
4106
+ if (lm_ggml_is_contiguous_channels(b)) {
4107
+ // Result will be permuted the same way as input (CWHN order)
4108
+ const int64_t type_size = lm_ggml_type_size(result->type);
4109
+ LM_GGML_ASSERT(lm_ggml_blck_size(result->type) == 1);
4110
+ result->nb[0] = result->ne[2] * type_size;
4111
+ result->nb[1] = result->ne[0] * result->nb[0];
4112
+ result->nb[2] = type_size;
4113
+ }
4114
+
4115
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4116
+ lm_ggml_set_op_params(result, params, sizeof(params));
4117
+
4118
+ result->op = LM_GGML_OP_CONV_2D_DW;
4119
+ result->src[0] = a;
4120
+ result->src[1] = b;
4121
+ return result;
4122
+ }
4123
+
4076
4124
  // lm_ggml_conv_transpose_2d_p0
4077
4125
 
4078
4126
  static int64_t lm_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4197,7 +4245,8 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl(
4197
4245
  int ne0,
4198
4246
  int ne1,
4199
4247
  int ne2,
4200
- int ne3) {
4248
+ int ne3,
4249
+ enum lm_ggml_scale_mode mode) {
4201
4250
  LM_GGML_ASSERT(a->ne[0] <= ne0);
4202
4251
  LM_GGML_ASSERT(a->ne[1] <= ne1);
4203
4252
  LM_GGML_ASSERT(a->ne[2] <= ne2);
@@ -4205,6 +4254,8 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl(
4205
4254
 
4206
4255
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4207
4256
 
4257
+ lm_ggml_set_op_params_i32(result, 0, mode);
4258
+
4208
4259
  result->op = LM_GGML_OP_UPSCALE;
4209
4260
  result->src[0] = a;
4210
4261
 
@@ -4214,8 +4265,9 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl(
4214
4265
  struct lm_ggml_tensor * lm_ggml_upscale(
4215
4266
  struct lm_ggml_context * ctx,
4216
4267
  struct lm_ggml_tensor * a,
4217
- int scale_factor) {
4218
- return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4268
+ int scale_factor,
4269
+ enum lm_ggml_scale_mode mode) {
4270
+ return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4219
4271
  }
4220
4272
 
4221
4273
  struct lm_ggml_tensor * lm_ggml_upscale_ext(
@@ -4224,8 +4276,9 @@ struct lm_ggml_tensor * lm_ggml_upscale_ext(
4224
4276
  int ne0,
4225
4277
  int ne1,
4226
4278
  int ne2,
4227
- int ne3) {
4228
- return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4279
+ int ne3,
4280
+ enum lm_ggml_scale_mode mode) {
4281
+ return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4229
4282
  }
4230
4283
 
4231
4284
  // lm_ggml_pad
@@ -4855,179 +4908,6 @@ struct lm_ggml_tensor * lm_ggml_unary_inplace(
4855
4908
  return lm_ggml_unary_impl(ctx, a, op, true);
4856
4909
  }
4857
4910
 
4858
- // lm_ggml_map_unary
4859
-
4860
- static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32(
4861
- struct lm_ggml_context * ctx,
4862
- struct lm_ggml_tensor * a,
4863
- const lm_ggml_unary_op_f32_t fun,
4864
- bool inplace) {
4865
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4866
-
4867
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4868
-
4869
- result->op = LM_GGML_OP_MAP_UNARY;
4870
- result->src[0] = a;
4871
-
4872
- return result;
4873
- }
4874
-
4875
- struct lm_ggml_tensor * lm_ggml_map_unary_f32(
4876
- struct lm_ggml_context * ctx,
4877
- struct lm_ggml_tensor * a,
4878
- const lm_ggml_unary_op_f32_t fun) {
4879
- return lm_ggml_map_unary_impl_f32(ctx, a, fun, false);
4880
- }
4881
-
4882
- struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32(
4883
- struct lm_ggml_context * ctx,
4884
- struct lm_ggml_tensor * a,
4885
- const lm_ggml_unary_op_f32_t fun) {
4886
- return lm_ggml_map_unary_impl_f32(ctx, a, fun, true);
4887
- }
4888
-
4889
- // lm_ggml_map_binary
4890
-
4891
- static struct lm_ggml_tensor * lm_ggml_map_binary_impl_f32(
4892
- struct lm_ggml_context * ctx,
4893
- struct lm_ggml_tensor * a,
4894
- struct lm_ggml_tensor * b,
4895
- const lm_ggml_binary_op_f32_t fun,
4896
- bool inplace) {
4897
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4898
-
4899
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4900
-
4901
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4902
-
4903
- result->op = LM_GGML_OP_MAP_BINARY;
4904
- result->src[0] = a;
4905
- result->src[1] = b;
4906
-
4907
- return result;
4908
- }
4909
-
4910
- struct lm_ggml_tensor * lm_ggml_map_binary_f32(
4911
- struct lm_ggml_context * ctx,
4912
- struct lm_ggml_tensor * a,
4913
- struct lm_ggml_tensor * b,
4914
- const lm_ggml_binary_op_f32_t fun) {
4915
- return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4916
- }
4917
-
4918
- struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32(
4919
- struct lm_ggml_context * ctx,
4920
- struct lm_ggml_tensor * a,
4921
- struct lm_ggml_tensor * b,
4922
- const lm_ggml_binary_op_f32_t fun) {
4923
- return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4924
- }
4925
-
4926
- // lm_ggml_map_custom1_f32
4927
-
4928
- static struct lm_ggml_tensor * lm_ggml_map_custom1_impl_f32(
4929
- struct lm_ggml_context * ctx,
4930
- struct lm_ggml_tensor * a,
4931
- const lm_ggml_custom1_op_f32_t fun,
4932
- bool inplace) {
4933
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4934
-
4935
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4936
-
4937
- result->op = LM_GGML_OP_MAP_CUSTOM1_F32;
4938
- result->src[0] = a;
4939
-
4940
- return result;
4941
- }
4942
-
4943
- struct lm_ggml_tensor * lm_ggml_map_custom1_f32(
4944
- struct lm_ggml_context * ctx,
4945
- struct lm_ggml_tensor * a,
4946
- const lm_ggml_custom1_op_f32_t fun) {
4947
- return lm_ggml_map_custom1_impl_f32(ctx, a, fun, false);
4948
- }
4949
-
4950
- struct lm_ggml_tensor * lm_ggml_map_custom1_inplace_f32(
4951
- struct lm_ggml_context * ctx,
4952
- struct lm_ggml_tensor * a,
4953
- const lm_ggml_custom1_op_f32_t fun) {
4954
- return lm_ggml_map_custom1_impl_f32(ctx, a, fun, true);
4955
- }
4956
-
4957
- // lm_ggml_map_custom2_f32
4958
-
4959
- static struct lm_ggml_tensor * lm_ggml_map_custom2_impl_f32(
4960
- struct lm_ggml_context * ctx,
4961
- struct lm_ggml_tensor * a,
4962
- struct lm_ggml_tensor * b,
4963
- const lm_ggml_custom2_op_f32_t fun,
4964
- bool inplace) {
4965
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4966
-
4967
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4968
-
4969
- result->op = LM_GGML_OP_MAP_CUSTOM2_F32;
4970
- result->src[0] = a;
4971
- result->src[1] = b;
4972
-
4973
- return result;
4974
- }
4975
-
4976
- struct lm_ggml_tensor * lm_ggml_map_custom2_f32(
4977
- struct lm_ggml_context * ctx,
4978
- struct lm_ggml_tensor * a,
4979
- struct lm_ggml_tensor * b,
4980
- const lm_ggml_custom2_op_f32_t fun) {
4981
- return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
4982
- }
4983
-
4984
- struct lm_ggml_tensor * lm_ggml_map_custom2_inplace_f32(
4985
- struct lm_ggml_context * ctx,
4986
- struct lm_ggml_tensor * a,
4987
- struct lm_ggml_tensor * b,
4988
- const lm_ggml_custom2_op_f32_t fun) {
4989
- return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
4990
- }
4991
-
4992
- // lm_ggml_map_custom3_f32
4993
-
4994
- static struct lm_ggml_tensor * lm_ggml_map_custom3_impl_f32(
4995
- struct lm_ggml_context * ctx,
4996
- struct lm_ggml_tensor * a,
4997
- struct lm_ggml_tensor * b,
4998
- struct lm_ggml_tensor * c,
4999
- const lm_ggml_custom3_op_f32_t fun,
5000
- bool inplace) {
5001
- struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5002
-
5003
- lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
5004
-
5005
- result->op = LM_GGML_OP_MAP_CUSTOM3_F32;
5006
- result->src[0] = a;
5007
- result->src[1] = b;
5008
- result->src[2] = c;
5009
-
5010
- return result;
5011
- }
5012
-
5013
- struct lm_ggml_tensor * lm_ggml_map_custom3_f32(
5014
- struct lm_ggml_context * ctx,
5015
- struct lm_ggml_tensor * a,
5016
- struct lm_ggml_tensor * b,
5017
- struct lm_ggml_tensor * c,
5018
- const lm_ggml_custom3_op_f32_t fun) {
5019
- return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
5020
- }
5021
-
5022
- struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32(
5023
- struct lm_ggml_context * ctx,
5024
- struct lm_ggml_tensor * a,
5025
- struct lm_ggml_tensor * b,
5026
- struct lm_ggml_tensor * c,
5027
- const lm_ggml_custom3_op_f32_t fun) {
5028
- return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
5029
- }
5030
-
5031
4911
  // lm_ggml_map_custom1
5032
4912
 
5033
4913
  static struct lm_ggml_tensor * lm_ggml_map_custom1_impl(
@@ -5046,7 +4926,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom1_impl(
5046
4926
  /*.n_tasks =*/ n_tasks,
5047
4927
  /*.userdata =*/ userdata
5048
4928
  };
5049
- lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
4929
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5050
4930
 
5051
4931
  result->op = LM_GGML_OP_MAP_CUSTOM1;
5052
4932
  result->src[0] = a;
@@ -5091,7 +4971,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl(
5091
4971
  /*.n_tasks =*/ n_tasks,
5092
4972
  /*.userdata =*/ userdata
5093
4973
  };
5094
- lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
4974
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5095
4975
 
5096
4976
  result->op = LM_GGML_OP_MAP_CUSTOM2;
5097
4977
  result->src[0] = a;
@@ -5140,7 +5020,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl(
5140
5020
  /*.n_tasks =*/ n_tasks,
5141
5021
  /*.userdata =*/ userdata
5142
5022
  };
5143
- lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
5023
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5144
5024
 
5145
5025
  result->op = LM_GGML_OP_MAP_CUSTOM3;
5146
5026
  result->src[0] = a;
@@ -5172,6 +5052,66 @@ struct lm_ggml_tensor * lm_ggml_map_custom3_inplace(
5172
5052
  return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
5173
5053
  }
5174
5054
 
5055
+ struct lm_ggml_tensor * lm_ggml_custom_4d(
5056
+ struct lm_ggml_context * ctx,
5057
+ enum lm_ggml_type type,
5058
+ int64_t ne0,
5059
+ int64_t ne1,
5060
+ int64_t ne2,
5061
+ int64_t ne3,
5062
+ struct lm_ggml_tensor ** args,
5063
+ int n_args,
5064
+ lm_ggml_custom_op_t fun,
5065
+ int n_tasks,
5066
+ void * userdata) {
5067
+
5068
+ LM_GGML_ASSERT(n_args < LM_GGML_MAX_SRC);
5069
+
5070
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
5071
+
5072
+ struct lm_ggml_custom_op_params params = {
5073
+ /*.fun =*/ fun,
5074
+ /*.n_tasks =*/ n_tasks,
5075
+ /*.userdata =*/ userdata
5076
+ };
5077
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5078
+
5079
+ result->op = LM_GGML_OP_CUSTOM;
5080
+ for (int i = 0; i < n_args; i++) {
5081
+ result->src[i] = args[i];
5082
+ }
5083
+
5084
+ return result;
5085
+ }
5086
+
5087
+ struct lm_ggml_tensor * lm_ggml_custom_inplace(
5088
+ struct lm_ggml_context * ctx,
5089
+ struct lm_ggml_tensor * a,
5090
+ struct lm_ggml_tensor ** args,
5091
+ int n_args,
5092
+ lm_ggml_custom_op_t fun,
5093
+ int n_tasks,
5094
+ void * userdata) {
5095
+
5096
+ LM_GGML_ASSERT(n_args < LM_GGML_MAX_SRC - 1);
5097
+
5098
+ struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
5099
+
5100
+ struct lm_ggml_custom_op_params params = {
5101
+ /*.fun =*/ fun,
5102
+ /*.n_tasks =*/ n_tasks,
5103
+ /*.userdata =*/ userdata
5104
+ };
5105
+ lm_ggml_set_op_params(result, &params, sizeof(params));
5106
+
5107
+ result->op = LM_GGML_OP_CUSTOM;
5108
+ result->src[0] = a;
5109
+ for (int i = 0; i < n_args; i++) {
5110
+ result->src[i + 1] = args[i];
5111
+ }
5112
+
5113
+ return result;
5114
+ }
5175
5115
  // lm_ggml_cross_entropy_loss
5176
5116
 
5177
5117
  struct lm_ggml_tensor * lm_ggml_cross_entropy_loss(
@@ -5618,7 +5558,7 @@ static void lm_ggml_compute_backward(
5618
5558
  // tensor = src0 * 1 + src1 * 0
5619
5559
  if (src0_needs_grads) {
5620
5560
  // dsrc0 = dtensor * 1
5621
- lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5561
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_reshape(ctx, grad, src0));
5622
5562
  }
5623
5563
  if (src1_needs_grads) {
5624
5564
  // dsrc1 = dtensor * 0 -> noop
@@ -5899,10 +5839,9 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
5899
5839
  }
5900
5840
 
5901
5841
  void lm_ggml_build_backward_expand(
5902
- struct lm_ggml_context * ctx_static,
5903
- struct lm_ggml_context * ctx_compute,
5904
- struct lm_ggml_cgraph * cgraph,
5905
- bool accumulate) {
5842
+ struct lm_ggml_context * ctx,
5843
+ struct lm_ggml_cgraph * cgraph,
5844
+ struct lm_ggml_tensor ** grad_accs) {
5906
5845
  LM_GGML_ASSERT(cgraph->n_nodes > 0);
5907
5846
  LM_GGML_ASSERT(cgraph->grads);
5908
5847
  LM_GGML_ASSERT(cgraph->grad_accs);
@@ -5975,21 +5914,24 @@ void lm_ggml_build_backward_expand(
5975
5914
  LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW ||
5976
5915
  node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE);
5977
5916
 
5978
- const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node);
5979
- LM_GGML_ASSERT(igrad != LM_GGML_HASHSET_FULL);
5980
- LM_GGML_ASSERT(lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5981
- if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) {
5982
- cgraph->grad_accs[igrad] = lm_ggml_dup_tensor(ctx_static, node);
5983
- cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5984
- lm_ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
5917
+ const size_t ihash = lm_ggml_hash_find(&cgraph->visited_hash_set, node);
5918
+ LM_GGML_ASSERT(ihash != LM_GGML_HASHSET_FULL);
5919
+ LM_GGML_ASSERT(lm_ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
5920
+ if (grad_accs && grad_accs[i]) {
5921
+ cgraph->grad_accs[ihash] = grad_accs[i];
5922
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5923
+ } else if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) {
5924
+ // loss tensors always need a gradient accumulator
5925
+ cgraph->grad_accs[ihash] = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
5926
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5985
5927
  }
5986
- grads_needed[igrad] = true;
5928
+ grads_needed[ihash] = true;
5987
5929
  }
5988
5930
 
5989
5931
  for (int i = n_nodes_f - 1; i >= 0; --i) {
5990
5932
  // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation
5991
5933
  // use allocator to automatically make inplace operations
5992
- lm_ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
5934
+ lm_ggml_compute_backward(ctx, cgraph, i, grads_needed);
5993
5935
  }
5994
5936
 
5995
5937
  free(grads_needed);
@@ -6135,8 +6077,8 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst)
6135
6077
  }
6136
6078
  }
6137
6079
 
6138
- struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph) {
6139
- struct lm_ggml_cgraph * result = lm_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
6080
+ struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, bool force_grads) {
6081
+ struct lm_ggml_cgraph * result = lm_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
6140
6082
  lm_ggml_graph_cpy(cgraph, result);
6141
6083
  return result;
6142
6084
  }
@@ -6155,6 +6097,9 @@ struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) {
6155
6097
  }
6156
6098
 
6157
6099
  void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) {
6100
+ if (!cgraph) {
6101
+ return;
6102
+ }
6158
6103
  LM_GGML_ASSERT(cgraph->grads != NULL);
6159
6104
 
6160
6105
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6464,8 +6409,8 @@ void lm_ggml_set_output(struct lm_ggml_tensor * tensor) {
6464
6409
  tensor->flags |= LM_GGML_TENSOR_FLAG_OUTPUT;
6465
6410
  }
6466
6411
 
6467
- void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) {
6468
- LM_GGML_UNUSED(ctx); // TODO: remove this parameter
6412
+ void lm_ggml_set_param(struct lm_ggml_tensor * tensor) {
6413
+ LM_GGML_ASSERT(tensor->op == LM_GGML_OP_NONE);
6469
6414
  tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM;
6470
6415
  }
6471
6416