llama_cpp 0.10.3 → 0.10.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -59,26 +59,26 @@ kernel void kernel_add(
59
59
  constant int64_t & ne01,
60
60
  constant int64_t & ne02,
61
61
  constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
62
+ constant uint64_t & nb00,
63
+ constant uint64_t & nb01,
64
+ constant uint64_t & nb02,
65
+ constant uint64_t & nb03,
66
66
  constant int64_t & ne10,
67
67
  constant int64_t & ne11,
68
68
  constant int64_t & ne12,
69
69
  constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
70
+ constant uint64_t & nb10,
71
+ constant uint64_t & nb11,
72
+ constant uint64_t & nb12,
73
+ constant uint64_t & nb13,
74
74
  constant int64_t & ne0,
75
75
  constant int64_t & ne1,
76
76
  constant int64_t & ne2,
77
77
  constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
78
+ constant uint64_t & nb0,
79
+ constant uint64_t & nb1,
80
+ constant uint64_t & nb2,
81
+ constant uint64_t & nb3,
82
82
  constant int64_t & offs,
83
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
84
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -109,26 +109,26 @@ kernel void kernel_mul(
109
109
  constant int64_t & ne01,
110
110
  constant int64_t & ne02,
111
111
  constant int64_t & ne03,
112
- constant int64_t & nb00,
113
- constant int64_t & nb01,
114
- constant int64_t & nb02,
115
- constant int64_t & nb03,
112
+ constant uint64_t & nb00,
113
+ constant uint64_t & nb01,
114
+ constant uint64_t & nb02,
115
+ constant uint64_t & nb03,
116
116
  constant int64_t & ne10,
117
117
  constant int64_t & ne11,
118
118
  constant int64_t & ne12,
119
119
  constant int64_t & ne13,
120
- constant int64_t & nb10,
121
- constant int64_t & nb11,
122
- constant int64_t & nb12,
123
- constant int64_t & nb13,
120
+ constant uint64_t & nb10,
121
+ constant uint64_t & nb11,
122
+ constant uint64_t & nb12,
123
+ constant uint64_t & nb13,
124
124
  constant int64_t & ne0,
125
125
  constant int64_t & ne1,
126
126
  constant int64_t & ne2,
127
127
  constant int64_t & ne3,
128
- constant int64_t & nb0,
129
- constant int64_t & nb1,
130
- constant int64_t & nb2,
131
- constant int64_t & nb3,
128
+ constant uint64_t & nb0,
129
+ constant uint64_t & nb1,
130
+ constant uint64_t & nb2,
131
+ constant uint64_t & nb3,
132
132
  uint3 tgpig[[threadgroup_position_in_grid]],
133
133
  uint3 tpitg[[thread_position_in_threadgroup]],
134
134
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -158,26 +158,26 @@ kernel void kernel_div(
158
158
  constant int64_t & ne01,
159
159
  constant int64_t & ne02,
160
160
  constant int64_t & ne03,
161
- constant int64_t & nb00,
162
- constant int64_t & nb01,
163
- constant int64_t & nb02,
164
- constant int64_t & nb03,
161
+ constant uint64_t & nb00,
162
+ constant uint64_t & nb01,
163
+ constant uint64_t & nb02,
164
+ constant uint64_t & nb03,
165
165
  constant int64_t & ne10,
166
166
  constant int64_t & ne11,
167
167
  constant int64_t & ne12,
168
168
  constant int64_t & ne13,
169
- constant int64_t & nb10,
170
- constant int64_t & nb11,
171
- constant int64_t & nb12,
172
- constant int64_t & nb13,
169
+ constant uint64_t & nb10,
170
+ constant uint64_t & nb11,
171
+ constant uint64_t & nb12,
172
+ constant uint64_t & nb13,
173
173
  constant int64_t & ne0,
174
174
  constant int64_t & ne1,
175
175
  constant int64_t & ne2,
176
176
  constant int64_t & ne3,
177
- constant int64_t & nb0,
178
- constant int64_t & nb1,
179
- constant int64_t & nb2,
180
- constant int64_t & nb3,
177
+ constant uint64_t & nb0,
178
+ constant uint64_t & nb1,
179
+ constant uint64_t & nb2,
180
+ constant uint64_t & nb3,
181
181
  uint3 tgpig[[threadgroup_position_in_grid]],
182
182
  uint3 tpitg[[thread_position_in_threadgroup]],
183
183
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
205
205
  device const float4 * src0,
206
206
  device const float4 * src1,
207
207
  device float4 * dst,
208
- constant int64_t & nb [[buffer(28)]],
208
+ constant uint64_t & nb [[buffer(28)]],
209
209
  uint tpig[[thread_position_in_grid]]) {
210
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
211
  }
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
214
214
  device const float4 * src0,
215
215
  device const float4 * src1,
216
216
  device float4 * dst,
217
- constant int64_t & nb [[buffer(28)]],
217
+ constant uint64_t & nb [[buffer(28)]],
218
218
  uint tpig[[thread_position_in_grid]]) {
219
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
220
  }
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
223
223
  device const float4 * src0,
224
224
  device const float4 * src1,
225
225
  device float4 * dst,
226
- constant int64_t & nb [[buffer(28)]],
226
+ constant uint64_t & nb [[buffer(28)]],
227
227
  uint tpig[[thread_position_in_grid]]) {
228
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
229
  }
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
307
307
  constant int64_t & ne01,
308
308
  constant int64_t & ne02,
309
309
  constant int64_t & ne03,
310
- constant int64_t & nb00,
311
- constant int64_t & nb01,
312
- constant int64_t & nb02,
313
- constant int64_t & nb03,
310
+ constant uint64_t & nb00,
311
+ constant uint64_t & nb01,
312
+ constant uint64_t & nb02,
313
+ constant uint64_t & nb03,
314
314
  constant int64_t & ne10,
315
315
  constant int64_t & ne11,
316
316
  constant int64_t & ne12,
317
317
  constant int64_t & ne13,
318
- constant int64_t & nb10,
319
- constant int64_t & nb11,
320
- constant int64_t & nb12,
321
- constant int64_t & nb13,
318
+ constant uint64_t & nb10,
319
+ constant uint64_t & nb11,
320
+ constant uint64_t & nb12,
321
+ constant uint64_t & nb13,
322
322
  constant int64_t & ne0,
323
323
  constant int64_t & ne1,
324
324
  constant int64_t & ne2,
325
325
  constant int64_t & ne3,
326
- constant int64_t & nb0,
327
- constant int64_t & nb1,
328
- constant int64_t & nb2,
329
- constant int64_t & nb3,
326
+ constant uint64_t & nb0,
327
+ constant uint64_t & nb1,
328
+ constant uint64_t & nb2,
329
+ constant uint64_t & nb3,
330
330
  uint3 tpig[[thread_position_in_grid]]) {
331
331
  int64_t i3 = tpig.z;
332
332
  int64_t i2 = tpig.y;
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
846
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
847
  //Note: This is a template, but strictly speaking it only applies to
848
848
  // quantizations where the block size is 32. It also does not
849
- // giard against the number of rows not being divisible by
849
+ // guard against the number of rows not being divisible by
850
850
  // N_DST, so this is another explicit assumption of the implementation.
851
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
852
  void mul_vec_q_n_f32_impl(
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
920
920
  device const float * src1,
921
921
  device float * dst,
922
922
  constant int64_t & ne00,
923
- constant int64_t & ne01[[buffer(4)]],
924
- constant int64_t & ne02[[buffer(5)]],
925
- constant int64_t & ne10[[buffer(9)]],
926
- constant int64_t & ne12[[buffer(11)]],
927
- constant int64_t & ne0 [[buffer(15)]],
928
- constant int64_t & ne1 [[buffer(16)]],
929
- constant uint & r2 [[buffer(17)]],
930
- constant uint & r3 [[buffer(18)]],
923
+ constant int64_t & ne01,
924
+ constant int64_t & ne02,
925
+ constant uint64_t & nb00,
926
+ constant uint64_t & nb01,
927
+ constant uint64_t & nb02,
928
+ constant int64_t & ne10,
929
+ constant int64_t & ne11,
930
+ constant int64_t & ne12,
931
+ constant uint64_t & nb10,
932
+ constant uint64_t & nb11,
933
+ constant uint64_t & nb12,
934
+ constant int64_t & ne0,
935
+ constant int64_t & ne1,
936
+ constant uint & r2,
937
+ constant uint & r3,
931
938
  uint3 tgpig[[threadgroup_position_in_grid]],
932
939
  uint tiisg[[thread_index_in_simdgroup]],
933
940
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
939
946
  device const float * src1,
940
947
  device float * dst,
941
948
  constant int64_t & ne00,
942
- constant int64_t & ne01[[buffer(4)]],
943
- constant int64_t & ne02[[buffer(5)]],
944
- constant int64_t & ne10[[buffer(9)]],
945
- constant int64_t & ne12[[buffer(11)]],
946
- constant int64_t & ne0 [[buffer(15)]],
947
- constant int64_t & ne1 [[buffer(16)]],
948
- constant uint & r2 [[buffer(17)]],
949
- constant uint & r3 [[buffer(18)]],
949
+ constant int64_t & ne01,
950
+ constant int64_t & ne02,
951
+ constant uint64_t & nb00,
952
+ constant uint64_t & nb01,
953
+ constant uint64_t & nb02,
954
+ constant int64_t & ne10,
955
+ constant int64_t & ne11,
956
+ constant int64_t & ne12,
957
+ constant uint64_t & nb10,
958
+ constant uint64_t & nb11,
959
+ constant uint64_t & nb12,
960
+ constant int64_t & ne0,
961
+ constant int64_t & ne1,
962
+ constant uint & r2,
963
+ constant uint & r3,
950
964
  uint3 tgpig[[threadgroup_position_in_grid]],
951
965
  uint tiisg[[thread_index_in_simdgroup]],
952
966
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
958
972
  device const float * src1,
959
973
  device float * dst,
960
974
  constant int64_t & ne00,
961
- constant int64_t & ne01[[buffer(4)]],
962
- constant int64_t & ne02[[buffer(5)]],
963
- constant int64_t & ne10[[buffer(9)]],
964
- constant int64_t & ne12[[buffer(11)]],
965
- constant int64_t & ne0 [[buffer(15)]],
966
- constant int64_t & ne1 [[buffer(16)]],
967
- constant uint & r2 [[buffer(17)]],
968
- constant uint & r3 [[buffer(18)]],
975
+ constant int64_t & ne01,
976
+ constant int64_t & ne02,
977
+ constant uint64_t & nb00,
978
+ constant uint64_t & nb01,
979
+ constant uint64_t & nb02,
980
+ constant int64_t & ne10,
981
+ constant int64_t & ne11,
982
+ constant int64_t & ne12,
983
+ constant uint64_t & nb10,
984
+ constant uint64_t & nb11,
985
+ constant uint64_t & nb12,
986
+ constant int64_t & ne0,
987
+ constant int64_t & ne1,
988
+ constant uint & r2,
989
+ constant uint & r3,
969
990
  uint3 tgpig[[threadgroup_position_in_grid]],
970
991
  uint tiisg[[thread_index_in_simdgroup]],
971
992
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
977
998
  device const float * src1,
978
999
  device float * dst,
979
1000
  constant int64_t & ne00,
980
- constant int64_t & ne01[[buffer(4)]],
981
- constant int64_t & ne02[[buffer(5)]],
982
- constant int64_t & ne10[[buffer(9)]],
983
- constant int64_t & ne12[[buffer(11)]],
984
- constant int64_t & ne0 [[buffer(15)]],
985
- constant int64_t & ne1 [[buffer(16)]],
986
- constant uint & r2 [[buffer(17)]],
987
- constant uint & r3 [[buffer(18)]],
1001
+ constant int64_t & ne01,
1002
+ constant int64_t & ne02,
1003
+ constant uint64_t & nb00,
1004
+ constant uint64_t & nb01,
1005
+ constant uint64_t & nb02,
1006
+ constant int64_t & ne10,
1007
+ constant int64_t & ne11,
1008
+ constant int64_t & ne12,
1009
+ constant uint64_t & nb10,
1010
+ constant uint64_t & nb11,
1011
+ constant uint64_t & nb12,
1012
+ constant int64_t & ne0,
1013
+ constant int64_t & ne1,
1014
+ constant uint & r2,
1015
+ constant uint & r3,
988
1016
  uint3 tgpig[[threadgroup_position_in_grid]],
989
1017
  uint tiisg[[thread_index_in_simdgroup]],
990
1018
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
1071
1099
  constant int64_t & ne00,
1072
1100
  constant int64_t & ne01,
1073
1101
  constant int64_t & ne02,
1102
+ constant uint64_t & nb00,
1103
+ constant uint64_t & nb01,
1104
+ constant uint64_t & nb02,
1074
1105
  constant int64_t & ne10,
1106
+ constant int64_t & ne11,
1075
1107
  constant int64_t & ne12,
1108
+ constant uint64_t & nb10,
1109
+ constant uint64_t & nb11,
1110
+ constant uint64_t & nb12,
1076
1111
  constant int64_t & ne0,
1077
1112
  constant int64_t & ne1,
1078
- constant uint & r2 [[buffer(17)]],
1079
- constant uint & r3 [[buffer(18)]],
1113
+ constant uint & r2,
1114
+ constant uint & r3,
1080
1115
  uint3 tgpig[[threadgroup_position_in_grid]],
1081
1116
  uint tiisg[[thread_index_in_simdgroup]],
1082
1117
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
1182
1217
  constant uint64_t & nb12,
1183
1218
  constant int64_t & ne0,
1184
1219
  constant int64_t & ne1,
1185
- constant uint & r2 [[buffer(17)]],
1186
- constant uint & r3 [[buffer(18)]],
1220
+ constant uint & r2,
1221
+ constant uint & r3,
1187
1222
  uint3 tgpig[[threadgroup_position_in_grid]],
1188
1223
  uint tiisg[[thread_index_in_simdgroup]]) {
1189
1224
  kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
1209
1244
  constant uint64_t & nb12,
1210
1245
  constant int64_t & ne0,
1211
1246
  constant int64_t & ne1,
1212
- constant uint & r2 [[buffer(17)]],
1213
- constant uint & r3 [[buffer(18)]],
1247
+ constant uint & r2,
1248
+ constant uint & r3,
1214
1249
  uint3 tgpig[[threadgroup_position_in_grid]],
1215
1250
  uint tiisg[[thread_index_in_simdgroup]]) {
1216
1251
 
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1346
1381
  constant uint64_t & nb12,
1347
1382
  constant int64_t & ne0,
1348
1383
  constant int64_t & ne1,
1349
- constant uint & r2 [[buffer(17)]],
1350
- constant uint & r3 [[buffer(18)]],
1384
+ constant uint & r2,
1385
+ constant uint & r3,
1351
1386
  uint3 tgpig[[threadgroup_position_in_grid]],
1352
1387
  uint tiisg[[thread_index_in_simdgroup]]) {
1353
1388
  kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
1452
1487
  constant uint64_t & nb12,
1453
1488
  constant int64_t & ne0,
1454
1489
  constant int64_t & ne1,
1455
- constant uint & r2 [[buffer(17)]],
1456
- constant uint & r3 [[buffer(18)]],
1490
+ constant uint & r2,
1491
+ constant uint & r3,
1457
1492
  uint3 tgpig[[threadgroup_position_in_grid]],
1458
1493
  uint tiisg[[thread_index_in_simdgroup]]) {
1459
1494
  kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1478
1513
  constant uint64_t & nb12,
1479
1514
  constant int64_t & ne0,
1480
1515
  constant int64_t & ne1,
1481
- constant uint & r2 [[buffer(17)]],
1482
- constant uint & r3 [[buffer(18)]],
1516
+ constant uint & r2,
1517
+ constant uint & r3,
1483
1518
  uint3 tgpig[[threadgroup_position_in_grid]],
1484
1519
  uint tiisg[[thread_index_in_simdgroup]]) {
1485
1520
 
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
1543
1578
  const int64_t i3 = n / (ne2*ne1*ne0);
1544
1579
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1545
1580
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1546
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1581
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1582
+
1547
1583
  const int64_t k = i3*ne3 + i2;
1548
1584
 
1549
1585
  float m_k;
@@ -2410,22 +2446,6 @@ typedef struct {
2410
2446
  } block_q6_K;
2411
2447
  // 210 bytes / block
2412
2448
 
2413
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2414
- uchar4 r;
2415
- if (j < 4) {
2416
- r[0] = q[j+0] & 63;
2417
- r[2] = q[j+1] & 63;
2418
- r[1] = q[j+4] & 63;
2419
- r[3] = q[j+5] & 63;
2420
- } else {
2421
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2422
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2423
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2424
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2425
- }
2426
- return r;
2427
- }
2428
-
2429
2449
  //====================================== dot products =========================
2430
2450
 
2431
2451
  void kernel_mul_mv_q2_K_f32_impl(
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
2584
2604
  device const float * src1,
2585
2605
  device float * dst,
2586
2606
  constant int64_t & ne00,
2587
- constant int64_t & ne01[[buffer(4)]],
2588
- constant int64_t & ne02[[buffer(5)]],
2589
- constant int64_t & ne10[[buffer(9)]],
2590
- constant int64_t & ne12[[buffer(11)]],
2591
- constant int64_t & ne0 [[buffer(15)]],
2592
- constant int64_t & ne1 [[buffer(16)]],
2593
- constant uint & r2 [[buffer(17)]],
2594
- constant uint & r3 [[buffer(18)]],
2607
+ constant int64_t & ne01,
2608
+ constant int64_t & ne02,
2609
+ constant uint64_t & nb00,
2610
+ constant uint64_t & nb01,
2611
+ constant uint64_t & nb02,
2612
+ constant int64_t & ne10,
2613
+ constant int64_t & ne11,
2614
+ constant int64_t & ne12,
2615
+ constant uint64_t & nb10,
2616
+ constant uint64_t & nb11,
2617
+ constant uint64_t & nb12,
2618
+ constant int64_t & ne0,
2619
+ constant int64_t & ne1,
2620
+ constant uint & r2,
2621
+ constant uint & r3,
2595
2622
  uint3 tgpig[[threadgroup_position_in_grid]],
2596
2623
  uint tiisg[[thread_index_in_simdgroup]],
2597
2624
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
2841
2868
  device const float * src1,
2842
2869
  device float * dst,
2843
2870
  constant int64_t & ne00,
2844
- constant int64_t & ne01[[buffer(4)]],
2845
- constant int64_t & ne02[[buffer(5)]],
2846
- constant int64_t & ne10[[buffer(9)]],
2847
- constant int64_t & ne12[[buffer(11)]],
2848
- constant int64_t & ne0 [[buffer(15)]],
2849
- constant int64_t & ne1 [[buffer(16)]],
2850
- constant uint & r2 [[buffer(17)]],
2851
- constant uint & r3 [[buffer(18)]],
2871
+ constant int64_t & ne01,
2872
+ constant int64_t & ne02,
2873
+ constant uint64_t & nb00,
2874
+ constant uint64_t & nb01,
2875
+ constant uint64_t & nb02,
2876
+ constant int64_t & ne10,
2877
+ constant int64_t & ne11,
2878
+ constant int64_t & ne12,
2879
+ constant uint64_t & nb10,
2880
+ constant uint64_t & nb11,
2881
+ constant uint64_t & nb12,
2882
+ constant int64_t & ne0,
2883
+ constant int64_t & ne1,
2884
+ constant uint & r2,
2885
+ constant uint & r3,
2852
2886
  uint3 tgpig[[threadgroup_position_in_grid]],
2853
2887
  uint tiisg[[thread_index_in_simdgroup]],
2854
2888
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
2984
3018
  constant uint & r2,
2985
3019
  constant uint & r3,
2986
3020
  uint3 tgpig[[threadgroup_position_in_grid]],
2987
- uint tiisg[[thread_index_in_simdgroup]],
2988
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3021
+ uint tiisg[[thread_index_in_simdgroup]],
3022
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2989
3023
 
2990
3024
  const int ix = tiisg/4; // 0...7
2991
3025
  const int it = tiisg%4; // 0...3
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
2994
3028
  const int r0 = tgpig.x;
2995
3029
  const int r1 = tgpig.y;
2996
3030
  const int im = tgpig.z;
2997
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3031
+ const int first_row = r0 * N_DST;
2998
3032
  const int ib_row = first_row * nb;
2999
3033
 
3000
3034
  const uint i12 = im%ne12;
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3060
3094
  for (int row = 0; row < N_DST; ++row) {
3061
3095
  all_sum = simd_sum(sumf[row]);
3062
3096
  if (tiisg == 0) {
3063
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3097
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
3064
3098
  }
3065
3099
  }
3066
3100
  }
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
3072
3106
  device const float * src1,
3073
3107
  device float * dst,
3074
3108
  constant int64_t & ne00,
3075
- constant int64_t & ne01[[buffer(4)]],
3076
- constant int64_t & ne02[[buffer(5)]],
3077
- constant int64_t & ne10[[buffer(9)]],
3078
- constant int64_t & ne12[[buffer(11)]],
3079
- constant int64_t & ne0 [[buffer(15)]],
3080
- constant int64_t & ne1 [[buffer(16)]],
3081
- constant uint & r2 [[buffer(17)]],
3082
- constant uint & r3 [[buffer(18)]],
3109
+ constant int64_t & ne01,
3110
+ constant int64_t & ne02,
3111
+ constant uint64_t & nb00,
3112
+ constant uint64_t & nb01,
3113
+ constant uint64_t & nb02,
3114
+ constant int64_t & ne10,
3115
+ constant int64_t & ne11,
3116
+ constant int64_t & ne12,
3117
+ constant uint64_t & nb10,
3118
+ constant uint64_t & nb11,
3119
+ constant uint64_t & nb12,
3120
+ constant int64_t & ne0,
3121
+ constant int64_t & ne1,
3122
+ constant uint & r2,
3123
+ constant uint & r3,
3083
3124
  uint3 tgpig[[threadgroup_position_in_grid]],
3084
3125
  uint tiisg[[thread_index_in_simdgroup]],
3085
3126
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
3271
3312
  device const float * src1,
3272
3313
  device float * dst,
3273
3314
  constant int64_t & ne00,
3274
- constant int64_t & ne01[[buffer(4)]],
3275
- constant int64_t & ne02[[buffer(5)]],
3276
- constant int64_t & ne10[[buffer(9)]],
3277
- constant int64_t & ne12[[buffer(11)]],
3278
- constant int64_t & ne0 [[buffer(15)]],
3279
- constant int64_t & ne1 [[buffer(16)]],
3280
- constant uint & r2 [[buffer(17)]],
3281
- constant uint & r3 [[buffer(18)]],
3315
+ constant int64_t & ne01,
3316
+ constant int64_t & ne02,
3317
+ constant uint64_t & nb00,
3318
+ constant uint64_t & nb01,
3319
+ constant uint64_t & nb02,
3320
+ constant int64_t & ne10,
3321
+ constant int64_t & ne11,
3322
+ constant int64_t & ne12,
3323
+ constant uint64_t & nb10,
3324
+ constant uint64_t & nb11,
3325
+ constant uint64_t & nb12,
3326
+ constant int64_t & ne0,
3327
+ constant int64_t & ne1,
3328
+ constant uint & r2,
3329
+ constant uint & r3,
3282
3330
  uint3 tgpig[[threadgroup_position_in_grid]],
3283
3331
  uint tiisg[[thread_index_in_simdgroup]],
3284
3332
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
3398
3446
  device const float * src1,
3399
3447
  device float * dst,
3400
3448
  constant int64_t & ne00,
3401
- constant int64_t & ne01[[buffer(4)]],
3402
- constant int64_t & ne02[[buffer(5)]],
3403
- constant int64_t & ne10[[buffer(9)]],
3404
- constant int64_t & ne12[[buffer(11)]],
3405
- constant int64_t & ne0 [[buffer(15)]],
3406
- constant int64_t & ne1 [[buffer(16)]],
3407
- constant uint & r2 [[buffer(17)]],
3408
- constant uint & r3 [[buffer(18)]],
3449
+ constant int64_t & ne01,
3450
+ constant int64_t & ne02,
3451
+ constant uint64_t & nb00,
3452
+ constant uint64_t & nb01,
3453
+ constant uint64_t & nb02,
3454
+ constant int64_t & ne10,
3455
+ constant int64_t & ne11,
3456
+ constant int64_t & ne12,
3457
+ constant uint64_t & nb10,
3458
+ constant uint64_t & nb11,
3459
+ constant uint64_t & nb12,
3460
+ constant int64_t & ne0,
3461
+ constant int64_t & ne1,
3462
+ constant uint & r2,
3463
+ constant uint & r3,
3409
3464
  uint3 tgpig[[threadgroup_position_in_grid]],
3410
3465
  uint tiisg[[thread_index_in_simdgroup]],
3411
3466
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3523
3578
  device const int8_t * qs = ((device const int8_t *)xb->qs);
3524
3579
  const half d = xb->d;
3525
3580
 
3526
- for (int i=0;i<16;i++) {
3581
+ for (int i = 0; i < 16; i++) {
3527
3582
  reg[i/4][i%4] = (qs[i + 16*il] * d);
3528
3583
  }
3529
3584
  }
@@ -3774,6 +3829,35 @@ kernel void kernel_get_rows_f16(
3774
3829
  }
3775
3830
  }
3776
3831
 
3832
+ kernel void kernel_get_rows_i32(
3833
+ device const void * src0,
3834
+ device const char * src1,
3835
+ device int32_t * dst,
3836
+ constant int64_t & ne00,
3837
+ constant uint64_t & nb01,
3838
+ constant uint64_t & nb02,
3839
+ constant int64_t & ne10,
3840
+ constant uint64_t & nb10,
3841
+ constant uint64_t & nb11,
3842
+ constant uint64_t & nb1,
3843
+ constant uint64_t & nb2,
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiitg[[thread_index_in_threadgroup]],
3846
+ uint3 tptg [[threads_per_threadgroup]]) {
3847
+ const int64_t i10 = tgpig.x;
3848
+ const int64_t i11 = tgpig.y;
3849
+
3850
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3851
+
3852
+ const int64_t i02 = i11;
3853
+
3854
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3855
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3856
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3857
+ }
3858
+ }
3859
+
3860
+
3777
3861
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3778
3862
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3779
3863
  #define BLOCK_SIZE_K 32
@@ -3792,12 +3876,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
3792
3876
  device float * dst,
3793
3877
  constant int64_t & ne00,
3794
3878
  constant int64_t & ne02,
3795
- constant int64_t & nb01,
3796
- constant int64_t & nb02,
3879
+ constant uint64_t & nb01,
3880
+ constant uint64_t & nb02,
3797
3881
  constant int64_t & ne12,
3798
- constant int64_t & nb10,
3799
- constant int64_t & nb11,
3800
- constant int64_t & nb12,
3882
+ constant uint64_t & nb10,
3883
+ constant uint64_t & nb11,
3884
+ constant uint64_t & nb12,
3801
3885
  constant int64_t & ne0,
3802
3886
  constant int64_t & ne1,
3803
3887
  constant uint & r2,
@@ -3918,18 +4002,143 @@ void kernel_mul_mm_impl(device const uchar * src0,
3918
4002
  }
3919
4003
  }
3920
4004
 
4005
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
4006
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4007
+ void kernel_mul_mm_id_impl(
4008
+ device const uchar * src0,
4009
+ device const uchar * src1,
4010
+ thread short * src1ids,
4011
+ device float * dst,
4012
+ constant int64_t & ne00,
4013
+ constant int64_t & ne02,
4014
+ constant uint64_t & nb01,
4015
+ constant uint64_t & nb02,
4016
+ constant int64_t & ne12,
4017
+ constant uint64_t & nb10,
4018
+ constant uint64_t & nb11,
4019
+ constant uint64_t & nb12,
4020
+ constant int64_t & ne0,
4021
+ int64_t ne1,
4022
+ constant uint & r2,
4023
+ constant uint & r3,
4024
+ threadgroup uchar * shared_memory,
4025
+ uint3 tgpig[[threadgroup_position_in_grid]],
4026
+ uint tiitg[[thread_index_in_threadgroup]],
4027
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4028
+
4029
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
4030
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
4031
+
4032
+ const uint r0 = tgpig.y;
4033
+ const uint r1 = tgpig.x;
4034
+ const uint im = tgpig.z;
4035
+
4036
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
4037
+
4038
+ // if this block is of 64x32 shape or smaller
4039
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4040
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
4041
+
4042
+ // a thread shouldn't load data outside of the matrix
4043
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4044
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
4045
+
4046
+ simdgroup_half8x8 ma[4];
4047
+ simdgroup_float8x8 mb[2];
4048
+ simdgroup_float8x8 c_res[8];
4049
+ for (int i = 0; i < 8; i++){
4050
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4051
+ }
4052
+
4053
+ short il = (tiitg % THREAD_PER_ROW);
4054
+
4055
+ const uint i12 = im%ne12;
4056
+ const uint i13 = im/ne12;
4057
+
4058
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4059
+ ushort offset1 = il/nl;
4060
+
4061
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4062
+ device const float * y = (device const float *)(src1
4063
+ + nb12 * im
4064
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
4065
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
4066
+
4067
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4068
+ // load data and store to threadgroup memory
4069
+ half4x4 temp_a;
4070
+ dequantize_func(x, il, temp_a);
4071
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4072
+
4073
+ for (int i = 0; i < 16; i++) {
4074
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4075
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4076
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4077
+ }
4078
+
4079
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
4080
+
4081
+ il = (il + 2 < nl) ? il + 2 : il % 2;
4082
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
4083
+ y += BLOCK_SIZE_K;
4084
+
4085
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4086
+
4087
+ // load matrices from threadgroup memory and conduct outer products
4088
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4089
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
4090
+
4091
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4092
+ for (int i = 0; i < 4; i++) {
4093
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4094
+ }
4095
+ simdgroup_barrier(mem_flags::mem_none);
4096
+ for (int i = 0; i < 2; i++) {
4097
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4098
+ }
4099
+
4100
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4101
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
4102
+
4103
+ for (int i = 0; i < 8; i++){
4104
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4105
+ }
4106
+ }
4107
+ }
4108
+
4109
+ {
4110
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4111
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4112
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4113
+ for (int i = 0; i < 8; i++) {
4114
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4115
+ }
4116
+
4117
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4118
+
4119
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
4120
+ if (sgitg == 0) {
4121
+ for (int i = 0; i < n_rows; i++) {
4122
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4123
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4124
+ }
4125
+ }
4126
+ }
4127
+ }
4128
+ }
4129
+
3921
4130
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3922
4131
  kernel void kernel_mul_mm(device const uchar * src0,
3923
4132
  device const uchar * src1,
3924
4133
  device float * dst,
3925
4134
  constant int64_t & ne00,
3926
4135
  constant int64_t & ne02,
3927
- constant int64_t & nb01,
3928
- constant int64_t & nb02,
4136
+ constant uint64_t & nb01,
4137
+ constant uint64_t & nb02,
3929
4138
  constant int64_t & ne12,
3930
- constant int64_t & nb10,
3931
- constant int64_t & nb11,
3932
- constant int64_t & nb12,
4139
+ constant uint64_t & nb10,
4140
+ constant uint64_t & nb11,
4141
+ constant uint64_t & nb12,
3933
4142
  constant int64_t & ne0,
3934
4143
  constant int64_t & ne1,
3935
4144
  constant uint & r2,
@@ -3964,20 +4173,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
3964
4173
  kernel void kernel_mul_mm_id(
3965
4174
  device const uchar * ids,
3966
4175
  device const uchar * src1,
3967
- device uchar * dst,
3968
- constant int64_t & nbi1,
4176
+ device float * dst,
4177
+ constant uint64_t & nbi1,
3969
4178
  constant int64_t & ne00,
3970
4179
  constant int64_t & ne02,
3971
- constant int64_t & nb01,
3972
- constant int64_t & nb02,
4180
+ constant uint64_t & nb01,
4181
+ constant uint64_t & nb02,
3973
4182
  constant int64_t & ne12,
3974
4183
  constant int64_t & ne13,
3975
- constant int64_t & nb10,
3976
- constant int64_t & nb11,
3977
- constant int64_t & nb12,
4184
+ constant uint64_t & nb10,
4185
+ constant uint64_t & nb11,
4186
+ constant uint64_t & nb12,
3978
4187
  constant int64_t & ne0,
3979
4188
  constant int64_t & ne1,
3980
- constant int64_t & nb1,
4189
+ constant uint64_t & nb1,
3981
4190
  constant uint & r2,
3982
4191
  constant uint & r3,
3983
4192
  constant int & idx,
@@ -3993,18 +4202,28 @@ kernel void kernel_mul_mm_id(
3993
4202
  uint3 tgpig[[threadgroup_position_in_grid]],
3994
4203
  uint tiitg[[thread_index_in_threadgroup]],
3995
4204
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3996
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4205
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3997
4206
 
3998
- const int64_t bid = tgpig.z/(ne12*ne13);
4207
+ // expert id
4208
+ const int32_t id = tgpig.z/(ne12*ne13);
3999
4209
 
4000
4210
  tgpig.z = tgpig.z%(ne12*ne13);
4001
4211
 
4002
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4212
+ // row indices of src1 for expert id
4213
+ int64_t _ne1 = 0;
4214
+ short src1ids[512];
4003
4215
 
4004
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
4005
- src0[id],
4006
- src1 + bid*nb11,
4007
- (device float *) (dst + bid*nb1),
4216
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
4217
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4218
+ src1ids[_ne1++] = i1;
4219
+ }
4220
+ }
4221
+
4222
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
4223
+ src0s[id],
4224
+ src1,
4225
+ src1ids,
4226
+ dst,
4008
4227
  ne00,
4009
4228
  ne02,
4010
4229
  nb01,
@@ -4014,7 +4233,7 @@ kernel void kernel_mul_mm_id(
4014
4233
  nb11,
4015
4234
  nb12,
4016
4235
  ne0,
4017
- ne1,
4236
+ _ne1,
4018
4237
  r2,
4019
4238
  r3,
4020
4239
  shared_memory,
@@ -4070,12 +4289,12 @@ typedef void (mat_mm_t)(
4070
4289
  device float * dst,
4071
4290
  constant int64_t & ne00,
4072
4291
  constant int64_t & ne02,
4073
- constant int64_t & nb01,
4074
- constant int64_t & nb02,
4292
+ constant uint64_t & nb01,
4293
+ constant uint64_t & nb02,
4075
4294
  constant int64_t & ne12,
4076
- constant int64_t & nb10,
4077
- constant int64_t & nb11,
4078
- constant int64_t & nb12,
4295
+ constant uint64_t & nb10,
4296
+ constant uint64_t & nb11,
4297
+ constant uint64_t & nb12,
4079
4298
  constant int64_t & ne0,
4080
4299
  constant int64_t & ne1,
4081
4300
  constant uint & r2,
@@ -4103,20 +4322,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4103
4322
  typedef void (mat_mm_id_t)(
4104
4323
  device const uchar * ids,
4105
4324
  device const uchar * src1,
4106
- device uchar * dst,
4107
- constant int64_t & nbi1,
4325
+ device float * dst,
4326
+ constant uint64_t & nbi1,
4108
4327
  constant int64_t & ne00,
4109
4328
  constant int64_t & ne02,
4110
- constant int64_t & nb01,
4111
- constant int64_t & nb02,
4329
+ constant uint64_t & nb01,
4330
+ constant uint64_t & nb02,
4112
4331
  constant int64_t & ne12,
4113
4332
  constant int64_t & ne13,
4114
- constant int64_t & nb10,
4115
- constant int64_t & nb11,
4116
- constant int64_t & nb12,
4333
+ constant uint64_t & nb10,
4334
+ constant uint64_t & nb11,
4335
+ constant uint64_t & nb12,
4117
4336
  constant int64_t & ne0,
4118
4337
  constant int64_t & ne1,
4119
- constant int64_t & nb1,
4338
+ constant uint64_t & nb1,
4120
4339
  constant uint & r2,
4121
4340
  constant uint & r3,
4122
4341
  constant int & idx,
@@ -4152,8 +4371,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
4152
4371
  kernel void kernel_mul_mv_id_f32_f32(
4153
4372
  device const char * ids,
4154
4373
  device const char * src1,
4155
- device uchar * dst,
4156
- constant int64_t & nbi1,
4374
+ device float * dst,
4375
+ constant uint64_t & nbi1,
4157
4376
  constant int64_t & ne00,
4158
4377
  constant int64_t & ne01,
4159
4378
  constant int64_t & ne02,
@@ -4169,7 +4388,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4169
4388
  constant uint64_t & nb12,
4170
4389
  constant int64_t & ne0,
4171
4390
  constant int64_t & ne1,
4172
- constant int64_t & nb1,
4391
+ constant uint64_t & nb1,
4173
4392
  constant uint & r2,
4174
4393
  constant uint & r3,
4175
4394
  constant int & idx,
@@ -4196,7 +4415,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4196
4415
  kernel_mul_mv_f32_f32_impl(
4197
4416
  src0[id],
4198
4417
  src1 + bid*nb11,
4199
- (device float *) (dst + bid*nb1),
4418
+ dst + bid*ne0,
4200
4419
  ne00,
4201
4420
  ne01,
4202
4421
  ne02,
@@ -4221,8 +4440,8 @@ kernel void kernel_mul_mv_id_f32_f32(
4221
4440
  kernel void kernel_mul_mv_id_f16_f32(
4222
4441
  device const char * ids,
4223
4442
  device const char * src1,
4224
- device uchar * dst,
4225
- constant int64_t & nbi1,
4443
+ device float * dst,
4444
+ constant uint64_t & nbi1,
4226
4445
  constant int64_t & ne00,
4227
4446
  constant int64_t & ne01,
4228
4447
  constant int64_t & ne02,
@@ -4238,7 +4457,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4238
4457
  constant uint64_t & nb12,
4239
4458
  constant int64_t & ne0,
4240
4459
  constant int64_t & ne1,
4241
- constant int64_t & nb1,
4460
+ constant uint64_t & nb1,
4242
4461
  constant uint & r2,
4243
4462
  constant uint & r3,
4244
4463
  constant int & idx,
@@ -4265,7 +4484,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4265
4484
  kernel_mul_mv_f16_f32_impl(
4266
4485
  src0[id],
4267
4486
  src1 + bid*nb11,
4268
- (device float *) (dst + bid*nb1),
4487
+ dst + bid*ne0,
4269
4488
  ne00,
4270
4489
  ne01,
4271
4490
  ne02,
@@ -4290,8 +4509,8 @@ kernel void kernel_mul_mv_id_f16_f32(
4290
4509
  kernel void kernel_mul_mv_id_q8_0_f32(
4291
4510
  device const char * ids,
4292
4511
  device const char * src1,
4293
- device uchar * dst,
4294
- constant int64_t & nbi1,
4512
+ device float * dst,
4513
+ constant uint64_t & nbi1,
4295
4514
  constant int64_t & ne00,
4296
4515
  constant int64_t & ne01,
4297
4516
  constant int64_t & ne02,
@@ -4307,7 +4526,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4307
4526
  constant uint64_t & nb12,
4308
4527
  constant int64_t & ne0,
4309
4528
  constant int64_t & ne1,
4310
- constant int64_t & nb1,
4529
+ constant uint64_t & nb1,
4311
4530
  constant uint & r2,
4312
4531
  constant uint & r3,
4313
4532
  constant int & idx,
@@ -4334,7 +4553,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4334
4553
  kernel_mul_mv_q8_0_f32_impl(
4335
4554
  src0[id],
4336
4555
  (device const float *) (src1 + bid*nb11),
4337
- (device float *) ( dst + bid*nb1),
4556
+ dst + bid*ne0,
4338
4557
  ne00,
4339
4558
  ne01,
4340
4559
  ne02,
@@ -4353,8 +4572,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4353
4572
  kernel void kernel_mul_mv_id_q4_0_f32(
4354
4573
  device const char * ids,
4355
4574
  device const char * src1,
4356
- device uchar * dst,
4357
- constant int64_t & nbi1,
4575
+ device float * dst,
4576
+ constant uint64_t & nbi1,
4358
4577
  constant int64_t & ne00,
4359
4578
  constant int64_t & ne01,
4360
4579
  constant int64_t & ne02,
@@ -4370,7 +4589,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4370
4589
  constant uint64_t & nb12,
4371
4590
  constant int64_t & ne0,
4372
4591
  constant int64_t & ne1,
4373
- constant int64_t & nb1,
4592
+ constant uint64_t & nb1,
4374
4593
  constant uint & r2,
4375
4594
  constant uint & r3,
4376
4595
  constant int & idx,
@@ -4397,7 +4616,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4397
4616
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4398
4617
  src0[id],
4399
4618
  (device const float *) (src1 + bid*nb11),
4400
- (device float *) ( dst + bid*nb1),
4619
+ dst + bid*ne0,
4401
4620
  ne00,
4402
4621
  ne01,
4403
4622
  ne02,
@@ -4416,8 +4635,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4416
4635
  kernel void kernel_mul_mv_id_q4_1_f32(
4417
4636
  device const char * ids,
4418
4637
  device const char * src1,
4419
- device uchar * dst,
4420
- constant int64_t & nbi1,
4638
+ device float * dst,
4639
+ constant uint64_t & nbi1,
4421
4640
  constant int64_t & ne00,
4422
4641
  constant int64_t & ne01,
4423
4642
  constant int64_t & ne02,
@@ -4433,7 +4652,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4433
4652
  constant uint64_t & nb12,
4434
4653
  constant int64_t & ne0,
4435
4654
  constant int64_t & ne1,
4436
- constant int64_t & nb1,
4655
+ constant uint64_t & nb1,
4437
4656
  constant uint & r2,
4438
4657
  constant uint & r3,
4439
4658
  constant int & idx,
@@ -4460,7 +4679,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4460
4679
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4461
4680
  src0[id],
4462
4681
  (device const float *) (src1 + bid*nb11),
4463
- (device float *) ( dst + bid*nb1),
4682
+ dst + bid*ne0,
4464
4683
  ne00,
4465
4684
  ne01,
4466
4685
  ne02,
@@ -4479,8 +4698,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4479
4698
  kernel void kernel_mul_mv_id_q5_0_f32(
4480
4699
  device const char * ids,
4481
4700
  device const char * src1,
4482
- device uchar * dst,
4483
- constant int64_t & nbi1,
4701
+ device float * dst,
4702
+ constant uint64_t & nbi1,
4484
4703
  constant int64_t & ne00,
4485
4704
  constant int64_t & ne01,
4486
4705
  constant int64_t & ne02,
@@ -4496,7 +4715,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4496
4715
  constant uint64_t & nb12,
4497
4716
  constant int64_t & ne0,
4498
4717
  constant int64_t & ne1,
4499
- constant int64_t & nb1,
4718
+ constant uint64_t & nb1,
4500
4719
  constant uint & r2,
4501
4720
  constant uint & r3,
4502
4721
  constant int & idx,
@@ -4523,7 +4742,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4523
4742
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4524
4743
  src0[id],
4525
4744
  (device const float *) (src1 + bid*nb11),
4526
- (device float *) ( dst + bid*nb1),
4745
+ dst + bid*ne0,
4527
4746
  ne00,
4528
4747
  ne01,
4529
4748
  ne02,
@@ -4542,8 +4761,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4542
4761
  kernel void kernel_mul_mv_id_q5_1_f32(
4543
4762
  device const char * ids,
4544
4763
  device const char * src1,
4545
- device uchar * dst,
4546
- constant int64_t & nbi1,
4764
+ device float * dst,
4765
+ constant uint64_t & nbi1,
4547
4766
  constant int64_t & ne00,
4548
4767
  constant int64_t & ne01,
4549
4768
  constant int64_t & ne02,
@@ -4559,7 +4778,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4559
4778
  constant uint64_t & nb12,
4560
4779
  constant int64_t & ne0,
4561
4780
  constant int64_t & ne1,
4562
- constant int64_t & nb1,
4781
+ constant uint64_t & nb1,
4563
4782
  constant uint & r2,
4564
4783
  constant uint & r3,
4565
4784
  constant int & idx,
@@ -4586,7 +4805,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4586
4805
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4587
4806
  src0[id],
4588
4807
  (device const float *) (src1 + bid*nb11),
4589
- (device float *) ( dst + bid*nb1),
4808
+ dst + bid*ne0,
4590
4809
  ne00,
4591
4810
  ne01,
4592
4811
  ne02,
@@ -4605,8 +4824,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4605
4824
  kernel void kernel_mul_mv_id_q2_K_f32(
4606
4825
  device const char * ids,
4607
4826
  device const char * src1,
4608
- device uchar * dst,
4609
- constant int64_t & nbi1,
4827
+ device float * dst,
4828
+ constant uint64_t & nbi1,
4610
4829
  constant int64_t & ne00,
4611
4830
  constant int64_t & ne01,
4612
4831
  constant int64_t & ne02,
@@ -4622,7 +4841,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4622
4841
  constant uint64_t & nb12,
4623
4842
  constant int64_t & ne0,
4624
4843
  constant int64_t & ne1,
4625
- constant int64_t & nb1,
4844
+ constant uint64_t & nb1,
4626
4845
  constant uint & r2,
4627
4846
  constant uint & r3,
4628
4847
  constant int & idx,
@@ -4649,7 +4868,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4649
4868
  kernel_mul_mv_q2_K_f32_impl(
4650
4869
  src0[id],
4651
4870
  (device const float *) (src1 + bid*nb11),
4652
- (device float *) ( dst + bid*nb1),
4871
+ dst + bid*ne0,
4653
4872
  ne00,
4654
4873
  ne01,
4655
4874
  ne02,
@@ -4668,8 +4887,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4668
4887
  kernel void kernel_mul_mv_id_q3_K_f32(
4669
4888
  device const char * ids,
4670
4889
  device const char * src1,
4671
- device uchar * dst,
4672
- constant int64_t & nbi1,
4890
+ device float * dst,
4891
+ constant uint64_t & nbi1,
4673
4892
  constant int64_t & ne00,
4674
4893
  constant int64_t & ne01,
4675
4894
  constant int64_t & ne02,
@@ -4685,7 +4904,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4685
4904
  constant uint64_t & nb12,
4686
4905
  constant int64_t & ne0,
4687
4906
  constant int64_t & ne1,
4688
- constant int64_t & nb1,
4907
+ constant uint64_t & nb1,
4689
4908
  constant uint & r2,
4690
4909
  constant uint & r3,
4691
4910
  constant int & idx,
@@ -4712,7 +4931,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4712
4931
  kernel_mul_mv_q3_K_f32_impl(
4713
4932
  src0[id],
4714
4933
  (device const float *) (src1 + bid*nb11),
4715
- (device float *) ( dst + bid*nb1),
4934
+ dst + bid*ne0,
4716
4935
  ne00,
4717
4936
  ne01,
4718
4937
  ne02,
@@ -4731,8 +4950,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4731
4950
  kernel void kernel_mul_mv_id_q4_K_f32(
4732
4951
  device const char * ids,
4733
4952
  device const char * src1,
4734
- device uchar * dst,
4735
- constant int64_t & nbi1,
4953
+ device float * dst,
4954
+ constant uint64_t & nbi1,
4736
4955
  constant int64_t & ne00,
4737
4956
  constant int64_t & ne01,
4738
4957
  constant int64_t & ne02,
@@ -4748,7 +4967,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4748
4967
  constant uint64_t & nb12,
4749
4968
  constant int64_t & ne0,
4750
4969
  constant int64_t & ne1,
4751
- constant int64_t & nb1,
4970
+ constant uint64_t & nb1,
4752
4971
  constant uint & r2,
4753
4972
  constant uint & r3,
4754
4973
  constant int & idx,
@@ -4775,7 +4994,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4775
4994
  kernel_mul_mv_q4_K_f32_impl(
4776
4995
  src0[id],
4777
4996
  (device const float *) (src1 + bid*nb11),
4778
- (device float *) ( dst + bid*nb1),
4997
+ dst + bid*ne0,
4779
4998
  ne00,
4780
4999
  ne01,
4781
5000
  ne02,
@@ -4794,8 +5013,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4794
5013
  kernel void kernel_mul_mv_id_q5_K_f32(
4795
5014
  device const char * ids,
4796
5015
  device const char * src1,
4797
- device uchar * dst,
4798
- constant int64_t & nbi1,
5016
+ device float * dst,
5017
+ constant uint64_t & nbi1,
4799
5018
  constant int64_t & ne00,
4800
5019
  constant int64_t & ne01,
4801
5020
  constant int64_t & ne02,
@@ -4811,7 +5030,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4811
5030
  constant uint64_t & nb12,
4812
5031
  constant int64_t & ne0,
4813
5032
  constant int64_t & ne1,
4814
- constant int64_t & nb1,
5033
+ constant uint64_t & nb1,
4815
5034
  constant uint & r2,
4816
5035
  constant uint & r3,
4817
5036
  constant int & idx,
@@ -4838,7 +5057,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4838
5057
  kernel_mul_mv_q5_K_f32_impl(
4839
5058
  src0[id],
4840
5059
  (device const float *) (src1 + bid*nb11),
4841
- (device float *) ( dst + bid*nb1),
5060
+ dst + bid*ne0,
4842
5061
  ne00,
4843
5062
  ne01,
4844
5063
  ne02,
@@ -4857,8 +5076,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4857
5076
  kernel void kernel_mul_mv_id_q6_K_f32(
4858
5077
  device const char * ids,
4859
5078
  device const char * src1,
4860
- device uchar * dst,
4861
- constant int64_t & nbi1,
5079
+ device float * dst,
5080
+ constant uint64_t & nbi1,
4862
5081
  constant int64_t & ne00,
4863
5082
  constant int64_t & ne01,
4864
5083
  constant int64_t & ne02,
@@ -4874,7 +5093,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4874
5093
  constant uint64_t & nb12,
4875
5094
  constant int64_t & ne0,
4876
5095
  constant int64_t & ne1,
4877
- constant int64_t & nb1,
5096
+ constant uint64_t & nb1,
4878
5097
  constant uint & r2,
4879
5098
  constant uint & r3,
4880
5099
  constant int & idx,
@@ -4901,7 +5120,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4901
5120
  kernel_mul_mv_q6_K_f32_impl(
4902
5121
  src0[id],
4903
5122
  (device const float *) (src1 + bid*nb11),
4904
- (device float *) ( dst + bid*nb1),
5123
+ dst + bid*ne0,
4905
5124
  ne00,
4906
5125
  ne01,
4907
5126
  ne02,