llama_cpp 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
68
  constant float & scale,
69
69
  uint tpig[[thread_position_in_grid]]) {
70
70
  dst[tpig] = src0[tpig] * scale;
71
71
  }
72
72
 
73
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
77
+ device const float4 & x = src0[tpig];
78
78
  dst[tpig] = x / (1.0f + exp(-x));
79
79
  }
80
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
90
 
91
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
95
+ device const float4 & x = src0[tpig];
96
96
 
97
97
  // BEWARE !!!
98
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
107
  constant int64_t & ne00,
108
108
  constant int64_t & ne01,
109
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
110
  uint3 tgpig[[threadgroup_position_in_grid]],
112
111
  uint3 tpitg[[thread_position_in_threadgroup]],
113
112
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
119
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
119
 
121
120
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
125
- }
126
-
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
121
+ float lmax = psrc0[tpitg[0]];
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
134
124
  }
135
-
136
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
137
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
138
- //if (tpitg[0] == 0) {
139
- // buf[0] = buf[0];
140
- //}
141
-
142
- //threadgroup_barrier(mem_flags::mem_threadgroup);
143
-
144
- const float max = buf[0];
125
+ const float max = simd_max(lmax);
145
126
 
146
127
  // parallel sum
147
- buf[tpitg[0]] = 0.0f;
128
+ float lsum = 0.0f;
148
129
  for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
149
130
  const float exp_psrc0 = exp(psrc0[i00] - max);
150
- buf[tpitg[0]] += exp_psrc0;
131
+ lsum += exp_psrc0;
151
132
  // Remember the result of exp here. exp is expensive, so we really do not
152
133
  // whish to compute it twice.
153
134
  pdst[i00] = exp_psrc0;
154
135
  }
155
136
 
156
- // reduce
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
159
- if (tpitg[0] < i) {
160
- buf[tpitg[0]] += buf[tpitg[0] + i];
161
- }
162
- threadgroup_barrier(mem_flags::mem_threadgroup);
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
141
+ }
142
+ }
143
+
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
156
+
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
+
160
+ // parallel max
161
+ float4 lmax4 = psrc4[tpitg[0]];
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
163
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
164
166
 
165
- // broadcast - not needed, see above
166
- //// broadcast
167
- //if (tpitg[0] == 0) {
168
- // buf[0] = buf[0];
169
- //}
167
+ const float max = simd_max(lmax);
170
168
 
171
- //threadgroup_barrier(mem_flags::mem_threadgroup);
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
175
+ }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
172
177
 
173
- const float sum = buf[0];
178
+ const float sum = simd_sum(lsum);
174
179
 
175
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
176
- pdst[i00] /= sum;
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
177
182
  }
178
183
  }
179
184
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
192
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
193
198
  } else {
194
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
195
227
  }
196
228
  }
197
229
 
@@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
616
648
  }
617
649
  }
618
650
 
651
+ // Assumes row size (ne00) is a multiple of 4
652
+ kernel void kernel_mul_mat_f16_f32_l4(
653
+ device const char * src0,
654
+ device const char * src1,
655
+ device float * dst,
656
+ constant int64_t & ne00,
657
+ constant int64_t & ne01,
658
+ constant int64_t & ne02,
659
+ constant uint64_t & nb00,
660
+ constant uint64_t & nb01,
661
+ constant uint64_t & nb02,
662
+ constant int64_t & ne10,
663
+ constant int64_t & ne11,
664
+ constant int64_t & ne12,
665
+ constant uint64_t & nb10,
666
+ constant uint64_t & nb11,
667
+ constant uint64_t & nb12,
668
+ constant int64_t & ne0,
669
+ constant int64_t & ne1,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint tiisg[[thread_index_in_simdgroup]]) {
672
+
673
+ const int nrows = ne11;
674
+ const int64_t r0 = tgpig.x;
675
+ const int64_t im = tgpig.z;
676
+
677
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
678
+
679
+ for (int r1 = 0; r1 < nrows; ++r1) {
680
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
681
+
682
+ float sumf = 0;
683
+ for (int i = tiisg; i < ne00/4; i += 32) {
684
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
685
+ }
686
+
687
+ float all_sum = simd_sum(sumf);
688
+ if (tiisg == 0) {
689
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
690
+ }
691
+ }
692
+ }
693
+
619
694
  kernel void kernel_alibi_f32(
620
695
  device const float * src0,
621
696
  device float * dst,
@@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1123
1198
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1124
1199
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1125
1200
 
1126
- float yl[16];
1201
+ float yl[32];
1127
1202
 
1128
- const uint16_t kmask1 = 0x0303;
1203
+ const uint16_t kmask1 = 0x3030;
1129
1204
  const uint16_t kmask2 = 0x0f0f;
1130
1205
 
1131
- const int tid = tiisg/2;
1132
- const int ix = tiisg%2;
1133
- const int ip = tid/8; // 0 or 1
1134
- const int il = tid/2 - 4*ip; // 0...3
1206
+ const int tid = tiisg/4;
1207
+ const int ix = tiisg%4;
1208
+ const int ip = tid/4; // 0 or 1
1209
+ const int il = 2*((tid%4)/2); // 0 or 2
1135
1210
  const int ir = tid%2;
1136
1211
  const int n = 8;
1137
1212
  const int l0 = n*ir;
1138
1213
 
1139
- const uint16_t m1 = 1 << (4*ip + il);
1140
- const uint16_t m2 = m1 << 8;
1214
+ // One would think that the Metal compiler would figure out that ip and il can only have
1215
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1216
+ // with these two tales.
1217
+ //
1218
+ // Possible masks for the high bit
1219
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1220
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1221
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1222
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1223
+
1224
+ // Possible masks for the low 2 bits
1225
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1226
+
1227
+ const ushort4 hm = mm[2*ip + il/2];
1141
1228
 
1142
1229
  const int shift = 2*il;
1143
- const uint16_t qm1 = 0x0003 << shift;
1144
- const uint16_t qm2 = 0x0300 << shift;
1145
- const int32_t v1 = 4 << shift;
1146
- const int32_t v2 = 1024 << shift;
1230
+ const float v1 = il == 0 ? 4.f : 64.f;
1231
+ const float v2 = 4.f * v1;
1147
1232
 
1148
1233
  const uint16_t s_shift1 = 4*ip;
1149
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1150
- const int ik = 4 + (il%2);
1234
+ const uint16_t s_shift2 = s_shift1 + il;
1151
1235
 
1152
1236
  const int q_offset = 32*ip + l0;
1153
1237
  const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1156
1240
 
1157
1241
  device const float * y1 = yy + ix*QK_K + y_offset;
1158
1242
 
1159
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1160
- for (int i = ix; i < nb; i += 2) {
1243
+ uint32_t scales32, aux32;
1244
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1245
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1246
+
1247
+ float sumf1[2] = {0.f};
1248
+ float sumf2[2] = {0.f};
1249
+ for (int i = ix; i < nb; i += 4) {
1161
1250
 
1162
1251
  for (int l = 0; l < 8; ++l) {
1163
- yl[l+0] = y1[l+ 0];
1164
- yl[l+8] = y1[l+16];
1252
+ yl[l+ 0] = y1[l+ 0];
1253
+ yl[l+ 8] = y1[l+16];
1254
+ yl[l+16] = y1[l+32];
1255
+ yl[l+24] = y1[l+48];
1165
1256
  }
1166
1257
 
1167
1258
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1172
1263
  for (int row = 0; row < 2; ++row) {
1173
1264
 
1174
1265
  const float d_all = (float)dh[0];
1175
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1176
1266
 
1177
- float s1 = 0, s2 = 0;
1267
+ scales16[0] = a[4];
1268
+ scales16[1] = a[5];
1269
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1270
+ scales16[0] = a[il+0];
1271
+ scales16[1] = a[il+1];
1272
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1273
+
1274
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1178
1275
  for (int l = 0; l < n; l += 2) {
1179
- const uint16_t qs = q[l/2];
1180
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1181
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1276
+ const int32_t qs = q[l/2];
1277
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1278
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1279
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1280
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1281
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1282
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1182
1283
  }
1183
- float d = d_all * (s1 + 1.f/256.f * s2);
1184
- sumf1[row] += d * scales[0];
1185
- sumf2[row] += d;
1284
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1285
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1286
+ sumf1[row] += d1 * (scales[0] - 32);
1287
+ sumf2[row] += d2 * (scales[2] - 32);
1186
1288
 
1187
- s1 = s2 = 0;
1289
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1188
1290
  for (int l = 0; l < n; l += 2) {
1189
- const uint16_t qs = q[l/2+8];
1190
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1191
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1291
+ const int32_t qs = q[l/2+8];
1292
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1293
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1294
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1295
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1296
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1297
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1192
1298
  }
1193
- d = d_all * (s1 + 1.f/256.f * s2);
1194
- sumf1[row] += d * scales[1];
1195
- sumf2[row] += d;
1299
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1300
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1301
+ sumf1[row] += d1 * (scales[1] - 32);
1302
+ sumf2[row] += d2 * (scales[3] - 32);
1196
1303
 
1197
1304
  q += step;
1198
1305
  h += step;
@@ -1201,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
1201
1308
 
1202
1309
  }
1203
1310
 
1204
- y1 += 2 * QK_K;
1311
+ y1 += 4 * QK_K;
1205
1312
 
1206
1313
  }
1207
1314
 
1208
1315
  for (int row = 0; row < 2; ++row) {
1209
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1210
- const float tot = simd_sum(sumf);
1211
- if (tiisg == 0) {
1212
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1316
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1317
+ sumf1[row] = simd_sum(sumf);
1318
+ }
1319
+ if (tiisg == 0) {
1320
+ for (int row = 0; row < 2; ++row) {
1321
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1213
1322
  }
1214
1323
  }
1324
+
1215
1325
  }
1216
1326
  #else
1217
1327
  kernel void kernel_mul_mat_q3_K_f32(
@@ -1564,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1564
1674
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1565
1675
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1566
1676
 
1567
- float4 acc = {0.f, 0.f, 0.f, 0.f};
1677
+ float4 acc1 = {0.f};
1678
+ float4 acc2 = {0.f};
1568
1679
  for (int l = 0; l < n; ++l) {
1569
1680
  uint8_t h = qh[l];
1570
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1571
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1572
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1573
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1681
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1682
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1683
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1684
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1685
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1686
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1687
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1688
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1574
1689
  }
1575
1690
  const float dall = dh[0];
1576
1691
  const float dmin = dh[1];
1577
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1692
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1693
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1694
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1695
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1578
1696
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1579
1697
 
1580
1698
  q1 += step;
@@ -1757,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1757
1875
 
1758
1876
  template <typename type4x4>
1759
1877
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1878
+
1760
1879
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1761
- const half d = il ? (xb->d / 16.h) : xb->d;
1762
- const half m = il ? ( -8.h * 16.h) : -8.h;
1880
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1881
+ const float d2 = d1 / 256.f;
1882
+ const float md = -8.h * xb->d;
1763
1883
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1764
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1884
+ const ushort mask1 = mask0 << 8;
1765
1885
 
1766
1886
  for (int i=0;i<8;i++) {
1767
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1768
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1887
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1888
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1769
1889
  }
1890
+
1770
1891
  }
1771
1892
 
1772
1893
  template <typename type4x4>
1773
1894
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1895
+
1774
1896
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1775
- const half d = il ? (xb->d / 16.h) : xb->d;
1776
- const half m = xb->m;
1897
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1898
+ const float d2 = d1 / 256.f;
1899
+ const float m = xb->m;
1777
1900
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1778
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1901
+ const ushort mask1 = mask0 << 8;
1779
1902
 
1780
1903
  for (int i=0;i<8;i++) {
1781
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1782
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1904
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1905
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1783
1906
  }
1784
1907
  }
1785
1908
 
@@ -1815,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1815
1938
 
1816
1939
  template <typename type4x4>
1817
1940
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1818
- const float d_all = (float)(xb->d);
1941
+ const half d_all = xb->d;
1819
1942
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1820
1943
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1821
1944
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1828
1951
  ((il/4)>0 ? 12 : 3);
1829
1952
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1830
1953
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1831
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1832
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1833
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1954
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
1955
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1956
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
1957
+ const half ml = 4.h * dl;
1834
1958
 
1835
- il = (il/2)%4;
1836
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1837
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1959
+ il = (il/2) & 3;
1960
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1961
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1962
+ dl *= coef;
1838
1963
 
1839
1964
  for (int i = 0; i < 16; ++i) {
1840
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1965
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1841
1966
  }
1967
+
1842
1968
  #else
1843
1969
  float kcoef = il&1 ? 1.f/16.f : 1.f;
1844
1970
  uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@@ -1852,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1852
1978
  #endif
1853
1979
  }
1854
1980
 
1981
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
1982
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
1983
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
1984
+ }
1985
+
1855
1986
  template <typename type4x4>
1856
1987
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1857
- device const uint8_t * q = xb->qs;
1988
+ device const uchar * q = xb->qs;
1858
1989
 
1859
1990
  #if QK_K == 256
1860
- const float d = (float)(xb->d);
1861
- const float min = (float)(xb->dmin);
1862
1991
  short is = (il/4) * 2;
1863
1992
  q = q + (il/4) * 32 + 16 * (il&1);
1864
- il = il%4;
1865
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1866
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1867
- const float ml = il<2 ? min * sc[1] : min * sc[3];
1993
+ il = il & 3;
1994
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
1995
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
1996
+ const half min = xb->dmin;
1997
+ const half dl = d * sc[0];
1998
+ const half ml = min * sc[1];
1868
1999
  #else
1869
2000
  q = q + 16 * (il&1);
1870
2001
  device const uint8_t * s = xb->scales;
1871
2002
  device const half2 * dh = (device const half2 *)xb->d;
1872
2003
  const float2 d = (float2)dh[0];
1873
2004
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1874
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
2005
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
1875
2006
  #endif
1876
2007
  const ushort mask = il<2 ? 0x0F : 0xF0;
1877
2008
  for (int i = 0; i < 16; ++i) {
1878
2009
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1879
2010
  }
2011
+
1880
2012
  }
1881
2013
 
1882
2014
  template <typename type4x4>
@@ -1885,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1885
2017
  device const uint8_t * qh = xb->qh;
1886
2018
 
1887
2019
  #if QK_K == 256
1888
- const float d = (float)(xb->d);
1889
- const float min = (float)(xb->dmin);
1890
2020
  short is = (il/4) * 2;
1891
2021
  q = q + 32 * (il/4) + 16 * (il&1);
1892
2022
  qh = qh + 16 * (il&1);
1893
2023
  uint8_t ul = 1 << (il/2);
1894
- il = il%4;
1895
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1896
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1897
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2024
+ il = il & 3;
2025
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2026
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2027
+ const half min = xb->dmin;
2028
+ const half dl = d * sc[0];
2029
+ const half ml = min * sc[1];
1898
2030
 
1899
- const ushort mask = il<2 ? 0x0F : 0xF0;
1900
- const float qh_val = il<2 ? 16.f : 256.f;
2031
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2032
+ const half qh_val = il<2 ? 16.h : 256.h;
1901
2033
  for (int i = 0; i < 16; ++i) {
1902
2034
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1903
2035
  }
@@ -1916,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1916
2048
 
1917
2049
  template <typename type4x4>
1918
2050
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1919
- const float d_all = (float)(xb->d);
2051
+ const half d_all = xb->d;
1920
2052
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1921
2053
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1922
2054
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1924
2056
  #if QK_K == 256
1925
2057
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1926
2058
  qh = qh + 32*(il/8) + 16*(il&1);
1927
- float sc = scales[(il%2) + 2 * ((il/2))];
1928
- il = (il/2)%4;
2059
+ half sc = scales[(il%2) + 2 * ((il/2))];
2060
+ il = (il/2) & 3;
1929
2061
  #else
1930
2062
  ql = ql + 16 * (il&1);
1931
- float sc = scales[il];
2063
+ half sc = scales[il];
1932
2064
  #endif
2065
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2066
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2067
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2068
+ const half ml = d_all * sc * 32.h;
2069
+ const half dl = d_all * sc * coef;
1933
2070
  for (int i = 0; i < 16; ++i) {
1934
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1935
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1936
- const float coef = il>1 ? 1.f/16.f : 1.f;
1937
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1938
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1939
- reg[i/4][i%4] = d_all * sc * q * coef;
2071
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2072
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2073
+ reg[i/4][i%4] = dl * q - ml;
1940
2074
  }
1941
2075
  }
1942
2076
 
@@ -1,4 +1,3 @@
1
- #define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
2
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
3
2
 
4
3
  #include "ggml.h"
@@ -47,6 +46,10 @@
47
46
  // disable "possible loss of data" to avoid hundreds of casts
48
47
  // we should just be careful :)
49
48
  #pragma warning(disable: 4244 4267)
49
+
50
+ // disable POSIX deprecation warnigns
51
+ // these functions are never going away, anyway
52
+ #pragma warning(disable: 4996)
50
53
  #endif
51
54
 
52
55
  #if defined(_WIN32)
@@ -280,7 +283,7 @@ typedef double ggml_float;
280
283
  // 16-bit float
281
284
  // on Arm, we use __fp16
282
285
  // on x86, we use uint16_t
283
- #ifdef __ARM_NEON
286
+ #if defined(__ARM_NEON) && !defined(_MSC_VER)
284
287
 
285
288
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
286
289
  //
@@ -307,12 +310,14 @@ typedef double ggml_float;
307
310
  #if defined(_MSC_VER) || defined(__MINGW32__)
308
311
  #include <intrin.h>
309
312
  #else
313
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
310
314
  #if !defined(__riscv)
311
315
  #include <immintrin.h>
312
316
  #endif
313
317
  #endif
314
318
  #endif
315
319
  #endif
320
+ #endif
316
321
 
317
322
  #ifdef __riscv_v_intrinsic
318
323
  #include <riscv_vector.h>
@@ -18872,7 +18877,6 @@ static enum ggml_opt_result linesearch_backtracking(
18872
18877
  // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
18873
18878
  return count;
18874
18879
  }
18875
- return count;
18876
18880
  }
18877
18881
  }
18878
18882