whisper.rn 0.4.0-rc.3 → 0.4.0-rc.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/android/src/main/CMakeLists.txt +2 -0
  2. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  3. package/android/src/main/java/com/rnwhisper/WhisperContext.java +3 -3
  4. package/android/src/main/jni.cpp +6 -2
  5. package/cpp/ggml-alloc.c +413 -280
  6. package/cpp/ggml-alloc.h +67 -8
  7. package/cpp/ggml-backend-impl.h +87 -0
  8. package/cpp/ggml-backend.c +950 -0
  9. package/cpp/ggml-backend.h +136 -0
  10. package/cpp/ggml-impl.h +243 -0
  11. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +591 -121
  12. package/cpp/ggml-metal.h +21 -0
  13. package/cpp/ggml-metal.m +623 -234
  14. package/cpp/ggml-quants.c +7377 -0
  15. package/cpp/ggml-quants.h +224 -0
  16. package/cpp/ggml.c +3773 -4455
  17. package/cpp/ggml.h +279 -146
  18. package/cpp/whisper.cpp +182 -103
  19. package/cpp/whisper.h +48 -11
  20. package/ios/RNWhisper.mm +8 -2
  21. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  22. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  23. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  24. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  25. package/ios/RNWhisperContext.h +5 -1
  26. package/ios/RNWhisperContext.mm +76 -10
  27. package/jest/mock.js +1 -1
  28. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  29. package/lib/commonjs/index.js +28 -9
  30. package/lib/commonjs/index.js.map +1 -1
  31. package/lib/commonjs/version.json +1 -1
  32. package/lib/module/NativeRNWhisper.js.map +1 -1
  33. package/lib/module/index.js +28 -9
  34. package/lib/module/index.js.map +1 -1
  35. package/lib/module/version.json +1 -1
  36. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  37. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  38. package/lib/typescript/index.d.ts +7 -2
  39. package/lib/typescript/index.d.ts.map +1 -1
  40. package/package.json +1 -1
  41. package/src/NativeRNWhisper.ts +8 -1
  42. package/src/index.ts +29 -17
  43. package/src/version.json +1 -1
  44. package/whisper-rn.podspec +1 -2
@@ -13,23 +13,85 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
  #define QK8_0 32
22
37
  typedef struct {
23
38
  half d; // delta
24
39
  int8_t qs[QK8_0]; // quants
25
40
  } block_q8_0;
26
41
 
42
+ // general-purpose kernel for addition of two tensors
43
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
+ // cons: not very efficient
27
45
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
46
+ device const char * src0,
47
+ device const char * src1,
48
+ device char * dst,
49
+ constant int64_t & ne00,
50
+ constant int64_t & ne01,
51
+ constant int64_t & ne02,
52
+ constant int64_t & ne03,
53
+ constant int64_t & nb00,
54
+ constant int64_t & nb01,
55
+ constant int64_t & nb02,
56
+ constant int64_t & nb03,
57
+ constant int64_t & ne10,
58
+ constant int64_t & ne11,
59
+ constant int64_t & ne12,
60
+ constant int64_t & ne13,
61
+ constant int64_t & nb10,
62
+ constant int64_t & nb11,
63
+ constant int64_t & nb12,
64
+ constant int64_t & nb13,
65
+ constant int64_t & ne0,
66
+ constant int64_t & ne1,
67
+ constant int64_t & ne2,
68
+ constant int64_t & ne3,
69
+ constant int64_t & nb0,
70
+ constant int64_t & nb1,
71
+ constant int64_t & nb2,
72
+ constant int64_t & nb3,
73
+ uint3 tgpig[[threadgroup_position_in_grid]],
74
+ uint3 tpitg[[thread_position_in_threadgroup]],
75
+ uint3 ntg[[threads_per_threadgroup]]) {
76
+ const int64_t i03 = tgpig.z;
77
+ const int64_t i02 = tgpig.y;
78
+ const int64_t i01 = tgpig.x;
79
+
80
+ const int64_t i13 = i03 % ne13;
81
+ const int64_t i12 = i02 % ne12;
82
+ const int64_t i11 = i01 % ne11;
83
+
84
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
85
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
86
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
87
+
88
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
89
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
90
+
91
+ src0_ptr += ntg.x*nb00;
92
+ src1_ptr += ntg.x*nb10;
93
+ dst_ptr += ntg.x*nb0;
94
+ }
33
95
  }
34
96
 
35
97
  // assumption: src1 is a row
@@ -38,7 +100,7 @@ kernel void kernel_add_row(
38
100
  device const float4 * src0,
39
101
  device const float4 * src1,
40
102
  device float4 * dst,
41
- constant int64_t & nb,
103
+ constant int64_t & nb [[buffer(27)]],
42
104
  uint tpig[[thread_position_in_grid]]) {
43
105
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
106
  }
@@ -63,9 +125,17 @@ kernel void kernel_mul_row(
63
125
  }
64
126
 
65
127
  kernel void kernel_scale(
128
+ device const float * src0,
129
+ device float * dst,
130
+ constant float & scale,
131
+ uint tpig[[thread_position_in_grid]]) {
132
+ dst[tpig] = src0[tpig] * scale;
133
+ }
134
+
135
+ kernel void kernel_scale_4(
66
136
  device const float4 * src0,
67
137
  device float4 * dst,
68
- constant float & scale,
138
+ constant float & scale,
69
139
  uint tpig[[thread_position_in_grid]]) {
70
140
  dst[tpig] = src0[tpig] * scale;
71
141
  }
@@ -85,6 +155,13 @@ kernel void kernel_relu(
85
155
  dst[tpig] = max(0.0f, src0[tpig]);
86
156
  }
87
157
 
158
+ kernel void kernel_sqr(
159
+ device const float * src0,
160
+ device float * dst,
161
+ uint tpig[[thread_position_in_grid]]) {
162
+ dst[tpig] = src0[tpig] * src0[tpig];
163
+ }
164
+
88
165
  constant float GELU_COEF_A = 0.044715f;
89
166
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
167
 
@@ -107,36 +184,73 @@ kernel void kernel_soft_max(
107
184
  constant int64_t & ne00,
108
185
  constant int64_t & ne01,
109
186
  constant int64_t & ne02,
110
- uint3 tgpig[[threadgroup_position_in_grid]],
111
- uint3 tpitg[[thread_position_in_threadgroup]],
112
- uint3 ntg[[threads_per_threadgroup]]) {
113
- const int64_t i03 = tgpig[2];
114
- const int64_t i02 = tgpig[1];
115
- const int64_t i01 = tgpig[0];
187
+ threadgroup float * buf [[threadgroup(0)]],
188
+ uint tgpig[[threadgroup_position_in_grid]],
189
+ uint tpitg[[thread_position_in_threadgroup]],
190
+ uint sgitg[[simdgroup_index_in_threadgroup]],
191
+ uint tiisg[[thread_index_in_simdgroup]],
192
+ uint ntg[[threads_per_threadgroup]]) {
193
+ const int64_t i03 = (tgpig) / (ne02*ne01);
194
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
116
196
 
117
197
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
118
198
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
199
 
120
200
  // parallel max
121
- float lmax = psrc0[tpitg[0]];
122
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
201
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
+
203
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
123
204
  lmax = MAX(lmax, psrc0[i00]);
124
205
  }
125
- const float max = simd_max(lmax);
206
+
207
+ float max = simd_max(lmax);
208
+ if (tiisg == 0) {
209
+ buf[sgitg] = max;
210
+ }
211
+
212
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
+
214
+ // broadcast, simd group number is ntg / 32
215
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
+ if (tpitg < i) {
217
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
+ }
219
+ }
220
+
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+
223
+ max = buf[0];
126
224
 
127
225
  // parallel sum
128
226
  float lsum = 0.0f;
129
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
227
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
130
228
  const float exp_psrc0 = exp(psrc0[i00] - max);
131
229
  lsum += exp_psrc0;
132
230
  // Remember the result of exp here. exp is expensive, so we really do not
133
- // whish to compute it twice.
231
+ // wish to compute it twice.
134
232
  pdst[i00] = exp_psrc0;
135
233
  }
136
234
 
137
- const float sum = simd_sum(lsum);
235
+ float sum = simd_sum(lsum);
236
+ if (tiisg == 0) {
237
+ buf[sgitg] = sum;
238
+ }
239
+
240
+ threadgroup_barrier(mem_flags::mem_threadgroup);
241
+
242
+ // broadcast, simd group number is ntg / 32
243
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
+ if (tpitg < i) {
245
+ buf[tpitg] += buf[tpitg + i];
246
+ }
247
+ }
248
+
249
+ threadgroup_barrier(mem_flags::mem_threadgroup);
250
+
251
+ sum = buf[0];
138
252
 
139
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
253
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
140
254
  pdst[i00] /= sum;
141
255
  }
142
256
  }
@@ -147,37 +261,73 @@ kernel void kernel_soft_max_4(
147
261
  constant int64_t & ne00,
148
262
  constant int64_t & ne01,
149
263
  constant int64_t & ne02,
150
- uint3 tgpig[[threadgroup_position_in_grid]],
151
- uint3 tpitg[[thread_position_in_threadgroup]],
152
- uint3 ntg[[threads_per_threadgroup]]) {
153
- const int64_t i03 = tgpig[2];
154
- const int64_t i02 = tgpig[1];
155
- const int64_t i01 = tgpig[0];
264
+ threadgroup float * buf [[threadgroup(0)]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ const int64_t i03 = (tgpig) / (ne02*ne01);
271
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
156
273
 
157
274
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
275
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
276
 
160
277
  // parallel max
161
- float4 lmax4 = psrc4[tpitg[0]];
162
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
278
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
+
280
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
163
281
  lmax4 = fmax(lmax4, psrc4[i00]);
164
282
  }
165
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
166
283
 
167
- const float max = simd_max(lmax);
284
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
+ float max = simd_max(lmax);
286
+ if (tiisg == 0) {
287
+ buf[sgitg] = max;
288
+ }
289
+
290
+ threadgroup_barrier(mem_flags::mem_threadgroup);
291
+
292
+ // broadcast, simd group number is ntg / 32
293
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
+ if (tpitg < i) {
295
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
+ }
297
+ }
298
+
299
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
+
301
+ max = buf[0];
168
302
 
169
303
  // parallel sum
170
304
  float4 lsum4 = 0.0f;
171
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
305
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
172
306
  const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
307
  lsum4 += exp_psrc4;
174
308
  pdst4[i00] = exp_psrc4;
175
309
  }
176
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
177
310
 
178
- const float sum = simd_sum(lsum);
311
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
+ float sum = simd_sum(lsum);
313
+ if (tiisg == 0) {
314
+ buf[sgitg] = sum;
315
+ }
316
+
317
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
+
319
+ // broadcast, simd group number is ntg / 32
320
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
+ if (tpitg < i) {
322
+ buf[tpitg] += buf[tpitg + i];
323
+ }
324
+ }
325
+
326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
327
+
328
+ sum = buf[0];
179
329
 
180
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
330
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
181
331
  pdst4[i00] /= sum;
182
332
  }
183
333
  }
@@ -197,7 +347,7 @@ kernel void kernel_diag_mask_inf(
197
347
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
198
348
  } else {
199
349
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
- }
350
+ }
201
351
  }
202
352
 
203
353
  kernel void kernel_diag_mask_inf_8(
@@ -291,10 +441,11 @@ kernel void kernel_rms_norm(
291
441
  uint sgitg[[simdgroup_index_in_threadgroup]],
292
442
  uint tiisg[[thread_index_in_simdgroup]],
293
443
  uint ntg[[threads_per_threadgroup]]) {
294
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
295
- device const float * x_scalar = (device const float *) x;
296
- float4 sumf=0;
297
- float all_sum=0;
444
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
+ device const float * x_scalar = (device const float *) x;
446
+
447
+ float4 sumf = 0;
448
+ float all_sum = 0;
298
449
 
299
450
  // parallel sum
300
451
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -307,6 +458,7 @@ kernel void kernel_rms_norm(
307
458
  }
308
459
 
309
460
  threadgroup_barrier(mem_flags::mem_threadgroup);
461
+
310
462
  // broadcast, simd group number is ntg / 32
311
463
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
312
464
  if (tpitg < i) {
@@ -314,7 +466,9 @@ kernel void kernel_rms_norm(
314
466
  }
315
467
  }
316
468
  if (tpitg == 0) {
317
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
469
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
+ sum[0] += x_scalar[i];
471
+ }
318
472
  sum[0] /= ne00;
319
473
  }
320
474
 
@@ -329,7 +483,9 @@ kernel void kernel_rms_norm(
329
483
  y[i00] = x[i00] * scale;
330
484
  }
331
485
  if (tpitg == 0) {
332
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
486
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
+ y_scalar[i00] = x_scalar[i00] * scale;
488
+ }
333
489
  }
334
490
  }
335
491
 
@@ -339,8 +495,11 @@ kernel void kernel_rms_norm(
339
495
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
340
496
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
341
497
  float d = qb_curr->d;
498
+
342
499
  float2 acc = 0.f;
500
+
343
501
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
502
+
344
503
  for (int i = 0; i < 8; i+=2) {
345
504
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
346
505
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -357,8 +516,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
357
516
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
358
517
  float d = qb_curr->d;
359
518
  float m = qb_curr->m;
360
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
519
+
361
520
  float2 acc = 0.f;
521
+
522
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
523
+
362
524
  for (int i = 0; i < 8; i+=2) {
363
525
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
364
526
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -368,9 +530,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
368
530
  return d * (acc[0] + acc[1]) + sumy * m;
369
531
  }
370
532
 
533
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
534
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
535
+ // we assume that the yl's have been multiplied with the appropriate scale factor
536
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
537
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
538
+ float d = qb_curr->d;
539
+
540
+ float2 acc = 0.f;
541
+
542
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
543
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
544
+
545
+ for (int i = 0; i < 8; i+=2) {
546
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
547
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
548
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
549
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
550
+ }
551
+ return d * (sumy * -16.f + acc[0] + acc[1]);
552
+ }
553
+
554
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
555
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
556
+ // we assume that the yl's have been multiplied with the appropriate scale factor
557
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
558
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
559
+ float d = qb_curr->d;
560
+ float m = qb_curr->m;
561
+
562
+ float2 acc = 0.f;
563
+
564
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
565
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
566
+
567
+ for (int i = 0; i < 8; i+=2) {
568
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
569
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
570
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
571
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
572
+ }
573
+ return d * (acc[0] + acc[1]) + sumy * m;
574
+ }
575
+
371
576
  // putting them in the kernel cause a significant performance penalty
372
- #define N_DST 4 // each SIMD group works on 4 rows
373
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
577
+ #define N_DST 4 // each SIMD group works on 4 rows
578
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
374
579
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
375
580
  //Note: This is a template, but strictly speaking it only applies to
376
581
  // quantizations where the block size is 32. It also does not
@@ -381,18 +586,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
381
586
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
382
587
  uint3 tgpig, uint tiisg, uint sgitg) {
383
588
  const int nb = ne00/QK4_0;
589
+
384
590
  const int r0 = tgpig.x;
385
591
  const int r1 = tgpig.y;
386
592
  const int im = tgpig.z;
593
+
387
594
  const int first_row = (r0 * nsg + sgitg) * nr;
595
+
388
596
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
597
+
389
598
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
390
599
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
391
- float yl[16]; // src1 vector cache
392
- float sumf[nr]={0.f};
393
600
 
394
- const int ix = tiisg/2;
395
- const int il = 8*(tiisg%2);
601
+ float yl[16]; // src1 vector cache
602
+ float sumf[nr] = {0.f};
603
+
604
+ const int ix = (tiisg/2);
605
+ const int il = (tiisg%2)*8;
396
606
 
397
607
  device const float * yb = y + ix * QK4_0 + il;
398
608
 
@@ -403,6 +613,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
403
613
  sumy += yb[i] + yb[i+1];
404
614
  yl[i+0] = yb[i+ 0];
405
615
  yl[i+1] = yb[i+ 1]/256.f;
616
+
406
617
  sumy += yb[i+16] + yb[i+17];
407
618
  yl[i+8] = yb[i+16]/16.f;
408
619
  yl[i+9] = yb[i+17]/4096.f;
@@ -418,12 +629,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
418
629
  for (int row = 0; row < nr; ++row) {
419
630
  const float tot = simd_sum(sumf[row]);
420
631
  if (tiisg == 0 && first_row + row < ne01) {
421
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
632
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
422
633
  }
423
634
  }
424
635
  }
425
636
 
426
- kernel void kernel_mul_mat_q4_0_f32(
637
+ kernel void kernel_mul_mv_q4_0_f32(
427
638
  device const void * src0,
428
639
  device const float * src1,
429
640
  device float * dst,
@@ -436,12 +647,12 @@ kernel void kernel_mul_mat_q4_0_f32(
436
647
  constant int64_t & ne1[[buffer(16)]],
437
648
  constant uint & gqa[[buffer(17)]],
438
649
  uint3 tgpig[[threadgroup_position_in_grid]],
439
- uint tiisg[[thread_index_in_simdgroup]],
440
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
650
+ uint tiisg[[thread_index_in_simdgroup]],
651
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
441
652
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
442
653
  }
443
654
 
444
- kernel void kernel_mul_mat_q4_1_f32(
655
+ kernel void kernel_mul_mv_q4_1_f32(
445
656
  device const void * src0,
446
657
  device const float * src1,
447
658
  device float * dst,
@@ -459,9 +670,46 @@ kernel void kernel_mul_mat_q4_1_f32(
459
670
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
460
671
  }
461
672
 
673
+ kernel void kernel_mul_mv_q5_0_f32(
674
+ device const void * src0,
675
+ device const float * src1,
676
+ device float * dst,
677
+ constant int64_t & ne00,
678
+ constant int64_t & ne01[[buffer(4)]],
679
+ constant int64_t & ne02[[buffer(5)]],
680
+ constant int64_t & ne10[[buffer(9)]],
681
+ constant int64_t & ne12[[buffer(11)]],
682
+ constant int64_t & ne0[[buffer(15)]],
683
+ constant int64_t & ne1[[buffer(16)]],
684
+ constant uint & gqa[[buffer(17)]],
685
+ uint3 tgpig[[threadgroup_position_in_grid]],
686
+ uint tiisg[[thread_index_in_simdgroup]],
687
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
689
+ }
690
+
691
+ kernel void kernel_mul_mv_q5_1_f32(
692
+ device const void * src0,
693
+ device const float * src1,
694
+ device float * dst,
695
+ constant int64_t & ne00,
696
+ constant int64_t & ne01[[buffer(4)]],
697
+ constant int64_t & ne02[[buffer(5)]],
698
+ constant int64_t & ne10[[buffer(9)]],
699
+ constant int64_t & ne12[[buffer(11)]],
700
+ constant int64_t & ne0[[buffer(15)]],
701
+ constant int64_t & ne1[[buffer(16)]],
702
+ constant uint & gqa[[buffer(17)]],
703
+ uint3 tgpig[[threadgroup_position_in_grid]],
704
+ uint tiisg[[thread_index_in_simdgroup]],
705
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
707
+ }
708
+
709
+
462
710
  #define NB_Q8_0 8
463
711
 
464
- kernel void kernel_mul_mat_q8_0_f32(
712
+ kernel void kernel_mul_mv_q8_0_f32(
465
713
  device const void * src0,
466
714
  device const float * src1,
467
715
  device float * dst,
@@ -525,7 +773,7 @@ kernel void kernel_mul_mat_q8_0_f32(
525
773
 
526
774
  #define N_F32_F32 4
527
775
 
528
- kernel void kernel_mul_mat_f32_f32(
776
+ kernel void kernel_mul_mv_f32_f32(
529
777
  device const char * src0,
530
778
  device const char * src1,
531
779
  device float * dst,
@@ -596,7 +844,7 @@ kernel void kernel_mul_mat_f32_f32(
596
844
  }
597
845
  }
598
846
 
599
- kernel void kernel_mul_mat_f16_f32_1row(
847
+ kernel void kernel_mul_mv_f16_f32_1row(
600
848
  device const char * src0,
601
849
  device const char * src1,
602
850
  device float * dst,
@@ -615,7 +863,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
615
863
  constant int64_t & ne0,
616
864
  constant int64_t & ne1,
617
865
  uint3 tgpig[[threadgroup_position_in_grid]],
618
- uint tiisg[[thread_index_in_simdgroup]]) {
866
+ uint tiisg[[thread_index_in_simdgroup]]) {
619
867
 
620
868
  const int64_t r0 = tgpig.x;
621
869
  const int64_t r1 = tgpig.y;
@@ -650,7 +898,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
650
898
 
651
899
  #define N_F16_F32 4
652
900
 
653
- kernel void kernel_mul_mat_f16_f32(
901
+ kernel void kernel_mul_mv_f16_f32(
654
902
  device const char * src0,
655
903
  device const char * src1,
656
904
  device float * dst,
@@ -722,7 +970,7 @@ kernel void kernel_mul_mat_f16_f32(
722
970
  }
723
971
 
724
972
  // Assumes row size (ne00) is a multiple of 4
725
- kernel void kernel_mul_mat_f16_f32_l4(
973
+ kernel void kernel_mul_mv_f16_f32_l4(
726
974
  device const char * src0,
727
975
  device const char * src1,
728
976
  device float * dst,
@@ -783,7 +1031,9 @@ kernel void kernel_alibi_f32(
783
1031
  constant uint64_t & nb1,
784
1032
  constant uint64_t & nb2,
785
1033
  constant uint64_t & nb3,
786
- constant float & m0,
1034
+ constant float & m0,
1035
+ constant float & m1,
1036
+ constant int & n_heads_log2_floor,
787
1037
  uint3 tgpig[[threadgroup_position_in_grid]],
788
1038
  uint3 tpitg[[thread_position_in_threadgroup]],
789
1039
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -799,37 +1049,122 @@ kernel void kernel_alibi_f32(
799
1049
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
800
1050
 
801
1051
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
- float m_k = pow(m0, i2 + 1);
1052
+ float m_k;
1053
+ if (i2 < n_heads_log2_floor) {
1054
+ m_k = pow(m0, i2 + 1);
1055
+ } else {
1056
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1057
+ }
803
1058
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
804
1059
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
805
1060
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
806
1061
  }
807
1062
  }
808
1063
 
1064
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1065
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
1066
+ return 1.0f - min(1.0f, max(0.0f, y));
1067
+ }
1068
+
1069
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1070
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1071
+ static void rope_yarn(
1072
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1073
+ thread float * cos_theta, thread float * sin_theta
1074
+ ) {
1075
+ // Get n-d rotational scaling corrected for extrapolation
1076
+ float theta_interp = freq_scale * theta_extrap;
1077
+ float theta = theta_interp;
1078
+ if (ext_factor != 0.0f) {
1079
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1080
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1081
+
1082
+ // Get n-d magnitude scaling corrected for interpolation
1083
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1084
+ }
1085
+ *cos_theta = cos(theta) * mscale;
1086
+ *sin_theta = sin(theta) * mscale;
1087
+ }
1088
+
1089
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1090
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1091
+ static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1092
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1093
+ }
1094
+
1095
+ static void rope_yarn_corr_dims(
1096
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1097
+ ) {
1098
+ // start and end correction dims
1099
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1100
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1101
+ }
1102
+
1103
+ typedef void (rope_t)(
1104
+ device const void * src0,
1105
+ device const int32_t * src1,
1106
+ device float * dst,
1107
+ constant int64_t & ne00,
1108
+ constant int64_t & ne01,
1109
+ constant int64_t & ne02,
1110
+ constant int64_t & ne03,
1111
+ constant uint64_t & nb00,
1112
+ constant uint64_t & nb01,
1113
+ constant uint64_t & nb02,
1114
+ constant uint64_t & nb03,
1115
+ constant int64_t & ne0,
1116
+ constant int64_t & ne1,
1117
+ constant int64_t & ne2,
1118
+ constant int64_t & ne3,
1119
+ constant uint64_t & nb0,
1120
+ constant uint64_t & nb1,
1121
+ constant uint64_t & nb2,
1122
+ constant uint64_t & nb3,
1123
+ constant int & n_past,
1124
+ constant int & n_dims,
1125
+ constant int & mode,
1126
+ constant int & n_orig_ctx,
1127
+ constant float & freq_base,
1128
+ constant float & freq_scale,
1129
+ constant float & ext_factor,
1130
+ constant float & attn_factor,
1131
+ constant float & beta_fast,
1132
+ constant float & beta_slow,
1133
+ uint tiitg[[thread_index_in_threadgroup]],
1134
+ uint3 tptg[[threads_per_threadgroup]],
1135
+ uint3 tgpig[[threadgroup_position_in_grid]]);
1136
+
1137
+ template<typename T>
809
1138
  kernel void kernel_rope(
810
- device const void * src0,
811
- device float * dst,
812
- constant int64_t & ne00,
813
- constant int64_t & ne01,
814
- constant int64_t & ne02,
815
- constant int64_t & ne03,
816
- constant uint64_t & nb00,
817
- constant uint64_t & nb01,
818
- constant uint64_t & nb02,
819
- constant uint64_t & nb03,
820
- constant int64_t & ne0,
821
- constant int64_t & ne1,
822
- constant int64_t & ne2,
823
- constant int64_t & ne3,
824
- constant uint64_t & nb0,
825
- constant uint64_t & nb1,
826
- constant uint64_t & nb2,
827
- constant uint64_t & nb3,
828
- constant int & n_past,
829
- constant int & n_dims,
830
- constant int & mode,
831
- constant float & freq_base,
832
- constant float & freq_scale,
1139
+ device const void * src0,
1140
+ device const int32_t * src1,
1141
+ device float * dst,
1142
+ constant int64_t & ne00,
1143
+ constant int64_t & ne01,
1144
+ constant int64_t & ne02,
1145
+ constant int64_t & ne03,
1146
+ constant uint64_t & nb00,
1147
+ constant uint64_t & nb01,
1148
+ constant uint64_t & nb02,
1149
+ constant uint64_t & nb03,
1150
+ constant int64_t & ne0,
1151
+ constant int64_t & ne1,
1152
+ constant int64_t & ne2,
1153
+ constant int64_t & ne3,
1154
+ constant uint64_t & nb0,
1155
+ constant uint64_t & nb1,
1156
+ constant uint64_t & nb2,
1157
+ constant uint64_t & nb3,
1158
+ constant int & n_past,
1159
+ constant int & n_dims,
1160
+ constant int & mode,
1161
+ constant int & n_orig_ctx,
1162
+ constant float & freq_base,
1163
+ constant float & freq_scale,
1164
+ constant float & ext_factor,
1165
+ constant float & attn_factor,
1166
+ constant float & beta_fast,
1167
+ constant float & beta_slow,
833
1168
  uint tiitg[[thread_index_in_threadgroup]],
834
1169
  uint3 tptg[[threads_per_threadgroup]],
835
1170
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -839,23 +1174,28 @@ kernel void kernel_rope(
839
1174
 
840
1175
  const bool is_neox = mode & 2;
841
1176
 
842
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
1177
+ float corr_dims[2];
1178
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
843
1179
 
844
- const float theta_0 = freq_scale * (float)p;
1180
+ device const int32_t * pos = src1;
1181
+
1182
+ const int64_t p = pos[i2];
1183
+
1184
+ const float theta_0 = (float)p;
845
1185
  const float inv_ndims = -1.f/n_dims;
846
1186
 
847
1187
  if (!is_neox) {
848
1188
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
849
1189
 
850
1190
  const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
851
- const float cos_theta = cos(theta);
852
- const float sin_theta = sin(theta);
1191
+ float cos_theta, sin_theta;
1192
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
853
1193
 
854
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
855
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1194
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1195
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
856
1196
 
857
- const float x0 = src[0];
858
- const float x1 = src[1];
1197
+ const T x0 = src[0];
1198
+ const T x1 = src[1];
859
1199
 
860
1200
  dst_data[0] = x0*cos_theta - x1*sin_theta;
861
1201
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -864,14 +1204,17 @@ kernel void kernel_rope(
864
1204
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
865
1205
  for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
866
1206
 
867
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
868
- const float cos_theta = cos(theta);
869
- const float sin_theta = sin(theta);
1207
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1208
+ const float cur_rot = inv_ndims*ic - ib;
1209
+
1210
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1211
+ float cos_theta, sin_theta;
1212
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
870
1213
 
871
1214
  const int64_t i0 = ib*n_dims + ic/2;
872
1215
 
873
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
874
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1216
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1217
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
875
1218
 
876
1219
  const float x0 = src[0];
877
1220
  const float x1 = src[n_dims/2];
@@ -883,6 +1226,9 @@ kernel void kernel_rope(
883
1226
  }
884
1227
  }
885
1228
 
1229
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1231
+
886
1232
  kernel void kernel_cpy_f16_f16(
887
1233
  device const half * src0,
888
1234
  device half * dst,
@@ -1008,6 +1354,62 @@ kernel void kernel_cpy_f32_f32(
1008
1354
  }
1009
1355
  }
1010
1356
 
1357
+ kernel void kernel_concat(
1358
+ device const char * src0,
1359
+ device const char * src1,
1360
+ device char * dst,
1361
+ constant int64_t & ne00,
1362
+ constant int64_t & ne01,
1363
+ constant int64_t & ne02,
1364
+ constant int64_t & ne03,
1365
+ constant uint64_t & nb00,
1366
+ constant uint64_t & nb01,
1367
+ constant uint64_t & nb02,
1368
+ constant uint64_t & nb03,
1369
+ constant int64_t & ne10,
1370
+ constant int64_t & ne11,
1371
+ constant int64_t & ne12,
1372
+ constant int64_t & ne13,
1373
+ constant uint64_t & nb10,
1374
+ constant uint64_t & nb11,
1375
+ constant uint64_t & nb12,
1376
+ constant uint64_t & nb13,
1377
+ constant int64_t & ne0,
1378
+ constant int64_t & ne1,
1379
+ constant int64_t & ne2,
1380
+ constant int64_t & ne3,
1381
+ constant uint64_t & nb0,
1382
+ constant uint64_t & nb1,
1383
+ constant uint64_t & nb2,
1384
+ constant uint64_t & nb3,
1385
+ uint3 tgpig[[threadgroup_position_in_grid]],
1386
+ uint3 tpitg[[thread_position_in_threadgroup]],
1387
+ uint3 ntg[[threads_per_threadgroup]]) {
1388
+
1389
+ const int64_t i03 = tgpig.z;
1390
+ const int64_t i02 = tgpig.y;
1391
+ const int64_t i01 = tgpig.x;
1392
+
1393
+ const int64_t i13 = i03 % ne13;
1394
+ const int64_t i12 = i02 % ne12;
1395
+ const int64_t i11 = i01 % ne11;
1396
+
1397
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1398
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1399
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1400
+
1401
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1402
+ if (i02 < ne02) {
1403
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1404
+ src0_ptr += ntg.x*nb00;
1405
+ } else {
1406
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1407
+ src1_ptr += ntg.x*nb10;
1408
+ }
1409
+ dst_ptr += ntg.x*nb0;
1410
+ }
1411
+ }
1412
+
1011
1413
  //============================================ k-quants ======================================================
1012
1414
 
1013
1415
  #ifndef QK_K
@@ -1100,7 +1502,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1100
1502
 
1101
1503
  //====================================== dot products =========================
1102
1504
 
1103
- kernel void kernel_mul_mat_q2_K_f32(
1505
+ kernel void kernel_mul_mv_q2_K_f32(
1104
1506
  device const void * src0,
1105
1507
  device const float * src1,
1106
1508
  device float * dst,
@@ -1244,7 +1646,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1244
1646
  }
1245
1647
 
1246
1648
  #if QK_K == 256
1247
- kernel void kernel_mul_mat_q3_K_f32(
1649
+ kernel void kernel_mul_mv_q3_K_f32(
1248
1650
  device const void * src0,
1249
1651
  device const float * src1,
1250
1652
  device float * dst,
@@ -1273,8 +1675,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1273
1675
 
1274
1676
  float yl[32];
1275
1677
 
1276
- const uint16_t kmask1 = 0x3030;
1277
- const uint16_t kmask2 = 0x0f0f;
1678
+ //const uint16_t kmask1 = 0x3030;
1679
+ //const uint16_t kmask2 = 0x0f0f;
1278
1680
 
1279
1681
  const int tid = tiisg/4;
1280
1682
  const int ix = tiisg%4;
@@ -1396,7 +1798,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1396
1798
  }
1397
1799
  }
1398
1800
  #else
1399
- kernel void kernel_mul_mat_q3_K_f32(
1801
+ kernel void kernel_mul_mv_q3_K_f32(
1400
1802
  device const void * src0,
1401
1803
  device const float * src1,
1402
1804
  device float * dst,
@@ -1467,7 +1869,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1467
1869
  #endif
1468
1870
 
1469
1871
  #if QK_K == 256
1470
- kernel void kernel_mul_mat_q4_K_f32(
1872
+ kernel void kernel_mul_mv_q4_K_f32(
1471
1873
  device const void * src0,
1472
1874
  device const float * src1,
1473
1875
  device float * dst,
@@ -1573,7 +1975,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1573
1975
  }
1574
1976
  }
1575
1977
  #else
1576
- kernel void kernel_mul_mat_q4_K_f32(
1978
+ kernel void kernel_mul_mv_q4_K_f32(
1577
1979
  device const void * src0,
1578
1980
  device const float * src1,
1579
1981
  device float * dst,
@@ -1662,7 +2064,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1662
2064
  }
1663
2065
  #endif
1664
2066
 
1665
- kernel void kernel_mul_mat_q5_K_f32(
2067
+ kernel void kernel_mul_mv_q5_K_f32(
1666
2068
  device const void * src0,
1667
2069
  device const float * src1,
1668
2070
  device float * dst,
@@ -1835,7 +2237,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1835
2237
 
1836
2238
  }
1837
2239
 
1838
- kernel void kernel_mul_mat_q6_K_f32(
2240
+ kernel void kernel_mul_mv_q6_K_f32(
1839
2241
  device const void * src0,
1840
2242
  device const float * src1,
1841
2243
  device float * dst,
@@ -1984,6 +2386,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1984
2386
  }
1985
2387
  }
1986
2388
 
2389
+ template <typename type4x4>
2390
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2391
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2392
+ const float d = xb->d;
2393
+ const float md = -16.h * xb->d;
2394
+ const ushort mask = il ? 0x00F0 : 0x000F;
2395
+
2396
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2397
+
2398
+ const int x_mv = il ? 4 : 0;
2399
+
2400
+ const int gh_mv = il ? 12 : 0;
2401
+ const int gh_bk = il ? 0 : 4;
2402
+
2403
+ for (int i = 0; i < 8; i++) {
2404
+ // extract the 5-th bits for x0 and x1
2405
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2406
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2407
+
2408
+ // combine the 4-bits from qs with the 5th bit
2409
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2410
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2411
+
2412
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2413
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2414
+ }
2415
+ }
2416
+
2417
+ template <typename type4x4>
2418
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2419
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2420
+ const float d = xb->d;
2421
+ const float m = xb->m;
2422
+ const ushort mask = il ? 0x00F0 : 0x000F;
2423
+
2424
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2425
+
2426
+ const int x_mv = il ? 4 : 0;
2427
+
2428
+ const int gh_mv = il ? 12 : 0;
2429
+ const int gh_bk = il ? 0 : 4;
2430
+
2431
+ for (int i = 0; i < 8; i++) {
2432
+ // extract the 5-th bits for x0 and x1
2433
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2434
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2435
+
2436
+ // combine the 4-bits from qs with the 5th bit
2437
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2438
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2439
+
2440
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2441
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2442
+ }
2443
+ }
2444
+
1987
2445
  template <typename type4x4>
1988
2446
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1989
2447
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2173,7 +2631,7 @@ kernel void kernel_get_rows(
2173
2631
  }
2174
2632
 
2175
2633
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2176
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2634
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2177
2635
  #define BLOCK_SIZE_K 32
2178
2636
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2179
2637
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2210,9 +2668,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2210
2668
  const uint r0 = tgpig.y;
2211
2669
  const uint r1 = tgpig.x;
2212
2670
  const uint im = tgpig.z;
2671
+
2213
2672
  // if this block is of 64x32 shape or smaller
2214
2673
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2215
2674
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2675
+
2216
2676
  // a thread shouldn't load data outside of the matrix
2217
2677
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2218
2678
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2236,26 +2696,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2236
2696
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2237
2697
 
2238
2698
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2239
- //load data and store to threadgroup memory
2699
+ // load data and store to threadgroup memory
2240
2700
  half4x4 temp_a;
2241
2701
  dequantize_func(x, il, temp_a);
2242
2702
  threadgroup_barrier(mem_flags::mem_threadgroup);
2703
+
2243
2704
  #pragma unroll(16)
2244
2705
  for (int i = 0; i < 16; i++) {
2245
2706
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2246
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2247
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2707
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2708
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2248
2709
  }
2249
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2250
- = *((device float2x4 *)y);
2710
+
2711
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2712
+
2251
2713
  il = (il + 2 < nl) ? il + 2 : il % 2;
2252
2714
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2253
2715
  y += BLOCK_SIZE_K;
2254
2716
 
2255
2717
  threadgroup_barrier(mem_flags::mem_threadgroup);
2256
- //load matrices from threadgroup memory and conduct outer products
2718
+
2719
+ // load matrices from threadgroup memory and conduct outer products
2257
2720
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2258
2721
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2722
+
2259
2723
  #pragma unroll(4)
2260
2724
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2261
2725
  #pragma unroll(4)
@@ -2270,6 +2734,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2270
2734
 
2271
2735
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2272
2736
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2737
+
2273
2738
  #pragma unroll(8)
2274
2739
  for (int i = 0; i < 8; i++){
2275
2740
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2278,25 +2743,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2278
2743
  }
2279
2744
 
2280
2745
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2281
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2282
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2746
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2747
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2283
2748
  for (int i = 0; i < 8; i++) {
2284
2749
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2285
2750
  }
2286
2751
  } else {
2287
2752
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2288
2753
  threadgroup_barrier(mem_flags::mem_threadgroup);
2289
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2754
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2290
2755
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2291
2756
  for (int i = 0; i < 8; i++) {
2292
2757
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2293
2758
  }
2294
2759
 
2295
2760
  threadgroup_barrier(mem_flags::mem_threadgroup);
2296
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2297
- if (sgitg==0) {
2761
+
2762
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2763
+ if (sgitg == 0) {
2298
2764
  for (int i = 0; i < n_rows; i++) {
2299
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2765
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2300
2766
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2301
2767
  }
2302
2768
  }
@@ -2317,6 +2783,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2317
2783
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2318
2784
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2319
2785
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2786
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2787
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2320
2788
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2321
2789
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2322
2790
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2345,6 +2813,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2345
2813
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2346
2814
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2347
2815
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2816
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2817
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2348
2818
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2349
2819
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2350
2820
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;