node-llama-cpp 2.8.3 → 2.8.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,26 +59,26 @@ kernel void kernel_add(
59
59
  constant int64_t & ne01,
60
60
  constant int64_t & ne02,
61
61
  constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
62
+ constant uint64_t & nb00,
63
+ constant uint64_t & nb01,
64
+ constant uint64_t & nb02,
65
+ constant uint64_t & nb03,
66
66
  constant int64_t & ne10,
67
67
  constant int64_t & ne11,
68
68
  constant int64_t & ne12,
69
69
  constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
70
+ constant uint64_t & nb10,
71
+ constant uint64_t & nb11,
72
+ constant uint64_t & nb12,
73
+ constant uint64_t & nb13,
74
74
  constant int64_t & ne0,
75
75
  constant int64_t & ne1,
76
76
  constant int64_t & ne2,
77
77
  constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
78
+ constant uint64_t & nb0,
79
+ constant uint64_t & nb1,
80
+ constant uint64_t & nb2,
81
+ constant uint64_t & nb3,
82
82
  constant int64_t & offs,
83
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
84
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -109,26 +109,26 @@ kernel void kernel_mul(
109
109
  constant int64_t & ne01,
110
110
  constant int64_t & ne02,
111
111
  constant int64_t & ne03,
112
- constant int64_t & nb00,
113
- constant int64_t & nb01,
114
- constant int64_t & nb02,
115
- constant int64_t & nb03,
112
+ constant uint64_t & nb00,
113
+ constant uint64_t & nb01,
114
+ constant uint64_t & nb02,
115
+ constant uint64_t & nb03,
116
116
  constant int64_t & ne10,
117
117
  constant int64_t & ne11,
118
118
  constant int64_t & ne12,
119
119
  constant int64_t & ne13,
120
- constant int64_t & nb10,
121
- constant int64_t & nb11,
122
- constant int64_t & nb12,
123
- constant int64_t & nb13,
120
+ constant uint64_t & nb10,
121
+ constant uint64_t & nb11,
122
+ constant uint64_t & nb12,
123
+ constant uint64_t & nb13,
124
124
  constant int64_t & ne0,
125
125
  constant int64_t & ne1,
126
126
  constant int64_t & ne2,
127
127
  constant int64_t & ne3,
128
- constant int64_t & nb0,
129
- constant int64_t & nb1,
130
- constant int64_t & nb2,
131
- constant int64_t & nb3,
128
+ constant uint64_t & nb0,
129
+ constant uint64_t & nb1,
130
+ constant uint64_t & nb2,
131
+ constant uint64_t & nb3,
132
132
  uint3 tgpig[[threadgroup_position_in_grid]],
133
133
  uint3 tpitg[[thread_position_in_threadgroup]],
134
134
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -158,26 +158,26 @@ kernel void kernel_div(
158
158
  constant int64_t & ne01,
159
159
  constant int64_t & ne02,
160
160
  constant int64_t & ne03,
161
- constant int64_t & nb00,
162
- constant int64_t & nb01,
163
- constant int64_t & nb02,
164
- constant int64_t & nb03,
161
+ constant uint64_t & nb00,
162
+ constant uint64_t & nb01,
163
+ constant uint64_t & nb02,
164
+ constant uint64_t & nb03,
165
165
  constant int64_t & ne10,
166
166
  constant int64_t & ne11,
167
167
  constant int64_t & ne12,
168
168
  constant int64_t & ne13,
169
- constant int64_t & nb10,
170
- constant int64_t & nb11,
171
- constant int64_t & nb12,
172
- constant int64_t & nb13,
169
+ constant uint64_t & nb10,
170
+ constant uint64_t & nb11,
171
+ constant uint64_t & nb12,
172
+ constant uint64_t & nb13,
173
173
  constant int64_t & ne0,
174
174
  constant int64_t & ne1,
175
175
  constant int64_t & ne2,
176
176
  constant int64_t & ne3,
177
- constant int64_t & nb0,
178
- constant int64_t & nb1,
179
- constant int64_t & nb2,
180
- constant int64_t & nb3,
177
+ constant uint64_t & nb0,
178
+ constant uint64_t & nb1,
179
+ constant uint64_t & nb2,
180
+ constant uint64_t & nb3,
181
181
  uint3 tgpig[[threadgroup_position_in_grid]],
182
182
  uint3 tpitg[[thread_position_in_threadgroup]],
183
183
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
205
205
  device const float4 * src0,
206
206
  device const float4 * src1,
207
207
  device float4 * dst,
208
- constant int64_t & nb [[buffer(28)]],
208
+ constant uint64_t & nb [[buffer(28)]],
209
209
  uint tpig[[thread_position_in_grid]]) {
210
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
211
  }
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
214
214
  device const float4 * src0,
215
215
  device const float4 * src1,
216
216
  device float4 * dst,
217
- constant int64_t & nb [[buffer(28)]],
217
+ constant uint64_t & nb [[buffer(28)]],
218
218
  uint tpig[[thread_position_in_grid]]) {
219
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
220
  }
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
223
223
  device const float4 * src0,
224
224
  device const float4 * src1,
225
225
  device float4 * dst,
226
- constant int64_t & nb [[buffer(28)]],
226
+ constant uint64_t & nb [[buffer(28)]],
227
227
  uint tpig[[thread_position_in_grid]]) {
228
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
229
  }
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
307
307
  constant int64_t & ne01,
308
308
  constant int64_t & ne02,
309
309
  constant int64_t & ne03,
310
- constant int64_t & nb00,
311
- constant int64_t & nb01,
312
- constant int64_t & nb02,
313
- constant int64_t & nb03,
310
+ constant uint64_t & nb00,
311
+ constant uint64_t & nb01,
312
+ constant uint64_t & nb02,
313
+ constant uint64_t & nb03,
314
314
  constant int64_t & ne10,
315
315
  constant int64_t & ne11,
316
316
  constant int64_t & ne12,
317
317
  constant int64_t & ne13,
318
- constant int64_t & nb10,
319
- constant int64_t & nb11,
320
- constant int64_t & nb12,
321
- constant int64_t & nb13,
318
+ constant uint64_t & nb10,
319
+ constant uint64_t & nb11,
320
+ constant uint64_t & nb12,
321
+ constant uint64_t & nb13,
322
322
  constant int64_t & ne0,
323
323
  constant int64_t & ne1,
324
324
  constant int64_t & ne2,
325
325
  constant int64_t & ne3,
326
- constant int64_t & nb0,
327
- constant int64_t & nb1,
328
- constant int64_t & nb2,
329
- constant int64_t & nb3,
326
+ constant uint64_t & nb0,
327
+ constant uint64_t & nb1,
328
+ constant uint64_t & nb2,
329
+ constant uint64_t & nb3,
330
330
  uint3 tpig[[thread_position_in_grid]]) {
331
331
  int64_t i3 = tpig.z;
332
332
  int64_t i2 = tpig.y;
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
846
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
847
  //Note: This is a template, but strictly speaking it only applies to
848
848
  // quantizations where the block size is 32. It also does not
849
- // giard against the number of rows not being divisible by
849
+ // guard against the number of rows not being divisible by
850
850
  // N_DST, so this is another explicit assumption of the implementation.
851
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
852
  void mul_vec_q_n_f32_impl(
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
920
920
  device const float * src1,
921
921
  device float * dst,
922
922
  constant int64_t & ne00,
923
- constant int64_t & ne01[[buffer(4)]],
924
- constant int64_t & ne02[[buffer(5)]],
925
- constant int64_t & ne10[[buffer(9)]],
926
- constant int64_t & ne12[[buffer(11)]],
927
- constant int64_t & ne0 [[buffer(15)]],
928
- constant int64_t & ne1 [[buffer(16)]],
929
- constant uint & r2 [[buffer(17)]],
930
- constant uint & r3 [[buffer(18)]],
923
+ constant int64_t & ne01,
924
+ constant int64_t & ne02,
925
+ constant uint64_t & nb00,
926
+ constant uint64_t & nb01,
927
+ constant uint64_t & nb02,
928
+ constant int64_t & ne10,
929
+ constant int64_t & ne11,
930
+ constant int64_t & ne12,
931
+ constant uint64_t & nb10,
932
+ constant uint64_t & nb11,
933
+ constant uint64_t & nb12,
934
+ constant int64_t & ne0,
935
+ constant int64_t & ne1,
936
+ constant uint & r2,
937
+ constant uint & r3,
931
938
  uint3 tgpig[[threadgroup_position_in_grid]],
932
939
  uint tiisg[[thread_index_in_simdgroup]],
933
940
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
939
946
  device const float * src1,
940
947
  device float * dst,
941
948
  constant int64_t & ne00,
942
- constant int64_t & ne01[[buffer(4)]],
943
- constant int64_t & ne02[[buffer(5)]],
944
- constant int64_t & ne10[[buffer(9)]],
945
- constant int64_t & ne12[[buffer(11)]],
946
- constant int64_t & ne0 [[buffer(15)]],
947
- constant int64_t & ne1 [[buffer(16)]],
948
- constant uint & r2 [[buffer(17)]],
949
- constant uint & r3 [[buffer(18)]],
949
+ constant int64_t & ne01,
950
+ constant int64_t & ne02,
951
+ constant uint64_t & nb00,
952
+ constant uint64_t & nb01,
953
+ constant uint64_t & nb02,
954
+ constant int64_t & ne10,
955
+ constant int64_t & ne11,
956
+ constant int64_t & ne12,
957
+ constant uint64_t & nb10,
958
+ constant uint64_t & nb11,
959
+ constant uint64_t & nb12,
960
+ constant int64_t & ne0,
961
+ constant int64_t & ne1,
962
+ constant uint & r2,
963
+ constant uint & r3,
950
964
  uint3 tgpig[[threadgroup_position_in_grid]],
951
965
  uint tiisg[[thread_index_in_simdgroup]],
952
966
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
958
972
  device const float * src1,
959
973
  device float * dst,
960
974
  constant int64_t & ne00,
961
- constant int64_t & ne01[[buffer(4)]],
962
- constant int64_t & ne02[[buffer(5)]],
963
- constant int64_t & ne10[[buffer(9)]],
964
- constant int64_t & ne12[[buffer(11)]],
965
- constant int64_t & ne0 [[buffer(15)]],
966
- constant int64_t & ne1 [[buffer(16)]],
967
- constant uint & r2 [[buffer(17)]],
968
- constant uint & r3 [[buffer(18)]],
975
+ constant int64_t & ne01,
976
+ constant int64_t & ne02,
977
+ constant uint64_t & nb00,
978
+ constant uint64_t & nb01,
979
+ constant uint64_t & nb02,
980
+ constant int64_t & ne10,
981
+ constant int64_t & ne11,
982
+ constant int64_t & ne12,
983
+ constant uint64_t & nb10,
984
+ constant uint64_t & nb11,
985
+ constant uint64_t & nb12,
986
+ constant int64_t & ne0,
987
+ constant int64_t & ne1,
988
+ constant uint & r2,
989
+ constant uint & r3,
969
990
  uint3 tgpig[[threadgroup_position_in_grid]],
970
991
  uint tiisg[[thread_index_in_simdgroup]],
971
992
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
977
998
  device const float * src1,
978
999
  device float * dst,
979
1000
  constant int64_t & ne00,
980
- constant int64_t & ne01[[buffer(4)]],
981
- constant int64_t & ne02[[buffer(5)]],
982
- constant int64_t & ne10[[buffer(9)]],
983
- constant int64_t & ne12[[buffer(11)]],
984
- constant int64_t & ne0 [[buffer(15)]],
985
- constant int64_t & ne1 [[buffer(16)]],
986
- constant uint & r2 [[buffer(17)]],
987
- constant uint & r3 [[buffer(18)]],
1001
+ constant int64_t & ne01,
1002
+ constant int64_t & ne02,
1003
+ constant uint64_t & nb00,
1004
+ constant uint64_t & nb01,
1005
+ constant uint64_t & nb02,
1006
+ constant int64_t & ne10,
1007
+ constant int64_t & ne11,
1008
+ constant int64_t & ne12,
1009
+ constant uint64_t & nb10,
1010
+ constant uint64_t & nb11,
1011
+ constant uint64_t & nb12,
1012
+ constant int64_t & ne0,
1013
+ constant int64_t & ne1,
1014
+ constant uint & r2,
1015
+ constant uint & r3,
988
1016
  uint3 tgpig[[threadgroup_position_in_grid]],
989
1017
  uint tiisg[[thread_index_in_simdgroup]],
990
1018
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
1071
1099
  constant int64_t & ne00,
1072
1100
  constant int64_t & ne01,
1073
1101
  constant int64_t & ne02,
1102
+ constant uint64_t & nb00,
1103
+ constant uint64_t & nb01,
1104
+ constant uint64_t & nb02,
1074
1105
  constant int64_t & ne10,
1106
+ constant int64_t & ne11,
1075
1107
  constant int64_t & ne12,
1108
+ constant uint64_t & nb10,
1109
+ constant uint64_t & nb11,
1110
+ constant uint64_t & nb12,
1076
1111
  constant int64_t & ne0,
1077
1112
  constant int64_t & ne1,
1078
- constant uint & r2 [[buffer(17)]],
1079
- constant uint & r3 [[buffer(18)]],
1113
+ constant uint & r2,
1114
+ constant uint & r3,
1080
1115
  uint3 tgpig[[threadgroup_position_in_grid]],
1081
1116
  uint tiisg[[thread_index_in_simdgroup]],
1082
1117
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
1182
1217
  constant uint64_t & nb12,
1183
1218
  constant int64_t & ne0,
1184
1219
  constant int64_t & ne1,
1185
- constant uint & r2 [[buffer(17)]],
1186
- constant uint & r3 [[buffer(18)]],
1220
+ constant uint & r2,
1221
+ constant uint & r3,
1187
1222
  uint3 tgpig[[threadgroup_position_in_grid]],
1188
1223
  uint tiisg[[thread_index_in_simdgroup]]) {
1189
1224
  kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
1209
1244
  constant uint64_t & nb12,
1210
1245
  constant int64_t & ne0,
1211
1246
  constant int64_t & ne1,
1212
- constant uint & r2 [[buffer(17)]],
1213
- constant uint & r3 [[buffer(18)]],
1247
+ constant uint & r2,
1248
+ constant uint & r3,
1214
1249
  uint3 tgpig[[threadgroup_position_in_grid]],
1215
1250
  uint tiisg[[thread_index_in_simdgroup]]) {
1216
1251
 
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1346
1381
  constant uint64_t & nb12,
1347
1382
  constant int64_t & ne0,
1348
1383
  constant int64_t & ne1,
1349
- constant uint & r2 [[buffer(17)]],
1350
- constant uint & r3 [[buffer(18)]],
1384
+ constant uint & r2,
1385
+ constant uint & r3,
1351
1386
  uint3 tgpig[[threadgroup_position_in_grid]],
1352
1387
  uint tiisg[[thread_index_in_simdgroup]]) {
1353
1388
  kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
1452
1487
  constant uint64_t & nb12,
1453
1488
  constant int64_t & ne0,
1454
1489
  constant int64_t & ne1,
1455
- constant uint & r2 [[buffer(17)]],
1456
- constant uint & r3 [[buffer(18)]],
1490
+ constant uint & r2,
1491
+ constant uint & r3,
1457
1492
  uint3 tgpig[[threadgroup_position_in_grid]],
1458
1493
  uint tiisg[[thread_index_in_simdgroup]]) {
1459
1494
  kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1478
1513
  constant uint64_t & nb12,
1479
1514
  constant int64_t & ne0,
1480
1515
  constant int64_t & ne1,
1481
- constant uint & r2 [[buffer(17)]],
1482
- constant uint & r3 [[buffer(18)]],
1516
+ constant uint & r2,
1517
+ constant uint & r3,
1483
1518
  uint3 tgpig[[threadgroup_position_in_grid]],
1484
1519
  uint tiisg[[thread_index_in_simdgroup]]) {
1485
1520
 
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
1543
1578
  const int64_t i3 = n / (ne2*ne1*ne0);
1544
1579
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1545
1580
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1546
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1581
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1582
+
1547
1583
  const int64_t k = i3*ne3 + i2;
1548
1584
 
1549
1585
  float m_k;
@@ -2410,21 +2446,18 @@ typedef struct {
2410
2446
  } block_q6_K;
2411
2447
  // 210 bytes / block
2412
2448
 
2413
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2414
- uchar4 r;
2415
- if (j < 4) {
2416
- r[0] = q[j+0] & 63;
2417
- r[2] = q[j+1] & 63;
2418
- r[1] = q[j+4] & 63;
2419
- r[3] = q[j+5] & 63;
2420
- } else {
2421
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2422
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2423
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2424
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2425
- }
2426
- return r;
2427
- }
2449
+ typedef struct {
2450
+ half d;
2451
+ uint16_t qs[QK_K/8];
2452
+ } block_iq2_xxs;
2453
+ // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
+
2455
+ typedef struct {
2456
+ half d;
2457
+ uint16_t qs[QK_K/8];
2458
+ uint8_t scales[QK_K/32];
2459
+ } block_iq2_xs;
2460
+ // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2428
2461
 
2429
2462
  //====================================== dot products =========================
2430
2463
 
@@ -2584,14 +2617,21 @@ kernel void kernel_mul_mv_q2_K_f32(
2584
2617
  device const float * src1,
2585
2618
  device float * dst,
2586
2619
  constant int64_t & ne00,
2587
- constant int64_t & ne01[[buffer(4)]],
2588
- constant int64_t & ne02[[buffer(5)]],
2589
- constant int64_t & ne10[[buffer(9)]],
2590
- constant int64_t & ne12[[buffer(11)]],
2591
- constant int64_t & ne0 [[buffer(15)]],
2592
- constant int64_t & ne1 [[buffer(16)]],
2593
- constant uint & r2 [[buffer(17)]],
2594
- constant uint & r3 [[buffer(18)]],
2620
+ constant int64_t & ne01,
2621
+ constant int64_t & ne02,
2622
+ constant uint64_t & nb00,
2623
+ constant uint64_t & nb01,
2624
+ constant uint64_t & nb02,
2625
+ constant int64_t & ne10,
2626
+ constant int64_t & ne11,
2627
+ constant int64_t & ne12,
2628
+ constant uint64_t & nb10,
2629
+ constant uint64_t & nb11,
2630
+ constant uint64_t & nb12,
2631
+ constant int64_t & ne0,
2632
+ constant int64_t & ne1,
2633
+ constant uint & r2,
2634
+ constant uint & r3,
2595
2635
  uint3 tgpig[[threadgroup_position_in_grid]],
2596
2636
  uint tiisg[[thread_index_in_simdgroup]],
2597
2637
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2841,14 +2881,21 @@ kernel void kernel_mul_mv_q3_K_f32(
2841
2881
  device const float * src1,
2842
2882
  device float * dst,
2843
2883
  constant int64_t & ne00,
2844
- constant int64_t & ne01[[buffer(4)]],
2845
- constant int64_t & ne02[[buffer(5)]],
2846
- constant int64_t & ne10[[buffer(9)]],
2847
- constant int64_t & ne12[[buffer(11)]],
2848
- constant int64_t & ne0 [[buffer(15)]],
2849
- constant int64_t & ne1 [[buffer(16)]],
2850
- constant uint & r2 [[buffer(17)]],
2851
- constant uint & r3 [[buffer(18)]],
2884
+ constant int64_t & ne01,
2885
+ constant int64_t & ne02,
2886
+ constant uint64_t & nb00,
2887
+ constant uint64_t & nb01,
2888
+ constant uint64_t & nb02,
2889
+ constant int64_t & ne10,
2890
+ constant int64_t & ne11,
2891
+ constant int64_t & ne12,
2892
+ constant uint64_t & nb10,
2893
+ constant uint64_t & nb11,
2894
+ constant uint64_t & nb12,
2895
+ constant int64_t & ne0,
2896
+ constant int64_t & ne1,
2897
+ constant uint & r2,
2898
+ constant uint & r3,
2852
2899
  uint3 tgpig[[threadgroup_position_in_grid]],
2853
2900
  uint tiisg[[thread_index_in_simdgroup]],
2854
2901
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2984,8 +3031,8 @@ void kernel_mul_mv_q4_K_f32_impl(
2984
3031
  constant uint & r2,
2985
3032
  constant uint & r3,
2986
3033
  uint3 tgpig[[threadgroup_position_in_grid]],
2987
- uint tiisg[[thread_index_in_simdgroup]],
2988
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3034
+ uint tiisg[[thread_index_in_simdgroup]],
3035
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2989
3036
 
2990
3037
  const int ix = tiisg/4; // 0...7
2991
3038
  const int it = tiisg%4; // 0...3
@@ -2994,7 +3041,7 @@ void kernel_mul_mv_q4_K_f32_impl(
2994
3041
  const int r0 = tgpig.x;
2995
3042
  const int r1 = tgpig.y;
2996
3043
  const int im = tgpig.z;
2997
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3044
+ const int first_row = r0 * N_DST;
2998
3045
  const int ib_row = first_row * nb;
2999
3046
 
3000
3047
  const uint i12 = im%ne12;
@@ -3060,7 +3107,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3060
3107
  for (int row = 0; row < N_DST; ++row) {
3061
3108
  all_sum = simd_sum(sumf[row]);
3062
3109
  if (tiisg == 0) {
3063
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3110
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
3064
3111
  }
3065
3112
  }
3066
3113
  }
@@ -3072,14 +3119,21 @@ kernel void kernel_mul_mv_q4_K_f32(
3072
3119
  device const float * src1,
3073
3120
  device float * dst,
3074
3121
  constant int64_t & ne00,
3075
- constant int64_t & ne01[[buffer(4)]],
3076
- constant int64_t & ne02[[buffer(5)]],
3077
- constant int64_t & ne10[[buffer(9)]],
3078
- constant int64_t & ne12[[buffer(11)]],
3079
- constant int64_t & ne0 [[buffer(15)]],
3080
- constant int64_t & ne1 [[buffer(16)]],
3081
- constant uint & r2 [[buffer(17)]],
3082
- constant uint & r3 [[buffer(18)]],
3122
+ constant int64_t & ne01,
3123
+ constant int64_t & ne02,
3124
+ constant uint64_t & nb00,
3125
+ constant uint64_t & nb01,
3126
+ constant uint64_t & nb02,
3127
+ constant int64_t & ne10,
3128
+ constant int64_t & ne11,
3129
+ constant int64_t & ne12,
3130
+ constant uint64_t & nb10,
3131
+ constant uint64_t & nb11,
3132
+ constant uint64_t & nb12,
3133
+ constant int64_t & ne0,
3134
+ constant int64_t & ne1,
3135
+ constant uint & r2,
3136
+ constant uint & r3,
3083
3137
  uint3 tgpig[[threadgroup_position_in_grid]],
3084
3138
  uint tiisg[[thread_index_in_simdgroup]],
3085
3139
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3271,14 +3325,21 @@ kernel void kernel_mul_mv_q5_K_f32(
3271
3325
  device const float * src1,
3272
3326
  device float * dst,
3273
3327
  constant int64_t & ne00,
3274
- constant int64_t & ne01[[buffer(4)]],
3275
- constant int64_t & ne02[[buffer(5)]],
3276
- constant int64_t & ne10[[buffer(9)]],
3277
- constant int64_t & ne12[[buffer(11)]],
3278
- constant int64_t & ne0 [[buffer(15)]],
3279
- constant int64_t & ne1 [[buffer(16)]],
3280
- constant uint & r2 [[buffer(17)]],
3281
- constant uint & r3 [[buffer(18)]],
3328
+ constant int64_t & ne01,
3329
+ constant int64_t & ne02,
3330
+ constant uint64_t & nb00,
3331
+ constant uint64_t & nb01,
3332
+ constant uint64_t & nb02,
3333
+ constant int64_t & ne10,
3334
+ constant int64_t & ne11,
3335
+ constant int64_t & ne12,
3336
+ constant uint64_t & nb10,
3337
+ constant uint64_t & nb11,
3338
+ constant uint64_t & nb12,
3339
+ constant int64_t & ne0,
3340
+ constant int64_t & ne1,
3341
+ constant uint & r2,
3342
+ constant uint & r3,
3282
3343
  uint3 tgpig[[threadgroup_position_in_grid]],
3283
3344
  uint tiisg[[thread_index_in_simdgroup]],
3284
3345
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3398,14 +3459,21 @@ kernel void kernel_mul_mv_q6_K_f32(
3398
3459
  device const float * src1,
3399
3460
  device float * dst,
3400
3461
  constant int64_t & ne00,
3401
- constant int64_t & ne01[[buffer(4)]],
3402
- constant int64_t & ne02[[buffer(5)]],
3403
- constant int64_t & ne10[[buffer(9)]],
3404
- constant int64_t & ne12[[buffer(11)]],
3405
- constant int64_t & ne0 [[buffer(15)]],
3406
- constant int64_t & ne1 [[buffer(16)]],
3407
- constant uint & r2 [[buffer(17)]],
3408
- constant uint & r3 [[buffer(18)]],
3462
+ constant int64_t & ne01,
3463
+ constant int64_t & ne02,
3464
+ constant uint64_t & nb00,
3465
+ constant uint64_t & nb01,
3466
+ constant uint64_t & nb02,
3467
+ constant int64_t & ne10,
3468
+ constant int64_t & ne11,
3469
+ constant int64_t & ne12,
3470
+ constant uint64_t & nb10,
3471
+ constant uint64_t & nb11,
3472
+ constant uint64_t & nb12,
3473
+ constant int64_t & ne0,
3474
+ constant int64_t & ne1,
3475
+ constant uint & r2,
3476
+ constant uint & r3,
3409
3477
  uint3 tgpig[[threadgroup_position_in_grid]],
3410
3478
  uint tiisg[[thread_index_in_simdgroup]],
3411
3479
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3413,53 +3481,542 @@ kernel void kernel_mul_mv_q6_K_f32(
3413
3481
  kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3414
3482
  }
3415
3483
 
3416
- //============================= templates and their specializations =============================
3484
+ // ======================= "True" 2-bit
3485
+
3486
+ constexpr constant static uint64_t iq2xxs_grid[256] = {
3487
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3488
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3489
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3490
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3491
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3492
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3493
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3494
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3495
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3496
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3497
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3498
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3499
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3500
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3501
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3502
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3503
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3504
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3505
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3506
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3507
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3508
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3509
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3510
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3511
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3512
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3513
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3514
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3515
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3516
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3517
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3518
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3519
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3520
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3521
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3522
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3523
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3524
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3525
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3526
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3527
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3528
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3529
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3530
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3531
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3532
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3533
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3534
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3535
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3536
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3537
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3538
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3539
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3540
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3541
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3542
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3543
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3544
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3545
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3546
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3547
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3548
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3549
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3550
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3551
+ };
3417
3552
 
3418
- // NOTE: this is not dequantizing - we are simply fitting the template
3419
- template <typename type4x4>
3420
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
3421
- float4x4 temp = *(((device float4x4 *)src));
3422
- for (int i = 0; i < 16; i++){
3423
- reg[i/4][i%4] = temp[i/4][i%4];
3424
- }
3425
- }
3553
+ constexpr constant static uint64_t iq2xs_grid[512] = {
3554
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3555
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3556
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3557
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3558
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3559
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3560
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3561
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3562
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3563
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3564
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3565
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3566
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3567
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3568
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3569
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3570
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3571
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3572
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3573
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3574
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3575
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3576
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3577
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3578
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3579
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3580
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3581
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3582
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3583
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3584
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3585
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3586
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3587
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3588
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3589
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3590
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3591
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3592
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3593
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3594
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3595
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3596
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3597
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3598
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3599
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3600
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3601
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3602
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3603
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3604
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3605
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3606
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3607
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3608
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3609
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3610
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3611
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3612
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3613
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3614
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3615
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3616
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3617
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3618
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3619
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3620
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3621
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3622
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3623
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3624
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3625
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3626
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3627
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3628
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3629
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3630
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3631
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3632
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3633
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3634
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3635
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3636
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3637
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3638
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3639
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3640
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3641
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3642
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3643
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3644
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3645
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3646
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3647
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3648
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3649
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3650
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3651
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3652
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3653
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3654
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3655
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3656
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3657
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3658
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3659
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3660
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3661
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3662
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3663
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3664
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3665
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3666
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3667
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3668
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3669
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3670
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3671
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3672
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3673
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3674
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3675
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3676
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3677
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3678
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3679
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3680
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3681
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
+ };
3426
3683
 
3427
- template <typename type4x4>
3428
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
3429
- half4x4 temp = *(((device half4x4 *)src));
3430
- for (int i = 0; i < 16; i++){
3431
- reg[i/4][i%4] = temp[i/4][i%4];
3432
- }
3433
- }
3684
+ constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3687
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3688
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3689
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3690
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3691
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3692
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3693
+ };
3434
3694
 
3435
- template <typename type4x4>
3436
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3437
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
3438
- const float d1 = il ? (xb->d / 16.h) : xb->d;
3439
- const float d2 = d1 / 256.f;
3440
- const float md = -8.h * xb->d;
3441
- const ushort mask0 = il ? 0x00F0 : 0x000F;
3442
- const ushort mask1 = mask0 << 8;
3695
+ constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3443
3696
 
3444
- for (int i=0;i<8;i++) {
3445
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
3446
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
3447
- }
3448
- }
3697
+ void kernel_mul_mv_iq2_xxs_f32_impl(
3698
+ device const void * src0,
3699
+ device const float * src1,
3700
+ device float * dst,
3701
+ constant int64_t & ne00,
3702
+ constant int64_t & ne01,
3703
+ constant int64_t & ne02,
3704
+ constant int64_t & ne10,
3705
+ constant int64_t & ne12,
3706
+ constant int64_t & ne0,
3707
+ constant int64_t & ne1,
3708
+ constant uint & r2,
3709
+ constant uint & r3,
3710
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3711
+ uint3 tgpig[[threadgroup_position_in_grid]],
3712
+ uint tiisg[[thread_index_in_simdgroup]],
3713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3449
3714
 
3450
- template <typename type4x4>
3451
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
3452
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
3453
- const float d1 = il ? (xb->d / 16.h) : xb->d;
3454
- const float d2 = d1 / 256.f;
3455
- const float m = xb->m;
3456
- const ushort mask0 = il ? 0x00F0 : 0x000F;
3457
- const ushort mask1 = mask0 << 8;
3715
+ const int nb = ne00/QK_K;
3716
+ const int r0 = tgpig.x;
3717
+ const int r1 = tgpig.y;
3718
+ const int im = tgpig.z;
3458
3719
 
3459
- for (int i=0;i<8;i++) {
3460
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
3461
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
3462
- }
3720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3721
+ const int ib_row = first_row * nb;
3722
+
3723
+ const uint i12 = im%ne12;
3724
+ const uint i13 = im/ne12;
3725
+
3726
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3727
+
3728
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
3729
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3730
+
3731
+ float yl[32];
3732
+ float sumf[N_DST]={0.f}, all_sum;
3733
+
3734
+ const int nb32 = nb * (QK_K / 32);
3735
+
3736
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3737
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
3738
+ {
3739
+ int nval = 4;
3740
+ int pos = (32*sgitg + tiisg)*nval;
3741
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
3742
+ nval = 2;
3743
+ pos = (32*sgitg + tiisg)*nval;
3744
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3745
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3746
+ }
3747
+
3748
+ #if QK_K == 256
3749
+ const int ix = tiisg;
3750
+
3751
+ device const float * y4 = y + 32 * ix;
3752
+
3753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3754
+
3755
+ for (int i = 0; i < 32; ++i) {
3756
+ yl[i] = y4[i];
3757
+ }
3758
+
3759
+ const int ibl = ib32 / (QK_K / 32);
3760
+ const int ib = ib32 % (QK_K / 32);
3761
+
3762
+ device const block_iq2_xxs * xr = x + ibl;
3763
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3764
+ device const half * dh = &xr->d;
3765
+
3766
+ for (int row = 0; row < N_DST; row++) {
3767
+
3768
+ const float db = dh[0];
3769
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
3770
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3771
+ const float d = db * (0.5f + (aux32 >> 28));
3772
+
3773
+ float sum = 0;
3774
+ for (int l = 0; l < 4; ++l) {
3775
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
3776
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
3777
+ for (int j = 0; j < 8; ++j) {
3778
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3779
+ }
3780
+ }
3781
+ sumf[row] += d * sum;
3782
+
3783
+ dh += nb*sizeof(block_iq2_xxs)/2;
3784
+ q2 += nb*sizeof(block_iq2_xxs)/2;
3785
+ }
3786
+
3787
+ y4 += 32 * 32;
3788
+ }
3789
+ #else
3790
+ // TODO
3791
+ #endif
3792
+
3793
+ for (int row = 0; row < N_DST; ++row) {
3794
+ all_sum = simd_sum(sumf[row]);
3795
+ if (tiisg == 0) {
3796
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3797
+ }
3798
+ }
3799
+ }
3800
+
3801
+ [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
3802
+ kernel void kernel_mul_mv_iq2_xxs_f32(
3803
+ device const void * src0,
3804
+ device const float * src1,
3805
+ device float * dst,
3806
+ constant int64_t & ne00,
3807
+ constant int64_t & ne01,
3808
+ constant int64_t & ne02,
3809
+ constant uint64_t & nb00,
3810
+ constant uint64_t & nb01,
3811
+ constant uint64_t & nb02,
3812
+ constant int64_t & ne10,
3813
+ constant int64_t & ne11,
3814
+ constant int64_t & ne12,
3815
+ constant uint64_t & nb10,
3816
+ constant uint64_t & nb11,
3817
+ constant uint64_t & nb12,
3818
+ constant int64_t & ne0,
3819
+ constant int64_t & ne1,
3820
+ constant uint & r2,
3821
+ constant uint & r3,
3822
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3823
+ uint3 tgpig[[threadgroup_position_in_grid]],
3824
+ uint tiisg[[thread_index_in_simdgroup]],
3825
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3826
+
3827
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3828
+ }
3829
+
3830
+ void kernel_mul_mv_iq2_xs_f32_impl(
3831
+ device const void * src0,
3832
+ device const float * src1,
3833
+ device float * dst,
3834
+ constant int64_t & ne00,
3835
+ constant int64_t & ne01,
3836
+ constant int64_t & ne02,
3837
+ constant int64_t & ne10,
3838
+ constant int64_t & ne12,
3839
+ constant int64_t & ne0,
3840
+ constant int64_t & ne1,
3841
+ constant uint & r2,
3842
+ constant uint & r3,
3843
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiisg[[thread_index_in_simdgroup]],
3846
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3847
+
3848
+ const int nb = ne00/QK_K;
3849
+ const int r0 = tgpig.x;
3850
+ const int r1 = tgpig.y;
3851
+ const int im = tgpig.z;
3852
+
3853
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3854
+ const int ib_row = first_row * nb;
3855
+
3856
+ const uint i12 = im%ne12;
3857
+ const uint i13 = im/ne12;
3858
+
3859
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3860
+
3861
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
3862
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3863
+
3864
+ float yl[32];
3865
+ float sumf[N_DST]={0.f}, all_sum;
3866
+
3867
+ const int nb32 = nb * (QK_K / 32);
3868
+
3869
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3870
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
3871
+ {
3872
+ int nval = 8;
3873
+ int pos = (32*sgitg + tiisg)*nval;
3874
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
3875
+ nval = 2;
3876
+ pos = (32*sgitg + tiisg)*nval;
3877
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3878
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3879
+ }
3880
+
3881
+ #if QK_K == 256
3882
+ const int ix = tiisg;
3883
+
3884
+ device const float * y4 = y + 32 * ix;
3885
+
3886
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3887
+
3888
+ for (int i = 0; i < 32; ++i) {
3889
+ yl[i] = y4[i];
3890
+ }
3891
+
3892
+ const int ibl = ib32 / (QK_K / 32);
3893
+ const int ib = ib32 % (QK_K / 32);
3894
+
3895
+ device const block_iq2_xs * xr = x + ibl;
3896
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3897
+ device const uint8_t * sc = xr->scales + ib;
3898
+ device const half * dh = &xr->d;
3899
+
3900
+ for (int row = 0; row < N_DST; row++) {
3901
+
3902
+ const float db = dh[0];
3903
+ const uint8_t ls1 = sc[0] & 0xf;
3904
+ const uint8_t ls2 = sc[0] >> 4;
3905
+ const float d1 = db * (0.5f + ls1);
3906
+ const float d2 = db * (0.5f + ls2);
3907
+
3908
+ float sum1 = 0, sum2 = 0;
3909
+ for (int l = 0; l < 2; ++l) {
3910
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3911
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3912
+ for (int j = 0; j < 8; ++j) {
3913
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3914
+ }
3915
+ }
3916
+ for (int l = 2; l < 4; ++l) {
3917
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3918
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3919
+ for (int j = 0; j < 8; ++j) {
3920
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3921
+ }
3922
+ }
3923
+ sumf[row] += d1 * sum1 + d2 * sum2;
3924
+
3925
+ dh += nb*sizeof(block_iq2_xs)/2;
3926
+ q2 += nb*sizeof(block_iq2_xs)/2;
3927
+ sc += nb*sizeof(block_iq2_xs);
3928
+ }
3929
+
3930
+ y4 += 32 * 32;
3931
+ }
3932
+ #else
3933
+ // TODO
3934
+ #endif
3935
+
3936
+ for (int row = 0; row < N_DST; ++row) {
3937
+ all_sum = simd_sum(sumf[row]);
3938
+ if (tiisg == 0) {
3939
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3940
+ }
3941
+ }
3942
+ }
3943
+
3944
+ [[host_name("kernel_mul_mv_iq2_xs_f32")]]
3945
+ kernel void kernel_mul_mv_iq2_xs_f32(
3946
+ device const void * src0,
3947
+ device const float * src1,
3948
+ device float * dst,
3949
+ constant int64_t & ne00,
3950
+ constant int64_t & ne01,
3951
+ constant int64_t & ne02,
3952
+ constant uint64_t & nb00,
3953
+ constant uint64_t & nb01,
3954
+ constant uint64_t & nb02,
3955
+ constant int64_t & ne10,
3956
+ constant int64_t & ne11,
3957
+ constant int64_t & ne12,
3958
+ constant uint64_t & nb10,
3959
+ constant uint64_t & nb11,
3960
+ constant uint64_t & nb12,
3961
+ constant int64_t & ne0,
3962
+ constant int64_t & ne1,
3963
+ constant uint & r2,
3964
+ constant uint & r3,
3965
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3966
+ uint3 tgpig[[threadgroup_position_in_grid]],
3967
+ uint tiisg[[thread_index_in_simdgroup]],
3968
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3969
+
3970
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
+ }
3972
+
3973
+ //============================= templates and their specializations =============================
3974
+
3975
+ // NOTE: this is not dequantizing - we are simply fitting the template
3976
+ template <typename type4x4>
3977
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
3978
+ float4x4 temp = *(((device float4x4 *)src));
3979
+ for (int i = 0; i < 16; i++){
3980
+ reg[i/4][i%4] = temp[i/4][i%4];
3981
+ }
3982
+ }
3983
+
3984
+ template <typename type4x4>
3985
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
3986
+ half4x4 temp = *(((device half4x4 *)src));
3987
+ for (int i = 0; i < 16; i++){
3988
+ reg[i/4][i%4] = temp[i/4][i%4];
3989
+ }
3990
+ }
3991
+
3992
+ template <typename type4x4>
3993
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3994
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
3995
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
3996
+ const float d2 = d1 / 256.f;
3997
+ const float md = -8.h * xb->d;
3998
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
3999
+ const ushort mask1 = mask0 << 8;
4000
+
4001
+ for (int i=0;i<8;i++) {
4002
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
4003
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
4004
+ }
4005
+ }
4006
+
4007
+ template <typename type4x4>
4008
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
4009
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
4010
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
4011
+ const float d2 = d1 / 256.f;
4012
+ const float m = xb->m;
4013
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
4014
+ const ushort mask1 = mask0 << 8;
4015
+
4016
+ for (int i=0;i<8;i++) {
4017
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
4018
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
4019
+ }
3463
4020
  }
3464
4021
 
3465
4022
  template <typename type4x4>
@@ -3523,7 +4080,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3523
4080
  device const int8_t * qs = ((device const int8_t *)xb->qs);
3524
4081
  const half d = xb->d;
3525
4082
 
3526
- for (int i=0;i<16;i++) {
4083
+ for (int i = 0; i < 16; i++) {
3527
4084
  reg[i/4][i%4] = (qs[i + 16*il] * d);
3528
4085
  }
3529
4086
  }
@@ -3565,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
3565
4122
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3566
4123
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3567
4124
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3568
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3569
- const half ml = 4.h * dl;
4125
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
4126
+ const float ml = 4.f * dl;
3570
4127
 
3571
4128
  il = (il/2) & 3;
3572
4129
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3633,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3633
4190
  uint8_t ul = 1 << (il/2);
3634
4191
  il = il & 3;
3635
4192
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3636
- const float d = il < 2 ? xb->d : xb->d / 16.h;
4193
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
3637
4194
  const float min = xb->dmin;
3638
4195
  const float dl = d * sc[0];
3639
4196
  const float ml = min * sc[1];
@@ -3666,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3666
4223
  #if QK_K == 256
3667
4224
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3668
4225
  qh = qh + 32*(il/8) + 16*(il&1);
3669
- half sc = scales[(il%2) + 2 * ((il/2))];
4226
+ float sc = scales[(il%2) + 2 * ((il/2))];
3670
4227
  il = (il/2) & 3;
3671
4228
  #else
3672
4229
  ql = ql + 16 * (il&1);
3673
- half sc = scales[il];
4230
+ float sc = scales[il];
3674
4231
  #endif
3675
4232
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3676
4233
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3677
- const half coef = il>1 ? 1.f/16.h : 1.h;
3678
- const half ml = d_all * sc * 32.h;
3679
- const half dl = d_all * sc * coef;
4234
+ const float coef = il>1 ? 1.f/16.f : 1.f;
4235
+ const float ml = d_all * sc * 32.f;
4236
+ const float dl = d_all * sc * coef;
3680
4237
  for (int i = 0; i < 16; ++i) {
3681
4238
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3682
4239
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@@ -3684,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3684
4241
  }
3685
4242
  }
3686
4243
 
4244
+ template <typename type4x4>
4245
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
4246
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4247
+ const float d = xb->d;
4248
+ const int ib32 = il/2;
4249
+ il = il%2;
4250
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4251
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
4252
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4253
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
4254
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4255
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4256
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4257
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4258
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4259
+ for (int i = 0; i < 8; ++i) {
4260
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4261
+ }
4262
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4263
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4264
+ for (int i = 0; i < 8; ++i) {
4265
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4266
+ }
4267
+ }
4268
+
4269
+ template <typename type4x4>
4270
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4271
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4272
+ const float d = xb->d;
4273
+ const int ib32 = il/2;
4274
+ il = il%2;
4275
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4276
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4277
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4278
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4279
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4280
+ for (int i = 0; i < 8; ++i) {
4281
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4282
+ }
4283
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4284
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4285
+ for (int i = 0; i < 8; ++i) {
4286
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4287
+ }
4288
+ }
4289
+
3687
4290
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3688
4291
  kernel void kernel_get_rows(
3689
4292
  device const void * src0,
@@ -3764,48 +4367,212 @@ kernel void kernel_get_rows_f16(
3764
4367
  const int64_t i10 = tgpig.x;
3765
4368
  const int64_t i11 = tgpig.y;
3766
4369
 
3767
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4370
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4371
+
4372
+ const int64_t i02 = i11;
4373
+
4374
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4375
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4376
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4377
+ }
4378
+ }
4379
+
4380
+ kernel void kernel_get_rows_i32(
4381
+ device const void * src0,
4382
+ device const char * src1,
4383
+ device int32_t * dst,
4384
+ constant int64_t & ne00,
4385
+ constant uint64_t & nb01,
4386
+ constant uint64_t & nb02,
4387
+ constant int64_t & ne10,
4388
+ constant uint64_t & nb10,
4389
+ constant uint64_t & nb11,
4390
+ constant uint64_t & nb1,
4391
+ constant uint64_t & nb2,
4392
+ uint3 tgpig[[threadgroup_position_in_grid]],
4393
+ uint tiitg[[thread_index_in_threadgroup]],
4394
+ uint3 tptg [[threads_per_threadgroup]]) {
4395
+ const int64_t i10 = tgpig.x;
4396
+ const int64_t i11 = tgpig.y;
4397
+
4398
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4399
+
4400
+ const int64_t i02 = i11;
4401
+
4402
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4403
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4404
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4405
+ }
4406
+ }
4407
+
4408
+
4409
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
4410
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
4411
+ #define BLOCK_SIZE_K 32
4412
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
4413
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
4414
+ #define THREAD_PER_BLOCK 128
4415
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
4416
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
4417
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
4418
+ #define SG_MAT_ROW 8
4419
+
4420
+ // each block_q contains 16*nl weights
4421
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4422
+ void kernel_mul_mm_impl(device const uchar * src0,
4423
+ device const uchar * src1,
4424
+ device float * dst,
4425
+ constant int64_t & ne00,
4426
+ constant int64_t & ne02,
4427
+ constant uint64_t & nb01,
4428
+ constant uint64_t & nb02,
4429
+ constant int64_t & ne12,
4430
+ constant uint64_t & nb10,
4431
+ constant uint64_t & nb11,
4432
+ constant uint64_t & nb12,
4433
+ constant int64_t & ne0,
4434
+ constant int64_t & ne1,
4435
+ constant uint & r2,
4436
+ constant uint & r3,
4437
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
4438
+ uint3 tgpig[[threadgroup_position_in_grid]],
4439
+ uint tiitg[[thread_index_in_threadgroup]],
4440
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4441
+
4442
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
4443
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
4444
+
4445
+ const uint r0 = tgpig.y;
4446
+ const uint r1 = tgpig.x;
4447
+ const uint im = tgpig.z;
4448
+
4449
+ // if this block is of 64x32 shape or smaller
4450
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4451
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
4452
+
4453
+ // a thread shouldn't load data outside of the matrix
4454
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4455
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
4456
+
4457
+ simdgroup_half8x8 ma[4];
4458
+ simdgroup_float8x8 mb[2];
4459
+ simdgroup_float8x8 c_res[8];
4460
+ for (int i = 0; i < 8; i++){
4461
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4462
+ }
4463
+
4464
+ short il = (tiitg % THREAD_PER_ROW);
4465
+
4466
+ const uint i12 = im%ne12;
4467
+ const uint i13 = im/ne12;
4468
+
4469
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4470
+ ushort offset1 = il/nl;
4471
+
4472
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4473
+ device const float * y = (device const float *)(src1
4474
+ + nb12 * im
4475
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
4476
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
4477
+
4478
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4479
+ // load data and store to threadgroup memory
4480
+ half4x4 temp_a;
4481
+ dequantize_func(x, il, temp_a);
4482
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4483
+
4484
+ #pragma unroll(16)
4485
+ for (int i = 0; i < 16; i++) {
4486
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4487
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4488
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4489
+ }
4490
+
4491
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
4492
+
4493
+ il = (il + 2 < nl) ? il + 2 : il % 2;
4494
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
4495
+ y += BLOCK_SIZE_K;
4496
+
4497
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4498
+
4499
+ // load matrices from threadgroup memory and conduct outer products
4500
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4501
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
4502
+
4503
+ #pragma unroll(4)
4504
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4505
+ #pragma unroll(4)
4506
+ for (int i = 0; i < 4; i++) {
4507
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4508
+ }
4509
+ simdgroup_barrier(mem_flags::mem_none);
4510
+ #pragma unroll(2)
4511
+ for (int i = 0; i < 2; i++) {
4512
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4513
+ }
4514
+
4515
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4516
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
3768
4517
 
3769
- const int64_t i02 = i11;
4518
+ #pragma unroll(8)
4519
+ for (int i = 0; i < 8; i++){
4520
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4521
+ }
4522
+ }
4523
+ }
3770
4524
 
3771
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3772
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3773
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4525
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
4526
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
4527
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
4528
+ for (int i = 0; i < 8; i++) {
4529
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
4530
+ }
4531
+ } else {
4532
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
4533
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4534
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4535
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4536
+ for (int i = 0; i < 8; i++) {
4537
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4538
+ }
4539
+
4540
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4541
+
4542
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
4543
+ if (sgitg == 0) {
4544
+ for (int i = 0; i < n_rows; i++) {
4545
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4546
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4547
+ }
4548
+ }
4549
+ }
3774
4550
  }
3775
4551
  }
3776
4552
 
3777
- #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3778
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3779
- #define BLOCK_SIZE_K 32
3780
- #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
3781
- #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
3782
- #define THREAD_PER_BLOCK 128
3783
- #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
3784
- #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
3785
- #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
3786
- #define SG_MAT_ROW 8
3787
-
3788
- // each block_q contains 16*nl weights
4553
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
3789
4554
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3790
- void kernel_mul_mm_impl(device const uchar * src0,
3791
- device const uchar * src1,
3792
- device float * dst,
3793
- constant int64_t & ne00,
3794
- constant int64_t & ne02,
3795
- constant int64_t & nb01,
3796
- constant int64_t & nb02,
3797
- constant int64_t & ne12,
3798
- constant int64_t & nb10,
3799
- constant int64_t & nb11,
3800
- constant int64_t & nb12,
3801
- constant int64_t & ne0,
3802
- constant int64_t & ne1,
3803
- constant uint & r2,
3804
- constant uint & r3,
3805
- threadgroup uchar * shared_memory [[threadgroup(0)]],
3806
- uint3 tgpig[[threadgroup_position_in_grid]],
3807
- uint tiitg[[thread_index_in_threadgroup]],
3808
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4555
+ void kernel_mul_mm_id_impl(
4556
+ device const uchar * src0,
4557
+ device const uchar * src1,
4558
+ thread short * src1ids,
4559
+ device float * dst,
4560
+ constant int64_t & ne00,
4561
+ constant int64_t & ne02,
4562
+ constant uint64_t & nb01,
4563
+ constant uint64_t & nb02,
4564
+ constant int64_t & ne12,
4565
+ constant uint64_t & nb10,
4566
+ constant uint64_t & nb11,
4567
+ constant uint64_t & nb12,
4568
+ constant int64_t & ne0,
4569
+ int64_t ne1,
4570
+ constant uint & r2,
4571
+ constant uint & r3,
4572
+ threadgroup uchar * shared_memory,
4573
+ uint3 tgpig[[threadgroup_position_in_grid]],
4574
+ uint tiitg[[thread_index_in_threadgroup]],
4575
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3809
4576
 
3810
4577
  threadgroup half * sa = (threadgroup half *)(shared_memory);
3811
4578
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -3814,6 +4581,8 @@ void kernel_mul_mm_impl(device const uchar * src0,
3814
4581
  const uint r1 = tgpig.x;
3815
4582
  const uint im = tgpig.z;
3816
4583
 
4584
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
4585
+
3817
4586
  // if this block is of 64x32 shape or smaller
3818
4587
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
3819
4588
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
@@ -3840,7 +4609,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
3840
4609
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
3841
4610
  device const float * y = (device const float *)(src1
3842
4611
  + nb12 * im
3843
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
4612
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
3844
4613
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
3845
4614
 
3846
4615
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -3849,7 +4618,6 @@ void kernel_mul_mm_impl(device const uchar * src0,
3849
4618
  dequantize_func(x, il, temp_a);
3850
4619
  threadgroup_barrier(mem_flags::mem_threadgroup);
3851
4620
 
3852
- #pragma unroll(16)
3853
4621
  for (int i = 0; i < 16; i++) {
3854
4622
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
3855
4623
  + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
@@ -3868,14 +4636,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
3868
4636
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
3869
4637
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
3870
4638
 
3871
- #pragma unroll(4)
3872
4639
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
3873
- #pragma unroll(4)
3874
4640
  for (int i = 0; i < 4; i++) {
3875
4641
  simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
3876
4642
  }
3877
4643
  simdgroup_barrier(mem_flags::mem_none);
3878
- #pragma unroll(2)
3879
4644
  for (int i = 0; i < 2; i++) {
3880
4645
  simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
3881
4646
  }
@@ -3883,21 +4648,13 @@ void kernel_mul_mm_impl(device const uchar * src0,
3883
4648
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
3884
4649
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
3885
4650
 
3886
- #pragma unroll(8)
3887
4651
  for (int i = 0; i < 8; i++){
3888
4652
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
3889
4653
  }
3890
4654
  }
3891
4655
  }
3892
4656
 
3893
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
3894
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
3895
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
3896
- for (int i = 0; i < 8; i++) {
3897
- simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
3898
- }
3899
- } else {
3900
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
4657
+ {
3901
4658
  threadgroup_barrier(mem_flags::mem_threadgroup);
3902
4659
  threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
3903
4660
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
@@ -3907,11 +4664,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
3907
4664
 
3908
4665
  threadgroup_barrier(mem_flags::mem_threadgroup);
3909
4666
 
3910
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
4667
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
3911
4668
  if (sgitg == 0) {
3912
4669
  for (int i = 0; i < n_rows; i++) {
3913
4670
  for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
3914
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4671
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
3915
4672
  }
3916
4673
  }
3917
4674
  }
@@ -3924,12 +4681,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
3924
4681
  device float * dst,
3925
4682
  constant int64_t & ne00,
3926
4683
  constant int64_t & ne02,
3927
- constant int64_t & nb01,
3928
- constant int64_t & nb02,
4684
+ constant uint64_t & nb01,
4685
+ constant uint64_t & nb02,
3929
4686
  constant int64_t & ne12,
3930
- constant int64_t & nb10,
3931
- constant int64_t & nb11,
3932
- constant int64_t & nb12,
4687
+ constant uint64_t & nb10,
4688
+ constant uint64_t & nb11,
4689
+ constant uint64_t & nb12,
3933
4690
  constant int64_t & ne0,
3934
4691
  constant int64_t & ne1,
3935
4692
  constant uint & r2,
@@ -3964,20 +4721,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
3964
4721
  kernel void kernel_mul_mm_id(
3965
4722
  device const uchar * ids,
3966
4723
  device const uchar * src1,
3967
- device uchar * dst,
3968
- constant int64_t & nbi1,
4724
+ device float * dst,
4725
+ constant uint64_t & nbi1,
3969
4726
  constant int64_t & ne00,
3970
4727
  constant int64_t & ne02,
3971
- constant int64_t & nb01,
3972
- constant int64_t & nb02,
4728
+ constant uint64_t & nb01,
4729
+ constant uint64_t & nb02,
3973
4730
  constant int64_t & ne12,
3974
4731
  constant int64_t & ne13,
3975
- constant int64_t & nb10,
3976
- constant int64_t & nb11,
3977
- constant int64_t & nb12,
4732
+ constant uint64_t & nb10,
4733
+ constant uint64_t & nb11,
4734
+ constant uint64_t & nb12,
3978
4735
  constant int64_t & ne0,
3979
4736
  constant int64_t & ne1,
3980
- constant int64_t & nb1,
4737
+ constant uint64_t & nb1,
3981
4738
  constant uint & r2,
3982
4739
  constant uint & r3,
3983
4740
  constant int & idx,
@@ -3993,18 +4750,28 @@ kernel void kernel_mul_mm_id(
3993
4750
  uint3 tgpig[[threadgroup_position_in_grid]],
3994
4751
  uint tiitg[[thread_index_in_threadgroup]],
3995
4752
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3996
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4753
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3997
4754
 
3998
- const int64_t bid = tgpig.z/(ne12*ne13);
4755
+ // expert id
4756
+ const int32_t id = tgpig.z/(ne12*ne13);
3999
4757
 
4000
4758
  tgpig.z = tgpig.z%(ne12*ne13);
4001
4759
 
4002
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4760
+ // row indices of src1 for expert id
4761
+ int64_t _ne1 = 0;
4762
+ short src1ids[512];
4003
4763
 
4004
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
4005
- src0[id],
4006
- src1 + bid*nb11,
4007
- (device float *) (dst + bid*nb1),
4764
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
4765
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4766
+ src1ids[_ne1++] = i1;
4767
+ }
4768
+ }
4769
+
4770
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
4771
+ src0s[id],
4772
+ src1,
4773
+ src1ids,
4774
+ dst,
4008
4775
  ne00,
4009
4776
  ne02,
4010
4777
  nb01,
@@ -4014,7 +4781,7 @@ kernel void kernel_mul_mm_id(
4014
4781
  nb11,
4015
4782
  nb12,
4016
4783
  ne0,
4017
- ne1,
4784
+ _ne1,
4018
4785
  r2,
4019
4786
  r3,
4020
4787
  shared_memory,
@@ -4059,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
4059
4826
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4060
4827
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4061
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4062
4831
 
4063
4832
  //
4064
4833
  // matrix-matrix multiplication
@@ -4070,12 +4839,12 @@ typedef void (mat_mm_t)(
4070
4839
  device float * dst,
4071
4840
  constant int64_t & ne00,
4072
4841
  constant int64_t & ne02,
4073
- constant int64_t & nb01,
4074
- constant int64_t & nb02,
4842
+ constant uint64_t & nb01,
4843
+ constant uint64_t & nb02,
4075
4844
  constant int64_t & ne12,
4076
- constant int64_t & nb10,
4077
- constant int64_t & nb11,
4078
- constant int64_t & nb12,
4845
+ constant uint64_t & nb10,
4846
+ constant uint64_t & nb11,
4847
+ constant uint64_t & nb12,
4079
4848
  constant int64_t & ne0,
4080
4849
  constant int64_t & ne1,
4081
4850
  constant uint & r2,
@@ -4095,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4095
4864
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4096
4865
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4097
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4098
4869
 
4099
4870
  //
4100
4871
  // indirect matrix-matrix multiplication
@@ -4103,20 +4874,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4103
4874
  typedef void (mat_mm_id_t)(
4104
4875
  device const uchar * ids,
4105
4876
  device const uchar * src1,
4106
- device uchar * dst,
4107
- constant int64_t & nbi1,
4877
+ device float * dst,
4878
+ constant uint64_t & nbi1,
4108
4879
  constant int64_t & ne00,
4109
4880
  constant int64_t & ne02,
4110
- constant int64_t & nb01,
4111
- constant int64_t & nb02,
4881
+ constant uint64_t & nb01,
4882
+ constant uint64_t & nb02,
4112
4883
  constant int64_t & ne12,
4113
4884
  constant int64_t & ne13,
4114
- constant int64_t & nb10,
4115
- constant int64_t & nb11,
4116
- constant int64_t & nb12,
4885
+ constant uint64_t & nb10,
4886
+ constant uint64_t & nb11,
4887
+ constant uint64_t & nb12,
4117
4888
  constant int64_t & ne0,
4118
4889
  constant int64_t & ne1,
4119
- constant int64_t & nb1,
4890
+ constant uint64_t & nb1,
4120
4891
  constant uint & r2,
4121
4892
  constant uint & r3,
4122
4893
  constant int & idx,
@@ -4143,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
4143
4914
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4144
4915
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4145
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4146
4919
 
4147
4920
  //
4148
4921
  // matrix-vector multiplication
@@ -4152,8 +4925,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
4152
4925
  kernel void kernel_mul_mv_id_f32_f32(
4153
4926
  device const char * ids,
4154
4927
  device const char * src1,
4155
- device uchar * dst,
4156
- constant int64_t & nbi1,
4928
+ device float * dst,
4929
+ constant uint64_t & nbi1,
4157
4930
  constant int64_t & ne00,
4158
4931
  constant int64_t & ne01,
4159
4932
  constant int64_t & ne02,
@@ -4169,7 +4942,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4169
4942
  constant uint64_t & nb12,
4170
4943
  constant int64_t & ne0,
4171
4944
  constant int64_t & ne1,
4172
- constant int64_t & nb1,
4945
+ constant uint64_t & nb1,
4173
4946
  constant uint & r2,
4174
4947
  constant uint & r3,
4175
4948
  constant int & idx,
@@ -4196,7 +4969,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4196
4969
  kernel_mul_mv_f32_f32_impl(
4197
4970
  src0[id],
4198
4971
  src1 + bid*nb11,
4199
- (device float *) (dst + bid*nb1),
4972
+ dst + bid*ne0,
4200
4973
  ne00,
4201
4974
  ne01,
4202
4975
  ne02,
@@ -4221,8 +4994,8 @@ kernel void kernel_mul_mv_id_f32_f32(
4221
4994
  kernel void kernel_mul_mv_id_f16_f32(
4222
4995
  device const char * ids,
4223
4996
  device const char * src1,
4224
- device uchar * dst,
4225
- constant int64_t & nbi1,
4997
+ device float * dst,
4998
+ constant uint64_t & nbi1,
4226
4999
  constant int64_t & ne00,
4227
5000
  constant int64_t & ne01,
4228
5001
  constant int64_t & ne02,
@@ -4238,7 +5011,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4238
5011
  constant uint64_t & nb12,
4239
5012
  constant int64_t & ne0,
4240
5013
  constant int64_t & ne1,
4241
- constant int64_t & nb1,
5014
+ constant uint64_t & nb1,
4242
5015
  constant uint & r2,
4243
5016
  constant uint & r3,
4244
5017
  constant int & idx,
@@ -4265,7 +5038,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4265
5038
  kernel_mul_mv_f16_f32_impl(
4266
5039
  src0[id],
4267
5040
  src1 + bid*nb11,
4268
- (device float *) (dst + bid*nb1),
5041
+ dst + bid*ne0,
4269
5042
  ne00,
4270
5043
  ne01,
4271
5044
  ne02,
@@ -4290,8 +5063,8 @@ kernel void kernel_mul_mv_id_f16_f32(
4290
5063
  kernel void kernel_mul_mv_id_q8_0_f32(
4291
5064
  device const char * ids,
4292
5065
  device const char * src1,
4293
- device uchar * dst,
4294
- constant int64_t & nbi1,
5066
+ device float * dst,
5067
+ constant uint64_t & nbi1,
4295
5068
  constant int64_t & ne00,
4296
5069
  constant int64_t & ne01,
4297
5070
  constant int64_t & ne02,
@@ -4307,7 +5080,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4307
5080
  constant uint64_t & nb12,
4308
5081
  constant int64_t & ne0,
4309
5082
  constant int64_t & ne1,
4310
- constant int64_t & nb1,
5083
+ constant uint64_t & nb1,
4311
5084
  constant uint & r2,
4312
5085
  constant uint & r3,
4313
5086
  constant int & idx,
@@ -4334,7 +5107,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4334
5107
  kernel_mul_mv_q8_0_f32_impl(
4335
5108
  src0[id],
4336
5109
  (device const float *) (src1 + bid*nb11),
4337
- (device float *) ( dst + bid*nb1),
5110
+ dst + bid*ne0,
4338
5111
  ne00,
4339
5112
  ne01,
4340
5113
  ne02,
@@ -4353,8 +5126,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4353
5126
  kernel void kernel_mul_mv_id_q4_0_f32(
4354
5127
  device const char * ids,
4355
5128
  device const char * src1,
4356
- device uchar * dst,
4357
- constant int64_t & nbi1,
5129
+ device float * dst,
5130
+ constant uint64_t & nbi1,
4358
5131
  constant int64_t & ne00,
4359
5132
  constant int64_t & ne01,
4360
5133
  constant int64_t & ne02,
@@ -4370,7 +5143,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4370
5143
  constant uint64_t & nb12,
4371
5144
  constant int64_t & ne0,
4372
5145
  constant int64_t & ne1,
4373
- constant int64_t & nb1,
5146
+ constant uint64_t & nb1,
4374
5147
  constant uint & r2,
4375
5148
  constant uint & r3,
4376
5149
  constant int & idx,
@@ -4397,7 +5170,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4397
5170
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4398
5171
  src0[id],
4399
5172
  (device const float *) (src1 + bid*nb11),
4400
- (device float *) ( dst + bid*nb1),
5173
+ dst + bid*ne0,
4401
5174
  ne00,
4402
5175
  ne01,
4403
5176
  ne02,
@@ -4416,8 +5189,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4416
5189
  kernel void kernel_mul_mv_id_q4_1_f32(
4417
5190
  device const char * ids,
4418
5191
  device const char * src1,
4419
- device uchar * dst,
4420
- constant int64_t & nbi1,
5192
+ device float * dst,
5193
+ constant uint64_t & nbi1,
4421
5194
  constant int64_t & ne00,
4422
5195
  constant int64_t & ne01,
4423
5196
  constant int64_t & ne02,
@@ -4433,7 +5206,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4433
5206
  constant uint64_t & nb12,
4434
5207
  constant int64_t & ne0,
4435
5208
  constant int64_t & ne1,
4436
- constant int64_t & nb1,
5209
+ constant uint64_t & nb1,
4437
5210
  constant uint & r2,
4438
5211
  constant uint & r3,
4439
5212
  constant int & idx,
@@ -4460,7 +5233,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4460
5233
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4461
5234
  src0[id],
4462
5235
  (device const float *) (src1 + bid*nb11),
4463
- (device float *) ( dst + bid*nb1),
5236
+ dst + bid*ne0,
4464
5237
  ne00,
4465
5238
  ne01,
4466
5239
  ne02,
@@ -4479,8 +5252,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4479
5252
  kernel void kernel_mul_mv_id_q5_0_f32(
4480
5253
  device const char * ids,
4481
5254
  device const char * src1,
4482
- device uchar * dst,
4483
- constant int64_t & nbi1,
5255
+ device float * dst,
5256
+ constant uint64_t & nbi1,
4484
5257
  constant int64_t & ne00,
4485
5258
  constant int64_t & ne01,
4486
5259
  constant int64_t & ne02,
@@ -4496,7 +5269,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4496
5269
  constant uint64_t & nb12,
4497
5270
  constant int64_t & ne0,
4498
5271
  constant int64_t & ne1,
4499
- constant int64_t & nb1,
5272
+ constant uint64_t & nb1,
4500
5273
  constant uint & r2,
4501
5274
  constant uint & r3,
4502
5275
  constant int & idx,
@@ -4523,7 +5296,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4523
5296
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4524
5297
  src0[id],
4525
5298
  (device const float *) (src1 + bid*nb11),
4526
- (device float *) ( dst + bid*nb1),
5299
+ dst + bid*ne0,
4527
5300
  ne00,
4528
5301
  ne01,
4529
5302
  ne02,
@@ -4542,8 +5315,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4542
5315
  kernel void kernel_mul_mv_id_q5_1_f32(
4543
5316
  device const char * ids,
4544
5317
  device const char * src1,
4545
- device uchar * dst,
4546
- constant int64_t & nbi1,
5318
+ device float * dst,
5319
+ constant uint64_t & nbi1,
4547
5320
  constant int64_t & ne00,
4548
5321
  constant int64_t & ne01,
4549
5322
  constant int64_t & ne02,
@@ -4559,7 +5332,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4559
5332
  constant uint64_t & nb12,
4560
5333
  constant int64_t & ne0,
4561
5334
  constant int64_t & ne1,
4562
- constant int64_t & nb1,
5335
+ constant uint64_t & nb1,
4563
5336
  constant uint & r2,
4564
5337
  constant uint & r3,
4565
5338
  constant int & idx,
@@ -4586,7 +5359,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4586
5359
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4587
5360
  src0[id],
4588
5361
  (device const float *) (src1 + bid*nb11),
4589
- (device float *) ( dst + bid*nb1),
5362
+ dst + bid*ne0,
4590
5363
  ne00,
4591
5364
  ne01,
4592
5365
  ne02,
@@ -4605,8 +5378,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4605
5378
  kernel void kernel_mul_mv_id_q2_K_f32(
4606
5379
  device const char * ids,
4607
5380
  device const char * src1,
4608
- device uchar * dst,
4609
- constant int64_t & nbi1,
5381
+ device float * dst,
5382
+ constant uint64_t & nbi1,
4610
5383
  constant int64_t & ne00,
4611
5384
  constant int64_t & ne01,
4612
5385
  constant int64_t & ne02,
@@ -4622,7 +5395,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4622
5395
  constant uint64_t & nb12,
4623
5396
  constant int64_t & ne0,
4624
5397
  constant int64_t & ne1,
4625
- constant int64_t & nb1,
5398
+ constant uint64_t & nb1,
4626
5399
  constant uint & r2,
4627
5400
  constant uint & r3,
4628
5401
  constant int & idx,
@@ -4649,7 +5422,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4649
5422
  kernel_mul_mv_q2_K_f32_impl(
4650
5423
  src0[id],
4651
5424
  (device const float *) (src1 + bid*nb11),
4652
- (device float *) ( dst + bid*nb1),
5425
+ dst + bid*ne0,
4653
5426
  ne00,
4654
5427
  ne01,
4655
5428
  ne02,
@@ -4668,8 +5441,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4668
5441
  kernel void kernel_mul_mv_id_q3_K_f32(
4669
5442
  device const char * ids,
4670
5443
  device const char * src1,
4671
- device uchar * dst,
4672
- constant int64_t & nbi1,
5444
+ device float * dst,
5445
+ constant uint64_t & nbi1,
4673
5446
  constant int64_t & ne00,
4674
5447
  constant int64_t & ne01,
4675
5448
  constant int64_t & ne02,
@@ -4685,7 +5458,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4685
5458
  constant uint64_t & nb12,
4686
5459
  constant int64_t & ne0,
4687
5460
  constant int64_t & ne1,
4688
- constant int64_t & nb1,
5461
+ constant uint64_t & nb1,
4689
5462
  constant uint & r2,
4690
5463
  constant uint & r3,
4691
5464
  constant int & idx,
@@ -4712,7 +5485,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4712
5485
  kernel_mul_mv_q3_K_f32_impl(
4713
5486
  src0[id],
4714
5487
  (device const float *) (src1 + bid*nb11),
4715
- (device float *) ( dst + bid*nb1),
5488
+ dst + bid*ne0,
4716
5489
  ne00,
4717
5490
  ne01,
4718
5491
  ne02,
@@ -4731,8 +5504,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4731
5504
  kernel void kernel_mul_mv_id_q4_K_f32(
4732
5505
  device const char * ids,
4733
5506
  device const char * src1,
4734
- device uchar * dst,
4735
- constant int64_t & nbi1,
5507
+ device float * dst,
5508
+ constant uint64_t & nbi1,
4736
5509
  constant int64_t & ne00,
4737
5510
  constant int64_t & ne01,
4738
5511
  constant int64_t & ne02,
@@ -4748,7 +5521,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4748
5521
  constant uint64_t & nb12,
4749
5522
  constant int64_t & ne0,
4750
5523
  constant int64_t & ne1,
4751
- constant int64_t & nb1,
5524
+ constant uint64_t & nb1,
4752
5525
  constant uint & r2,
4753
5526
  constant uint & r3,
4754
5527
  constant int & idx,
@@ -4775,7 +5548,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4775
5548
  kernel_mul_mv_q4_K_f32_impl(
4776
5549
  src0[id],
4777
5550
  (device const float *) (src1 + bid*nb11),
4778
- (device float *) ( dst + bid*nb1),
5551
+ dst + bid*ne0,
4779
5552
  ne00,
4780
5553
  ne01,
4781
5554
  ne02,
@@ -4794,8 +5567,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4794
5567
  kernel void kernel_mul_mv_id_q5_K_f32(
4795
5568
  device const char * ids,
4796
5569
  device const char * src1,
4797
- device uchar * dst,
4798
- constant int64_t & nbi1,
5570
+ device float * dst,
5571
+ constant uint64_t & nbi1,
4799
5572
  constant int64_t & ne00,
4800
5573
  constant int64_t & ne01,
4801
5574
  constant int64_t & ne02,
@@ -4811,7 +5584,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4811
5584
  constant uint64_t & nb12,
4812
5585
  constant int64_t & ne0,
4813
5586
  constant int64_t & ne1,
4814
- constant int64_t & nb1,
5587
+ constant uint64_t & nb1,
4815
5588
  constant uint & r2,
4816
5589
  constant uint & r3,
4817
5590
  constant int & idx,
@@ -4838,7 +5611,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4838
5611
  kernel_mul_mv_q5_K_f32_impl(
4839
5612
  src0[id],
4840
5613
  (device const float *) (src1 + bid*nb11),
4841
- (device float *) ( dst + bid*nb1),
5614
+ dst + bid*ne0,
4842
5615
  ne00,
4843
5616
  ne01,
4844
5617
  ne02,
@@ -4857,8 +5630,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4857
5630
  kernel void kernel_mul_mv_id_q6_K_f32(
4858
5631
  device const char * ids,
4859
5632
  device const char * src1,
4860
- device uchar * dst,
4861
- constant int64_t & nbi1,
5633
+ device float * dst,
5634
+ constant uint64_t & nbi1,
4862
5635
  constant int64_t & ne00,
4863
5636
  constant int64_t & ne01,
4864
5637
  constant int64_t & ne02,
@@ -4874,7 +5647,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4874
5647
  constant uint64_t & nb12,
4875
5648
  constant int64_t & ne0,
4876
5649
  constant int64_t & ne1,
4877
- constant int64_t & nb1,
5650
+ constant uint64_t & nb1,
4878
5651
  constant uint & r2,
4879
5652
  constant uint & r3,
4880
5653
  constant int & idx,
@@ -4901,7 +5674,136 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4901
5674
  kernel_mul_mv_q6_K_f32_impl(
4902
5675
  src0[id],
4903
5676
  (device const float *) (src1 + bid*nb11),
4904
- (device float *) ( dst + bid*nb1),
5677
+ dst + bid*ne0,
5678
+ ne00,
5679
+ ne01,
5680
+ ne02,
5681
+ ne10,
5682
+ ne12,
5683
+ ne0,
5684
+ ne1,
5685
+ r2,
5686
+ r3,
5687
+ tgpig,
5688
+ tiisg,
5689
+ sgitg);
5690
+ }
5691
+
5692
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5693
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5694
+ device const char * ids,
5695
+ device const char * src1,
5696
+ device float * dst,
5697
+ constant uint64_t & nbi1,
5698
+ constant int64_t & ne00,
5699
+ constant int64_t & ne01,
5700
+ constant int64_t & ne02,
5701
+ constant uint64_t & nb00,
5702
+ constant uint64_t & nb01,
5703
+ constant uint64_t & nb02,
5704
+ constant int64_t & ne10,
5705
+ constant int64_t & ne11,
5706
+ constant int64_t & ne12,
5707
+ constant int64_t & ne13,
5708
+ constant uint64_t & nb10,
5709
+ constant uint64_t & nb11,
5710
+ constant uint64_t & nb12,
5711
+ constant int64_t & ne0,
5712
+ constant int64_t & ne1,
5713
+ constant uint64_t & nb1,
5714
+ constant uint & r2,
5715
+ constant uint & r3,
5716
+ constant int & idx,
5717
+ device const char * src00,
5718
+ device const char * src01,
5719
+ device const char * src02,
5720
+ device const char * src03,
5721
+ device const char * src04,
5722
+ device const char * src05,
5723
+ device const char * src06,
5724
+ device const char * src07,
5725
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5726
+ uint3 tgpig[[threadgroup_position_in_grid]],
5727
+ uint tiitg[[thread_index_in_threadgroup]],
5728
+ uint tiisg[[thread_index_in_simdgroup]],
5729
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5730
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5731
+
5732
+ const int64_t bid = tgpig.z/(ne12*ne13);
5733
+
5734
+ tgpig.z = tgpig.z%(ne12*ne13);
5735
+
5736
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5737
+
5738
+ kernel_mul_mv_iq2_xxs_f32_impl(
5739
+ src0[id],
5740
+ (device const float *) (src1 + bid*nb11),
5741
+ dst + bid*ne0,
5742
+ ne00,
5743
+ ne01,
5744
+ ne02,
5745
+ ne10,
5746
+ ne12,
5747
+ ne0,
5748
+ ne1,
5749
+ r2,
5750
+ r3,
5751
+ shared_values,
5752
+ tgpig,
5753
+ tiisg,
5754
+ sgitg);
5755
+ }
5756
+
5757
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5758
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5759
+ device const char * ids,
5760
+ device const char * src1,
5761
+ device float * dst,
5762
+ constant uint64_t & nbi1,
5763
+ constant int64_t & ne00,
5764
+ constant int64_t & ne01,
5765
+ constant int64_t & ne02,
5766
+ constant uint64_t & nb00,
5767
+ constant uint64_t & nb01,
5768
+ constant uint64_t & nb02,
5769
+ constant int64_t & ne10,
5770
+ constant int64_t & ne11,
5771
+ constant int64_t & ne12,
5772
+ constant int64_t & ne13,
5773
+ constant uint64_t & nb10,
5774
+ constant uint64_t & nb11,
5775
+ constant uint64_t & nb12,
5776
+ constant int64_t & ne0,
5777
+ constant int64_t & ne1,
5778
+ constant uint64_t & nb1,
5779
+ constant uint & r2,
5780
+ constant uint & r3,
5781
+ constant int & idx,
5782
+ device const char * src00,
5783
+ device const char * src01,
5784
+ device const char * src02,
5785
+ device const char * src03,
5786
+ device const char * src04,
5787
+ device const char * src05,
5788
+ device const char * src06,
5789
+ device const char * src07,
5790
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5791
+ uint3 tgpig[[threadgroup_position_in_grid]],
5792
+ uint tiitg[[thread_index_in_threadgroup]],
5793
+ uint tiisg[[thread_index_in_simdgroup]],
5794
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
+
5797
+ const int64_t bid = tgpig.z/(ne12*ne13);
5798
+
5799
+ tgpig.z = tgpig.z%(ne12*ne13);
5800
+
5801
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5802
+
5803
+ kernel_mul_mv_iq2_xs_f32_impl(
5804
+ src0[id],
5805
+ (device const float *) (src1 + bid*nb11),
5806
+ dst + bid*ne0,
4905
5807
  ne00,
4906
5808
  ne01,
4907
5809
  ne02,
@@ -4911,6 +5813,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4911
5813
  ne1,
4912
5814
  r2,
4913
5815
  r3,
5816
+ shared_values,
4914
5817
  tgpig,
4915
5818
  tiisg,
4916
5819
  sgitg);