llama_cpp 0.16.1 → 0.17.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -42
  5. data/ext/llama_cpp/llama_cpp.cpp +20 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +5 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1116
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2214
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -233
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -286
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3267
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14380
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1173
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -17429
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -49
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -140820
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7271
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22589
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2452
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18692
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1143
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1030
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -6983
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -796
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,1030 +0,0 @@
1
- // Copyright 2024 Mozilla Foundation
2
- //
3
- // Permission is hereby granted, free of charge, to any person obtaining
4
- // a copy of this software and associated documentation files (the
5
- // "Software"), to deal in the Software without restriction, including
6
- // without limitation the rights to use, copy, modify, merge, publish,
7
- // distribute, sublicense, and/or sell copies of the Software, and to
8
- // permit persons to whom the Software is furnished to do so, subject to
9
- // the following conditions:
10
- //
11
- // The above copyright notice and this permission notice shall be
12
- // included in all copies or substantial portions of the Software.
13
- //
14
- // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
- // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
- // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
- // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18
- // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19
- // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
- // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- // SOFTWARE.
22
-
23
- //
24
- // _ _ ___ _ _ ___
25
- // | |_(_)_ _ _ _| _ ) | /_\ / __|
26
- // | _| | ' \ || | _ \ |__ / _ \\__ \.
27
- // \__|_|_||_\_, |___/____/_/ \_\___/
28
- // |__/
29
- //
30
- // BASIC LINEAR ALGEBRA SUBPROGRAMS
31
- //
32
- //
33
- // This file implements multithreaded CPU matrix multiplication for the
34
- // common contiguous use case C = Aᵀ * B. These kernels are designed to
35
- // have excellent performance[1] for matrices that fit in the CPU cache
36
- // without imposing any overhead such as cache filling or malloc calls.
37
- //
38
- // This implementation does not guarantee any upper bound with rounding
39
- // errors, which grow along with k. Our goal's to maximally exploit the
40
- // hardware for performance, and then use whatever resources remain for
41
- // improving numerical accuracy.
42
- //
43
- // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44
- // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
-
46
- #pragma GCC diagnostic ignored "-Wpedantic"
47
- #pragma GCC diagnostic ignored "-Wignored-attributes"
48
-
49
- #include "sgemm.h"
50
- #include "ggml-impl.h"
51
- #include "ggml-quants.h"
52
-
53
- #ifdef _MSC_VER
54
- #define NOINLINE __declspec(noinline)
55
- #else
56
- #define NOINLINE __attribute__((__noinline__))
57
- #endif
58
-
59
- #if defined(__ARM_NEON) || defined(__AVX512F__)
60
- #define VECTOR_REGISTERS 32
61
- #else
62
- #define VECTOR_REGISTERS 16
63
- #endif
64
-
65
- #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
66
-
67
- namespace {
68
-
69
- inline float unhalf(ggml_fp16_t d) {
70
- return GGML_FP16_TO_FP32(d);
71
- }
72
-
73
- ////////////////////////////////////////////////////////////////////////////////////////////////////
74
- // VECTORIZED ARITHMETIC OPERATIONS
75
-
76
- #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
77
- inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
78
- inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
79
- inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
80
- #endif // __SSE__
81
-
82
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
83
- inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
84
- inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
85
- inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
86
- #endif // __AVX__
87
-
88
- #if defined(__AVX512F__)
89
- inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
90
- inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
91
- inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
92
- #endif // __AVX512F__
93
-
94
- #if defined(__ARM_NEON)
95
- inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
96
- inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
97
- inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
98
- #endif // __ARM_NEON
99
-
100
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
101
- inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
102
- inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
103
- inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
104
- #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
105
-
106
- ////////////////////////////////////////////////////////////////////////////////////////////////////
107
- // VECTORIZED FUSED MULTIPLY ADD
108
-
109
- /**
110
- * Computes a * b + c.
111
- */
112
- template <typename T, typename U>
113
- inline U madd(T a, T b, U c) {
114
- return add(mul(a, b), c);
115
- }
116
-
117
- #if defined(__FMA__)
118
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
119
- template <>
120
- inline __m256 madd(__m256 a, __m256 b, __m256 c) {
121
- return _mm256_fmadd_ps(a, b, c);
122
- }
123
- #endif
124
- #if defined(__AVX512F__)
125
- template <>
126
- inline __m512 madd(__m512 a, __m512 b, __m512 c) {
127
- return _mm512_fmadd_ps(a, b, c);
128
- }
129
- #endif
130
- #endif
131
-
132
- #if defined(__ARM_FEATURE_FMA)
133
- template <>
134
- inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
135
- return vfmaq_f32(c, b, a);
136
- }
137
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
138
- template <>
139
- inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
140
- return vfmaq_f16(c, b, a);
141
- }
142
- #endif
143
- #endif
144
-
145
- ////////////////////////////////////////////////////////////////////////////////////////////////////
146
- // VECTORIZED HORIZONTAL SUM
147
-
148
- #if defined(__ARM_NEON)
149
- inline float hsum(float32x4_t x) {
150
- return vaddvq_f32(x);
151
- }
152
- #endif // __ARM_NEON
153
-
154
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
155
- inline float hsum(float16x8_t x) {
156
- return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
157
- vcvt_f32_f16(vget_high_f16(x))));
158
- }
159
- #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
160
-
161
- #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
162
- inline float hsum(__m128 x) {
163
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
164
- x = _mm_add_ps(x, _mm_movehl_ps(x, x));
165
- x = _mm_add_ss(x, _mm_movehdup_ps(x));
166
- #else
167
- __m128 t;
168
- t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
169
- x = _mm_add_ps(x, t);
170
- t = _mm_movehl_ps(t, x);
171
- x = _mm_add_ss(x, t);
172
- #endif
173
- return _mm_cvtss_f32(x);
174
- }
175
- #endif
176
-
177
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
178
- inline float hsum(__m256 x) {
179
- return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
180
- _mm256_castps256_ps128(x)));
181
- }
182
- #endif // __AVX__
183
-
184
- #if defined(__AVX512F__)
185
- inline float hsum(__m512 x) {
186
- return _mm512_reduce_add_ps(x);
187
- }
188
- #endif // __AVX512F__
189
-
190
- ////////////////////////////////////////////////////////////////////////////////////////////////////
191
- // VECTORIZED MEMORY LOADING
192
-
193
- template <typename T, typename U> T load(const U *);
194
-
195
- #if defined(__ARM_NEON)
196
- template <> inline float32x4_t load(const float *p) {
197
- return vld1q_f32(p);
198
- }
199
- #if !defined(_MSC_VER)
200
- template <> inline float16x8_t load(const ggml_fp16_t *p) {
201
- return vld1q_f16((const float16_t *)p);
202
- }
203
- template <> inline float32x4_t load(const ggml_fp16_t *p) {
204
- return vcvt_f32_f16(vld1_f16((const float16_t *)p));
205
- }
206
- #endif // _MSC_VER
207
- #endif // __ARM_NEON
208
-
209
- #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
210
- template <> inline __m128 load(const float *p) {
211
- return _mm_loadu_ps(p);
212
- }
213
- #endif // __SSE__
214
-
215
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
216
- template <> inline __m256 load(const float *p) {
217
- return _mm256_loadu_ps(p);
218
- }
219
- #endif // __AVX__
220
-
221
- #if defined(__F16C__)
222
- template <> inline __m256 load(const ggml_fp16_t *p) {
223
- return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
224
- }
225
- #endif // __F16C__
226
-
227
- #if defined(__AVX512F__)
228
- template <> inline __m512 load(const float *p) {
229
- return _mm512_loadu_ps(p);
230
- }
231
- template <> inline __m512 load(const ggml_fp16_t *p) {
232
- return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
233
- }
234
- #endif // __AVX512F__
235
-
236
- ////////////////////////////////////////////////////////////////////////////////////////////////////
237
- // FLOATING POINT MATRIX MULTIPLICATION
238
-
239
- template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
240
- class tinyBLAS {
241
- public:
242
- tinyBLAS(int64_t k,
243
- const TA *A, int64_t lda,
244
- const TB *B, int64_t ldb,
245
- TC *C, int64_t ldc,
246
- int ith, int nth)
247
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
248
- }
249
-
250
- void matmul(int64_t m, int64_t n, int task) {
251
- if (task == GGML_TASK_TYPE_COMPUTE)
252
- mnpack(0, m, 0, n);
253
- }
254
-
255
- private:
256
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
257
- int64_t mc, nc, mp, np;
258
- switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
259
- #if VECTOR_REGISTERS == 32
260
- case 0x55:
261
- mc = 5;
262
- nc = 5;
263
- gemm<5, 5>(m0, m, n0, n);
264
- break;
265
- case 0x45:
266
- mc = 4;
267
- nc = 5;
268
- gemm<4, 5>(m0, m, n0, n);
269
- break;
270
- case 0x54:
271
- mc = 5;
272
- nc = 4;
273
- gemm<5, 4>(m0, m, n0, n);
274
- break;
275
- case 0x44:
276
- mc = 4;
277
- nc = 4;
278
- gemm<4, 4>(m0, m, n0, n);
279
- break;
280
- case 0x53:
281
- mc = 5;
282
- nc = 3;
283
- gemm<5, 3>(m0, m, n0, n);
284
- break;
285
- case 0x35:
286
- mc = 3;
287
- nc = 5;
288
- gemm<3, 5>(m0, m, n0, n);
289
- break;
290
- case 0x43:
291
- mc = 4;
292
- nc = 3;
293
- gemm<4, 3>(m0, m, n0, n);
294
- break;
295
- #else
296
- case 0x55:
297
- case 0x54:
298
- case 0x53:
299
- case 0x45:
300
- case 0x44:
301
- case 0x43:
302
- mc = 4;
303
- nc = 3;
304
- gemm<4, 3>(m0, m, n0, n);
305
- break;
306
- case 0x35:
307
- #endif
308
- case 0x34:
309
- mc = 3;
310
- nc = 4;
311
- gemm<3, 4>(m0, m, n0, n);
312
- break;
313
- case 0x52:
314
- mc = 5;
315
- nc = 2;
316
- gemm<5, 2>(m0, m, n0, n);
317
- break;
318
- case 0x33:
319
- mc = 3;
320
- nc = 3;
321
- gemm<3, 3>(m0, m, n0, n);
322
- break;
323
- case 0x25:
324
- mc = 2;
325
- nc = 5;
326
- gemm<2, 5>(m0, m, n0, n);
327
- break;
328
- case 0x42:
329
- mc = 4;
330
- nc = 2;
331
- gemm<4, 2>(m0, m, n0, n);
332
- break;
333
- case 0x24:
334
- mc = 2;
335
- nc = 4;
336
- gemm<2, 4>(m0, m, n0, n);
337
- break;
338
- case 0x32:
339
- mc = 3;
340
- nc = 2;
341
- gemm<3, 2>(m0, m, n0, n);
342
- break;
343
- case 0x23:
344
- mc = 2;
345
- nc = 3;
346
- gemm<2, 3>(m0, m, n0, n);
347
- break;
348
- case 0x51:
349
- mc = 5;
350
- nc = 1;
351
- gemm<5, 1>(m0, m, n0, n);
352
- break;
353
- case 0x41:
354
- mc = 4;
355
- nc = 1;
356
- gemm<4, 1>(m0, m, n0, n);
357
- break;
358
- case 0x22:
359
- mc = 2;
360
- nc = 2;
361
- gemm<2, 2>(m0, m, n0, n);
362
- break;
363
- case 0x15:
364
- mc = 1;
365
- nc = 5;
366
- gemm<1, 5>(m0, m, n0, n);
367
- break;
368
- case 0x14:
369
- mc = 1;
370
- nc = 4;
371
- gemm<1, 4>(m0, m, n0, n);
372
- break;
373
- case 0x31:
374
- mc = 3;
375
- nc = 1;
376
- gemm<3, 1>(m0, m, n0, n);
377
- break;
378
- case 0x13:
379
- mc = 1;
380
- nc = 3;
381
- gemm<1, 3>(m0, m, n0, n);
382
- break;
383
- case 0x21:
384
- mc = 2;
385
- nc = 1;
386
- gemm<2, 1>(m0, m, n0, n);
387
- break;
388
- case 0x12:
389
- mc = 1;
390
- nc = 2;
391
- gemm<1, 2>(m0, m, n0, n);
392
- break;
393
- case 0x11:
394
- mc = 1;
395
- nc = 1;
396
- gemm<1, 1>(m0, m, n0, n);
397
- break;
398
- default:
399
- return;
400
- }
401
- mp = m0 + (m - m0) / mc * mc;
402
- np = n0 + (n - n0) / nc * nc;
403
- mnpack(mp, m, n0, np);
404
- mnpack(m0, m, np, n);
405
- }
406
-
407
- template <int RM, int RN>
408
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
409
- int64_t ytiles = (m - m0) / RM;
410
- int64_t xtiles = (n - n0) / RN;
411
- int64_t tiles = xtiles * ytiles;
412
- int64_t duty = (tiles + nth - 1) / nth;
413
- int64_t start = duty * ith;
414
- int64_t end = start + duty;
415
- if (end > tiles)
416
- end = tiles;
417
- for (int64_t job = start; job < end; ++job) {
418
- int64_t ii = m0 + job / xtiles * RM;
419
- int64_t jj = n0 + job % xtiles * RN;
420
- D Cv[RN][RM] = {};
421
- for (int64_t l = 0; l < k; l += KN)
422
- for (int64_t j = 0; j < RN; ++j)
423
- for (int64_t i = 0; i < RM; ++i)
424
- Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
425
- load<V>(B + ldb * (jj + j) + l),
426
- Cv[j][i]);
427
- for (int64_t j = 0; j < RN; ++j)
428
- for (int64_t i = 0; i < RM; ++i)
429
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
430
- }
431
- }
432
-
433
- const TA *const A;
434
- const TB *const B;
435
- TC *const C;
436
- const int64_t k;
437
- const int64_t lda;
438
- const int64_t ldb;
439
- const int64_t ldc;
440
- const int ith;
441
- const int nth;
442
- };
443
-
444
- //////////////////////////////////////////////////////////////////////////////////////////
445
- // QUANT ZERO MATRIX MULTIPLICATION
446
-
447
- #if defined(__ARM_FEATURE_DOTPROD)
448
- template <typename TA>
449
- class tinyBLAS_Q0_ARM {
450
- public:
451
- tinyBLAS_Q0_ARM(int64_t k,
452
- const TA *A, int64_t lda,
453
- const block_q8_0 *B, int64_t ldb,
454
- float *C, int64_t ldc,
455
- int ith, int nth)
456
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
457
- }
458
-
459
- void matmul(int64_t m, int64_t n, int task) {
460
- if (task == GGML_TASK_TYPE_COMPUTE)
461
- mnpack(0, m, 0, n);
462
- }
463
-
464
- private:
465
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
466
- int64_t mc, nc, mp, np;
467
- switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
468
- case 0x33:
469
- mc = 3;
470
- nc = 3;
471
- gemm<3, 3>(m0, m, n0, n);
472
- break;
473
- case 0x32:
474
- mc = 3;
475
- nc = 2;
476
- gemm<3, 2>(m0, m, n0, n);
477
- break;
478
- case 0x23:
479
- mc = 2;
480
- nc = 3;
481
- gemm<2, 3>(m0, m, n0, n);
482
- break;
483
- case 0x22:
484
- mc = 2;
485
- nc = 2;
486
- gemm<2, 2>(m0, m, n0, n);
487
- break;
488
- case 0x31:
489
- mc = 3;
490
- nc = 1;
491
- gemm<3, 1>(m0, m, n0, n);
492
- break;
493
- case 0x13:
494
- mc = 1;
495
- nc = 3;
496
- gemm<1, 3>(m0, m, n0, n);
497
- break;
498
- case 0x21:
499
- mc = 2;
500
- nc = 1;
501
- gemm<2, 1>(m0, m, n0, n);
502
- break;
503
- case 0x12:
504
- mc = 1;
505
- nc = 2;
506
- gemm<1, 2>(m0, m, n0, n);
507
- break;
508
- case 0x11:
509
- mc = 1;
510
- nc = 1;
511
- gemm<1, 1>(m0, m, n0, n);
512
- break;
513
- default:
514
- return;
515
- }
516
- mp = m0 + (m - m0) / mc * mc;
517
- np = n0 + (n - n0) / nc * nc;
518
- mnpack(mp, m, n0, np);
519
- mnpack(m0, m, np, n);
520
- }
521
-
522
- template <int RM, int RN>
523
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
524
- int64_t ytiles = (m - m0) / RM;
525
- int64_t xtiles = (n - n0) / RN;
526
- int64_t tiles = xtiles * ytiles;
527
- int64_t duty = (tiles + nth - 1) / nth;
528
- int64_t start = duty * ith;
529
- int64_t end = start + duty;
530
- if (end > tiles)
531
- end = tiles;
532
- for (int64_t job = start; job < end; ++job) {
533
- int64_t ii = m0 + job / xtiles * RM;
534
- int64_t jj = n0 + job % xtiles * RN;
535
- float32x4_t Cv[RN][RM] = {};
536
- for (int64_t l = 0; l < k; ++l)
537
- for (int64_t j = 0; j < RN; ++j)
538
- for (int64_t i = 0; i < RM; ++i)
539
- Cv[j][i] = vmlaq_n_f32(Cv[j][i],
540
- vcvtq_f32_s32(vdotq_s32(
541
- vdotq_s32(vdupq_n_s32(0),
542
- load_lo(A + lda * (ii + i) + l),
543
- load_lo(B + ldb * (jj + j) + l)),
544
- load_hi(A + lda * (ii + i) + l),
545
- load_hi(B + ldb * (jj + j) + l))),
546
- unhalf(A[lda * (ii + i) + l].d) *
547
- unhalf(B[ldb * (jj + j) + l].d));
548
- for (int64_t j = 0; j < RN; ++j)
549
- for (int64_t i = 0; i < RM; ++i)
550
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
551
- }
552
- }
553
-
554
- inline int8x16_t load_lo(const block_q8_0 *b) {
555
- return vld1q_s8(b->qs);
556
- }
557
-
558
- inline int8x16_t load_hi(const block_q8_0 *b) {
559
- return vld1q_s8(b->qs + 16);
560
- }
561
-
562
- inline int8x16_t load_lo(const block_q4_0 *b) {
563
- return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
564
- vdupq_n_u8(0x0f))),
565
- vdupq_n_s8(0x8));
566
- }
567
-
568
- inline int8x16_t load_hi(const block_q4_0 *b) {
569
- return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
570
- vdupq_n_s8(0x8));
571
- }
572
-
573
- const TA *const A;
574
- const block_q8_0 *const B;
575
- float *const C;
576
- const int64_t k;
577
- const int64_t lda;
578
- const int64_t ldb;
579
- const int64_t ldc;
580
- const int ith;
581
- const int nth;
582
- };
583
- #endif // __ARM_FEATURE_DOTPROD
584
-
585
- #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
586
- template <typename TA, typename TB, typename TC>
587
- class tinyBLAS_Q0_AVX {
588
- public:
589
- tinyBLAS_Q0_AVX(int64_t k,
590
- const TA *A, int64_t lda,
591
- const TB *B, int64_t ldb,
592
- TC *C, int64_t ldc,
593
- int ith, int nth)
594
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
595
- }
596
-
597
- void matmul(int64_t m, int64_t n, int task) {
598
- if (task == GGML_TASK_TYPE_COMPUTE)
599
- mnpack(0, m, 0, n);
600
- }
601
-
602
- private:
603
- void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
604
- int64_t mc, nc, mp, np;
605
- switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
606
- #if VECTOR_REGISTERS == 32
607
- case 0x44:
608
- mc = 4;
609
- nc = 4;
610
- gemm<4, 4>(m0, m, n0, n);
611
- break;
612
- case 0x43:
613
- mc = 4;
614
- nc = 3;
615
- gemm<4, 3>(m0, m, n0, n);
616
- break;
617
- case 0x34:
618
- mc = 3;
619
- nc = 4;
620
- gemm<3, 4>(m0, m, n0, n);
621
- break;
622
- case 0x33:
623
- mc = 3;
624
- nc = 3;
625
- gemm<3, 3>(m0, m, n0, n);
626
- break;
627
- case 0x42:
628
- mc = 4;
629
- nc = 2;
630
- gemm<4, 2>(m0, m, n0, n);
631
- break;
632
- case 0x24:
633
- mc = 2;
634
- nc = 4;
635
- gemm<2, 4>(m0, m, n0, n);
636
- break;
637
- #else
638
- case 0x44:
639
- case 0x43:
640
- case 0x42:
641
- mc = 4;
642
- nc = 2;
643
- gemm<4, 2>(m0, m, n0, n);
644
- break;
645
- case 0x34:
646
- case 0x24:
647
- mc = 2;
648
- nc = 4;
649
- gemm<2, 4>(m0, m, n0, n);
650
- break;
651
- case 0x33:
652
- #endif
653
- case 0x32:
654
- mc = 3;
655
- nc = 2;
656
- gemm<3, 2>(m0, m, n0, n);
657
- break;
658
- case 0x23:
659
- mc = 2;
660
- nc = 3;
661
- gemm<2, 3>(m0, m, n0, n);
662
- break;
663
- case 0x41:
664
- mc = 4;
665
- nc = 1;
666
- gemm<4, 1>(m0, m, n0, n);
667
- break;
668
- case 0x22:
669
- mc = 2;
670
- nc = 2;
671
- gemm<2, 2>(m0, m, n0, n);
672
- break;
673
- case 0x14:
674
- mc = 1;
675
- nc = 4;
676
- gemm<1, 4>(m0, m, n0, n);
677
- break;
678
- case 0x31:
679
- mc = 3;
680
- nc = 1;
681
- gemm<3, 1>(m0, m, n0, n);
682
- break;
683
- case 0x13:
684
- mc = 1;
685
- nc = 3;
686
- gemm<1, 3>(m0, m, n0, n);
687
- break;
688
- case 0x21:
689
- mc = 2;
690
- nc = 1;
691
- gemm<2, 1>(m0, m, n0, n);
692
- break;
693
- case 0x12:
694
- mc = 1;
695
- nc = 2;
696
- gemm<1, 2>(m0, m, n0, n);
697
- break;
698
- case 0x11:
699
- mc = 1;
700
- nc = 1;
701
- gemm<1, 1>(m0, m, n0, n);
702
- break;
703
- default:
704
- return;
705
- }
706
- mp = m0 + (m - m0) / mc * mc;
707
- np = n0 + (n - n0) / nc * nc;
708
- mnpack(mp, m, n0, np);
709
- mnpack(m0, m, np, n);
710
- }
711
-
712
- template <int RM, int RN>
713
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
714
- int64_t ytiles = (m - m0) / RM;
715
- int64_t xtiles = (n - n0) / RN;
716
- int64_t tiles = xtiles * ytiles;
717
- int64_t duty = (tiles + nth - 1) / nth;
718
- int64_t start = duty * ith;
719
- int64_t end = start + duty;
720
- if (end > tiles)
721
- end = tiles;
722
- for (int64_t job = start; job < end; ++job) {
723
- int64_t ii = m0 + job / xtiles * RM;
724
- int64_t jj = n0 + job % xtiles * RN;
725
- __m256 Cv[RN][RM] = {};
726
- for (int64_t l = 0; l < k; ++l)
727
- for (int64_t j = 0; j < RN; ++j)
728
- for (int64_t i = 0; i < RM; ++i) {
729
- #if defined(__AVX2__)
730
- __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
731
- load(A + lda * (ii + i) + l)),
732
- _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
733
- load(A + lda * (ii + i) + l)));
734
- #else
735
- __m128i ali0 = load0(A + lda * (ii + i) + l);
736
- __m128i ali1 = load1(A + lda * (ii + i) + l);
737
- __m128i blj0 = load0(B + ldb * (jj + j) + l);
738
- __m128i blj1 = load1(B + ldb * (jj + j) + l);
739
-
740
- __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
741
- __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
742
- __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
743
- __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
744
-
745
- // updot
746
- const __m128i oneFill = _mm_set1_epi16(1);
747
- __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
748
- __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
749
- __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
750
- #endif
751
- Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
752
- unhalf(B[ldb * (jj + j) + l].d)),
753
- udTmp,
754
- Cv[j][i]);
755
- }
756
- for (int64_t j = 0; j < RN; ++j)
757
- for (int64_t i = 0; i < RM; ++i)
758
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
759
- }
760
- }
761
-
762
- inline __m256i load(const block_q8_0 *b) {
763
- return _mm256_loadu_si256((const __m256i *)b->qs);
764
- }
765
-
766
- inline __m128i load0(const block_q8_0 *b) {
767
- return _mm_loadu_si128((const __m128i *)b->qs);
768
- }
769
-
770
- inline __m128i load1(const block_q8_0 *b) {
771
- return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
772
- }
773
-
774
- inline __m256i load(const block_q4_0 *b) {
775
- return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
776
- }
777
-
778
- inline __m128i load0(const block_q4_0 *b) {
779
- const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
780
- return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
781
- }
782
-
783
- inline __m128i load1(const block_q4_0 *b) {
784
- const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
785
- return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
786
- }
787
-
788
- inline __m256 updot(__m256i u, __m256i s) {
789
- __m256i res;
790
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
791
- res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
792
- #else
793
- res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
794
- #endif
795
- return _mm256_cvtepi32_ps(res);
796
- }
797
-
798
- static inline __m256i denibble(const uint8_t *p) {
799
- __m128i x = _mm_loadu_si128((const __m128i *)p);
800
- return _mm256_and_si256(_mm256_set1_epi8(15),
801
- _mm256_insertf128_si256(_mm256_castsi128_si256(x),
802
- _mm_srli_epi16(x, 4), 1));
803
- }
804
-
805
- const TA *const A;
806
- const TB *const B;
807
- TC *const C;
808
- const int64_t k;
809
- const int64_t lda;
810
- const int64_t ldb;
811
- const int64_t ldc;
812
- const int ith;
813
- const int nth;
814
- };
815
- #endif // __AVX__
816
-
817
- } // namespace
818
-
819
- /**
820
- * Performs optimized matrix multiplication on CPU.
821
- *
822
- * This subroutine may compute C = Aᵀ * B with column major ordering.
823
- * Despite its name, this isn't a generalized implementation. Work is
824
- * only performed when a handwritten kernel is written and available.
825
- * Otherwise the caller should fall back to a general matmul routine.
826
- *
827
- * For example, for single-threaded single-precision GEMM you can say
828
- *
829
- * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
830
- * 0, 1, GGML_TASK_TYPE_COMPUTE,
831
- * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
832
- *
833
- * @param m is rows in `A` and `C`
834
- * @param n is cols in `B` and `C`
835
- * @param k is cols in `A` and rows in `B`
836
- * @param A is first input matrix (always transposed)
837
- * @param lda is row stride of `A`
838
- * @param B is second input matrix (never transposed)
839
- * @param ldb is row stride of `B`
840
- * @param C is input/output array of output matrices
841
- * @param ldc is row stride of `C`
842
- * @param ith is thread id (must be less than `nth`)
843
- * @param nth is number of threads (must be greater than zero)
844
- * @param task is GGML task type
845
- * @param Atype is GGML data type of `A`
846
- * @param Btype is GGML data type of `B`
847
- * @param Ctype is GGML data type of `C`
848
- * @return true if this function was able to service the matmul request
849
- */
850
- bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
851
- int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
852
-
853
- assert(m >= 0);
854
- assert(n >= 0);
855
- assert(k >= 0);
856
- assert(lda >= k);
857
- assert(ldb >= k);
858
- assert(ldc >= m);
859
- assert(nth > 0);
860
- assert(ith < nth);
861
-
862
- if (Ctype != GGML_TYPE_F32)
863
- return false;
864
-
865
- switch (Atype) {
866
-
867
- case GGML_TYPE_F32: {
868
- if (Btype != GGML_TYPE_F32)
869
- return false;
870
- #if defined(__AVX512F__)
871
- if (k % 16)
872
- return false;
873
- tinyBLAS<16, __m512, __m512, float, float, float> tb{
874
- k, (const float *)A, lda,
875
- (const float *)B, ldb,
876
- (float *)C, ldc,
877
- ith, nth};
878
- tb.matmul(m, n, task);
879
- return true;
880
- #elif defined(__AVX__) || defined(__AVX2__)
881
- if (k % 8)
882
- return false;
883
- tinyBLAS<8, __m256, __m256, float, float, float> tb{
884
- k, (const float *)A, lda,
885
- (const float *)B, ldb,
886
- (float *)C, ldc,
887
- ith, nth};
888
- tb.matmul(m, n, task);
889
- return true;
890
- #elif defined(__ARM_NEON)
891
- if (n < 4)
892
- return false;
893
- if (k % 4)
894
- return false;
895
- tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
896
- k, (const float *)A, lda,
897
- (const float *)B, ldb,
898
- (float *)C, ldc,
899
- ith, nth};
900
- tb.matmul(m, n, task);
901
- return true;
902
- #else
903
- return false;
904
- #endif
905
- }
906
-
907
- case GGML_TYPE_F16: {
908
- #if defined(__AVX512F__)
909
- if (k % 16)
910
- return false;
911
- if (Btype != GGML_TYPE_F32)
912
- return false;
913
- tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
914
- k, (const ggml_fp16_t *)A, lda,
915
- (const float *)B, ldb,
916
- (float *)C, ldc,
917
- ith, nth};
918
- tb.matmul(m, n, task);
919
- return true;
920
- #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
921
- if (k % 8)
922
- return false;
923
- if (Btype != GGML_TYPE_F32)
924
- return false;
925
- tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
926
- k, (const ggml_fp16_t *)A, lda,
927
- (const float *)B, ldb,
928
- (float *)C, ldc,
929
- ith, nth};
930
- tb.matmul(m, n, task);
931
- return true;
932
- #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
933
- if (n < 8)
934
- return false;
935
- if (k % 8)
936
- return false;
937
- if (Btype != GGML_TYPE_F16)
938
- return false;
939
- tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
940
- k, (const ggml_fp16_t *)A, lda,
941
- (const ggml_fp16_t *)B, ldb,
942
- (float *)C, ldc,
943
- ith, nth};
944
- tb.matmul(m, n, task);
945
- return true;
946
- #elif defined(__ARM_NEON) && !defined(_MSC_VER)
947
- if (k % 4)
948
- return false;
949
- if (Btype != GGML_TYPE_F32)
950
- return false;
951
- tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
952
- k, (const ggml_fp16_t *)A, lda,
953
- (const float *)B, ldb,
954
- (float *)C, ldc,
955
- ith, nth};
956
- tb.matmul(m, n, task);
957
- return true;
958
- #else
959
- return false;
960
- #endif
961
- }
962
-
963
- case GGML_TYPE_Q8_0: {
964
- if (Btype != GGML_TYPE_Q8_0)
965
- return false;
966
- #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
967
- tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
968
- k, (const block_q8_0 *)A, lda,
969
- (const block_q8_0 *)B, ldb,
970
- (float *)C, ldc,
971
- ith, nth};
972
- tb.matmul(m, n, task);
973
- return true;
974
- #elif defined(__ARM_FEATURE_DOTPROD)
975
- tinyBLAS_Q0_ARM<block_q8_0> tb{
976
- k, (const block_q8_0 *)A, lda,
977
- (const block_q8_0 *)B, ldb,
978
- (float *)C, ldc,
979
- ith, nth};
980
- tb.matmul(m, n, task);
981
- return true;
982
- #else
983
- return false;
984
- #endif
985
- }
986
-
987
- case GGML_TYPE_Q4_0: {
988
- if (Btype != GGML_TYPE_Q8_0)
989
- return false;
990
- #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
991
- tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
992
- k, (const block_q4_0 *)A, lda,
993
- (const block_q8_0 *)B, ldb,
994
- (float *)C, ldc,
995
- ith, nth};
996
- tb.matmul(m, n, task);
997
- return true;
998
- #elif defined(__ARM_FEATURE_DOTPROD)
999
- tinyBLAS_Q0_ARM<block_q4_0> tb{
1000
- k, (const block_q4_0 *)A, lda,
1001
- (const block_q8_0 *)B, ldb,
1002
- (float *)C, ldc,
1003
- ith, nth};
1004
- tb.matmul(m, n, task);
1005
- return true;
1006
- #else
1007
- return false;
1008
- #endif
1009
- }
1010
-
1011
- default:
1012
- return false;
1013
- }
1014
-
1015
- (void)m;
1016
- (void)n;
1017
- (void)k;
1018
- (void)A;
1019
- (void)lda;
1020
- (void)B;
1021
- (void)ldb;
1022
- (void)C;
1023
- (void)ldc;
1024
- (void)ith;
1025
- (void)nth;
1026
- (void)task;
1027
- (void)Atype;
1028
- (void)Btype;
1029
- (void)Ctype;
1030
- }