cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -65,6 +65,7 @@ void lm_ggml_compute_forward_conv_transpose_1d(const struct lm_ggml_compute_para
65
65
  void lm_ggml_compute_forward_im2col(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
66
66
  void lm_ggml_compute_forward_im2col_back_f32(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
67
67
  void lm_ggml_compute_forward_conv_transpose_2d(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
68
+ void lm_ggml_compute_forward_conv_2d_dw(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
68
69
  void lm_ggml_compute_forward_pool_1d(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
69
70
  void lm_ggml_compute_forward_pool_2d(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
70
71
  void lm_ggml_compute_forward_pool_2d_back(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
@@ -96,29 +97,10 @@ void lm_ggml_compute_forward_add_rel_pos(const struct lm_ggml_compute_params * p
96
97
  void lm_ggml_compute_forward_rwkv_wkv6(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
97
98
  void lm_ggml_compute_forward_rwkv_wkv7(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
98
99
  void lm_ggml_compute_forward_gla(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
99
- void lm_ggml_compute_forward_map_unary(
100
- const struct lm_ggml_compute_params * params,
101
- struct lm_ggml_tensor * dst,
102
- const lm_ggml_unary_op_f32_t fun);
103
- void lm_ggml_compute_forward_map_binary(
104
- const struct lm_ggml_compute_params * params,
105
- struct lm_ggml_tensor * dst,
106
- const lm_ggml_binary_op_f32_t fun);
107
- void lm_ggml_compute_forward_map_custom1_f32(
108
- const struct lm_ggml_compute_params * params,
109
- struct lm_ggml_tensor * dst,
110
- const lm_ggml_custom1_op_f32_t fun);
111
- void lm_ggml_compute_forward_map_custom2_f32(
112
- const struct lm_ggml_compute_params * params,
113
- struct lm_ggml_tensor * dst,
114
- const lm_ggml_custom2_op_f32_t fun);
115
- void lm_ggml_compute_forward_map_custom3_f32(
116
- const struct lm_ggml_compute_params * params,
117
- struct lm_ggml_tensor * dst,
118
- const lm_ggml_custom3_op_f32_t fun);
119
100
  void lm_ggml_compute_forward_map_custom1(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
120
101
  void lm_ggml_compute_forward_map_custom2(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
121
102
  void lm_ggml_compute_forward_map_custom3(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
103
+ void lm_ggml_compute_forward_custom(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
122
104
  void lm_ggml_compute_forward_cross_entropy_loss(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
123
105
  void lm_ggml_compute_forward_cross_entropy_loss_back(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
124
106
  void lm_ggml_compute_forward_opt_step_adamw(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
@@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
1054
1054
  } \
1055
1055
  } \
1056
1056
 
1057
+ template <typename TA, typename TB, typename TC>
1058
+ class tinyBLAS_BF16_PPC {
1059
+ public:
1060
+ tinyBLAS_BF16_PPC(int64_t k,
1061
+ const TA *A, int64_t lda,
1062
+ const TB *B, int64_t ldb,
1063
+ TC *C, int64_t ldc,
1064
+ int ith, int nth)
1065
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1066
+ }
1067
+
1068
+ void matmul(int64_t m, int64_t n) {
1069
+ mnpack(0, m, 0, n);
1070
+ }
1071
+
1072
+ private:
1073
+ void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1074
+ vec_t t[8], s[8];
1075
+ vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1076
+ vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1077
+ vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1078
+ vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1079
+
1080
+ if (numVec == 2) {
1081
+ t[0] = vec_perm(c[0], c[1], swiz1);
1082
+ t[1] = vec_perm(c[2], c[3], swiz1);
1083
+ s[0] = vec_perm(t[0], t[1], swiz3);
1084
+ s[1] = vec_perm(t[0], t[1], swiz4);
1085
+ vec_xst(s[0], 0, (vec_t*)vecOffset);
1086
+ vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1087
+ } else if (numVec == 4) {
1088
+ t[0] = vec_perm(c[0], c[1], swiz1);
1089
+ t[1] = vec_perm(c[0], c[1], swiz2);
1090
+ t[2] = vec_perm(c[2], c[3], swiz1);
1091
+ t[3] = vec_perm(c[2], c[3], swiz2);
1092
+ s[0] = vec_perm(t[0], t[2], swiz3);
1093
+ s[1] = vec_perm(t[0], t[2], swiz4);
1094
+ s[2] = vec_perm(t[1], t[3], swiz3);
1095
+ s[3] = vec_perm(t[1], t[3], swiz4);
1096
+ for (int i = 0; i < 4; ++i)
1097
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1098
+ } else if (numVec == 8) {
1099
+ for (int i = 0; i < 4; i += 2) {
1100
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1101
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1102
+ }
1103
+ for (int i = 4; i < 8; i += 2) {
1104
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1105
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1106
+ }
1107
+ s[0] = vec_perm(t[0], t[2], swiz3);
1108
+ s[1] = vec_perm(t[0], t[2], swiz4);
1109
+ s[2] = vec_perm(t[1], t[3], swiz3);
1110
+ s[3] = vec_perm(t[1], t[3], swiz4);
1111
+ s[4] = vec_perm(t[4], t[6], swiz3);
1112
+ s[5] = vec_perm(t[4], t[6], swiz4);
1113
+ s[6] = vec_perm(t[5], t[7], swiz3);
1114
+ s[7] = vec_perm(t[5], t[7], swiz4);
1115
+ for (int i = 0; i < 8; ++i)
1116
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1117
+ }
1118
+ }
1119
+
1120
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1121
+ int64_t i, j;
1122
+ TA *aoffset = NULL;
1123
+ unsigned char *vecOffset = NULL;
1124
+ TA * aoffsets[8];
1125
+ vector unsigned char c_arr[8];
1126
+ aoffset = const_cast<TA*>(a);
1127
+ vecOffset = vec;
1128
+ j = (rows >> 3);
1129
+ if (j > 0) {
1130
+ do {
1131
+ if (cols == 4) {
1132
+ aoffsets[0] = aoffset;
1133
+ for (int it = 1; it < 4; ++it)
1134
+ aoffsets[it] = aoffsets[it-1] + lda;
1135
+ aoffset += 4 * lda;
1136
+ for (int i = 0; i < 4; ++i)
1137
+ c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1138
+ vector_permute_store(c_arr, 4, vecOffset);
1139
+ for (int i = 0; i<4; i++)
1140
+ aoffsets[i] = aoffsets[i]+lda;
1141
+ vecOffset +=64;
1142
+ }
1143
+ i = (cols >> 3);
1144
+ if (i > 0) {
1145
+ aoffsets[0] = aoffset;
1146
+ for (int it = 1; it < 8; ++it) {
1147
+ aoffsets[it] = aoffsets[it-1] + lda;
1148
+ }
1149
+ aoffset += 8 * lda;
1150
+ do {
1151
+ for (int it = 0; it < 8; ++it)
1152
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1153
+ vector_permute_store(c_arr, 8, vecOffset);
1154
+ for (int it = 0; it < 8; ++it)
1155
+ aoffsets[it] = aoffsets[it] + 8*lda;
1156
+ vecOffset += 128;
1157
+ i--;
1158
+ } while(i > 0);
1159
+ }
1160
+ j--;
1161
+ } while(j > 0);
1162
+ }
1163
+ if (rows & 4) {
1164
+ aoffsets[0] = aoffset;
1165
+ for (int it = 1; it < 4; ++it)
1166
+ aoffsets[it] = aoffsets[it-1] + lda;
1167
+ aoffset += 4 * lda;
1168
+ if (cols == 4) {
1169
+ for (int it = 0; it < 4; ++it)
1170
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1171
+ vector_permute_store(c_arr, 2, vecOffset);
1172
+ for (int it = 0; it< 4; it++)
1173
+ aoffsets[it] = aoffsets[it] + lda;
1174
+ vecOffset += 32;
1175
+ }
1176
+ i = (cols >> 3);
1177
+ if (i > 0) {
1178
+ do {
1179
+ for (int it = 0; it < 4; ++it)
1180
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1181
+ vector_permute_store(c_arr, 4, vecOffset);
1182
+ for (int it = 0; it< 4; it++)
1183
+ aoffsets[it] = aoffsets[it] + 8*lda;
1184
+ vecOffset += 64;
1185
+ i--;
1186
+ } while(i > 0);
1187
+ }
1188
+ }
1189
+ if (rows & 3) {
1190
+ aoffsets[0] = aoffset;
1191
+ for (int it = 1; it < 4; ++it)
1192
+ aoffsets[it] = aoffsets[it-1] + lda;
1193
+ if (cols == 4) {
1194
+ switch(rows) {
1195
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1196
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1197
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1198
+ break;
1199
+ }
1200
+ vector_permute_store(c_arr, 2, vecOffset);
1201
+ for (int it = 0; it< 4; it++)
1202
+ aoffsets[it] = aoffsets[it] + lda;
1203
+ vecOffset += 32;
1204
+ }
1205
+ i = (cols >> 3);
1206
+ if (i > 0) {
1207
+ do {
1208
+ switch(rows) {
1209
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1210
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1211
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1212
+ break;
1213
+ }
1214
+ vector_permute_store(c_arr, 4, vecOffset);
1215
+ for (int it = 0; it <4; it++)
1216
+ aoffsets[it] = aoffsets[it] + 8* lda;
1217
+ vecOffset += 64;
1218
+ i--;
1219
+ } while(i > 0);
1220
+ }
1221
+ }
1222
+ }
1223
+
1224
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1225
+ int64_t mc, nc, mp, np;
1226
+ int m_rem = MIN(m - m0, 8);
1227
+ int n_rem = MIN(n - n0, 8);
1228
+
1229
+ if (m_rem >= 8 && n_rem >= 8) {
1230
+ mc = 8;
1231
+ nc = 8;
1232
+ gemm<8,8>(m0, m, n0, n);
1233
+ } else if (m_rem >= 4 && n_rem >= 8) {
1234
+ mc = 4;
1235
+ nc = 8;
1236
+ gemm<4,8>(m0, m, n0, n);
1237
+ } else if (m_rem >=8 && n_rem >=4){
1238
+ mc = 8;
1239
+ nc = 4;
1240
+ gemm<8,4>(m0, m, n0, n);
1241
+ } else if ((m_rem < 4) && (n_rem >= 8)) {
1242
+ nc = 8;
1243
+ switch(m_rem) {
1244
+ case 1:
1245
+ mc = 1;
1246
+ gemm_Mx8<1>(m0, m, n0, n);
1247
+ break;
1248
+ case 2:
1249
+ mc = 2;
1250
+ gemm_Mx8<2>(m0, m, n0, n);
1251
+ break;
1252
+ case 3:
1253
+ mc = 3;
1254
+ gemm_Mx8<3>(m0, m, n0, n);
1255
+ break;
1256
+ default:
1257
+ return;
1258
+ }
1259
+ } else if (m_rem >= 4 && n_rem >= 4) {
1260
+ mc = 4;
1261
+ nc = 4;
1262
+ gemm_small<4, 4>(m0, m, n0, n);
1263
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1264
+ mc = 4;
1265
+ switch(n_rem) {
1266
+ case 1:
1267
+ nc = 1;
1268
+ gemm_small<4, 1>(m0, m, n0, n);
1269
+ break;
1270
+ case 2:
1271
+ nc = 2;
1272
+ gemm_small<4, 2>(m0, m, n0, n);
1273
+ break;
1274
+ case 3:
1275
+ nc = 3;
1276
+ gemm_small<4, 3>(m0, m, n0, n);
1277
+ break;
1278
+
1279
+ default:
1280
+ return;
1281
+ }
1282
+ } else {
1283
+ switch((m_rem << 4) | n_rem) {
1284
+ case 0x43:
1285
+ mc = 4;
1286
+ nc = 3;
1287
+ gemm_small<4, 3>(m0, m, n0, n);
1288
+ break;
1289
+ case 0x42:
1290
+ mc = 4;
1291
+ nc = 2;
1292
+ gemm_small<4, 2>(m0, m, n0, n);
1293
+ break;
1294
+ case 0x41:
1295
+ mc = 4;
1296
+ nc = 1;
1297
+ gemm_small<4, 1>(m0, m, n0, n);
1298
+ break;
1299
+ case 0x34:
1300
+ mc = 3;
1301
+ nc = 4;
1302
+ gemm_small<3, 4>(m0, m, n0, n);
1303
+ break;
1304
+ case 0x33:
1305
+ mc = 3;
1306
+ nc = 3;
1307
+ gemm_small<3, 3>(m0, m, n0, n);
1308
+ break;
1309
+ case 0x32:
1310
+ mc = 3;
1311
+ nc = 2;
1312
+ gemm_small<3, 2>(m0, m, n0, n);
1313
+ break;
1314
+ case 0x31:
1315
+ mc = 3;
1316
+ nc = 1;
1317
+ gemm_small<3, 1>(m0, m, n0, n);
1318
+ break;
1319
+ case 0x24:
1320
+ mc = 2;
1321
+ nc = 4;
1322
+ gemm_small<2,4>(m0, m, n0, n);
1323
+ break;
1324
+ case 0x23:
1325
+ mc = 2;
1326
+ nc = 3;
1327
+ gemm_small<2, 3>(m0, m, n0, n);
1328
+ break;
1329
+ case 0x22:
1330
+ mc = 2;
1331
+ nc = 2;
1332
+ gemm_small<2, 2>(m0, m, n0, n);
1333
+ break;
1334
+ case 0x21:
1335
+ mc = 2;
1336
+ nc = 1;
1337
+ gemm_small<2, 1>(m0, m, n0, n);
1338
+ break;
1339
+ case 0x14:
1340
+ mc = 1;
1341
+ nc = 4;
1342
+ gemm_small<1, 4>(m0, m, n0, n);
1343
+ break;
1344
+ case 0x13:
1345
+ mc = 1;
1346
+ nc = 3;
1347
+ gemm_small<1, 3>(m0, m, n0, n);
1348
+ break;
1349
+ case 0x12:
1350
+ mc = 1;
1351
+ nc = 2;
1352
+ gemm_small<1, 2>(m0, m, n0, n);
1353
+ break;
1354
+ case 0x11:
1355
+ mc = 1;
1356
+ nc = 1;
1357
+ gemm_small<1, 1>(m0, m, n0, n);
1358
+ break;
1359
+ default:
1360
+ return;
1361
+ }
1362
+ }
1363
+ mp = m0 + (m - m0) / mc * mc;
1364
+ np = n0 + (n - n0) / nc * nc;
1365
+ mnpack(mp, m, n0, np);
1366
+ mnpack(m0, m, np, n);
1367
+ }
1368
+
1369
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1370
+ vec_t vec_A[4], vec_B[8] , vec_C[4];
1371
+ acc_t acc_0, acc_1;
1372
+ __builtin_mma_xxsetaccz(&acc_0);
1373
+ __builtin_mma_xxsetaccz(&acc_1);
1374
+ for (int l = 0; l < k; l+=8) {
1375
+ packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1376
+ packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1377
+ for (int x = 0; x < 4; x++) {
1378
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1379
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1380
+ }
1381
+ }
1382
+ SAVE_ACC(&acc_0, ii, jj);
1383
+ SAVE_ACC(&acc_1, ii, jj+4);
1384
+ }
1385
+
1386
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1387
+ vec_t vec_A[8], vec_B[4] , vec_C[4];
1388
+ acc_t acc_0, acc_1;
1389
+ __builtin_mma_xxsetaccz(&acc_0);
1390
+ __builtin_mma_xxsetaccz(&acc_1);
1391
+ for (int l = 0; l < k; l+=8) {
1392
+ packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1393
+ packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1394
+ for (int x = 0; x < 4; x++) {
1395
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1396
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
1397
+ }
1398
+ }
1399
+ SAVE_ACC(&acc_0, ii, jj);
1400
+ SAVE_ACC(&acc_1, ii+4, jj);
1401
+ }
1402
+
1403
+
1404
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1405
+ vec_t vec_A[8], vec_B[8], vec_C[4];
1406
+ acc_t acc_0, acc_1, acc_2, acc_3;
1407
+ __builtin_mma_xxsetaccz(&acc_0);
1408
+ __builtin_mma_xxsetaccz(&acc_1);
1409
+ __builtin_mma_xxsetaccz(&acc_2);
1410
+ __builtin_mma_xxsetaccz(&acc_3);
1411
+ for (int l = 0; l < k; l+=8) {
1412
+ packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1413
+ packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1414
+ for (int x = 0; x < 4; x++) {
1415
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1416
+ __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1417
+ __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1418
+ __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
1419
+ }
1420
+ }
1421
+
1422
+ SAVE_ACC(&acc_0, ii, jj);
1423
+ SAVE_ACC(&acc_1, ii, jj+4);
1424
+ SAVE_ACC(&acc_2, ii+4, jj);
1425
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1426
+ }
1427
+
1428
+ template<int RM, int RN>
1429
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1430
+ int64_t ytiles = (m - m0) / RM;
1431
+ int64_t xtiles = (n - n0) / RN;
1432
+ int64_t tiles = xtiles * ytiles;
1433
+ int64_t duty = (tiles + nth - 1) / nth;
1434
+ int64_t start = duty * ith;
1435
+ int64_t end = start + duty;
1436
+ if (end > tiles)
1437
+ end = tiles;
1438
+ for (int64_t job = start; job < end; ++job) {
1439
+ int64_t ii = m0 + job / xtiles * RM;
1440
+ int64_t jj = n0 + job % xtiles * RN;
1441
+ vec_t vec_C[4];
1442
+ acc_t acc_0;
1443
+ __builtin_mma_xxsetaccz(&acc_0);
1444
+ vec_t vec_A[2], vec_B[2];
1445
+ for (int l=0; l<k; l+=4) {
1446
+ packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1447
+ packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1448
+ for (int x = 0; x<2; x++) {
1449
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1450
+ }
1451
+ }
1452
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1453
+ for (int I = 0; I < RM; I++) {
1454
+ for (int J = 0; J < RN; J++) {
1455
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1456
+ }
1457
+ }
1458
+ }
1459
+ }
1460
+
1461
+ template<int RM>
1462
+ void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1463
+ int RN = 8;
1464
+ int64_t ytiles = (m - m0) / RM;
1465
+ int64_t xtiles = (n - n0) / RN;
1466
+ int64_t tiles = xtiles * ytiles;
1467
+ int64_t duty = (tiles + nth - 1) / nth;
1468
+ int64_t start = duty * ith;
1469
+ int64_t end = start + duty;
1470
+ if (end > tiles)
1471
+ end = tiles;
1472
+ for (int64_t job = start; job < end; ++job) {
1473
+ int64_t ii = m0 + job / xtiles * RM;
1474
+ int64_t jj = n0 + job % xtiles * RN;
1475
+ vec_t vec_C[4];
1476
+ acc_t acc_0, acc_1;
1477
+ __builtin_mma_xxsetaccz(&acc_0);
1478
+ __builtin_mma_xxsetaccz(&acc_1);
1479
+ vec_t vec_A[4], vec_B[8];
1480
+ for (int l=0; l<k; l+=8) {
1481
+ packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1482
+ packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1483
+ for (int x = 0; x<4; x++) {
1484
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1485
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1486
+ }
1487
+ }
1488
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1489
+ for (int I = 0; I < RM; I++) {
1490
+ for (int J = 0; J < 4; J++) {
1491
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1492
+ }
1493
+ }
1494
+ __builtin_mma_disassemble_acc(vec_C, &acc_1);
1495
+ for (int I = 0; I < RM; I++) {
1496
+ for (int J = 0; J < 4; J++) {
1497
+ *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1498
+ }
1499
+ }
1500
+ }
1501
+ }
1502
+
1503
+ template<int RM, int RN>
1504
+ inline void kernel(int64_t ii, int64_t jj) {
1505
+ if constexpr(RM == 4 && RN == 8) {
1506
+ KERNEL_4x8(ii,jj);
1507
+ } else if constexpr(RM == 8 && RN == 8) {
1508
+ KERNEL_8x8(ii,jj);
1509
+ } else if constexpr(RM == 8 && RN == 4) {
1510
+ KERNEL_8x4(ii,jj);
1511
+ } else {
1512
+ static_assert(false, "RN/RM values not supported");
1513
+ }
1514
+ }
1515
+
1516
+ template <int RM, int RN>
1517
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1518
+ int64_t ytiles = (m - m0) / RM;
1519
+ int64_t xtiles = (n - n0) / RN;
1520
+ int64_t tiles = xtiles * ytiles;
1521
+ int64_t duty = (tiles + nth - 1) / nth;
1522
+ int64_t start = duty * ith;
1523
+ int64_t end = start + duty;
1524
+ if (end > tiles)
1525
+ end = tiles;
1526
+ for (int64_t job = start; job < end; ++job) {
1527
+ int64_t ii = m0 + job / xtiles * RM;
1528
+ int64_t jj = n0 + job % xtiles * RN;
1529
+ kernel<RM, RN>(ii, jj);
1530
+ }
1531
+ }
1532
+
1533
+ const TA *const A;
1534
+ const TB *const B;
1535
+ TC *C;
1536
+ const int64_t k;
1537
+ const int64_t lda;
1538
+ const int64_t ldb;
1539
+ const int64_t ldc;
1540
+ const int ith;
1541
+ const int nth;
1542
+ };
1543
+
1057
1544
  template <typename TA, typename TB, typename TC>
1058
1545
  class tinyBLAS_Q0_PPC {
1059
1546
  public:
@@ -2202,6 +2689,7 @@ class tinyBLAS_PPC {
2202
2689
  boffset = vec;
2203
2690
  j = (rows >> 3);
2204
2691
  if (j > 0) {
2692
+
2205
2693
  do {
2206
2694
  aoffset1 = aoffset;
2207
2695
  aoffset2 = aoffset1 + lda;
@@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
2875
3363
  (float *)C, ldc};
2876
3364
  return tb.matmul(m, n);
2877
3365
  }
3366
+ #elif defined(__MMA__)
3367
+ if ((k % 8))
3368
+ return false;
3369
+ if(Btype == LM_GGML_TYPE_BF16) {
3370
+ tinyBLAS_BF16_PPC<lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ k,
3371
+ (const lm_ggml_bf16_t *)A, lda,
3372
+ (const lm_ggml_bf16_t *)B, ldb,
3373
+ (float *)C, ldc,
3374
+ params->ith, params->nth};
3375
+ tb.matmul(m, n);
3376
+ return true;
3377
+ }
2878
3378
  #endif
2879
3379
  return false;
2880
3380
  }
3381
+
2881
3382
  case LM_GGML_TYPE_F16: {
2882
3383
  #if defined(__AVX512F__)
2883
3384
  if (Btype == LM_GGML_TYPE_F16) {
@@ -341,7 +341,7 @@ static inline void __avx_f32cx8_store(lm_ggml_fp16_t *x, __m256 y) {
341
341
  #define LM_GGML_F32_EPR 4
342
342
 
343
343
  #define LM_GGML_F32x4 vector float
344
- #define LM_GGML_F32x4_ZERO 0.0f
344
+ #define LM_GGML_F32x4_ZERO {0.0f}
345
345
  #define LM_GGML_F32x4_SET1 vec_splats
346
346
  #define LM_GGML_F32x4_LOAD(p) vec_xl(0, p)
347
347
  #define LM_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
@@ -855,13 +855,17 @@ static inline __vector float __lzs_f16cx4_load(const lm_ggml_fp16_t * x) {
855
855
  tmp[i] = LM_GGML_FP16_TO_FP32(x[i]);
856
856
  }
857
857
 
858
- return vec_xl(0, tmp);
858
+ // note: keep type-cast here to prevent compiler bugs
859
+ // see: https://github.com/ggml-org/llama.cpp/issues/12846
860
+ return vec_xl(0, (const float *)(tmp));
859
861
  }
860
862
 
861
863
  static inline void __lzs_f16cx4_store(lm_ggml_fp16_t * x, __vector float y) {
862
864
  float arr[4];
863
865
 
864
- vec_xst(y, 0, arr);
866
+ // note: keep type-cast here to prevent compiler bugs
867
+ // see: https://github.com/ggml-org/llama.cpp/issues/12846
868
+ vec_xst(y, 0, (float *)(arr));
865
869
 
866
870
  for (int i = 0; i < 4; i++) {
867
871
  x[i] = LM_GGML_FP32_TO_FP16(arr[i]);
@@ -1,6 +1,6 @@
1
1
  #pragma once
2
2
 
3
- #include "cpu-common.h"
3
+ #include "common.h"
4
4
 
5
5
  #ifdef __cplusplus
6
6
  extern "C" {
package/cpp/ggml-cpu.h CHANGED
@@ -133,6 +133,11 @@ extern "C" {
133
133
 
134
134
  LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void);
135
135
 
136
+ LM_GGML_BACKEND_API void lm_ggml_cpu_fp32_to_fp16(const float *, lm_ggml_fp16_t *, int64_t);
137
+ LM_GGML_BACKEND_API void lm_ggml_cpu_fp16_to_fp32(const lm_ggml_fp16_t *, float *, int64_t);
138
+ LM_GGML_BACKEND_API void lm_ggml_cpu_fp32_to_bf16(const float *, lm_ggml_bf16_t *, int64_t);
139
+ LM_GGML_BACKEND_API void lm_ggml_cpu_bf16_to_fp32(const lm_ggml_bf16_t *, float *, int64_t);
140
+
136
141
  #ifdef __cplusplus
137
142
  }
138
143
  #endif
package/cpp/ggml-impl.h CHANGED
@@ -16,6 +16,14 @@
16
16
  #include <arm_sve.h>
17
17
  #endif // __ARM_FEATURE_SVE
18
18
 
19
+ #if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
20
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
21
+ //
22
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
23
+ //
24
+ #include <arm_neon.h>
25
+ #endif
26
+
19
27
  #if defined(__F16C__)
20
28
  #include <immintrin.h>
21
29
  #endif
@@ -140,8 +148,14 @@ struct lm_ggml_map_custom2_op_params {
140
148
 
141
149
  struct lm_ggml_map_custom3_op_params {
142
150
  lm_ggml_custom3_op_t fun;
143
- int n_tasks;
144
- void * userdata;
151
+ int n_tasks;
152
+ void * userdata;
153
+ };
154
+
155
+ struct lm_ggml_custom_op_params {
156
+ lm_ggml_custom_op_t fun;
157
+ int n_tasks;
158
+ void * userdata;
145
159
  };
146
160
 
147
161
  // bitset
@@ -311,13 +325,6 @@ LM_GGML_API void lm_ggml_aligned_free(void * ptr, size_t size);
311
325
  // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
312
326
  //
313
327
  #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
314
-
315
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
316
- //
317
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
318
- //
319
- #include <arm_neon.h>
320
-
321
328
  #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x)
322
329
  #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x)
323
330
 
Binary file
Binary file