llama_cpp 0.6.0 → 0.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +49 -3
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +622 -150
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +358 -131
- data/ext/llama_cpp/src/ggml-metal.metal +137 -47
- data/ext/llama_cpp/src/ggml-opencl.cpp +136 -68
- data/ext/llama_cpp/src/ggml.c +812 -365
- data/ext/llama_cpp/src/ggml.h +25 -7
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +2387 -421
- data/ext/llama_cpp/src/llama.h +22 -6
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -0
- metadata +5 -2
@@ -13,8 +13,8 @@ typedef struct {
|
|
13
13
|
|
14
14
|
#define QK4_1 32
|
15
15
|
typedef struct {
|
16
|
-
half d;
|
17
|
-
half m;
|
16
|
+
half d; // delta
|
17
|
+
half m; // min
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
@@ -132,6 +132,13 @@ kernel void kernel_relu(
|
|
132
132
|
dst[tpig] = max(0.0f, src0[tpig]);
|
133
133
|
}
|
134
134
|
|
135
|
+
kernel void kernel_sqr(
|
136
|
+
device const float * src0,
|
137
|
+
device float * dst,
|
138
|
+
uint tpig[[thread_position_in_grid]]) {
|
139
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
140
|
+
}
|
141
|
+
|
135
142
|
constant float GELU_COEF_A = 0.044715f;
|
136
143
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
137
144
|
|
@@ -338,10 +345,11 @@ kernel void kernel_rms_norm(
|
|
338
345
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
339
346
|
uint tiisg[[thread_index_in_simdgroup]],
|
340
347
|
uint ntg[[threads_per_threadgroup]]) {
|
341
|
-
device const float4 * x
|
342
|
-
device const float
|
343
|
-
|
344
|
-
|
348
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
349
|
+
device const float * x_scalar = (device const float *) x;
|
350
|
+
|
351
|
+
float4 sumf = 0;
|
352
|
+
float all_sum = 0;
|
345
353
|
|
346
354
|
// parallel sum
|
347
355
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
@@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
|
|
354
362
|
}
|
355
363
|
|
356
364
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
365
|
+
|
357
366
|
// broadcast, simd group number is ntg / 32
|
358
367
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
359
368
|
if (tpitg < i) {
|
@@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
|
|
361
370
|
}
|
362
371
|
}
|
363
372
|
if (tpitg == 0) {
|
364
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
373
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
374
|
+
sum[0] += x_scalar[i];
|
375
|
+
}
|
365
376
|
sum[0] /= ne00;
|
366
377
|
}
|
367
378
|
|
@@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
|
|
376
387
|
y[i00] = x[i00] * scale;
|
377
388
|
}
|
378
389
|
if (tpitg == 0) {
|
379
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
390
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
391
|
+
y_scalar[i00] = x_scalar[i00] * scale;
|
392
|
+
}
|
380
393
|
}
|
381
394
|
}
|
382
395
|
|
@@ -416,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
416
429
|
}
|
417
430
|
|
418
431
|
// putting them in the kernel cause a significant performance penalty
|
419
|
-
#define N_DST 4
|
420
|
-
#define N_SIMDGROUP 2
|
432
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
433
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
421
434
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
422
435
|
//Note: This is a template, but strictly speaking it only applies to
|
423
436
|
// quantizations where the block size is 32. It also does not
|
@@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
428
441
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
429
442
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
430
443
|
const int nb = ne00/QK4_0;
|
444
|
+
|
431
445
|
const int r0 = tgpig.x;
|
432
446
|
const int r1 = tgpig.y;
|
433
447
|
const int im = tgpig.z;
|
448
|
+
|
434
449
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
450
|
+
|
435
451
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
452
|
+
|
436
453
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
437
454
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
438
|
-
float yl[16]; // src1 vector cache
|
439
|
-
float sumf[nr]={0.f};
|
440
455
|
|
441
|
-
|
442
|
-
|
456
|
+
float yl[16]; // src1 vector cache
|
457
|
+
float sumf[nr] = {0.f};
|
458
|
+
|
459
|
+
const int ix = (tiisg/2);
|
460
|
+
const int il = (tiisg%2)*8;
|
443
461
|
|
444
462
|
device const float * yb = y + ix * QK4_0 + il;
|
445
463
|
|
@@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
450
468
|
sumy += yb[i] + yb[i+1];
|
451
469
|
yl[i+0] = yb[i+ 0];
|
452
470
|
yl[i+1] = yb[i+ 1]/256.f;
|
471
|
+
|
453
472
|
sumy += yb[i+16] + yb[i+17];
|
454
473
|
yl[i+8] = yb[i+16]/16.f;
|
455
474
|
yl[i+9] = yb[i+17]/4096.f;
|
@@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
465
484
|
for (int row = 0; row < nr; ++row) {
|
466
485
|
const float tot = simd_sum(sumf[row]);
|
467
486
|
if (tiisg == 0 && first_row + row < ne01) {
|
468
|
-
dst[
|
487
|
+
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
469
488
|
}
|
470
489
|
}
|
471
490
|
}
|
472
491
|
|
473
|
-
kernel void
|
492
|
+
kernel void kernel_mul_mv_q4_0_f32(
|
474
493
|
device const void * src0,
|
475
494
|
device const float * src1,
|
476
495
|
device float * dst,
|
@@ -483,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
483
502
|
constant int64_t & ne1[[buffer(16)]],
|
484
503
|
constant uint & gqa[[buffer(17)]],
|
485
504
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
486
|
-
uint
|
487
|
-
uint
|
505
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
506
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
488
507
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
489
508
|
}
|
490
509
|
|
491
|
-
kernel void
|
510
|
+
kernel void kernel_mul_mv_q4_1_f32(
|
492
511
|
device const void * src0,
|
493
512
|
device const float * src1,
|
494
513
|
device float * dst,
|
@@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
508
527
|
|
509
528
|
#define NB_Q8_0 8
|
510
529
|
|
511
|
-
kernel void
|
530
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
512
531
|
device const void * src0,
|
513
532
|
device const float * src1,
|
514
533
|
device float * dst,
|
@@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
572
591
|
|
573
592
|
#define N_F32_F32 4
|
574
593
|
|
575
|
-
kernel void
|
594
|
+
kernel void kernel_mul_mv_f32_f32(
|
576
595
|
device const char * src0,
|
577
596
|
device const char * src1,
|
578
597
|
device float * dst,
|
@@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
|
|
643
662
|
}
|
644
663
|
}
|
645
664
|
|
646
|
-
kernel void
|
665
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
647
666
|
device const char * src0,
|
648
667
|
device const char * src1,
|
649
668
|
device float * dst,
|
@@ -662,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
662
681
|
constant int64_t & ne0,
|
663
682
|
constant int64_t & ne1,
|
664
683
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
665
|
-
uint
|
684
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
666
685
|
|
667
686
|
const int64_t r0 = tgpig.x;
|
668
687
|
const int64_t r1 = tgpig.y;
|
@@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
697
716
|
|
698
717
|
#define N_F16_F32 4
|
699
718
|
|
700
|
-
kernel void
|
719
|
+
kernel void kernel_mul_mv_f16_f32(
|
701
720
|
device const char * src0,
|
702
721
|
device const char * src1,
|
703
722
|
device float * dst,
|
@@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
769
788
|
}
|
770
789
|
|
771
790
|
// Assumes row size (ne00) is a multiple of 4
|
772
|
-
kernel void
|
791
|
+
kernel void kernel_mul_mv_f16_f32_l4(
|
773
792
|
device const char * src0,
|
774
793
|
device const char * src1,
|
775
794
|
device float * dst,
|
@@ -830,7 +849,9 @@ kernel void kernel_alibi_f32(
|
|
830
849
|
constant uint64_t & nb1,
|
831
850
|
constant uint64_t & nb2,
|
832
851
|
constant uint64_t & nb3,
|
833
|
-
constant
|
852
|
+
constant float & m0,
|
853
|
+
constant float & m1,
|
854
|
+
constant int & n_heads_log2_floor,
|
834
855
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
835
856
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
836
857
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -846,7 +867,12 @@ kernel void kernel_alibi_f32(
|
|
846
867
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
847
868
|
|
848
869
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
849
|
-
float m_k
|
870
|
+
float m_k;
|
871
|
+
if (i2 < n_heads_log2_floor) {
|
872
|
+
m_k = pow(m0, i2 + 1);
|
873
|
+
} else {
|
874
|
+
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
|
875
|
+
}
|
850
876
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
851
877
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
852
878
|
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
@@ -1091,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
|
|
1091
1117
|
}
|
1092
1118
|
}
|
1093
1119
|
|
1120
|
+
kernel void kernel_concat(
|
1121
|
+
device const char * src0,
|
1122
|
+
device const char * src1,
|
1123
|
+
device char * dst,
|
1124
|
+
constant int64_t & ne00,
|
1125
|
+
constant int64_t & ne01,
|
1126
|
+
constant int64_t & ne02,
|
1127
|
+
constant int64_t & ne03,
|
1128
|
+
constant uint64_t & nb00,
|
1129
|
+
constant uint64_t & nb01,
|
1130
|
+
constant uint64_t & nb02,
|
1131
|
+
constant uint64_t & nb03,
|
1132
|
+
constant int64_t & ne10,
|
1133
|
+
constant int64_t & ne11,
|
1134
|
+
constant int64_t & ne12,
|
1135
|
+
constant int64_t & ne13,
|
1136
|
+
constant uint64_t & nb10,
|
1137
|
+
constant uint64_t & nb11,
|
1138
|
+
constant uint64_t & nb12,
|
1139
|
+
constant uint64_t & nb13,
|
1140
|
+
constant int64_t & ne0,
|
1141
|
+
constant int64_t & ne1,
|
1142
|
+
constant int64_t & ne2,
|
1143
|
+
constant int64_t & ne3,
|
1144
|
+
constant uint64_t & nb0,
|
1145
|
+
constant uint64_t & nb1,
|
1146
|
+
constant uint64_t & nb2,
|
1147
|
+
constant uint64_t & nb3,
|
1148
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1149
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1150
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1151
|
+
|
1152
|
+
const int64_t i03 = tgpig.z;
|
1153
|
+
const int64_t i02 = tgpig.y;
|
1154
|
+
const int64_t i01 = tgpig.x;
|
1155
|
+
|
1156
|
+
const int64_t i13 = i03 % ne13;
|
1157
|
+
const int64_t i12 = i02 % ne12;
|
1158
|
+
const int64_t i11 = i01 % ne11;
|
1159
|
+
|
1160
|
+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
1161
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1162
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1163
|
+
|
1164
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1165
|
+
if (i02 < ne02) {
|
1166
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
1167
|
+
src0_ptr += ntg.x*nb00;
|
1168
|
+
} else {
|
1169
|
+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
1170
|
+
src1_ptr += ntg.x*nb10;
|
1171
|
+
}
|
1172
|
+
dst_ptr += ntg.x*nb0;
|
1173
|
+
}
|
1174
|
+
}
|
1175
|
+
|
1094
1176
|
//============================================ k-quants ======================================================
|
1095
1177
|
|
1096
1178
|
#ifndef QK_K
|
@@ -1183,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
1183
1265
|
|
1184
1266
|
//====================================== dot products =========================
|
1185
1267
|
|
1186
|
-
kernel void
|
1268
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
1187
1269
|
device const void * src0,
|
1188
1270
|
device const float * src1,
|
1189
1271
|
device float * dst,
|
@@ -1327,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1327
1409
|
}
|
1328
1410
|
|
1329
1411
|
#if QK_K == 256
|
1330
|
-
kernel void
|
1412
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1331
1413
|
device const void * src0,
|
1332
1414
|
device const float * src1,
|
1333
1415
|
device float * dst,
|
@@ -1479,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1479
1561
|
}
|
1480
1562
|
}
|
1481
1563
|
#else
|
1482
|
-
kernel void
|
1564
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1483
1565
|
device const void * src0,
|
1484
1566
|
device const float * src1,
|
1485
1567
|
device float * dst,
|
@@ -1550,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1550
1632
|
#endif
|
1551
1633
|
|
1552
1634
|
#if QK_K == 256
|
1553
|
-
kernel void
|
1635
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1554
1636
|
device const void * src0,
|
1555
1637
|
device const float * src1,
|
1556
1638
|
device float * dst,
|
@@ -1656,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1656
1738
|
}
|
1657
1739
|
}
|
1658
1740
|
#else
|
1659
|
-
kernel void
|
1741
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1660
1742
|
device const void * src0,
|
1661
1743
|
device const float * src1,
|
1662
1744
|
device float * dst,
|
@@ -1745,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1745
1827
|
}
|
1746
1828
|
#endif
|
1747
1829
|
|
1748
|
-
kernel void
|
1830
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
1749
1831
|
device const void * src0,
|
1750
1832
|
device const float * src1,
|
1751
1833
|
device float * dst,
|
@@ -1918,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1918
2000
|
|
1919
2001
|
}
|
1920
2002
|
|
1921
|
-
kernel void
|
2003
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
1922
2004
|
device const void * src0,
|
1923
2005
|
device const float * src1,
|
1924
2006
|
device float * dst,
|
@@ -2256,7 +2338,7 @@ kernel void kernel_get_rows(
|
|
2256
2338
|
}
|
2257
2339
|
|
2258
2340
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
2259
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix
|
2341
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
2260
2342
|
#define BLOCK_SIZE_K 32
|
2261
2343
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
2262
2344
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
@@ -2293,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2293
2375
|
const uint r0 = tgpig.y;
|
2294
2376
|
const uint r1 = tgpig.x;
|
2295
2377
|
const uint im = tgpig.z;
|
2378
|
+
|
2296
2379
|
// if this block is of 64x32 shape or smaller
|
2297
2380
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
2298
2381
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
2382
|
+
|
2299
2383
|
// a thread shouldn't load data outside of the matrix
|
2300
2384
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2301
2385
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
@@ -2319,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2319
2403
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2320
2404
|
|
2321
2405
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2322
|
-
//load data and store to threadgroup memory
|
2406
|
+
// load data and store to threadgroup memory
|
2323
2407
|
half4x4 temp_a;
|
2324
2408
|
dequantize_func(x, il, temp_a);
|
2325
2409
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2410
|
+
|
2326
2411
|
#pragma unroll(16)
|
2327
2412
|
for (int i = 0; i < 16; i++) {
|
2328
2413
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
2329
|
-
+
|
2330
|
-
+
|
2414
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
2415
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
2331
2416
|
}
|
2332
|
-
|
2333
|
-
|
2417
|
+
|
2418
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
2419
|
+
|
2334
2420
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
2335
2421
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
2336
2422
|
y += BLOCK_SIZE_K;
|
2337
2423
|
|
2338
2424
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2339
|
-
|
2425
|
+
|
2426
|
+
// load matrices from threadgroup memory and conduct outer products
|
2340
2427
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
2341
2428
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
2429
|
+
|
2342
2430
|
#pragma unroll(4)
|
2343
2431
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
2344
2432
|
#pragma unroll(4)
|
@@ -2353,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2353
2441
|
|
2354
2442
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
2355
2443
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
2444
|
+
|
2356
2445
|
#pragma unroll(8)
|
2357
2446
|
for (int i = 0; i < 8; i++){
|
2358
2447
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
@@ -2361,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2361
2450
|
}
|
2362
2451
|
|
2363
2452
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
2364
|
-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
2365
|
-
|
2453
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
2454
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
2366
2455
|
for (int i = 0; i < 8; i++) {
|
2367
2456
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
2368
2457
|
}
|
2369
2458
|
} else {
|
2370
2459
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
2371
2460
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2372
|
-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
2461
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
2373
2462
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
2374
2463
|
for (int i = 0; i < 8; i++) {
|
2375
2464
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
2376
2465
|
}
|
2377
2466
|
|
2378
2467
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2379
|
-
|
2380
|
-
|
2468
|
+
|
2469
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
2470
|
+
if (sgitg == 0) {
|
2381
2471
|
for (int i = 0; i < n_rows; i++) {
|
2382
|
-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
2472
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
2383
2473
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
2384
2474
|
}
|
2385
2475
|
}
|