cui-llama.rn 1.4.6 → 1.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (366) hide show
  1. package/LICENSE +20 -20
  2. package/README.md +317 -319
  3. package/android/build.gradle +116 -116
  4. package/android/gradle.properties +5 -5
  5. package/android/src/main/AndroidManifest.xml +4 -4
  6. package/android/src/main/CMakeLists.txt +124 -117
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +645 -645
  8. package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
  10. package/android/src/main/jni-utils.h +100 -100
  11. package/android/src/main/jni.cpp +1263 -1245
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  14. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  15. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  16. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  17. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  20. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
  21. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
  22. package/cpp/README.md +4 -4
  23. package/cpp/binary-ops.cpp +158 -0
  24. package/cpp/binary-ops.h +16 -0
  25. package/cpp/chat.cpp +1769 -1779
  26. package/cpp/chat.h +9 -1
  27. package/cpp/common.cpp +20 -522
  28. package/cpp/common.h +13 -36
  29. package/cpp/cpu-common.h +72 -0
  30. package/cpp/ggml-common.h +12 -6
  31. package/cpp/ggml-cpu-aarch64.cpp +1557 -80
  32. package/cpp/ggml-cpu-impl.h +2 -21
  33. package/cpp/ggml-cpu-quants.c +904 -405
  34. package/cpp/ggml-cpu.c +909 -13237
  35. package/cpp/ggml-impl.h +50 -23
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +597 -523
  39. package/cpp/ggml-metal.m +798 -580
  40. package/cpp/ggml.c +92 -3
  41. package/cpp/ggml.h +30 -6
  42. package/cpp/gguf.cpp +1 -0
  43. package/cpp/llama-adapter.cpp +55 -20
  44. package/cpp/llama-adapter.h +11 -9
  45. package/cpp/llama-arch.cpp +217 -16
  46. package/cpp/llama-arch.h +25 -0
  47. package/cpp/llama-batch.h +2 -2
  48. package/cpp/llama-chat.cpp +54 -2
  49. package/cpp/llama-chat.h +3 -0
  50. package/cpp/llama-context.cpp +2294 -1238
  51. package/cpp/llama-context.h +214 -77
  52. package/cpp/llama-cparams.h +1 -0
  53. package/cpp/llama-graph.cpp +1695 -0
  54. package/cpp/llama-graph.h +592 -0
  55. package/cpp/llama-hparams.cpp +8 -0
  56. package/cpp/llama-hparams.h +17 -0
  57. package/cpp/llama-io.cpp +15 -0
  58. package/cpp/llama-io.h +35 -0
  59. package/cpp/llama-kv-cache.cpp +965 -303
  60. package/cpp/llama-kv-cache.h +145 -151
  61. package/cpp/llama-memory.cpp +1 -0
  62. package/cpp/llama-memory.h +21 -0
  63. package/cpp/llama-mmap.cpp +1 -1
  64. package/cpp/llama-model-loader.cpp +10 -5
  65. package/cpp/llama-model-loader.h +5 -3
  66. package/cpp/llama-model.cpp +9194 -201
  67. package/cpp/llama-model.h +40 -1
  68. package/cpp/llama-sampling.cpp +5 -0
  69. package/cpp/llama-vocab.cpp +36 -5
  70. package/cpp/llama.cpp +51 -9984
  71. package/cpp/llama.h +102 -22
  72. package/cpp/log.cpp +34 -0
  73. package/cpp/minja/chat-template.hpp +15 -7
  74. package/cpp/minja/minja.hpp +120 -94
  75. package/cpp/ops.cpp +8723 -0
  76. package/cpp/ops.h +128 -0
  77. package/cpp/rn-llama.cpp +873 -882
  78. package/cpp/rn-llama.h +138 -148
  79. package/cpp/sampling.cpp +3 -0
  80. package/cpp/sampling.h +107 -107
  81. package/cpp/sgemm.cpp +533 -88
  82. package/cpp/simd-mappings.h +888 -0
  83. package/cpp/speculative.cpp +4 -4
  84. package/cpp/unary-ops.cpp +186 -0
  85. package/cpp/unary-ops.h +28 -0
  86. package/cpp/unicode-data.cpp +7034 -7034
  87. package/cpp/unicode-data.h +20 -20
  88. package/cpp/unicode.cpp +849 -849
  89. package/cpp/unicode.h +66 -66
  90. package/cpp/vec.cpp +258 -0
  91. package/cpp/vec.h +802 -0
  92. package/ios/CMakeLists.txt +116 -105
  93. package/ios/RNLlama.h +7 -7
  94. package/ios/RNLlama.mm +418 -405
  95. package/ios/RNLlamaContext.h +57 -57
  96. package/ios/RNLlamaContext.mm +835 -819
  97. package/ios/rnllama.xcframework/Info.plist +74 -74
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +677 -0
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1434 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
  143. package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/chat-template.hpp +15 -7
  144. package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/minja.hpp +120 -94
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +128 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +802 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  184. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  191. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  192. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  193. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  194. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  195. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  196. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  197. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  198. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  199. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  200. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  201. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  202. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  203. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  204. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  205. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  206. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  207. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  208. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  209. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  210. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  211. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  212. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  213. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  214. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  215. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  216. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  217. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  218. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +677 -0
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  225. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  226. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  227. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  228. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  229. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  230. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  231. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  232. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  233. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  234. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  235. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  236. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  237. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  238. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  239. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  240. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
  241. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  242. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  243. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  244. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  245. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  246. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  247. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  248. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  249. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  250. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  251. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  261. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  262. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1434 -0
  263. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
  264. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  265. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  266. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +128 -0
  267. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  268. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
  269. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  270. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
  272. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  273. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  274. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
  275. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +802 -0
  276. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  277. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  278. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  279. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  280. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  281. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  282. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  283. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  284. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  285. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  286. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  287. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  288. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  289. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  290. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  291. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  292. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  293. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  294. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  295. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  296. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  297. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  298. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  299. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  300. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  301. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  302. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  303. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  304. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  305. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  306. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  307. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  308. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  309. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  310. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  311. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  312. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  313. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  314. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  315. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  316. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  317. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  318. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  319. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  320. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  321. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  322. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  323. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  324. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  325. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  326. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  327. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  328. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  329. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  330. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  331. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  332. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  333. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  334. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  335. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  336. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  337. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  338. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  339. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  340. package/jest/mock.js +203 -203
  341. package/lib/commonjs/NativeRNLlama.js +1 -2
  342. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  343. package/lib/commonjs/chat.js.map +1 -1
  344. package/lib/commonjs/grammar.js +12 -31
  345. package/lib/commonjs/grammar.js.map +1 -1
  346. package/lib/commonjs/index.js +47 -47
  347. package/lib/commonjs/index.js.map +1 -1
  348. package/lib/commonjs/package.json +1 -0
  349. package/lib/module/NativeRNLlama.js +2 -0
  350. package/lib/module/NativeRNLlama.js.map +1 -1
  351. package/lib/module/chat.js +2 -0
  352. package/lib/module/chat.js.map +1 -1
  353. package/lib/module/grammar.js +14 -31
  354. package/lib/module/grammar.js.map +1 -1
  355. package/lib/module/index.js +47 -45
  356. package/lib/module/index.js.map +1 -1
  357. package/lib/module/package.json +1 -0
  358. package/lib/typescript/NativeRNLlama.d.ts +6 -4
  359. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  360. package/lib/typescript/index.d.ts.map +1 -1
  361. package/llama-rn.podspec +48 -48
  362. package/package.json +233 -233
  363. package/src/NativeRNLlama.ts +426 -424
  364. package/src/chat.ts +44 -44
  365. package/src/grammar.ts +854 -854
  366. package/src/index.ts +495 -485
package/cpp/sgemm.cpp CHANGED
@@ -55,6 +55,7 @@
55
55
 
56
56
  #include <atomic>
57
57
  #include <array>
58
+ #include <type_traits>
58
59
 
59
60
  #ifdef _MSC_VER
60
61
  #define NOINLINE __declspec(noinline)
@@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC {
1092
1093
  }
1093
1094
  }
1094
1095
 
1095
- template<typename VA, typename VB>
1096
- void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1096
+ template<typename VA, typename VB, int size>
1097
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1097
1098
  int64_t i, j;
1098
1099
  TA *aoffset = NULL;
1099
1100
  VA *vecOffset = NULL;
1100
1101
  TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1101
1102
  TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1103
+ VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1104
+ VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1105
+ VB t1, t2, t3, t4, t5, t6, t7, t8;
1106
+ const vector signed char lowMask = vec_splats((signed char)0xF);
1107
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1108
+ const vector signed char v8 = vec_splats((signed char)0x8);
1109
+ aoffset = const_cast<TA*>(a);
1110
+ vecOffset = vec;
1111
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1112
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1113
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1114
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1115
+ vector signed int vsum = {0};
1116
+ vector signed int vsum2 = {0};
1117
+
1118
+ j = (rows >> 3);
1119
+ if (j > 0) {
1120
+ do {
1121
+ aoffset1 = aoffset;
1122
+ aoffset2 = aoffset1 + lda;
1123
+ aoffset3 = aoffset2 + lda;
1124
+ aoffset4 = aoffset3 + lda;
1125
+ aoffset5 = aoffset4 + lda;
1126
+ aoffset6 = aoffset5 + lda;
1127
+ aoffset7 = aoffset6 + lda;
1128
+ aoffset8 = aoffset7 + lda;
1129
+ aoffset += 8 * lda;
1130
+
1131
+ i = (cols >> 2);
1132
+ if (i > 0) {
1133
+ do {
1134
+ c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1135
+ c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1136
+ c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1137
+ c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1138
+ c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1139
+ c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1140
+ c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1141
+ c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1142
+
1143
+ c1[0] = vec_and(c1[1], lowMask);
1144
+ c1[1] = vec_sr(c1[1], v4);
1145
+ c1[0] = vec_sub(c1[0], v8);
1146
+ c1[1] = vec_sub(c1[1], v8);
1147
+ vsum = vec_sum4s(c1[0], vsum);
1148
+ vsum2 = vec_sum4s(c1[1], vsum2);
1149
+ vsum = vec_add(vsum, vsum2);
1150
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1151
+ vsum = vec_splats(0);
1152
+ vsum2 = vec_splats(0);
1153
+
1154
+ c2[0] = vec_and(c2[1], lowMask);
1155
+ c2[1] = vec_sr(c2[1], v4);
1156
+ c2[0] = vec_sub(c2[0], v8);
1157
+ c2[1] = vec_sub(c2[1], v8);
1158
+ vsum = vec_sum4s(c2[0], vsum);
1159
+ vsum2 = vec_sum4s(c2[1], vsum2);
1160
+ vsum = vec_add(vsum, vsum2);
1161
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1162
+ vsum = vec_splats(0);
1163
+ vsum2 = vec_splats(0);
1164
+
1165
+ c3[0] = vec_and(c3[1], lowMask);
1166
+ c3[1] = vec_sr(c3[1], v4);
1167
+ c3[0] = vec_sub(c3[0], v8);
1168
+ c3[1] = vec_sub(c3[1], v8);
1169
+ vsum = vec_sum4s(c3[0], vsum);
1170
+ vsum2 = vec_sum4s(c3[1], vsum2);
1171
+ vsum = vec_add(vsum, vsum2);
1172
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1173
+ vsum = vec_splats(0);
1174
+ vsum2 = vec_splats(0);
1175
+
1176
+ c4[0] = vec_and(c4[1], lowMask);
1177
+ c4[1] = vec_sr(c4[1], v4);
1178
+ c4[0] = vec_sub(c4[0], v8);
1179
+ c4[1] = vec_sub(c4[1], v8);
1180
+ vsum = vec_sum4s(c4[0], vsum);
1181
+ vsum2 = vec_sum4s(c4[1], vsum2);
1182
+ vsum = vec_add(vsum, vsum2);
1183
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1184
+ vsum = vec_splats(0);
1185
+ vsum2 = vec_splats(0);
1186
+
1187
+ c5[0] = vec_and(c5[1], lowMask);
1188
+ c5[1] = vec_sr(c5[1], v4);
1189
+ c5[0] = vec_sub(c5[0], v8);
1190
+ c5[1] = vec_sub(c5[1], v8);
1191
+ vsum = vec_sum4s(c5[0], vsum);
1192
+ vsum2 = vec_sum4s(c5[1], vsum2);
1193
+ vsum = vec_add(vsum, vsum2);
1194
+ comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1195
+ vsum = vec_splats(0);
1196
+ vsum2 = vec_splats(0);
1197
+
1198
+ c6[0] = vec_and(c6[1], lowMask);
1199
+ c6[1] = vec_sr(c6[1], v4);
1200
+ c6[0] = vec_sub(c6[0], v8);
1201
+ c6[1] = vec_sub(c6[1], v8);
1202
+ vsum = vec_sum4s(c6[0], vsum);
1203
+ vsum2 = vec_sum4s(c6[1], vsum2);
1204
+ vsum = vec_add(vsum, vsum2);
1205
+ comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1206
+ vsum = vec_splats(0);
1207
+ vsum2 = vec_splats(0);
1208
+
1209
+ c7[0] = vec_and(c7[1], lowMask);
1210
+ c7[1] = vec_sr(c7[1], v4);
1211
+ c7[0] = vec_sub(c7[0], v8);
1212
+ c7[1] = vec_sub(c7[1], v8);
1213
+ vsum = vec_sum4s(c7[0], vsum);
1214
+ vsum2 = vec_sum4s(c7[1], vsum2);
1215
+ vsum = vec_add(vsum, vsum2);
1216
+ comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1217
+ vsum = vec_splats(0);
1218
+ vsum2 = vec_splats(0);
1219
+
1220
+ c8[0] = vec_and(c8[1], lowMask);
1221
+ c8[1] = vec_sr(c8[1], v4);
1222
+ c8[0] = vec_sub(c8[0], v8);
1223
+ c8[1] = vec_sub(c8[1], v8);
1224
+ vsum = vec_sum4s(c8[0], vsum);
1225
+ vsum2 = vec_sum4s(c8[1], vsum2);
1226
+ vsum = vec_add(vsum, vsum2);
1227
+ comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1228
+ vsum = vec_splats(0);
1229
+ vsum2 = vec_splats(0);
1230
+
1231
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1232
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1233
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1234
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1235
+ t5 = vec_perm(t1, t3, swiz3);
1236
+ t6 = vec_perm(t1, t3, swiz4);
1237
+ t7 = vec_perm(t2, t4, swiz3);
1238
+ t8 = vec_perm(t2, t4, swiz4);
1239
+ vec_xst(t5, 0, vecOffset);
1240
+ vec_xst(t6, 0, vecOffset+16);
1241
+ vec_xst(t7, 0, vecOffset+32);
1242
+ vec_xst(t8, 0, vecOffset+48);
1243
+
1244
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1245
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1246
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1247
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1248
+ t5 = vec_perm(t1, t3, swiz3);
1249
+ t6 = vec_perm(t1, t3, swiz4);
1250
+ t7 = vec_perm(t2, t4, swiz3);
1251
+ t8 = vec_perm(t2, t4, swiz4);
1252
+ vec_xst(t5, 0, vecOffset+64);
1253
+ vec_xst(t6, 0, vecOffset+80);
1254
+ vec_xst(t7, 0, vecOffset+96);
1255
+ vec_xst(t8, 0, vecOffset+112);
1256
+
1257
+ t1 = vec_perm(c5[0], c6[0], swiz1);
1258
+ t2 = vec_perm(c5[0], c6[0], swiz2);
1259
+ t3 = vec_perm(c7[0], c8[0], swiz1);
1260
+ t4 = vec_perm(c7[0], c8[0], swiz2);
1261
+ t5 = vec_perm(t1, t3, swiz3);
1262
+ t6 = vec_perm(t1, t3, swiz4);
1263
+ t7 = vec_perm(t2, t4, swiz3);
1264
+ t8 = vec_perm(t2, t4, swiz4);
1265
+ vec_xst(t5, 0, vecOffset+128);
1266
+ vec_xst(t6, 0, vecOffset+144);
1267
+ vec_xst(t7, 0, vecOffset+160);
1268
+ vec_xst(t8, 0, vecOffset+176);
1269
+
1270
+ t1 = vec_perm(c5[1], c6[1], swiz1);
1271
+ t2 = vec_perm(c5[1], c6[1], swiz2);
1272
+ t3 = vec_perm(c7[1], c8[1], swiz1);
1273
+ t4 = vec_perm(c7[1], c8[1], swiz2);
1274
+ t5 = vec_perm(t1, t3, swiz3);
1275
+ t6 = vec_perm(t1, t3, swiz4);
1276
+ t7 = vec_perm(t2, t4, swiz3);
1277
+ t8 = vec_perm(t2, t4, swiz4);
1278
+ vec_xst(t5, 0, vecOffset+192);
1279
+ vec_xst(t6, 0, vecOffset+208);
1280
+ vec_xst(t7, 0, vecOffset+224);
1281
+ vec_xst(t8, 0, vecOffset+240);
1282
+
1283
+ aoffset1 += lda;
1284
+ aoffset2 += lda;
1285
+ aoffset3 += lda;
1286
+ aoffset4 += lda;
1287
+ aoffset5 += lda;
1288
+ aoffset6 += lda;
1289
+ aoffset7 += lda;
1290
+ aoffset8 += lda;
1291
+ vecOffset += 256;
1292
+ i--;
1293
+ } while (i > 0);
1294
+ }
1295
+ j--;
1296
+ } while (j > 0);
1297
+ }
1298
+
1299
+ if (rows & 4) {
1300
+ aoffset1 = aoffset;
1301
+ aoffset2 = aoffset1 + lda;
1302
+ aoffset3 = aoffset2 + lda;
1303
+ aoffset4 = aoffset3 + lda;
1304
+ aoffset += 4 * lda;
1305
+
1306
+ i = (cols >> 2);
1307
+ if (i > 0) {
1308
+ do {
1309
+ c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1310
+ c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1311
+ c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1312
+ c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1313
+
1314
+ c1[0] = vec_and(c1[1], lowMask);
1315
+ c1[1] = vec_sr(c1[1], v4);
1316
+ c1[0] = vec_sub(c1[0], v8);
1317
+ c1[1] = vec_sub(c1[1], v8);
1318
+ vsum = vec_sum4s(c1[0], vsum);
1319
+ vsum2 = vec_sum4s(c1[1], vsum2);
1320
+ vsum = vec_add(vsum, vsum2);
1321
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1322
+ vsum = vec_splats(0);
1323
+ vsum2 = vec_splats(0);
1324
+
1325
+ c2[0] = vec_and(c2[1], lowMask);
1326
+ c2[1] = vec_sr(c2[1], v4);
1327
+ c2[0] = vec_sub(c2[0], v8);
1328
+ c2[1] = vec_sub(c2[1], v8);
1329
+ vsum = vec_sum4s(c2[0], vsum);
1330
+ vsum2 = vec_sum4s(c2[1], vsum2);
1331
+ vsum = vec_add(vsum, vsum2);
1332
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1333
+ vsum = vec_splats(0);
1334
+ vsum2 = vec_splats(0);
1335
+
1336
+ c3[0] = vec_and(c3[1], lowMask);
1337
+ c3[1] = vec_sr(c3[1], v4);
1338
+ c3[0] = vec_sub(c3[0], v8);
1339
+ c3[1] = vec_sub(c3[1], v8);
1340
+ vsum = vec_sum4s(c3[0], vsum);
1341
+ vsum2 = vec_sum4s(c3[1], vsum2);
1342
+ vsum = vec_add(vsum, vsum2);
1343
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1344
+ vsum = vec_splats(0);
1345
+ vsum2 = vec_splats(0);
1346
+
1347
+ c4[0] = vec_and(c4[1], lowMask);
1348
+ c4[1] = vec_sr(c4[1], v4);
1349
+ c4[0] = vec_sub(c4[0], v8);
1350
+ c4[1] = vec_sub(c4[1], v8);
1351
+ vsum = vec_sum4s(c4[0], vsum);
1352
+ vsum2 = vec_sum4s(c4[1], vsum2);
1353
+ vsum = vec_add(vsum, vsum2);
1354
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1355
+ vsum = vec_splats(0);
1356
+ vsum2 = vec_splats( 0);
1357
+
1358
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1359
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1360
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1361
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1362
+ t5 = vec_perm(t1, t3, swiz3);
1363
+ t6 = vec_perm(t1, t3, swiz4);
1364
+ t7 = vec_perm(t2, t4, swiz3);
1365
+ t8 = vec_perm(t2, t4, swiz4);
1366
+ vec_xst(t5, 0, vecOffset);
1367
+ vec_xst(t6, 0, vecOffset+16);
1368
+ vec_xst(t7, 0, vecOffset+32);
1369
+ vec_xst(t8, 0, vecOffset+48);
1370
+
1371
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1372
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1373
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1374
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1375
+ t5 = vec_perm(t1, t3, swiz3);
1376
+ t6 = vec_perm(t1, t3, swiz4);
1377
+ t7 = vec_perm(t2, t4, swiz3);
1378
+ t8 = vec_perm(t2, t4, swiz4);
1379
+ vec_xst(t5, 0, vecOffset+64);
1380
+ vec_xst(t6, 0, vecOffset+80);
1381
+ vec_xst(t7, 0, vecOffset+96);
1382
+ vec_xst(t8, 0, vecOffset+112);
1383
+
1384
+ aoffset1 += lda;
1385
+ aoffset2 += lda;
1386
+ aoffset3 += lda;
1387
+ aoffset4 += lda;
1388
+ vecOffset += 128;
1389
+ i--;
1390
+ } while (i > 0);
1391
+ }
1392
+ }
1393
+
1394
+ if (rows & 3) {
1395
+ aoffset1 = aoffset;
1396
+ aoffset2 = aoffset1 + lda;
1397
+ aoffset3 = aoffset2 + lda;
1398
+ i = (cols >> 2);
1399
+ if (i > 0) {
1400
+ do {
1401
+ switch(rows) {
1402
+ case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1403
+ case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1404
+ case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1405
+ break;
1406
+ }
1407
+ c1[0] = vec_and(c1[1], lowMask);
1408
+ c1[1] = vec_sr(c1[1], v4);
1409
+ c1[0] = vec_sub(c1[0], v8);
1410
+ c1[1] = vec_sub(c1[1], v8);
1411
+ vsum = vec_sum4s(c1[0], vsum);
1412
+ vsum2 = vec_sum4s(c1[1], vsum2);
1413
+ vsum = vec_add(vsum, vsum2);
1414
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1415
+ vsum = vec_splats(0);
1416
+ vsum2 = vec_splats(0);
1417
+
1418
+ c2[0] = vec_and(c2[1], lowMask);
1419
+ c2[1] = vec_sr(c2[1], v4);
1420
+ c2[0] = vec_sub(c2[0], v8);
1421
+ c2[1] = vec_sub(c2[1], v8);
1422
+ vsum = vec_sum4s(c2[0], vsum);
1423
+ vsum2 = vec_sum4s(c2[1], vsum2);
1424
+ vsum = vec_add(vsum, vsum2);
1425
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1426
+ vsum = vec_splats(0);
1427
+ vsum2 = vec_splats(0);
1428
+
1429
+ c3[0] = vec_and(c3[1], lowMask);
1430
+ c3[1] = vec_sr(c3[1], v4);
1431
+ c3[0] = vec_sub(c3[0], v8);
1432
+ c3[1] = vec_sub(c3[1], v8);
1433
+ vsum = vec_sum4s(c3[0], vsum);
1434
+ vsum2 = vec_sum4s(c3[1], vsum2);
1435
+ vsum = vec_add(vsum, vsum2);
1436
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1437
+ vsum = vec_splats(0);
1438
+ vsum2 = vec_splats(0);
1439
+
1440
+ c4[0] = vec_and(c4[1], lowMask);
1441
+ c4[1] = vec_sr(c4[1], v4);
1442
+ c4[0] = vec_sub(c4[0], v8);
1443
+ c4[1] = vec_sub(c4[1], v8);
1444
+ vsum = vec_sum4s(c4[0], vsum);
1445
+ vsum2 = vec_sum4s(c4[1], vsum2);
1446
+ vsum = vec_add(vsum, vsum2);
1447
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1448
+ vsum = vec_splats(0);
1449
+ vsum2 = vec_splats(0);
1450
+
1451
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1452
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1453
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1454
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1455
+ t5 = vec_perm(t1, t3, swiz3);
1456
+ t6 = vec_perm(t1, t3, swiz4);
1457
+ t7 = vec_perm(t2, t4, swiz3);
1458
+ t8 = vec_perm(t2, t4, swiz4);
1459
+ vec_xst(t5, 0, vecOffset);
1460
+ vec_xst(t6, 0, vecOffset+16);
1461
+ vec_xst(t7, 0, vecOffset+32);
1462
+ vec_xst(t8, 0, vecOffset+48);
1463
+
1464
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1465
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1466
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1467
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1468
+ t5 = vec_perm(t1, t3, swiz3);
1469
+ t6 = vec_perm(t1, t3, swiz4);
1470
+ t7 = vec_perm(t2, t4, swiz3);
1471
+ t8 = vec_perm(t2, t4, swiz4);
1472
+ vec_xst(t5, 0, vecOffset+64);
1473
+ vec_xst(t6, 0, vecOffset+80);
1474
+ vec_xst(t7, 0, vecOffset+96);
1475
+ vec_xst(t8, 0, vecOffset+112);
1476
+ aoffset1 += lda;
1477
+ aoffset2 += lda;
1478
+ aoffset3 += lda;
1479
+ vecOffset += 128;
1480
+ i--;
1481
+ } while(i > 0);
1482
+ }
1483
+ }
1484
+ }
1485
+
1486
+ template<typename VA, typename VB>
1487
+ void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1488
+ int64_t i, j;
1489
+ TB *aoffset = NULL;
1490
+ VA *vecOffset = NULL;
1491
+ TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1492
+ TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1102
1493
  __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1103
1494
  VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1104
1495
  VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
@@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC {
1111
1502
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1112
1503
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1113
1504
 
1114
- aoffset = const_cast<TA*>(a);
1505
+ aoffset = const_cast<TB*>(a);
1115
1506
  vecOffset = vec;
1116
1507
  j = (rows >> 3);
1117
1508
  if (j > 0) {
1118
1509
  do {
1119
- aoffset1 = aoffset;
1120
- aoffset2 = aoffset1 + lda;
1121
- aoffset3 = aoffset2 + lda;
1122
- aoffset4 = aoffset3 + lda;
1123
- aoffset5 = aoffset4 + lda;
1124
- aoffset6 = aoffset5 + lda;
1125
- aoffset7 = aoffset6 + lda;
1126
- aoffset8 = aoffset7 + lda;
1127
- aoffset += 8 * lda;
1510
+ aoffset1 = aoffset;
1511
+ aoffset2 = aoffset1 + lda;
1512
+ aoffset3 = aoffset2 + lda;
1513
+ aoffset4 = aoffset3 + lda;
1514
+ aoffset5 = aoffset4 + lda;
1515
+ aoffset6 = aoffset5 + lda;
1516
+ aoffset7 = aoffset6 + lda;
1517
+ aoffset8 = aoffset7 + lda;
1518
+ aoffset += 8 * lda;
1128
1519
 
1129
- i = (cols >> 3);
1130
- if (i > 0) {
1131
- do {
1520
+ i = (cols >> 3);
1521
+ if (i > 0) {
1522
+ do {
1132
1523
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1133
1524
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1134
1525
  C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC {
1156
1547
  t7 = vec_perm(t2, t4, swiz3);
1157
1548
  t8 = vec_perm(t2, t4, swiz4);
1158
1549
  if (flip == true) {
1159
- t5 = vec_xor(t5, xor_vector);
1160
- t6 = vec_xor(t6, xor_vector);
1161
- t7 = vec_xor(t7, xor_vector);
1162
- t8 = vec_xor(t8, xor_vector);
1550
+ t5 = vec_xor(t5, xor_vector);
1551
+ t6 = vec_xor(t6, xor_vector);
1552
+ t7 = vec_xor(t7, xor_vector);
1553
+ t8 = vec_xor(t8, xor_vector);
1163
1554
  }
1164
1555
  vec_xst(t5, 0, vecOffset);
1165
1556
  vec_xst(t6, 0, vecOffset+16);
@@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC {
1175
1566
  t7 = vec_perm(t2, t4, swiz3);
1176
1567
  t8 = vec_perm(t2, t4, swiz4);
1177
1568
  if (flip == true) {
1178
- t5 = vec_xor(t5, xor_vector);
1179
- t6 = vec_xor(t6, xor_vector);
1180
- t7 = vec_xor(t7, xor_vector);
1181
- t8 = vec_xor(t8, xor_vector);
1569
+ t5 = vec_xor(t5, xor_vector);
1570
+ t6 = vec_xor(t6, xor_vector);
1571
+ t7 = vec_xor(t7, xor_vector);
1572
+ t8 = vec_xor(t8, xor_vector);
1182
1573
  }
1183
1574
  vec_xst(t5, 0, vecOffset+64);
1184
1575
  vec_xst(t6, 0, vecOffset+80);
@@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC {
1194
1585
  t7 = vec_perm(t2, t4, swiz3);
1195
1586
  t8 = vec_perm(t2, t4, swiz4);
1196
1587
  if (flip == true) {
1197
- t5 = vec_xor(t5, xor_vector);
1198
- t6 = vec_xor(t6, xor_vector);
1199
- t7 = vec_xor(t7, xor_vector);
1200
- t8 = vec_xor(t8, xor_vector);
1588
+ t5 = vec_xor(t5, xor_vector);
1589
+ t6 = vec_xor(t6, xor_vector);
1590
+ t7 = vec_xor(t7, xor_vector);
1591
+ t8 = vec_xor(t8, xor_vector);
1201
1592
  }
1202
1593
  vec_xst(t5, 0, vecOffset+128);
1203
1594
  vec_xst(t6, 0, vecOffset+144);
@@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC {
1213
1604
  t7 = vec_perm(t2, t4, swiz3);
1214
1605
  t8 = vec_perm(t2, t4, swiz4);
1215
1606
  if (flip == true) {
1216
- t5 = vec_xor(t5, xor_vector);
1217
- t6 = vec_xor(t6, xor_vector);
1218
- t7 = vec_xor(t7, xor_vector);
1219
- t8 = vec_xor(t8, xor_vector);
1607
+ t5 = vec_xor(t5, xor_vector);
1608
+ t6 = vec_xor(t6, xor_vector);
1609
+ t7 = vec_xor(t7, xor_vector);
1610
+ t8 = vec_xor(t8, xor_vector);
1220
1611
  }
1221
1612
  vec_xst(t5, 0, vecOffset+192);
1222
1613
  vec_xst(t6, 0, vecOffset+208);
@@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC {
1240
1631
  }
1241
1632
 
1242
1633
  if (rows & 4) {
1243
- aoffset1 = aoffset;
1244
- aoffset2 = aoffset1 + lda;
1245
- aoffset3 = aoffset2 + lda;
1246
- aoffset4 = aoffset3 + lda;
1247
- aoffset += 4 * lda;
1634
+ aoffset1 = aoffset;
1635
+ aoffset2 = aoffset1 + lda;
1636
+ aoffset3 = aoffset2 + lda;
1637
+ aoffset4 = aoffset3 + lda;
1638
+ aoffset += 4 * lda;
1248
1639
 
1249
1640
  i = (cols >> 3);
1250
1641
  if (i > 0) {
@@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC {
1311
1702
  aoffset2 = aoffset1 + lda;
1312
1703
  aoffset3 = aoffset2 + lda;
1313
1704
  i = (cols >> 3);
1314
- if (i > 0) {
1705
+ if (i > 0) {
1315
1706
  do {
1316
1707
  switch(rows) {
1317
1708
  case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC {
1527
1918
  void KERNEL_4x8(int64_t ii, int64_t jj) {
1528
1919
  vec_t vec_A[8], vec_B[16] = {0};
1529
1920
  acc_t acc_0, acc_1;
1530
- std::array<int, 4> comparray;
1921
+ std::array<int, 4> comparray {};
1531
1922
  vector float fin_res[8] = {0};
1532
1923
  vector float vs[8] = {0};
1924
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1533
1925
  for (int l = 0; l < k; l++) {
1534
1926
  __builtin_mma_xxsetaccz(&acc_0);
1535
1927
  __builtin_mma_xxsetaccz(&acc_1);
1536
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1928
+ if (std::is_same_v<TA, block_q4_0>) {
1929
+ packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1930
+ } else {
1931
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1932
+ }
1537
1933
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1538
1934
  for(int x = 0; x < 8; x++) {
1539
1935
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC {
1545
1941
  *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1546
1942
  }
1547
1943
  }
1548
- auto aoffset = A+(ii*lda)+l;
1549
- for (int i = 0; i < 4; i++) {
1550
- comparray[i] = 0;
1551
- int ca = 0;
1552
- const int8_t *at = aoffset->qs;
1553
- for (int j = 0; j < 32; j++)
1554
- ca += (int)*at++;
1555
- comparray[i] = ca;
1556
- aoffset += lda;
1944
+ if (!isAblock_q4) {
1945
+ auto aoffset = A+(ii*lda)+l;
1946
+ for (int i = 0; i < 4; i++) {
1947
+ comparray[i] = 0;
1948
+ int ca = 0;
1949
+ auto *at = aoffset->qs;
1950
+ for (int j = 0; j < 32; j++)
1951
+ ca += (int)*at++;
1952
+ comparray[i] = ca;
1953
+ aoffset += lda;
1954
+ }
1557
1955
  }
1558
1956
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1559
1957
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
@@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC {
1565
1963
  void KERNEL_8x4(int64_t ii, int64_t jj) {
1566
1964
  vec_t vec_A[16], vec_B[8] = {0};
1567
1965
  acc_t acc_0, acc_1;
1568
- std::array<int, 8> comparray;
1966
+ std::array<int, 8> comparray {};
1569
1967
  vector float fin_res[8] = {0};
1570
1968
  vector float vs[8] = {0};
1969
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1571
1970
  for (int l = 0; l < k; l++) {
1572
1971
  __builtin_mma_xxsetaccz(&acc_0);
1573
1972
  __builtin_mma_xxsetaccz(&acc_1);
1574
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1973
+ if (std::is_same_v<TA, block_q4_0>) {
1974
+ packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1975
+ } else {
1976
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1977
+ }
1575
1978
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1576
1979
  for(int x = 0; x < 8; x++) {
1577
1980
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC {
1582
1985
  *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1583
1986
  }
1584
1987
  }
1585
- auto aoffset = A+(ii*lda)+l;
1586
- for (int i = 0; i < 8; i++) {
1587
- comparray[i] = 0;
1588
- int ca = 0;
1589
- const int8_t *at = aoffset->qs;
1590
- for (int j = 0; j < 32; j++)
1591
- ca += (int)*at++;
1592
- comparray[i] = ca;
1593
- aoffset += lda;
1988
+ if (!isAblock_q4) {
1989
+ auto aoffset = A+(ii*lda)+l;
1990
+ for (int i = 0; i < 8; i++) {
1991
+ comparray[i] = 0;
1992
+ int ca = 0;
1993
+ auto *at = aoffset->qs;
1994
+ for (int j = 0; j < 32; j++)
1995
+ ca += (int)*at++;
1996
+ comparray[i] = ca;
1997
+ aoffset += lda;
1998
+ }
1594
1999
  }
1595
2000
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1596
2001
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC {
1602
2007
  void KERNEL_8x8(int64_t ii, int64_t jj) {
1603
2008
  vec_t vec_A[16], vec_B[16] = {0};
1604
2009
  acc_t acc_0, acc_1, acc_2, acc_3;
1605
- std::array<int, 8> comparray;
2010
+ std::array<int, 8> comparray {};
1606
2011
  vector float fin_res[16] = {0};
1607
2012
  vector float vs[16] = {0};
2013
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1608
2014
  for (int l = 0; l < k; l++) {
1609
2015
  __builtin_mma_xxsetaccz(&acc_0);
1610
2016
  __builtin_mma_xxsetaccz(&acc_1);
1611
2017
  __builtin_mma_xxsetaccz(&acc_2);
1612
2018
  __builtin_mma_xxsetaccz(&acc_3);
1613
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2019
+ if (std::is_same_v<TA, block_q4_0>) {
2020
+ packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2021
+ } else {
2022
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2023
+ }
1614
2024
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1615
2025
  for(int x = 0; x < 8; x++) {
1616
2026
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC {
1624
2034
  *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1625
2035
  }
1626
2036
  }
1627
- auto aoffset = A+(ii*lda)+l;
1628
- for (int i = 0; i < 8; i++) {
1629
- comparray[i] = 0;
1630
- int ca = 0;
1631
- const int8_t *at = aoffset->qs;
1632
- for (int j = 0; j < 32; j++)
1633
- ca += (int)*at++;
1634
- comparray[i] = ca;
1635
- aoffset += lda;
2037
+ if (!isAblock_q4) {
2038
+ auto aoffset = A+(ii*lda)+l;
2039
+ for (int i = 0; i < 8; i++) {
2040
+ comparray[i] = 0;
2041
+ int ca = 0;
2042
+ auto *at = aoffset->qs;
2043
+ for (int j = 0; j < 32; j++)
2044
+ ca += (int)*at++;
2045
+ comparray[i] = ca;
2046
+ aoffset += lda;
2047
+ }
1636
2048
  }
1637
2049
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1638
2050
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC {
1653
2065
  int64_t duty = (tiles + nth - 1) / nth;
1654
2066
  int64_t start = duty * ith;
1655
2067
  int64_t end = start + duty;
1656
- vec_t vec_A[8], vec_B[8] = {0};
2068
+ vec_t vec_A[8] = {0}, vec_B[8] = {0};
1657
2069
  vector signed int vec_C[4];
1658
2070
  acc_t acc_0;
2071
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1659
2072
 
1660
2073
  if (end > tiles)
1661
2074
  end = tiles;
1662
2075
  for (int64_t job = start; job < end; ++job) {
1663
2076
  int64_t ii = m0 + job / xtiles * RM;
1664
2077
  int64_t jj = n0 + job % xtiles * RN;
1665
- std::array<int, RM> comparray;
2078
+ std::array<int, 4> comparray{};
1666
2079
  vector float res[4] = {0};
1667
2080
  vector float fin_res[4] = {0};
1668
2081
  vector float vs[4] = {0};
@@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC {
1673
2086
  __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1674
2087
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1675
2088
  __builtin_mma_xxsetaccz(&acc_0);
1676
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2089
+ if (isAblock_q4) {
2090
+ packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2091
+ } else {
2092
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2093
+ }
1677
2094
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
1678
2095
  for(int x = 0; x < 8; x+=4) {
1679
2096
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC {
1687
2104
  }
1688
2105
  }
1689
2106
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
1690
- auto aoffset = A+(ii*lda)+l;
1691
- for (int i = 0; i < RM; i++) {
1692
- comparray[i] = 0;
1693
- int ca = 0;
1694
- const int8_t *at = aoffset->qs;
1695
- for (int j = 0; j < 32; j++)
1696
- ca += (int)*at++;
1697
- comparray[i] = ca;
1698
- aoffset += lda;
2107
+ if (!isAblock_q4) {
2108
+ auto aoffset = A+(ii*lda)+l;
2109
+ for (int i = 0; i < RM; i++) {
2110
+ comparray[i] = 0;
2111
+ int ca = 0;
2112
+ auto *at = aoffset->qs;
2113
+ for (int j = 0; j < 32; j++)
2114
+ ca += (int)*at++;
2115
+ comparray[i] = ca;
2116
+ aoffset += lda;
2117
+ }
1699
2118
  }
1700
-
1701
2119
  for (int i = 0; i < RM; i++) {
1702
2120
  CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
1703
2121
  res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
@@ -2013,6 +2431,7 @@ class tinyBLAS_PPC {
2013
2431
  }
2014
2432
  }
2015
2433
  }
2434
+
2016
2435
  void KERNEL_4x4(int64_t ii, int64_t jj) {
2017
2436
  vec_t vec_A[4], vec_B[4], vec_C[4];
2018
2437
  acc_t acc_0;
@@ -2259,15 +2678,27 @@ class tinyBLAS_PPC {
2259
2678
  vec_t vec_C[4];
2260
2679
  acc_t acc_0;
2261
2680
  __builtin_mma_xxsetaccz(&acc_0);
2262
- vec_t vec_A[4], vec_B[4];
2681
+ vec_t vec_A[4] {0}, vec_B[4] = {0};
2263
2682
  for (int l=0; l<k; l+=4) {
2264
- if (RN >= 4 && RM == 1) {
2683
+ /* 'GEMV Forwarding' concept is used in first two conditional loops.
2684
+ * when one of the matrix has a single row/column, the elements are
2685
+ * broadcasted, instead of using packing routine to prepack the
2686
+ * matrix elements.
2687
+ */
2688
+ if (RM == 1) {
2265
2689
  TA* a = const_cast<TA*>(A+(ii)*lda+l);
2266
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2690
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2267
2691
  vec_A[0] = (vec_t)vec_xl(0,a);
2268
2692
  vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
2269
2693
  vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
2270
2694
  vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2695
+ } else if (RN == 1) {
2696
+ packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2697
+ TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2698
+ vec_B[0] = (vec_t)vec_xl(0,b);
2699
+ vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
2700
+ vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
2701
+ vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
2271
2702
  } else {
2272
2703
  packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2273
2704
  packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2371,8 +2802,10 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
2371
2802
  assert(params->ith < params->nth);
2372
2803
 
2373
2804
  // only enable sgemm for prompt processing
2805
+ #if !defined(__MMA__)
2374
2806
  if (n < 2)
2375
2807
  return false;
2808
+ #endif
2376
2809
 
2377
2810
  if (Ctype != LM_GGML_TYPE_F32)
2378
2811
  return false;
@@ -2503,8 +2936,8 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
2503
2936
  params->ith, params->nth};
2504
2937
  tb.matmul(m, n);
2505
2938
  return true;
2506
-
2507
2939
  #elif defined(__MMA__)
2940
+ //TO-DO: Remove this condition once gemv forwarding is enabled.
2508
2941
  if (n < 8 && n != 4)
2509
2942
  return false;
2510
2943
  if (m < 8 && m != 4)
@@ -2516,7 +2949,6 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
2516
2949
  params->ith, params->nth};
2517
2950
  tb.matmul(m, n);
2518
2951
  return true;
2519
-
2520
2952
  #else
2521
2953
  return false;
2522
2954
  #endif
@@ -2541,6 +2973,19 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
2541
2973
  params->ith, params->nth};
2542
2974
  tb.matmul(m, n);
2543
2975
  return true;
2976
+ #elif defined(__MMA__)
2977
+ //TO-DO: Remove this condition once gemv forwarding is enabled.
2978
+ if (n < 8 && n != 4)
2979
+ return false;
2980
+ if (m < 8 && m != 4)
2981
+ return false;
2982
+ tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
2983
+ k, (const block_q4_0 *)A, lda,
2984
+ (const block_q8_0 *)B, ldb,
2985
+ (float *)C, ldc,
2986
+ params->ith, params->nth};
2987
+ tb.matmul(m, n);
2988
+ return true;
2544
2989
  #else
2545
2990
  return false;
2546
2991
  #endif