llama_cpp 0.6.0 → 0.7.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +49 -3
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +622 -150
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +358 -131
- data/ext/llama_cpp/src/ggml-metal.metal +137 -47
- data/ext/llama_cpp/src/ggml-opencl.cpp +136 -68
- data/ext/llama_cpp/src/ggml.c +812 -365
- data/ext/llama_cpp/src/ggml.h +25 -7
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +2387 -421
- data/ext/llama_cpp/src/llama.h +22 -6
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -0
- metadata +5 -2
@@ -13,8 +13,8 @@ typedef struct {
|
|
13
13
|
|
14
14
|
#define QK4_1 32
|
15
15
|
typedef struct {
|
16
|
-
half d;
|
17
|
-
half m;
|
16
|
+
half d; // delta
|
17
|
+
half m; // min
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
@@ -132,6 +132,13 @@ kernel void kernel_relu(
|
|
132
132
|
dst[tpig] = max(0.0f, src0[tpig]);
|
133
133
|
}
|
134
134
|
|
135
|
+
kernel void kernel_sqr(
|
136
|
+
device const float * src0,
|
137
|
+
device float * dst,
|
138
|
+
uint tpig[[thread_position_in_grid]]) {
|
139
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
140
|
+
}
|
141
|
+
|
135
142
|
constant float GELU_COEF_A = 0.044715f;
|
136
143
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
137
144
|
|
@@ -338,10 +345,11 @@ kernel void kernel_rms_norm(
|
|
338
345
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
339
346
|
uint tiisg[[thread_index_in_simdgroup]],
|
340
347
|
uint ntg[[threads_per_threadgroup]]) {
|
341
|
-
device const float4 * x
|
342
|
-
device const float
|
343
|
-
|
344
|
-
|
348
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
349
|
+
device const float * x_scalar = (device const float *) x;
|
350
|
+
|
351
|
+
float4 sumf = 0;
|
352
|
+
float all_sum = 0;
|
345
353
|
|
346
354
|
// parallel sum
|
347
355
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
@@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
|
|
354
362
|
}
|
355
363
|
|
356
364
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
365
|
+
|
357
366
|
// broadcast, simd group number is ntg / 32
|
358
367
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
359
368
|
if (tpitg < i) {
|
@@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
|
|
361
370
|
}
|
362
371
|
}
|
363
372
|
if (tpitg == 0) {
|
364
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
373
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
374
|
+
sum[0] += x_scalar[i];
|
375
|
+
}
|
365
376
|
sum[0] /= ne00;
|
366
377
|
}
|
367
378
|
|
@@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
|
|
376
387
|
y[i00] = x[i00] * scale;
|
377
388
|
}
|
378
389
|
if (tpitg == 0) {
|
379
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
390
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
391
|
+
y_scalar[i00] = x_scalar[i00] * scale;
|
392
|
+
}
|
380
393
|
}
|
381
394
|
}
|
382
395
|
|
@@ -416,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
416
429
|
}
|
417
430
|
|
418
431
|
// putting them in the kernel cause a significant performance penalty
|
419
|
-
#define N_DST 4
|
420
|
-
#define N_SIMDGROUP 2
|
432
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
433
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
421
434
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
422
435
|
//Note: This is a template, but strictly speaking it only applies to
|
423
436
|
// quantizations where the block size is 32. It also does not
|
@@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
428
441
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
429
442
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
430
443
|
const int nb = ne00/QK4_0;
|
444
|
+
|
431
445
|
const int r0 = tgpig.x;
|
432
446
|
const int r1 = tgpig.y;
|
433
447
|
const int im = tgpig.z;
|
448
|
+
|
434
449
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
450
|
+
|
435
451
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
452
|
+
|
436
453
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
437
454
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
438
|
-
float yl[16]; // src1 vector cache
|
439
|
-
float sumf[nr]={0.f};
|
440
455
|
|
441
|
-
|
442
|
-
|
456
|
+
float yl[16]; // src1 vector cache
|
457
|
+
float sumf[nr] = {0.f};
|
458
|
+
|
459
|
+
const int ix = (tiisg/2);
|
460
|
+
const int il = (tiisg%2)*8;
|
443
461
|
|
444
462
|
device const float * yb = y + ix * QK4_0 + il;
|
445
463
|
|
@@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
450
468
|
sumy += yb[i] + yb[i+1];
|
451
469
|
yl[i+0] = yb[i+ 0];
|
452
470
|
yl[i+1] = yb[i+ 1]/256.f;
|
471
|
+
|
453
472
|
sumy += yb[i+16] + yb[i+17];
|
454
473
|
yl[i+8] = yb[i+16]/16.f;
|
455
474
|
yl[i+9] = yb[i+17]/4096.f;
|
@@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
465
484
|
for (int row = 0; row < nr; ++row) {
|
466
485
|
const float tot = simd_sum(sumf[row]);
|
467
486
|
if (tiisg == 0 && first_row + row < ne01) {
|
468
|
-
dst[
|
487
|
+
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
469
488
|
}
|
470
489
|
}
|
471
490
|
}
|
472
491
|
|
473
|
-
kernel void
|
492
|
+
kernel void kernel_mul_mv_q4_0_f32(
|
474
493
|
device const void * src0,
|
475
494
|
device const float * src1,
|
476
495
|
device float * dst,
|
@@ -483,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
483
502
|
constant int64_t & ne1[[buffer(16)]],
|
484
503
|
constant uint & gqa[[buffer(17)]],
|
485
504
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
486
|
-
uint
|
487
|
-
uint
|
505
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
506
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
488
507
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
489
508
|
}
|
490
509
|
|
491
|
-
kernel void
|
510
|
+
kernel void kernel_mul_mv_q4_1_f32(
|
492
511
|
device const void * src0,
|
493
512
|
device const float * src1,
|
494
513
|
device float * dst,
|
@@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
508
527
|
|
509
528
|
#define NB_Q8_0 8
|
510
529
|
|
511
|
-
kernel void
|
530
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
512
531
|
device const void * src0,
|
513
532
|
device const float * src1,
|
514
533
|
device float * dst,
|
@@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
572
591
|
|
573
592
|
#define N_F32_F32 4
|
574
593
|
|
575
|
-
kernel void
|
594
|
+
kernel void kernel_mul_mv_f32_f32(
|
576
595
|
device const char * src0,
|
577
596
|
device const char * src1,
|
578
597
|
device float * dst,
|
@@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
|
|
643
662
|
}
|
644
663
|
}
|
645
664
|
|
646
|
-
kernel void
|
665
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
647
666
|
device const char * src0,
|
648
667
|
device const char * src1,
|
649
668
|
device float * dst,
|
@@ -662,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
662
681
|
constant int64_t & ne0,
|
663
682
|
constant int64_t & ne1,
|
664
683
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
665
|
-
uint
|
684
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
666
685
|
|
667
686
|
const int64_t r0 = tgpig.x;
|
668
687
|
const int64_t r1 = tgpig.y;
|
@@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
697
716
|
|
698
717
|
#define N_F16_F32 4
|
699
718
|
|
700
|
-
kernel void
|
719
|
+
kernel void kernel_mul_mv_f16_f32(
|
701
720
|
device const char * src0,
|
702
721
|
device const char * src1,
|
703
722
|
device float * dst,
|
@@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
769
788
|
}
|
770
789
|
|
771
790
|
// Assumes row size (ne00) is a multiple of 4
|
772
|
-
kernel void
|
791
|
+
kernel void kernel_mul_mv_f16_f32_l4(
|
773
792
|
device const char * src0,
|
774
793
|
device const char * src1,
|
775
794
|
device float * dst,
|
@@ -830,7 +849,9 @@ kernel void kernel_alibi_f32(
|
|
830
849
|
constant uint64_t & nb1,
|
831
850
|
constant uint64_t & nb2,
|
832
851
|
constant uint64_t & nb3,
|
833
|
-
constant
|
852
|
+
constant float & m0,
|
853
|
+
constant float & m1,
|
854
|
+
constant int & n_heads_log2_floor,
|
834
855
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
835
856
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
836
857
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -846,7 +867,12 @@ kernel void kernel_alibi_f32(
|
|
846
867
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
847
868
|
|
848
869
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
849
|
-
float m_k
|
870
|
+
float m_k;
|
871
|
+
if (i2 < n_heads_log2_floor) {
|
872
|
+
m_k = pow(m0, i2 + 1);
|
873
|
+
} else {
|
874
|
+
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
|
875
|
+
}
|
850
876
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
851
877
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
852
878
|
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
@@ -1091,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
|
|
1091
1117
|
}
|
1092
1118
|
}
|
1093
1119
|
|
1120
|
+
kernel void kernel_concat(
|
1121
|
+
device const char * src0,
|
1122
|
+
device const char * src1,
|
1123
|
+
device char * dst,
|
1124
|
+
constant int64_t & ne00,
|
1125
|
+
constant int64_t & ne01,
|
1126
|
+
constant int64_t & ne02,
|
1127
|
+
constant int64_t & ne03,
|
1128
|
+
constant uint64_t & nb00,
|
1129
|
+
constant uint64_t & nb01,
|
1130
|
+
constant uint64_t & nb02,
|
1131
|
+
constant uint64_t & nb03,
|
1132
|
+
constant int64_t & ne10,
|
1133
|
+
constant int64_t & ne11,
|
1134
|
+
constant int64_t & ne12,
|
1135
|
+
constant int64_t & ne13,
|
1136
|
+
constant uint64_t & nb10,
|
1137
|
+
constant uint64_t & nb11,
|
1138
|
+
constant uint64_t & nb12,
|
1139
|
+
constant uint64_t & nb13,
|
1140
|
+
constant int64_t & ne0,
|
1141
|
+
constant int64_t & ne1,
|
1142
|
+
constant int64_t & ne2,
|
1143
|
+
constant int64_t & ne3,
|
1144
|
+
constant uint64_t & nb0,
|
1145
|
+
constant uint64_t & nb1,
|
1146
|
+
constant uint64_t & nb2,
|
1147
|
+
constant uint64_t & nb3,
|
1148
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1149
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1150
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1151
|
+
|
1152
|
+
const int64_t i03 = tgpig.z;
|
1153
|
+
const int64_t i02 = tgpig.y;
|
1154
|
+
const int64_t i01 = tgpig.x;
|
1155
|
+
|
1156
|
+
const int64_t i13 = i03 % ne13;
|
1157
|
+
const int64_t i12 = i02 % ne12;
|
1158
|
+
const int64_t i11 = i01 % ne11;
|
1159
|
+
|
1160
|
+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
1161
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1162
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1163
|
+
|
1164
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1165
|
+
if (i02 < ne02) {
|
1166
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
1167
|
+
src0_ptr += ntg.x*nb00;
|
1168
|
+
} else {
|
1169
|
+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
1170
|
+
src1_ptr += ntg.x*nb10;
|
1171
|
+
}
|
1172
|
+
dst_ptr += ntg.x*nb0;
|
1173
|
+
}
|
1174
|
+
}
|
1175
|
+
|
1094
1176
|
//============================================ k-quants ======================================================
|
1095
1177
|
|
1096
1178
|
#ifndef QK_K
|
@@ -1183,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
1183
1265
|
|
1184
1266
|
//====================================== dot products =========================
|
1185
1267
|
|
1186
|
-
kernel void
|
1268
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
1187
1269
|
device const void * src0,
|
1188
1270
|
device const float * src1,
|
1189
1271
|
device float * dst,
|
@@ -1327,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1327
1409
|
}
|
1328
1410
|
|
1329
1411
|
#if QK_K == 256
|
1330
|
-
kernel void
|
1412
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1331
1413
|
device const void * src0,
|
1332
1414
|
device const float * src1,
|
1333
1415
|
device float * dst,
|
@@ -1479,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1479
1561
|
}
|
1480
1562
|
}
|
1481
1563
|
#else
|
1482
|
-
kernel void
|
1564
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1483
1565
|
device const void * src0,
|
1484
1566
|
device const float * src1,
|
1485
1567
|
device float * dst,
|
@@ -1550,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1550
1632
|
#endif
|
1551
1633
|
|
1552
1634
|
#if QK_K == 256
|
1553
|
-
kernel void
|
1635
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1554
1636
|
device const void * src0,
|
1555
1637
|
device const float * src1,
|
1556
1638
|
device float * dst,
|
@@ -1656,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1656
1738
|
}
|
1657
1739
|
}
|
1658
1740
|
#else
|
1659
|
-
kernel void
|
1741
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1660
1742
|
device const void * src0,
|
1661
1743
|
device const float * src1,
|
1662
1744
|
device float * dst,
|
@@ -1745,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1745
1827
|
}
|
1746
1828
|
#endif
|
1747
1829
|
|
1748
|
-
kernel void
|
1830
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
1749
1831
|
device const void * src0,
|
1750
1832
|
device const float * src1,
|
1751
1833
|
device float * dst,
|
@@ -1918,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1918
2000
|
|
1919
2001
|
}
|
1920
2002
|
|
1921
|
-
kernel void
|
2003
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
1922
2004
|
device const void * src0,
|
1923
2005
|
device const float * src1,
|
1924
2006
|
device float * dst,
|
@@ -2256,7 +2338,7 @@ kernel void kernel_get_rows(
|
|
2256
2338
|
}
|
2257
2339
|
|
2258
2340
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
2259
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix
|
2341
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
2260
2342
|
#define BLOCK_SIZE_K 32
|
2261
2343
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
2262
2344
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
@@ -2293,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2293
2375
|
const uint r0 = tgpig.y;
|
2294
2376
|
const uint r1 = tgpig.x;
|
2295
2377
|
const uint im = tgpig.z;
|
2378
|
+
|
2296
2379
|
// if this block is of 64x32 shape or smaller
|
2297
2380
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
2298
2381
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
2382
|
+
|
2299
2383
|
// a thread shouldn't load data outside of the matrix
|
2300
2384
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2301
2385
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
@@ -2319,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2319
2403
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2320
2404
|
|
2321
2405
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2322
|
-
//load data and store to threadgroup memory
|
2406
|
+
// load data and store to threadgroup memory
|
2323
2407
|
half4x4 temp_a;
|
2324
2408
|
dequantize_func(x, il, temp_a);
|
2325
2409
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2410
|
+
|
2326
2411
|
#pragma unroll(16)
|
2327
2412
|
for (int i = 0; i < 16; i++) {
|
2328
2413
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
2329
|
-
+
|
2330
|
-
+
|
2414
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
2415
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
2331
2416
|
}
|
2332
|
-
|
2333
|
-
|
2417
|
+
|
2418
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
2419
|
+
|
2334
2420
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
2335
2421
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
2336
2422
|
y += BLOCK_SIZE_K;
|
2337
2423
|
|
2338
2424
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2339
|
-
|
2425
|
+
|
2426
|
+
// load matrices from threadgroup memory and conduct outer products
|
2340
2427
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
2341
2428
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
2429
|
+
|
2342
2430
|
#pragma unroll(4)
|
2343
2431
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
2344
2432
|
#pragma unroll(4)
|
@@ -2353,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2353
2441
|
|
2354
2442
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
2355
2443
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
2444
|
+
|
2356
2445
|
#pragma unroll(8)
|
2357
2446
|
for (int i = 0; i < 8; i++){
|
2358
2447
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
@@ -2361,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2361
2450
|
}
|
2362
2451
|
|
2363
2452
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
2364
|
-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
2365
|
-
|
2453
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
2454
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
2366
2455
|
for (int i = 0; i < 8; i++) {
|
2367
2456
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
2368
2457
|
}
|
2369
2458
|
} else {
|
2370
2459
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
2371
2460
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2372
|
-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
2461
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
2373
2462
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
2374
2463
|
for (int i = 0; i < 8; i++) {
|
2375
2464
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
2376
2465
|
}
|
2377
2466
|
|
2378
2467
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2379
|
-
|
2380
|
-
|
2468
|
+
|
2469
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
2470
|
+
if (sgitg == 0) {
|
2381
2471
|
for (int i = 0; i < n_rows; i++) {
|
2382
|
-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
2472
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
2383
2473
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
2384
2474
|
}
|
2385
2475
|
}
|