cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
package/README.md CHANGED
@@ -123,22 +123,50 @@ console.log('Result:', textResult.text)
123
123
  console.log('Timings:', textResult.timings)
124
124
  ```
125
125
 
126
- The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) example in llama.cpp, so you can map its API to LlamaContext:
126
+ The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) example in llama.cpp:
127
127
 
128
128
  - `/completion` and `/chat/completions`: `context.completion(params, partialCompletionCallback)`
129
129
  - `/tokenize`: `context.tokenize(content)`
130
130
  - `/detokenize`: `context.detokenize(tokens)`
131
131
  - `/embedding`: `context.embedding(content)`
132
- - Other methods
133
- - `context.loadSession(path)`
134
- - `context.saveSession(path)`
135
- - `context.stopCompletion()`
136
- - `context.release()`
132
+ - ... Other methods
137
133
 
138
134
  Please visit the [Documentation](docs/API) for more details.
139
135
 
140
136
  You can also visit the [example](example) to see how to use it.
141
137
 
138
+ ## Session (State)
139
+
140
+ The session file is a binary file that contains the state of the context, it can saves time of prompt processing.
141
+
142
+ ```js
143
+ const context = await initLlama({ ...params })
144
+
145
+ // After prompt processing or completion ...
146
+
147
+ // Save the session
148
+ await context.saveSession('<path to save session>')
149
+
150
+ // Load the session
151
+ await context.loadSession('<path to load session>')
152
+ ```
153
+
154
+ ## Embedding
155
+
156
+ The embedding API is used to get the embedding of a text.
157
+
158
+ ```js
159
+ const context = await initLlama({
160
+ ...params,
161
+ embedding: true,
162
+ })
163
+
164
+ const { embedding } = await context.embedding('Hello, world!')
165
+ ```
166
+
167
+ - You can use model like [nomic-ai/nomic-embed-text-v1.5-GGUF](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF) for better embedding quality.
168
+ - You can use DB like [op-sqlite](https://github.com/OP-Engineering/op-sqlite) with sqlite-vec support to store and search embeddings.
169
+
142
170
  ## Tool Calling
143
171
 
144
172
  `llama.rn` has universal tool call support by using [minja](https://github.com/google/minja) (as Jinja template parser) and [chat.cpp](https://github.com/ggerganov/llama.cpp/blob/master/common/chat.cpp) in llama.cpp.
@@ -273,7 +301,7 @@ jest.mock('llama.rn', () => require('llama.rn/jest/mock'))
273
301
 
274
302
  iOS:
275
303
 
276
- - The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
304
+ - The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) and [Increased Memory Limit](https://developer.apple.com/documentation/bundleresources/entitlements/com.apple.developer.kernel.increased-memory-limit?language=objc) capabilities are recommended to enable on iOS project.
277
305
  - Metal:
278
306
  - We have tested to know some devices is not able to use Metal (GPU) due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
279
307
  - It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.
@@ -11,7 +11,10 @@ endif(CCACHE_FOUND)
11
11
  set(CMAKE_CXX_STANDARD 17)
12
12
  set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)
13
13
 
14
- include_directories(${RNLLAMA_LIB_DIR})
14
+ include_directories(
15
+ ${RNLLAMA_LIB_DIR}
16
+ ${RNLLAMA_LIB_DIR}/ggml-cpu
17
+ )
15
18
 
16
19
  set(
17
20
  SOURCE_FILES
@@ -19,15 +22,18 @@ set(
19
22
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
20
23
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
21
24
  ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
22
- ${RNLLAMA_LIB_DIR}/ops.cpp
23
- ${RNLLAMA_LIB_DIR}/unary-ops.cpp
24
- ${RNLLAMA_LIB_DIR}/binary-ops.cpp
25
- ${RNLLAMA_LIB_DIR}/vec.cpp
26
- ${RNLLAMA_LIB_DIR}/ggml-cpu.c
27
- ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
28
- ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
29
- ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
30
- ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
25
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/amx/amx.cpp
26
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/amx/mmq.cpp
27
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/ggml-cpu.c
28
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/ggml-cpu.cpp
29
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/ggml-cpu-aarch64.cpp
30
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/ggml-cpu-quants.c
31
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/ggml-cpu-traits.cpp
32
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/unary-ops.cpp
33
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/binary-ops.cpp
34
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/sgemm.cpp
35
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/vec.cpp
36
+ ${RNLLAMA_LIB_DIR}/ggml-cpu/ops.cpp
31
37
  ${RNLLAMA_LIB_DIR}/ggml-opt.cpp
32
38
  ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
33
39
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
@@ -56,7 +62,6 @@ set(
56
62
  ${RNLLAMA_LIB_DIR}/sampling.cpp
57
63
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
58
64
  ${RNLLAMA_LIB_DIR}/unicode.cpp
59
- ${RNLLAMA_LIB_DIR}/sgemm.cpp
60
65
  ${RNLLAMA_LIB_DIR}/common.cpp
61
66
  ${RNLLAMA_LIB_DIR}/chat.cpp
62
67
  ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
@@ -170,6 +170,8 @@ public class LlamaContext {
170
170
  params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
171
171
  // int pooling_type,
172
172
  params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
173
+ // boolean ctx_shift,
174
+ params.hasKey("ctx_shift") ? params.getBoolean("ctx_shift") : true,
173
175
  // LoadProgressCallback load_progress_callback
174
176
  params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
175
177
  );
@@ -536,7 +538,7 @@ public class LlamaContext {
536
538
  String[] skip
537
539
  );
538
540
  protected static native long initContext(
539
- String model,
541
+ String model_path,
540
542
  String chat_template,
541
543
  String reasoning_format,
542
544
  boolean embedding,
@@ -558,6 +560,7 @@ public class LlamaContext {
558
560
  float rope_freq_base,
559
561
  float rope_freq_scale,
560
562
  int pooling_type,
563
+ boolean ctx_shift,
561
564
  LoadProgressCallback load_progress_callback
562
565
  );
563
566
  protected static native void interruptLoad(long contextPtr);
@@ -9,6 +9,7 @@
9
9
  #include <string>
10
10
  #include <thread>
11
11
  #include <unordered_map>
12
+ #include "json.hpp"
12
13
  #include "json-schema-to-grammar.h"
13
14
  #include "llama.h"
14
15
  #include "chat.h"
@@ -252,6 +253,7 @@ Java_com_rnllama_LlamaContext_initContext(
252
253
  jfloat rope_freq_base,
253
254
  jfloat rope_freq_scale,
254
255
  jint pooling_type,
256
+ jboolean ctx_shift,
255
257
  jobject load_progress_callback
256
258
  ) {
257
259
  UNUSED(thiz);
@@ -264,7 +266,7 @@ Java_com_rnllama_LlamaContext_initContext(
264
266
  }
265
267
 
266
268
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
267
- defaultParams.model = { model_path_chars };
269
+ defaultParams.model.path = model_path_chars;
268
270
 
269
271
  const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
270
272
  defaultParams.chat_template = chat_template_chars;
@@ -279,6 +281,7 @@ Java_com_rnllama_LlamaContext_initContext(
279
281
  defaultParams.n_ctx = n_ctx;
280
282
  defaultParams.n_batch = n_batch;
281
283
  defaultParams.n_ubatch = n_ubatch;
284
+ defaultParams.ctx_shift = ctx_shift;
282
285
 
283
286
  if (pooling_type != -1) {
284
287
  defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
@@ -298,7 +301,7 @@ Java_com_rnllama_LlamaContext_initContext(
298
301
  int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
299
302
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
300
303
 
301
- // defaultParams.n_gpu_layers = n_gpu_layers;
304
+ defaultParams.n_gpu_layers = n_gpu_layers;
302
305
  defaultParams.flash_attn = flash_attn;
303
306
 
304
307
  const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
@@ -534,9 +537,15 @@ Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
534
537
  pushString(env, additional_stops, stop.c_str());
535
538
  }
536
539
  putArray(env, result, "additional_stops", additional_stops);
540
+ } catch (const nlohmann::json_abi_v3_11_3::detail::parse_error& e) {
541
+ std::string errorMessage = "JSON parse error in getFormattedChat: " + std::string(e.what());
542
+ putString(env, result, "_error", errorMessage.c_str());
543
+ LOGI("[RNLlama] %s", errorMessage.c_str());
537
544
  } catch (const std::runtime_error &e) {
538
- LOGI("[RNLlama] Error: %s", e.what());
539
545
  putString(env, result, "_error", e.what());
546
+ LOGI("[RNLlama] Error: %s", e.what());
547
+ } catch (...) {
548
+ putString(env, result, "_error", "Unknown error in getFormattedChat");
540
549
  }
541
550
  env->ReleaseStringUTFChars(tools, tools_chars);
542
551
  env->ReleaseStringUTFChars(messages, messages_chars);
@@ -855,6 +864,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
855
864
  llama->beginCompletion();
856
865
  llama->loadPrompt();
857
866
 
867
+ if (llama->context_full) {
868
+ auto result = createWriteableMap(env);
869
+ putString(env, result, "error", "Context is full");
870
+ return reinterpret_cast<jobject>(result);
871
+ }
872
+
858
873
  size_t sent_count = 0;
859
874
  size_t sent_token_probs_index = 0;
860
875
 
@@ -945,7 +960,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
945
960
  toolCallsSize++;
946
961
  }
947
962
  } catch (const std::exception &e) {
948
- // LOGI("Error parsing tool calls: %s", e.what());
963
+ } catch (...) {
949
964
  }
950
965
  }
951
966
 
@@ -964,6 +979,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
964
979
  putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
965
980
  putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
966
981
  putInt(env, result, "truncated", llama->truncated);
982
+ putBoolean(env, result, "context_full", llama->context_full);
967
983
  putInt(env, result, "stopped_eos", llama->stopped_eos);
968
984
  putInt(env, result, "stopped_word", llama->stopped_word);
969
985
  putInt(env, result, "stopped_limit", llama->stopped_limit);
package/cpp/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2024 The ggml authors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
package/cpp/chat.cpp CHANGED
@@ -1612,7 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
1612
1612
  }
1613
1613
 
1614
1614
  // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1615
- if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
1615
+ if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) {
1616
1616
  return common_chat_params_init_hermes_2_pro(tmpl, params);
1617
1617
  }
1618
1618
 
package/cpp/common.cpp CHANGED
@@ -837,7 +837,7 @@ std::string fs_get_cache_directory() {
837
837
  if (getenv("LLAMA_CACHE")) {
838
838
  cache_directory = std::getenv("LLAMA_CACHE");
839
839
  } else {
840
- #ifdef __linux__
840
+ #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
841
841
  if (std::getenv("XDG_CACHE_HOME")) {
842
842
  cache_directory = std::getenv("XDG_CACHE_HOME");
843
843
  } else {
@@ -847,7 +847,9 @@ std::string fs_get_cache_directory() {
847
847
  cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
848
848
  #elif defined(_WIN32)
849
849
  cache_directory = std::getenv("LOCALAPPDATA");
850
- #endif // __linux__
850
+ #else
851
+ # error Unknown architecture
852
+ #endif
851
853
  cache_directory = ensure_trailing_slash(cache_directory);
852
854
  cache_directory += "llama.cpp";
853
855
  }
@@ -1034,6 +1036,19 @@ struct common_init_result common_init_from_params(common_params & params) {
1034
1036
  return iparams;
1035
1037
  }
1036
1038
 
1039
+ std::string get_model_endpoint() {
1040
+ const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
1041
+ // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
1042
+ const char * hf_endpoint_env = getenv("HF_ENDPOINT");
1043
+ const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
1044
+ std::string model_endpoint = "https://huggingface.co/";
1045
+ if (endpoint_env) {
1046
+ model_endpoint = endpoint_env;
1047
+ if (model_endpoint.back() != '/') model_endpoint += '/';
1048
+ }
1049
+ return model_endpoint;
1050
+ }
1051
+
1037
1052
  void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
1038
1053
  llama_clear_adapter_lora(ctx);
1039
1054
  for (auto & la : lora) {
package/cpp/common.h CHANGED
@@ -355,8 +355,10 @@ struct common_params {
355
355
 
356
356
  common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
357
357
 
358
- // multimodal models (see examples/llava)
358
+ // multimodal models (see tools/llava)
359
359
  struct common_params_model mmproj;
360
+ bool mmproj_use_gpu = true; // use GPU for multimodal model
361
+ bool no_mmproj = false; // explicitly disable multimodal model
360
362
  std::vector<std::string> image; // path to image file(s)
361
363
 
362
364
  // embedding
@@ -427,8 +429,8 @@ struct common_params {
427
429
  int n_pca_batch = 100;
428
430
  int n_pca_iterations = 1000;
429
431
  dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
430
- std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
431
- std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
432
+ std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
433
+ std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
432
434
 
433
435
  bool spm_infill = false; // suffix/prefix/middle pattern for infill
434
436
 
@@ -558,6 +560,8 @@ struct lm_ggml_threadpool_params lm_ggml_threadpool_params_from_cpu_params(const
558
560
  // clear LoRA adapters from context, then apply new list of adapters
559
561
  void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
560
562
 
563
+ std::string get_model_endpoint();
564
+
561
565
  //
562
566
  // Batch utils
563
567
  //
package/cpp/ggml-alloc.c CHANGED
@@ -816,7 +816,10 @@ static void lm_ggml_gallocr_init_tensor(lm_ggml_gallocr_t galloc, struct lm_ggml
816
816
  static bool lm_ggml_gallocr_node_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, struct tensor_alloc * talloc) {
817
817
  size_t node_size = 0;
818
818
  if (!node->data && !node->view_src) {
819
- LM_GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
819
+ // If we previously had data but don't now then reallocate
820
+ if (talloc->buffer_id < 0) {
821
+ return false;
822
+ }
820
823
  node_size = lm_ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
821
824
  }
822
825
  return talloc->size_max >= node_size;
package/cpp/ggml-cpp.h CHANGED
@@ -24,7 +24,7 @@ typedef std::unique_ptr<lm_gguf_context, lm_gguf_context_deleter> lm_gguf_contex
24
24
 
25
25
  struct lm_ggml_gallocr_deleter { void operator()(lm_ggml_gallocr_t galloc) { lm_ggml_gallocr_free(galloc); } };
26
26
 
27
- typedef std::unique_ptr<lm_ggml_gallocr_t, lm_ggml_gallocr_deleter> lm_ggml_gallocr_ptr;
27
+ typedef std::unique_ptr<lm_ggml_gallocr, lm_ggml_gallocr_deleter> lm_ggml_gallocr_ptr;
28
28
 
29
29
  // ggml-backend
30
30
 
@@ -0,0 +1,221 @@
1
+ #include "amx.h"
2
+ #include "common.h"
3
+ #include "mmq.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-backend.h"
6
+ #include "ggml-impl.h"
7
+ #include "ggml-cpu.h"
8
+ #include "ggml-cpu-traits.h"
9
+
10
+ #if defined(__gnu_linux__)
11
+ #include <sys/syscall.h>
12
+ #include <unistd.h>
13
+ #endif
14
+
15
+ #include <cstdlib>
16
+ #include <cstring>
17
+ #include <memory>
18
+
19
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
20
+
21
+ // AMX type_trais
22
+ namespace ggml::cpu::amx {
23
+ class tensor_traits : public ggml::cpu::tensor_traits {
24
+ bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
25
+ size = lm_ggml_backend_amx_desired_wsize(op);
26
+ return true;
27
+ }
28
+
29
+ bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
30
+ if (op->op == LM_GGML_OP_MUL_MAT) {
31
+ lm_ggml_backend_amx_mul_mat(params, op);
32
+ return true;
33
+ }
34
+ return false;
35
+ }
36
+ };
37
+
38
+ static ggml::cpu::tensor_traits * get_tensor_traits(lm_ggml_backend_buffer_t, struct lm_ggml_tensor *) {
39
+ static tensor_traits traits;
40
+ return &traits;
41
+ }
42
+ } // namespace ggml::cpu::amx
43
+
44
+ // AMX buffer interface
45
+ static void lm_ggml_backend_amx_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
46
+ free(buffer->context);
47
+ }
48
+
49
+ static void * lm_ggml_backend_amx_buffer_get_base(lm_ggml_backend_buffer_t buffer) {
50
+ return (void *) (buffer->context);
51
+ }
52
+
53
+ static enum lm_ggml_status lm_ggml_backend_amx_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
54
+ tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
55
+
56
+ LM_GGML_UNUSED(buffer);
57
+ return LM_GGML_STATUS_SUCCESS;
58
+ }
59
+
60
+ static void lm_ggml_backend_amx_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
61
+ uint8_t value, size_t offset, size_t size) {
62
+ memset((char *) tensor->data + offset, value, size);
63
+
64
+ LM_GGML_UNUSED(buffer);
65
+ }
66
+
67
+ static void lm_ggml_backend_amx_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
68
+ const void * data, size_t offset, size_t size) {
69
+ if (qtype_has_amx_kernels(tensor->type)) {
70
+ LM_GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, lm_ggml_type_name(tensor->type));
71
+ lm_ggml_backend_amx_convert_weight(tensor, data, offset, size);
72
+ } else {
73
+ memcpy((char *) tensor->data + offset, data, size);
74
+ }
75
+
76
+ LM_GGML_UNUSED(buffer);
77
+ }
78
+
79
+ /*
80
+ // need to figure what we need to do with buffer->extra.
81
+ static void lm_ggml_backend_amx_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
82
+ LM_GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
83
+ memcpy(data, (const char *)tensor->data + offset, size);
84
+
85
+ LM_GGML_UNUSED(buffer);
86
+ }
87
+
88
+ static bool lm_ggml_backend_amx_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) {
89
+ if (lm_ggml_backend_buffer_is_host(src->buffer)) {
90
+ if (qtype_has_amx_kernels(src->type)) {
91
+ lm_ggml_backend_amx_convert_weight(dst, src->data, 0, lm_ggml_nbytes(dst));
92
+ } else {
93
+ memcpy(dst->data, src->data, lm_ggml_nbytes(src));
94
+ }
95
+ return true;
96
+ }
97
+ return false;
98
+
99
+ LM_GGML_UNUSED(buffer);
100
+ }
101
+ */
102
+
103
+ static void lm_ggml_backend_amx_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) {
104
+ memset(buffer->context, value, buffer->size);
105
+ }
106
+
107
+ static lm_ggml_backend_buffer_i lm_ggml_backend_amx_buffer_interface = {
108
+ /* .free_buffer = */ lm_ggml_backend_amx_buffer_free_buffer,
109
+ /* .get_base = */ lm_ggml_backend_amx_buffer_get_base,
110
+ /* .init_tensor = */ lm_ggml_backend_amx_buffer_init_tensor,
111
+ /* .memset_tensor = */ lm_ggml_backend_amx_buffer_memset_tensor,
112
+ /* .set_tensor = */ lm_ggml_backend_amx_buffer_set_tensor,
113
+ /* .get_tensor = */ nullptr,
114
+ /* .cpy_tensor = */ nullptr,
115
+ /* .clear = */ lm_ggml_backend_amx_buffer_clear,
116
+ /* .reset = */ nullptr,
117
+ };
118
+
119
+ static const char * lm_ggml_backend_amx_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
120
+ return "AMX";
121
+
122
+ LM_GGML_UNUSED(buft);
123
+ }
124
+
125
+ static lm_ggml_backend_buffer_t lm_ggml_backend_amx_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
126
+ void * data = lm_ggml_aligned_malloc(size);
127
+ if (data == NULL) {
128
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
129
+ return NULL;
130
+ }
131
+
132
+ return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_amx_buffer_interface, data, size);
133
+ }
134
+
135
+ static size_t lm_ggml_backend_amx_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
136
+ return TENSOR_ALIGNMENT;
137
+
138
+ LM_GGML_UNUSED(buft);
139
+ }
140
+
141
+ namespace ggml::cpu::amx {
142
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
143
+ bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override {
144
+ // handle only 2d gemm for now
145
+ auto is_contiguous_2d = [](const struct lm_ggml_tensor * t) {
146
+ return lm_ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
147
+ };
148
+
149
+ if (op->op == LM_GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
150
+ is_contiguous_2d(op->src[1]) && // src1 must be contiguous
151
+ op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_amx_buffer_type() &&
152
+ op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
153
+ (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == LM_GGML_TYPE_F16))) {
154
+ // src1 must be host buffer
155
+ if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
156
+ return false;
157
+ }
158
+ // src1 must be float32
159
+ if (op->src[1]->type == LM_GGML_TYPE_F32) {
160
+ return true;
161
+ }
162
+ }
163
+ return false;
164
+ }
165
+
166
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override {
167
+ if (op->op == LM_GGML_OP_MUL_MAT && op->src[0]->buffer &&
168
+ op->src[0]->buffer->buft == lm_ggml_backend_amx_buffer_type()) {
169
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
170
+ }
171
+
172
+ return nullptr;
173
+ }
174
+ };
175
+ } // namespace ggml::cpu::amx
176
+
177
+ static size_t lm_ggml_backend_amx_buffer_type_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const lm_ggml_tensor * tensor) {
178
+ return lm_ggml_backend_amx_get_alloc_size(tensor);
179
+
180
+ LM_GGML_UNUSED(buft);
181
+ }
182
+
183
+ #define ARCH_GET_XCOMP_PERM 0x1022
184
+ #define ARCH_REQ_XCOMP_PERM 0x1023
185
+ #define XFEATURE_XTILECFG 17
186
+ #define XFEATURE_XTILEDATA 18
187
+
188
+ static bool lm_ggml_amx_init() {
189
+ #if defined(__gnu_linux__)
190
+ if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
191
+ fprintf(stderr, "AMX is not ready to be used!\n");
192
+ return false;
193
+ }
194
+ return true;
195
+ #elif defined(_WIN32)
196
+ return true;
197
+ #endif
198
+ }
199
+
200
+ lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type() {
201
+ static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_amx = {
202
+ /* .iface = */ {
203
+ /* .get_name = */ lm_ggml_backend_amx_buffer_type_get_name,
204
+ /* .alloc_buffer = */ lm_ggml_backend_amx_buffer_type_alloc_buffer,
205
+ /* .get_alignment = */ lm_ggml_backend_amx_buffer_type_get_alignment,
206
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
207
+ /* .get_alloc_size = */ lm_ggml_backend_amx_buffer_type_get_alloc_size,
208
+ /* .is_host = */ nullptr,
209
+ },
210
+ /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
211
+ /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
212
+ };
213
+
214
+ if (!lm_ggml_amx_init()) {
215
+ return nullptr;
216
+ }
217
+
218
+ return &lm_ggml_backend_buffer_type_amx;
219
+ }
220
+
221
+ #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
@@ -0,0 +1,8 @@
1
+ #include "ggml-backend.h"
2
+ #include "ggml-cpu-impl.h"
3
+
4
+ // GGML internal header
5
+
6
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
7
+ lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type(void);
8
+ #endif