llama_cpp 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1133 @@
1
+ #include <metal_stdlib>
2
+
3
+ using namespace metal;
4
+
5
+ #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+
7
+ #define QK4_0 32
8
+ #define QR4_0 2
9
+ typedef struct {
10
+ half d; // delta
11
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
12
+ } block_q4_0;
13
+
14
+ #define QK4_1 32
15
+ typedef struct {
16
+ half d; // delta
17
+ half m; // min
18
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
+ } block_q4_1;
20
+
21
+ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
22
+ const int qk = QK4_0;
23
+
24
+ assert(k % qk == 0);
25
+
26
+ const int nb = k / qk;
27
+
28
+ for (int i = 0; i < nb; i++) {
29
+ const half d = x[i].d;
30
+
31
+ for (int j = 0; j < qk/2; ++j) {
32
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
33
+ const int x1 = (x[i].qs[j] >> 4) - 8;
34
+
35
+ y[i*qk + j + 0 ] = x0*d;
36
+ y[i*qk + j + qk/2] = x1*d;
37
+ }
38
+ }
39
+ }
40
+
41
+ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42
+ const int qk = QK4_1;
43
+
44
+ assert(k % qk == 0);
45
+
46
+ const int nb = k / qk;
47
+
48
+ for (int i = 0; i < nb; i++) {
49
+ const half d = x[i].d;
50
+ const half m = x[i].m;
51
+
52
+ for (int j = 0; j < qk/2; ++j) {
53
+ const int x0 = (x[i].qs[j] & 0x0F);
54
+ const int x1 = (x[i].qs[j] >> 4);
55
+
56
+ y[i*qk + j + 0 ] = x0*d + m;
57
+ y[i*qk + j + qk/2] = x1*d + m;
58
+ }
59
+ }
60
+ }
61
+
62
+ kernel void kernel_add(
63
+ device const float * src0,
64
+ device const float * src1,
65
+ device float * dst,
66
+ uint tpig[[thread_position_in_grid]]) {
67
+ dst[tpig] = src0[tpig] + src1[tpig];
68
+ }
69
+
70
+ kernel void kernel_mul(
71
+ device const float * src0,
72
+ device const float * src1,
73
+ device float * dst,
74
+ uint tpig[[thread_position_in_grid]]) {
75
+ dst[tpig] = src0[tpig] * src1[tpig];
76
+ }
77
+
78
+ // assumption: src1 is a row
79
+ // broadcast src1 into src0
80
+ kernel void kernel_mul_row(
81
+ device const float * src0,
82
+ device const float * src1,
83
+ device float * dst,
84
+ constant int64_t & ne00,
85
+ uint tpig[[thread_position_in_grid]]) {
86
+ dst[tpig] = src0[tpig] * src1[tpig % ne00];
87
+ }
88
+
89
+ kernel void kernel_scale(
90
+ device const float * src0,
91
+ device float * dst,
92
+ constant float & scale,
93
+ uint tpig[[thread_position_in_grid]]) {
94
+ dst[tpig] = src0[tpig] * scale;
95
+ }
96
+
97
+ kernel void kernel_silu(
98
+ device const float * src0,
99
+ device float * dst,
100
+ uint tpig[[thread_position_in_grid]]) {
101
+ float x = src0[tpig];
102
+ dst[tpig] = x / (1.0f + exp(-x));
103
+ }
104
+
105
+ kernel void kernel_relu(
106
+ device const float * src0,
107
+ device float * dst,
108
+ uint tpig[[thread_position_in_grid]]) {
109
+ dst[tpig] = max(0.0f, src0[tpig]);
110
+ }
111
+
112
+ constant float GELU_COEF_A = 0.044715f;
113
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
114
+
115
+ kernel void kernel_gelu(
116
+ device const float * src0,
117
+ device float * dst,
118
+ uint tpig[[thread_position_in_grid]]) {
119
+ float x = src0[tpig];
120
+ dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
121
+ }
122
+
123
+ kernel void kernel_soft_max(
124
+ device const float * src0,
125
+ device float * dst,
126
+ constant int64_t & ne00,
127
+ constant int64_t & ne01,
128
+ constant int64_t & ne02,
129
+ threadgroup float * buf [[threadgroup(0)]],
130
+ uint3 tgpig[[threadgroup_position_in_grid]],
131
+ uint3 tpitg[[thread_position_in_threadgroup]],
132
+ uint3 ntg[[threads_per_threadgroup]]) {
133
+ const int64_t i03 = tgpig[2];
134
+ const int64_t i02 = tgpig[1];
135
+ const int64_t i01 = tgpig[0];
136
+
137
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
138
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
139
+
140
+ // parallel max
141
+ buf[tpitg[0]] = -INFINITY;
142
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
143
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
144
+ }
145
+
146
+ // reduce
147
+ threadgroup_barrier(mem_flags::mem_threadgroup);
148
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
149
+ if (tpitg[0] < i) {
150
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
151
+ }
152
+ threadgroup_barrier(mem_flags::mem_threadgroup);
153
+ }
154
+
155
+ // broadcast
156
+ if (tpitg[0] == 0) {
157
+ buf[0] = buf[0];
158
+ }
159
+
160
+ threadgroup_barrier(mem_flags::mem_threadgroup);
161
+
162
+ const float max = buf[0];
163
+
164
+ // parallel sum
165
+ buf[tpitg[0]] = 0.0f;
166
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
167
+ buf[tpitg[0]] += exp(psrc0[i00] - max);
168
+ }
169
+
170
+ // reduce
171
+ threadgroup_barrier(mem_flags::mem_threadgroup);
172
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
173
+ if (tpitg[0] < i) {
174
+ buf[tpitg[0]] += buf[tpitg[0] + i];
175
+ }
176
+ threadgroup_barrier(mem_flags::mem_threadgroup);
177
+ }
178
+
179
+ // broadcast
180
+ if (tpitg[0] == 0) {
181
+ buf[0] = buf[0];
182
+ }
183
+
184
+ threadgroup_barrier(mem_flags::mem_threadgroup);
185
+
186
+ const float sum = buf[0];
187
+
188
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
189
+ pdst[i00] = exp(psrc0[i00] - max) / sum;
190
+ }
191
+ }
192
+
193
+ kernel void kernel_diag_mask_inf(
194
+ device const float * src0,
195
+ device float * dst,
196
+ constant int64_t & ne00,
197
+ constant int64_t & ne01,
198
+ constant int & n_past,
199
+ uint3 tpig[[thread_position_in_grid]]) {
200
+ const int64_t i02 = tpig[2];
201
+ const int64_t i01 = tpig[1];
202
+ const int64_t i00 = tpig[0];
203
+
204
+ if (i00 > n_past + i01) {
205
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
206
+ } else {
207
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
208
+ }
209
+ }
210
+
211
+ kernel void kernel_get_rows_f16(
212
+ device const void * src0,
213
+ device const int * src1,
214
+ device float * dst,
215
+ constant int64_t & ne00,
216
+ constant uint64_t & nb01,
217
+ constant uint64_t & nb1,
218
+ uint tpig[[thread_position_in_grid]]) {
219
+ const int i = tpig;
220
+ const int r = ((device int32_t *) src1)[i];
221
+
222
+ for (int j = 0; j < ne00; j++) {
223
+ dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
224
+ }
225
+ }
226
+
227
+ kernel void kernel_get_rows_q4_0(
228
+ device const void * src0,
229
+ device const int * src1,
230
+ device float * dst,
231
+ constant int64_t & ne00,
232
+ constant uint64_t & nb01,
233
+ constant uint64_t & nb1,
234
+ uint tpig[[thread_position_in_grid]]) {
235
+ const int i = tpig;
236
+ const int r = ((device int32_t *) src1)[i];
237
+
238
+ dequantize_row_q4_0(
239
+ (device const block_q4_0 *) ((device char *) src0 + r*nb01),
240
+ (device float *) ((device char *) dst + i*nb1), ne00);
241
+ }
242
+
243
+ kernel void kernel_get_rows_q4_1(
244
+ device const void * src0,
245
+ device const int * src1,
246
+ device float * dst,
247
+ constant int64_t & ne00,
248
+ constant uint64_t & nb01,
249
+ constant uint64_t & nb1,
250
+ uint tpig[[thread_position_in_grid]]) {
251
+ const int i = tpig;
252
+ const int r = ((device int32_t *) src1)[i];
253
+
254
+ dequantize_row_q4_1(
255
+ (device const block_q4_1 *) ((device char *) src0 + r*nb01),
256
+ (device float *) ((device char *) dst + i*nb1), ne00);
257
+ }
258
+
259
+ kernel void kernel_rms_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00] * x[i00];
275
+ }
276
+
277
+ // reduce
278
+ threadgroup_barrier(mem_flags::mem_threadgroup);
279
+ for (uint i = ntg/2; i > 0; i /= 2) {
280
+ if (tpitg < i) {
281
+ sum[tpitg] += sum[tpitg + i];
282
+ }
283
+ threadgroup_barrier(mem_flags::mem_threadgroup);
284
+ }
285
+
286
+ // broadcast
287
+ if (tpitg == 0) {
288
+ sum[0] /= ne00;
289
+ }
290
+
291
+ threadgroup_barrier(mem_flags::mem_threadgroup);
292
+
293
+ const float mean = sum[0];
294
+ const float scale = 1.0f/sqrt(mean + eps);
295
+
296
+ device float * y = dst + tgpig*ne00;
297
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
298
+ y[i00] = x[i00] * scale;
299
+ }
300
+ }
301
+
302
+ kernel void kernel_mul_mat_q4_0_f32(
303
+ device const void * src0,
304
+ device const float * src1,
305
+ device float * dst,
306
+ constant int64_t & ne00,
307
+ constant int64_t & ne01,
308
+ constant uint64_t & nb00,
309
+ constant uint64_t & nb01,
310
+ constant uint64_t & nb02,
311
+ constant int64_t & ne10,
312
+ constant int64_t & ne11,
313
+ constant uint64_t & nb10,
314
+ constant uint64_t & nb11,
315
+ constant uint64_t & nb12,
316
+ constant int64_t & ne0,
317
+ constant int64_t & ne1,
318
+ threadgroup float * sum [[threadgroup(0)]],
319
+ uint2 tgpig[[threadgroup_position_in_grid]],
320
+ uint2 tpig[[thread_position_in_grid]],
321
+ uint2 tpitg[[thread_position_in_threadgroup]],
322
+ uint2 tptg[[threads_per_threadgroup]]) {
323
+ const int nb = ne00/QK4_0;
324
+
325
+ const int8_t m8 = 8;
326
+
327
+ const int64_t r0 = tgpig.x;
328
+ const int64_t r1 = tgpig.y;
329
+
330
+ device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
331
+ device const float * y = (device const float *) src1 + r1*ne10;
332
+
333
+ const uint nth = tptg.x*tptg.y;
334
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
335
+
336
+ const int ix = tpitg.y/4; // 0 or 1
337
+ const int iy = tpitg.y - 4*ix; // 0...3
338
+
339
+ const int first = 4 * iy;
340
+
341
+ float sumf = 0;
342
+
343
+ for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
344
+
345
+ const float d = (float)x[i].d;
346
+
347
+ device const uint8_t * xl = x[i].qs + first;
348
+ device const float * yl = y + i * QK4_0 + first;
349
+
350
+ float2 acc = {0.0f, 0.0f};
351
+
352
+ for (int j = 0; j < 4; ++j) {
353
+
354
+ acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
355
+ acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
356
+
357
+ }
358
+
359
+ sumf += d * (acc[0] + acc[1]);
360
+ }
361
+
362
+ sum[ith] = sumf;
363
+
364
+ //
365
+ // Accumulate the sum from all threads in the threadgroup
366
+ // This version is slightly faster than the commented out one below,
367
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
368
+ //
369
+ threadgroup_barrier(mem_flags::mem_threadgroup);
370
+ if (ith%4 == 0) {
371
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
372
+ }
373
+ threadgroup_barrier(mem_flags::mem_threadgroup);
374
+ if (ith%16 == 0) {
375
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
376
+ }
377
+ threadgroup_barrier(mem_flags::mem_threadgroup);
378
+ if (ith == 0) {
379
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
380
+ dst[r1*ne0 + r0] = sum[0];
381
+ }
382
+
383
+ //// accumulate the sum from all threads in the threadgroup
384
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
385
+ //for (uint i = nth/2; i > 0; i /= 2) {
386
+ // if (ith < i) {
387
+ // sum[ith] += sum[ith + i];
388
+ // }
389
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
390
+ //}
391
+
392
+ //if (ith == 0) {
393
+ // dst[r1*ne0 + r0] = sum[0];
394
+ //}
395
+ }
396
+
397
+ kernel void kernel_mul_mat_q4_1_f32(
398
+ device const void * src0,
399
+ device const float * src1,
400
+ device float * dst,
401
+ constant int64_t & ne00,
402
+ constant int64_t & ne01,
403
+ constant uint64_t & nb00,
404
+ constant uint64_t & nb01,
405
+ constant uint64_t & nb02,
406
+ constant int64_t & ne10,
407
+ constant int64_t & ne11,
408
+ constant uint64_t & nb10,
409
+ constant uint64_t & nb11,
410
+ constant uint64_t & nb12,
411
+ constant int64_t & ne0,
412
+ constant int64_t & ne1,
413
+ threadgroup float * sum [[threadgroup(0)]],
414
+ uint2 tgpig[[threadgroup_position_in_grid]],
415
+ uint2 tpig[[thread_position_in_grid]],
416
+ uint2 tpitg[[thread_position_in_threadgroup]],
417
+ uint2 tptg[[threads_per_threadgroup]]) {
418
+ const int nb = ne00/QK4_1;
419
+
420
+ const int64_t r0 = tgpig.x;
421
+ const int64_t r1 = tgpig.y;
422
+
423
+ device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
424
+ device const float * y = (device const float *) src1 + r1*ne10;
425
+
426
+ const uint nth = tptg.x*tptg.y;
427
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
428
+
429
+ const int ix = tpitg.y/4; // 0 or 1
430
+ const int iy = tpitg.y - 4*ix; // 0...3
431
+
432
+ const int first = 4 * iy;
433
+
434
+ float sumf = 0;
435
+
436
+ for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
437
+
438
+ const float d = (float)x[i].d;
439
+ const float m = (float)x[i].m;
440
+
441
+ device const uint8_t * xl = x[i].qs + first;
442
+ device const float * yl = y + i * QK4_1 + first;
443
+
444
+ float2 acc = {0.0f, 0.0f};
445
+
446
+ for (int j = 0; j < 4; ++j) {
447
+
448
+ acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
449
+ acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
450
+
451
+ }
452
+
453
+ sumf += acc[0] + acc[1];
454
+ }
455
+
456
+ sum[ith] = sumf;
457
+
458
+ //
459
+ // Accumulate the sum from all threads in the threadgroup
460
+ //
461
+ threadgroup_barrier(mem_flags::mem_threadgroup);
462
+ if (ith%4 == 0) {
463
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
464
+ }
465
+ threadgroup_barrier(mem_flags::mem_threadgroup);
466
+ if (ith%16 == 0) {
467
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
468
+ }
469
+ threadgroup_barrier(mem_flags::mem_threadgroup);
470
+ if (ith == 0) {
471
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
472
+ dst[r1*ne0 + r0] = sum[0];
473
+ }
474
+ }
475
+
476
+ kernel void kernel_mul_mat_f16_f32(
477
+ device const char * src0,
478
+ device const char * src1,
479
+ device float * dst,
480
+ constant int64_t & ne00,
481
+ constant int64_t & ne01,
482
+ constant uint64_t & nb00,
483
+ constant uint64_t & nb01,
484
+ constant uint64_t & nb02,
485
+ constant int64_t & ne10,
486
+ constant int64_t & ne11,
487
+ constant uint64_t & nb10,
488
+ constant uint64_t & nb11,
489
+ constant uint64_t & nb12,
490
+ constant int64_t & ne0,
491
+ constant int64_t & ne1,
492
+ threadgroup float * sum [[threadgroup(0)]],
493
+ uint3 tgpig[[threadgroup_position_in_grid]],
494
+ uint3 tpig[[thread_position_in_grid]],
495
+ uint3 tpitg[[thread_position_in_threadgroup]],
496
+ uint3 tptg[[threads_per_threadgroup]]) {
497
+
498
+ const int64_t r0 = tgpig.x;
499
+ const int64_t r1 = tgpig.y;
500
+ const int64_t im = tgpig.z;
501
+
502
+ device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
503
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
504
+
505
+ sum[tpitg.x] = 0.0f;
506
+
507
+ for (int i = tpitg.x; i < ne00; i += tptg.x) {
508
+ sum[tpitg.x] += (float) x[i] * (float) y[i];
509
+ }
510
+
511
+ // accumulate the sum from all threads in the threadgroup
512
+ threadgroup_barrier(mem_flags::mem_threadgroup);
513
+ for (uint i = tptg.x/2; i > 0; i /= 2) {
514
+ if (tpitg.x < i) {
515
+ sum[tpitg.x] += sum[tpitg.x + i];
516
+ }
517
+ threadgroup_barrier(mem_flags::mem_threadgroup);
518
+ }
519
+
520
+ if (tpitg.x == 0) {
521
+ dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
522
+ }
523
+ }
524
+
525
+ kernel void kernel_rope(
526
+ device const void * src0,
527
+ device float * dst,
528
+ constant int64_t & ne00,
529
+ constant int64_t & ne01,
530
+ constant int64_t & ne02,
531
+ constant int64_t & ne03,
532
+ constant uint64_t & nb00,
533
+ constant uint64_t & nb01,
534
+ constant uint64_t & nb02,
535
+ constant uint64_t & nb03,
536
+ constant int64_t & ne0,
537
+ constant int64_t & ne1,
538
+ constant int64_t & ne2,
539
+ constant int64_t & ne3,
540
+ constant uint64_t & nb0,
541
+ constant uint64_t & nb1,
542
+ constant uint64_t & nb2,
543
+ constant uint64_t & nb3,
544
+ constant int & n_past,
545
+ constant int & n_dims,
546
+ constant int & mode,
547
+ uint3 tpig[[thread_position_in_grid]]) {
548
+ const int64_t i3 = tpig[2];
549
+ const int64_t i2 = tpig[1];
550
+ const int64_t i1 = tpig[0];
551
+
552
+ const bool is_neox = mode & 2;
553
+ const float theta_scale = pow(10000.0, -2.0f/n_dims);
554
+
555
+ const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
556
+
557
+ float theta = (float)p;
558
+
559
+ if (!is_neox) {
560
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
561
+ const float cos_theta = cos(theta);
562
+ const float sin_theta = sin(theta);
563
+
564
+ theta *= theta_scale;
565
+
566
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
567
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
568
+
569
+ const float x0 = src[0];
570
+ const float x1 = src[1];
571
+
572
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
573
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
574
+ }
575
+ } else {
576
+ // TODO: implement
577
+ }
578
+ }
579
+
580
+ kernel void kernel_cpy_f32_f16(
581
+ device const float * src0,
582
+ device half * dst,
583
+ constant int64_t & ne00,
584
+ constant int64_t & ne01,
585
+ constant int64_t & ne02,
586
+ constant int64_t & ne03,
587
+ constant uint64_t & nb00,
588
+ constant uint64_t & nb01,
589
+ constant uint64_t & nb02,
590
+ constant uint64_t & nb03,
591
+ constant int64_t & ne0,
592
+ constant int64_t & ne1,
593
+ constant int64_t & ne2,
594
+ constant int64_t & ne3,
595
+ constant uint64_t & nb0,
596
+ constant uint64_t & nb1,
597
+ constant uint64_t & nb2,
598
+ constant uint64_t & nb3,
599
+ uint3 tgpig[[threadgroup_position_in_grid]],
600
+ uint3 tpitg[[thread_position_in_threadgroup]],
601
+ uint3 ntg[[threads_per_threadgroup]]) {
602
+ const int64_t i03 = tgpig[2];
603
+ const int64_t i02 = tgpig[1];
604
+ const int64_t i01 = tgpig[0];
605
+
606
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
607
+
608
+ const int64_t i3 = n / (ne2*ne1*ne0);
609
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
610
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
611
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
612
+
613
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
614
+
615
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
616
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
617
+
618
+ dst_data[i00] = src[0];
619
+ }
620
+ }
621
+
622
+ kernel void kernel_cpy_f32_f32(
623
+ device const float * src0,
624
+ device float * dst,
625
+ constant int64_t & ne00,
626
+ constant int64_t & ne01,
627
+ constant int64_t & ne02,
628
+ constant int64_t & ne03,
629
+ constant uint64_t & nb00,
630
+ constant uint64_t & nb01,
631
+ constant uint64_t & nb02,
632
+ constant uint64_t & nb03,
633
+ constant int64_t & ne0,
634
+ constant int64_t & ne1,
635
+ constant int64_t & ne2,
636
+ constant int64_t & ne3,
637
+ constant uint64_t & nb0,
638
+ constant uint64_t & nb1,
639
+ constant uint64_t & nb2,
640
+ constant uint64_t & nb3,
641
+ uint3 tgpig[[threadgroup_position_in_grid]],
642
+ uint3 tpitg[[thread_position_in_threadgroup]],
643
+ uint3 ntg[[threads_per_threadgroup]]) {
644
+ const int64_t i03 = tgpig[2];
645
+ const int64_t i02 = tgpig[1];
646
+ const int64_t i01 = tgpig[0];
647
+
648
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
649
+
650
+ const int64_t i3 = n / (ne2*ne1*ne0);
651
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
652
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
653
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
654
+
655
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
656
+
657
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
658
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
659
+
660
+ dst_data[i00] = src[0];
661
+ }
662
+ }
663
+
664
+ //============================================ k-quants ======================================================
665
+
666
+ #define QK_K 256
667
+
668
+ typedef struct {
669
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
670
+ uint8_t qs[QK_K/4]; // quants
671
+ half d; // super-block scale for quantized scales
672
+ half dmin; // super-block scale for quantized mins
673
+ } block_q2_k;
674
+
675
+ typedef struct {
676
+ half d; // super-block scale for quantized scales
677
+ half dmin; // super-block scale for quantized mins
678
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
679
+ uint8_t qs[QK_K/2]; // 4--bit quants
680
+ } block_q4_k;
681
+
682
+ typedef struct {
683
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
684
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
685
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
686
+ half d; // super-block scale
687
+ } block_q6_k;
688
+
689
+ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
690
+ uchar4 r;
691
+ if (j < 4) {
692
+ r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
693
+ r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
694
+ } else {
695
+ r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
696
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
697
+ r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
698
+ r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
699
+ }
700
+ return r;
701
+ }
702
+
703
+ //========================================== dequantization =============================
704
+
705
+ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
706
+ assert(k % QK_K == 0);
707
+ const int nb = k / QK_K;
708
+
709
+ for (int i = 0; i < nb; i++) {
710
+
711
+ const float d = x[i].d;
712
+ const float min = x[i].dmin;
713
+
714
+ device const uint8_t * q = x[i].qs;
715
+
716
+ int is = 0;
717
+ float dl, ml;
718
+ for (int n = 0; n < QK_K; n += 128) {
719
+ int shift = 0;
720
+ for (int j = 0; j < 4; ++j) {
721
+
722
+ uint8_t sc = x[i].scales[is++];
723
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
724
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
725
+
726
+ sc = x[i].scales[is++];
727
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
728
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
729
+
730
+ shift += 2;
731
+ }
732
+ q += 32;
733
+ }
734
+
735
+ }
736
+ }
737
+
738
+ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
739
+ assert(k % QK_K == 0);
740
+ const int nb = k / QK_K;
741
+
742
+ for (int i = 0; i < nb; i++) {
743
+
744
+ const float d = x[i].d;
745
+ const float min = x[i].dmin;
746
+
747
+ device const uint8_t * q = x[i].qs;
748
+ device const uint8_t * scales = x[i].scales;
749
+
750
+ int is = 0;
751
+ for (int j = 0; j < QK_K; j += 64) {
752
+ const uchar4 sc = get_scale_min_k4(is, scales);
753
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
754
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
755
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
756
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
757
+ q += 32; is += 2;
758
+ }
759
+
760
+ }
761
+ }
762
+
763
+ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
764
+ assert(k % QK_K == 0);
765
+ const int nb = k / QK_K;
766
+
767
+ for (int i = 0; i < nb; i++) {
768
+
769
+ device const uint8_t * ql = x[i].ql;
770
+ device const uint8_t * qh = x[i].qh;
771
+ device const int8_t * sc = x[i].scales;
772
+
773
+ const float d = x[i].d;
774
+
775
+ for (int n = 0; n < QK_K; n += 128) {
776
+ for (int l = 0; l < 32; ++l) {
777
+ int is = l/16;
778
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
779
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
780
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
781
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
782
+ y[l + 0] = d * sc[is + 0] * q1;
783
+ y[l + 32] = d * sc[is + 2] * q2;
784
+ y[l + 64] = d * sc[is + 4] * q3;
785
+ y[l + 96] = d * sc[is + 6] * q4;
786
+ }
787
+ y += 128;
788
+ ql += 64;
789
+ qh += 32;
790
+ sc += 8;
791
+ }
792
+ }
793
+ }
794
+
795
+ kernel void kernel_get_rows_q2_k(
796
+ device const void * src0,
797
+ device const int * src1,
798
+ device float * dst,
799
+ constant int64_t & ne00,
800
+ constant uint64_t & nb01,
801
+ constant uint64_t & nb1,
802
+ uint tpig[[thread_position_in_grid]]) {
803
+ const int i = tpig;
804
+ const int r = ((device int32_t *) src1)[i];
805
+
806
+ dequantize_row_q2_k(
807
+ (device const block_q2_k *) ((device char *) src0 + r*nb01),
808
+ (device float *) ((device char *) dst + i*nb1), ne00);
809
+ }
810
+
811
+ kernel void kernel_get_rows_q4_k(
812
+ device const void * src0,
813
+ device const int * src1,
814
+ device float * dst,
815
+ constant int64_t & ne00,
816
+ constant uint64_t & nb01,
817
+ constant uint64_t & nb1,
818
+ uint tpig[[thread_position_in_grid]]) {
819
+ const int i = tpig;
820
+ const int r = ((device int32_t *) src1)[i];
821
+
822
+ dequantize_row_q4_k(
823
+ (device const block_q4_k *) ((device char *) src0 + r*nb01),
824
+ (device float *) ((device char *) dst + i*nb1), ne00);
825
+ }
826
+
827
+ kernel void kernel_get_rows_q6_k(
828
+ device const void * src0,
829
+ device const int * src1,
830
+ device float * dst,
831
+ constant int64_t & ne00,
832
+ constant uint64_t & nb01,
833
+ constant uint64_t & nb1,
834
+ uint tpig[[thread_position_in_grid]]) {
835
+ const int i = tpig;
836
+ const int r = ((device int32_t *) src1)[i];
837
+
838
+ dequantize_row_q6_k(
839
+ (device const block_q6_k *) ((device char *) src0 + r*nb01),
840
+ (device float *) ((device char *) dst + i*nb1), ne00);
841
+ }
842
+
843
+ //====================================== dot products =========================
844
+
845
+ kernel void kernel_mul_mat_q2_k_f32(
846
+ device const void * src0,
847
+ device const float * src1,
848
+ device float * dst,
849
+ constant int64_t & ne00,
850
+ constant int64_t & ne01,
851
+ constant uint64_t & nb00,
852
+ constant uint64_t & nb01,
853
+ constant uint64_t & nb02,
854
+ constant int64_t & ne10,
855
+ constant int64_t & ne11,
856
+ constant uint64_t & nb10,
857
+ constant uint64_t & nb11,
858
+ constant uint64_t & nb12,
859
+ constant int64_t & ne0,
860
+ constant int64_t & ne1,
861
+ threadgroup float * sum [[threadgroup(0)]],
862
+ uint2 tgpig[[threadgroup_position_in_grid]],
863
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
864
+ uint2 tpitg[[thread_position_in_threadgroup]],
865
+ uint2 tptg[[threads_per_threadgroup]]) {
866
+
867
+ const int nb = ne00/QK_K;
868
+
869
+ const int64_t r0 = tgpig.x;
870
+ const int64_t r1 = tgpig.y;
871
+
872
+ device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
873
+ device const float * yy = (device const float *) src1 + r1*ne10;
874
+
875
+ const int nth = tptg.x*tptg.y;
876
+ const int ith = tptg.y*tpitg.x + tpitg.y;
877
+
878
+
879
+ const int tid = tpitg.y; // 0...16
880
+ const int il = tid/4; // 0...3
881
+ const int ir = tid%4; // 0...3
882
+ const int ip = il/2; // 0 or 1
883
+ const int shift1 = 4*(il%2);// 0 or 4
884
+ const int shift2 = shift1+2;// 2 or 6
885
+ const int n = 8;
886
+ const int is = 4*il + (n*ir)/16;
887
+
888
+ sum[ith] = 0.0f;
889
+
890
+ float sumf = 0;
891
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
892
+
893
+ device const uint8_t * q = x[i].qs + 32*ip + n*ir;
894
+ device const uint8_t * scales = x[i].scales + is;
895
+
896
+ uint8_t d1 = scales[0] & 0xF;
897
+ uint8_t m1 = scales[0] >> 4;
898
+ uint8_t d2 = scales[2] & 0xF;
899
+ uint8_t m2 = scales[2] >> 4;
900
+
901
+ device const float * y = yy + i*QK_K + 64*il + n*ir;
902
+
903
+ const float dall = (float)x[i].d;
904
+ const float dmin = (float)x[i].dmin;
905
+
906
+ float4 s = {0.f, 0.f, 0.f, 0.f};
907
+ for (int l = 0; l < n; ++l) {
908
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
909
+ s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
910
+ }
911
+ sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
912
+
913
+
914
+ }
915
+ sum[ith] = sumf;
916
+
917
+ //
918
+ // Accumulate the sum from all threads in the threadgroup
919
+ // This version is slightly faster than the commented out one below,
920
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
921
+ //
922
+ threadgroup_barrier(mem_flags::mem_threadgroup);
923
+ if (ith%4 == 0) {
924
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
925
+ }
926
+ threadgroup_barrier(mem_flags::mem_threadgroup);
927
+ if (ith%16 == 0) {
928
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
929
+ }
930
+ threadgroup_barrier(mem_flags::mem_threadgroup);
931
+ if (ith == 0) {
932
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
933
+ dst[r1*ne0 + r0] = sum[0];
934
+ }
935
+
936
+ //// accumulate the sum from all threads in the threadgroup
937
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
938
+ //for (uint i = nth/2; i > 0; i /= 2) {
939
+ // if (ith < i) {
940
+ // sum[ith] += sum[ith + i];
941
+ // }
942
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
943
+ //}
944
+
945
+ //if (ith == 0) {
946
+ // dst[r1*ne0 + r0] = sum[0];
947
+ //}
948
+ }
949
+
950
+ kernel void kernel_mul_mat_q4_k_f32(
951
+ device const void * src0,
952
+ device const float * src1,
953
+ device float * dst,
954
+ constant int64_t & ne00,
955
+ constant int64_t & ne01,
956
+ constant uint64_t & nb00,
957
+ constant uint64_t & nb01,
958
+ constant uint64_t & nb02,
959
+ constant int64_t & ne10,
960
+ constant int64_t & ne11,
961
+ constant uint64_t & nb10,
962
+ constant uint64_t & nb11,
963
+ constant uint64_t & nb12,
964
+ constant int64_t & ne0,
965
+ constant int64_t & ne1,
966
+ threadgroup float * sum [[threadgroup(0)]],
967
+ uint2 tgpig[[threadgroup_position_in_grid]],
968
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
969
+ uint2 tpitg[[thread_position_in_threadgroup]],
970
+ uint2 tptg[[threads_per_threadgroup]]) {
971
+
972
+ const int nb = ne00/QK_K;
973
+
974
+ const int64_t r0 = tgpig.x;
975
+ const int64_t r1 = tgpig.y;
976
+
977
+ device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
978
+ device const float * yy = (device const float *) src1 + r1*ne10;
979
+
980
+ const uint nth = tptg.x*tptg.y;
981
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
982
+
983
+ const int tid = tpitg.y; // 0...16
984
+ const int il = tid/4; // 0...3
985
+ const int ir = tid%4; // 0...3
986
+ const int n = 8;
987
+ const int is = 2*il;
988
+
989
+ sum[ith] = 0.0f;
990
+
991
+ float sumf = 0;
992
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
993
+
994
+ device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995
+ device const float * y = yy + i*QK_K + 64*il + n*ir;
996
+ device const uint8_t * scales = (x + i)->scales;
997
+
998
+ const float dall = (float)((x + i)->d);
999
+ const float dmin = (float)((x + i)->dmin);
1000
+
1001
+ const uchar4 sc = get_scale_min_k4(is, scales);
1002
+
1003
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1004
+ for (int l = 0; l < n; ++l) {
1005
+ s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006
+ s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1007
+ }
1008
+ sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1009
+
1010
+ }
1011
+ sum[ith] = sumf;
1012
+
1013
+ //
1014
+ // Accumulate the sum from all threads in the threadgroup
1015
+ // This version is slightly faster than the commented out one below,
1016
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1017
+ //
1018
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1019
+ if (ith%4 == 0) {
1020
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1021
+ }
1022
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1023
+ if (ith%16 == 0) {
1024
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1025
+ }
1026
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1027
+ if (ith == 0) {
1028
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1029
+ dst[r1*ne0 + r0] = sum[0];
1030
+ }
1031
+
1032
+ //// accumulate the sum from all threads in the threadgroup
1033
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1034
+ //for (uint i = nth/2; i > 0; i /= 2) {
1035
+ // if (ith < i) {
1036
+ // sum[ith] += sum[ith + i];
1037
+ // }
1038
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
1039
+ //}
1040
+
1041
+ //if (ith == 0) {
1042
+ // dst[r1*ne0 + r0] = sum[0];
1043
+ //}
1044
+ }
1045
+
1046
+ kernel void kernel_mul_mat_q6_k_f32(
1047
+ device const void * src0,
1048
+ device const float * src1,
1049
+ device float * dst,
1050
+ constant int64_t & ne00,
1051
+ constant int64_t & ne01,
1052
+ constant uint64_t & nb00,
1053
+ constant uint64_t & nb01,
1054
+ constant uint64_t & nb02,
1055
+ constant int64_t & ne10,
1056
+ constant int64_t & ne11,
1057
+ constant uint64_t & nb10,
1058
+ constant uint64_t & nb11,
1059
+ constant uint64_t & nb12,
1060
+ constant int64_t & ne0,
1061
+ constant int64_t & ne1,
1062
+ threadgroup float * sum [[threadgroup(0)]],
1063
+ uint2 tgpig[[threadgroup_position_in_grid]],
1064
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
1065
+ uint2 tpitg[[thread_position_in_threadgroup]],
1066
+ uint2 tptg[[threads_per_threadgroup]]) {
1067
+
1068
+ const uint8_t kmask1 = 0x03;
1069
+ const uint8_t kmask2 = 0x0C;
1070
+ const uint8_t kmask3 = 0x30;
1071
+ const uint8_t kmask4 = 0xC0;
1072
+
1073
+ const int nb = ne00/QK_K;
1074
+
1075
+ const int64_t r0 = tgpig.x;
1076
+ const int64_t r1 = tgpig.y;
1077
+
1078
+ device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1079
+ device const float * yy = (device const float *) src1 + r1*ne10;
1080
+
1081
+ const uint nth = tptg.x*tptg.y;
1082
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
1083
+
1084
+ const int step = QK_K / tptg.y; // we expect this to be 16
1085
+ const int iqs = step * tpitg.y; // 0...240 in steps of 16
1086
+ const int ip = iqs / 128; // 0 or 1
1087
+ const int il = (iqs - 128*ip)/16; // 0...7
1088
+ const int n = 4;
1089
+ const int is = 8*ip + (n*il)/16;
1090
+
1091
+ float sumf = 0;
1092
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1093
+
1094
+ device const uint8_t * ql = x[i].ql + 64*ip + n*il;
1095
+ device const uint8_t * qh = x[i].qh + 32*ip + n*il;
1096
+ device const int8_t * sc = x[i].scales + is;
1097
+
1098
+ device const float * y = yy + i * QK_K + 128*ip + n*il;
1099
+
1100
+ const float dall = x[i].d;
1101
+
1102
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1103
+ for (int l = 0; l < n; ++l) {
1104
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1105
+ sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1106
+ sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1107
+ sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1108
+ }
1109
+
1110
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1111
+
1112
+ }
1113
+
1114
+ sum[ith] = sumf;
1115
+
1116
+ //
1117
+ // Accumulate the sum from all threads in the threadgroup
1118
+ //
1119
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1120
+ if (ith%4 == 0) {
1121
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1122
+ }
1123
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1124
+ if (ith%16 == 0) {
1125
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1126
+ }
1127
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1128
+ if (ith == 0) {
1129
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1130
+ dst[r1*ne0 + r0] = sum[0];
1131
+ }
1132
+
1133
+ }