whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,1884 @@
1
+ // Copyright 2024 Mozilla Foundation
2
+ //
3
+ // Permission is hereby granted, free of charge, to any person obtaining
4
+ // a copy of this software and associated documentation files (the
5
+ // "Software"), to deal in the Software without restriction, including
6
+ // without limitation the rights to use, copy, modify, merge, publish,
7
+ // distribute, sublicense, and/or sell copies of the Software, and to
8
+ // permit persons to whom the Software is furnished to do so, subject to
9
+ // the following conditions:
10
+ //
11
+ // The above copyright notice and this permission notice shall be
12
+ // included in all copies or substantial portions of the Software.
13
+ //
14
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18
+ // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19
+ // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
+ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ // SOFTWARE.
22
+
23
+ //
24
+ // _ _ ___ _ _ ___
25
+ // | |_(_)_ _ _ _| _ ) | /_\ / __|
26
+ // | _| | ' \ || | _ \ |__ / _ \\__ \.
27
+ // \__|_|_||_\_, |___/____/_/ \_\___/
28
+ // |__/
29
+ //
30
+ // BASIC LINEAR ALGEBRA SUBPROGRAMS
31
+ //
32
+ //
33
+ // This file implements multithreaded CPU matrix multiplication for the
34
+ // common contiguous use case C = Aᵀ * B. These kernels are designed to
35
+ // have excellent performance[1] for matrices that fit in the CPU cache
36
+ // without imposing any overhead such as cache filling or malloc calls.
37
+ //
38
+ // This implementation does not guarantee any upper bound with rounding
39
+ // errors, which grow along with k. Our goal's to maximally exploit the
40
+ // hardware for performance, and then use whatever resources remain for
41
+ // improving numerical accuracy.
42
+ //
43
+ // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44
+ // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
+
46
+ #if defined(__GNUC__)
47
+ #pragma GCC diagnostic ignored "-Wpedantic"
48
+ #pragma GCC diagnostic ignored "-Wignored-attributes"
49
+ #endif
50
+
51
+ #include "sgemm.h"
52
+ #include "ggml-impl.h"
53
+ #include "ggml-cpu-impl.h"
54
+ #include "ggml-quants.h"
55
+
56
+ #ifdef _MSC_VER
57
+ #define NOINLINE __declspec(noinline)
58
+ #else
59
+ #define NOINLINE __attribute__((__noinline__))
60
+ #endif
61
+
62
+ #if defined(__ARM_NEON) || defined(__AVX512F__)
63
+ #define VECTOR_REGISTERS 32
64
+ #else
65
+ #define VECTOR_REGISTERS 16
66
+ #endif
67
+
68
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
69
+
70
+ namespace {
71
+
72
+ inline float unhalf(ggml_fp16_t d) {
73
+ return GGML_FP16_TO_FP32(d);
74
+ }
75
+
76
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
77
+ // VECTORIZED ARITHMETIC OPERATIONS
78
+
79
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
80
+ inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
81
+ inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
82
+ inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
83
+ #endif // __SSE__
84
+
85
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
86
+ inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
87
+ inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
88
+ inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
89
+ #endif // __AVX__
90
+
91
+ #if defined(__AVX512F__)
92
+ inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
93
+ inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
94
+ inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
95
+ #endif // __AVX512F__
96
+
97
+ #if defined(__ARM_NEON)
98
+ inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
99
+ inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
100
+ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
101
+ #endif // __ARM_NEON
102
+
103
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
104
+ inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
105
+ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
106
+ inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
107
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
108
+
109
+ #if defined(__MMA__)
110
+ typedef vector unsigned char vec_t;
111
+ typedef __vector_quad acc_t;
112
+ #endif
113
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
114
+ // VECTORIZED FUSED MULTIPLY ADD
115
+
116
+ /**
117
+ * Computes a * b + c.
118
+ */
119
+ template <typename T, typename U>
120
+ inline U madd(T a, T b, U c) {
121
+ return add(mul(a, b), c);
122
+ }
123
+
124
+ #if defined(__FMA__)
125
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
126
+ template <>
127
+ inline __m256 madd(__m256 a, __m256 b, __m256 c) {
128
+ return _mm256_fmadd_ps(a, b, c);
129
+ }
130
+ #endif
131
+ #if defined(__AVX512F__)
132
+ template <>
133
+ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
134
+ return _mm512_fmadd_ps(a, b, c);
135
+ }
136
+ #endif
137
+ #endif
138
+
139
+ #if defined(__ARM_FEATURE_FMA)
140
+ template <>
141
+ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
142
+ return vfmaq_f32(c, b, a);
143
+ }
144
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
145
+ template <>
146
+ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
147
+ return vfmaq_f16(c, b, a);
148
+ }
149
+ #endif
150
+ #endif
151
+
152
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
153
+ // VECTORIZED HORIZONTAL SUM
154
+
155
+ #if defined(__ARM_NEON)
156
+ inline float hsum(float32x4_t x) {
157
+ return vaddvq_f32(x);
158
+ }
159
+ #endif // __ARM_NEON
160
+
161
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
162
+ inline float hsum(float16x8_t x) {
163
+ return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
164
+ vcvt_f32_f16(vget_high_f16(x))));
165
+ }
166
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
167
+
168
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
169
+ inline float hsum(__m128 x) {
170
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
171
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
172
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
173
+ #else
174
+ __m128 t;
175
+ t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
176
+ x = _mm_add_ps(x, t);
177
+ t = _mm_movehl_ps(t, x);
178
+ x = _mm_add_ss(x, t);
179
+ #endif
180
+ return _mm_cvtss_f32(x);
181
+ }
182
+ #endif
183
+
184
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
185
+ inline float hsum(__m256 x) {
186
+ return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
187
+ _mm256_castps256_ps128(x)));
188
+ }
189
+ #endif // __AVX__
190
+
191
+ #if defined(__AVX512F__)
192
+ inline float hsum(__m512 x) {
193
+ return _mm512_reduce_add_ps(x);
194
+ }
195
+ #endif // __AVX512F__
196
+
197
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
198
+ // VECTORIZED MEMORY LOADING
199
+
200
+ template <typename T, typename U> T load(const U *);
201
+
202
+ #if defined(__ARM_NEON)
203
+ template <> inline float32x4_t load(const float *p) {
204
+ return vld1q_f32(p);
205
+ }
206
+ #if !defined(_MSC_VER)
207
+ template <> inline float16x8_t load(const ggml_fp16_t *p) {
208
+ return vld1q_f16((const float16_t *)p);
209
+ }
210
+ template <> inline float32x4_t load(const ggml_fp16_t *p) {
211
+ return vcvt_f32_f16(vld1_f16((const float16_t *)p));
212
+ }
213
+ #endif // _MSC_VER
214
+ #endif // __ARM_NEON
215
+
216
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
217
+ template <> inline __m128 load(const float *p) {
218
+ return _mm_loadu_ps(p);
219
+ }
220
+ #endif // __SSE__
221
+
222
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
223
+ template <> inline __m256 load(const float *p) {
224
+ return _mm256_loadu_ps(p);
225
+ }
226
+ #endif // __AVX__
227
+
228
+ #if defined(__F16C__)
229
+ template <> inline __m256 load(const ggml_fp16_t *p) {
230
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
231
+ }
232
+ #endif // __F16C__
233
+
234
+ #if defined(__AVX512F__)
235
+ template <> inline __m512 load(const float *p) {
236
+ return _mm512_loadu_ps(p);
237
+ }
238
+ template <> inline __m512 load(const ggml_fp16_t *p) {
239
+ return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
240
+ }
241
+ #endif // __AVX512F__
242
+
243
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
244
+ // CONSTANTS
245
+
246
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
247
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
248
+ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
249
+ #endif
250
+
251
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
252
+ // FLOATING POINT MATRIX MULTIPLICATION
253
+
254
+ template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
255
+ class tinyBLAS {
256
+ public:
257
+ tinyBLAS(int64_t k,
258
+ const TA *A, int64_t lda,
259
+ const TB *B, int64_t ldb,
260
+ TC *C, int64_t ldc,
261
+ int ith, int nth)
262
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
263
+ }
264
+
265
+ void matmul(int64_t m, int64_t n) {
266
+ mnpack(0, m, 0, n);
267
+ }
268
+
269
+ private:
270
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
271
+ int64_t mc, nc, mp, np;
272
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
273
+ #if VECTOR_REGISTERS == 32
274
+ case 0x55:
275
+ mc = 5;
276
+ nc = 5;
277
+ gemm<5, 5>(m0, m, n0, n);
278
+ break;
279
+ case 0x45:
280
+ mc = 4;
281
+ nc = 5;
282
+ gemm<4, 5>(m0, m, n0, n);
283
+ break;
284
+ case 0x54:
285
+ mc = 5;
286
+ nc = 4;
287
+ gemm<5, 4>(m0, m, n0, n);
288
+ break;
289
+ case 0x44:
290
+ mc = 4;
291
+ nc = 4;
292
+ gemm<4, 4>(m0, m, n0, n);
293
+ break;
294
+ case 0x53:
295
+ mc = 5;
296
+ nc = 3;
297
+ gemm<5, 3>(m0, m, n0, n);
298
+ break;
299
+ case 0x35:
300
+ mc = 3;
301
+ nc = 5;
302
+ gemm<3, 5>(m0, m, n0, n);
303
+ break;
304
+ case 0x43:
305
+ mc = 4;
306
+ nc = 3;
307
+ gemm<4, 3>(m0, m, n0, n);
308
+ break;
309
+ #else
310
+ case 0x55:
311
+ case 0x54:
312
+ case 0x53:
313
+ case 0x45:
314
+ case 0x44:
315
+ case 0x43:
316
+ mc = 4;
317
+ nc = 3;
318
+ gemm<4, 3>(m0, m, n0, n);
319
+ break;
320
+ case 0x35:
321
+ #endif
322
+ case 0x34:
323
+ mc = 3;
324
+ nc = 4;
325
+ gemm<3, 4>(m0, m, n0, n);
326
+ break;
327
+ case 0x52:
328
+ mc = 5;
329
+ nc = 2;
330
+ gemm<5, 2>(m0, m, n0, n);
331
+ break;
332
+ case 0x33:
333
+ mc = 3;
334
+ nc = 3;
335
+ gemm<3, 3>(m0, m, n0, n);
336
+ break;
337
+ case 0x25:
338
+ mc = 2;
339
+ nc = 5;
340
+ gemm<2, 5>(m0, m, n0, n);
341
+ break;
342
+ case 0x42:
343
+ mc = 4;
344
+ nc = 2;
345
+ gemm<4, 2>(m0, m, n0, n);
346
+ break;
347
+ case 0x24:
348
+ mc = 2;
349
+ nc = 4;
350
+ gemm<2, 4>(m0, m, n0, n);
351
+ break;
352
+ case 0x32:
353
+ mc = 3;
354
+ nc = 2;
355
+ gemm<3, 2>(m0, m, n0, n);
356
+ break;
357
+ case 0x23:
358
+ mc = 2;
359
+ nc = 3;
360
+ gemm<2, 3>(m0, m, n0, n);
361
+ break;
362
+ case 0x51:
363
+ mc = 5;
364
+ nc = 1;
365
+ gemm<5, 1>(m0, m, n0, n);
366
+ break;
367
+ case 0x41:
368
+ mc = 4;
369
+ nc = 1;
370
+ gemm<4, 1>(m0, m, n0, n);
371
+ break;
372
+ case 0x22:
373
+ mc = 2;
374
+ nc = 2;
375
+ gemm<2, 2>(m0, m, n0, n);
376
+ break;
377
+ case 0x15:
378
+ mc = 1;
379
+ nc = 5;
380
+ gemm<1, 5>(m0, m, n0, n);
381
+ break;
382
+ case 0x14:
383
+ mc = 1;
384
+ nc = 4;
385
+ gemm<1, 4>(m0, m, n0, n);
386
+ break;
387
+ case 0x31:
388
+ mc = 3;
389
+ nc = 1;
390
+ gemm<3, 1>(m0, m, n0, n);
391
+ break;
392
+ case 0x13:
393
+ mc = 1;
394
+ nc = 3;
395
+ gemm<1, 3>(m0, m, n0, n);
396
+ break;
397
+ case 0x21:
398
+ mc = 2;
399
+ nc = 1;
400
+ gemm<2, 1>(m0, m, n0, n);
401
+ break;
402
+ case 0x12:
403
+ mc = 1;
404
+ nc = 2;
405
+ gemm<1, 2>(m0, m, n0, n);
406
+ break;
407
+ case 0x11:
408
+ mc = 1;
409
+ nc = 1;
410
+ gemm<1, 1>(m0, m, n0, n);
411
+ break;
412
+ default:
413
+ return;
414
+ }
415
+ mp = m0 + (m - m0) / mc * mc;
416
+ np = n0 + (n - n0) / nc * nc;
417
+ mnpack(mp, m, n0, np);
418
+ mnpack(m0, m, np, n);
419
+ }
420
+
421
+ template <int RM, int RN>
422
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
423
+ int64_t ytiles = (m - m0) / RM;
424
+ int64_t xtiles = (n - n0) / RN;
425
+ int64_t tiles = xtiles * ytiles;
426
+ int64_t duty = (tiles + nth - 1) / nth;
427
+ int64_t start = duty * ith;
428
+ int64_t end = start + duty;
429
+ if (end > tiles)
430
+ end = tiles;
431
+ for (int64_t job = start; job < end; ++job) {
432
+ int64_t ii = m0 + job / xtiles * RM;
433
+ int64_t jj = n0 + job % xtiles * RN;
434
+ D Cv[RN][RM] = {};
435
+ for (int64_t l = 0; l < k; l += KN)
436
+ for (int64_t j = 0; j < RN; ++j)
437
+ for (int64_t i = 0; i < RM; ++i)
438
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
439
+ load<V>(B + ldb * (jj + j) + l),
440
+ Cv[j][i]);
441
+ for (int64_t j = 0; j < RN; ++j)
442
+ for (int64_t i = 0; i < RM; ++i)
443
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
444
+ }
445
+ }
446
+
447
+ const TA *const A;
448
+ const TB *const B;
449
+ TC *const C;
450
+ const int64_t k;
451
+ const int64_t lda;
452
+ const int64_t ldb;
453
+ const int64_t ldc;
454
+ const int ith;
455
+ const int nth;
456
+ };
457
+
458
+ //////////////////////////////////////////////////////////////////////////////////////////
459
+ // QUANT ZERO MATRIX MULTIPLICATION
460
+
461
+ #if defined(__ARM_FEATURE_DOTPROD)
462
+ template <typename TA>
463
+ class tinyBLAS_Q0_ARM {
464
+ public:
465
+ tinyBLAS_Q0_ARM(int64_t k,
466
+ const TA *A, int64_t lda,
467
+ const block_q8_0 *B, int64_t ldb,
468
+ float *C, int64_t ldc,
469
+ int ith, int nth)
470
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
471
+ }
472
+
473
+ void matmul(int64_t m, int64_t n) {
474
+ mnpack(0, m, 0, n);
475
+ }
476
+
477
+ private:
478
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
479
+ int64_t mc, nc, mp, np;
480
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
481
+ case 0x33:
482
+ mc = 3;
483
+ nc = 3;
484
+ gemm<3, 3>(m0, m, n0, n);
485
+ break;
486
+ case 0x32:
487
+ mc = 3;
488
+ nc = 2;
489
+ gemm<3, 2>(m0, m, n0, n);
490
+ break;
491
+ case 0x23:
492
+ mc = 2;
493
+ nc = 3;
494
+ gemm<2, 3>(m0, m, n0, n);
495
+ break;
496
+ case 0x22:
497
+ mc = 2;
498
+ nc = 2;
499
+ gemm<2, 2>(m0, m, n0, n);
500
+ break;
501
+ case 0x31:
502
+ mc = 3;
503
+ nc = 1;
504
+ gemm<3, 1>(m0, m, n0, n);
505
+ break;
506
+ case 0x13:
507
+ mc = 1;
508
+ nc = 3;
509
+ gemm<1, 3>(m0, m, n0, n);
510
+ break;
511
+ case 0x21:
512
+ mc = 2;
513
+ nc = 1;
514
+ gemm<2, 1>(m0, m, n0, n);
515
+ break;
516
+ case 0x12:
517
+ mc = 1;
518
+ nc = 2;
519
+ gemm<1, 2>(m0, m, n0, n);
520
+ break;
521
+ case 0x11:
522
+ mc = 1;
523
+ nc = 1;
524
+ gemm<1, 1>(m0, m, n0, n);
525
+ break;
526
+ default:
527
+ return;
528
+ }
529
+ mp = m0 + (m - m0) / mc * mc;
530
+ np = n0 + (n - n0) / nc * nc;
531
+ mnpack(mp, m, n0, np);
532
+ mnpack(m0, m, np, n);
533
+ }
534
+
535
+ template <int RM, int RN>
536
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
537
+ int64_t ytiles = (m - m0) / RM;
538
+ int64_t xtiles = (n - n0) / RN;
539
+ int64_t tiles = xtiles * ytiles;
540
+ int64_t duty = (tiles + nth - 1) / nth;
541
+ int64_t start = duty * ith;
542
+ int64_t end = start + duty;
543
+ if (end > tiles)
544
+ end = tiles;
545
+ for (int64_t job = start; job < end; ++job) {
546
+ int64_t ii = m0 + job / xtiles * RM;
547
+ int64_t jj = n0 + job % xtiles * RN;
548
+ float32x4_t Cv[RN][RM] = {};
549
+ for (int64_t l = 0; l < k; ++l)
550
+ for (int64_t j = 0; j < RN; ++j)
551
+ for (int64_t i = 0; i < RM; ++i)
552
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
553
+ vcvtq_f32_s32(vdotq_s32(
554
+ vdotq_s32(vdupq_n_s32(0),
555
+ load_lo(A + lda * (ii + i) + l),
556
+ load_lo(B + ldb * (jj + j) + l)),
557
+ load_hi(A + lda * (ii + i) + l),
558
+ load_hi(B + ldb * (jj + j) + l))),
559
+ unhalf(A[lda * (ii + i) + l].d) *
560
+ unhalf(B[ldb * (jj + j) + l].d));
561
+ for (int64_t j = 0; j < RN; ++j)
562
+ for (int64_t i = 0; i < RM; ++i)
563
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
564
+ }
565
+ }
566
+
567
+ inline int8x16_t load_lo(const block_q8_0 *b) {
568
+ return vld1q_s8(b->qs);
569
+ }
570
+
571
+ inline int8x16_t load_hi(const block_q8_0 *b) {
572
+ return vld1q_s8(b->qs + 16);
573
+ }
574
+
575
+ inline int8x16_t load_lo(const block_q4_0 *b) {
576
+ return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
577
+ vdupq_n_u8(0x0f))),
578
+ vdupq_n_s8(0x8));
579
+ }
580
+
581
+ inline int8x16_t load_hi(const block_q4_0 *b) {
582
+ return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
583
+ vdupq_n_s8(0x8));
584
+ }
585
+
586
+ const TA *const A;
587
+ const block_q8_0 *const B;
588
+ float *const C;
589
+ const int64_t k;
590
+ const int64_t lda;
591
+ const int64_t ldb;
592
+ const int64_t ldc;
593
+ const int ith;
594
+ const int nth;
595
+ };
596
+ #endif // __ARM_FEATURE_DOTPROD
597
+
598
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
599
+ template <typename TA, typename TB, typename TC>
600
+ class tinyBLAS_Q0_AVX {
601
+ public:
602
+ tinyBLAS_Q0_AVX(int64_t k,
603
+ const TA *A, int64_t lda,
604
+ const TB *B, int64_t ldb,
605
+ TC *C, int64_t ldc,
606
+ int ith, int nth)
607
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
608
+ }
609
+
610
+ void matmul(int64_t m, int64_t n) {
611
+ mnpack(0, m, 0, n);
612
+ }
613
+
614
+ private:
615
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
616
+ int64_t mc, nc, mp, np;
617
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
618
+ #if VECTOR_REGISTERS == 32
619
+ case 0x44:
620
+ mc = 4;
621
+ nc = 4;
622
+ #if defined(__AVX2__) && defined(__F16C__)
623
+ gemm4xN<4>(m0, m, n0, n);
624
+ #else
625
+ gemm<4, 4>(m0, m, n0, n);
626
+ #endif
627
+ break;
628
+ case 0x43:
629
+ mc = 4;
630
+ nc = 3;
631
+ #if defined(__AVX2__) && defined(__F16C__)
632
+ gemm4xN<3>(m0, m, n0, n);
633
+ #else
634
+ gemm<4, 3>(m0, m, n0, n);
635
+ #endif
636
+ break;
637
+ case 0x34:
638
+ mc = 3;
639
+ nc = 4;
640
+ #if defined(__AVX2__) && defined(__F16C__)
641
+ gemmMx4<3>(m0, m, n0, n);
642
+ #else
643
+ gemm<3, 4>(m0, m, n0, n);
644
+ #endif
645
+ break;
646
+ case 0x33:
647
+ mc = 3;
648
+ nc = 3;
649
+ gemm<3, 3>(m0, m, n0, n);
650
+ break;
651
+ case 0x42:
652
+ mc = 4;
653
+ nc = 2;
654
+ #if defined(__AVX2__) && defined(__F16C__)
655
+ gemm4xN<2>(m0, m, n0, n);
656
+ #else
657
+ gemm<4, 2>(m0, m, n0, n);
658
+ #endif
659
+ break;
660
+ case 0x24:
661
+ mc = 2;
662
+ nc = 4;
663
+ #if defined(__AVX2__) && defined(__F16C__)
664
+ gemmMx4<2>(m0, m, n0, n);
665
+ #else
666
+ gemm<2, 4>(m0, m, n0, n);
667
+ #endif
668
+ break;
669
+ #else
670
+ case 0x44:
671
+ case 0x43:
672
+ case 0x42:
673
+ mc = 4;
674
+ nc = 2;
675
+ #if defined(__AVX2__) && defined(__F16C__)
676
+ gemm4xN<2>(m0, m, n0, n);
677
+ #else
678
+ gemm<4, 2>(m0, m, n0, n);
679
+ #endif
680
+ break;
681
+ case 0x34:
682
+ case 0x24:
683
+ mc = 2;
684
+ nc = 4;
685
+ #if defined(__AVX2__) && defined(__F16C__)
686
+ gemmMx4<2>(m0, m, n0, n);
687
+ #else
688
+ gemm<2, 4>(m0, m, n0, n);
689
+ #endif
690
+ break;
691
+ case 0x33:
692
+ #endif
693
+ case 0x32:
694
+ mc = 3;
695
+ nc = 2;
696
+ gemm<3, 2>(m0, m, n0, n);
697
+ break;
698
+ case 0x23:
699
+ mc = 2;
700
+ nc = 3;
701
+ gemm<2, 3>(m0, m, n0, n);
702
+ break;
703
+ case 0x41:
704
+ mc = 4;
705
+ nc = 1;
706
+ #if defined(__AVX2__) && defined(__F16C__)
707
+ gemm4xN<1>(m0, m, n0, n);
708
+ #else
709
+ gemm<4, 1>(m0, m, n0, n);
710
+ #endif
711
+ break;
712
+ case 0x22:
713
+ mc = 2;
714
+ nc = 2;
715
+ gemm<2, 2>(m0, m, n0, n);
716
+ break;
717
+ case 0x14:
718
+ mc = 1;
719
+ nc = 4;
720
+ #if defined(__AVX2__) && defined(__F16C__)
721
+ gemmMx4<1>(m0, m, n0, n);
722
+ #else
723
+ gemm<1, 4>(m0, m, n0, n);
724
+ #endif
725
+ break;
726
+ case 0x31:
727
+ mc = 3;
728
+ nc = 1;
729
+ gemm<3, 1>(m0, m, n0, n);
730
+ break;
731
+ case 0x13:
732
+ mc = 1;
733
+ nc = 3;
734
+ gemm<1, 3>(m0, m, n0, n);
735
+ break;
736
+ case 0x21:
737
+ mc = 2;
738
+ nc = 1;
739
+ gemm<2, 1>(m0, m, n0, n);
740
+ break;
741
+ case 0x12:
742
+ mc = 1;
743
+ nc = 2;
744
+ gemm<1, 2>(m0, m, n0, n);
745
+ break;
746
+ case 0x11:
747
+ mc = 1;
748
+ nc = 1;
749
+ gemm<1, 1>(m0, m, n0, n);
750
+ break;
751
+ default:
752
+ return;
753
+ }
754
+ mp = m0 + (m - m0) / mc * mc;
755
+ np = n0 + (n - n0) / nc * nc;
756
+ mnpack(mp, m, n0, np);
757
+ mnpack(m0, m, np, n);
758
+ }
759
+
760
+ #if defined(__AVX2__) && defined(__F16C__)
761
+ // Templated functions for gemm of dimensions 4xN
762
+ template <int RN>
763
+ NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
764
+ int64_t ytiles = (m - m0) / 4;
765
+ int64_t xtiles = (n - n0) / RN;
766
+ int64_t tiles = xtiles * ytiles;
767
+ int64_t duty = (tiles + nth - 1) / nth;
768
+ int64_t start = duty * ith;
769
+ int64_t end = start + duty;
770
+ if (end > tiles)
771
+ end = tiles;
772
+ for (int64_t job = start; job < end; ++job) {
773
+ int64_t ii = m0 + job / xtiles * 4;
774
+ int64_t jj = n0 + job % xtiles * RN;
775
+ __m256 Cv[RN][4] = {};
776
+ for (int64_t l = 0; l < k; ++l) {
777
+ uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
778
+ // Convert delta values for four blocks to float values
779
+ __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
780
+ __m256i avec0 = load(A + lda * (ii + 0) + l);
781
+ __m256i avec1 = load(A + lda * (ii + 1) + l);
782
+ __m256i avec2 = load(A + lda * (ii + 2) + l);
783
+ __m256i avec3 = load(A + lda * (ii + 3) + l);
784
+ for (int64_t j = 0; j < RN; ++j) {
785
+ __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
786
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
787
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
788
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
789
+ // Computation of dot product and multiplication with appropriate delta value products
790
+ Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
791
+ updot(_mm256_sign_epi8(avec0, avec0),
792
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
793
+ Cv[j][0]);
794
+ Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
795
+ updot(_mm256_sign_epi8(avec1, avec1),
796
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
797
+ Cv[j][1]);
798
+ Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
799
+ updot(_mm256_sign_epi8(avec2, avec2),
800
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
801
+ Cv[j][2]);
802
+ Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
803
+ updot(_mm256_sign_epi8(avec3, avec3),
804
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
805
+ Cv[j][3]);
806
+ }
807
+ }
808
+
809
+ for (int64_t j = 0; j < RN; ++j)
810
+ for (int64_t i = 0; i < 4; ++i)
811
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
812
+ }
813
+ }
814
+
815
+ // Templated functions for gemm of dimensions Mx4
816
+ template <int RM>
817
+ NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
818
+ int64_t ytiles = (m - m0) / RM;
819
+ int64_t xtiles = (n - n0) / 4;
820
+ int64_t tiles = xtiles * ytiles;
821
+ int64_t duty = (tiles + nth - 1) / nth;
822
+ int64_t start = duty * ith;
823
+ int64_t end = start + duty;
824
+ if (end > tiles)
825
+ end = tiles;
826
+ for (int64_t job = start; job < end; ++job) {
827
+ int64_t ii = m0 + job / xtiles * RM;
828
+ int64_t jj = n0 + job % xtiles * 4;
829
+ __m256 Cv[4][RM] = {};
830
+ for (int64_t l = 0; l < k; ++l) {
831
+ uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
832
+ // Convert delta values for four blocks to float values
833
+ __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
834
+ __m256i bvec0 = load(B + ldb * (jj + 0) + l);
835
+ __m256i bvec1 = load(B + ldb * (jj + 1) + l);
836
+ __m256i bvec2 = load(B + ldb * (jj + 2) + l);
837
+ __m256i bvec3 = load(B + ldb * (jj + 3) + l);
838
+ for (int64_t i = 0; i < RM; ++i) {
839
+ __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
840
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
841
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
842
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
843
+ // Computation of dot product and multiplication with appropriate delta value products
844
+ Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
845
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
846
+ load(A + lda * (ii + i) + l)),
847
+ _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
848
+ Cv[0][i]);
849
+ Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
850
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
851
+ load(A + lda * (ii + i) + l)),
852
+ _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
853
+ Cv[1][i]);
854
+ Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
855
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
856
+ load(A + lda * (ii + i) + l)),
857
+ _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
858
+ Cv[2][i]);
859
+ Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
860
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
861
+ load(A + lda * (ii + i) + l)),
862
+ _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
863
+ Cv[3][i]);
864
+ }
865
+ }
866
+ for (int64_t j = 0; j < 4; ++j)
867
+ for (int64_t i = 0; i < RM; ++i)
868
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
869
+ }
870
+ }
871
+ #endif
872
+
873
+ template <int RM, int RN>
874
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
875
+ int64_t ytiles = (m - m0) / RM;
876
+ int64_t xtiles = (n - n0) / RN;
877
+ int64_t tiles = xtiles * ytiles;
878
+ int64_t duty = (tiles + nth - 1) / nth;
879
+ int64_t start = duty * ith;
880
+ int64_t end = start + duty;
881
+ if (end > tiles)
882
+ end = tiles;
883
+ for (int64_t job = start; job < end; ++job) {
884
+ int64_t ii = m0 + job / xtiles * RM;
885
+ int64_t jj = n0 + job % xtiles * RN;
886
+ __m256 Cv[RN][RM] = {};
887
+ for (int64_t l = 0; l < k; ++l)
888
+ for (int64_t j = 0; j < RN; ++j)
889
+ for (int64_t i = 0; i < RM; ++i) {
890
+ #if defined(__AVX2__)
891
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
892
+ load(A + lda * (ii + i) + l)),
893
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
894
+ load(A + lda * (ii + i) + l)));
895
+ #else
896
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
897
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
898
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
899
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
900
+
901
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
902
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
903
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
904
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
905
+
906
+ // updot
907
+ const __m128i oneFill = _mm_set1_epi16(1);
908
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
909
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
910
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
911
+ #endif
912
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
913
+ unhalf(B[ldb * (jj + j) + l].d)),
914
+ udTmp,
915
+ Cv[j][i]);
916
+ }
917
+ for (int64_t j = 0; j < RN; ++j)
918
+ for (int64_t i = 0; i < RM; ++i)
919
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
920
+ }
921
+ }
922
+
923
+ inline __m256i load(const block_q8_0 *b) {
924
+ return _mm256_loadu_si256((const __m256i *)b->qs);
925
+ }
926
+
927
+ inline __m128i load0(const block_q8_0 *b) {
928
+ return _mm_loadu_si128((const __m128i *)b->qs);
929
+ }
930
+
931
+ inline __m128i load1(const block_q8_0 *b) {
932
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
933
+ }
934
+
935
+ inline __m256i load(const block_q4_0 *b) {
936
+ return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
937
+ }
938
+
939
+ inline __m128i load0(const block_q4_0 *b) {
940
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
941
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
942
+ }
943
+
944
+ inline __m128i load1(const block_q4_0 *b) {
945
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
946
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
947
+ }
948
+
949
+ inline __m256i load(const block_q5_0 *b) {
950
+ return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
951
+ }
952
+
953
+ inline __m128i load0(const block_q5_0* b) {
954
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
955
+ uint32_t x32;
956
+ memcpy(&x32, b->qh, sizeof(uint32_t));
957
+ __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
958
+ __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
959
+ _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
960
+ _mm_shuffle_epi8(_mm_set1_epi32(x32),
961
+ _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
962
+ bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
963
+ return _mm_or_si128(qxl, bytesl);
964
+ }
965
+
966
+ inline __m128i load1(const block_q5_0* b) {
967
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
968
+ uint32_t x32;
969
+ memcpy(&x32, b->qh, sizeof(uint32_t));
970
+ __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
971
+ __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
972
+ _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
973
+ _mm_shuffle_epi8(_mm_set1_epi32(x32),
974
+ _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
975
+ bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
976
+ return _mm_or_si128(qxh, bytesh);
977
+ }
978
+
979
+ inline __m256i load(const block_iq4_nl *b) {
980
+ return MM256_SET_M128I(load1(b), load0(b));
981
+ }
982
+
983
+ inline __m128i load0(const block_iq4_nl *b) {
984
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
985
+ return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
986
+ }
987
+
988
+ inline __m128i load1(const block_iq4_nl *b) {
989
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
990
+ return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
991
+ }
992
+
993
+ inline __m256 updot(__m256i u, __m256i s) {
994
+ __m256i res;
995
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
996
+ res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
997
+ #else
998
+ res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
999
+ #endif
1000
+ return _mm256_cvtepi32_ps(res);
1001
+ }
1002
+
1003
+ static inline __m256i denibble(const uint8_t *p) {
1004
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
1005
+ return _mm256_and_si256(_mm256_set1_epi8(15),
1006
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1007
+ _mm_srli_epi16(x, 4), 1));
1008
+ }
1009
+
1010
+ static inline __m256i bittobyte(const uint8_t *p) {
1011
+ uint32_t x32;
1012
+ memcpy(&x32, p, sizeof(uint32_t));
1013
+ __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1014
+ _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1015
+ _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1016
+ _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1017
+ 0x0101010101010101, 0x0000000000000000))));
1018
+ return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1019
+ }
1020
+
1021
+ const TA *const A;
1022
+ const TB *const B;
1023
+ TC *const C;
1024
+ const int64_t k;
1025
+ const int64_t lda;
1026
+ const int64_t ldb;
1027
+ const int64_t ldc;
1028
+ const int ith;
1029
+ const int nth;
1030
+ };
1031
+ #endif // __AVX__
1032
+
1033
+ //PPC Implementation
1034
+ #if defined(__MMA__)
1035
+
1036
+ #define SAVE_ACC(ACC, ii, jj) \
1037
+ __builtin_mma_disassemble_acc(vec_C, ACC); \
1038
+ for (int I = 0; I < 4; I++) { \
1039
+ for (int J = 0; J < 4; J++) { \
1040
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1041
+ } \
1042
+ } \
1043
+
1044
+ template <typename TA, typename TB, typename TC>
1045
+ class tinyBLAS_PPC {
1046
+ public:
1047
+ tinyBLAS_PPC(int64_t k,
1048
+ const TA *A, int64_t lda,
1049
+ const TB *B, int64_t ldb,
1050
+ TC *C, int64_t ldc,
1051
+ int ith, int nth)
1052
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1053
+ }
1054
+
1055
+ void matmul(int64_t m, int64_t n) {
1056
+ mnpack(0, m, 0, n);
1057
+ }
1058
+
1059
+ private:
1060
+
1061
+ void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1062
+
1063
+ void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
1064
+ int64_t i, j;
1065
+ float *aoffset = NULL, *boffset = NULL;
1066
+ float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1067
+ float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1068
+
1069
+ aoffset = const_cast<float*>(a);
1070
+ boffset = vec;
1071
+ j = (rows >> 3);
1072
+ if (j > 0) {
1073
+ do {
1074
+ aoffset1 = aoffset;
1075
+ aoffset2 = aoffset1 + lda;
1076
+ aoffset3 = aoffset2 + lda;
1077
+ aoffset4 = aoffset3 + lda;
1078
+ aoffset5 = aoffset4 + lda;
1079
+ aoffset6 = aoffset5 + lda;
1080
+ aoffset7 = aoffset6 + lda;
1081
+ aoffset8 = aoffset7 + lda;
1082
+ aoffset += 8 * lda;
1083
+ i = (cols >> 3);
1084
+ if (i > 0) {
1085
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1086
+ vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1087
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1088
+ do {
1089
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1090
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1091
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1092
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1093
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
1094
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
1095
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
1096
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
1097
+ __builtin_vsx_disassemble_pair(c1, &C1);
1098
+ __builtin_vsx_disassemble_pair(c2, &C2);
1099
+ __builtin_vsx_disassemble_pair(c3, &C3);
1100
+ __builtin_vsx_disassemble_pair(c4, &C4);
1101
+ __builtin_vsx_disassemble_pair(c5, &C5);
1102
+ __builtin_vsx_disassemble_pair(c6, &C6);
1103
+ __builtin_vsx_disassemble_pair(c7, &C7);
1104
+ __builtin_vsx_disassemble_pair(c8, &C8);
1105
+
1106
+ t1 = vec_mergeh(c1[0], c2[0]);
1107
+ t2 = vec_mergeh(c3[0], c4[0]);
1108
+ t3 = vec_mergeh(c5[0], c6[0]);
1109
+ t4 = vec_mergeh(c7[0], c8[0]);
1110
+ t5 = vec_xxpermdi(t1, t2, 0);
1111
+ t6 = vec_xxpermdi(t3, t4, 0);
1112
+ t7 = vec_xxpermdi(t1, t2, 3);
1113
+ t8 = vec_xxpermdi(t3, t4, 3);
1114
+ vec_xst(t5, 0, boffset);
1115
+ vec_xst(t6, 0, boffset+4);
1116
+ vec_xst(t7, 0, boffset+8);
1117
+ vec_xst(t8, 0, boffset+12);
1118
+
1119
+ t1 = vec_mergel(c1[0], c2[0]);
1120
+ t2 = vec_mergel(c3[0], c4[0]);
1121
+ t3 = vec_mergel(c5[0], c6[0]);
1122
+ t4 = vec_mergel(c7[0], c8[0]);
1123
+ t5 = vec_xxpermdi(t1, t2, 0);
1124
+ t6 = vec_xxpermdi(t3, t4, 0);
1125
+ t7 = vec_xxpermdi(t1, t2, 3);
1126
+ t8 = vec_xxpermdi(t3, t4, 3);
1127
+ vec_xst(t5, 0, boffset+16);
1128
+ vec_xst(t6, 0, boffset+20);
1129
+ vec_xst(t7, 0, boffset+24);
1130
+ vec_xst(t8, 0, boffset+28);
1131
+
1132
+ t1 = vec_mergeh(c1[1], c2[1]);
1133
+ t2 = vec_mergeh(c3[1], c4[1]);
1134
+ t3 = vec_mergeh(c5[1], c6[1]);
1135
+ t4 = vec_mergeh(c7[1], c8[1]);
1136
+ t5 = vec_xxpermdi(t1, t2, 0);
1137
+ t6 = vec_xxpermdi(t3, t4, 0);
1138
+ t7 = vec_xxpermdi(t1, t2, 3);
1139
+ t8 = vec_xxpermdi(t3, t4, 3);
1140
+ vec_xst(t5, 0, boffset+32);
1141
+ vec_xst(t6, 0, boffset+36);
1142
+ vec_xst(t7, 0, boffset+40);
1143
+ vec_xst(t8, 0, boffset+44);
1144
+
1145
+ t1 = vec_mergel(c1[1], c2[1]);
1146
+ t2 = vec_mergel(c3[1], c4[1]);
1147
+ t3 = vec_mergel(c5[1], c6[1]);
1148
+ t4 = vec_mergel(c7[1], c8[1]);
1149
+ t5 = vec_xxpermdi(t1, t2, 0);
1150
+ t6 = vec_xxpermdi(t3, t4, 0);
1151
+ t7 = vec_xxpermdi(t1, t2, 3);
1152
+ t8 = vec_xxpermdi(t3, t4, 3);
1153
+ vec_xst(t5, 0, boffset+48);
1154
+ vec_xst(t6, 0, boffset+52);
1155
+ vec_xst(t7, 0, boffset+56);
1156
+ vec_xst(t8, 0, boffset+60);
1157
+
1158
+ aoffset1 += 8*lda;
1159
+ aoffset2 += 8*lda;
1160
+ aoffset3 += 8*lda;
1161
+ aoffset4 += 8*lda;
1162
+ boffset += 64;
1163
+ i--;
1164
+ } while(i > 0);
1165
+ }
1166
+ if (cols & 4) {
1167
+ vector float c1, c2, c3, c4, c5, c6, c7, c8;
1168
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1169
+ c1 = vec_xl(0, aoffset1);
1170
+ c2 = vec_xl(0, aoffset2);
1171
+ c3 = vec_xl(0, aoffset3);
1172
+ c4 = vec_xl(0, aoffset4);
1173
+ c5 = vec_xl(0, aoffset5);
1174
+ c6 = vec_xl(0, aoffset6);
1175
+ c7 = vec_xl(0, aoffset7);
1176
+ c8 = vec_xl(0, aoffset8);
1177
+
1178
+ t1 = vec_mergeh(c1, c2);
1179
+ t2 = vec_mergeh(c3, c4);
1180
+ t3 = vec_mergeh(c5, c6);
1181
+ t4 = vec_mergeh(c7, c8);
1182
+ t5 = vec_xxpermdi(t1, t2, 0);
1183
+ t6 = vec_xxpermdi(t3, t4, 0);
1184
+ t7 = vec_xxpermdi(t1, t2, 3);
1185
+ t8 = vec_xxpermdi(t3, t4, 3);
1186
+ vec_xst(t5, 0, boffset);
1187
+ vec_xst(t6, 0, boffset+4);
1188
+ vec_xst(t7, 0, boffset+8);
1189
+ vec_xst(t8, 0, boffset+12);
1190
+
1191
+ t1 = vec_mergel(c1, c2);
1192
+ t2 = vec_mergel(c3, c4);
1193
+ t3 = vec_mergel(c5, c6);
1194
+ t4 = vec_mergel(c7, c8);
1195
+ t5 = vec_xxpermdi(t1, t2, 0);
1196
+ t6 = vec_xxpermdi(t3, t4, 0);
1197
+ t7 = vec_xxpermdi(t1, t2, 3);
1198
+ t8 = vec_xxpermdi(t3, t4, 3);
1199
+ vec_xst(t5, 0, boffset+16);
1200
+ vec_xst(t6, 0, boffset+20);
1201
+ vec_xst(t7, 0, boffset+24);
1202
+ vec_xst(t8, 0, boffset+28);
1203
+ }
1204
+ j--;
1205
+ } while(j > 0);
1206
+ }
1207
+
1208
+ if (rows & 4) {
1209
+ aoffset1 = aoffset;
1210
+ aoffset2 = aoffset1 + lda;
1211
+ aoffset3 = aoffset2 + lda;
1212
+ aoffset4 = aoffset3 + lda;
1213
+ aoffset += 4 * lda;
1214
+ i = (cols >> 3);
1215
+ if (i > 0) {
1216
+ __vector_pair C1, C2, C3, C4;
1217
+ vector float c1[2], c2[2], c3[2], c4[2];
1218
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1219
+ do {
1220
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1221
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1222
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1223
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1224
+ __builtin_vsx_disassemble_pair(c1, &C1);
1225
+ __builtin_vsx_disassemble_pair(c2, &C2);
1226
+ __builtin_vsx_disassemble_pair(c3, &C3);
1227
+ __builtin_vsx_disassemble_pair(c4, &C4);
1228
+
1229
+ t1 = vec_mergeh(c1[0], c2[0]);
1230
+ t2 = vec_mergeh(c3[0], c4[0]);
1231
+ t3 = vec_mergel(c1[0], c2[0]);
1232
+ t4 = vec_mergel(c3[0], c4[0]);
1233
+ t5 = vec_xxpermdi(t1, t2, 0);
1234
+ t6 = vec_xxpermdi(t1, t2, 3);
1235
+ t7 = vec_xxpermdi(t3, t4, 0);
1236
+ t8 = vec_xxpermdi(t3, t4, 3);
1237
+ vec_xst(t5, 0, boffset);
1238
+ vec_xst(t6, 0, boffset+4);
1239
+ vec_xst(t7, 0, boffset+8);
1240
+ vec_xst(t8, 0, boffset+12);
1241
+
1242
+ t1 = vec_mergeh(c1[1], c2[1]);
1243
+ t2 = vec_mergeh(c3[1], c4[1]);
1244
+ t3 = vec_mergel(c1[1], c2[1]);
1245
+ t4 = vec_mergel(c3[1], c4[1]);
1246
+ t5 = vec_xxpermdi(t1, t2, 0);
1247
+ t6 = vec_xxpermdi(t1, t2, 3);
1248
+ t7 = vec_xxpermdi(t3, t4, 0);
1249
+ t8 = vec_xxpermdi(t3, t4, 3);
1250
+ vec_xst(t5, 0, boffset+16);
1251
+ vec_xst(t6, 0, boffset+20);
1252
+ vec_xst(t7, 0, boffset+24);
1253
+ vec_xst(t8, 0, boffset+28);
1254
+
1255
+ aoffset1 += 8*lda;
1256
+ aoffset2 += 8*lda;
1257
+ aoffset3 += 8*lda;
1258
+ aoffset4 += 8*lda;
1259
+ boffset += 32;
1260
+ i--;
1261
+ } while(i > 0);
1262
+ }
1263
+
1264
+ if (cols & 4) {
1265
+ vector float c1, c2, c3, c4;
1266
+ vector float t1, t2, t3, t4;
1267
+ c1 = vec_xl(0, aoffset1);
1268
+ c2 = vec_xl(0, aoffset2);
1269
+ c3 = vec_xl(0, aoffset3);
1270
+ c4 = vec_xl(0, aoffset4);
1271
+
1272
+ t1 = vec_mergeh(c1, c2);
1273
+ t2 = vec_mergeh(c3, c4);
1274
+ t3 = vec_xxpermdi(t1, t2, 0);
1275
+ t4 = vec_xxpermdi(t1, t2, 3);
1276
+ vec_xst(t3, 0, boffset);
1277
+ vec_xst(t4, 0, boffset+4);
1278
+
1279
+ t1 = vec_mergel(c1, c2);
1280
+ t2 = vec_mergel(c3, c4);
1281
+ t3 = vec_xxpermdi(t1, t2, 0);
1282
+ t4 = vec_xxpermdi(t1, t2, 3);
1283
+ vec_xst(t3, 0, boffset+8);
1284
+ vec_xst(t4, 0, boffset+12);
1285
+ }
1286
+ }
1287
+ if (rows & 3) {
1288
+ aoffset1 = aoffset;
1289
+ aoffset2 = aoffset1 + lda;
1290
+ aoffset3 = aoffset2 + lda;
1291
+ if (cols & 4) {
1292
+ vector float c1, c2, c3, c4 = {0};
1293
+ vector float t1, t2, t3, t4;
1294
+ c1 = vec_xl(0, aoffset1);
1295
+ c2 = vec_xl(0, aoffset2);
1296
+ c3 = vec_xl(0, aoffset3);
1297
+
1298
+ t1 = vec_mergeh(c1, c2);
1299
+ t2 = vec_mergeh(c3, c4);
1300
+ t3 = vec_xxpermdi(t1, t2, 0);
1301
+ t4 = vec_xxpermdi(t1, t2, 3);
1302
+ vec_xst(t3, 0, boffset);
1303
+ vec_xst(t4, 0, boffset+4);
1304
+
1305
+ t1 = vec_mergel(c1, c2);
1306
+ t2 = vec_mergel(c3, c4);
1307
+ t3 = vec_xxpermdi(t1, t2, 0);
1308
+ t4 = vec_xxpermdi(t1, t2, 3);
1309
+ vec_xst(t3, 0, boffset+8);
1310
+ vec_xst(t4, 0, boffset+12);
1311
+ }
1312
+ }
1313
+ }
1314
+
1315
+ void KERNEL_4x4(int64_t ii, int64_t jj) {
1316
+ vec_t vec_A[4], vec_B[4], vec_C[4];
1317
+ acc_t acc_0;
1318
+ __builtin_mma_xxsetaccz(&acc_0);
1319
+ for (int l = 0; l < k; l+=4) {
1320
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1321
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1322
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1323
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1324
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1325
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1326
+ }
1327
+ SAVE_ACC(&acc_0, ii, jj);
1328
+ }
1329
+
1330
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1331
+ vec_t vec_A[4], vec_B[8], vec_C[4];
1332
+ acc_t acc_0, acc_1;
1333
+ __builtin_mma_xxsetaccz(&acc_0);
1334
+ __builtin_mma_xxsetaccz(&acc_1);
1335
+ for (int64_t l = 0; l < k; l+=4) {
1336
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1337
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
1338
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1339
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1340
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
1341
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
1342
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
1343
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
1344
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
1345
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
1346
+ }
1347
+ SAVE_ACC(&acc_0, ii, jj);
1348
+ SAVE_ACC(&acc_1, ii, jj+4);
1349
+ }
1350
+
1351
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1352
+ vec_t vec_A[8], vec_B[4], vec_C[4];
1353
+ acc_t acc_0, acc_1;
1354
+ __builtin_mma_xxsetaccz(&acc_0);
1355
+ __builtin_mma_xxsetaccz(&acc_1);
1356
+ for (int64_t l = 0; l < k; l+=4) {
1357
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1358
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1359
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1360
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1361
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
1362
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
1363
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
1364
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
1365
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
1366
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
1367
+ }
1368
+ SAVE_ACC(&acc_0, ii, jj);
1369
+ SAVE_ACC(&acc_1, ii+4, jj);
1370
+ }
1371
+
1372
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1373
+ vec_t vec_A[16], vec_B[16], vec_C[4];
1374
+ acc_t acc_0, acc_1, acc_2, acc_3;
1375
+ __builtin_mma_xxsetaccz(&acc_0);
1376
+ __builtin_mma_xxsetaccz(&acc_1);
1377
+ __builtin_mma_xxsetaccz(&acc_2);
1378
+ __builtin_mma_xxsetaccz(&acc_3);
1379
+ for (int l = 0; l < k; l+=8) {
1380
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1381
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
1382
+ for(int x = 0; x < 16; x+=2) {
1383
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1384
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
1385
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
1386
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
1387
+ }
1388
+ }
1389
+ SAVE_ACC(&acc_0, ii, jj);
1390
+ SAVE_ACC(&acc_1, ii, jj+4);
1391
+ SAVE_ACC(&acc_2, ii+4, jj);
1392
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1393
+ }
1394
+
1395
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1396
+ int64_t mc, nc, mp, np;
1397
+ int m_rem = MIN(m - m0, 16);
1398
+ int n_rem = MIN(n - n0, 16);
1399
+ if (m_rem >= 16 && n_rem >= 8) {
1400
+ mc = 8;
1401
+ nc = 8;
1402
+ gemm<8,8>(m0, m, n0, n);
1403
+ } else if(m_rem >= 8 && n_rem >= 16) {
1404
+ mc = 8;
1405
+ nc = 8;
1406
+ gemm<8,8>(m0, m, n0, n);
1407
+ } else if (m_rem >= 8 && n_rem >= 8) {
1408
+ mc = 8;
1409
+ nc = 8;
1410
+ gemm<8,8>(m0, m, n0, n);
1411
+ } else if (m_rem >= 4 && n_rem >= 8) {
1412
+ mc = 4;
1413
+ nc = 8;
1414
+ gemm<4,8>(m0, m, n0, n);
1415
+ } else if (m_rem >= 8 && n_rem >= 4) {
1416
+ mc = 8;
1417
+ nc = 4;
1418
+ gemm<8,4>(m0, m, n0, n);
1419
+ } else if (m_rem >= 4 && n_rem >= 4) {
1420
+ mc = 4;
1421
+ nc = 4;
1422
+ gemm<4,4>(m0, m, n0, n);
1423
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1424
+ nc = 4;
1425
+ switch(m_rem) {
1426
+ case 1:
1427
+ mc = 1;
1428
+ gemm_small(m0, m, n0, n, mc, nc);
1429
+ break;
1430
+ case 2:
1431
+ mc = 2;
1432
+ gemm_small(m0, m, n0, n, mc, nc);
1433
+ break;
1434
+ case 3:
1435
+ mc = 3;
1436
+ gemm_small(m0, m, n0, n, mc, nc);
1437
+ break;
1438
+ default:
1439
+ return;
1440
+ }
1441
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1442
+ mc = 4;
1443
+ switch(n_rem) {
1444
+ case 1:
1445
+ nc = 1;
1446
+ gemm_small(m0, m, n0, n, mc, nc);
1447
+ break;
1448
+ case 2:
1449
+ nc = 2;
1450
+ gemm_small(m0, m, n0, n, mc, nc);
1451
+ break;
1452
+ case 3:
1453
+ nc = 3;
1454
+ gemm_small(m0, m, n0, n, mc, nc);
1455
+ break;
1456
+ default:
1457
+ return;
1458
+ }
1459
+ } else {
1460
+ switch((m_rem << 4) | n_rem) {
1461
+ case 0x43:
1462
+ mc = 4;
1463
+ nc = 3;
1464
+ gemm_small(m0, m, n0, n, mc, nc);
1465
+ break;
1466
+ case 0x42:
1467
+ mc = 4;
1468
+ nc = 2;
1469
+ gemm_small(m0, m, n0, n, mc, nc);
1470
+ break;
1471
+ case 0x41:
1472
+ mc = 4;
1473
+ nc = 1;
1474
+ gemm_small(m0, m, n0, n, mc, nc);
1475
+ break;
1476
+ case 0x34:
1477
+ mc = 3;
1478
+ nc = 4;
1479
+ gemm_small(m0, m, n0, n, mc, nc);
1480
+ break;
1481
+ case 0x33:
1482
+ mc = 3;
1483
+ nc = 3;
1484
+ gemm_small(m0, m, n0, n, mc, nc);
1485
+ break;
1486
+ case 0x32:
1487
+ mc = 3;
1488
+ nc = 2;
1489
+ gemm_small(m0, m, n0, n, mc, nc);
1490
+ break;
1491
+ case 0x31:
1492
+ mc = 3;
1493
+ nc = 1;
1494
+ gemm_small(m0, m, n0, n, mc, nc);
1495
+ break;
1496
+ case 0x24:
1497
+ mc = 2;
1498
+ nc = 4;
1499
+ gemm_small(m0, m, n0, n, mc, nc);
1500
+ break;
1501
+ case 0x23:
1502
+ mc = 2;
1503
+ nc = 3;
1504
+ gemm_small(m0, m, n0, n, mc, nc);
1505
+ break;
1506
+ case 0x22:
1507
+ mc = 2;
1508
+ nc = 2;
1509
+ gemm_small(m0, m, n0, n, mc, nc);
1510
+ break;
1511
+ case 0x21:
1512
+ mc = 2;
1513
+ nc = 1;
1514
+ gemm_small(m0, m, n0, n, mc, nc);
1515
+ break;
1516
+ case 0x14:
1517
+ mc = 1;
1518
+ nc = 4;
1519
+ gemm_small(m0, m, n0, n, mc, nc);
1520
+ break;
1521
+ case 0x13:
1522
+ mc = 1;
1523
+ nc = 3;
1524
+ gemm_small(m0, m, n0, n, mc, nc);
1525
+ break;
1526
+ case 0x12:
1527
+ mc = 1;
1528
+ nc = 2;
1529
+ gemm_small(m0, m, n0, n, mc, nc);
1530
+ break;
1531
+ case 0x11:
1532
+ mc = 1;
1533
+ nc = 1;
1534
+ gemm_small(m0, m, n0, n, mc, nc);
1535
+ break;
1536
+ default:
1537
+ return;
1538
+ }
1539
+ }
1540
+ mp = m0 + (m - m0) / mc * mc;
1541
+ np = n0 + (n - n0) / nc * nc;
1542
+ mnpack(mp, m, n0, np);
1543
+ mnpack(m0, m, np, n);
1544
+ }
1545
+
1546
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
1547
+ int64_t ytiles = (m - m0) / RM;
1548
+ int64_t xtiles = (n - n0) / RN;
1549
+ int64_t tiles = xtiles * ytiles;
1550
+ int64_t duty = (tiles + nth - 1) / nth;
1551
+ int64_t start = duty * ith;
1552
+ int64_t end = start + duty;
1553
+ if (end > tiles)
1554
+ end = tiles;
1555
+ for (int64_t job = start; job < end; ++job) {
1556
+ int64_t ii = m0 + job / xtiles * RM;
1557
+ int64_t jj = n0 + job % xtiles * RN;
1558
+ vec_t vec_C[4];
1559
+ acc_t acc_0;
1560
+ __builtin_mma_xxsetaccz(&acc_0);
1561
+ vec_t vec_A[4], vec_B[4];
1562
+ for (int l=0; l<k; l+=4) {
1563
+ if (RN >= 4 && RM == 1) {
1564
+ float* a = const_cast<float*>(A+(ii)*lda+l);
1565
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1566
+ vec_A[0] = (vec_t)vec_xl(0,a);
1567
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1568
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1569
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
1570
+ } else {
1571
+ READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1572
+ READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
1573
+ }
1574
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1575
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1576
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1577
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1578
+ }
1579
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1580
+ for (int I = 0; I < RM; I++) {
1581
+ for (int J = 0; J < RN; J++) {
1582
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
1583
+ }
1584
+ }
1585
+ }
1586
+ }
1587
+
1588
+ template <int RM, int RN>
1589
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1590
+ int64_t ytiles = (m - m0) / RM;
1591
+ int64_t xtiles = (n - n0) / RN;
1592
+ int64_t tiles = xtiles * ytiles;
1593
+ int64_t duty = (tiles + nth - 1) / nth;
1594
+ int64_t start = duty * ith;
1595
+ int64_t end = start + duty;
1596
+ if (RM == 4 && RN == 4) {
1597
+ kernel = &tinyBLAS_PPC::KERNEL_4x4;
1598
+ } else if (RM == 4 && RN == 8) {
1599
+ kernel = &tinyBLAS_PPC::KERNEL_4x8;
1600
+ } else if (RM == 8 && RN == 4) {
1601
+ kernel = &tinyBLAS_PPC::KERNEL_8x4;
1602
+ } else if (RM == 8 && RN == 8) {
1603
+ kernel = &tinyBLAS_PPC::KERNEL_8x8;
1604
+ }
1605
+ if (end > tiles)
1606
+ end = tiles;
1607
+ for (int64_t job = start; job < end; ++job) {
1608
+ int64_t ii = m0 + job / xtiles * RM;
1609
+ int64_t jj = n0 + job % xtiles * RN;
1610
+ (this->*kernel)(ii, jj);
1611
+ }
1612
+ }
1613
+
1614
+ const TA *const A;
1615
+ const TB *const B;
1616
+ TC *C;
1617
+ TA *At;
1618
+ TB *Bt;
1619
+ const int64_t k;
1620
+ const int64_t lda;
1621
+ const int64_t ldb;
1622
+ const int64_t ldc;
1623
+ const int ith;
1624
+ const int nth;
1625
+ };
1626
+ #endif
1627
+ } // namespace
1628
+
1629
+ /**
1630
+ * Performs optimized matrix multiplication on CPU.
1631
+ *
1632
+ * This subroutine may compute C = Aᵀ * B with column major ordering.
1633
+ * Despite its name, this isn't a generalized implementation. Work is
1634
+ * only performed when a handwritten kernel is written and available.
1635
+ * Otherwise the caller should fall back to a general matmul routine.
1636
+ *
1637
+ * For example, for single-threaded single-precision GEMM you can say
1638
+ *
1639
+ * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
1640
+ * 0, 1,
1641
+ * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
1642
+ *
1643
+ * @param m is rows in `A` and `C`
1644
+ * @param n is cols in `B` and `C`
1645
+ * @param k is cols in `A` and rows in `B`
1646
+ * @param A is first input matrix (always transposed)
1647
+ * @param lda is row stride of `A`
1648
+ * @param B is second input matrix (never transposed)
1649
+ * @param ldb is row stride of `B`
1650
+ * @param C is input/output array of output matrices
1651
+ * @param ldc is row stride of `C`
1652
+ * @param ith is thread id (must be less than `nth`)
1653
+ * @param nth is number of threads (must be greater than zero)
1654
+ * @param Atype is GGML data type of `A`
1655
+ * @param Btype is GGML data type of `B`
1656
+ * @param Ctype is GGML data type of `C`
1657
+ * @return true if this function was able to service the matmul request
1658
+ */
1659
+ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1660
+ int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
1661
+
1662
+ assert(m >= 0);
1663
+ assert(n >= 0);
1664
+ assert(k >= 0);
1665
+ assert(lda >= k);
1666
+ assert(ldb >= k);
1667
+ assert(ldc >= m);
1668
+ assert(nth > 0);
1669
+ assert(ith < nth);
1670
+
1671
+ // only enable sgemm for prompt processing
1672
+ if (n < 2)
1673
+ return false;
1674
+
1675
+ if (Ctype != GGML_TYPE_F32)
1676
+ return false;
1677
+
1678
+ switch (Atype) {
1679
+
1680
+ case GGML_TYPE_F32: {
1681
+ if (Btype != GGML_TYPE_F32)
1682
+ return false;
1683
+ #if defined(__AVX512F__)
1684
+ if (k % 16)
1685
+ return false;
1686
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{
1687
+ k, (const float *)A, lda,
1688
+ (const float *)B, ldb,
1689
+ (float *)C, ldc,
1690
+ ith, nth};
1691
+ tb.matmul(m, n);
1692
+ return true;
1693
+ #elif defined(__AVX__) || defined(__AVX2__)
1694
+ if (k % 8)
1695
+ return false;
1696
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{
1697
+ k, (const float *)A, lda,
1698
+ (const float *)B, ldb,
1699
+ (float *)C, ldc,
1700
+ ith, nth};
1701
+ tb.matmul(m, n);
1702
+ return true;
1703
+ #elif defined(__ARM_NEON)
1704
+ if (n < 4)
1705
+ return false;
1706
+ if (k % 4)
1707
+ return false;
1708
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
1709
+ k, (const float *)A, lda,
1710
+ (const float *)B, ldb,
1711
+ (float *)C, ldc,
1712
+ ith, nth};
1713
+ tb.matmul(m, n);
1714
+ return true;
1715
+ #elif defined(__MMA__)
1716
+ if (k % 8)
1717
+ return false;
1718
+ tinyBLAS_PPC<float, float, float> tb{
1719
+ k, (const float *)A, lda,
1720
+ (const float *)B, ldb,
1721
+ (float *)C, ldc,
1722
+ ith, nth};
1723
+ tb.matmul(m, n);
1724
+ return true;
1725
+ #else
1726
+ return false;
1727
+ #endif
1728
+ }
1729
+
1730
+ case GGML_TYPE_F16: {
1731
+ #if defined(__AVX512F__)
1732
+ if (k % 16)
1733
+ return false;
1734
+ if (Btype != GGML_TYPE_F32)
1735
+ return false;
1736
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
1737
+ k, (const ggml_fp16_t *)A, lda,
1738
+ (const float *)B, ldb,
1739
+ (float *)C, ldc,
1740
+ ith, nth};
1741
+ tb.matmul(m, n);
1742
+ return true;
1743
+ #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
1744
+ if (k % 8)
1745
+ return false;
1746
+ if (Btype != GGML_TYPE_F32)
1747
+ return false;
1748
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
1749
+ k, (const ggml_fp16_t *)A, lda,
1750
+ (const float *)B, ldb,
1751
+ (float *)C, ldc,
1752
+ ith, nth};
1753
+ tb.matmul(m, n);
1754
+ return true;
1755
+ #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
1756
+ if (n < 8)
1757
+ return false;
1758
+ if (k % 8)
1759
+ return false;
1760
+ if (Btype != GGML_TYPE_F16)
1761
+ return false;
1762
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
1763
+ k, (const ggml_fp16_t *)A, lda,
1764
+ (const ggml_fp16_t *)B, ldb,
1765
+ (float *)C, ldc,
1766
+ ith, nth};
1767
+ tb.matmul(m, n);
1768
+ return true;
1769
+ #elif defined(__ARM_NEON) && !defined(_MSC_VER)
1770
+ if (k % 4)
1771
+ return false;
1772
+ if (Btype != GGML_TYPE_F32)
1773
+ return false;
1774
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
1775
+ k, (const ggml_fp16_t *)A, lda,
1776
+ (const float *)B, ldb,
1777
+ (float *)C, ldc,
1778
+ ith, nth};
1779
+ tb.matmul(m, n);
1780
+ return true;
1781
+ #else
1782
+ return false;
1783
+ #endif
1784
+ }
1785
+
1786
+ case GGML_TYPE_Q8_0: {
1787
+ if (Btype != GGML_TYPE_Q8_0)
1788
+ return false;
1789
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1790
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
1791
+ k, (const block_q8_0 *)A, lda,
1792
+ (const block_q8_0 *)B, ldb,
1793
+ (float *)C, ldc,
1794
+ ith, nth};
1795
+ tb.matmul(m, n);
1796
+ return true;
1797
+ #elif defined(__ARM_FEATURE_DOTPROD)
1798
+ tinyBLAS_Q0_ARM<block_q8_0> tb{
1799
+ k, (const block_q8_0 *)A, lda,
1800
+ (const block_q8_0 *)B, ldb,
1801
+ (float *)C, ldc,
1802
+ ith, nth};
1803
+ tb.matmul(m, n);
1804
+ return true;
1805
+ #else
1806
+ return false;
1807
+ #endif
1808
+ }
1809
+
1810
+ case GGML_TYPE_Q4_0: {
1811
+ if (Btype != GGML_TYPE_Q8_0)
1812
+ return false;
1813
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1814
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
1815
+ k, (const block_q4_0 *)A, lda,
1816
+ (const block_q8_0 *)B, ldb,
1817
+ (float *)C, ldc,
1818
+ ith, nth};
1819
+ tb.matmul(m, n);
1820
+ return true;
1821
+ #elif defined(__ARM_FEATURE_DOTPROD)
1822
+ tinyBLAS_Q0_ARM<block_q4_0> tb{
1823
+ k, (const block_q4_0 *)A, lda,
1824
+ (const block_q8_0 *)B, ldb,
1825
+ (float *)C, ldc,
1826
+ ith, nth};
1827
+ tb.matmul(m, n);
1828
+ return true;
1829
+ #else
1830
+ return false;
1831
+ #endif
1832
+ }
1833
+
1834
+ case GGML_TYPE_Q5_0: {
1835
+ if (Btype != GGML_TYPE_Q8_0)
1836
+ return false;
1837
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1838
+ tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
1839
+ k, (const block_q5_0 *)A, lda,
1840
+ (const block_q8_0 *)B, ldb,
1841
+ (float *)C, ldc,
1842
+ ith, nth};
1843
+ tb.matmul(m, n);
1844
+ return true;
1845
+ #else
1846
+ return false;
1847
+ #endif
1848
+ }
1849
+
1850
+ case GGML_TYPE_IQ4_NL: {
1851
+ if (Btype != GGML_TYPE_Q8_0)
1852
+ return false;
1853
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1854
+ tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
1855
+ k, (const block_iq4_nl *)A, lda,
1856
+ (const block_q8_0 *)B, ldb,
1857
+ (float *)C, ldc,
1858
+ ith, nth};
1859
+ tb.matmul(m, n);
1860
+ return true;
1861
+ #else
1862
+ return false;
1863
+ #endif
1864
+ }
1865
+
1866
+ default:
1867
+ return false;
1868
+ }
1869
+
1870
+ (void)m;
1871
+ (void)n;
1872
+ (void)k;
1873
+ (void)A;
1874
+ (void)lda;
1875
+ (void)B;
1876
+ (void)ldb;
1877
+ (void)C;
1878
+ (void)ldc;
1879
+ (void)ith;
1880
+ (void)nth;
1881
+ (void)Atype;
1882
+ (void)Btype;
1883
+ (void)Ctype;
1884
+ }