llama_cpp 0.9.3 → 0.9.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
1028
1028
  int nth = 32; // SIMD width
1029
1029
 
1030
1030
  if (ne00%4 == 0) {
1031
+ while (nth < ne00/4 && nth < 256) {
1032
+ nth *= 2;
1033
+ }
1031
1034
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1032
1035
  } else {
1033
- do {
1036
+ while (nth < ne00 && nth < 1024) {
1034
1037
  nth *= 2;
1035
- } while (nth <= ne00 && nth <= 1024);
1036
- nth /= 2;
1038
+ }
1037
1039
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1038
1040
  }
1039
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1040
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1041
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1042
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1043
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1044
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1041
+
1042
+ const float scale = ((float *) dst->op_params)[0];
1043
+
1044
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1046
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1047
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1048
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1049
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1050
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1051
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1045
1052
 
1046
1053
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1047
1054
  } break;
@@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
1351
1358
  float eps;
1352
1359
  memcpy(&eps, dst->op_params, sizeof(float));
1353
1360
 
1354
- const int nth = MIN(512, ne00);
1361
+ int nth = 32; // SIMD width
1362
+
1363
+ while (nth < ne00/4 && nth < 1024) {
1364
+ nth *= 2;
1365
+ }
1355
1366
 
1356
1367
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1357
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1360
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1361
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1362
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1368
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1369
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1370
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1371
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1372
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1373
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1363
1374
 
1364
1375
  const int64_t nrows = ggml_nrows(src0);
1365
1376
 
@@ -1433,7 +1444,8 @@ void ggml_metal_graph_compute(
1433
1444
  const int n_past = ((int32_t *) dst->op_params)[0];
1434
1445
  const int n_dims = ((int32_t *) dst->op_params)[1];
1435
1446
  const int mode = ((int32_t *) dst->op_params)[2];
1436
- const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1447
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1448
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1437
1449
 
1438
1450
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1439
1451
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -39,6 +39,8 @@ typedef struct {
39
39
  int8_t qs[QK8_0]; // quants
40
40
  } block_q8_0;
41
41
 
42
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43
+
42
44
  // general-purpose kernel for addition of two tensors
43
45
  // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
46
  // cons: not very efficient
@@ -180,10 +182,12 @@ kernel void kernel_gelu(
180
182
 
181
183
  kernel void kernel_soft_max(
182
184
  device const float * src0,
185
+ device const float * src1,
183
186
  device float * dst,
184
187
  constant int64_t & ne00,
185
188
  constant int64_t & ne01,
186
189
  constant int64_t & ne02,
190
+ constant float & scale,
187
191
  threadgroup float * buf [[threadgroup(0)]],
188
192
  uint tgpig[[threadgroup_position_in_grid]],
189
193
  uint tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +198,77 @@ kernel void kernel_soft_max(
194
198
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
199
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
196
200
 
197
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
198
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
201
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
202
+ device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
203
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
199
204
 
200
205
  // parallel max
201
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
206
+ float lmax = -INFINITY;
202
207
 
203
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
204
- lmax = MAX(lmax, psrc0[i00]);
208
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
209
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
205
210
  }
206
211
 
207
- float max = simd_max(lmax);
208
- if (tiisg == 0) {
209
- buf[sgitg] = max;
210
- }
212
+ // find the max value in the block
213
+ float max_val = simd_max(lmax);
214
+ if (ntg > N_SIMDWIDTH) {
215
+ if (sgitg == 0) {
216
+ buf[tiisg] = -INFINITY;
217
+ }
211
218
 
212
- threadgroup_barrier(mem_flags::mem_threadgroup);
219
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
220
 
214
- // broadcast, simd group number is ntg / 32
215
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
- if (tpitg < i) {
217
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
- }
219
- }
221
+ if (tiisg == 0) {
222
+ buf[sgitg] = max_val;
223
+ }
220
224
 
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
225
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
226
 
223
- max = buf[0];
227
+ max_val = buf[tiisg];
228
+ max_val = simd_max(max_val);
229
+ }
224
230
 
225
231
  // parallel sum
226
232
  float lsum = 0.0f;
227
233
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
228
- const float exp_psrc0 = exp(psrc0[i00] - max);
234
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
229
235
  lsum += exp_psrc0;
230
- // Remember the result of exp here. exp is expensive, so we really do not
231
- // wish to compute it twice.
232
236
  pdst[i00] = exp_psrc0;
233
237
  }
234
238
 
235
239
  float sum = simd_sum(lsum);
236
- if (tiisg == 0) {
237
- buf[sgitg] = sum;
238
- }
240
+ if (ntg > N_SIMDWIDTH) {
241
+ if (sgitg == 0) {
242
+ buf[tiisg] = 0.0f;
243
+ }
239
244
 
240
- threadgroup_barrier(mem_flags::mem_threadgroup);
245
+ threadgroup_barrier(mem_flags::mem_threadgroup);
241
246
 
242
- // broadcast, simd group number is ntg / 32
243
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
- if (tpitg < i) {
245
- buf[tpitg] += buf[tpitg + i];
246
- }
247
- }
247
+ if (tiisg == 0) {
248
+ buf[sgitg] = sum;
249
+ }
248
250
 
249
- threadgroup_barrier(mem_flags::mem_threadgroup);
251
+ threadgroup_barrier(mem_flags::mem_threadgroup);
252
+
253
+ sum = buf[tiisg];
254
+ sum = simd_sum(sum);
255
+ }
250
256
 
251
- sum = buf[0];
257
+ const float inv_sum = 1.0f/sum;
252
258
 
253
259
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
254
- pdst[i00] /= sum;
260
+ pdst[i00] *= inv_sum;
255
261
  }
256
262
  }
257
263
 
258
264
  kernel void kernel_soft_max_4(
259
265
  device const float * src0,
266
+ device const float * src1,
260
267
  device float * dst,
261
268
  constant int64_t & ne00,
262
269
  constant int64_t & ne01,
263
270
  constant int64_t & ne02,
271
+ constant float & scale,
264
272
  threadgroup float * buf [[threadgroup(0)]],
265
273
  uint tgpig[[threadgroup_position_in_grid]],
266
274
  uint tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +279,68 @@ kernel void kernel_soft_max_4(
271
279
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
280
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
273
281
 
274
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
275
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
282
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
283
+ device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
284
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
276
285
 
277
286
  // parallel max
278
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
287
+ float4 lmax4 = -INFINITY;
279
288
 
280
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
281
- lmax4 = fmax(lmax4, psrc4[i00]);
289
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
290
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
282
291
  }
283
292
 
284
293
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
- float max = simd_max(lmax);
286
- if (tiisg == 0) {
287
- buf[sgitg] = max;
288
- }
289
294
 
290
- threadgroup_barrier(mem_flags::mem_threadgroup);
295
+ float max_val = simd_max(lmax);
296
+ if (ntg > N_SIMDWIDTH) {
297
+ if (sgitg == 0) {
298
+ buf[tiisg] = -INFINITY;
299
+ }
291
300
 
292
- // broadcast, simd group number is ntg / 32
293
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
- if (tpitg < i) {
295
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
- }
297
- }
301
+ threadgroup_barrier(mem_flags::mem_threadgroup);
298
302
 
299
- threadgroup_barrier(mem_flags::mem_threadgroup);
303
+ if (tiisg == 0) {
304
+ buf[sgitg] = max_val;
305
+ }
306
+
307
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
308
 
301
- max = buf[0];
309
+ max_val = buf[tiisg];
310
+ max_val = simd_max(max_val);
311
+ }
302
312
 
303
313
  // parallel sum
304
314
  float4 lsum4 = 0.0f;
305
315
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
306
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
316
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
307
317
  lsum4 += exp_psrc4;
308
318
  pdst4[i00] = exp_psrc4;
309
319
  }
310
320
 
311
321
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
322
  float sum = simd_sum(lsum);
313
- if (tiisg == 0) {
314
- buf[sgitg] = sum;
315
- }
323
+ if (ntg > N_SIMDWIDTH) {
324
+ if (sgitg == 0) {
325
+ buf[tiisg] = 0.0f;
326
+ }
316
327
 
317
- threadgroup_barrier(mem_flags::mem_threadgroup);
328
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
329
 
319
- // broadcast, simd group number is ntg / 32
320
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
- if (tpitg < i) {
322
- buf[tpitg] += buf[tpitg + i];
323
- }
324
- }
330
+ if (tiisg == 0) {
331
+ buf[sgitg] = sum;
332
+ }
325
333
 
326
- threadgroup_barrier(mem_flags::mem_threadgroup);
334
+ threadgroup_barrier(mem_flags::mem_threadgroup);
335
+
336
+ sum = buf[tiisg];
337
+ sum = simd_sum(sum);
338
+ }
327
339
 
328
- sum = buf[0];
340
+ const float inv_sum = 1.0f/sum;
329
341
 
330
342
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
331
- pdst4[i00] /= sum;
343
+ pdst4[i00] *= inv_sum;
332
344
  }
333
345
  }
334
346
 
@@ -435,14 +447,13 @@ kernel void kernel_rms_norm(
435
447
  constant int64_t & ne00,
436
448
  constant uint64_t & nb01,
437
449
  constant float & eps,
438
- threadgroup float * sum [[threadgroup(0)]],
450
+ threadgroup float * buf [[threadgroup(0)]],
439
451
  uint tgpig[[threadgroup_position_in_grid]],
440
452
  uint tpitg[[thread_position_in_threadgroup]],
441
453
  uint sgitg[[simdgroup_index_in_threadgroup]],
442
454
  uint tiisg[[thread_index_in_simdgroup]],
443
455
  uint ntg[[threads_per_threadgroup]]) {
444
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
- device const float * x_scalar = (device const float *) x;
456
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
446
457
 
447
458
  float4 sumf = 0;
448
459
  float all_sum = 0;
@@ -453,40 +464,30 @@ kernel void kernel_rms_norm(
453
464
  }
454
465
  all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
455
466
  all_sum = simd_sum(all_sum);
456
- if (tiisg == 0) {
457
- sum[sgitg] = all_sum;
458
- }
467
+ if (ntg > N_SIMDWIDTH) {
468
+ if (sgitg == 0) {
469
+ buf[tiisg] = 0.0f;
470
+ }
459
471
 
460
- threadgroup_barrier(mem_flags::mem_threadgroup);
472
+ threadgroup_barrier(mem_flags::mem_threadgroup);
461
473
 
462
- // broadcast, simd group number is ntg / 32
463
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
464
- if (tpitg < i) {
465
- sum[tpitg] += sum[tpitg + i];
466
- }
467
- }
468
- if (tpitg == 0) {
469
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
- sum[0] += x_scalar[i];
474
+ if (tiisg == 0) {
475
+ buf[sgitg] = all_sum;
471
476
  }
472
- sum[0] /= ne00;
473
- }
474
477
 
475
- threadgroup_barrier(mem_flags::mem_threadgroup);
478
+ threadgroup_barrier(mem_flags::mem_threadgroup);
476
479
 
477
- const float mean = sum[0];
480
+ all_sum = buf[tiisg];
481
+ all_sum = simd_sum(all_sum);
482
+ }
483
+
484
+ const float mean = all_sum/ne00;
478
485
  const float scale = 1.0f/sqrt(mean + eps);
479
486
 
480
487
  device float4 * y = (device float4 *) (dst + tgpig*ne00);
481
- device float * y_scalar = (device float *) y;
482
488
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
483
489
  y[i00] = x[i00] * scale;
484
490
  }
485
- if (tpitg == 0) {
486
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
- y_scalar[i00] = x_scalar[i00] * scale;
488
- }
489
- }
490
491
  }
491
492
 
492
493
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,7 +577,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
576
577
  // putting them in the kernel cause a significant performance penalty
577
578
  #define N_DST 4 // each SIMD group works on 4 rows
578
579
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
579
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
580
580
  //Note: This is a template, but strictly speaking it only applies to
581
581
  // quantizations where the block size is 32. It also does not
582
582
  // giard against the number of rows not being divisible by
@@ -1,20 +1,18 @@
1
+ #include "ggml.h"
1
2
  #include "ggml-opencl.h"
2
3
 
3
4
  #include <array>
4
5
  #include <atomic>
6
+ #include <cstdio>
7
+ #include <cstdlib>
8
+ #include <cstring>
9
+ #include <limits>
5
10
  #include <sstream>
6
11
  #include <vector>
7
- #include <limits>
8
12
 
9
13
  #define CL_TARGET_OPENCL_VERSION 110
10
14
  #include <clblast.h>
11
15
 
12
- #include <stdlib.h>
13
- #include <stdio.h>
14
- #include <string.h>
15
-
16
- #include "ggml.h"
17
-
18
16
  #if defined(_MSC_VER)
19
17
  #pragma warning(disable: 4244 4267) // possible loss of data
20
18
  #endif
@@ -19,7 +19,7 @@
19
19
  #ifdef __wasm_simd128__
20
20
  #include <wasm_simd128.h>
21
21
  #else
22
- #ifdef __POWER9_VECTOR__
22
+ #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
23
23
  #include <altivec.h>
24
24
  #undef bool
25
25
  #define bool _Bool