llama_cpp 0.16.2 → 0.17.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -43
  5. data/ext/llama_cpp/llama_cpp.cpp +8 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +3 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1124
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2225
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -236
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -314
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3273
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14994
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1178
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -6351
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -40
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -144508
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7183
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22506
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2458
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18985
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1147
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1032
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -7033
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -810
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,1032 +0,0 @@
1
- // Copyright 2024 Mozilla Foundation
2
- //
3
- // Permission is hereby granted, free of charge, to any person obtaining
4
- // a copy of this software and associated documentation files (the
5
- // "Software"), to deal in the Software without restriction, including
6
- // without limitation the rights to use, copy, modify, merge, publish,
7
- // distribute, sublicense, and/or sell copies of the Software, and to
8
- // permit persons to whom the Software is furnished to do so, subject to
9
- // the following conditions:
10
- //
11
- // The above copyright notice and this permission notice shall be
12
- // included in all copies or substantial portions of the Software.
13
- //
14
- // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
- // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
- // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
- // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18
- // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19
- // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
- // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- // SOFTWARE.
22
-
23
- //
24
- // _ _ ___ _ _ ___
25
- // | |_(_)_ _ _ _| _ ) | /_\ / __|
26
- // | _| | ' \ || | _ \ |__ / _ \\__ \.
27
- // \__|_|_||_\_, |___/____/_/ \_\___/
28
- // |__/
29
- //
30
- // BASIC LINEAR ALGEBRA SUBPROGRAMS
31
- //
32
- //
33
- // This file implements multithreaded CPU matrix multiplication for the
34
- // common contiguous use case C = Aᵀ * B. These kernels are designed to
35
- // have excellent performance[1] for matrices that fit in the CPU cache
36
- // without imposing any overhead such as cache filling or malloc calls.
37
- //
38
- // This implementation does not guarantee any upper bound with rounding
39
- // errors, which grow along with k. Our goal's to maximally exploit the
40
- // hardware for performance, and then use whatever resources remain for
41
- // improving numerical accuracy.
42
- //
43
- // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44
- // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
-
46
- #if defined(__GNUC__)
47
- #pragma GCC diagnostic ignored "-Wpedantic"
48
- #pragma GCC diagnostic ignored "-Wignored-attributes"
49
- #endif
50
-
51
- #include "sgemm.h"
52
- #include "ggml-impl.h"
53
- #include "ggml-quants.h"
54
-
55
- #ifdef _MSC_VER
56
- #define NOINLINE __declspec(noinline)
57
- #else
58
- #define NOINLINE __attribute__((__noinline__))
59
- #endif
60
-
61
- #if defined(__ARM_NEON) || defined(__AVX512F__)
62
- #define VECTOR_REGISTERS 32
63
- #else
64
- #define VECTOR_REGISTERS 16
65
- #endif
66
-
67
- #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
68
-
69
- namespace {
70
-
71
- inline float unhalf(ggml_fp16_t d) {
72
- return GGML_FP16_TO_FP32(d);
73
- }
74
-
75
- ////////////////////////////////////////////////////////////////////////////////////////////////////
76
- // VECTORIZED ARITHMETIC OPERATIONS
77
-
78
- #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
79
- inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
80
- inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
81
- inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
82
- #endif // __SSE__
83
-
84
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
85
- inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
86
- inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
87
- inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
88
- #endif // __AVX__
89
-
90
- #if defined(__AVX512F__)
91
- inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
92
- inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
93
- inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
94
- #endif // __AVX512F__
95
-
96
- #if defined(__ARM_NEON)
97
- inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
98
- inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
99
- inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
100
- #endif // __ARM_NEON
101
-
102
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
103
- inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
104
- inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
105
- inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
106
- #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
107
-
108
- ////////////////////////////////////////////////////////////////////////////////////////////////////
109
- // VECTORIZED FUSED MULTIPLY ADD
110
-
111
- /**
112
- * Computes a * b + c.
113
- */
114
- template <typename T, typename U>
115
- inline U madd(T a, T b, U c) {
116
- return add(mul(a, b), c);
117
- }
118
-
119
- #if defined(__FMA__)
120
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
121
- template <>
122
- inline __m256 madd(__m256 a, __m256 b, __m256 c) {
123
- return _mm256_fmadd_ps(a, b, c);
124
- }
125
- #endif
126
- #if defined(__AVX512F__)
127
- template <>
128
- inline __m512 madd(__m512 a, __m512 b, __m512 c) {
129
- return _mm512_fmadd_ps(a, b, c);
130
- }
131
- #endif
132
- #endif
133
-
134
- #if defined(__ARM_FEATURE_FMA)
135
- template <>
136
- inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
137
- return vfmaq_f32(c, b, a);
138
- }
139
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
140
- template <>
141
- inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
142
- return vfmaq_f16(c, b, a);
143
- }
144
- #endif
145
- #endif
146
-
147
- ////////////////////////////////////////////////////////////////////////////////////////////////////
148
- // VECTORIZED HORIZONTAL SUM
149
-
150
- #if defined(__ARM_NEON)
151
- inline float hsum(float32x4_t x) {
152
- return vaddvq_f32(x);
153
- }
154
- #endif // __ARM_NEON
155
-
156
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
157
- inline float hsum(float16x8_t x) {
158
- return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
159
- vcvt_f32_f16(vget_high_f16(x))));
160
- }
161
- #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
162
-
163
- #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
164
- inline float hsum(__m128 x) {
165
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
166
- x = _mm_add_ps(x, _mm_movehl_ps(x, x));
167
- x = _mm_add_ss(x, _mm_movehdup_ps(x));
168
- #else
169
- __m128 t;
170
- t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
171
- x = _mm_add_ps(x, t);
172
- t = _mm_movehl_ps(t, x);
173
- x = _mm_add_ss(x, t);
174
- #endif
175
- return _mm_cvtss_f32(x);
176
- }
177
- #endif
178
-
179
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
180
- inline float hsum(__m256 x) {
181
- return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
182
- _mm256_castps256_ps128(x)));
183
- }
184
- #endif // __AVX__
185
-
186
- #if defined(__AVX512F__)
187
- inline float hsum(__m512 x) {
188
- return _mm512_reduce_add_ps(x);
189
- }
190
- #endif // __AVX512F__
191
-
192
- ////////////////////////////////////////////////////////////////////////////////////////////////////
193
- // VECTORIZED MEMORY LOADING
194
-
195
- template <typename T, typename U> T load(const U *);
196
-
197
- #if defined(__ARM_NEON)
198
- template <> inline float32x4_t load(const float *p) {
199
- return vld1q_f32(p);
200
- }
201
- #if !defined(_MSC_VER)
202
- template <> inline float16x8_t load(const ggml_fp16_t *p) {
203
- return vld1q_f16((const float16_t *)p);
204
- }
205
- template <> inline float32x4_t load(const ggml_fp16_t *p) {
206
- return vcvt_f32_f16(vld1_f16((const float16_t *)p));
207
- }
208
- #endif // _MSC_VER
209
- #endif // __ARM_NEON
210
-
211
- #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
212
- template <> inline __m128 load(const float *p) {
213
- return _mm_loadu_ps(p);
214
- }
215
- #endif // __SSE__
216
-
217
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
218
- template <> inline __m256 load(const float *p) {
219
- return _mm256_loadu_ps(p);
220
- }
221
- #endif // __AVX__
222
-
223
- #if defined(__F16C__)
224
- template <> inline __m256 load(const ggml_fp16_t *p) {
225
- return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
226
- }
227
- #endif // __F16C__
228
-
229
- #if defined(__AVX512F__)
230
- template <> inline __m512 load(const float *p) {
231
- return _mm512_loadu_ps(p);
232
- }
233
- template <> inline __m512 load(const ggml_fp16_t *p) {
234
- return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
235
- }
236
- #endif // __AVX512F__
237
-
238
- ////////////////////////////////////////////////////////////////////////////////////////////////////
239
- // FLOATING POINT MATRIX MULTIPLICATION
240
-
241
- template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
242
- class tinyBLAS {
243
- public:
244
- tinyBLAS(int64_t k,
245
- const TA *A, int64_t lda,
246
- const TB *B, int64_t ldb,
247
- TC *C, int64_t ldc,
248
- int ith, int nth)
249
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
250
- }
251
-
252
- void matmul(int64_t m, int64_t n, int task) {
253
- if (task == GGML_TASK_TYPE_COMPUTE)
254
- mnpack(0, m, 0, n);
255
- }
256
-
257
- private:
258
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
259
- int64_t mc, nc, mp, np;
260
- switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
261
- #if VECTOR_REGISTERS == 32
262
- case 0x55:
263
- mc = 5;
264
- nc = 5;
265
- gemm<5, 5>(m0, m, n0, n);
266
- break;
267
- case 0x45:
268
- mc = 4;
269
- nc = 5;
270
- gemm<4, 5>(m0, m, n0, n);
271
- break;
272
- case 0x54:
273
- mc = 5;
274
- nc = 4;
275
- gemm<5, 4>(m0, m, n0, n);
276
- break;
277
- case 0x44:
278
- mc = 4;
279
- nc = 4;
280
- gemm<4, 4>(m0, m, n0, n);
281
- break;
282
- case 0x53:
283
- mc = 5;
284
- nc = 3;
285
- gemm<5, 3>(m0, m, n0, n);
286
- break;
287
- case 0x35:
288
- mc = 3;
289
- nc = 5;
290
- gemm<3, 5>(m0, m, n0, n);
291
- break;
292
- case 0x43:
293
- mc = 4;
294
- nc = 3;
295
- gemm<4, 3>(m0, m, n0, n);
296
- break;
297
- #else
298
- case 0x55:
299
- case 0x54:
300
- case 0x53:
301
- case 0x45:
302
- case 0x44:
303
- case 0x43:
304
- mc = 4;
305
- nc = 3;
306
- gemm<4, 3>(m0, m, n0, n);
307
- break;
308
- case 0x35:
309
- #endif
310
- case 0x34:
311
- mc = 3;
312
- nc = 4;
313
- gemm<3, 4>(m0, m, n0, n);
314
- break;
315
- case 0x52:
316
- mc = 5;
317
- nc = 2;
318
- gemm<5, 2>(m0, m, n0, n);
319
- break;
320
- case 0x33:
321
- mc = 3;
322
- nc = 3;
323
- gemm<3, 3>(m0, m, n0, n);
324
- break;
325
- case 0x25:
326
- mc = 2;
327
- nc = 5;
328
- gemm<2, 5>(m0, m, n0, n);
329
- break;
330
- case 0x42:
331
- mc = 4;
332
- nc = 2;
333
- gemm<4, 2>(m0, m, n0, n);
334
- break;
335
- case 0x24:
336
- mc = 2;
337
- nc = 4;
338
- gemm<2, 4>(m0, m, n0, n);
339
- break;
340
- case 0x32:
341
- mc = 3;
342
- nc = 2;
343
- gemm<3, 2>(m0, m, n0, n);
344
- break;
345
- case 0x23:
346
- mc = 2;
347
- nc = 3;
348
- gemm<2, 3>(m0, m, n0, n);
349
- break;
350
- case 0x51:
351
- mc = 5;
352
- nc = 1;
353
- gemm<5, 1>(m0, m, n0, n);
354
- break;
355
- case 0x41:
356
- mc = 4;
357
- nc = 1;
358
- gemm<4, 1>(m0, m, n0, n);
359
- break;
360
- case 0x22:
361
- mc = 2;
362
- nc = 2;
363
- gemm<2, 2>(m0, m, n0, n);
364
- break;
365
- case 0x15:
366
- mc = 1;
367
- nc = 5;
368
- gemm<1, 5>(m0, m, n0, n);
369
- break;
370
- case 0x14:
371
- mc = 1;
372
- nc = 4;
373
- gemm<1, 4>(m0, m, n0, n);
374
- break;
375
- case 0x31:
376
- mc = 3;
377
- nc = 1;
378
- gemm<3, 1>(m0, m, n0, n);
379
- break;
380
- case 0x13:
381
- mc = 1;
382
- nc = 3;
383
- gemm<1, 3>(m0, m, n0, n);
384
- break;
385
- case 0x21:
386
- mc = 2;
387
- nc = 1;
388
- gemm<2, 1>(m0, m, n0, n);
389
- break;
390
- case 0x12:
391
- mc = 1;
392
- nc = 2;
393
- gemm<1, 2>(m0, m, n0, n);
394
- break;
395
- case 0x11:
396
- mc = 1;
397
- nc = 1;
398
- gemm<1, 1>(m0, m, n0, n);
399
- break;
400
- default:
401
- return;
402
- }
403
- mp = m0 + (m - m0) / mc * mc;
404
- np = n0 + (n - n0) / nc * nc;
405
- mnpack(mp, m, n0, np);
406
- mnpack(m0, m, np, n);
407
- }
408
-
409
- template <int RM, int RN>
410
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
411
- int64_t ytiles = (m - m0) / RM;
412
- int64_t xtiles = (n - n0) / RN;
413
- int64_t tiles = xtiles * ytiles;
414
- int64_t duty = (tiles + nth - 1) / nth;
415
- int64_t start = duty * ith;
416
- int64_t end = start + duty;
417
- if (end > tiles)
418
- end = tiles;
419
- for (int64_t job = start; job < end; ++job) {
420
- int64_t ii = m0 + job / xtiles * RM;
421
- int64_t jj = n0 + job % xtiles * RN;
422
- D Cv[RN][RM] = {};
423
- for (int64_t l = 0; l < k; l += KN)
424
- for (int64_t j = 0; j < RN; ++j)
425
- for (int64_t i = 0; i < RM; ++i)
426
- Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
427
- load<V>(B + ldb * (jj + j) + l),
428
- Cv[j][i]);
429
- for (int64_t j = 0; j < RN; ++j)
430
- for (int64_t i = 0; i < RM; ++i)
431
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
432
- }
433
- }
434
-
435
- const TA *const A;
436
- const TB *const B;
437
- TC *const C;
438
- const int64_t k;
439
- const int64_t lda;
440
- const int64_t ldb;
441
- const int64_t ldc;
442
- const int ith;
443
- const int nth;
444
- };
445
-
446
- //////////////////////////////////////////////////////////////////////////////////////////
447
- // QUANT ZERO MATRIX MULTIPLICATION
448
-
449
- #if defined(__ARM_FEATURE_DOTPROD)
450
- template <typename TA>
451
- class tinyBLAS_Q0_ARM {
452
- public:
453
- tinyBLAS_Q0_ARM(int64_t k,
454
- const TA *A, int64_t lda,
455
- const block_q8_0 *B, int64_t ldb,
456
- float *C, int64_t ldc,
457
- int ith, int nth)
458
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
459
- }
460
-
461
- void matmul(int64_t m, int64_t n, int task) {
462
- if (task == GGML_TASK_TYPE_COMPUTE)
463
- mnpack(0, m, 0, n);
464
- }
465
-
466
- private:
467
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
468
- int64_t mc, nc, mp, np;
469
- switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
470
- case 0x33:
471
- mc = 3;
472
- nc = 3;
473
- gemm<3, 3>(m0, m, n0, n);
474
- break;
475
- case 0x32:
476
- mc = 3;
477
- nc = 2;
478
- gemm<3, 2>(m0, m, n0, n);
479
- break;
480
- case 0x23:
481
- mc = 2;
482
- nc = 3;
483
- gemm<2, 3>(m0, m, n0, n);
484
- break;
485
- case 0x22:
486
- mc = 2;
487
- nc = 2;
488
- gemm<2, 2>(m0, m, n0, n);
489
- break;
490
- case 0x31:
491
- mc = 3;
492
- nc = 1;
493
- gemm<3, 1>(m0, m, n0, n);
494
- break;
495
- case 0x13:
496
- mc = 1;
497
- nc = 3;
498
- gemm<1, 3>(m0, m, n0, n);
499
- break;
500
- case 0x21:
501
- mc = 2;
502
- nc = 1;
503
- gemm<2, 1>(m0, m, n0, n);
504
- break;
505
- case 0x12:
506
- mc = 1;
507
- nc = 2;
508
- gemm<1, 2>(m0, m, n0, n);
509
- break;
510
- case 0x11:
511
- mc = 1;
512
- nc = 1;
513
- gemm<1, 1>(m0, m, n0, n);
514
- break;
515
- default:
516
- return;
517
- }
518
- mp = m0 + (m - m0) / mc * mc;
519
- np = n0 + (n - n0) / nc * nc;
520
- mnpack(mp, m, n0, np);
521
- mnpack(m0, m, np, n);
522
- }
523
-
524
- template <int RM, int RN>
525
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
526
- int64_t ytiles = (m - m0) / RM;
527
- int64_t xtiles = (n - n0) / RN;
528
- int64_t tiles = xtiles * ytiles;
529
- int64_t duty = (tiles + nth - 1) / nth;
530
- int64_t start = duty * ith;
531
- int64_t end = start + duty;
532
- if (end > tiles)
533
- end = tiles;
534
- for (int64_t job = start; job < end; ++job) {
535
- int64_t ii = m0 + job / xtiles * RM;
536
- int64_t jj = n0 + job % xtiles * RN;
537
- float32x4_t Cv[RN][RM] = {};
538
- for (int64_t l = 0; l < k; ++l)
539
- for (int64_t j = 0; j < RN; ++j)
540
- for (int64_t i = 0; i < RM; ++i)
541
- Cv[j][i] = vmlaq_n_f32(Cv[j][i],
542
- vcvtq_f32_s32(vdotq_s32(
543
- vdotq_s32(vdupq_n_s32(0),
544
- load_lo(A + lda * (ii + i) + l),
545
- load_lo(B + ldb * (jj + j) + l)),
546
- load_hi(A + lda * (ii + i) + l),
547
- load_hi(B + ldb * (jj + j) + l))),
548
- unhalf(A[lda * (ii + i) + l].d) *
549
- unhalf(B[ldb * (jj + j) + l].d));
550
- for (int64_t j = 0; j < RN; ++j)
551
- for (int64_t i = 0; i < RM; ++i)
552
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
553
- }
554
- }
555
-
556
- inline int8x16_t load_lo(const block_q8_0 *b) {
557
- return vld1q_s8(b->qs);
558
- }
559
-
560
- inline int8x16_t load_hi(const block_q8_0 *b) {
561
- return vld1q_s8(b->qs + 16);
562
- }
563
-
564
- inline int8x16_t load_lo(const block_q4_0 *b) {
565
- return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
566
- vdupq_n_u8(0x0f))),
567
- vdupq_n_s8(0x8));
568
- }
569
-
570
- inline int8x16_t load_hi(const block_q4_0 *b) {
571
- return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
572
- vdupq_n_s8(0x8));
573
- }
574
-
575
- const TA *const A;
576
- const block_q8_0 *const B;
577
- float *const C;
578
- const int64_t k;
579
- const int64_t lda;
580
- const int64_t ldb;
581
- const int64_t ldc;
582
- const int ith;
583
- const int nth;
584
- };
585
- #endif // __ARM_FEATURE_DOTPROD
586
-
587
- #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
588
- template <typename TA, typename TB, typename TC>
589
- class tinyBLAS_Q0_AVX {
590
- public:
591
- tinyBLAS_Q0_AVX(int64_t k,
592
- const TA *A, int64_t lda,
593
- const TB *B, int64_t ldb,
594
- TC *C, int64_t ldc,
595
- int ith, int nth)
596
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
597
- }
598
-
599
- void matmul(int64_t m, int64_t n, int task) {
600
- if (task == GGML_TASK_TYPE_COMPUTE)
601
- mnpack(0, m, 0, n);
602
- }
603
-
604
- private:
605
- void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
606
- int64_t mc, nc, mp, np;
607
- switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
608
- #if VECTOR_REGISTERS == 32
609
- case 0x44:
610
- mc = 4;
611
- nc = 4;
612
- gemm<4, 4>(m0, m, n0, n);
613
- break;
614
- case 0x43:
615
- mc = 4;
616
- nc = 3;
617
- gemm<4, 3>(m0, m, n0, n);
618
- break;
619
- case 0x34:
620
- mc = 3;
621
- nc = 4;
622
- gemm<3, 4>(m0, m, n0, n);
623
- break;
624
- case 0x33:
625
- mc = 3;
626
- nc = 3;
627
- gemm<3, 3>(m0, m, n0, n);
628
- break;
629
- case 0x42:
630
- mc = 4;
631
- nc = 2;
632
- gemm<4, 2>(m0, m, n0, n);
633
- break;
634
- case 0x24:
635
- mc = 2;
636
- nc = 4;
637
- gemm<2, 4>(m0, m, n0, n);
638
- break;
639
- #else
640
- case 0x44:
641
- case 0x43:
642
- case 0x42:
643
- mc = 4;
644
- nc = 2;
645
- gemm<4, 2>(m0, m, n0, n);
646
- break;
647
- case 0x34:
648
- case 0x24:
649
- mc = 2;
650
- nc = 4;
651
- gemm<2, 4>(m0, m, n0, n);
652
- break;
653
- case 0x33:
654
- #endif
655
- case 0x32:
656
- mc = 3;
657
- nc = 2;
658
- gemm<3, 2>(m0, m, n0, n);
659
- break;
660
- case 0x23:
661
- mc = 2;
662
- nc = 3;
663
- gemm<2, 3>(m0, m, n0, n);
664
- break;
665
- case 0x41:
666
- mc = 4;
667
- nc = 1;
668
- gemm<4, 1>(m0, m, n0, n);
669
- break;
670
- case 0x22:
671
- mc = 2;
672
- nc = 2;
673
- gemm<2, 2>(m0, m, n0, n);
674
- break;
675
- case 0x14:
676
- mc = 1;
677
- nc = 4;
678
- gemm<1, 4>(m0, m, n0, n);
679
- break;
680
- case 0x31:
681
- mc = 3;
682
- nc = 1;
683
- gemm<3, 1>(m0, m, n0, n);
684
- break;
685
- case 0x13:
686
- mc = 1;
687
- nc = 3;
688
- gemm<1, 3>(m0, m, n0, n);
689
- break;
690
- case 0x21:
691
- mc = 2;
692
- nc = 1;
693
- gemm<2, 1>(m0, m, n0, n);
694
- break;
695
- case 0x12:
696
- mc = 1;
697
- nc = 2;
698
- gemm<1, 2>(m0, m, n0, n);
699
- break;
700
- case 0x11:
701
- mc = 1;
702
- nc = 1;
703
- gemm<1, 1>(m0, m, n0, n);
704
- break;
705
- default:
706
- return;
707
- }
708
- mp = m0 + (m - m0) / mc * mc;
709
- np = n0 + (n - n0) / nc * nc;
710
- mnpack(mp, m, n0, np);
711
- mnpack(m0, m, np, n);
712
- }
713
-
714
- template <int RM, int RN>
715
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
716
- int64_t ytiles = (m - m0) / RM;
717
- int64_t xtiles = (n - n0) / RN;
718
- int64_t tiles = xtiles * ytiles;
719
- int64_t duty = (tiles + nth - 1) / nth;
720
- int64_t start = duty * ith;
721
- int64_t end = start + duty;
722
- if (end > tiles)
723
- end = tiles;
724
- for (int64_t job = start; job < end; ++job) {
725
- int64_t ii = m0 + job / xtiles * RM;
726
- int64_t jj = n0 + job % xtiles * RN;
727
- __m256 Cv[RN][RM] = {};
728
- for (int64_t l = 0; l < k; ++l)
729
- for (int64_t j = 0; j < RN; ++j)
730
- for (int64_t i = 0; i < RM; ++i) {
731
- #if defined(__AVX2__)
732
- __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
733
- load(A + lda * (ii + i) + l)),
734
- _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
735
- load(A + lda * (ii + i) + l)));
736
- #else
737
- __m128i ali0 = load0(A + lda * (ii + i) + l);
738
- __m128i ali1 = load1(A + lda * (ii + i) + l);
739
- __m128i blj0 = load0(B + ldb * (jj + j) + l);
740
- __m128i blj1 = load1(B + ldb * (jj + j) + l);
741
-
742
- __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
743
- __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
744
- __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
745
- __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
746
-
747
- // updot
748
- const __m128i oneFill = _mm_set1_epi16(1);
749
- __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
750
- __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
751
- __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
752
- #endif
753
- Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
754
- unhalf(B[ldb * (jj + j) + l].d)),
755
- udTmp,
756
- Cv[j][i]);
757
- }
758
- for (int64_t j = 0; j < RN; ++j)
759
- for (int64_t i = 0; i < RM; ++i)
760
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
761
- }
762
- }
763
-
764
- inline __m256i load(const block_q8_0 *b) {
765
- return _mm256_loadu_si256((const __m256i *)b->qs);
766
- }
767
-
768
- inline __m128i load0(const block_q8_0 *b) {
769
- return _mm_loadu_si128((const __m128i *)b->qs);
770
- }
771
-
772
- inline __m128i load1(const block_q8_0 *b) {
773
- return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
774
- }
775
-
776
- inline __m256i load(const block_q4_0 *b) {
777
- return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
778
- }
779
-
780
- inline __m128i load0(const block_q4_0 *b) {
781
- const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
782
- return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
783
- }
784
-
785
- inline __m128i load1(const block_q4_0 *b) {
786
- const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
787
- return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
788
- }
789
-
790
- inline __m256 updot(__m256i u, __m256i s) {
791
- __m256i res;
792
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
793
- res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
794
- #else
795
- res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
796
- #endif
797
- return _mm256_cvtepi32_ps(res);
798
- }
799
-
800
- static inline __m256i denibble(const uint8_t *p) {
801
- __m128i x = _mm_loadu_si128((const __m128i *)p);
802
- return _mm256_and_si256(_mm256_set1_epi8(15),
803
- _mm256_insertf128_si256(_mm256_castsi128_si256(x),
804
- _mm_srli_epi16(x, 4), 1));
805
- }
806
-
807
- const TA *const A;
808
- const TB *const B;
809
- TC *const C;
810
- const int64_t k;
811
- const int64_t lda;
812
- const int64_t ldb;
813
- const int64_t ldc;
814
- const int ith;
815
- const int nth;
816
- };
817
- #endif // __AVX__
818
-
819
- } // namespace
820
-
821
- /**
822
- * Performs optimized matrix multiplication on CPU.
823
- *
824
- * This subroutine may compute C = Aᵀ * B with column major ordering.
825
- * Despite its name, this isn't a generalized implementation. Work is
826
- * only performed when a handwritten kernel is written and available.
827
- * Otherwise the caller should fall back to a general matmul routine.
828
- *
829
- * For example, for single-threaded single-precision GEMM you can say
830
- *
831
- * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
832
- * 0, 1, GGML_TASK_TYPE_COMPUTE,
833
- * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
834
- *
835
- * @param m is rows in `A` and `C`
836
- * @param n is cols in `B` and `C`
837
- * @param k is cols in `A` and rows in `B`
838
- * @param A is first input matrix (always transposed)
839
- * @param lda is row stride of `A`
840
- * @param B is second input matrix (never transposed)
841
- * @param ldb is row stride of `B`
842
- * @param C is input/output array of output matrices
843
- * @param ldc is row stride of `C`
844
- * @param ith is thread id (must be less than `nth`)
845
- * @param nth is number of threads (must be greater than zero)
846
- * @param task is GGML task type
847
- * @param Atype is GGML data type of `A`
848
- * @param Btype is GGML data type of `B`
849
- * @param Ctype is GGML data type of `C`
850
- * @return true if this function was able to service the matmul request
851
- */
852
- bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
853
- int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
854
-
855
- assert(m >= 0);
856
- assert(n >= 0);
857
- assert(k >= 0);
858
- assert(lda >= k);
859
- assert(ldb >= k);
860
- assert(ldc >= m);
861
- assert(nth > 0);
862
- assert(ith < nth);
863
-
864
- if (Ctype != GGML_TYPE_F32)
865
- return false;
866
-
867
- switch (Atype) {
868
-
869
- case GGML_TYPE_F32: {
870
- if (Btype != GGML_TYPE_F32)
871
- return false;
872
- #if defined(__AVX512F__)
873
- if (k % 16)
874
- return false;
875
- tinyBLAS<16, __m512, __m512, float, float, float> tb{
876
- k, (const float *)A, lda,
877
- (const float *)B, ldb,
878
- (float *)C, ldc,
879
- ith, nth};
880
- tb.matmul(m, n, task);
881
- return true;
882
- #elif defined(__AVX__) || defined(__AVX2__)
883
- if (k % 8)
884
- return false;
885
- tinyBLAS<8, __m256, __m256, float, float, float> tb{
886
- k, (const float *)A, lda,
887
- (const float *)B, ldb,
888
- (float *)C, ldc,
889
- ith, nth};
890
- tb.matmul(m, n, task);
891
- return true;
892
- #elif defined(__ARM_NEON)
893
- if (n < 4)
894
- return false;
895
- if (k % 4)
896
- return false;
897
- tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
898
- k, (const float *)A, lda,
899
- (const float *)B, ldb,
900
- (float *)C, ldc,
901
- ith, nth};
902
- tb.matmul(m, n, task);
903
- return true;
904
- #else
905
- return false;
906
- #endif
907
- }
908
-
909
- case GGML_TYPE_F16: {
910
- #if defined(__AVX512F__)
911
- if (k % 16)
912
- return false;
913
- if (Btype != GGML_TYPE_F32)
914
- return false;
915
- tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
916
- k, (const ggml_fp16_t *)A, lda,
917
- (const float *)B, ldb,
918
- (float *)C, ldc,
919
- ith, nth};
920
- tb.matmul(m, n, task);
921
- return true;
922
- #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
923
- if (k % 8)
924
- return false;
925
- if (Btype != GGML_TYPE_F32)
926
- return false;
927
- tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
928
- k, (const ggml_fp16_t *)A, lda,
929
- (const float *)B, ldb,
930
- (float *)C, ldc,
931
- ith, nth};
932
- tb.matmul(m, n, task);
933
- return true;
934
- #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
935
- if (n < 8)
936
- return false;
937
- if (k % 8)
938
- return false;
939
- if (Btype != GGML_TYPE_F16)
940
- return false;
941
- tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
942
- k, (const ggml_fp16_t *)A, lda,
943
- (const ggml_fp16_t *)B, ldb,
944
- (float *)C, ldc,
945
- ith, nth};
946
- tb.matmul(m, n, task);
947
- return true;
948
- #elif defined(__ARM_NEON) && !defined(_MSC_VER)
949
- if (k % 4)
950
- return false;
951
- if (Btype != GGML_TYPE_F32)
952
- return false;
953
- tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
954
- k, (const ggml_fp16_t *)A, lda,
955
- (const float *)B, ldb,
956
- (float *)C, ldc,
957
- ith, nth};
958
- tb.matmul(m, n, task);
959
- return true;
960
- #else
961
- return false;
962
- #endif
963
- }
964
-
965
- case GGML_TYPE_Q8_0: {
966
- if (Btype != GGML_TYPE_Q8_0)
967
- return false;
968
- #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
969
- tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
970
- k, (const block_q8_0 *)A, lda,
971
- (const block_q8_0 *)B, ldb,
972
- (float *)C, ldc,
973
- ith, nth};
974
- tb.matmul(m, n, task);
975
- return true;
976
- #elif defined(__ARM_FEATURE_DOTPROD)
977
- tinyBLAS_Q0_ARM<block_q8_0> tb{
978
- k, (const block_q8_0 *)A, lda,
979
- (const block_q8_0 *)B, ldb,
980
- (float *)C, ldc,
981
- ith, nth};
982
- tb.matmul(m, n, task);
983
- return true;
984
- #else
985
- return false;
986
- #endif
987
- }
988
-
989
- case GGML_TYPE_Q4_0: {
990
- if (Btype != GGML_TYPE_Q8_0)
991
- return false;
992
- #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
993
- tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
994
- k, (const block_q4_0 *)A, lda,
995
- (const block_q8_0 *)B, ldb,
996
- (float *)C, ldc,
997
- ith, nth};
998
- tb.matmul(m, n, task);
999
- return true;
1000
- #elif defined(__ARM_FEATURE_DOTPROD)
1001
- tinyBLAS_Q0_ARM<block_q4_0> tb{
1002
- k, (const block_q4_0 *)A, lda,
1003
- (const block_q8_0 *)B, ldb,
1004
- (float *)C, ldc,
1005
- ith, nth};
1006
- tb.matmul(m, n, task);
1007
- return true;
1008
- #else
1009
- return false;
1010
- #endif
1011
- }
1012
-
1013
- default:
1014
- return false;
1015
- }
1016
-
1017
- (void)m;
1018
- (void)n;
1019
- (void)k;
1020
- (void)A;
1021
- (void)lda;
1022
- (void)B;
1023
- (void)ldb;
1024
- (void)C;
1025
- (void)ldc;
1026
- (void)ith;
1027
- (void)nth;
1028
- (void)task;
1029
- (void)Atype;
1030
- (void)Btype;
1031
- (void)Ctype;
1032
- }