whisper.rn 0.4.0-rc.7 → 0.4.0-rc.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,26 +59,26 @@ kernel void kernel_add(
59
59
  constant int64_t & ne01,
60
60
  constant int64_t & ne02,
61
61
  constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
62
+ constant uint64_t & nb00,
63
+ constant uint64_t & nb01,
64
+ constant uint64_t & nb02,
65
+ constant uint64_t & nb03,
66
66
  constant int64_t & ne10,
67
67
  constant int64_t & ne11,
68
68
  constant int64_t & ne12,
69
69
  constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
70
+ constant uint64_t & nb10,
71
+ constant uint64_t & nb11,
72
+ constant uint64_t & nb12,
73
+ constant uint64_t & nb13,
74
74
  constant int64_t & ne0,
75
75
  constant int64_t & ne1,
76
76
  constant int64_t & ne2,
77
77
  constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
78
+ constant uint64_t & nb0,
79
+ constant uint64_t & nb1,
80
+ constant uint64_t & nb2,
81
+ constant uint64_t & nb3,
82
82
  constant int64_t & offs,
83
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
84
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -109,26 +109,26 @@ kernel void kernel_mul(
109
109
  constant int64_t & ne01,
110
110
  constant int64_t & ne02,
111
111
  constant int64_t & ne03,
112
- constant int64_t & nb00,
113
- constant int64_t & nb01,
114
- constant int64_t & nb02,
115
- constant int64_t & nb03,
112
+ constant uint64_t & nb00,
113
+ constant uint64_t & nb01,
114
+ constant uint64_t & nb02,
115
+ constant uint64_t & nb03,
116
116
  constant int64_t & ne10,
117
117
  constant int64_t & ne11,
118
118
  constant int64_t & ne12,
119
119
  constant int64_t & ne13,
120
- constant int64_t & nb10,
121
- constant int64_t & nb11,
122
- constant int64_t & nb12,
123
- constant int64_t & nb13,
120
+ constant uint64_t & nb10,
121
+ constant uint64_t & nb11,
122
+ constant uint64_t & nb12,
123
+ constant uint64_t & nb13,
124
124
  constant int64_t & ne0,
125
125
  constant int64_t & ne1,
126
126
  constant int64_t & ne2,
127
127
  constant int64_t & ne3,
128
- constant int64_t & nb0,
129
- constant int64_t & nb1,
130
- constant int64_t & nb2,
131
- constant int64_t & nb3,
128
+ constant uint64_t & nb0,
129
+ constant uint64_t & nb1,
130
+ constant uint64_t & nb2,
131
+ constant uint64_t & nb3,
132
132
  uint3 tgpig[[threadgroup_position_in_grid]],
133
133
  uint3 tpitg[[thread_position_in_threadgroup]],
134
134
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -158,26 +158,26 @@ kernel void kernel_div(
158
158
  constant int64_t & ne01,
159
159
  constant int64_t & ne02,
160
160
  constant int64_t & ne03,
161
- constant int64_t & nb00,
162
- constant int64_t & nb01,
163
- constant int64_t & nb02,
164
- constant int64_t & nb03,
161
+ constant uint64_t & nb00,
162
+ constant uint64_t & nb01,
163
+ constant uint64_t & nb02,
164
+ constant uint64_t & nb03,
165
165
  constant int64_t & ne10,
166
166
  constant int64_t & ne11,
167
167
  constant int64_t & ne12,
168
168
  constant int64_t & ne13,
169
- constant int64_t & nb10,
170
- constant int64_t & nb11,
171
- constant int64_t & nb12,
172
- constant int64_t & nb13,
169
+ constant uint64_t & nb10,
170
+ constant uint64_t & nb11,
171
+ constant uint64_t & nb12,
172
+ constant uint64_t & nb13,
173
173
  constant int64_t & ne0,
174
174
  constant int64_t & ne1,
175
175
  constant int64_t & ne2,
176
176
  constant int64_t & ne3,
177
- constant int64_t & nb0,
178
- constant int64_t & nb1,
179
- constant int64_t & nb2,
180
- constant int64_t & nb3,
177
+ constant uint64_t & nb0,
178
+ constant uint64_t & nb1,
179
+ constant uint64_t & nb2,
180
+ constant uint64_t & nb3,
181
181
  uint3 tgpig[[threadgroup_position_in_grid]],
182
182
  uint3 tpitg[[thread_position_in_threadgroup]],
183
183
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
205
205
  device const float4 * src0,
206
206
  device const float4 * src1,
207
207
  device float4 * dst,
208
- constant int64_t & nb [[buffer(28)]],
208
+ constant uint64_t & nb [[buffer(28)]],
209
209
  uint tpig[[thread_position_in_grid]]) {
210
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
211
  }
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
214
214
  device const float4 * src0,
215
215
  device const float4 * src1,
216
216
  device float4 * dst,
217
- constant int64_t & nb [[buffer(28)]],
217
+ constant uint64_t & nb [[buffer(28)]],
218
218
  uint tpig[[thread_position_in_grid]]) {
219
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
220
  }
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
223
223
  device const float4 * src0,
224
224
  device const float4 * src1,
225
225
  device float4 * dst,
226
- constant int64_t & nb [[buffer(28)]],
226
+ constant uint64_t & nb [[buffer(28)]],
227
227
  uint tpig[[thread_position_in_grid]]) {
228
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
229
  }
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
307
307
  constant int64_t & ne01,
308
308
  constant int64_t & ne02,
309
309
  constant int64_t & ne03,
310
- constant int64_t & nb00,
311
- constant int64_t & nb01,
312
- constant int64_t & nb02,
313
- constant int64_t & nb03,
310
+ constant uint64_t & nb00,
311
+ constant uint64_t & nb01,
312
+ constant uint64_t & nb02,
313
+ constant uint64_t & nb03,
314
314
  constant int64_t & ne10,
315
315
  constant int64_t & ne11,
316
316
  constant int64_t & ne12,
317
317
  constant int64_t & ne13,
318
- constant int64_t & nb10,
319
- constant int64_t & nb11,
320
- constant int64_t & nb12,
321
- constant int64_t & nb13,
318
+ constant uint64_t & nb10,
319
+ constant uint64_t & nb11,
320
+ constant uint64_t & nb12,
321
+ constant uint64_t & nb13,
322
322
  constant int64_t & ne0,
323
323
  constant int64_t & ne1,
324
324
  constant int64_t & ne2,
325
325
  constant int64_t & ne3,
326
- constant int64_t & nb0,
327
- constant int64_t & nb1,
328
- constant int64_t & nb2,
329
- constant int64_t & nb3,
326
+ constant uint64_t & nb0,
327
+ constant uint64_t & nb1,
328
+ constant uint64_t & nb2,
329
+ constant uint64_t & nb3,
330
330
  uint3 tpig[[thread_position_in_grid]]) {
331
331
  int64_t i3 = tpig.z;
332
332
  int64_t i2 = tpig.y;
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
846
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
847
  //Note: This is a template, but strictly speaking it only applies to
848
848
  // quantizations where the block size is 32. It also does not
849
- // giard against the number of rows not being divisible by
849
+ // guard against the number of rows not being divisible by
850
850
  // N_DST, so this is another explicit assumption of the implementation.
851
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
852
  void mul_vec_q_n_f32_impl(
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
920
920
  device const float * src1,
921
921
  device float * dst,
922
922
  constant int64_t & ne00,
923
- constant int64_t & ne01[[buffer(4)]],
924
- constant int64_t & ne02[[buffer(5)]],
925
- constant int64_t & ne10[[buffer(9)]],
926
- constant int64_t & ne12[[buffer(11)]],
927
- constant int64_t & ne0 [[buffer(15)]],
928
- constant int64_t & ne1 [[buffer(16)]],
929
- constant uint & r2 [[buffer(17)]],
930
- constant uint & r3 [[buffer(18)]],
923
+ constant int64_t & ne01,
924
+ constant int64_t & ne02,
925
+ constant uint64_t & nb00,
926
+ constant uint64_t & nb01,
927
+ constant uint64_t & nb02,
928
+ constant int64_t & ne10,
929
+ constant int64_t & ne11,
930
+ constant int64_t & ne12,
931
+ constant uint64_t & nb10,
932
+ constant uint64_t & nb11,
933
+ constant uint64_t & nb12,
934
+ constant int64_t & ne0,
935
+ constant int64_t & ne1,
936
+ constant uint & r2,
937
+ constant uint & r3,
931
938
  uint3 tgpig[[threadgroup_position_in_grid]],
932
939
  uint tiisg[[thread_index_in_simdgroup]],
933
940
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
939
946
  device const float * src1,
940
947
  device float * dst,
941
948
  constant int64_t & ne00,
942
- constant int64_t & ne01[[buffer(4)]],
943
- constant int64_t & ne02[[buffer(5)]],
944
- constant int64_t & ne10[[buffer(9)]],
945
- constant int64_t & ne12[[buffer(11)]],
946
- constant int64_t & ne0 [[buffer(15)]],
947
- constant int64_t & ne1 [[buffer(16)]],
948
- constant uint & r2 [[buffer(17)]],
949
- constant uint & r3 [[buffer(18)]],
949
+ constant int64_t & ne01,
950
+ constant int64_t & ne02,
951
+ constant uint64_t & nb00,
952
+ constant uint64_t & nb01,
953
+ constant uint64_t & nb02,
954
+ constant int64_t & ne10,
955
+ constant int64_t & ne11,
956
+ constant int64_t & ne12,
957
+ constant uint64_t & nb10,
958
+ constant uint64_t & nb11,
959
+ constant uint64_t & nb12,
960
+ constant int64_t & ne0,
961
+ constant int64_t & ne1,
962
+ constant uint & r2,
963
+ constant uint & r3,
950
964
  uint3 tgpig[[threadgroup_position_in_grid]],
951
965
  uint tiisg[[thread_index_in_simdgroup]],
952
966
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
958
972
  device const float * src1,
959
973
  device float * dst,
960
974
  constant int64_t & ne00,
961
- constant int64_t & ne01[[buffer(4)]],
962
- constant int64_t & ne02[[buffer(5)]],
963
- constant int64_t & ne10[[buffer(9)]],
964
- constant int64_t & ne12[[buffer(11)]],
965
- constant int64_t & ne0 [[buffer(15)]],
966
- constant int64_t & ne1 [[buffer(16)]],
967
- constant uint & r2 [[buffer(17)]],
968
- constant uint & r3 [[buffer(18)]],
975
+ constant int64_t & ne01,
976
+ constant int64_t & ne02,
977
+ constant uint64_t & nb00,
978
+ constant uint64_t & nb01,
979
+ constant uint64_t & nb02,
980
+ constant int64_t & ne10,
981
+ constant int64_t & ne11,
982
+ constant int64_t & ne12,
983
+ constant uint64_t & nb10,
984
+ constant uint64_t & nb11,
985
+ constant uint64_t & nb12,
986
+ constant int64_t & ne0,
987
+ constant int64_t & ne1,
988
+ constant uint & r2,
989
+ constant uint & r3,
969
990
  uint3 tgpig[[threadgroup_position_in_grid]],
970
991
  uint tiisg[[thread_index_in_simdgroup]],
971
992
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
977
998
  device const float * src1,
978
999
  device float * dst,
979
1000
  constant int64_t & ne00,
980
- constant int64_t & ne01[[buffer(4)]],
981
- constant int64_t & ne02[[buffer(5)]],
982
- constant int64_t & ne10[[buffer(9)]],
983
- constant int64_t & ne12[[buffer(11)]],
984
- constant int64_t & ne0 [[buffer(15)]],
985
- constant int64_t & ne1 [[buffer(16)]],
986
- constant uint & r2 [[buffer(17)]],
987
- constant uint & r3 [[buffer(18)]],
1001
+ constant int64_t & ne01,
1002
+ constant int64_t & ne02,
1003
+ constant uint64_t & nb00,
1004
+ constant uint64_t & nb01,
1005
+ constant uint64_t & nb02,
1006
+ constant int64_t & ne10,
1007
+ constant int64_t & ne11,
1008
+ constant int64_t & ne12,
1009
+ constant uint64_t & nb10,
1010
+ constant uint64_t & nb11,
1011
+ constant uint64_t & nb12,
1012
+ constant int64_t & ne0,
1013
+ constant int64_t & ne1,
1014
+ constant uint & r2,
1015
+ constant uint & r3,
988
1016
  uint3 tgpig[[threadgroup_position_in_grid]],
989
1017
  uint tiisg[[thread_index_in_simdgroup]],
990
1018
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
1071
1099
  constant int64_t & ne00,
1072
1100
  constant int64_t & ne01,
1073
1101
  constant int64_t & ne02,
1102
+ constant uint64_t & nb00,
1103
+ constant uint64_t & nb01,
1104
+ constant uint64_t & nb02,
1074
1105
  constant int64_t & ne10,
1106
+ constant int64_t & ne11,
1075
1107
  constant int64_t & ne12,
1108
+ constant uint64_t & nb10,
1109
+ constant uint64_t & nb11,
1110
+ constant uint64_t & nb12,
1076
1111
  constant int64_t & ne0,
1077
1112
  constant int64_t & ne1,
1078
- constant uint & r2 [[buffer(17)]],
1079
- constant uint & r3 [[buffer(18)]],
1113
+ constant uint & r2,
1114
+ constant uint & r3,
1080
1115
  uint3 tgpig[[threadgroup_position_in_grid]],
1081
1116
  uint tiisg[[thread_index_in_simdgroup]],
1082
1117
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
1182
1217
  constant uint64_t & nb12,
1183
1218
  constant int64_t & ne0,
1184
1219
  constant int64_t & ne1,
1185
- constant uint & r2 [[buffer(17)]],
1186
- constant uint & r3 [[buffer(18)]],
1220
+ constant uint & r2,
1221
+ constant uint & r3,
1187
1222
  uint3 tgpig[[threadgroup_position_in_grid]],
1188
1223
  uint tiisg[[thread_index_in_simdgroup]]) {
1189
1224
  kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
1209
1244
  constant uint64_t & nb12,
1210
1245
  constant int64_t & ne0,
1211
1246
  constant int64_t & ne1,
1212
- constant uint & r2 [[buffer(17)]],
1213
- constant uint & r3 [[buffer(18)]],
1247
+ constant uint & r2,
1248
+ constant uint & r3,
1214
1249
  uint3 tgpig[[threadgroup_position_in_grid]],
1215
1250
  uint tiisg[[thread_index_in_simdgroup]]) {
1216
1251
 
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1346
1381
  constant uint64_t & nb12,
1347
1382
  constant int64_t & ne0,
1348
1383
  constant int64_t & ne1,
1349
- constant uint & r2 [[buffer(17)]],
1350
- constant uint & r3 [[buffer(18)]],
1384
+ constant uint & r2,
1385
+ constant uint & r3,
1351
1386
  uint3 tgpig[[threadgroup_position_in_grid]],
1352
1387
  uint tiisg[[thread_index_in_simdgroup]]) {
1353
1388
  kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
1452
1487
  constant uint64_t & nb12,
1453
1488
  constant int64_t & ne0,
1454
1489
  constant int64_t & ne1,
1455
- constant uint & r2 [[buffer(17)]],
1456
- constant uint & r3 [[buffer(18)]],
1490
+ constant uint & r2,
1491
+ constant uint & r3,
1457
1492
  uint3 tgpig[[threadgroup_position_in_grid]],
1458
1493
  uint tiisg[[thread_index_in_simdgroup]]) {
1459
1494
  kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1478
1513
  constant uint64_t & nb12,
1479
1514
  constant int64_t & ne0,
1480
1515
  constant int64_t & ne1,
1481
- constant uint & r2 [[buffer(17)]],
1482
- constant uint & r3 [[buffer(18)]],
1516
+ constant uint & r2,
1517
+ constant uint & r3,
1483
1518
  uint3 tgpig[[threadgroup_position_in_grid]],
1484
1519
  uint tiisg[[thread_index_in_simdgroup]]) {
1485
1520
 
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
1543
1578
  const int64_t i3 = n / (ne2*ne1*ne0);
1544
1579
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1545
1580
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1546
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1581
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1582
+
1547
1583
  const int64_t k = i3*ne3 + i2;
1548
1584
 
1549
1585
  float m_k;
@@ -1702,8 +1738,9 @@ kernel void kernel_rope(
1702
1738
  dst_data[1] = x0*sin_theta + x1*cos_theta;
1703
1739
  }
1704
1740
  } else {
1705
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1706
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1741
+ for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1742
+ if (ic < n_dims) {
1743
+ const int64_t ib = 0;
1707
1744
 
1708
1745
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1709
1746
  const float cur_rot = inv_ndims*ic - ib;
@@ -1722,6 +1759,14 @@ kernel void kernel_rope(
1722
1759
 
1723
1760
  dst_data[0] = x0*cos_theta - x1*sin_theta;
1724
1761
  dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1762
+ } else {
1763
+ const int64_t i0 = ic;
1764
+
1765
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1766
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1767
+
1768
+ dst_data[0] = src[0];
1769
+ dst_data[1] = src[1];
1725
1770
  }
1726
1771
  }
1727
1772
  }
@@ -2401,21 +2446,18 @@ typedef struct {
2401
2446
  } block_q6_K;
2402
2447
  // 210 bytes / block
2403
2448
 
2404
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2405
- uchar4 r;
2406
- if (j < 4) {
2407
- r[0] = q[j+0] & 63;
2408
- r[2] = q[j+1] & 63;
2409
- r[1] = q[j+4] & 63;
2410
- r[3] = q[j+5] & 63;
2411
- } else {
2412
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2413
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2414
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2415
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2416
- }
2417
- return r;
2418
- }
2449
+ typedef struct {
2450
+ half d;
2451
+ uint16_t qs[QK_K/8];
2452
+ } block_iq2_xxs;
2453
+ // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
+
2455
+ typedef struct {
2456
+ half d;
2457
+ uint16_t qs[QK_K/8];
2458
+ uint8_t scales[QK_K/32];
2459
+ } block_iq2_xs;
2460
+ // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2419
2461
 
2420
2462
  //====================================== dot products =========================
2421
2463
 
@@ -2575,14 +2617,21 @@ kernel void kernel_mul_mv_q2_K_f32(
2575
2617
  device const float * src1,
2576
2618
  device float * dst,
2577
2619
  constant int64_t & ne00,
2578
- constant int64_t & ne01[[buffer(4)]],
2579
- constant int64_t & ne02[[buffer(5)]],
2580
- constant int64_t & ne10[[buffer(9)]],
2581
- constant int64_t & ne12[[buffer(11)]],
2582
- constant int64_t & ne0 [[buffer(15)]],
2583
- constant int64_t & ne1 [[buffer(16)]],
2584
- constant uint & r2 [[buffer(17)]],
2585
- constant uint & r3 [[buffer(18)]],
2620
+ constant int64_t & ne01,
2621
+ constant int64_t & ne02,
2622
+ constant uint64_t & nb00,
2623
+ constant uint64_t & nb01,
2624
+ constant uint64_t & nb02,
2625
+ constant int64_t & ne10,
2626
+ constant int64_t & ne11,
2627
+ constant int64_t & ne12,
2628
+ constant uint64_t & nb10,
2629
+ constant uint64_t & nb11,
2630
+ constant uint64_t & nb12,
2631
+ constant int64_t & ne0,
2632
+ constant int64_t & ne1,
2633
+ constant uint & r2,
2634
+ constant uint & r3,
2586
2635
  uint3 tgpig[[threadgroup_position_in_grid]],
2587
2636
  uint tiisg[[thread_index_in_simdgroup]],
2588
2637
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2832,14 +2881,21 @@ kernel void kernel_mul_mv_q3_K_f32(
2832
2881
  device const float * src1,
2833
2882
  device float * dst,
2834
2883
  constant int64_t & ne00,
2835
- constant int64_t & ne01[[buffer(4)]],
2836
- constant int64_t & ne02[[buffer(5)]],
2837
- constant int64_t & ne10[[buffer(9)]],
2838
- constant int64_t & ne12[[buffer(11)]],
2839
- constant int64_t & ne0 [[buffer(15)]],
2840
- constant int64_t & ne1 [[buffer(16)]],
2841
- constant uint & r2 [[buffer(17)]],
2842
- constant uint & r3 [[buffer(18)]],
2884
+ constant int64_t & ne01,
2885
+ constant int64_t & ne02,
2886
+ constant uint64_t & nb00,
2887
+ constant uint64_t & nb01,
2888
+ constant uint64_t & nb02,
2889
+ constant int64_t & ne10,
2890
+ constant int64_t & ne11,
2891
+ constant int64_t & ne12,
2892
+ constant uint64_t & nb10,
2893
+ constant uint64_t & nb11,
2894
+ constant uint64_t & nb12,
2895
+ constant int64_t & ne0,
2896
+ constant int64_t & ne1,
2897
+ constant uint & r2,
2898
+ constant uint & r3,
2843
2899
  uint3 tgpig[[threadgroup_position_in_grid]],
2844
2900
  uint tiisg[[thread_index_in_simdgroup]],
2845
2901
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2975,8 +3031,8 @@ void kernel_mul_mv_q4_K_f32_impl(
2975
3031
  constant uint & r2,
2976
3032
  constant uint & r3,
2977
3033
  uint3 tgpig[[threadgroup_position_in_grid]],
2978
- uint tiisg[[thread_index_in_simdgroup]],
2979
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3034
+ uint tiisg[[thread_index_in_simdgroup]],
3035
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2980
3036
 
2981
3037
  const int ix = tiisg/4; // 0...7
2982
3038
  const int it = tiisg%4; // 0...3
@@ -2985,7 +3041,7 @@ void kernel_mul_mv_q4_K_f32_impl(
2985
3041
  const int r0 = tgpig.x;
2986
3042
  const int r1 = tgpig.y;
2987
3043
  const int im = tgpig.z;
2988
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3044
+ const int first_row = r0 * N_DST;
2989
3045
  const int ib_row = first_row * nb;
2990
3046
 
2991
3047
  const uint i12 = im%ne12;
@@ -3051,7 +3107,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3051
3107
  for (int row = 0; row < N_DST; ++row) {
3052
3108
  all_sum = simd_sum(sumf[row]);
3053
3109
  if (tiisg == 0) {
3054
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3110
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
3055
3111
  }
3056
3112
  }
3057
3113
  }
@@ -3063,14 +3119,21 @@ kernel void kernel_mul_mv_q4_K_f32(
3063
3119
  device const float * src1,
3064
3120
  device float * dst,
3065
3121
  constant int64_t & ne00,
3066
- constant int64_t & ne01[[buffer(4)]],
3067
- constant int64_t & ne02[[buffer(5)]],
3068
- constant int64_t & ne10[[buffer(9)]],
3069
- constant int64_t & ne12[[buffer(11)]],
3070
- constant int64_t & ne0 [[buffer(15)]],
3071
- constant int64_t & ne1 [[buffer(16)]],
3072
- constant uint & r2 [[buffer(17)]],
3073
- constant uint & r3 [[buffer(18)]],
3122
+ constant int64_t & ne01,
3123
+ constant int64_t & ne02,
3124
+ constant uint64_t & nb00,
3125
+ constant uint64_t & nb01,
3126
+ constant uint64_t & nb02,
3127
+ constant int64_t & ne10,
3128
+ constant int64_t & ne11,
3129
+ constant int64_t & ne12,
3130
+ constant uint64_t & nb10,
3131
+ constant uint64_t & nb11,
3132
+ constant uint64_t & nb12,
3133
+ constant int64_t & ne0,
3134
+ constant int64_t & ne1,
3135
+ constant uint & r2,
3136
+ constant uint & r3,
3074
3137
  uint3 tgpig[[threadgroup_position_in_grid]],
3075
3138
  uint tiisg[[thread_index_in_simdgroup]],
3076
3139
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3262,14 +3325,21 @@ kernel void kernel_mul_mv_q5_K_f32(
3262
3325
  device const float * src1,
3263
3326
  device float * dst,
3264
3327
  constant int64_t & ne00,
3265
- constant int64_t & ne01[[buffer(4)]],
3266
- constant int64_t & ne02[[buffer(5)]],
3267
- constant int64_t & ne10[[buffer(9)]],
3268
- constant int64_t & ne12[[buffer(11)]],
3269
- constant int64_t & ne0 [[buffer(15)]],
3270
- constant int64_t & ne1 [[buffer(16)]],
3271
- constant uint & r2 [[buffer(17)]],
3272
- constant uint & r3 [[buffer(18)]],
3328
+ constant int64_t & ne01,
3329
+ constant int64_t & ne02,
3330
+ constant uint64_t & nb00,
3331
+ constant uint64_t & nb01,
3332
+ constant uint64_t & nb02,
3333
+ constant int64_t & ne10,
3334
+ constant int64_t & ne11,
3335
+ constant int64_t & ne12,
3336
+ constant uint64_t & nb10,
3337
+ constant uint64_t & nb11,
3338
+ constant uint64_t & nb12,
3339
+ constant int64_t & ne0,
3340
+ constant int64_t & ne1,
3341
+ constant uint & r2,
3342
+ constant uint & r3,
3273
3343
  uint3 tgpig[[threadgroup_position_in_grid]],
3274
3344
  uint tiisg[[thread_index_in_simdgroup]],
3275
3345
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3389,14 +3459,21 @@ kernel void kernel_mul_mv_q6_K_f32(
3389
3459
  device const float * src1,
3390
3460
  device float * dst,
3391
3461
  constant int64_t & ne00,
3392
- constant int64_t & ne01[[buffer(4)]],
3393
- constant int64_t & ne02[[buffer(5)]],
3394
- constant int64_t & ne10[[buffer(9)]],
3395
- constant int64_t & ne12[[buffer(11)]],
3396
- constant int64_t & ne0 [[buffer(15)]],
3397
- constant int64_t & ne1 [[buffer(16)]],
3398
- constant uint & r2 [[buffer(17)]],
3399
- constant uint & r3 [[buffer(18)]],
3462
+ constant int64_t & ne01,
3463
+ constant int64_t & ne02,
3464
+ constant uint64_t & nb00,
3465
+ constant uint64_t & nb01,
3466
+ constant uint64_t & nb02,
3467
+ constant int64_t & ne10,
3468
+ constant int64_t & ne11,
3469
+ constant int64_t & ne12,
3470
+ constant uint64_t & nb10,
3471
+ constant uint64_t & nb11,
3472
+ constant uint64_t & nb12,
3473
+ constant int64_t & ne0,
3474
+ constant int64_t & ne1,
3475
+ constant uint & r2,
3476
+ constant uint & r3,
3400
3477
  uint3 tgpig[[threadgroup_position_in_grid]],
3401
3478
  uint tiisg[[thread_index_in_simdgroup]],
3402
3479
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3404,51 +3481,540 @@ kernel void kernel_mul_mv_q6_K_f32(
3404
3481
  kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
3482
  }
3406
3483
 
3407
- //============================= templates and their specializations =============================
3408
-
3409
- // NOTE: this is not dequantizing - we are simply fitting the template
3410
- template <typename type4x4>
3411
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
3412
- float4x4 temp = *(((device float4x4 *)src));
3413
- for (int i = 0; i < 16; i++){
3414
- reg[i/4][i%4] = temp[i/4][i%4];
3415
- }
3416
- }
3484
+ // ======================= "True" 2-bit
3485
+
3486
+ constexpr constant static uint64_t iq2xxs_grid[256] = {
3487
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3488
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3489
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3490
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3491
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3492
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3493
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3494
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3495
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3496
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3497
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3498
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3499
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3500
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3501
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3502
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3503
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3504
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3505
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3506
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3507
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3508
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3509
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3510
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3511
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3512
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3513
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3514
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3515
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3516
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3517
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3518
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3519
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3520
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3521
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3522
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3523
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3524
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3525
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3526
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3527
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3528
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3529
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3530
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3531
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3532
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3533
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3534
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3535
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3536
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3537
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3538
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3539
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3540
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3541
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3542
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3543
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3544
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3545
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3546
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3547
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3548
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3549
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3550
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3551
+ };
3417
3552
 
3418
- template <typename type4x4>
3419
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
3420
- half4x4 temp = *(((device half4x4 *)src));
3421
- for (int i = 0; i < 16; i++){
3422
- reg[i/4][i%4] = temp[i/4][i%4];
3423
- }
3424
- }
3553
+ constexpr constant static uint64_t iq2xs_grid[512] = {
3554
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3555
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3556
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3557
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3558
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3559
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3560
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3561
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3562
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3563
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3564
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3565
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3566
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3567
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3568
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3569
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3570
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3571
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3572
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3573
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3574
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3575
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3576
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3577
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3578
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3579
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3580
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3581
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3582
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3583
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3584
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3585
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3586
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3587
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3588
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3589
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3590
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3591
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3592
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3593
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3594
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3595
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3596
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3597
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3598
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3599
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3600
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3601
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3602
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3603
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3604
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3605
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3606
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3607
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3608
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3609
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3610
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3611
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3612
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3613
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3614
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3615
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3616
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3617
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3618
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3619
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3620
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3621
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3622
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3623
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3624
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3625
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3626
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3627
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3628
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3629
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3630
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3631
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3632
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3633
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3634
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3635
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3636
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3637
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3638
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3639
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3640
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3641
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3642
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3643
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3644
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3645
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3646
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3647
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3648
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3649
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3650
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3651
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3652
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3653
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3654
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3655
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3656
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3657
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3658
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3659
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3660
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3661
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3662
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3663
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3664
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3665
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3666
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3667
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3668
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3669
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3670
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3671
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3672
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3673
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3674
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3675
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3676
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3677
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3678
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3679
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3680
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3681
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
+ };
3425
3683
 
3426
- template <typename type4x4>
3427
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3428
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
3429
- const float d1 = il ? (xb->d / 16.h) : xb->d;
3430
- const float d2 = d1 / 256.f;
3431
- const float md = -8.h * xb->d;
3432
- const ushort mask0 = il ? 0x00F0 : 0x000F;
3433
- const ushort mask1 = mask0 << 8;
3684
+ constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3687
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3688
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3689
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3690
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3691
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3692
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3693
+ };
3434
3694
 
3435
- for (int i=0;i<8;i++) {
3436
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
3437
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
3438
- }
3439
- }
3695
+ constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3440
3696
 
3441
- template <typename type4x4>
3442
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
3443
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
3444
- const float d1 = il ? (xb->d / 16.h) : xb->d;
3445
- const float d2 = d1 / 256.f;
3446
- const float m = xb->m;
3447
- const ushort mask0 = il ? 0x00F0 : 0x000F;
3448
- const ushort mask1 = mask0 << 8;
3697
+ void kernel_mul_mv_iq2_xxs_f32_impl(
3698
+ device const void * src0,
3699
+ device const float * src1,
3700
+ device float * dst,
3701
+ constant int64_t & ne00,
3702
+ constant int64_t & ne01,
3703
+ constant int64_t & ne02,
3704
+ constant int64_t & ne10,
3705
+ constant int64_t & ne12,
3706
+ constant int64_t & ne0,
3707
+ constant int64_t & ne1,
3708
+ constant uint & r2,
3709
+ constant uint & r3,
3710
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3711
+ uint3 tgpig[[threadgroup_position_in_grid]],
3712
+ uint tiisg[[thread_index_in_simdgroup]],
3713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3449
3714
 
3450
- for (int i=0;i<8;i++) {
3451
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
3715
+ const int nb = ne00/QK_K;
3716
+ const int r0 = tgpig.x;
3717
+ const int r1 = tgpig.y;
3718
+ const int im = tgpig.z;
3719
+
3720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3721
+ const int ib_row = first_row * nb;
3722
+
3723
+ const uint i12 = im%ne12;
3724
+ const uint i13 = im/ne12;
3725
+
3726
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3727
+
3728
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
3729
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3730
+
3731
+ float yl[32];
3732
+ float sumf[N_DST]={0.f}, all_sum;
3733
+
3734
+ const int nb32 = nb * (QK_K / 32);
3735
+
3736
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3737
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
3738
+ {
3739
+ int nval = 4;
3740
+ int pos = (32*sgitg + tiisg)*nval;
3741
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
3742
+ nval = 2;
3743
+ pos = (32*sgitg + tiisg)*nval;
3744
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3745
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3746
+ }
3747
+
3748
+ #if QK_K == 256
3749
+ const int ix = tiisg;
3750
+
3751
+ device const float * y4 = y + 32 * ix;
3752
+
3753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3754
+
3755
+ for (int i = 0; i < 32; ++i) {
3756
+ yl[i] = y4[i];
3757
+ }
3758
+
3759
+ const int ibl = ib32 / (QK_K / 32);
3760
+ const int ib = ib32 % (QK_K / 32);
3761
+
3762
+ device const block_iq2_xxs * xr = x + ibl;
3763
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3764
+ device const half * dh = &xr->d;
3765
+
3766
+ for (int row = 0; row < N_DST; row++) {
3767
+
3768
+ const float db = dh[0];
3769
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
3770
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3771
+ const float d = db * (0.5f + (aux32 >> 28));
3772
+
3773
+ float sum = 0;
3774
+ for (int l = 0; l < 4; ++l) {
3775
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
3776
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
3777
+ for (int j = 0; j < 8; ++j) {
3778
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3779
+ }
3780
+ }
3781
+ sumf[row] += d * sum;
3782
+
3783
+ dh += nb*sizeof(block_iq2_xxs)/2;
3784
+ q2 += nb*sizeof(block_iq2_xxs)/2;
3785
+ }
3786
+
3787
+ y4 += 32 * 32;
3788
+ }
3789
+ #else
3790
+ // TODO
3791
+ #endif
3792
+
3793
+ for (int row = 0; row < N_DST; ++row) {
3794
+ all_sum = simd_sum(sumf[row]);
3795
+ if (tiisg == 0) {
3796
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3797
+ }
3798
+ }
3799
+ }
3800
+
3801
+ [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
3802
+ kernel void kernel_mul_mv_iq2_xxs_f32(
3803
+ device const void * src0,
3804
+ device const float * src1,
3805
+ device float * dst,
3806
+ constant int64_t & ne00,
3807
+ constant int64_t & ne01,
3808
+ constant int64_t & ne02,
3809
+ constant uint64_t & nb00,
3810
+ constant uint64_t & nb01,
3811
+ constant uint64_t & nb02,
3812
+ constant int64_t & ne10,
3813
+ constant int64_t & ne11,
3814
+ constant int64_t & ne12,
3815
+ constant uint64_t & nb10,
3816
+ constant uint64_t & nb11,
3817
+ constant uint64_t & nb12,
3818
+ constant int64_t & ne0,
3819
+ constant int64_t & ne1,
3820
+ constant uint & r2,
3821
+ constant uint & r3,
3822
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3823
+ uint3 tgpig[[threadgroup_position_in_grid]],
3824
+ uint tiisg[[thread_index_in_simdgroup]],
3825
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3826
+
3827
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3828
+ }
3829
+
3830
+ void kernel_mul_mv_iq2_xs_f32_impl(
3831
+ device const void * src0,
3832
+ device const float * src1,
3833
+ device float * dst,
3834
+ constant int64_t & ne00,
3835
+ constant int64_t & ne01,
3836
+ constant int64_t & ne02,
3837
+ constant int64_t & ne10,
3838
+ constant int64_t & ne12,
3839
+ constant int64_t & ne0,
3840
+ constant int64_t & ne1,
3841
+ constant uint & r2,
3842
+ constant uint & r3,
3843
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiisg[[thread_index_in_simdgroup]],
3846
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3847
+
3848
+ const int nb = ne00/QK_K;
3849
+ const int r0 = tgpig.x;
3850
+ const int r1 = tgpig.y;
3851
+ const int im = tgpig.z;
3852
+
3853
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3854
+ const int ib_row = first_row * nb;
3855
+
3856
+ const uint i12 = im%ne12;
3857
+ const uint i13 = im/ne12;
3858
+
3859
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3860
+
3861
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
3862
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3863
+
3864
+ float yl[32];
3865
+ float sumf[N_DST]={0.f}, all_sum;
3866
+
3867
+ const int nb32 = nb * (QK_K / 32);
3868
+
3869
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3870
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
3871
+ {
3872
+ int nval = 8;
3873
+ int pos = (32*sgitg + tiisg)*nval;
3874
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
3875
+ nval = 2;
3876
+ pos = (32*sgitg + tiisg)*nval;
3877
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3878
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3879
+ }
3880
+
3881
+ #if QK_K == 256
3882
+ const int ix = tiisg;
3883
+
3884
+ device const float * y4 = y + 32 * ix;
3885
+
3886
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3887
+
3888
+ for (int i = 0; i < 32; ++i) {
3889
+ yl[i] = y4[i];
3890
+ }
3891
+
3892
+ const int ibl = ib32 / (QK_K / 32);
3893
+ const int ib = ib32 % (QK_K / 32);
3894
+
3895
+ device const block_iq2_xs * xr = x + ibl;
3896
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3897
+ device const uint8_t * sc = xr->scales + ib;
3898
+ device const half * dh = &xr->d;
3899
+
3900
+ for (int row = 0; row < N_DST; row++) {
3901
+
3902
+ const float db = dh[0];
3903
+ const uint8_t ls1 = sc[0] & 0xf;
3904
+ const uint8_t ls2 = sc[0] >> 4;
3905
+ const float d1 = db * (0.5f + ls1);
3906
+ const float d2 = db * (0.5f + ls2);
3907
+
3908
+ float sum1 = 0, sum2 = 0;
3909
+ for (int l = 0; l < 2; ++l) {
3910
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3911
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3912
+ for (int j = 0; j < 8; ++j) {
3913
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3914
+ }
3915
+ }
3916
+ for (int l = 2; l < 4; ++l) {
3917
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3918
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3919
+ for (int j = 0; j < 8; ++j) {
3920
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3921
+ }
3922
+ }
3923
+ sumf[row] += d1 * sum1 + d2 * sum2;
3924
+
3925
+ dh += nb*sizeof(block_iq2_xs)/2;
3926
+ q2 += nb*sizeof(block_iq2_xs)/2;
3927
+ sc += nb*sizeof(block_iq2_xs);
3928
+ }
3929
+
3930
+ y4 += 32 * 32;
3931
+ }
3932
+ #else
3933
+ // TODO
3934
+ #endif
3935
+
3936
+ for (int row = 0; row < N_DST; ++row) {
3937
+ all_sum = simd_sum(sumf[row]);
3938
+ if (tiisg == 0) {
3939
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3940
+ }
3941
+ }
3942
+ }
3943
+
3944
+ [[host_name("kernel_mul_mv_iq2_xs_f32")]]
3945
+ kernel void kernel_mul_mv_iq2_xs_f32(
3946
+ device const void * src0,
3947
+ device const float * src1,
3948
+ device float * dst,
3949
+ constant int64_t & ne00,
3950
+ constant int64_t & ne01,
3951
+ constant int64_t & ne02,
3952
+ constant uint64_t & nb00,
3953
+ constant uint64_t & nb01,
3954
+ constant uint64_t & nb02,
3955
+ constant int64_t & ne10,
3956
+ constant int64_t & ne11,
3957
+ constant int64_t & ne12,
3958
+ constant uint64_t & nb10,
3959
+ constant uint64_t & nb11,
3960
+ constant uint64_t & nb12,
3961
+ constant int64_t & ne0,
3962
+ constant int64_t & ne1,
3963
+ constant uint & r2,
3964
+ constant uint & r3,
3965
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3966
+ uint3 tgpig[[threadgroup_position_in_grid]],
3967
+ uint tiisg[[thread_index_in_simdgroup]],
3968
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3969
+
3970
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
+ }
3972
+
3973
+ //============================= templates and their specializations =============================
3974
+
3975
+ // NOTE: this is not dequantizing - we are simply fitting the template
3976
+ template <typename type4x4>
3977
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
3978
+ float4x4 temp = *(((device float4x4 *)src));
3979
+ for (int i = 0; i < 16; i++){
3980
+ reg[i/4][i%4] = temp[i/4][i%4];
3981
+ }
3982
+ }
3983
+
3984
+ template <typename type4x4>
3985
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
3986
+ half4x4 temp = *(((device half4x4 *)src));
3987
+ for (int i = 0; i < 16; i++){
3988
+ reg[i/4][i%4] = temp[i/4][i%4];
3989
+ }
3990
+ }
3991
+
3992
+ template <typename type4x4>
3993
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3994
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
3995
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
3996
+ const float d2 = d1 / 256.f;
3997
+ const float md = -8.h * xb->d;
3998
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
3999
+ const ushort mask1 = mask0 << 8;
4000
+
4001
+ for (int i=0;i<8;i++) {
4002
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
4003
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
4004
+ }
4005
+ }
4006
+
4007
+ template <typename type4x4>
4008
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
4009
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
4010
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
4011
+ const float d2 = d1 / 256.f;
4012
+ const float m = xb->m;
4013
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
4014
+ const ushort mask1 = mask0 << 8;
4015
+
4016
+ for (int i=0;i<8;i++) {
4017
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
3452
4018
  reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
3453
4019
  }
3454
4020
  }
@@ -3514,7 +4080,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3514
4080
  device const int8_t * qs = ((device const int8_t *)xb->qs);
3515
4081
  const half d = xb->d;
3516
4082
 
3517
- for (int i=0;i<16;i++) {
4083
+ for (int i = 0; i < 16; i++) {
3518
4084
  reg[i/4][i%4] = (qs[i + 16*il] * d);
3519
4085
  }
3520
4086
  }
@@ -3556,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
3556
4122
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3557
4123
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3558
4124
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3559
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3560
- const half ml = 4.h * dl;
4125
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
4126
+ const float ml = 4.f * dl;
3561
4127
 
3562
4128
  il = (il/2) & 3;
3563
4129
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3624,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3624
4190
  uint8_t ul = 1 << (il/2);
3625
4191
  il = il & 3;
3626
4192
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3627
- const float d = il < 2 ? xb->d : xb->d / 16.h;
4193
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
3628
4194
  const float min = xb->dmin;
3629
4195
  const float dl = d * sc[0];
3630
4196
  const float ml = min * sc[1];
@@ -3657,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3657
4223
  #if QK_K == 256
3658
4224
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3659
4225
  qh = qh + 32*(il/8) + 16*(il&1);
3660
- half sc = scales[(il%2) + 2 * ((il/2))];
4226
+ float sc = scales[(il%2) + 2 * ((il/2))];
3661
4227
  il = (il/2) & 3;
3662
4228
  #else
3663
4229
  ql = ql + 16 * (il&1);
3664
- half sc = scales[il];
4230
+ float sc = scales[il];
3665
4231
  #endif
3666
4232
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3667
4233
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3668
- const half coef = il>1 ? 1.f/16.h : 1.h;
3669
- const half ml = d_all * sc * 32.h;
3670
- const half dl = d_all * sc * coef;
4234
+ const float coef = il>1 ? 1.f/16.f : 1.f;
4235
+ const float ml = d_all * sc * 32.f;
4236
+ const float dl = d_all * sc * coef;
3671
4237
  for (int i = 0; i < 16; ++i) {
3672
4238
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3673
4239
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@@ -3675,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3675
4241
  }
3676
4242
  }
3677
4243
 
4244
+ template <typename type4x4>
4245
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
4246
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4247
+ const float d = xb->d;
4248
+ const int ib32 = il/2;
4249
+ il = il%2;
4250
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4251
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
4252
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4253
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
4254
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4255
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4256
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4257
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4258
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4259
+ for (int i = 0; i < 8; ++i) {
4260
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4261
+ }
4262
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4263
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4264
+ for (int i = 0; i < 8; ++i) {
4265
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4266
+ }
4267
+ }
4268
+
4269
+ template <typename type4x4>
4270
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4271
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4272
+ const float d = xb->d;
4273
+ const int ib32 = il/2;
4274
+ il = il%2;
4275
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4276
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4277
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4278
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4279
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4280
+ for (int i = 0; i < 8; ++i) {
4281
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4282
+ }
4283
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4284
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4285
+ for (int i = 0; i < 8; ++i) {
4286
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4287
+ }
4288
+ }
4289
+
3678
4290
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3679
4291
  kernel void kernel_get_rows(
3680
4292
  device const void * src0,
@@ -3755,48 +4367,212 @@ kernel void kernel_get_rows_f16(
3755
4367
  const int64_t i10 = tgpig.x;
3756
4368
  const int64_t i11 = tgpig.y;
3757
4369
 
3758
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4370
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4371
+
4372
+ const int64_t i02 = i11;
4373
+
4374
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4375
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4376
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4377
+ }
4378
+ }
4379
+
4380
+ kernel void kernel_get_rows_i32(
4381
+ device const void * src0,
4382
+ device const char * src1,
4383
+ device int32_t * dst,
4384
+ constant int64_t & ne00,
4385
+ constant uint64_t & nb01,
4386
+ constant uint64_t & nb02,
4387
+ constant int64_t & ne10,
4388
+ constant uint64_t & nb10,
4389
+ constant uint64_t & nb11,
4390
+ constant uint64_t & nb1,
4391
+ constant uint64_t & nb2,
4392
+ uint3 tgpig[[threadgroup_position_in_grid]],
4393
+ uint tiitg[[thread_index_in_threadgroup]],
4394
+ uint3 tptg [[threads_per_threadgroup]]) {
4395
+ const int64_t i10 = tgpig.x;
4396
+ const int64_t i11 = tgpig.y;
4397
+
4398
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4399
+
4400
+ const int64_t i02 = i11;
4401
+
4402
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4403
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4404
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4405
+ }
4406
+ }
4407
+
4408
+
4409
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
4410
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
4411
+ #define BLOCK_SIZE_K 32
4412
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
4413
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
4414
+ #define THREAD_PER_BLOCK 128
4415
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
4416
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
4417
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
4418
+ #define SG_MAT_ROW 8
4419
+
4420
+ // each block_q contains 16*nl weights
4421
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4422
+ void kernel_mul_mm_impl(device const uchar * src0,
4423
+ device const uchar * src1,
4424
+ device float * dst,
4425
+ constant int64_t & ne00,
4426
+ constant int64_t & ne02,
4427
+ constant uint64_t & nb01,
4428
+ constant uint64_t & nb02,
4429
+ constant int64_t & ne12,
4430
+ constant uint64_t & nb10,
4431
+ constant uint64_t & nb11,
4432
+ constant uint64_t & nb12,
4433
+ constant int64_t & ne0,
4434
+ constant int64_t & ne1,
4435
+ constant uint & r2,
4436
+ constant uint & r3,
4437
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
4438
+ uint3 tgpig[[threadgroup_position_in_grid]],
4439
+ uint tiitg[[thread_index_in_threadgroup]],
4440
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4441
+
4442
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
4443
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
4444
+
4445
+ const uint r0 = tgpig.y;
4446
+ const uint r1 = tgpig.x;
4447
+ const uint im = tgpig.z;
4448
+
4449
+ // if this block is of 64x32 shape or smaller
4450
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4451
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
4452
+
4453
+ // a thread shouldn't load data outside of the matrix
4454
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4455
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
4456
+
4457
+ simdgroup_half8x8 ma[4];
4458
+ simdgroup_float8x8 mb[2];
4459
+ simdgroup_float8x8 c_res[8];
4460
+ for (int i = 0; i < 8; i++){
4461
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4462
+ }
4463
+
4464
+ short il = (tiitg % THREAD_PER_ROW);
4465
+
4466
+ const uint i12 = im%ne12;
4467
+ const uint i13 = im/ne12;
4468
+
4469
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4470
+ ushort offset1 = il/nl;
4471
+
4472
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4473
+ device const float * y = (device const float *)(src1
4474
+ + nb12 * im
4475
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
4476
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
4477
+
4478
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4479
+ // load data and store to threadgroup memory
4480
+ half4x4 temp_a;
4481
+ dequantize_func(x, il, temp_a);
4482
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4483
+
4484
+ #pragma unroll(16)
4485
+ for (int i = 0; i < 16; i++) {
4486
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4487
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4488
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4489
+ }
4490
+
4491
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
4492
+
4493
+ il = (il + 2 < nl) ? il + 2 : il % 2;
4494
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
4495
+ y += BLOCK_SIZE_K;
4496
+
4497
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4498
+
4499
+ // load matrices from threadgroup memory and conduct outer products
4500
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4501
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
4502
+
4503
+ #pragma unroll(4)
4504
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4505
+ #pragma unroll(4)
4506
+ for (int i = 0; i < 4; i++) {
4507
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4508
+ }
4509
+ simdgroup_barrier(mem_flags::mem_none);
4510
+ #pragma unroll(2)
4511
+ for (int i = 0; i < 2; i++) {
4512
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4513
+ }
4514
+
4515
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4516
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
3759
4517
 
3760
- const int64_t i02 = i11;
4518
+ #pragma unroll(8)
4519
+ for (int i = 0; i < 8; i++){
4520
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4521
+ }
4522
+ }
4523
+ }
3761
4524
 
3762
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4525
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
4526
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
4527
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
4528
+ for (int i = 0; i < 8; i++) {
4529
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
4530
+ }
4531
+ } else {
4532
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
4533
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4534
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4535
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4536
+ for (int i = 0; i < 8; i++) {
4537
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4538
+ }
4539
+
4540
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4541
+
4542
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
4543
+ if (sgitg == 0) {
4544
+ for (int i = 0; i < n_rows; i++) {
4545
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4546
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4547
+ }
4548
+ }
4549
+ }
3765
4550
  }
3766
4551
  }
3767
4552
 
3768
- #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3769
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3770
- #define BLOCK_SIZE_K 32
3771
- #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
3772
- #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
3773
- #define THREAD_PER_BLOCK 128
3774
- #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
3775
- #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
3776
- #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
3777
- #define SG_MAT_ROW 8
3778
-
3779
- // each block_q contains 16*nl weights
4553
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
3780
4554
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3781
- void kernel_mul_mm_impl(device const uchar * src0,
3782
- device const uchar * src1,
3783
- device float * dst,
3784
- constant int64_t & ne00,
3785
- constant int64_t & ne02,
3786
- constant int64_t & nb01,
3787
- constant int64_t & nb02,
3788
- constant int64_t & ne12,
3789
- constant int64_t & nb10,
3790
- constant int64_t & nb11,
3791
- constant int64_t & nb12,
3792
- constant int64_t & ne0,
3793
- constant int64_t & ne1,
3794
- constant uint & r2,
3795
- constant uint & r3,
3796
- threadgroup uchar * shared_memory [[threadgroup(0)]],
3797
- uint3 tgpig[[threadgroup_position_in_grid]],
3798
- uint tiitg[[thread_index_in_threadgroup]],
3799
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4555
+ void kernel_mul_mm_id_impl(
4556
+ device const uchar * src0,
4557
+ device const uchar * src1,
4558
+ thread short * src1ids,
4559
+ device float * dst,
4560
+ constant int64_t & ne00,
4561
+ constant int64_t & ne02,
4562
+ constant uint64_t & nb01,
4563
+ constant uint64_t & nb02,
4564
+ constant int64_t & ne12,
4565
+ constant uint64_t & nb10,
4566
+ constant uint64_t & nb11,
4567
+ constant uint64_t & nb12,
4568
+ constant int64_t & ne0,
4569
+ int64_t ne1,
4570
+ constant uint & r2,
4571
+ constant uint & r3,
4572
+ threadgroup uchar * shared_memory,
4573
+ uint3 tgpig[[threadgroup_position_in_grid]],
4574
+ uint tiitg[[thread_index_in_threadgroup]],
4575
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3800
4576
 
3801
4577
  threadgroup half * sa = (threadgroup half *)(shared_memory);
3802
4578
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -3805,6 +4581,8 @@ void kernel_mul_mm_impl(device const uchar * src0,
3805
4581
  const uint r1 = tgpig.x;
3806
4582
  const uint im = tgpig.z;
3807
4583
 
4584
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
4585
+
3808
4586
  // if this block is of 64x32 shape or smaller
3809
4587
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
3810
4588
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
@@ -3831,7 +4609,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
3831
4609
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
3832
4610
  device const float * y = (device const float *)(src1
3833
4611
  + nb12 * im
3834
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
4612
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
3835
4613
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
3836
4614
 
3837
4615
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -3840,7 +4618,6 @@ void kernel_mul_mm_impl(device const uchar * src0,
3840
4618
  dequantize_func(x, il, temp_a);
3841
4619
  threadgroup_barrier(mem_flags::mem_threadgroup);
3842
4620
 
3843
- #pragma unroll(16)
3844
4621
  for (int i = 0; i < 16; i++) {
3845
4622
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
3846
4623
  + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
@@ -3859,14 +4636,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
3859
4636
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
3860
4637
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
3861
4638
 
3862
- #pragma unroll(4)
3863
4639
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
3864
- #pragma unroll(4)
3865
4640
  for (int i = 0; i < 4; i++) {
3866
4641
  simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
3867
4642
  }
3868
4643
  simdgroup_barrier(mem_flags::mem_none);
3869
- #pragma unroll(2)
3870
4644
  for (int i = 0; i < 2; i++) {
3871
4645
  simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
3872
4646
  }
@@ -3874,21 +4648,13 @@ void kernel_mul_mm_impl(device const uchar * src0,
3874
4648
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
3875
4649
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
3876
4650
 
3877
- #pragma unroll(8)
3878
4651
  for (int i = 0; i < 8; i++){
3879
4652
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
3880
4653
  }
3881
4654
  }
3882
4655
  }
3883
4656
 
3884
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
3885
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
3886
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
3887
- for (int i = 0; i < 8; i++) {
3888
- simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
3889
- }
3890
- } else {
3891
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
4657
+ {
3892
4658
  threadgroup_barrier(mem_flags::mem_threadgroup);
3893
4659
  threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
3894
4660
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
@@ -3898,11 +4664,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
3898
4664
 
3899
4665
  threadgroup_barrier(mem_flags::mem_threadgroup);
3900
4666
 
3901
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
4667
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
3902
4668
  if (sgitg == 0) {
3903
4669
  for (int i = 0; i < n_rows; i++) {
3904
4670
  for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
3905
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4671
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
3906
4672
  }
3907
4673
  }
3908
4674
  }
@@ -3915,12 +4681,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
3915
4681
  device float * dst,
3916
4682
  constant int64_t & ne00,
3917
4683
  constant int64_t & ne02,
3918
- constant int64_t & nb01,
3919
- constant int64_t & nb02,
4684
+ constant uint64_t & nb01,
4685
+ constant uint64_t & nb02,
3920
4686
  constant int64_t & ne12,
3921
- constant int64_t & nb10,
3922
- constant int64_t & nb11,
3923
- constant int64_t & nb12,
4687
+ constant uint64_t & nb10,
4688
+ constant uint64_t & nb11,
4689
+ constant uint64_t & nb12,
3924
4690
  constant int64_t & ne0,
3925
4691
  constant int64_t & ne1,
3926
4692
  constant uint & r2,
@@ -3955,20 +4721,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
3955
4721
  kernel void kernel_mul_mm_id(
3956
4722
  device const uchar * ids,
3957
4723
  device const uchar * src1,
3958
- device uchar * dst,
3959
- constant int64_t & nbi1,
4724
+ device float * dst,
4725
+ constant uint64_t & nbi1,
3960
4726
  constant int64_t & ne00,
3961
4727
  constant int64_t & ne02,
3962
- constant int64_t & nb01,
3963
- constant int64_t & nb02,
4728
+ constant uint64_t & nb01,
4729
+ constant uint64_t & nb02,
3964
4730
  constant int64_t & ne12,
3965
4731
  constant int64_t & ne13,
3966
- constant int64_t & nb10,
3967
- constant int64_t & nb11,
3968
- constant int64_t & nb12,
4732
+ constant uint64_t & nb10,
4733
+ constant uint64_t & nb11,
4734
+ constant uint64_t & nb12,
3969
4735
  constant int64_t & ne0,
3970
4736
  constant int64_t & ne1,
3971
- constant int64_t & nb1,
4737
+ constant uint64_t & nb1,
3972
4738
  constant uint & r2,
3973
4739
  constant uint & r3,
3974
4740
  constant int & idx,
@@ -3984,18 +4750,28 @@ kernel void kernel_mul_mm_id(
3984
4750
  uint3 tgpig[[threadgroup_position_in_grid]],
3985
4751
  uint tiitg[[thread_index_in_threadgroup]],
3986
4752
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3987
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4753
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3988
4754
 
3989
- const int64_t bid = tgpig.z/(ne12*ne13);
4755
+ // expert id
4756
+ const int32_t id = tgpig.z/(ne12*ne13);
3990
4757
 
3991
4758
  tgpig.z = tgpig.z%(ne12*ne13);
3992
4759
 
3993
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4760
+ // row indices of src1 for expert id
4761
+ int64_t _ne1 = 0;
4762
+ short src1ids[512];
3994
4763
 
3995
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3996
- src0[id],
3997
- src1 + bid*nb11,
3998
- (device float *) (dst + bid*nb1),
4764
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
4765
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4766
+ src1ids[_ne1++] = i1;
4767
+ }
4768
+ }
4769
+
4770
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
4771
+ src0s[id],
4772
+ src1,
4773
+ src1ids,
4774
+ dst,
3999
4775
  ne00,
4000
4776
  ne02,
4001
4777
  nb01,
@@ -4005,7 +4781,7 @@ kernel void kernel_mul_mm_id(
4005
4781
  nb11,
4006
4782
  nb12,
4007
4783
  ne0,
4008
- ne1,
4784
+ _ne1,
4009
4785
  r2,
4010
4786
  r3,
4011
4787
  shared_memory,
@@ -4050,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
4050
4826
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4051
4827
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4052
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4053
4831
 
4054
4832
  //
4055
4833
  // matrix-matrix multiplication
@@ -4061,12 +4839,12 @@ typedef void (mat_mm_t)(
4061
4839
  device float * dst,
4062
4840
  constant int64_t & ne00,
4063
4841
  constant int64_t & ne02,
4064
- constant int64_t & nb01,
4065
- constant int64_t & nb02,
4842
+ constant uint64_t & nb01,
4843
+ constant uint64_t & nb02,
4066
4844
  constant int64_t & ne12,
4067
- constant int64_t & nb10,
4068
- constant int64_t & nb11,
4069
- constant int64_t & nb12,
4845
+ constant uint64_t & nb10,
4846
+ constant uint64_t & nb11,
4847
+ constant uint64_t & nb12,
4070
4848
  constant int64_t & ne0,
4071
4849
  constant int64_t & ne1,
4072
4850
  constant uint & r2,
@@ -4086,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4086
4864
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4087
4865
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4088
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4089
4869
 
4090
4870
  //
4091
4871
  // indirect matrix-matrix multiplication
@@ -4094,20 +4874,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4094
4874
  typedef void (mat_mm_id_t)(
4095
4875
  device const uchar * ids,
4096
4876
  device const uchar * src1,
4097
- device uchar * dst,
4098
- constant int64_t & nbi1,
4877
+ device float * dst,
4878
+ constant uint64_t & nbi1,
4099
4879
  constant int64_t & ne00,
4100
4880
  constant int64_t & ne02,
4101
- constant int64_t & nb01,
4102
- constant int64_t & nb02,
4881
+ constant uint64_t & nb01,
4882
+ constant uint64_t & nb02,
4103
4883
  constant int64_t & ne12,
4104
4884
  constant int64_t & ne13,
4105
- constant int64_t & nb10,
4106
- constant int64_t & nb11,
4107
- constant int64_t & nb12,
4885
+ constant uint64_t & nb10,
4886
+ constant uint64_t & nb11,
4887
+ constant uint64_t & nb12,
4108
4888
  constant int64_t & ne0,
4109
4889
  constant int64_t & ne1,
4110
- constant int64_t & nb1,
4890
+ constant uint64_t & nb1,
4111
4891
  constant uint & r2,
4112
4892
  constant uint & r3,
4113
4893
  constant int & idx,
@@ -4134,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
4134
4914
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4135
4915
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4136
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4137
4919
 
4138
4920
  //
4139
4921
  // matrix-vector multiplication
@@ -4143,8 +4925,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
4143
4925
  kernel void kernel_mul_mv_id_f32_f32(
4144
4926
  device const char * ids,
4145
4927
  device const char * src1,
4146
- device uchar * dst,
4147
- constant int64_t & nbi1,
4928
+ device float * dst,
4929
+ constant uint64_t & nbi1,
4148
4930
  constant int64_t & ne00,
4149
4931
  constant int64_t & ne01,
4150
4932
  constant int64_t & ne02,
@@ -4160,7 +4942,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4160
4942
  constant uint64_t & nb12,
4161
4943
  constant int64_t & ne0,
4162
4944
  constant int64_t & ne1,
4163
- constant int64_t & nb1,
4945
+ constant uint64_t & nb1,
4164
4946
  constant uint & r2,
4165
4947
  constant uint & r3,
4166
4948
  constant int & idx,
@@ -4187,7 +4969,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4187
4969
  kernel_mul_mv_f32_f32_impl(
4188
4970
  src0[id],
4189
4971
  src1 + bid*nb11,
4190
- (device float *) (dst + bid*nb1),
4972
+ dst + bid*ne0,
4191
4973
  ne00,
4192
4974
  ne01,
4193
4975
  ne02,
@@ -4212,8 +4994,8 @@ kernel void kernel_mul_mv_id_f32_f32(
4212
4994
  kernel void kernel_mul_mv_id_f16_f32(
4213
4995
  device const char * ids,
4214
4996
  device const char * src1,
4215
- device uchar * dst,
4216
- constant int64_t & nbi1,
4997
+ device float * dst,
4998
+ constant uint64_t & nbi1,
4217
4999
  constant int64_t & ne00,
4218
5000
  constant int64_t & ne01,
4219
5001
  constant int64_t & ne02,
@@ -4229,7 +5011,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4229
5011
  constant uint64_t & nb12,
4230
5012
  constant int64_t & ne0,
4231
5013
  constant int64_t & ne1,
4232
- constant int64_t & nb1,
5014
+ constant uint64_t & nb1,
4233
5015
  constant uint & r2,
4234
5016
  constant uint & r3,
4235
5017
  constant int & idx,
@@ -4256,7 +5038,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4256
5038
  kernel_mul_mv_f16_f32_impl(
4257
5039
  src0[id],
4258
5040
  src1 + bid*nb11,
4259
- (device float *) (dst + bid*nb1),
5041
+ dst + bid*ne0,
4260
5042
  ne00,
4261
5043
  ne01,
4262
5044
  ne02,
@@ -4281,8 +5063,8 @@ kernel void kernel_mul_mv_id_f16_f32(
4281
5063
  kernel void kernel_mul_mv_id_q8_0_f32(
4282
5064
  device const char * ids,
4283
5065
  device const char * src1,
4284
- device uchar * dst,
4285
- constant int64_t & nbi1,
5066
+ device float * dst,
5067
+ constant uint64_t & nbi1,
4286
5068
  constant int64_t & ne00,
4287
5069
  constant int64_t & ne01,
4288
5070
  constant int64_t & ne02,
@@ -4298,7 +5080,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4298
5080
  constant uint64_t & nb12,
4299
5081
  constant int64_t & ne0,
4300
5082
  constant int64_t & ne1,
4301
- constant int64_t & nb1,
5083
+ constant uint64_t & nb1,
4302
5084
  constant uint & r2,
4303
5085
  constant uint & r3,
4304
5086
  constant int & idx,
@@ -4325,7 +5107,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4325
5107
  kernel_mul_mv_q8_0_f32_impl(
4326
5108
  src0[id],
4327
5109
  (device const float *) (src1 + bid*nb11),
4328
- (device float *) ( dst + bid*nb1),
5110
+ dst + bid*ne0,
4329
5111
  ne00,
4330
5112
  ne01,
4331
5113
  ne02,
@@ -4344,8 +5126,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4344
5126
  kernel void kernel_mul_mv_id_q4_0_f32(
4345
5127
  device const char * ids,
4346
5128
  device const char * src1,
4347
- device uchar * dst,
4348
- constant int64_t & nbi1,
5129
+ device float * dst,
5130
+ constant uint64_t & nbi1,
4349
5131
  constant int64_t & ne00,
4350
5132
  constant int64_t & ne01,
4351
5133
  constant int64_t & ne02,
@@ -4361,7 +5143,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4361
5143
  constant uint64_t & nb12,
4362
5144
  constant int64_t & ne0,
4363
5145
  constant int64_t & ne1,
4364
- constant int64_t & nb1,
5146
+ constant uint64_t & nb1,
4365
5147
  constant uint & r2,
4366
5148
  constant uint & r3,
4367
5149
  constant int & idx,
@@ -4388,7 +5170,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4388
5170
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
5171
  src0[id],
4390
5172
  (device const float *) (src1 + bid*nb11),
4391
- (device float *) ( dst + bid*nb1),
5173
+ dst + bid*ne0,
4392
5174
  ne00,
4393
5175
  ne01,
4394
5176
  ne02,
@@ -4407,8 +5189,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4407
5189
  kernel void kernel_mul_mv_id_q4_1_f32(
4408
5190
  device const char * ids,
4409
5191
  device const char * src1,
4410
- device uchar * dst,
4411
- constant int64_t & nbi1,
5192
+ device float * dst,
5193
+ constant uint64_t & nbi1,
4412
5194
  constant int64_t & ne00,
4413
5195
  constant int64_t & ne01,
4414
5196
  constant int64_t & ne02,
@@ -4424,7 +5206,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4424
5206
  constant uint64_t & nb12,
4425
5207
  constant int64_t & ne0,
4426
5208
  constant int64_t & ne1,
4427
- constant int64_t & nb1,
5209
+ constant uint64_t & nb1,
4428
5210
  constant uint & r2,
4429
5211
  constant uint & r3,
4430
5212
  constant int & idx,
@@ -4451,7 +5233,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4451
5233
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
5234
  src0[id],
4453
5235
  (device const float *) (src1 + bid*nb11),
4454
- (device float *) ( dst + bid*nb1),
5236
+ dst + bid*ne0,
4455
5237
  ne00,
4456
5238
  ne01,
4457
5239
  ne02,
@@ -4470,8 +5252,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4470
5252
  kernel void kernel_mul_mv_id_q5_0_f32(
4471
5253
  device const char * ids,
4472
5254
  device const char * src1,
4473
- device uchar * dst,
4474
- constant int64_t & nbi1,
5255
+ device float * dst,
5256
+ constant uint64_t & nbi1,
4475
5257
  constant int64_t & ne00,
4476
5258
  constant int64_t & ne01,
4477
5259
  constant int64_t & ne02,
@@ -4487,7 +5269,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4487
5269
  constant uint64_t & nb12,
4488
5270
  constant int64_t & ne0,
4489
5271
  constant int64_t & ne1,
4490
- constant int64_t & nb1,
5272
+ constant uint64_t & nb1,
4491
5273
  constant uint & r2,
4492
5274
  constant uint & r3,
4493
5275
  constant int & idx,
@@ -4514,7 +5296,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4514
5296
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
5297
  src0[id],
4516
5298
  (device const float *) (src1 + bid*nb11),
4517
- (device float *) ( dst + bid*nb1),
5299
+ dst + bid*ne0,
4518
5300
  ne00,
4519
5301
  ne01,
4520
5302
  ne02,
@@ -4533,8 +5315,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4533
5315
  kernel void kernel_mul_mv_id_q5_1_f32(
4534
5316
  device const char * ids,
4535
5317
  device const char * src1,
4536
- device uchar * dst,
4537
- constant int64_t & nbi1,
5318
+ device float * dst,
5319
+ constant uint64_t & nbi1,
4538
5320
  constant int64_t & ne00,
4539
5321
  constant int64_t & ne01,
4540
5322
  constant int64_t & ne02,
@@ -4550,7 +5332,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4550
5332
  constant uint64_t & nb12,
4551
5333
  constant int64_t & ne0,
4552
5334
  constant int64_t & ne1,
4553
- constant int64_t & nb1,
5335
+ constant uint64_t & nb1,
4554
5336
  constant uint & r2,
4555
5337
  constant uint & r3,
4556
5338
  constant int & idx,
@@ -4577,7 +5359,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4577
5359
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
5360
  src0[id],
4579
5361
  (device const float *) (src1 + bid*nb11),
4580
- (device float *) ( dst + bid*nb1),
5362
+ dst + bid*ne0,
4581
5363
  ne00,
4582
5364
  ne01,
4583
5365
  ne02,
@@ -4596,8 +5378,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4596
5378
  kernel void kernel_mul_mv_id_q2_K_f32(
4597
5379
  device const char * ids,
4598
5380
  device const char * src1,
4599
- device uchar * dst,
4600
- constant int64_t & nbi1,
5381
+ device float * dst,
5382
+ constant uint64_t & nbi1,
4601
5383
  constant int64_t & ne00,
4602
5384
  constant int64_t & ne01,
4603
5385
  constant int64_t & ne02,
@@ -4613,7 +5395,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4613
5395
  constant uint64_t & nb12,
4614
5396
  constant int64_t & ne0,
4615
5397
  constant int64_t & ne1,
4616
- constant int64_t & nb1,
5398
+ constant uint64_t & nb1,
4617
5399
  constant uint & r2,
4618
5400
  constant uint & r3,
4619
5401
  constant int & idx,
@@ -4640,7 +5422,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4640
5422
  kernel_mul_mv_q2_K_f32_impl(
4641
5423
  src0[id],
4642
5424
  (device const float *) (src1 + bid*nb11),
4643
- (device float *) ( dst + bid*nb1),
5425
+ dst + bid*ne0,
4644
5426
  ne00,
4645
5427
  ne01,
4646
5428
  ne02,
@@ -4659,8 +5441,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4659
5441
  kernel void kernel_mul_mv_id_q3_K_f32(
4660
5442
  device const char * ids,
4661
5443
  device const char * src1,
4662
- device uchar * dst,
4663
- constant int64_t & nbi1,
5444
+ device float * dst,
5445
+ constant uint64_t & nbi1,
4664
5446
  constant int64_t & ne00,
4665
5447
  constant int64_t & ne01,
4666
5448
  constant int64_t & ne02,
@@ -4676,7 +5458,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4676
5458
  constant uint64_t & nb12,
4677
5459
  constant int64_t & ne0,
4678
5460
  constant int64_t & ne1,
4679
- constant int64_t & nb1,
5461
+ constant uint64_t & nb1,
4680
5462
  constant uint & r2,
4681
5463
  constant uint & r3,
4682
5464
  constant int & idx,
@@ -4703,7 +5485,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4703
5485
  kernel_mul_mv_q3_K_f32_impl(
4704
5486
  src0[id],
4705
5487
  (device const float *) (src1 + bid*nb11),
4706
- (device float *) ( dst + bid*nb1),
5488
+ dst + bid*ne0,
4707
5489
  ne00,
4708
5490
  ne01,
4709
5491
  ne02,
@@ -4722,8 +5504,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4722
5504
  kernel void kernel_mul_mv_id_q4_K_f32(
4723
5505
  device const char * ids,
4724
5506
  device const char * src1,
4725
- device uchar * dst,
4726
- constant int64_t & nbi1,
5507
+ device float * dst,
5508
+ constant uint64_t & nbi1,
4727
5509
  constant int64_t & ne00,
4728
5510
  constant int64_t & ne01,
4729
5511
  constant int64_t & ne02,
@@ -4739,7 +5521,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4739
5521
  constant uint64_t & nb12,
4740
5522
  constant int64_t & ne0,
4741
5523
  constant int64_t & ne1,
4742
- constant int64_t & nb1,
5524
+ constant uint64_t & nb1,
4743
5525
  constant uint & r2,
4744
5526
  constant uint & r3,
4745
5527
  constant int & idx,
@@ -4766,7 +5548,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4766
5548
  kernel_mul_mv_q4_K_f32_impl(
4767
5549
  src0[id],
4768
5550
  (device const float *) (src1 + bid*nb11),
4769
- (device float *) ( dst + bid*nb1),
5551
+ dst + bid*ne0,
4770
5552
  ne00,
4771
5553
  ne01,
4772
5554
  ne02,
@@ -4785,8 +5567,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4785
5567
  kernel void kernel_mul_mv_id_q5_K_f32(
4786
5568
  device const char * ids,
4787
5569
  device const char * src1,
4788
- device uchar * dst,
4789
- constant int64_t & nbi1,
5570
+ device float * dst,
5571
+ constant uint64_t & nbi1,
4790
5572
  constant int64_t & ne00,
4791
5573
  constant int64_t & ne01,
4792
5574
  constant int64_t & ne02,
@@ -4802,7 +5584,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4802
5584
  constant uint64_t & nb12,
4803
5585
  constant int64_t & ne0,
4804
5586
  constant int64_t & ne1,
4805
- constant int64_t & nb1,
5587
+ constant uint64_t & nb1,
4806
5588
  constant uint & r2,
4807
5589
  constant uint & r3,
4808
5590
  constant int & idx,
@@ -4829,7 +5611,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4829
5611
  kernel_mul_mv_q5_K_f32_impl(
4830
5612
  src0[id],
4831
5613
  (device const float *) (src1 + bid*nb11),
4832
- (device float *) ( dst + bid*nb1),
5614
+ dst + bid*ne0,
4833
5615
  ne00,
4834
5616
  ne01,
4835
5617
  ne02,
@@ -4848,8 +5630,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4848
5630
  kernel void kernel_mul_mv_id_q6_K_f32(
4849
5631
  device const char * ids,
4850
5632
  device const char * src1,
4851
- device uchar * dst,
4852
- constant int64_t & nbi1,
5633
+ device float * dst,
5634
+ constant uint64_t & nbi1,
4853
5635
  constant int64_t & ne00,
4854
5636
  constant int64_t & ne01,
4855
5637
  constant int64_t & ne02,
@@ -4865,7 +5647,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4865
5647
  constant uint64_t & nb12,
4866
5648
  constant int64_t & ne0,
4867
5649
  constant int64_t & ne1,
4868
- constant int64_t & nb1,
5650
+ constant uint64_t & nb1,
4869
5651
  constant uint & r2,
4870
5652
  constant uint & r3,
4871
5653
  constant int & idx,
@@ -4892,7 +5674,136 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4892
5674
  kernel_mul_mv_q6_K_f32_impl(
4893
5675
  src0[id],
4894
5676
  (device const float *) (src1 + bid*nb11),
4895
- (device float *) ( dst + bid*nb1),
5677
+ dst + bid*ne0,
5678
+ ne00,
5679
+ ne01,
5680
+ ne02,
5681
+ ne10,
5682
+ ne12,
5683
+ ne0,
5684
+ ne1,
5685
+ r2,
5686
+ r3,
5687
+ tgpig,
5688
+ tiisg,
5689
+ sgitg);
5690
+ }
5691
+
5692
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5693
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5694
+ device const char * ids,
5695
+ device const char * src1,
5696
+ device float * dst,
5697
+ constant uint64_t & nbi1,
5698
+ constant int64_t & ne00,
5699
+ constant int64_t & ne01,
5700
+ constant int64_t & ne02,
5701
+ constant uint64_t & nb00,
5702
+ constant uint64_t & nb01,
5703
+ constant uint64_t & nb02,
5704
+ constant int64_t & ne10,
5705
+ constant int64_t & ne11,
5706
+ constant int64_t & ne12,
5707
+ constant int64_t & ne13,
5708
+ constant uint64_t & nb10,
5709
+ constant uint64_t & nb11,
5710
+ constant uint64_t & nb12,
5711
+ constant int64_t & ne0,
5712
+ constant int64_t & ne1,
5713
+ constant uint64_t & nb1,
5714
+ constant uint & r2,
5715
+ constant uint & r3,
5716
+ constant int & idx,
5717
+ device const char * src00,
5718
+ device const char * src01,
5719
+ device const char * src02,
5720
+ device const char * src03,
5721
+ device const char * src04,
5722
+ device const char * src05,
5723
+ device const char * src06,
5724
+ device const char * src07,
5725
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5726
+ uint3 tgpig[[threadgroup_position_in_grid]],
5727
+ uint tiitg[[thread_index_in_threadgroup]],
5728
+ uint tiisg[[thread_index_in_simdgroup]],
5729
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5730
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5731
+
5732
+ const int64_t bid = tgpig.z/(ne12*ne13);
5733
+
5734
+ tgpig.z = tgpig.z%(ne12*ne13);
5735
+
5736
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5737
+
5738
+ kernel_mul_mv_iq2_xxs_f32_impl(
5739
+ src0[id],
5740
+ (device const float *) (src1 + bid*nb11),
5741
+ dst + bid*ne0,
5742
+ ne00,
5743
+ ne01,
5744
+ ne02,
5745
+ ne10,
5746
+ ne12,
5747
+ ne0,
5748
+ ne1,
5749
+ r2,
5750
+ r3,
5751
+ shared_values,
5752
+ tgpig,
5753
+ tiisg,
5754
+ sgitg);
5755
+ }
5756
+
5757
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5758
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5759
+ device const char * ids,
5760
+ device const char * src1,
5761
+ device float * dst,
5762
+ constant uint64_t & nbi1,
5763
+ constant int64_t & ne00,
5764
+ constant int64_t & ne01,
5765
+ constant int64_t & ne02,
5766
+ constant uint64_t & nb00,
5767
+ constant uint64_t & nb01,
5768
+ constant uint64_t & nb02,
5769
+ constant int64_t & ne10,
5770
+ constant int64_t & ne11,
5771
+ constant int64_t & ne12,
5772
+ constant int64_t & ne13,
5773
+ constant uint64_t & nb10,
5774
+ constant uint64_t & nb11,
5775
+ constant uint64_t & nb12,
5776
+ constant int64_t & ne0,
5777
+ constant int64_t & ne1,
5778
+ constant uint64_t & nb1,
5779
+ constant uint & r2,
5780
+ constant uint & r3,
5781
+ constant int & idx,
5782
+ device const char * src00,
5783
+ device const char * src01,
5784
+ device const char * src02,
5785
+ device const char * src03,
5786
+ device const char * src04,
5787
+ device const char * src05,
5788
+ device const char * src06,
5789
+ device const char * src07,
5790
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5791
+ uint3 tgpig[[threadgroup_position_in_grid]],
5792
+ uint tiitg[[thread_index_in_threadgroup]],
5793
+ uint tiisg[[thread_index_in_simdgroup]],
5794
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
+
5797
+ const int64_t bid = tgpig.z/(ne12*ne13);
5798
+
5799
+ tgpig.z = tgpig.z%(ne12*ne13);
5800
+
5801
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5802
+
5803
+ kernel_mul_mv_iq2_xs_f32_impl(
5804
+ src0[id],
5805
+ (device const float *) (src1 + bid*nb11),
5806
+ dst + bid*ne0,
4896
5807
  ne00,
4897
5808
  ne01,
4898
5809
  ne02,
@@ -4902,6 +5813,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4902
5813
  ne1,
4903
5814
  r2,
4904
5815
  r3,
5816
+ shared_values,
4905
5817
  tgpig,
4906
5818
  tiisg,
4907
5819
  sgitg);