llama_cpp 0.10.3 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/llama_cpp/extconf.rb +35 -110
  5. data/ext/llama_cpp/llama_cpp.cpp +52 -28
  6. data/lib/llama_cpp/version.rb +2 -2
  7. data/sig/llama_cpp.rbs +3 -1
  8. data/vendor/include/.gitkeep +0 -0
  9. data/vendor/lib/.gitkeep +0 -0
  10. data/vendor/tmp/llama.cpp/Makefile +758 -0
  11. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +6 -2
  12. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +73 -63
  13. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +1 -0
  14. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +43 -20
  15. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +464 -245
  16. data/vendor/tmp/llama.cpp/ggml-opencl.h +25 -0
  17. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +61 -57
  18. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +171 -5
  19. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +1 -0
  20. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +222 -105
  21. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +31 -32
  22. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
  23. metadata +30 -27
  24. data/ext/llama_cpp/src/ggml-opencl.h +0 -25
  25. data/ext/llama_cpp/src/llama-util.h +0 -546
  26. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
  27. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
  28. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
  29. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
  30. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
  31. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
  32. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
  33. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
  34. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
  35. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
  36. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
  37. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
@@ -59,26 +59,26 @@ kernel void kernel_add(
59
59
  constant int64_t & ne01,
60
60
  constant int64_t & ne02,
61
61
  constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
62
+ constant uint64_t & nb00,
63
+ constant uint64_t & nb01,
64
+ constant uint64_t & nb02,
65
+ constant uint64_t & nb03,
66
66
  constant int64_t & ne10,
67
67
  constant int64_t & ne11,
68
68
  constant int64_t & ne12,
69
69
  constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
70
+ constant uint64_t & nb10,
71
+ constant uint64_t & nb11,
72
+ constant uint64_t & nb12,
73
+ constant uint64_t & nb13,
74
74
  constant int64_t & ne0,
75
75
  constant int64_t & ne1,
76
76
  constant int64_t & ne2,
77
77
  constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
78
+ constant uint64_t & nb0,
79
+ constant uint64_t & nb1,
80
+ constant uint64_t & nb2,
81
+ constant uint64_t & nb3,
82
82
  constant int64_t & offs,
83
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
84
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -109,26 +109,26 @@ kernel void kernel_mul(
109
109
  constant int64_t & ne01,
110
110
  constant int64_t & ne02,
111
111
  constant int64_t & ne03,
112
- constant int64_t & nb00,
113
- constant int64_t & nb01,
114
- constant int64_t & nb02,
115
- constant int64_t & nb03,
112
+ constant uint64_t & nb00,
113
+ constant uint64_t & nb01,
114
+ constant uint64_t & nb02,
115
+ constant uint64_t & nb03,
116
116
  constant int64_t & ne10,
117
117
  constant int64_t & ne11,
118
118
  constant int64_t & ne12,
119
119
  constant int64_t & ne13,
120
- constant int64_t & nb10,
121
- constant int64_t & nb11,
122
- constant int64_t & nb12,
123
- constant int64_t & nb13,
120
+ constant uint64_t & nb10,
121
+ constant uint64_t & nb11,
122
+ constant uint64_t & nb12,
123
+ constant uint64_t & nb13,
124
124
  constant int64_t & ne0,
125
125
  constant int64_t & ne1,
126
126
  constant int64_t & ne2,
127
127
  constant int64_t & ne3,
128
- constant int64_t & nb0,
129
- constant int64_t & nb1,
130
- constant int64_t & nb2,
131
- constant int64_t & nb3,
128
+ constant uint64_t & nb0,
129
+ constant uint64_t & nb1,
130
+ constant uint64_t & nb2,
131
+ constant uint64_t & nb3,
132
132
  uint3 tgpig[[threadgroup_position_in_grid]],
133
133
  uint3 tpitg[[thread_position_in_threadgroup]],
134
134
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -158,26 +158,26 @@ kernel void kernel_div(
158
158
  constant int64_t & ne01,
159
159
  constant int64_t & ne02,
160
160
  constant int64_t & ne03,
161
- constant int64_t & nb00,
162
- constant int64_t & nb01,
163
- constant int64_t & nb02,
164
- constant int64_t & nb03,
161
+ constant uint64_t & nb00,
162
+ constant uint64_t & nb01,
163
+ constant uint64_t & nb02,
164
+ constant uint64_t & nb03,
165
165
  constant int64_t & ne10,
166
166
  constant int64_t & ne11,
167
167
  constant int64_t & ne12,
168
168
  constant int64_t & ne13,
169
- constant int64_t & nb10,
170
- constant int64_t & nb11,
171
- constant int64_t & nb12,
172
- constant int64_t & nb13,
169
+ constant uint64_t & nb10,
170
+ constant uint64_t & nb11,
171
+ constant uint64_t & nb12,
172
+ constant uint64_t & nb13,
173
173
  constant int64_t & ne0,
174
174
  constant int64_t & ne1,
175
175
  constant int64_t & ne2,
176
176
  constant int64_t & ne3,
177
- constant int64_t & nb0,
178
- constant int64_t & nb1,
179
- constant int64_t & nb2,
180
- constant int64_t & nb3,
177
+ constant uint64_t & nb0,
178
+ constant uint64_t & nb1,
179
+ constant uint64_t & nb2,
180
+ constant uint64_t & nb3,
181
181
  uint3 tgpig[[threadgroup_position_in_grid]],
182
182
  uint3 tpitg[[thread_position_in_threadgroup]],
183
183
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
205
205
  device const float4 * src0,
206
206
  device const float4 * src1,
207
207
  device float4 * dst,
208
- constant int64_t & nb [[buffer(28)]],
208
+ constant uint64_t & nb [[buffer(28)]],
209
209
  uint tpig[[thread_position_in_grid]]) {
210
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
211
  }
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
214
214
  device const float4 * src0,
215
215
  device const float4 * src1,
216
216
  device float4 * dst,
217
- constant int64_t & nb [[buffer(28)]],
217
+ constant uint64_t & nb [[buffer(28)]],
218
218
  uint tpig[[thread_position_in_grid]]) {
219
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
220
  }
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
223
223
  device const float4 * src0,
224
224
  device const float4 * src1,
225
225
  device float4 * dst,
226
- constant int64_t & nb [[buffer(28)]],
226
+ constant uint64_t & nb [[buffer(28)]],
227
227
  uint tpig[[thread_position_in_grid]]) {
228
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
229
  }
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
307
307
  constant int64_t & ne01,
308
308
  constant int64_t & ne02,
309
309
  constant int64_t & ne03,
310
- constant int64_t & nb00,
311
- constant int64_t & nb01,
312
- constant int64_t & nb02,
313
- constant int64_t & nb03,
310
+ constant uint64_t & nb00,
311
+ constant uint64_t & nb01,
312
+ constant uint64_t & nb02,
313
+ constant uint64_t & nb03,
314
314
  constant int64_t & ne10,
315
315
  constant int64_t & ne11,
316
316
  constant int64_t & ne12,
317
317
  constant int64_t & ne13,
318
- constant int64_t & nb10,
319
- constant int64_t & nb11,
320
- constant int64_t & nb12,
321
- constant int64_t & nb13,
318
+ constant uint64_t & nb10,
319
+ constant uint64_t & nb11,
320
+ constant uint64_t & nb12,
321
+ constant uint64_t & nb13,
322
322
  constant int64_t & ne0,
323
323
  constant int64_t & ne1,
324
324
  constant int64_t & ne2,
325
325
  constant int64_t & ne3,
326
- constant int64_t & nb0,
327
- constant int64_t & nb1,
328
- constant int64_t & nb2,
329
- constant int64_t & nb3,
326
+ constant uint64_t & nb0,
327
+ constant uint64_t & nb1,
328
+ constant uint64_t & nb2,
329
+ constant uint64_t & nb3,
330
330
  uint3 tpig[[thread_position_in_grid]]) {
331
331
  int64_t i3 = tpig.z;
332
332
  int64_t i2 = tpig.y;
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
846
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
847
  //Note: This is a template, but strictly speaking it only applies to
848
848
  // quantizations where the block size is 32. It also does not
849
- // giard against the number of rows not being divisible by
849
+ // guard against the number of rows not being divisible by
850
850
  // N_DST, so this is another explicit assumption of the implementation.
851
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
852
  void mul_vec_q_n_f32_impl(
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
920
920
  device const float * src1,
921
921
  device float * dst,
922
922
  constant int64_t & ne00,
923
- constant int64_t & ne01[[buffer(4)]],
924
- constant int64_t & ne02[[buffer(5)]],
925
- constant int64_t & ne10[[buffer(9)]],
926
- constant int64_t & ne12[[buffer(11)]],
927
- constant int64_t & ne0 [[buffer(15)]],
928
- constant int64_t & ne1 [[buffer(16)]],
929
- constant uint & r2 [[buffer(17)]],
930
- constant uint & r3 [[buffer(18)]],
923
+ constant int64_t & ne01,
924
+ constant int64_t & ne02,
925
+ constant uint64_t & nb00,
926
+ constant uint64_t & nb01,
927
+ constant uint64_t & nb02,
928
+ constant int64_t & ne10,
929
+ constant int64_t & ne11,
930
+ constant int64_t & ne12,
931
+ constant uint64_t & nb10,
932
+ constant uint64_t & nb11,
933
+ constant uint64_t & nb12,
934
+ constant int64_t & ne0,
935
+ constant int64_t & ne1,
936
+ constant uint & r2,
937
+ constant uint & r3,
931
938
  uint3 tgpig[[threadgroup_position_in_grid]],
932
939
  uint tiisg[[thread_index_in_simdgroup]],
933
940
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
939
946
  device const float * src1,
940
947
  device float * dst,
941
948
  constant int64_t & ne00,
942
- constant int64_t & ne01[[buffer(4)]],
943
- constant int64_t & ne02[[buffer(5)]],
944
- constant int64_t & ne10[[buffer(9)]],
945
- constant int64_t & ne12[[buffer(11)]],
946
- constant int64_t & ne0 [[buffer(15)]],
947
- constant int64_t & ne1 [[buffer(16)]],
948
- constant uint & r2 [[buffer(17)]],
949
- constant uint & r3 [[buffer(18)]],
949
+ constant int64_t & ne01,
950
+ constant int64_t & ne02,
951
+ constant uint64_t & nb00,
952
+ constant uint64_t & nb01,
953
+ constant uint64_t & nb02,
954
+ constant int64_t & ne10,
955
+ constant int64_t & ne11,
956
+ constant int64_t & ne12,
957
+ constant uint64_t & nb10,
958
+ constant uint64_t & nb11,
959
+ constant uint64_t & nb12,
960
+ constant int64_t & ne0,
961
+ constant int64_t & ne1,
962
+ constant uint & r2,
963
+ constant uint & r3,
950
964
  uint3 tgpig[[threadgroup_position_in_grid]],
951
965
  uint tiisg[[thread_index_in_simdgroup]],
952
966
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
958
972
  device const float * src1,
959
973
  device float * dst,
960
974
  constant int64_t & ne00,
961
- constant int64_t & ne01[[buffer(4)]],
962
- constant int64_t & ne02[[buffer(5)]],
963
- constant int64_t & ne10[[buffer(9)]],
964
- constant int64_t & ne12[[buffer(11)]],
965
- constant int64_t & ne0 [[buffer(15)]],
966
- constant int64_t & ne1 [[buffer(16)]],
967
- constant uint & r2 [[buffer(17)]],
968
- constant uint & r3 [[buffer(18)]],
975
+ constant int64_t & ne01,
976
+ constant int64_t & ne02,
977
+ constant uint64_t & nb00,
978
+ constant uint64_t & nb01,
979
+ constant uint64_t & nb02,
980
+ constant int64_t & ne10,
981
+ constant int64_t & ne11,
982
+ constant int64_t & ne12,
983
+ constant uint64_t & nb10,
984
+ constant uint64_t & nb11,
985
+ constant uint64_t & nb12,
986
+ constant int64_t & ne0,
987
+ constant int64_t & ne1,
988
+ constant uint & r2,
989
+ constant uint & r3,
969
990
  uint3 tgpig[[threadgroup_position_in_grid]],
970
991
  uint tiisg[[thread_index_in_simdgroup]],
971
992
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
977
998
  device const float * src1,
978
999
  device float * dst,
979
1000
  constant int64_t & ne00,
980
- constant int64_t & ne01[[buffer(4)]],
981
- constant int64_t & ne02[[buffer(5)]],
982
- constant int64_t & ne10[[buffer(9)]],
983
- constant int64_t & ne12[[buffer(11)]],
984
- constant int64_t & ne0 [[buffer(15)]],
985
- constant int64_t & ne1 [[buffer(16)]],
986
- constant uint & r2 [[buffer(17)]],
987
- constant uint & r3 [[buffer(18)]],
1001
+ constant int64_t & ne01,
1002
+ constant int64_t & ne02,
1003
+ constant uint64_t & nb00,
1004
+ constant uint64_t & nb01,
1005
+ constant uint64_t & nb02,
1006
+ constant int64_t & ne10,
1007
+ constant int64_t & ne11,
1008
+ constant int64_t & ne12,
1009
+ constant uint64_t & nb10,
1010
+ constant uint64_t & nb11,
1011
+ constant uint64_t & nb12,
1012
+ constant int64_t & ne0,
1013
+ constant int64_t & ne1,
1014
+ constant uint & r2,
1015
+ constant uint & r3,
988
1016
  uint3 tgpig[[threadgroup_position_in_grid]],
989
1017
  uint tiisg[[thread_index_in_simdgroup]],
990
1018
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
1071
1099
  constant int64_t & ne00,
1072
1100
  constant int64_t & ne01,
1073
1101
  constant int64_t & ne02,
1102
+ constant uint64_t & nb00,
1103
+ constant uint64_t & nb01,
1104
+ constant uint64_t & nb02,
1074
1105
  constant int64_t & ne10,
1106
+ constant int64_t & ne11,
1075
1107
  constant int64_t & ne12,
1108
+ constant uint64_t & nb10,
1109
+ constant uint64_t & nb11,
1110
+ constant uint64_t & nb12,
1076
1111
  constant int64_t & ne0,
1077
1112
  constant int64_t & ne1,
1078
- constant uint & r2 [[buffer(17)]],
1079
- constant uint & r3 [[buffer(18)]],
1113
+ constant uint & r2,
1114
+ constant uint & r3,
1080
1115
  uint3 tgpig[[threadgroup_position_in_grid]],
1081
1116
  uint tiisg[[thread_index_in_simdgroup]],
1082
1117
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
1182
1217
  constant uint64_t & nb12,
1183
1218
  constant int64_t & ne0,
1184
1219
  constant int64_t & ne1,
1185
- constant uint & r2 [[buffer(17)]],
1186
- constant uint & r3 [[buffer(18)]],
1220
+ constant uint & r2,
1221
+ constant uint & r3,
1187
1222
  uint3 tgpig[[threadgroup_position_in_grid]],
1188
1223
  uint tiisg[[thread_index_in_simdgroup]]) {
1189
1224
  kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
1209
1244
  constant uint64_t & nb12,
1210
1245
  constant int64_t & ne0,
1211
1246
  constant int64_t & ne1,
1212
- constant uint & r2 [[buffer(17)]],
1213
- constant uint & r3 [[buffer(18)]],
1247
+ constant uint & r2,
1248
+ constant uint & r3,
1214
1249
  uint3 tgpig[[threadgroup_position_in_grid]],
1215
1250
  uint tiisg[[thread_index_in_simdgroup]]) {
1216
1251
 
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1346
1381
  constant uint64_t & nb12,
1347
1382
  constant int64_t & ne0,
1348
1383
  constant int64_t & ne1,
1349
- constant uint & r2 [[buffer(17)]],
1350
- constant uint & r3 [[buffer(18)]],
1384
+ constant uint & r2,
1385
+ constant uint & r3,
1351
1386
  uint3 tgpig[[threadgroup_position_in_grid]],
1352
1387
  uint tiisg[[thread_index_in_simdgroup]]) {
1353
1388
  kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
1452
1487
  constant uint64_t & nb12,
1453
1488
  constant int64_t & ne0,
1454
1489
  constant int64_t & ne1,
1455
- constant uint & r2 [[buffer(17)]],
1456
- constant uint & r3 [[buffer(18)]],
1490
+ constant uint & r2,
1491
+ constant uint & r3,
1457
1492
  uint3 tgpig[[threadgroup_position_in_grid]],
1458
1493
  uint tiisg[[thread_index_in_simdgroup]]) {
1459
1494
  kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1478
1513
  constant uint64_t & nb12,
1479
1514
  constant int64_t & ne0,
1480
1515
  constant int64_t & ne1,
1481
- constant uint & r2 [[buffer(17)]],
1482
- constant uint & r3 [[buffer(18)]],
1516
+ constant uint & r2,
1517
+ constant uint & r3,
1483
1518
  uint3 tgpig[[threadgroup_position_in_grid]],
1484
1519
  uint tiisg[[thread_index_in_simdgroup]]) {
1485
1520
 
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
1543
1578
  const int64_t i3 = n / (ne2*ne1*ne0);
1544
1579
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1545
1580
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1546
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1581
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1582
+
1547
1583
  const int64_t k = i3*ne3 + i2;
1548
1584
 
1549
1585
  float m_k;
@@ -2410,22 +2446,6 @@ typedef struct {
2410
2446
  } block_q6_K;
2411
2447
  // 210 bytes / block
2412
2448
 
2413
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2414
- uchar4 r;
2415
- if (j < 4) {
2416
- r[0] = q[j+0] & 63;
2417
- r[2] = q[j+1] & 63;
2418
- r[1] = q[j+4] & 63;
2419
- r[3] = q[j+5] & 63;
2420
- } else {
2421
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2422
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2423
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2424
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2425
- }
2426
- return r;
2427
- }
2428
-
2429
2449
  //====================================== dot products =========================
2430
2450
 
2431
2451
  void kernel_mul_mv_q2_K_f32_impl(
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
2584
2604
  device const float * src1,
2585
2605
  device float * dst,
2586
2606
  constant int64_t & ne00,
2587
- constant int64_t & ne01[[buffer(4)]],
2588
- constant int64_t & ne02[[buffer(5)]],
2589
- constant int64_t & ne10[[buffer(9)]],
2590
- constant int64_t & ne12[[buffer(11)]],
2591
- constant int64_t & ne0 [[buffer(15)]],
2592
- constant int64_t & ne1 [[buffer(16)]],
2593
- constant uint & r2 [[buffer(17)]],
2594
- constant uint & r3 [[buffer(18)]],
2607
+ constant int64_t & ne01,
2608
+ constant int64_t & ne02,
2609
+ constant uint64_t & nb00,
2610
+ constant uint64_t & nb01,
2611
+ constant uint64_t & nb02,
2612
+ constant int64_t & ne10,
2613
+ constant int64_t & ne11,
2614
+ constant int64_t & ne12,
2615
+ constant uint64_t & nb10,
2616
+ constant uint64_t & nb11,
2617
+ constant uint64_t & nb12,
2618
+ constant int64_t & ne0,
2619
+ constant int64_t & ne1,
2620
+ constant uint & r2,
2621
+ constant uint & r3,
2595
2622
  uint3 tgpig[[threadgroup_position_in_grid]],
2596
2623
  uint tiisg[[thread_index_in_simdgroup]],
2597
2624
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
2841
2868
  device const float * src1,
2842
2869
  device float * dst,
2843
2870
  constant int64_t & ne00,
2844
- constant int64_t & ne01[[buffer(4)]],
2845
- constant int64_t & ne02[[buffer(5)]],
2846
- constant int64_t & ne10[[buffer(9)]],
2847
- constant int64_t & ne12[[buffer(11)]],
2848
- constant int64_t & ne0 [[buffer(15)]],
2849
- constant int64_t & ne1 [[buffer(16)]],
2850
- constant uint & r2 [[buffer(17)]],
2851
- constant uint & r3 [[buffer(18)]],
2871
+ constant int64_t & ne01,
2872
+ constant int64_t & ne02,
2873
+ constant uint64_t & nb00,
2874
+ constant uint64_t & nb01,
2875
+ constant uint64_t & nb02,
2876
+ constant int64_t & ne10,
2877
+ constant int64_t & ne11,
2878
+ constant int64_t & ne12,
2879
+ constant uint64_t & nb10,
2880
+ constant uint64_t & nb11,
2881
+ constant uint64_t & nb12,
2882
+ constant int64_t & ne0,
2883
+ constant int64_t & ne1,
2884
+ constant uint & r2,
2885
+ constant uint & r3,
2852
2886
  uint3 tgpig[[threadgroup_position_in_grid]],
2853
2887
  uint tiisg[[thread_index_in_simdgroup]],
2854
2888
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
2984
3018
  constant uint & r2,
2985
3019
  constant uint & r3,
2986
3020
  uint3 tgpig[[threadgroup_position_in_grid]],
2987
- uint tiisg[[thread_index_in_simdgroup]],
2988
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3021
+ uint tiisg[[thread_index_in_simdgroup]],
3022
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2989
3023
 
2990
3024
  const int ix = tiisg/4; // 0...7
2991
3025
  const int it = tiisg%4; // 0...3
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
2994
3028
  const int r0 = tgpig.x;
2995
3029
  const int r1 = tgpig.y;
2996
3030
  const int im = tgpig.z;
2997
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3031
+ const int first_row = r0 * N_DST;
2998
3032
  const int ib_row = first_row * nb;
2999
3033
 
3000
3034
  const uint i12 = im%ne12;
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3060
3094
  for (int row = 0; row < N_DST; ++row) {
3061
3095
  all_sum = simd_sum(sumf[row]);
3062
3096
  if (tiisg == 0) {
3063
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3097
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
3064
3098
  }
3065
3099
  }
3066
3100
  }
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
3072
3106
  device const float * src1,
3073
3107
  device float * dst,
3074
3108
  constant int64_t & ne00,
3075
- constant int64_t & ne01[[buffer(4)]],
3076
- constant int64_t & ne02[[buffer(5)]],
3077
- constant int64_t & ne10[[buffer(9)]],
3078
- constant int64_t & ne12[[buffer(11)]],
3079
- constant int64_t & ne0 [[buffer(15)]],
3080
- constant int64_t & ne1 [[buffer(16)]],
3081
- constant uint & r2 [[buffer(17)]],
3082
- constant uint & r3 [[buffer(18)]],
3109
+ constant int64_t & ne01,
3110
+ constant int64_t & ne02,
3111
+ constant uint64_t & nb00,
3112
+ constant uint64_t & nb01,
3113
+ constant uint64_t & nb02,
3114
+ constant int64_t & ne10,
3115
+ constant int64_t & ne11,
3116
+ constant int64_t & ne12,
3117
+ constant uint64_t & nb10,
3118
+ constant uint64_t & nb11,
3119
+ constant uint64_t & nb12,
3120
+ constant int64_t & ne0,
3121
+ constant int64_t & ne1,
3122
+ constant uint & r2,
3123
+ constant uint & r3,
3083
3124
  uint3 tgpig[[threadgroup_position_in_grid]],
3084
3125
  uint tiisg[[thread_index_in_simdgroup]],
3085
3126
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
3271
3312
  device const float * src1,
3272
3313
  device float * dst,
3273
3314
  constant int64_t & ne00,
3274
- constant int64_t & ne01[[buffer(4)]],
3275
- constant int64_t & ne02[[buffer(5)]],
3276
- constant int64_t & ne10[[buffer(9)]],
3277
- constant int64_t & ne12[[buffer(11)]],
3278
- constant int64_t & ne0 [[buffer(15)]],
3279
- constant int64_t & ne1 [[buffer(16)]],
3280
- constant uint & r2 [[buffer(17)]],
3281
- constant uint & r3 [[buffer(18)]],
3315
+ constant int64_t & ne01,
3316
+ constant int64_t & ne02,
3317
+ constant uint64_t & nb00,
3318
+ constant uint64_t & nb01,
3319
+ constant uint64_t & nb02,
3320
+ constant int64_t & ne10,
3321
+ constant int64_t & ne11,
3322
+ constant int64_t & ne12,
3323
+ constant uint64_t & nb10,
3324
+ constant uint64_t & nb11,
3325
+ constant uint64_t & nb12,
3326
+ constant int64_t & ne0,
3327
+ constant int64_t & ne1,
3328
+ constant uint & r2,
3329
+ constant uint & r3,
3282
3330
  uint3 tgpig[[threadgroup_position_in_grid]],
3283
3331
  uint tiisg[[thread_index_in_simdgroup]],
3284
3332
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
3398
3446
  device const float * src1,
3399
3447
  device float * dst,
3400
3448
  constant int64_t & ne00,
3401
- constant int64_t & ne01[[buffer(4)]],
3402
- constant int64_t & ne02[[buffer(5)]],
3403
- constant int64_t & ne10[[buffer(9)]],
3404
- constant int64_t & ne12[[buffer(11)]],
3405
- constant int64_t & ne0 [[buffer(15)]],
3406
- constant int64_t & ne1 [[buffer(16)]],
3407
- constant uint & r2 [[buffer(17)]],
3408
- constant uint & r3 [[buffer(18)]],
3449
+ constant int64_t & ne01,
3450
+ constant int64_t & ne02,
3451
+ constant uint64_t & nb00,
3452
+ constant uint64_t & nb01,
3453
+ constant uint64_t & nb02,
3454
+ constant int64_t & ne10,
3455
+ constant int64_t & ne11,
3456
+ constant int64_t & ne12,
3457
+ constant uint64_t & nb10,
3458
+ constant uint64_t & nb11,
3459
+ constant uint64_t & nb12,
3460
+ constant int64_t & ne0,
3461
+ constant int64_t & ne1,
3462
+ constant uint & r2,
3463
+ constant uint & r3,
3409
3464
  uint3 tgpig[[threadgroup_position_in_grid]],
3410
3465
  uint tiisg[[thread_index_in_simdgroup]],
3411
3466
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3523
3578
  device const int8_t * qs = ((device const int8_t *)xb->qs);
3524
3579
  const half d = xb->d;
3525
3580
 
3526
- for (int i=0;i<16;i++) {
3581
+ for (int i = 0; i < 16; i++) {
3527
3582
  reg[i/4][i%4] = (qs[i + 16*il] * d);
3528
3583
  }
3529
3584
  }
@@ -3774,6 +3829,35 @@ kernel void kernel_get_rows_f16(
3774
3829
  }
3775
3830
  }
3776
3831
 
3832
+ kernel void kernel_get_rows_i32(
3833
+ device const void * src0,
3834
+ device const char * src1,
3835
+ device int32_t * dst,
3836
+ constant int64_t & ne00,
3837
+ constant uint64_t & nb01,
3838
+ constant uint64_t & nb02,
3839
+ constant int64_t & ne10,
3840
+ constant uint64_t & nb10,
3841
+ constant uint64_t & nb11,
3842
+ constant uint64_t & nb1,
3843
+ constant uint64_t & nb2,
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiitg[[thread_index_in_threadgroup]],
3846
+ uint3 tptg [[threads_per_threadgroup]]) {
3847
+ const int64_t i10 = tgpig.x;
3848
+ const int64_t i11 = tgpig.y;
3849
+
3850
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3851
+
3852
+ const int64_t i02 = i11;
3853
+
3854
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3855
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3856
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3857
+ }
3858
+ }
3859
+
3860
+
3777
3861
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3778
3862
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3779
3863
  #define BLOCK_SIZE_K 32
@@ -3792,12 +3876,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
3792
3876
  device float * dst,
3793
3877
  constant int64_t & ne00,
3794
3878
  constant int64_t & ne02,
3795
- constant int64_t & nb01,
3796
- constant int64_t & nb02,
3879
+ constant uint64_t & nb01,
3880
+ constant uint64_t & nb02,
3797
3881
  constant int64_t & ne12,
3798
- constant int64_t & nb10,
3799
- constant int64_t & nb11,
3800
- constant int64_t & nb12,
3882
+ constant uint64_t & nb10,
3883
+ constant uint64_t & nb11,
3884
+ constant uint64_t & nb12,
3801
3885
  constant int64_t & ne0,
3802
3886
  constant int64_t & ne1,
3803
3887
  constant uint & r2,
@@ -3918,18 +4002,143 @@ void kernel_mul_mm_impl(device const uchar * src0,
3918
4002
  }
3919
4003
  }
3920
4004
 
4005
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
4006
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4007
+ void kernel_mul_mm_id_impl(
4008
+ device const uchar * src0,
4009
+ device const uchar * src1,
4010
+ thread short * src1ids,
4011
+ device float * dst,
4012
+ constant int64_t & ne00,
4013
+ constant int64_t & ne02,
4014
+ constant uint64_t & nb01,
4015
+ constant uint64_t & nb02,
4016
+ constant int64_t & ne12,
4017
+ constant uint64_t & nb10,
4018
+ constant uint64_t & nb11,
4019
+ constant uint64_t & nb12,
4020
+ constant int64_t & ne0,
4021
+ int64_t ne1,
4022
+ constant uint & r2,
4023
+ constant uint & r3,
4024
+ threadgroup uchar * shared_memory,
4025
+ uint3 tgpig[[threadgroup_position_in_grid]],
4026
+ uint tiitg[[thread_index_in_threadgroup]],
4027
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4028
+
4029
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
4030
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
4031
+
4032
+ const uint r0 = tgpig.y;
4033
+ const uint r1 = tgpig.x;
4034
+ const uint im = tgpig.z;
4035
+
4036
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
4037
+
4038
+ // if this block is of 64x32 shape or smaller
4039
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4040
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
4041
+
4042
+ // a thread shouldn't load data outside of the matrix
4043
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4044
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
4045
+
4046
+ simdgroup_half8x8 ma[4];
4047
+ simdgroup_float8x8 mb[2];
4048
+ simdgroup_float8x8 c_res[8];
4049
+ for (int i = 0; i < 8; i++){
4050
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4051
+ }
4052
+
4053
+ short il = (tiitg % THREAD_PER_ROW);
4054
+
4055
+ const uint i12 = im%ne12;
4056
+ const uint i13 = im/ne12;
4057
+
4058
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4059
+ ushort offset1 = il/nl;
4060
+
4061
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4062
+ device const float * y = (device const float *)(src1
4063
+ + nb12 * im
4064
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
4065
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
4066
+
4067
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4068
+ // load data and store to threadgroup memory
4069
+ half4x4 temp_a;
4070
+ dequantize_func(x, il, temp_a);
4071
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4072
+
4073
+ for (int i = 0; i < 16; i++) {
4074
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4075
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4076
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4077
+ }
4078
+
4079
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
4080
+
4081
+ il = (il + 2 < nl) ? il + 2 : il % 2;
4082
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
4083
+ y += BLOCK_SIZE_K;
4084
+
4085
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4086
+
4087
+ // load matrices from threadgroup memory and conduct outer products
4088
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4089
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
4090
+
4091
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4092
+ for (int i = 0; i < 4; i++) {
4093
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4094
+ }
4095
+ simdgroup_barrier(mem_flags::mem_none);
4096
+ for (int i = 0; i < 2; i++) {
4097
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4098
+ }
4099
+
4100
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4101
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
4102
+
4103
+ for (int i = 0; i < 8; i++){
4104
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4105
+ }
4106
+ }
4107
+ }
4108
+
4109
+ {
4110
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4111
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4112
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4113
+ for (int i = 0; i < 8; i++) {
4114
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4115
+ }
4116
+
4117
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4118
+
4119
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
4120
+ if (sgitg == 0) {
4121
+ for (int i = 0; i < n_rows; i++) {
4122
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4123
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4124
+ }
4125
+ }
4126
+ }
4127
+ }
4128
+ }
4129
+
3921
4130
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3922
4131
  kernel void kernel_mul_mm(device const uchar * src0,
3923
4132
  device const uchar * src1,
3924
4133
  device float * dst,
3925
4134
  constant int64_t & ne00,
3926
4135
  constant int64_t & ne02,
3927
- constant int64_t & nb01,
3928
- constant int64_t & nb02,
4136
+ constant uint64_t & nb01,
4137
+ constant uint64_t & nb02,
3929
4138
  constant int64_t & ne12,
3930
- constant int64_t & nb10,
3931
- constant int64_t & nb11,
3932
- constant int64_t & nb12,
4139
+ constant uint64_t & nb10,
4140
+ constant uint64_t & nb11,
4141
+ constant uint64_t & nb12,
3933
4142
  constant int64_t & ne0,
3934
4143
  constant int64_t & ne1,
3935
4144
  constant uint & r2,
@@ -3964,20 +4173,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
3964
4173
  kernel void kernel_mul_mm_id(
3965
4174
  device const uchar * ids,
3966
4175
  device const uchar * src1,
3967
- device uchar * dst,
3968
- constant int64_t & nbi1,
4176
+ device float * dst,
4177
+ constant uint64_t & nbi1,
3969
4178
  constant int64_t & ne00,
3970
4179
  constant int64_t & ne02,
3971
- constant int64_t & nb01,
3972
- constant int64_t & nb02,
4180
+ constant uint64_t & nb01,
4181
+ constant uint64_t & nb02,
3973
4182
  constant int64_t & ne12,
3974
4183
  constant int64_t & ne13,
3975
- constant int64_t & nb10,
3976
- constant int64_t & nb11,
3977
- constant int64_t & nb12,
4184
+ constant uint64_t & nb10,
4185
+ constant uint64_t & nb11,
4186
+ constant uint64_t & nb12,
3978
4187
  constant int64_t & ne0,
3979
4188
  constant int64_t & ne1,
3980
- constant int64_t & nb1,
4189
+ constant uint64_t & nb1,
3981
4190
  constant uint & r2,
3982
4191
  constant uint & r3,
3983
4192
  constant int & idx,
@@ -3993,18 +4202,28 @@ kernel void kernel_mul_mm_id(
3993
4202
  uint3 tgpig[[threadgroup_position_in_grid]],
3994
4203
  uint tiitg[[thread_index_in_threadgroup]],
3995
4204
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3996
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4205
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3997
4206
 
3998
- const int64_t bid = tgpig.z/(ne12*ne13);
4207
+ // expert id
4208
+ const int32_t id = tgpig.z/(ne12*ne13);
3999
4209
 
4000
4210
  tgpig.z = tgpig.z%(ne12*ne13);
4001
4211
 
4002
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4212
+ // row indices of src1 for expert id
4213
+ int64_t _ne1 = 0;
4214
+ short src1ids[512];
4003
4215
 
4004
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
4005
- src0[id],
4006
- src1 + bid*nb11,
4007
- (device float *) (dst + bid*nb1),
4216
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
4217
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4218
+ src1ids[_ne1++] = i1;
4219
+ }
4220
+ }
4221
+
4222
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
4223
+ src0s[id],
4224
+ src1,
4225
+ src1ids,
4226
+ dst,
4008
4227
  ne00,
4009
4228
  ne02,
4010
4229
  nb01,
@@ -4014,7 +4233,7 @@ kernel void kernel_mul_mm_id(
4014
4233
  nb11,
4015
4234
  nb12,
4016
4235
  ne0,
4017
- ne1,
4236
+ _ne1,
4018
4237
  r2,
4019
4238
  r3,
4020
4239
  shared_memory,
@@ -4070,12 +4289,12 @@ typedef void (mat_mm_t)(
4070
4289
  device float * dst,
4071
4290
  constant int64_t & ne00,
4072
4291
  constant int64_t & ne02,
4073
- constant int64_t & nb01,
4074
- constant int64_t & nb02,
4292
+ constant uint64_t & nb01,
4293
+ constant uint64_t & nb02,
4075
4294
  constant int64_t & ne12,
4076
- constant int64_t & nb10,
4077
- constant int64_t & nb11,
4078
- constant int64_t & nb12,
4295
+ constant uint64_t & nb10,
4296
+ constant uint64_t & nb11,
4297
+ constant uint64_t & nb12,
4079
4298
  constant int64_t & ne0,
4080
4299
  constant int64_t & ne1,
4081
4300
  constant uint & r2,
@@ -4103,20 +4322,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4103
4322
  typedef void (mat_mm_id_t)(
4104
4323
  device const uchar * ids,
4105
4324
  device const uchar * src1,
4106
- device uchar * dst,
4107
- constant int64_t & nbi1,
4325
+ device float * dst,
4326
+ constant uint64_t & nbi1,
4108
4327
  constant int64_t & ne00,
4109
4328
  constant int64_t & ne02,
4110
- constant int64_t & nb01,
4111
- constant int64_t & nb02,
4329
+ constant uint64_t & nb01,
4330
+ constant uint64_t & nb02,
4112
4331
  constant int64_t & ne12,
4113
4332
  constant int64_t & ne13,
4114
- constant int64_t & nb10,
4115
- constant int64_t & nb11,
4116
- constant int64_t & nb12,
4333
+ constant uint64_t & nb10,
4334
+ constant uint64_t & nb11,
4335
+ constant uint64_t & nb12,
4117
4336
  constant int64_t & ne0,
4118
4337
  constant int64_t & ne1,
4119
- constant int64_t & nb1,
4338
+ constant uint64_t & nb1,
4120
4339
  constant uint & r2,
4121
4340
  constant uint & r3,
4122
4341
  constant int & idx,
@@ -4152,8 +4371,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
4152
4371
  kernel void kernel_mul_mv_id_f32_f32(
4153
4372
  device const char * ids,
4154
4373
  device const char * src1,
4155
- device uchar * dst,
4156
- constant int64_t & nbi1,
4374
+ device float * dst,
4375
+ constant uint64_t & nbi1,
4157
4376
  constant int64_t & ne00,
4158
4377
  constant int64_t & ne01,
4159
4378
  constant int64_t & ne02,
@@ -4169,7 +4388,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4169
4388
  constant uint64_t & nb12,
4170
4389
  constant int64_t & ne0,
4171
4390
  constant int64_t & ne1,
4172
- constant int64_t & nb1,
4391
+ constant uint64_t & nb1,
4173
4392
  constant uint & r2,
4174
4393
  constant uint & r3,
4175
4394
  constant int & idx,
@@ -4196,7 +4415,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4196
4415
  kernel_mul_mv_f32_f32_impl(
4197
4416
  src0[id],
4198
4417
  src1 + bid*nb11,
4199
- (device float *) (dst + bid*nb1),
4418
+ dst + bid*ne0,
4200
4419
  ne00,
4201
4420
  ne01,
4202
4421
  ne02,
@@ -4221,8 +4440,8 @@ kernel void kernel_mul_mv_id_f32_f32(
4221
4440
  kernel void kernel_mul_mv_id_f16_f32(
4222
4441
  device const char * ids,
4223
4442
  device const char * src1,
4224
- device uchar * dst,
4225
- constant int64_t & nbi1,
4443
+ device float * dst,
4444
+ constant uint64_t & nbi1,
4226
4445
  constant int64_t & ne00,
4227
4446
  constant int64_t & ne01,
4228
4447
  constant int64_t & ne02,
@@ -4238,7 +4457,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4238
4457
  constant uint64_t & nb12,
4239
4458
  constant int64_t & ne0,
4240
4459
  constant int64_t & ne1,
4241
- constant int64_t & nb1,
4460
+ constant uint64_t & nb1,
4242
4461
  constant uint & r2,
4243
4462
  constant uint & r3,
4244
4463
  constant int & idx,
@@ -4265,7 +4484,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4265
4484
  kernel_mul_mv_f16_f32_impl(
4266
4485
  src0[id],
4267
4486
  src1 + bid*nb11,
4268
- (device float *) (dst + bid*nb1),
4487
+ dst + bid*ne0,
4269
4488
  ne00,
4270
4489
  ne01,
4271
4490
  ne02,
@@ -4290,8 +4509,8 @@ kernel void kernel_mul_mv_id_f16_f32(
4290
4509
  kernel void kernel_mul_mv_id_q8_0_f32(
4291
4510
  device const char * ids,
4292
4511
  device const char * src1,
4293
- device uchar * dst,
4294
- constant int64_t & nbi1,
4512
+ device float * dst,
4513
+ constant uint64_t & nbi1,
4295
4514
  constant int64_t & ne00,
4296
4515
  constant int64_t & ne01,
4297
4516
  constant int64_t & ne02,
@@ -4307,7 +4526,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4307
4526
  constant uint64_t & nb12,
4308
4527
  constant int64_t & ne0,
4309
4528
  constant int64_t & ne1,
4310
- constant int64_t & nb1,
4529
+ constant uint64_t & nb1,
4311
4530
  constant uint & r2,
4312
4531
  constant uint & r3,
4313
4532
  constant int & idx,
@@ -4334,7 +4553,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4334
4553
  kernel_mul_mv_q8_0_f32_impl(
4335
4554
  src0[id],
4336
4555
  (device const float *) (src1 + bid*nb11),
4337
- (device float *) ( dst + bid*nb1),
4556
+ dst + bid*ne0,
4338
4557
  ne00,
4339
4558
  ne01,
4340
4559
  ne02,
@@ -4353,8 +4572,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4353
4572
  kernel void kernel_mul_mv_id_q4_0_f32(
4354
4573
  device const char * ids,
4355
4574
  device const char * src1,
4356
- device uchar * dst,
4357
- constant int64_t & nbi1,
4575
+ device float * dst,
4576
+ constant uint64_t & nbi1,
4358
4577
  constant int64_t & ne00,
4359
4578
  constant int64_t & ne01,
4360
4579
  constant int64_t & ne02,
@@ -4370,7 +4589,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4370
4589
  constant uint64_t & nb12,
4371
4590
  constant int64_t & ne0,
4372
4591
  constant int64_t & ne1,
4373
- constant int64_t & nb1,
4592
+ constant uint64_t & nb1,
4374
4593
  constant uint & r2,
4375
4594
  constant uint & r3,
4376
4595
  constant int & idx,
@@ -4397,7 +4616,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4397
4616
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4398
4617
  src0[id],
4399
4618
  (device const float *) (src1 + bid*nb11),
4400
- (device float *) ( dst + bid*nb1),
4619
+ dst + bid*ne0,
4401
4620
  ne00,
4402
4621
  ne01,
4403
4622
  ne02,
@@ -4416,8 +4635,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4416
4635
  kernel void kernel_mul_mv_id_q4_1_f32(
4417
4636
  device const char * ids,
4418
4637
  device const char * src1,
4419
- device uchar * dst,
4420
- constant int64_t & nbi1,
4638
+ device float * dst,
4639
+ constant uint64_t & nbi1,
4421
4640
  constant int64_t & ne00,
4422
4641
  constant int64_t & ne01,
4423
4642
  constant int64_t & ne02,
@@ -4433,7 +4652,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4433
4652
  constant uint64_t & nb12,
4434
4653
  constant int64_t & ne0,
4435
4654
  constant int64_t & ne1,
4436
- constant int64_t & nb1,
4655
+ constant uint64_t & nb1,
4437
4656
  constant uint & r2,
4438
4657
  constant uint & r3,
4439
4658
  constant int & idx,
@@ -4460,7 +4679,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4460
4679
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4461
4680
  src0[id],
4462
4681
  (device const float *) (src1 + bid*nb11),
4463
- (device float *) ( dst + bid*nb1),
4682
+ dst + bid*ne0,
4464
4683
  ne00,
4465
4684
  ne01,
4466
4685
  ne02,
@@ -4479,8 +4698,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4479
4698
  kernel void kernel_mul_mv_id_q5_0_f32(
4480
4699
  device const char * ids,
4481
4700
  device const char * src1,
4482
- device uchar * dst,
4483
- constant int64_t & nbi1,
4701
+ device float * dst,
4702
+ constant uint64_t & nbi1,
4484
4703
  constant int64_t & ne00,
4485
4704
  constant int64_t & ne01,
4486
4705
  constant int64_t & ne02,
@@ -4496,7 +4715,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4496
4715
  constant uint64_t & nb12,
4497
4716
  constant int64_t & ne0,
4498
4717
  constant int64_t & ne1,
4499
- constant int64_t & nb1,
4718
+ constant uint64_t & nb1,
4500
4719
  constant uint & r2,
4501
4720
  constant uint & r3,
4502
4721
  constant int & idx,
@@ -4523,7 +4742,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4523
4742
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4524
4743
  src0[id],
4525
4744
  (device const float *) (src1 + bid*nb11),
4526
- (device float *) ( dst + bid*nb1),
4745
+ dst + bid*ne0,
4527
4746
  ne00,
4528
4747
  ne01,
4529
4748
  ne02,
@@ -4542,8 +4761,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4542
4761
  kernel void kernel_mul_mv_id_q5_1_f32(
4543
4762
  device const char * ids,
4544
4763
  device const char * src1,
4545
- device uchar * dst,
4546
- constant int64_t & nbi1,
4764
+ device float * dst,
4765
+ constant uint64_t & nbi1,
4547
4766
  constant int64_t & ne00,
4548
4767
  constant int64_t & ne01,
4549
4768
  constant int64_t & ne02,
@@ -4559,7 +4778,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4559
4778
  constant uint64_t & nb12,
4560
4779
  constant int64_t & ne0,
4561
4780
  constant int64_t & ne1,
4562
- constant int64_t & nb1,
4781
+ constant uint64_t & nb1,
4563
4782
  constant uint & r2,
4564
4783
  constant uint & r3,
4565
4784
  constant int & idx,
@@ -4586,7 +4805,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4586
4805
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4587
4806
  src0[id],
4588
4807
  (device const float *) (src1 + bid*nb11),
4589
- (device float *) ( dst + bid*nb1),
4808
+ dst + bid*ne0,
4590
4809
  ne00,
4591
4810
  ne01,
4592
4811
  ne02,
@@ -4605,8 +4824,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4605
4824
  kernel void kernel_mul_mv_id_q2_K_f32(
4606
4825
  device const char * ids,
4607
4826
  device const char * src1,
4608
- device uchar * dst,
4609
- constant int64_t & nbi1,
4827
+ device float * dst,
4828
+ constant uint64_t & nbi1,
4610
4829
  constant int64_t & ne00,
4611
4830
  constant int64_t & ne01,
4612
4831
  constant int64_t & ne02,
@@ -4622,7 +4841,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4622
4841
  constant uint64_t & nb12,
4623
4842
  constant int64_t & ne0,
4624
4843
  constant int64_t & ne1,
4625
- constant int64_t & nb1,
4844
+ constant uint64_t & nb1,
4626
4845
  constant uint & r2,
4627
4846
  constant uint & r3,
4628
4847
  constant int & idx,
@@ -4649,7 +4868,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4649
4868
  kernel_mul_mv_q2_K_f32_impl(
4650
4869
  src0[id],
4651
4870
  (device const float *) (src1 + bid*nb11),
4652
- (device float *) ( dst + bid*nb1),
4871
+ dst + bid*ne0,
4653
4872
  ne00,
4654
4873
  ne01,
4655
4874
  ne02,
@@ -4668,8 +4887,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4668
4887
  kernel void kernel_mul_mv_id_q3_K_f32(
4669
4888
  device const char * ids,
4670
4889
  device const char * src1,
4671
- device uchar * dst,
4672
- constant int64_t & nbi1,
4890
+ device float * dst,
4891
+ constant uint64_t & nbi1,
4673
4892
  constant int64_t & ne00,
4674
4893
  constant int64_t & ne01,
4675
4894
  constant int64_t & ne02,
@@ -4685,7 +4904,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4685
4904
  constant uint64_t & nb12,
4686
4905
  constant int64_t & ne0,
4687
4906
  constant int64_t & ne1,
4688
- constant int64_t & nb1,
4907
+ constant uint64_t & nb1,
4689
4908
  constant uint & r2,
4690
4909
  constant uint & r3,
4691
4910
  constant int & idx,
@@ -4712,7 +4931,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4712
4931
  kernel_mul_mv_q3_K_f32_impl(
4713
4932
  src0[id],
4714
4933
  (device const float *) (src1 + bid*nb11),
4715
- (device float *) ( dst + bid*nb1),
4934
+ dst + bid*ne0,
4716
4935
  ne00,
4717
4936
  ne01,
4718
4937
  ne02,
@@ -4731,8 +4950,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4731
4950
  kernel void kernel_mul_mv_id_q4_K_f32(
4732
4951
  device const char * ids,
4733
4952
  device const char * src1,
4734
- device uchar * dst,
4735
- constant int64_t & nbi1,
4953
+ device float * dst,
4954
+ constant uint64_t & nbi1,
4736
4955
  constant int64_t & ne00,
4737
4956
  constant int64_t & ne01,
4738
4957
  constant int64_t & ne02,
@@ -4748,7 +4967,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4748
4967
  constant uint64_t & nb12,
4749
4968
  constant int64_t & ne0,
4750
4969
  constant int64_t & ne1,
4751
- constant int64_t & nb1,
4970
+ constant uint64_t & nb1,
4752
4971
  constant uint & r2,
4753
4972
  constant uint & r3,
4754
4973
  constant int & idx,
@@ -4775,7 +4994,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4775
4994
  kernel_mul_mv_q4_K_f32_impl(
4776
4995
  src0[id],
4777
4996
  (device const float *) (src1 + bid*nb11),
4778
- (device float *) ( dst + bid*nb1),
4997
+ dst + bid*ne0,
4779
4998
  ne00,
4780
4999
  ne01,
4781
5000
  ne02,
@@ -4794,8 +5013,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4794
5013
  kernel void kernel_mul_mv_id_q5_K_f32(
4795
5014
  device const char * ids,
4796
5015
  device const char * src1,
4797
- device uchar * dst,
4798
- constant int64_t & nbi1,
5016
+ device float * dst,
5017
+ constant uint64_t & nbi1,
4799
5018
  constant int64_t & ne00,
4800
5019
  constant int64_t & ne01,
4801
5020
  constant int64_t & ne02,
@@ -4811,7 +5030,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4811
5030
  constant uint64_t & nb12,
4812
5031
  constant int64_t & ne0,
4813
5032
  constant int64_t & ne1,
4814
- constant int64_t & nb1,
5033
+ constant uint64_t & nb1,
4815
5034
  constant uint & r2,
4816
5035
  constant uint & r3,
4817
5036
  constant int & idx,
@@ -4838,7 +5057,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4838
5057
  kernel_mul_mv_q5_K_f32_impl(
4839
5058
  src0[id],
4840
5059
  (device const float *) (src1 + bid*nb11),
4841
- (device float *) ( dst + bid*nb1),
5060
+ dst + bid*ne0,
4842
5061
  ne00,
4843
5062
  ne01,
4844
5063
  ne02,
@@ -4857,8 +5076,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4857
5076
  kernel void kernel_mul_mv_id_q6_K_f32(
4858
5077
  device const char * ids,
4859
5078
  device const char * src1,
4860
- device uchar * dst,
4861
- constant int64_t & nbi1,
5079
+ device float * dst,
5080
+ constant uint64_t & nbi1,
4862
5081
  constant int64_t & ne00,
4863
5082
  constant int64_t & ne01,
4864
5083
  constant int64_t & ne02,
@@ -4874,7 +5093,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4874
5093
  constant uint64_t & nb12,
4875
5094
  constant int64_t & ne0,
4876
5095
  constant int64_t & ne1,
4877
- constant int64_t & nb1,
5096
+ constant uint64_t & nb1,
4878
5097
  constant uint & r2,
4879
5098
  constant uint & r3,
4880
5099
  constant int & idx,
@@ -4901,7 +5120,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4901
5120
  kernel_mul_mv_q6_K_f32_impl(
4902
5121
  src0[id],
4903
5122
  (device const float *) (src1 + bid*nb11),
4904
- (device float *) ( dst + bid*nb1),
5123
+ dst + bid*ne0,
4905
5124
  ne00,
4906
5125
  ne01,
4907
5126
  ne02,