llama_cpp 0.5.1 → 0.5.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
68
  constant float & scale,
69
69
  uint tpig[[thread_position_in_grid]]) {
70
70
  dst[tpig] = src0[tpig] * scale;
71
71
  }
72
72
 
73
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
77
+ device const float4 & x = src0[tpig];
78
78
  dst[tpig] = x / (1.0f + exp(-x));
79
79
  }
80
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
90
 
91
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
95
+ device const float4 & x = src0[tpig];
96
96
 
97
97
  // BEWARE !!!
98
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
107
  constant int64_t & ne00,
108
108
  constant int64_t & ne01,
109
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
110
  uint3 tgpig[[threadgroup_position_in_grid]],
112
111
  uint3 tpitg[[thread_position_in_threadgroup]],
113
112
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
119
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
119
 
121
120
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
125
- }
126
-
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
121
+ float lmax = psrc0[tpitg[0]];
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
134
124
  }
135
-
136
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
137
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
138
- //if (tpitg[0] == 0) {
139
- // buf[0] = buf[0];
140
- //}
141
-
142
- //threadgroup_barrier(mem_flags::mem_threadgroup);
143
-
144
- const float max = buf[0];
125
+ const float max = simd_max(lmax);
145
126
 
146
127
  // parallel sum
147
- buf[tpitg[0]] = 0.0f;
128
+ float lsum = 0.0f;
148
129
  for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
149
130
  const float exp_psrc0 = exp(psrc0[i00] - max);
150
- buf[tpitg[0]] += exp_psrc0;
131
+ lsum += exp_psrc0;
151
132
  // Remember the result of exp here. exp is expensive, so we really do not
152
133
  // whish to compute it twice.
153
134
  pdst[i00] = exp_psrc0;
154
135
  }
155
136
 
156
- // reduce
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
159
- if (tpitg[0] < i) {
160
- buf[tpitg[0]] += buf[tpitg[0] + i];
161
- }
162
- threadgroup_barrier(mem_flags::mem_threadgroup);
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
141
+ }
142
+ }
143
+
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
156
+
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
+
160
+ // parallel max
161
+ float4 lmax4 = psrc4[tpitg[0]];
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
163
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
164
166
 
165
- // broadcast - not needed, see above
166
- //// broadcast
167
- //if (tpitg[0] == 0) {
168
- // buf[0] = buf[0];
169
- //}
167
+ const float max = simd_max(lmax);
170
168
 
171
- //threadgroup_barrier(mem_flags::mem_threadgroup);
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
175
+ }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
172
177
 
173
- const float sum = buf[0];
178
+ const float sum = simd_sum(lsum);
174
179
 
175
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
176
- pdst[i00] /= sum;
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
177
182
  }
178
183
  }
179
184
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
192
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
193
198
  } else {
194
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
195
227
  }
196
228
  }
197
229
 
@@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
616
648
  }
617
649
  }
618
650
 
651
+ // Assumes row size (ne00) is a multiple of 4
652
+ kernel void kernel_mul_mat_f16_f32_l4(
653
+ device const char * src0,
654
+ device const char * src1,
655
+ device float * dst,
656
+ constant int64_t & ne00,
657
+ constant int64_t & ne01,
658
+ constant int64_t & ne02,
659
+ constant uint64_t & nb00,
660
+ constant uint64_t & nb01,
661
+ constant uint64_t & nb02,
662
+ constant int64_t & ne10,
663
+ constant int64_t & ne11,
664
+ constant int64_t & ne12,
665
+ constant uint64_t & nb10,
666
+ constant uint64_t & nb11,
667
+ constant uint64_t & nb12,
668
+ constant int64_t & ne0,
669
+ constant int64_t & ne1,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint tiisg[[thread_index_in_simdgroup]]) {
672
+
673
+ const int nrows = ne11;
674
+ const int64_t r0 = tgpig.x;
675
+ const int64_t im = tgpig.z;
676
+
677
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
678
+
679
+ for (int r1 = 0; r1 < nrows; ++r1) {
680
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
681
+
682
+ float sumf = 0;
683
+ for (int i = tiisg; i < ne00/4; i += 32) {
684
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
685
+ }
686
+
687
+ float all_sum = simd_sum(sumf);
688
+ if (tiisg == 0) {
689
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
690
+ }
691
+ }
692
+ }
693
+
619
694
  kernel void kernel_alibi_f32(
620
695
  device const float * src0,
621
696
  device float * dst,
@@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1123
1198
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1124
1199
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1125
1200
 
1126
- float yl[16];
1201
+ float yl[32];
1127
1202
 
1128
- const uint16_t kmask1 = 0x0303;
1203
+ const uint16_t kmask1 = 0x3030;
1129
1204
  const uint16_t kmask2 = 0x0f0f;
1130
1205
 
1131
- const int tid = tiisg/2;
1132
- const int ix = tiisg%2;
1133
- const int ip = tid/8; // 0 or 1
1134
- const int il = tid/2 - 4*ip; // 0...3
1206
+ const int tid = tiisg/4;
1207
+ const int ix = tiisg%4;
1208
+ const int ip = tid/4; // 0 or 1
1209
+ const int il = 2*((tid%4)/2); // 0 or 2
1135
1210
  const int ir = tid%2;
1136
1211
  const int n = 8;
1137
1212
  const int l0 = n*ir;
1138
1213
 
1139
- const uint16_t m1 = 1 << (4*ip + il);
1140
- const uint16_t m2 = m1 << 8;
1214
+ // One would think that the Metal compiler would figure out that ip and il can only have
1215
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1216
+ // with these two tales.
1217
+ //
1218
+ // Possible masks for the high bit
1219
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1220
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1221
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1222
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1223
+
1224
+ // Possible masks for the low 2 bits
1225
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1226
+
1227
+ const ushort4 hm = mm[2*ip + il/2];
1141
1228
 
1142
1229
  const int shift = 2*il;
1143
- const uint16_t qm1 = 0x0003 << shift;
1144
- const uint16_t qm2 = 0x0300 << shift;
1145
- const int32_t v1 = 4 << shift;
1146
- const int32_t v2 = 1024 << shift;
1230
+ const float v1 = il == 0 ? 4.f : 64.f;
1231
+ const float v2 = 4.f * v1;
1147
1232
 
1148
1233
  const uint16_t s_shift1 = 4*ip;
1149
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1150
- const int ik = 4 + (il%2);
1234
+ const uint16_t s_shift2 = s_shift1 + il;
1151
1235
 
1152
1236
  const int q_offset = 32*ip + l0;
1153
1237
  const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1156
1240
 
1157
1241
  device const float * y1 = yy + ix*QK_K + y_offset;
1158
1242
 
1159
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1160
- for (int i = ix; i < nb; i += 2) {
1243
+ uint32_t scales32, aux32;
1244
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1245
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1246
+
1247
+ float sumf1[2] = {0.f};
1248
+ float sumf2[2] = {0.f};
1249
+ for (int i = ix; i < nb; i += 4) {
1161
1250
 
1162
1251
  for (int l = 0; l < 8; ++l) {
1163
- yl[l+0] = y1[l+ 0];
1164
- yl[l+8] = y1[l+16];
1252
+ yl[l+ 0] = y1[l+ 0];
1253
+ yl[l+ 8] = y1[l+16];
1254
+ yl[l+16] = y1[l+32];
1255
+ yl[l+24] = y1[l+48];
1165
1256
  }
1166
1257
 
1167
1258
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1172
1263
  for (int row = 0; row < 2; ++row) {
1173
1264
 
1174
1265
  const float d_all = (float)dh[0];
1175
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1176
1266
 
1177
- float s1 = 0, s2 = 0;
1267
+ scales16[0] = a[4];
1268
+ scales16[1] = a[5];
1269
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1270
+ scales16[0] = a[il+0];
1271
+ scales16[1] = a[il+1];
1272
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1273
+
1274
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1178
1275
  for (int l = 0; l < n; l += 2) {
1179
- const uint16_t qs = q[l/2];
1180
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1181
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1276
+ const int32_t qs = q[l/2];
1277
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1278
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1279
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1280
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1281
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1282
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1182
1283
  }
1183
- float d = d_all * (s1 + 1.f/256.f * s2);
1184
- sumf1[row] += d * scales[0];
1185
- sumf2[row] += d;
1284
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1285
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1286
+ sumf1[row] += d1 * (scales[0] - 32);
1287
+ sumf2[row] += d2 * (scales[2] - 32);
1186
1288
 
1187
- s1 = s2 = 0;
1289
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1188
1290
  for (int l = 0; l < n; l += 2) {
1189
- const uint16_t qs = q[l/2+8];
1190
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1191
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1291
+ const int32_t qs = q[l/2+8];
1292
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1293
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1294
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1295
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1296
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1297
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1192
1298
  }
1193
- d = d_all * (s1 + 1.f/256.f * s2);
1194
- sumf1[row] += d * scales[1];
1195
- sumf2[row] += d;
1299
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1300
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1301
+ sumf1[row] += d1 * (scales[1] - 32);
1302
+ sumf2[row] += d2 * (scales[3] - 32);
1196
1303
 
1197
1304
  q += step;
1198
1305
  h += step;
@@ -1201,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
1201
1308
 
1202
1309
  }
1203
1310
 
1204
- y1 += 2 * QK_K;
1311
+ y1 += 4 * QK_K;
1205
1312
 
1206
1313
  }
1207
1314
 
1208
1315
  for (int row = 0; row < 2; ++row) {
1209
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1210
- const float tot = simd_sum(sumf);
1211
- if (tiisg == 0) {
1212
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1316
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1317
+ sumf1[row] = simd_sum(sumf);
1318
+ }
1319
+ if (tiisg == 0) {
1320
+ for (int row = 0; row < 2; ++row) {
1321
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1213
1322
  }
1214
1323
  }
1324
+
1215
1325
  }
1216
1326
  #else
1217
1327
  kernel void kernel_mul_mat_q3_K_f32(
@@ -1564,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1564
1674
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1565
1675
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1566
1676
 
1567
- float4 acc = {0.f, 0.f, 0.f, 0.f};
1677
+ float4 acc1 = {0.f};
1678
+ float4 acc2 = {0.f};
1568
1679
  for (int l = 0; l < n; ++l) {
1569
1680
  uint8_t h = qh[l];
1570
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1571
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1572
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1573
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1681
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1682
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1683
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1684
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1685
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1686
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1687
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1688
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1574
1689
  }
1575
1690
  const float dall = dh[0];
1576
1691
  const float dmin = dh[1];
1577
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1692
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1693
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1694
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1695
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1578
1696
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1579
1697
 
1580
1698
  q1 += step;
@@ -1757,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1757
1875
 
1758
1876
  template <typename type4x4>
1759
1877
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1878
+
1760
1879
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1761
- const half d = il ? (xb->d / 16.h) : xb->d;
1762
- const half m = il ? ( -8.h * 16.h) : -8.h;
1880
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1881
+ const float d2 = d1 / 256.f;
1882
+ const float md = -8.h * xb->d;
1763
1883
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1764
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1884
+ const ushort mask1 = mask0 << 8;
1765
1885
 
1766
1886
  for (int i=0;i<8;i++) {
1767
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1768
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1887
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1888
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1769
1889
  }
1890
+
1770
1891
  }
1771
1892
 
1772
1893
  template <typename type4x4>
1773
1894
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1895
+
1774
1896
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1775
- const half d = il ? (xb->d / 16.h) : xb->d;
1776
- const half m = xb->m;
1897
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1898
+ const float d2 = d1 / 256.f;
1899
+ const float m = xb->m;
1777
1900
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1778
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1901
+ const ushort mask1 = mask0 << 8;
1779
1902
 
1780
1903
  for (int i=0;i<8;i++) {
1781
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1782
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1904
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1905
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1783
1906
  }
1784
1907
  }
1785
1908
 
@@ -1815,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1815
1938
 
1816
1939
  template <typename type4x4>
1817
1940
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1818
- const float d_all = (float)(xb->d);
1941
+ const half d_all = xb->d;
1819
1942
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1820
1943
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1821
1944
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1828
1951
  ((il/4)>0 ? 12 : 3);
1829
1952
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1830
1953
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1831
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1832
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1833
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1954
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
1955
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1956
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
1957
+ const half ml = 4.h * dl;
1834
1958
 
1835
- il = (il/2)%4;
1836
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1837
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1959
+ il = (il/2) & 3;
1960
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1961
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1962
+ dl *= coef;
1838
1963
 
1839
1964
  for (int i = 0; i < 16; ++i) {
1840
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1965
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1841
1966
  }
1967
+
1842
1968
  #else
1843
1969
  float kcoef = il&1 ? 1.f/16.f : 1.f;
1844
1970
  uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@@ -1852,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1852
1978
  #endif
1853
1979
  }
1854
1980
 
1981
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
1982
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
1983
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
1984
+ }
1985
+
1855
1986
  template <typename type4x4>
1856
1987
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1857
- device const uint8_t * q = xb->qs;
1988
+ device const uchar * q = xb->qs;
1858
1989
 
1859
1990
  #if QK_K == 256
1860
- const float d = (float)(xb->d);
1861
- const float min = (float)(xb->dmin);
1862
1991
  short is = (il/4) * 2;
1863
1992
  q = q + (il/4) * 32 + 16 * (il&1);
1864
- il = il%4;
1865
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1866
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1867
- const float ml = il<2 ? min * sc[1] : min * sc[3];
1993
+ il = il & 3;
1994
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
1995
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
1996
+ const half min = xb->dmin;
1997
+ const half dl = d * sc[0];
1998
+ const half ml = min * sc[1];
1868
1999
  #else
1869
2000
  q = q + 16 * (il&1);
1870
2001
  device const uint8_t * s = xb->scales;
1871
2002
  device const half2 * dh = (device const half2 *)xb->d;
1872
2003
  const float2 d = (float2)dh[0];
1873
2004
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1874
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
2005
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
1875
2006
  #endif
1876
2007
  const ushort mask = il<2 ? 0x0F : 0xF0;
1877
2008
  for (int i = 0; i < 16; ++i) {
1878
2009
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1879
2010
  }
2011
+
1880
2012
  }
1881
2013
 
1882
2014
  template <typename type4x4>
@@ -1885,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1885
2017
  device const uint8_t * qh = xb->qh;
1886
2018
 
1887
2019
  #if QK_K == 256
1888
- const float d = (float)(xb->d);
1889
- const float min = (float)(xb->dmin);
1890
2020
  short is = (il/4) * 2;
1891
2021
  q = q + 32 * (il/4) + 16 * (il&1);
1892
2022
  qh = qh + 16 * (il&1);
1893
2023
  uint8_t ul = 1 << (il/2);
1894
- il = il%4;
1895
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1896
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1897
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2024
+ il = il & 3;
2025
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2026
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2027
+ const half min = xb->dmin;
2028
+ const half dl = d * sc[0];
2029
+ const half ml = min * sc[1];
1898
2030
 
1899
- const ushort mask = il<2 ? 0x0F : 0xF0;
1900
- const float qh_val = il<2 ? 16.f : 256.f;
2031
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2032
+ const half qh_val = il<2 ? 16.h : 256.h;
1901
2033
  for (int i = 0; i < 16; ++i) {
1902
2034
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1903
2035
  }
@@ -1916,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1916
2048
 
1917
2049
  template <typename type4x4>
1918
2050
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1919
- const float d_all = (float)(xb->d);
2051
+ const half d_all = xb->d;
1920
2052
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1921
2053
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1922
2054
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1924
2056
  #if QK_K == 256
1925
2057
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1926
2058
  qh = qh + 32*(il/8) + 16*(il&1);
1927
- float sc = scales[(il%2) + 2 * ((il/2))];
1928
- il = (il/2)%4;
2059
+ half sc = scales[(il%2) + 2 * ((il/2))];
2060
+ il = (il/2) & 3;
1929
2061
  #else
1930
2062
  ql = ql + 16 * (il&1);
1931
- float sc = scales[il];
2063
+ half sc = scales[il];
1932
2064
  #endif
2065
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2066
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2067
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2068
+ const half ml = d_all * sc * 32.h;
2069
+ const half dl = d_all * sc * coef;
1933
2070
  for (int i = 0; i < 16; ++i) {
1934
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1935
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1936
- const float coef = il>1 ? 1.f/16.f : 1.f;
1937
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1938
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1939
- reg[i/4][i%4] = d_all * sc * q * coef;
2071
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2072
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2073
+ reg[i/4][i%4] = dl * q - ml;
1940
2074
  }
1941
2075
  }
1942
2076
 
@@ -1,4 +1,3 @@
1
- #define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
2
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
3
2
 
4
3
  #include "ggml.h"
@@ -47,6 +46,10 @@
47
46
  // disable "possible loss of data" to avoid hundreds of casts
48
47
  // we should just be careful :)
49
48
  #pragma warning(disable: 4244 4267)
49
+
50
+ // disable POSIX deprecation warnigns
51
+ // these functions are never going away, anyway
52
+ #pragma warning(disable: 4996)
50
53
  #endif
51
54
 
52
55
  #if defined(_WIN32)
@@ -280,7 +283,7 @@ typedef double ggml_float;
280
283
  // 16-bit float
281
284
  // on Arm, we use __fp16
282
285
  // on x86, we use uint16_t
283
- #ifdef __ARM_NEON
286
+ #if defined(__ARM_NEON) && !defined(_MSC_VER)
284
287
 
285
288
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
286
289
  //
@@ -307,12 +310,14 @@ typedef double ggml_float;
307
310
  #if defined(_MSC_VER) || defined(__MINGW32__)
308
311
  #include <intrin.h>
309
312
  #else
313
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
310
314
  #if !defined(__riscv)
311
315
  #include <immintrin.h>
312
316
  #endif
313
317
  #endif
314
318
  #endif
315
319
  #endif
320
+ #endif
316
321
 
317
322
  #ifdef __riscv_v_intrinsic
318
323
  #include <riscv_vector.h>
@@ -18872,7 +18877,6 @@ static enum ggml_opt_result linesearch_backtracking(
18872
18877
  // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
18873
18878
  return count;
18874
18879
  }
18875
- return count;
18876
18880
  }
18877
18881
  }
18878
18882